mirror of
https://github.com/htrefil/rkvm.git
synced 2025-01-13 20:01:29 +01:00
Refactor rkvm-net
This commit is contained in:
parent
867cfc4b94
commit
37d741eb97
7 changed files with 126 additions and 75 deletions
13
Cargo.lock
generated
13
Cargo.lock
generated
|
@ -69,6 +69,17 @@ dependencies = [
|
|||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.68"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
|
@ -723,9 +734,11 @@ dependencies = [
|
|||
name = "rkvm-net"
|
||||
version = "0.2.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"rkvm-input",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
|
|
|
@ -5,11 +5,11 @@ use clap::Parser;
|
|||
use config::Config;
|
||||
use log::LevelFilter;
|
||||
use rkvm_input::EventWriter;
|
||||
use std::io::{Error, ErrorKind};
|
||||
use rkvm_net::Message;
|
||||
use std::io::Error;
|
||||
use std::path::PathBuf;
|
||||
use std::process::ExitCode;
|
||||
use tokio::fs;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::signal;
|
||||
use tokio_rustls::rustls::ServerName;
|
||||
|
@ -86,20 +86,11 @@ async fn run(hostname: &ServerName, port: u16, connector: TlsConnector) -> Resul
|
|||
let mut stream = connector.connect(hostname.clone(), stream).await?;
|
||||
log::info!("Connected to server");
|
||||
|
||||
rkvm_net::write_version(&mut stream, rkvm_net::PROTOCOL_VERSION).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let version = rkvm_net::read_version(&mut stream).await?;
|
||||
if version != rkvm_net::PROTOCOL_VERSION {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"Invalid server protocol version",
|
||||
));
|
||||
}
|
||||
rkvm_net::negotiate(&mut stream).await?;
|
||||
|
||||
let mut writer = EventWriter::new().await?;
|
||||
loop {
|
||||
let event = rkvm_net::read_message(&mut stream).await?;
|
||||
let event = Message::decode(&mut stream).await?;
|
||||
writer.write(event).await?;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,12 +2,14 @@
|
|||
name = "rkvm-net"
|
||||
version = "0.2.0"
|
||||
authors = ["Jan Trefil <8711792+htrefil@users.noreply.github.com>"]
|
||||
edition = "2018"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
rkvm-input = { path = "../rkvm-input" }
|
||||
serde = { version = "1.0.117", features = ["derive"] }
|
||||
bincode = "1.3.1"
|
||||
bincode = "1.3.3"
|
||||
tokio = { version = "1.0.1", features = ["io-util"] }
|
||||
async-trait = "0.1.68"
|
||||
thiserror = "1.0.40"
|
||||
|
|
|
@ -1,59 +1,36 @@
|
|||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::convert::TryInto;
|
||||
mod message;
|
||||
mod version;
|
||||
|
||||
use std::io::{Error, ErrorKind};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use version::Version;
|
||||
|
||||
// Is it bold to assume there won't be more than 65536 protocol versions?
|
||||
pub const PROTOCOL_VERSION: u16 = 2;
|
||||
pub use message::Message;
|
||||
|
||||
pub async fn read_version<R>(mut reader: R) -> Result<u16, Error>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
let mut bytes = [0; 2];
|
||||
reader.read_exact(&mut bytes).await?;
|
||||
pub async fn negotiate<T: AsyncRead + AsyncWrite + Send + Unpin>(
|
||||
stream: &mut T,
|
||||
) -> Result<(), Error> {
|
||||
#[derive(Error, Debug)]
|
||||
#[error("Invalid version (expected {expected}, got {got} instead)")]
|
||||
struct InvalidVersionError {
|
||||
expected: Version,
|
||||
got: Version,
|
||||
}
|
||||
|
||||
Ok(u16::from_le_bytes(bytes))
|
||||
}
|
||||
Version::CURRENT.encode(stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
pub async fn write_version<W>(mut writer: W, version: u16) -> Result<(), Error>
|
||||
where
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
writer.write_all(&version.to_le_bytes()).await
|
||||
}
|
||||
|
||||
pub async fn read_message<R, T>(mut reader: R) -> Result<T, Error>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
let length = {
|
||||
let mut bytes = [0; 1];
|
||||
reader.read_exact(&mut bytes).await?;
|
||||
|
||||
bytes[0]
|
||||
};
|
||||
|
||||
let mut data = vec![0; length as usize];
|
||||
reader.read_exact(&mut data).await?;
|
||||
|
||||
bincode::deserialize(&data).map_err(|err| Error::new(ErrorKind::InvalidData, err))
|
||||
}
|
||||
|
||||
pub async fn write_message<W, T: Serialize>(mut writer: W, message: &T) -> Result<(), Error>
|
||||
where
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
let data =
|
||||
bincode::serialize(&message).map_err(|err| Error::new(ErrorKind::InvalidInput, err))?;
|
||||
let length: u8 = data
|
||||
.len()
|
||||
.try_into()
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Serialized data is too large"))?;
|
||||
writer.write_all(&length.to_le_bytes()).await?;
|
||||
writer.write_all(&data).await?;
|
||||
let version = Version::decode(stream).await?;
|
||||
if version != Version::CURRENT {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
InvalidVersionError {
|
||||
expected: Version::CURRENT,
|
||||
got: version,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
44
rkvm-net/src/message.rs
Normal file
44
rkvm-net/src/message.rs
Normal file
|
@ -0,0 +1,44 @@
|
|||
use bincode::{DefaultOptions, Options};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use std::io::{Error, ErrorKind};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Message: Sized {
|
||||
async fn decode<R: AsyncRead + Send + Unpin>(stream: &mut R) -> Result<Self, Error>;
|
||||
|
||||
async fn encode<W: AsyncWrite + Send + Unpin>(&self, stream: &mut W) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<T: DeserializeOwned + Serialize + Sync> Message for T {
|
||||
async fn decode<R: AsyncRead + Send + Unpin>(stream: &mut R) -> Result<Self, Error> {
|
||||
let length = stream.read_u16().await?;
|
||||
|
||||
let mut data = vec![0; length.into()];
|
||||
stream.read_exact(&mut data).await?;
|
||||
|
||||
options()
|
||||
.deserialize(&data)
|
||||
.map_err(|err| Error::new(ErrorKind::InvalidData, err))
|
||||
}
|
||||
|
||||
async fn encode<W: AsyncWrite + Send + Unpin>(&self, stream: &mut W) -> Result<(), Error> {
|
||||
let data = options()
|
||||
.serialize(self)
|
||||
.map_err(|err| Error::new(ErrorKind::InvalidInput, err))?;
|
||||
|
||||
let length = data
|
||||
.len()
|
||||
.try_into()
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Data too large"))?;
|
||||
stream.write_u16(length).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn options() -> impl Options {
|
||||
DefaultOptions::new().with_limit(u16::MAX.into())
|
||||
}
|
29
rkvm-net/src/version.rs
Normal file
29
rkvm-net/src/version.rs
Normal file
|
@ -0,0 +1,29 @@
|
|||
use crate::message::Message;
|
||||
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
use std::io::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub struct Version(u16);
|
||||
|
||||
impl Version {
|
||||
pub const CURRENT: Self = Self(2);
|
||||
}
|
||||
|
||||
impl Display for Version {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "v{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Message for Version {
|
||||
async fn decode<R: AsyncRead + Send + Unpin>(stream: &mut R) -> Result<Self, Error> {
|
||||
stream.read_u16_le().await.map(Self)
|
||||
}
|
||||
|
||||
async fn encode<W: AsyncWrite + Send + Unpin>(&self, stream: &mut W) -> Result<(), Error> {
|
||||
stream.write_u16_le(self.0).await
|
||||
}
|
||||
}
|
|
@ -5,8 +5,9 @@ use clap::Parser;
|
|||
use config::Config;
|
||||
use log::LevelFilter;
|
||||
use rkvm_input::{Direction, Event, EventManager, Key, KeyKind};
|
||||
use rkvm_net::Message;
|
||||
use slab::Slab;
|
||||
use std::io::{Error, ErrorKind};
|
||||
use std::io::Error;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::process::ExitCode;
|
||||
|
@ -115,13 +116,7 @@ async fn run(listen: SocketAddr, acceptor: TlsAcceptor, switch_key: Key) -> Resu
|
|||
let result = async {
|
||||
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
||||
|
||||
rkvm_net::write_version(&mut stream, rkvm_net::PROTOCOL_VERSION).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let version = rkvm_net::read_version(&mut stream).await?;
|
||||
if version != rkvm_net::PROTOCOL_VERSION {
|
||||
return Err(Error::new(ErrorKind::InvalidData, "Invalid client protocol version"));
|
||||
}
|
||||
rkvm_net::negotiate(&mut stream).await?;
|
||||
|
||||
loop {
|
||||
let event = match receiver.recv().await {
|
||||
|
@ -129,7 +124,7 @@ async fn run(listen: SocketAddr, acceptor: TlsAcceptor, switch_key: Key) -> Resu
|
|||
None => break,
|
||||
};
|
||||
|
||||
rkvm_net::write_message(&mut stream, &event).await?;
|
||||
event.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue