mirror of
https://github.com/htrefil/rkvm.git
synced 2025-01-13 20:01:29 +01:00
Refactor rkvm-net
This commit is contained in:
parent
867cfc4b94
commit
37d741eb97
7 changed files with 126 additions and 75 deletions
13
Cargo.lock
generated
13
Cargo.lock
generated
|
@ -69,6 +69,17 @@ dependencies = [
|
||||||
"windows-sys 0.48.0",
|
"windows-sys 0.48.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-trait"
|
||||||
|
version = "0.1.68"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.15",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "atty"
|
name = "atty"
|
||||||
version = "0.2.14"
|
version = "0.2.14"
|
||||||
|
@ -723,9 +734,11 @@ dependencies = [
|
||||||
name = "rkvm-net"
|
name = "rkvm-net"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
"bincode",
|
"bincode",
|
||||||
"rkvm-input",
|
"rkvm-input",
|
||||||
"serde",
|
"serde",
|
||||||
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -5,11 +5,11 @@ use clap::Parser;
|
||||||
use config::Config;
|
use config::Config;
|
||||||
use log::LevelFilter;
|
use log::LevelFilter;
|
||||||
use rkvm_input::EventWriter;
|
use rkvm_input::EventWriter;
|
||||||
use std::io::{Error, ErrorKind};
|
use rkvm_net::Message;
|
||||||
|
use std::io::Error;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::process::ExitCode;
|
use std::process::ExitCode;
|
||||||
use tokio::fs;
|
use tokio::fs;
|
||||||
use tokio::io::AsyncWriteExt;
|
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio_rustls::rustls::ServerName;
|
use tokio_rustls::rustls::ServerName;
|
||||||
|
@ -86,20 +86,11 @@ async fn run(hostname: &ServerName, port: u16, connector: TlsConnector) -> Resul
|
||||||
let mut stream = connector.connect(hostname.clone(), stream).await?;
|
let mut stream = connector.connect(hostname.clone(), stream).await?;
|
||||||
log::info!("Connected to server");
|
log::info!("Connected to server");
|
||||||
|
|
||||||
rkvm_net::write_version(&mut stream, rkvm_net::PROTOCOL_VERSION).await?;
|
rkvm_net::negotiate(&mut stream).await?;
|
||||||
stream.flush().await?;
|
|
||||||
|
|
||||||
let version = rkvm_net::read_version(&mut stream).await?;
|
|
||||||
if version != rkvm_net::PROTOCOL_VERSION {
|
|
||||||
return Err(Error::new(
|
|
||||||
ErrorKind::InvalidData,
|
|
||||||
"Invalid server protocol version",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut writer = EventWriter::new().await?;
|
let mut writer = EventWriter::new().await?;
|
||||||
loop {
|
loop {
|
||||||
let event = rkvm_net::read_message(&mut stream).await?;
|
let event = Message::decode(&mut stream).await?;
|
||||||
writer.write(event).await?;
|
writer.write(event).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,12 +2,14 @@
|
||||||
name = "rkvm-net"
|
name = "rkvm-net"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
authors = ["Jan Trefil <8711792+htrefil@users.noreply.github.com>"]
|
authors = ["Jan Trefil <8711792+htrefil@users.noreply.github.com>"]
|
||||||
edition = "2018"
|
edition = "2021"
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
rkvm-input = { path = "../rkvm-input" }
|
rkvm-input = { path = "../rkvm-input" }
|
||||||
serde = { version = "1.0.117", features = ["derive"] }
|
serde = { version = "1.0.117", features = ["derive"] }
|
||||||
bincode = "1.3.1"
|
bincode = "1.3.3"
|
||||||
tokio = { version = "1.0.1", features = ["io-util"] }
|
tokio = { version = "1.0.1", features = ["io-util"] }
|
||||||
|
async-trait = "0.1.68"
|
||||||
|
thiserror = "1.0.40"
|
||||||
|
|
|
@ -1,59 +1,36 @@
|
||||||
use serde::de::DeserializeOwned;
|
mod message;
|
||||||
use serde::{Deserialize, Serialize};
|
mod version;
|
||||||
use std::convert::TryInto;
|
|
||||||
use std::io::{Error, ErrorKind};
|
use std::io::{Error, ErrorKind};
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
|
use version::Version;
|
||||||
|
|
||||||
// Is it bold to assume there won't be more than 65536 protocol versions?
|
pub use message::Message;
|
||||||
pub const PROTOCOL_VERSION: u16 = 2;
|
|
||||||
|
|
||||||
pub async fn read_version<R>(mut reader: R) -> Result<u16, Error>
|
pub async fn negotiate<T: AsyncRead + AsyncWrite + Send + Unpin>(
|
||||||
where
|
stream: &mut T,
|
||||||
R: AsyncRead + Unpin,
|
) -> Result<(), Error> {
|
||||||
{
|
#[derive(Error, Debug)]
|
||||||
let mut bytes = [0; 2];
|
#[error("Invalid version (expected {expected}, got {got} instead)")]
|
||||||
reader.read_exact(&mut bytes).await?;
|
struct InvalidVersionError {
|
||||||
|
expected: Version,
|
||||||
|
got: Version,
|
||||||
|
}
|
||||||
|
|
||||||
Ok(u16::from_le_bytes(bytes))
|
Version::CURRENT.encode(stream).await?;
|
||||||
}
|
stream.flush().await?;
|
||||||
|
|
||||||
pub async fn write_version<W>(mut writer: W, version: u16) -> Result<(), Error>
|
let version = Version::decode(stream).await?;
|
||||||
where
|
if version != Version::CURRENT {
|
||||||
W: AsyncWrite + Unpin,
|
return Err(Error::new(
|
||||||
{
|
ErrorKind::InvalidData,
|
||||||
writer.write_all(&version.to_le_bytes()).await
|
InvalidVersionError {
|
||||||
}
|
expected: Version::CURRENT,
|
||||||
|
got: version,
|
||||||
pub async fn read_message<R, T>(mut reader: R) -> Result<T, Error>
|
},
|
||||||
where
|
));
|
||||||
R: AsyncRead + Unpin,
|
}
|
||||||
T: DeserializeOwned,
|
|
||||||
{
|
|
||||||
let length = {
|
|
||||||
let mut bytes = [0; 1];
|
|
||||||
reader.read_exact(&mut bytes).await?;
|
|
||||||
|
|
||||||
bytes[0]
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut data = vec![0; length as usize];
|
|
||||||
reader.read_exact(&mut data).await?;
|
|
||||||
|
|
||||||
bincode::deserialize(&data).map_err(|err| Error::new(ErrorKind::InvalidData, err))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn write_message<W, T: Serialize>(mut writer: W, message: &T) -> Result<(), Error>
|
|
||||||
where
|
|
||||||
W: AsyncWrite + Unpin,
|
|
||||||
{
|
|
||||||
let data =
|
|
||||||
bincode::serialize(&message).map_err(|err| Error::new(ErrorKind::InvalidInput, err))?;
|
|
||||||
let length: u8 = data
|
|
||||||
.len()
|
|
||||||
.try_into()
|
|
||||||
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Serialized data is too large"))?;
|
|
||||||
writer.write_all(&length.to_le_bytes()).await?;
|
|
||||||
writer.write_all(&data).await?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
44
rkvm-net/src/message.rs
Normal file
44
rkvm-net/src/message.rs
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
use bincode::{DefaultOptions, Options};
|
||||||
|
use serde::de::DeserializeOwned;
|
||||||
|
use serde::Serialize;
|
||||||
|
use std::io::{Error, ErrorKind};
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
pub trait Message: Sized {
|
||||||
|
async fn decode<R: AsyncRead + Send + Unpin>(stream: &mut R) -> Result<Self, Error>;
|
||||||
|
|
||||||
|
async fn encode<W: AsyncWrite + Send + Unpin>(&self, stream: &mut W) -> Result<(), Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl<T: DeserializeOwned + Serialize + Sync> Message for T {
|
||||||
|
async fn decode<R: AsyncRead + Send + Unpin>(stream: &mut R) -> Result<Self, Error> {
|
||||||
|
let length = stream.read_u16().await?;
|
||||||
|
|
||||||
|
let mut data = vec![0; length.into()];
|
||||||
|
stream.read_exact(&mut data).await?;
|
||||||
|
|
||||||
|
options()
|
||||||
|
.deserialize(&data)
|
||||||
|
.map_err(|err| Error::new(ErrorKind::InvalidData, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn encode<W: AsyncWrite + Send + Unpin>(&self, stream: &mut W) -> Result<(), Error> {
|
||||||
|
let data = options()
|
||||||
|
.serialize(self)
|
||||||
|
.map_err(|err| Error::new(ErrorKind::InvalidInput, err))?;
|
||||||
|
|
||||||
|
let length = data
|
||||||
|
.len()
|
||||||
|
.try_into()
|
||||||
|
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Data too large"))?;
|
||||||
|
stream.write_u16(length).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn options() -> impl Options {
|
||||||
|
DefaultOptions::new().with_limit(u16::MAX.into())
|
||||||
|
}
|
29
rkvm-net/src/version.rs
Normal file
29
rkvm-net/src/version.rs
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
use crate::message::Message;
|
||||||
|
|
||||||
|
use std::fmt::{self, Display, Formatter};
|
||||||
|
use std::io::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
|
pub struct Version(u16);
|
||||||
|
|
||||||
|
impl Version {
|
||||||
|
pub const CURRENT: Self = Self(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for Version {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "v{}", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Message for Version {
|
||||||
|
async fn decode<R: AsyncRead + Send + Unpin>(stream: &mut R) -> Result<Self, Error> {
|
||||||
|
stream.read_u16_le().await.map(Self)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn encode<W: AsyncWrite + Send + Unpin>(&self, stream: &mut W) -> Result<(), Error> {
|
||||||
|
stream.write_u16_le(self.0).await
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,8 +5,9 @@ use clap::Parser;
|
||||||
use config::Config;
|
use config::Config;
|
||||||
use log::LevelFilter;
|
use log::LevelFilter;
|
||||||
use rkvm_input::{Direction, Event, EventManager, Key, KeyKind};
|
use rkvm_input::{Direction, Event, EventManager, Key, KeyKind};
|
||||||
|
use rkvm_net::Message;
|
||||||
use slab::Slab;
|
use slab::Slab;
|
||||||
use std::io::{Error, ErrorKind};
|
use std::io::Error;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::process::ExitCode;
|
use std::process::ExitCode;
|
||||||
|
@ -115,13 +116,7 @@ async fn run(listen: SocketAddr, acceptor: TlsAcceptor, switch_key: Key) -> Resu
|
||||||
let result = async {
|
let result = async {
|
||||||
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
||||||
|
|
||||||
rkvm_net::write_version(&mut stream, rkvm_net::PROTOCOL_VERSION).await?;
|
rkvm_net::negotiate(&mut stream).await?;
|
||||||
stream.flush().await?;
|
|
||||||
|
|
||||||
let version = rkvm_net::read_version(&mut stream).await?;
|
|
||||||
if version != rkvm_net::PROTOCOL_VERSION {
|
|
||||||
return Err(Error::new(ErrorKind::InvalidData, "Invalid client protocol version"));
|
|
||||||
}
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let event = match receiver.recv().await {
|
let event = match receiver.recv().await {
|
||||||
|
@ -129,7 +124,7 @@ async fn run(listen: SocketAddr, acceptor: TlsAcceptor, switch_key: Key) -> Resu
|
||||||
None => break,
|
None => break,
|
||||||
};
|
};
|
||||||
|
|
||||||
rkvm_net::write_message(&mut stream, &event).await?;
|
event.encode(&mut stream).await?;
|
||||||
stream.flush().await?;
|
stream.flush().await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue