Refactor rkvm-net

This commit is contained in:
Jan Trefil 2023-04-16 19:29:27 +02:00
parent 867cfc4b94
commit 37d741eb97
7 changed files with 126 additions and 75 deletions

13
Cargo.lock generated
View file

@ -69,6 +69,17 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "async-trait"
version = "0.1.68"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.15",
]
[[package]]
name = "atty"
version = "0.2.14"
@ -723,9 +734,11 @@ dependencies = [
name = "rkvm-net"
version = "0.2.0"
dependencies = [
"async-trait",
"bincode",
"rkvm-input",
"serde",
"thiserror",
"tokio",
]

View file

@ -5,11 +5,11 @@ use clap::Parser;
use config::Config;
use log::LevelFilter;
use rkvm_input::EventWriter;
use std::io::{Error, ErrorKind};
use rkvm_net::Message;
use std::io::Error;
use std::path::PathBuf;
use std::process::ExitCode;
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::signal;
use tokio_rustls::rustls::ServerName;
@ -86,20 +86,11 @@ async fn run(hostname: &ServerName, port: u16, connector: TlsConnector) -> Resul
let mut stream = connector.connect(hostname.clone(), stream).await?;
log::info!("Connected to server");
rkvm_net::write_version(&mut stream, rkvm_net::PROTOCOL_VERSION).await?;
stream.flush().await?;
let version = rkvm_net::read_version(&mut stream).await?;
if version != rkvm_net::PROTOCOL_VERSION {
return Err(Error::new(
ErrorKind::InvalidData,
"Invalid server protocol version",
));
}
rkvm_net::negotiate(&mut stream).await?;
let mut writer = EventWriter::new().await?;
loop {
let event = rkvm_net::read_message(&mut stream).await?;
let event = Message::decode(&mut stream).await?;
writer.write(event).await?;
}
}

View file

@ -2,12 +2,14 @@
name = "rkvm-net"
version = "0.2.0"
authors = ["Jan Trefil <8711792+htrefil@users.noreply.github.com>"]
edition = "2018"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
rkvm-input = { path = "../rkvm-input" }
serde = { version = "1.0.117", features = ["derive"] }
bincode = "1.3.1"
bincode = "1.3.3"
tokio = { version = "1.0.1", features = ["io-util"] }
async-trait = "0.1.68"
thiserror = "1.0.40"

View file

