diff --git a/Cargo.lock b/Cargo.lock index 1d0ba26..3bb01f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -69,6 +69,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "async-trait" +version = "0.1.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.15", +] + [[package]] name = "atty" version = "0.2.14" @@ -723,9 +734,11 @@ dependencies = [ name = "rkvm-net" version = "0.2.0" dependencies = [ + "async-trait", "bincode", "rkvm-input", "serde", + "thiserror", "tokio", ] diff --git a/rkvm-client/src/main.rs b/rkvm-client/src/main.rs index 3a69d66..97d1ee8 100644 --- a/rkvm-client/src/main.rs +++ b/rkvm-client/src/main.rs @@ -5,11 +5,11 @@ use clap::Parser; use config::Config; use log::LevelFilter; use rkvm_input::EventWriter; -use std::io::{Error, ErrorKind}; +use rkvm_net::Message; +use std::io::Error; use std::path::PathBuf; use std::process::ExitCode; use tokio::fs; -use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; use tokio::signal; use tokio_rustls::rustls::ServerName; @@ -86,20 +86,11 @@ async fn run(hostname: &ServerName, port: u16, connector: TlsConnector) -> Resul let mut stream = connector.connect(hostname.clone(), stream).await?; log::info!("Connected to server"); - rkvm_net::write_version(&mut stream, rkvm_net::PROTOCOL_VERSION).await?; - stream.flush().await?; - - let version = rkvm_net::read_version(&mut stream).await?; - if version != rkvm_net::PROTOCOL_VERSION { - return Err(Error::new( - ErrorKind::InvalidData, - "Invalid server protocol version", - )); - } + rkvm_net::negotiate(&mut stream).await?; let mut writer = EventWriter::new().await?; loop { - let event = rkvm_net::read_message(&mut stream).await?; + let event = Message::decode(&mut stream).await?; writer.write(event).await?; } } diff --git a/rkvm-net/Cargo.toml b/rkvm-net/Cargo.toml index 24035c4..2c81dde 100644 --- a/rkvm-net/Cargo.toml +++ b/rkvm-net/Cargo.toml @@ -2,12 +2,14 @@ name = "rkvm-net" version = "0.2.0" authors = ["Jan Trefil <8711792+htrefil@users.noreply.github.com>"] -edition = "2018" +edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] rkvm-input = { path = "../rkvm-input" } serde = { version = "1.0.117", features = ["derive"] } -bincode = "1.3.1" +bincode = "1.3.3" tokio = { version = "1.0.1", features = ["io-util"] } +async-trait = "0.1.68" +thiserror = "1.0.40" diff --git a/rkvm-net/src/lib.rs b/rkvm-net/src/lib.rs index 6cb6ece..3afaa95 100644 --- a/rkvm-net/src/lib.rs +++ b/rkvm-net/src/lib.rs @@ -1,59 +1,36 @@ -use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; -use std::convert::TryInto; +mod message; +mod version; + use std::io::{Error, ErrorKind}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use version::Version; -// Is it bold to assume there won't be more than 65536 protocol versions? -pub const PROTOCOL_VERSION: u16 = 2; +pub use message::Message; -pub async fn read_version(mut reader: R) -> Result -where - R: AsyncRead + Unpin, -{ - let mut bytes = [0; 2]; - reader.read_exact(&mut bytes).await?; +pub async fn negotiate( + stream: &mut T, +) -> Result<(), Error> { + #[derive(Error, Debug)] + #[error("Invalid version (expected {expected}, got {got} instead)")] + struct InvalidVersionError { + expected: Version, + got: Version, + } - Ok(u16::from_le_bytes(bytes)) -} + Version::CURRENT.encode(stream).await?; + stream.flush().await?; -pub async fn write_version(mut writer: W, version: u16) -> Result<(), Error> -where - W: AsyncWrite + Unpin, -{ - writer.write_all(&version.to_le_bytes()).await -} - -pub async fn read_message(mut reader: R) -> Result -where - R: AsyncRead + Unpin, - T: DeserializeOwned, -{ - let length = { - let mut bytes = [0; 1]; - reader.read_exact(&mut bytes).await?; - - bytes[0] - }; - - let mut data = vec![0; length as usize]; - reader.read_exact(&mut data).await?; - - bincode::deserialize(&data).map_err(|err| Error::new(ErrorKind::InvalidData, err)) -} - -pub async fn write_message(mut writer: W, message: &T) -> Result<(), Error> -where - W: AsyncWrite + Unpin, -{ - let data = - bincode::serialize(&message).map_err(|err| Error::new(ErrorKind::InvalidInput, err))?; - let length: u8 = data - .len() - .try_into() - .map_err(|_| Error::new(ErrorKind::InvalidInput, "Serialized data is too large"))?; - writer.write_all(&length.to_le_bytes()).await?; - writer.write_all(&data).await?; + let version = Version::decode(stream).await?; + if version != Version::CURRENT { + return Err(Error::new( + ErrorKind::InvalidData, + InvalidVersionError { + expected: Version::CURRENT, + got: version, + }, + )); + } Ok(()) } diff --git a/rkvm-net/src/message.rs b/rkvm-net/src/message.rs new file mode 100644 index 0000000..018e51b --- /dev/null +++ b/rkvm-net/src/message.rs @@ -0,0 +1,44 @@ +use bincode::{DefaultOptions, Options}; +use serde::de::DeserializeOwned; +use serde::Serialize; +use std::io::{Error, ErrorKind}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +#[async_trait::async_trait] +pub trait Message: Sized { + async fn decode(stream: &mut R) -> Result; + + async fn encode(&self, stream: &mut W) -> Result<(), Error>; +} + +#[async_trait::async_trait] +impl Message for T { + async fn decode(stream: &mut R) -> Result { + let length = stream.read_u16().await?; + + let mut data = vec![0; length.into()]; + stream.read_exact(&mut data).await?; + + options() + .deserialize(&data) + .map_err(|err| Error::new(ErrorKind::InvalidData, err)) + } + + async fn encode(&self, stream: &mut W) -> Result<(), Error> { + let data = options() + .serialize(self) + .map_err(|err| Error::new(ErrorKind::InvalidInput, err))?; + + let length = data + .len() + .try_into() + .map_err(|_| Error::new(ErrorKind::InvalidInput, "Data too large"))?; + stream.write_u16(length).await?; + + Ok(()) + } +} + +fn options() -> impl Options { + DefaultOptions::new().with_limit(u16::MAX.into()) +} diff --git a/rkvm-net/src/version.rs b/rkvm-net/src/version.rs new file mode 100644 index 0000000..ac01ac5 --- /dev/null +++ b/rkvm-net/src/version.rs @@ -0,0 +1,29 @@ +use crate::message::Message; + +use std::fmt::{self, Display, Formatter}; +use std::io::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Version(u16); + +impl Version { + pub const CURRENT: Self = Self(2); +} + +impl Display for Version { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "v{}", self.0) + } +} + +#[async_trait::async_trait] +impl Message for Version { + async fn decode(stream: &mut R) -> Result { + stream.read_u16_le().await.map(Self) + } + + async fn encode(&self, stream: &mut W) -> Result<(), Error> { + stream.write_u16_le(self.0).await + } +} diff --git a/rkvm-server/src/main.rs b/rkvm-server/src/main.rs index c66a933..a64a2f2 100644 --- a/rkvm-server/src/main.rs +++ b/rkvm-server/src/main.rs @@ -5,8 +5,9 @@ use clap::Parser; use config::Config; use log::LevelFilter; use rkvm_input::{Direction, Event, EventManager, Key, KeyKind}; +use rkvm_net::Message; use slab::Slab; -use std::io::{Error, ErrorKind}; +use std::io::Error; use std::net::SocketAddr; use std::path::PathBuf; use std::process::ExitCode; @@ -115,13 +116,7 @@ async fn run(listen: SocketAddr, acceptor: TlsAcceptor, switch_key: Key) -> Resu let result = async { let mut stream = BufStream::with_capacity(1024, 1024, stream); - rkvm_net::write_version(&mut stream, rkvm_net::PROTOCOL_VERSION).await?; - stream.flush().await?; - - let version = rkvm_net::read_version(&mut stream).await?; - if version != rkvm_net::PROTOCOL_VERSION { - return Err(Error::new(ErrorKind::InvalidData, "Invalid client protocol version")); - } + rkvm_net::negotiate(&mut stream).await?; loop { let event = match receiver.recv().await { @@ -129,7 +124,7 @@ async fn run(listen: SocketAddr, acceptor: TlsAcceptor, switch_key: Key) -> Resu None => break, }; - rkvm_net::write_message(&mut stream, &event).await?; + event.encode(&mut stream).await?; stream.flush().await?; }