@ -1,59 +1,36 @@
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::convert::TryInto;
mod message;
mod version;
use std::io::{Error, ErrorKind};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use version::Version;
// Is it bold to assume there won't be more than 65536 protocol versions?
pub const PROTOCOL_VERSION: u16 = 2;
pub use message::Message;
pub async fn read_version<R>(mut reader: R) -> Result<u16, Error>
where
R: AsyncRead + Unpin,
{
let mut bytes = [0; 2];
reader.read_exact(&mut bytes).await?;
pub async fn negotiate<T: AsyncRead + AsyncWrite + Send + Unpin>(
stream: &mut T,
) -> Result<(), Error> {
#[derive(Error, Debug)]
#[error("Invalid version (expected {expected}, got {got} instead)")]
struct InvalidVersionError {
expected: Version,
got: Version,
}
Ok(u16::from_le_bytes(bytes))
}
Version::CURRENT.encode(stream).await?;
stream.flush().await?;
pub async fn write_version<W>(mut writer: W, version: u16) -> Result<(), Error>
where
W: AsyncWrite + Unpin,
{
writer.write_all(&version.to_le_bytes()).await
}
pub async fn read_message<R, T>(mut reader: R) -> Result<T, Error>
where
R: AsyncRead + Unpin,
T: DeserializeOwned,
{
let length = {
let mut bytes = [0; 1];
reader.read_exact(&mut bytes).await?;
bytes[0]
};
let mut data = vec![0; length as usize];
reader.read_exact(&mut data).await?;
bincode::deserialize(&data).map_err(|err| Error::new(ErrorKind::InvalidData, err))
}
pub async fn write_message<W, T: Serialize>(mut writer: W, message: &T) -> Result<(), Error>
where
W: AsyncWrite + Unpin,
{
let data =
bincode::serialize(&message).map_err(|err| Error::new(ErrorKind::InvalidInput, err))?;
let length: u8 = data
.len()
.try_into()
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Serialized data is too large"))?;
writer.write_all(&length.to_le_bytes()).await?;
writer.write_all(&data).await?;
let version = Version::decode(stream).await?;
if version != Version::CURRENT {
return Err(Error::new(
ErrorKind::InvalidData,
InvalidVersionError {
expected: Version::CURRENT,
got: version,
},
));
}
Ok(())
}

44
rkvm-net/src/message.rs Normal file
View file

@ -0,0 +1,44 @@
use bincode::{DefaultOptions, Options};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::io::{Error, ErrorKind};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[async_trait::async_trait]
pub trait Message: Sized {
async fn decode<R: AsyncRead + Send + Unpin>(stream: &mut R) -> Result<Self, Error>;
async fn encode<W: AsyncWrite + Send + Unpin>(&self, stream: &mut W) -> Result<(), Error>;
}
#[async_trait::async_trait]
impl<T: DeserializeOwned + Serialize + Sync> Message for T {
async fn decode<R: AsyncRead + Send + Unpin>(stream: &mut R) -> Result<Self, Error> {
let length = stream.read_u16().await?;
let mut data = vec![0; length.into()];
stream.read_exact(&mut data).await?;
options()
.deserialize(&data)
.map_err(|err| Error::new(ErrorKind::InvalidData, err))
}
async fn encode<W: AsyncWrite + Send + Unpin>(&self, stream: &mut W) -> Result<(), Error> {
let data = options()
.serialize(self)
.map_err(|err| Error::new(ErrorKind::InvalidInput, err))?;
let length = data
.len()
.try_into()
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Data too large"))?;
stream.write_u16(length).await?;
Ok(())
}
}
fn options() -> impl Options {
DefaultOptions::new().with_limit(u16::MAX.into())
}

29
rkvm-net/src/version.rs Normal file
View file

@ -0,0 +1,29 @@
use crate::message::Message;
use std::fmt::{self, Display, Formatter};
use std::io::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Version(u16);
impl Version {
pub const CURRENT: Self = Self(2);
}
impl Display for Version {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "v{}", self.0)
}
}
#[async_trait::async_trait]
impl Message for Version {
async fn decode<R: AsyncRead + Send + Unpin>(stream: &mut R) -> Result<Self, Error> {
stream.read_u16_le().await.map(Self)
}
async fn encode<W: AsyncWrite + Send + Unpin>(&self, stream: &mut W) -> Result<(), Error> {
stream.write_u16_le(self.0).await
}
}

View file

@ -5,8 +5,9 @@ use clap::Parser;
use config::Config;
use log::LevelFilter;
use rkvm_input::{Direction, Event, EventManager, Key, KeyKind};
use rkvm_net::Message;
use slab::Slab;
use std::io::{Error, ErrorKind};
use std::io::Error;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::process::ExitCode;
@ -115,13 +116,7 @@ async fn run(listen: SocketAddr, acceptor: TlsAcceptor, switch_key: Key) -> Resu
let result = async {
let mut stream = BufStream::with_capacity(1024, 1024, stream);
rkvm_net::write_version(&mut stream, rkvm_net::PROTOCOL_VERSION).await?;
stream.flush().await?;
let version = rkvm_net::read_version(&mut stream).await?;
if version != rkvm_net::PROTOCOL_VERSION {
return Err(Error::new(ErrorKind::InvalidData, "Invalid client protocol version"));
}
rkvm_net::negotiate(&mut stream).await?;
loop {
let event = match receiver.recv().await {
@ -129,7 +124,7 @@ async fn run(listen: SocketAddr, acceptor: TlsAcceptor, switch_key: Key) -> Resu
None => break,
};
rkvm_net::write_message(&mut stream, &event).await?;
event.encode(&mut stream).await?;
stream.flush().await?;
}