From 897eb0a68e6e748d4e754a640c5267e1e30344a2 Mon Sep 17 00:00:00 2001 From: Jan Trefil <8711792+htrefil@users.noreply.github.com> Date: Sun, 21 Jul 2024 15:50:50 +0200 Subject: [PATCH] Port rkvm to QUIC --- Cargo.lock | 303 ++++++++++++++++++++++++--- rkvm-client/Cargo.toml | 5 +- rkvm-client/src/client.rs | 424 ++++++++++++++++++++++++-------------- rkvm-client/src/config.rs | 30 +-- rkvm-client/src/main.rs | 6 +- rkvm-client/src/tls.rs | 31 +-- rkvm-input/src/event.rs | 2 +- rkvm-net/Cargo.toml | 1 + rkvm-net/src/lib.rs | 77 ++----- rkvm-server/Cargo.toml | 4 +- rkvm-server/src/main.rs | 6 +- rkvm-server/src/server.rs | 306 ++++++++++++++++++++------- rkvm-server/src/tls.rs | 83 ++------ 13 files changed, 862 insertions(+), 416 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 12f950d..48f87a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -109,9 +109,9 @@ dependencies = [ [[package]] name = "base64" -version = "0.21.7" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bincode" @@ -178,6 +178,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47de7e88bbbd467951ae7f5a6f34f70d1b4d9cfce53d5fd70f74ebe118b3db56" +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cexpr" version = "0.6.0" @@ -250,6 +256,32 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + [[package]] name = "cpufeatures" version = "0.2.12" @@ -510,6 +542,26 @@ version = "1.70.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" +[[package]] +name = "jni" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec" +dependencies = [ + "cesu8", + "combine", + "jni-sys", + "log", + "thiserror", + "walkdir", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "lazy_static" version = "1.5.0" @@ -611,6 +663,34 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.16.0" @@ -636,6 +716,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + [[package]] name = "overload" version = "0.1.1" @@ -691,6 +777,54 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quinn" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4ceeeeabace7857413798eb1ffa1e9c905a9946a57d81fb69b4b71c4d8eb3ad" +dependencies = [ + "bytes", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "thiserror", + "tokio", + "tracing", +] + +[[package]] +name = "quinn-proto" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddf517c03a109db8100448a4be38d498df8a210a99fe0e1b9eaf39e78c640efe" +dependencies = [ + "bytes", + "rand", + "ring", + "rustc-hash", + "rustls", + "rustls-platform-verifier", + "slab", + "thiserror", + "tinyvec", + "tracing", +] + +[[package]] +name = "quinn-udp" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9096629c45860fc7fb143e125eb826b5e721e10be3263160c7d60ca832cf8c46" +dependencies = [ + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.52.0", +] + [[package]] name = "quote" version = "1.0.36" @@ -803,14 +937,13 @@ name = "rkvm-client" version = "0.6.1" dependencies = [ "clap", - "env_logger", + "quinn", "rkvm-input", "rkvm-net", "rustls-pemfile", "serde", "thiserror", "tokio", - "tokio-rustls", "toml", "tracing", "tracing-subscriber", @@ -839,6 +972,7 @@ version = "0.1.0" dependencies = [ "bincode", "hmac", + "quinn", "rand", "rkvm-input", "serde", @@ -854,6 +988,7 @@ version = "0.6.1" dependencies = [ "clap", "env_logger", + "quinn", "rand", "rkvm-input", "rkvm-net", @@ -862,7 +997,6 @@ dependencies = [ "slab", "thiserror", "tokio", - "tokio-rustls", "toml", "tracing", "tracing-subscriber", @@ -895,43 +1029,125 @@ dependencies = [ [[package]] name = "rustls" -version = "0.21.12" +version = "0.23.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +checksum = "4828ea528154ae444e5a642dbb7d5623354030dc9822b83fd9bb79683c7399d0" dependencies = [ - "log", + "once_cell", "ring", + "rustls-pki-types", "rustls-webpki", - "sct", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "rustls-pki-types", + "schannel", + "security-framework", ] [[package]] name = "rustls-pemfile" -version = "1.0.4" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" dependencies = [ "base64", + "rustls-pki-types", ] +[[package]] +name = "rustls-pki-types" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" + +[[package]] +name = "rustls-platform-verifier" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e3beb939bcd33c269f4bf946cc829fcd336370267c4a927ac0399c84a3151a1" +dependencies = [ + "core-foundation", + "core-foundation-sys", + "jni", + "log", + "once_cell", + "rustls", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-roots", + "winapi", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84e217e7fdc8466b5b35d30f8c0a30febd29173df4a3a0c2115d306b9c4117ad" + [[package]] name = "rustls-webpki" -version = "0.101.7" +version = "0.102.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +checksum = "f9a6fccd794a42c2c105b513a2f62bc3fd8f3ba57a4593677ceb0bd035164d78" dependencies = [ "ring", + "rustls-pki-types", "untrusted", ] [[package]] -name = "sct" -version = "0.7.1" +name = "same-file" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" dependencies = [ - "ring", - "untrusted", + "winapi-util", +] + +[[package]] +name = "schannel" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "security-framework" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "core-foundation-sys", + "libc", + "num-bigint", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +dependencies = [ + "core-foundation-sys", + "libc", ] [[package]] @@ -1097,6 +1313,21 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.38.0" @@ -1126,16 +1357,6 @@ dependencies = [ "syn", ] -[[package]] -name = "tokio-rustls" -version = "0.24.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" -dependencies = [ - "rustls", - "tokio", -] - [[package]] name = "toml" version = "0.5.11" @@ -1151,6 +1372,7 @@ version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -1242,12 +1464,31 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "webpki-roots" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd7c23921eeb1713a4e851530e9b9756e4fb0e89978582942612524cf09f01cd" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "which" version = "4.4.2" @@ -1429,3 +1670,9 @@ name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" diff --git a/rkvm-client/Cargo.toml b/rkvm-client/Cargo.toml index 272a814..fdc657d 100644 --- a/rkvm-client/Cargo.toml +++ b/rkvm-client/Cargo.toml @@ -13,13 +13,12 @@ rkvm-input = { path = "../rkvm-input" } rkvm-net = { path = "../rkvm-net" } serde = { version = "1.0.117", features = ["derive"] } toml = "0.5.7" -env_logger = "0.8.1" clap = { version = "4.2.2", features = ["derive"] } thiserror = "1.0.40" -tokio-rustls = "0.24.0" -rustls-pemfile = "1.0.2" +rustls-pemfile = "2.1.2" tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +quinn = "0.11.2" [package.metadata.rpm] package = "rkvm-client" diff --git a/rkvm-client/src/client.rs b/rkvm-client/src/client.rs index 9cefbf1..4616541 100644 --- a/rkvm-client/src/client.rs +++ b/rkvm-client/src/client.rs @@ -1,23 +1,31 @@ +use quinn::ClientConfig; +use quinn::ConnectError; +use quinn::Connection; +use quinn::ConnectionError; +use quinn::Endpoint; +use rkvm_input::event::Event; use rkvm_input::writer::Writer; use rkvm_net::auth::{AuthChallenge, AuthStatus}; use rkvm_net::message::Message; use rkvm_net::version::Version; -use rkvm_net::{Pong, Update}; -use std::collections::hash_map::Entry; +use rkvm_net::Datagram; +use rkvm_net::DeviceInfo; use std::collections::HashMap; use std::io; -use std::time::Instant; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use thiserror::Error; -use tokio::io::{AsyncWriteExt, BufStream}; -use tokio::net::TcpStream; -use tokio::time; -use tokio_rustls::rustls::ServerName; -use tokio_rustls::TlsConnector; +use tokio::io::AsyncRead; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::io::BufWriter; +use tokio::net; +use tokio::sync::mpsc::{self, Sender}; +use tracing::Instrument; #[derive(Error, Debug)] pub enum Error { #[error("Network error: {0}")] - Network(io::Error), + Network(#[from] NetworkError), #[error("Input error: {0}")] Input(io::Error), #[error("Incompatible server version (got {server}, expected {client})")] @@ -26,45 +34,39 @@ pub enum Error { Auth, } +#[derive(Error, Debug)] +pub enum NetworkError { + #[error(transparent)] + Io(#[from] io::Error), + #[error(transparent)] + Connect(#[from] ConnectError), + #[error(transparent)] + Connection(#[from] ConnectionError), +} + pub async fn run( - hostname: &ServerName, + hostname: &str, port: u16, - connector: TlsConnector, + mut config: ClientConfig, password: &str, ) -> Result<(), Error> { - // Intentionally don't impose any timeout for TCP connect. - let stream = match hostname { - ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port)).await, - ServerName::IpAddress(address) => TcpStream::connect(&(*address, port)).await, - _ => unimplemented!("Unhandled rustls ServerName variant: {:?}", hostname), - } - .map_err(Error::Network)?; + config.transport_config(rkvm_net::transport_config().into()); - tracing::info!("Connected to server"); + let connection = connect(hostname, port, config).await?; + let (data_write, data_read) = connection.accept_bi().await.map_err(NetworkError::from)?; - let stream = rkvm_net::timeout( - rkvm_net::TLS_TIMEOUT, - connector.connect(hostname.clone(), stream), - ) - .await - .map_err(Error::Network)?; + let mut data_write = BufWriter::new(data_write); + let mut data_read = BufReader::new(data_read); - tracing::info!("TLS connected"); - - let mut stream = BufStream::with_capacity(1024, 1024, stream); - - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { - Version::CURRENT.encode(&mut stream).await?; - stream.flush().await?; - - Ok(()) - }) - .await - .map_err(Error::Network)?; - - let version = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Version::decode(&mut stream)) + Version::CURRENT + .encode(&mut data_write) .await - .map_err(Error::Network)?; + .map_err(NetworkError::from)?; + data_write.flush().await.map_err(NetworkError::from)?; + + let version = Version::decode(&mut data_read) + .await + .map_err(NetworkError::from)?; if version != Version::CURRENT { return Err(Error::Version { @@ -73,24 +75,21 @@ pub async fn run( }); } - let challenge = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthChallenge::decode(&mut stream)) + let challenge = AuthChallenge::decode(&mut data_read) .await - .map_err(Error::Network)?; + .map_err(NetworkError::from)?; let response = challenge.respond(password); - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { - response.encode(&mut stream).await?; - stream.flush().await?; - - Ok(()) - }) - .await - .map_err(Error::Network)?; - - let status = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthStatus::decode(&mut stream)) + response + .encode(&mut data_write) .await - .map_err(Error::Network)?; + .map_err(NetworkError::from)?; + data_write.flush().await.map_err(NetworkError::from)?; + + let status = AuthStatus::decode(&mut data_read) + .await + .map_err(NetworkError::from)?; match status { AuthStatus::Passed => {} @@ -99,110 +98,235 @@ pub async fn run( tracing::info!("Authenticated successfully"); - let mut start = Instant::now(); + data_write.shutdown().await.map_err(NetworkError::from)?; - let mut interval = time::interval(rkvm_net::PING_INTERVAL + rkvm_net::READ_TIMEOUT); - let mut writers = HashMap::new(); + let (error_sender, mut error_receiver) = mpsc::channel(1); + let (device_sender, mut device_receiver) = mpsc::channel(1); - // Interval ticks immediately after creation. - interval.tick().await; + let mut devices = HashMap::new(); loop { - let update = tokio::select! { - update = Update::decode(&mut stream) => update.map_err(Error::Network)?, - _ = interval.tick() => return Err(Error::Network(io::Error::new(io::ErrorKind::TimedOut, "Ping timed out"))), + let device = async { device_receiver.recv().await.unwrap() }; + let read = tokio::select! { + read = connection.accept_uni() => read.map_err(NetworkError::from)?, + device = device => { + match device { + DeviceEvent::Create { id, sender } => { + if devices.insert(id, sender).is_some() { + return Err(NetworkError::from(io::Error::new(io::ErrorKind::AlreadyExists, "Device already exists")).into()); + } + + continue; + } + DeviceEvent::Destroy { id } => { + devices.remove(&id).unwrap(); + continue; + } + } + } + datagram = connection.read_datagram() => { + let datagram = datagram.map_err(NetworkError::from)?; + let datagram = Datagram::decode(&mut &*datagram).await.map_err(NetworkError::from)?; + + let sender = match devices.get(&datagram.id) { + Some(sender) => sender, + None => { + tracing::warn!(id = %datagram.id, "Received datagram for unknown device"); + continue; + } + }; + + let _ = sender.send(datagram.events.into_owned()).await; + continue; + } + err = error_receiver.recv() => return Err(err.unwrap()), }; - match update { - Update::CreateDevice { - id, - name, - vendor, - product, - version, - rel, - abs, - keys, - delay, - period, - } => { - let entry = writers.entry(id); - if let Entry::Occupied(_) = entry { - return Err(Error::Network(io::Error::new( - io::ErrorKind::InvalidData, - "Server created the same device twice", - ))); + let stream_id = read.id(); + + let read = BufReader::new(read); + + let error_sender = error_sender.clone(); + let device_sender = device_sender.clone(); + let span = tracing::debug_span!("stream", id = %stream_id); + + tokio::spawn( + async move { + tracing::debug!("Stream connected"); + + match stream(read, device_sender).await { + Ok(()) => { + tracing::debug!("Stream disconnected"); + } + Err(err) => { + tracing::debug!("Stream disconnected: {}", err); + let _ = error_sender.send(err).await; + } } - - let writer = async { - Writer::builder()? - .name(&name) - .vendor(vendor) - .product(product) - .version(version) - .rel(rel)? - .abs(abs)? - .key(keys)? - .delay(delay)? - .period(period)? - .build() - .await - } - .await - .map_err(Error::Input)?; - - entry.or_insert(writer); - - tracing::info!( - id = %id, - name = ?name, - vendor = %vendor, - product = %product, - version = %version, - "Created new device" - ); } - Update::DestroyDevice { id } => { - if writers.remove(&id).is_none() { - return Err(Error::Network(io::Error::new( - io::ErrorKind::InvalidData, - "Server destroyed a nonexistent device", - ))); - } - - tracing::info!(id = %id, "Destroyed device"); - } - Update::Event { id, event } => { - let writer = writers.get_mut(&id).ok_or_else(|| { - Error::Network(io::Error::new( - io::ErrorKind::InvalidData, - "Server sent an event to a nonexistent device", - )) - })?; - - writer.write(&event).await.map_err(Error::Input)?; - - tracing::trace!(id = %id, "Wrote an event to device"); - } - Update::Ping => { - let duration = start.elapsed(); - tracing::debug!(duration = ?duration, "Received ping"); - - start = Instant::now(); - interval.reset(); - - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { - Pong.encode(&mut stream).await?; - stream.flush().await?; - - Ok(()) - }) - .await - .map_err(Error::Network)?; - - let duration = start.elapsed(); - tracing::debug!(duration = ?duration, "Sent pong"); - } - } + .instrument(span), + ); } } + +async fn connect( + hostname: &str, + port: u16, + config: ClientConfig, +) -> Result { + let mut last_err = None; + + let addrs = net::lookup_host((hostname, port)) + .await + .map_err(NetworkError::from)?; + + for addr in addrs { + let bind = match addr { + SocketAddr::V4(_) => (Ipv4Addr::UNSPECIFIED, 0).into(), + SocketAddr::V6(_) => (Ipv6Addr::UNSPECIFIED, 0).into(), + }; + + let endpoint = match Endpoint::client(bind) { + Ok(endpoint) => endpoint, + Err(err) => { + tracing::debug!(addr = %addr, "Error binding: {}", err); + last_err = Some(err.into()); + continue; + } + }; + + let connection = match endpoint.connect_with(config.clone(), addr, hostname) { + Ok(connection) => connection, + Err(err) => { + tracing::debug!(addr = %addr, "Error connecting: {}", err); + last_err = Some(err.into()); + continue; + } + }; + + let connection = match connection.await { + Ok(connection) => connection, + Err(err) => { + tracing::debug!(addr = %addr, "Error connecting: {}", err); + last_err = Some(err.into()); + continue; + } + }; + + tracing::info!(addr = %addr, "Connected"); + + return Ok(connection); + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "No addresses resolved").into() + })) +} + +enum DeviceEvent { + Create { + id: usize, + sender: Sender>, + }, + Destroy { + id: usize, + }, +} + +async fn stream( + mut read: T, + device_sender: Sender, +) -> Result<(), Error> { + let device_info = DeviceInfo::decode(&mut read) + .await + .map_err(NetworkError::from)?; + + let id = device_info.id; + let span = tracing::info_span!("device", id = ?device_info.id); + + let mut writer = build(device_info).await.map_err(Error::Input)?; + + let (datagram_sender, mut datagram_receiver) = mpsc::channel(1); + + let event = DeviceEvent::Create { + id, + sender: datagram_sender, + }; + + if device_sender.send(event).await.is_err() { + return Ok(()); + } + + let (event_sender, mut event_receiver) = mpsc::channel(1); + tokio::spawn( + async move { + loop { + let event = tokio::select! { + event = Event::decode(&mut read) => event, + _ = event_sender.closed() => break, + }; + + if event.is_err() | event_sender.send(event).await.is_err() { + break; + } + } + } + .instrument(span.clone()), + ); + + let result = async move { + loop { + let event = async { event_receiver.recv().await.unwrap() }; + + tokio::select! { + event = event => { + let event = event.map_err(NetworkError::from)?; + writer.write(&event).await.map_err(Error::Input)?; + } + datagram = datagram_receiver.recv() => { + let datagram = match datagram { + Some(datagram) => datagram, + None => break, + }; + + for event in datagram { + writer.write(&event).await.map_err(Error::Input)?; + } + } + } + } + + Ok(()) + } + .instrument(span) + .await; + + let _ = device_sender.send(DeviceEvent::Destroy { id }).await; + + result +} + +async fn build(device_info: DeviceInfo) -> Result { + let writer = Writer::builder()? + .name(&device_info.name) + .vendor(device_info.vendor) + .product(device_info.product) + .version(device_info.version) + .rel(device_info.rel)? + .abs(device_info.abs)? + .key(device_info.keys)? + .delay(device_info.delay)? + .period(device_info.period)? + .build() + .await?; + + tracing::info!( + name = ?device_info.name, + vendor = %device_info.vendor, + product = %device_info.product, + version = %device_info.version, + "Created new device" + ); + + Ok(writer) +} diff --git a/rkvm-client/src/config.rs b/rkvm-client/src/config.rs index 8d8145e..fe8c455 100644 --- a/rkvm-client/src/config.rs +++ b/rkvm-client/src/config.rs @@ -1,10 +1,7 @@ use serde::de::{self, Visitor}; use serde::{Deserialize, Deserializer}; use std::fmt::{self, Formatter}; -use std::net::SocketAddr; use std::path::PathBuf; -use std::str::FromStr; -use tokio_rustls::rustls::ServerName; #[derive(Deserialize)] #[serde(rename_all = "kebab-case")] @@ -15,7 +12,7 @@ pub struct Config { } pub struct Server { - pub hostname: ServerName, + pub hostname: String, pub port: u16, } @@ -41,19 +38,11 @@ impl<'de> Visitor<'de> for ServerVisitor { where E: de::Error, { - // Parsing IPv6 socket addresses can get quite hairy, so let the SocketAddr parser do it for us. - if let Ok(socket_addr) = SocketAddr::from_str(data) { - return Ok(Server { - hostname: ServerName::IpAddress(socket_addr.ip()), - port: socket_addr.port(), - }); - } - let (hostname, port) = data - .split_once(':') + .rsplit_once(':') .ok_or_else(|| E::custom("No port provided"))?; - let hostname = hostname.try_into().map_err(E::custom)?; + let hostname = hostname.to_owned(); let port = port.parse().map_err(E::custom)?; Ok(Server { hostname, port }) @@ -62,8 +51,6 @@ impl<'de> Visitor<'de> for ServerVisitor { #[cfg(test)] mod tests { - use std::net::Ipv6Addr; - use super::*; #[derive(Deserialize)] @@ -91,7 +78,7 @@ mod tests { .unwrap() .server; let expected = Server { - hostname: "127.0.0.1".try_into().unwrap(), + hostname: "127.0.0.1".to_owned(), port: 8523, }; @@ -105,19 +92,12 @@ mod tests { .unwrap() .server; let expected = Server { - hostname: "::1".try_into().unwrap(), + hostname: "[::1]".to_owned(), port: 8523, }; assert_eq!(parsed.hostname, expected.hostname); assert_eq!(parsed.port, expected.port); - - let parsed_ip = match parsed.hostname { - ServerName::IpAddress(parsed_ip) => parsed_ip, - _ => unreachable!(), - }; - - assert_eq!(parsed_ip, Ipv6Addr::from_str("::1").unwrap()); } #[test] diff --git a/rkvm-client/src/main.rs b/rkvm-client/src/main.rs index 5e15b42..fc05fe4 100644 --- a/rkvm-client/src/main.rs +++ b/rkvm-client/src/main.rs @@ -48,8 +48,8 @@ async fn main() -> ExitCode { } }; - let connector = match tls::configure(&config.certificate).await { - Ok(connector) => connector, + let client_config = match tls::configure(&config.certificate).await { + Ok(client_config) => client_config, Err(err) => { tracing::error!("Error configuring TLS: {}", err); return ExitCode::FAILURE; @@ -57,7 +57,7 @@ async fn main() -> ExitCode { }; tokio::select! { - result = client::run(&config.server.hostname, config.server.port, connector, &config.password) => { + result = client::run(&config.server.hostname, config.server.port, client_config, &config.password) => { if let Err(err) = result { tracing::error!("Error: {}", err); return ExitCode::FAILURE; diff --git a/rkvm-client/src/tls.rs b/rkvm-client/src/tls.rs index a32832a..a4956d0 100644 --- a/rkvm-client/src/tls.rs +++ b/rkvm-client/src/tls.rs @@ -1,10 +1,12 @@ +use quinn::crypto::rustls::{NoInitialCipherSuite, QuicClientConfig}; +use quinn::rustls; +use quinn::rustls::RootCertStore; +use quinn::ClientConfig; use std::io; use std::path::Path; use std::sync::Arc; use thiserror::Error; use tokio::fs; -use tokio_rustls::rustls::{self, Certificate, ClientConfig, RootCertStore}; -use tokio_rustls::TlsConnector; #[derive(Error, Debug)] pub enum Error { @@ -12,23 +14,26 @@ pub enum Error { Rustls(#[from] rustls::Error), #[error(transparent)] Io(#[from] io::Error), + #[error(transparent)] + NoInitialCipherSuite(#[from] NoInitialCipherSuite), } -pub async fn configure(certificate: &Path) -> Result { +pub async fn configure(certificate: &Path) -> Result { let certificate = fs::read(certificate).await?; - let certificates = rustls_pemfile::certs(&mut certificate.as_slice())?; + let certificate = rustls_pemfile::certs(&mut &*certificate).collect::, _>>()?; let mut store = RootCertStore::empty(); - for certificate in certificates { - store.add(&Certificate(certificate))?; + for certificate in certificate { + store.add(certificate)?; } - let config = Arc::new( - ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(store) - .with_no_client_auth(), - ); + let config = rustls::ClientConfig::builder() + .with_root_certificates(store) + .with_no_client_auth(); - Ok(config.into()) + let config = QuicClientConfig::try_from(config)?; + let config = Arc::new(config); + let config = ClientConfig::new(config); + + Ok(config) } diff --git a/rkvm-input/src/event.rs b/rkvm-input/src/event.rs index 38744a6..c7426b4 100644 --- a/rkvm-input/src/event.rs +++ b/rkvm-input/src/event.rs @@ -5,7 +5,7 @@ use crate::sync::SyncEvent; use serde::{Deserialize, Serialize}; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] pub enum Event { Rel(RelEvent), Abs(AbsEvent), diff --git a/rkvm-net/Cargo.toml b/rkvm-net/Cargo.toml index 635153f..036e7bf 100644 --- a/rkvm-net/Cargo.toml +++ b/rkvm-net/Cargo.toml @@ -16,3 +16,4 @@ hmac = "0.12.1" sha2 = "0.10.6" rand = "0.8.5" tracing = "0.1.37" +quinn = "0.11.2" diff --git a/rkvm-net/src/lib.rs b/rkvm-net/src/lib.rs index 9425e38..96130e0 100644 --- a/rkvm-net/src/lib.rs +++ b/rkvm-net/src/lib.rs @@ -5,75 +5,42 @@ pub mod auth; pub mod message; pub mod version; +use quinn::TransportConfig; use rkvm_input::abs::{AbsAxis, AbsInfo}; use rkvm_input::event::Event; use rkvm_input::key::Key; use rkvm_input::rel::RelAxis; use serde::{Deserialize, Serialize}; +use std::borrow::Cow; use std::collections::{HashMap, HashSet}; use std::ffi::CString; -use std::future::Future; -use std::io::{Error, ErrorKind}; use std::time::Duration; -use tokio::time; - -pub const PING_INTERVAL: Duration = Duration::from_secs(1); - -// Message read timeout (does not apply to updates, only auth negotiation and replies). -pub const READ_TIMEOUT: Duration = Duration::from_millis(500); - -// Message write timeout (applies to all messages). -pub const WRITE_TIMEOUT: Duration = Duration::from_millis(500); - -// TLS negotiation timeout. -pub const TLS_TIMEOUT: Duration = Duration::from_millis(500); #[derive(Deserialize, Serialize, Debug)] -pub enum Update { - CreateDevice { - id: usize, - name: CString, - vendor: u16, - product: u16, - version: u16, - rel: HashSet, - abs: HashMap, - keys: HashSet, - delay: Option, - period: Option, - }, - DestroyDevice { - id: usize, - }, - Event { - id: usize, - event: Event, - }, - Ping, +pub struct DeviceInfo { + // ID generated by rkvm-server, sent for debugging purposes. + pub id: usize, + pub name: CString, + pub vendor: u16, + pub product: u16, + pub version: u16, + pub rel: HashSet, + pub abs: HashMap, + pub keys: HashSet, + pub delay: Option, + pub period: Option, } #[derive(Deserialize, Serialize, Debug)] -pub struct Pong; - -pub async fn timeout>, U>( - duration: Duration, - future: T, -) -> Result { - time::timeout(duration, future) - .await - .map_err(|_| Error::new(ErrorKind::TimedOut, "Message timeout"))? +pub struct Datagram<'a> { + pub id: usize, + pub events: Cow<'a, [Event]>, } -#[cfg(test)] -mod test { - use super::message::Message; - use super::*; +pub fn transport_config() -> TransportConfig { + let mut transport_config = TransportConfig::default(); + transport_config.keep_alive_interval(Some(Duration::from_millis(500))); + transport_config.max_idle_timeout(Some(Duration::from_secs(1).try_into().unwrap())); - #[tokio::test] - async fn pong_is_not_empty() { - let mut data = Vec::new(); - Pong.encode(&mut data).await.unwrap(); - - assert!(!data.is_empty()); - } + transport_config } diff --git a/rkvm-server/Cargo.toml b/rkvm-server/Cargo.toml index e3f32e9..9a9ad79 100644 --- a/rkvm-server/Cargo.toml +++ b/rkvm-server/Cargo.toml @@ -13,8 +13,7 @@ serde = { version = "1.0.117", features = ["derive"] } toml = "0.5.7" env_logger = "0.8.1" clap = { version = "4.2.2", features = ["derive"] } -tokio-rustls = "0.24.0" -rustls-pemfile = "1.0.2" +rustls-pemfile = "2.1.2" thiserror = "1.0.40" slab = "0.4.8" rand = "0.8.5" @@ -22,6 +21,7 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } tracing = "0.1.37" rkvm-net = { path = "../rkvm-net" } rkvm-input = { path = "../rkvm-input" } +quinn = "0.11.2" [package.metadata.rpm] package = "rkvm-server" diff --git a/rkvm-server/src/main.rs b/rkvm-server/src/main.rs index eb3f402..36ec79e 100644 --- a/rkvm-server/src/main.rs +++ b/rkvm-server/src/main.rs @@ -52,8 +52,8 @@ async fn main() -> ExitCode { } }; - let acceptor = match tls::configure(&config.certificate, &config.key).await { - Ok(acceptor) => acceptor, + let server_config = match tls::configure(&config.certificate, &config.key).await { + Ok(server_config) => server_config, Err(err) => { tracing::error!("Error configuring TLS: {}", err); return ExitCode::FAILURE; @@ -71,7 +71,7 @@ async fn main() -> ExitCode { let propagate_switch_keys = config.propagate_switch_keys.unwrap_or(true); tokio::select! { - result = server::run(config.listen, acceptor, &config.password, &switch_keys, propagate_switch_keys) => { + result = server::run(config.listen, server_config, &config.password, &switch_keys, propagate_switch_keys) => { if let Err(err) = result { tracing::error!("Error: {}", err); return ExitCode::FAILURE; diff --git a/rkvm-server/src/server.rs b/rkvm-server/src/server.rs index b783181..0270226 100644 --- a/rkvm-server/src/server.rs +++ b/rkvm-server/src/server.rs @@ -1,3 +1,4 @@ +use quinn::{ConnectionError, Endpoint, Incoming, SendDatagramError, ServerConfig}; use rkvm_input::abs::{AbsAxis, AbsInfo}; use rkvm_input::event::Event; use rkvm_input::key::{Key, KeyEvent}; @@ -7,20 +8,17 @@ use rkvm_input::sync::SyncEvent; use rkvm_net::auth::{AuthChallenge, AuthResponse, AuthStatus}; use rkvm_net::message::Message; use rkvm_net::version::Version; -use rkvm_net::{Pong, Update}; +use rkvm_net::{Datagram, DeviceInfo}; use slab::Slab; use std::collections::{HashMap, HashSet, VecDeque}; use std::ffi::CString; use std::io::{self, ErrorKind}; +use std::iter; use std::net::SocketAddr; -use std::time::Instant; use thiserror::Error; -use tokio::io::{AsyncWriteExt, BufStream}; -use tokio::net::{TcpListener, TcpStream}; +use tokio::io::{AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; use tokio::sync::mpsc::error::TrySendError; use tokio::sync::mpsc::{self, Receiver, Sender}; -use tokio::time; -use tokio_rustls::TlsAcceptor; use tracing::Instrument; #[derive(Error, Debug)] @@ -35,14 +33,35 @@ pub enum Error { pub async fn run( listen: SocketAddr, - acceptor: TlsAcceptor, + mut config: ServerConfig, password: &str, switch_keys: &HashSet, propagate_switch_keys: bool, ) -> Result<(), Error> { - let listener = TcpListener::bind(&listen).await.map_err(Error::Network)?; + config.transport_config(rkvm_net::transport_config().into()); + + let endpoint = Endpoint::server(config, listen).map_err(Error::Network)?; tracing::info!("Listening on {}", listen); + let (connection_sender, mut connection_receiver) = mpsc::channel(1); + tokio::spawn(async move { + loop { + tokio::select! { + connection = endpoint.accept() => { + let connection = match connection { + Some(connection) => connection, + None => break, + }; + + if connection_sender.send(connection).await.is_err() { + break; + } + } + _ = connection_sender.closed() => break, + } + } + }); + let mut monitor = Monitor::new(); let mut devices = Slab::::new(); let mut clients = Slab::<(Sender<_>, SocketAddr)>::new(); @@ -52,14 +71,17 @@ pub async fn run( let mut pressed_keys = HashSet::new(); let (events_sender, mut events_receiver) = mpsc::channel(1); - loop { let event = async { events_receiver.recv().await.unwrap() }; tokio::select! { - result = listener.accept() => { - let (stream, addr) = result.map_err(Error::Network)?; - let acceptor = acceptor.clone(); + connection = connection_receiver.recv() => { + let connection = match connection { + Some(connection) => connection, + None => break, + }; + + let addr = connection.remote_address(); let password = password.to_owned(); // Remove dead clients. @@ -87,12 +109,13 @@ pub async fn run( let (sender, receiver) = mpsc::channel(1); clients.insert((sender, addr)); - let span = tracing::info_span!("connection", addr = %addr); + let span = tracing::info_span!("client", addr = %addr); + tokio::spawn( async move { tracing::info!("Connected"); - match client(init_updates, receiver, stream, acceptor, &password).await { + match client(init_updates, receiver, connection, &password).await { Ok(()) => tracing::info!("Disconnected"), Err(err) => tracing::error!("Disconnected: {}", err), } @@ -100,8 +123,8 @@ pub async fn run( .instrument(span), ); } - result = monitor.read() => { - let mut interceptor = result.map_err(Error::Input)?; + interceptor = monitor.read() => { + let mut interceptor = interceptor.map_err(Error::Input)?; let name = interceptor.name().to_owned(); let id = devices.vacant_key(); @@ -275,8 +298,30 @@ pub async fn run( } } } -} + Ok(()) +} +enum Update { + CreateDevice { + id: usize, + name: CString, + vendor: u16, + product: u16, + version: u16, + rel: HashSet, + abs: HashMap, + keys: HashSet, + delay: Option, + period: Option, + }, + DestroyDevice { + id: usize, + }, + Event { + id: usize, + event: Event, + }, +} struct Device { name: CString, vendor: u16, @@ -300,29 +345,27 @@ enum ClientError { Auth, #[error(transparent)] Rand(#[from] rand::Error), + #[error(transparent)] + Connection(#[from] ConnectionError), } async fn client( mut init_updates: VecDeque, mut receiver: Receiver, - stream: TcpStream, - acceptor: TlsAcceptor, + connection: Incoming, password: &str, ) -> Result<(), ClientError> { - let stream = rkvm_net::timeout(rkvm_net::TLS_TIMEOUT, acceptor.accept(stream)).await?; - tracing::info!("TLS connected"); + let connection = connection.await?; - let mut stream = BufStream::with_capacity(1024, 1024, stream); + let (data_write, data_read) = connection.open_bi().await?; - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { - Version::CURRENT.encode(&mut stream).await?; - stream.flush().await?; + let mut data_write = BufWriter::new(data_write); + let mut data_read = BufReader::new(data_read); - Ok(()) - }) - .await?; + Version::CURRENT.encode(&mut data_write).await?; + data_write.flush().await?; - let version = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Version::decode(&mut stream)).await?; + let version = Version::decode(&mut data_read).await?; if version != Version::CURRENT { return Err(ClientError::Version { server: Version::CURRENT, @@ -332,28 +375,18 @@ async fn client( let challenge = AuthChallenge::generate().await?; - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { - challenge.encode(&mut stream).await?; - stream.flush().await?; + challenge.encode(&mut data_write).await?; + data_write.flush().await?; - Ok(()) - }) - .await?; + let response = AuthResponse::decode(&mut data_read).await?; - let response = - rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthResponse::decode(&mut stream)).await?; let status = match response.verify(&challenge, password) { true => AuthStatus::Passed, false => AuthStatus::Failed, }; - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { - status.encode(&mut stream).await?; - stream.flush().await?; - - Ok(()) - }) - .await?; + status.encode(&mut data_write).await?; + data_write.flush().await?; if status == AuthStatus::Failed { return Err(ClientError::Auth); @@ -361,10 +394,16 @@ async fn client( tracing::info!("Authenticated successfully"); - let mut interval = time::interval(rkvm_net::PING_INTERVAL); + let mut senders = HashMap::new(); + let (error_sender, mut error_receiver) = mpsc::channel(1); + + data_write.shutdown().await?; + + let mut enable_datagrams = true; + let mut datagram_events = Vec::new(); loop { - let recv = async { + let update = async { match init_updates.pop_front() { Some(update) => Some(update), None => receiver.recv().await, @@ -372,12 +411,9 @@ async fn client( }; let update = tokio::select! { - // Make sure pings have priority. - // The client could time out otherwise. - biased; - - _ = interval.tick() => Some(Update::Ping), - recv = recv => recv, + update = update => update, + err = connection.closed() => return Err(err.into()), + err = error_receiver.recv() => return Err(err.unwrap()), }; let update = match update { @@ -385,29 +421,161 @@ async fn client( None => break, }; - let start = Instant::now(); - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { - update.encode(&mut stream).await?; - stream.flush().await?; + match update { + Update::CreateDevice { + id, + name, + vendor, + product, + version, + rel, + abs, + keys, + delay, + period, + } => { + let write = connection.open_uni().await?; + let stream_id = write.id(); - Ok(()) - }) - .await?; - let duration = start.elapsed(); + let mut write = BufWriter::new(write); - if let Update::Ping = update { - // Keeping these as debug because it's not as frequent as other updates. - tracing::debug!(duration = ?duration, "Sent ping"); + let device_info = DeviceInfo { + id, + name, + vendor, + product, + version, + rel, + abs, + keys, + delay, + period, + }; - let start = Instant::now(); - rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Pong::decode(&mut stream)).await?; - let duration = start.elapsed(); + let (device_sender, stream_receiver) = mpsc::channel(1); + senders.insert(id, device_sender); - tracing::debug!(duration = ?duration, "Received pong"); + let error_sender = error_sender.clone(); + let span = tracing::debug_span!("stream", id = %stream_id); + + tokio::spawn( + async move { + tracing::debug!("Stream connected"); + + match stream(&mut write, &device_info, stream_receiver).await { + Ok(()) => tracing::debug!("Stream disconnected"), + Err(err) => { + tracing::debug!("Stream disconnected: {}", err); + let _ = error_sender.send(err).await; + } + } + } + .instrument(span), + ); + } + Update::DestroyDevice { id } => { + senders.remove(&id).unwrap(); + } + Update::Event { id, event } => match event { + Event::Rel(_) if enable_datagrams => { + datagram_events.push(event); + } + Event::Sync(SyncEvent::All) if enable_datagrams && !datagram_events.is_empty() => { + datagram_events.push(event); + + let mut message = Vec::new(); + Datagram { + id, + events: datagram_events.as_slice().into(), + } + .encode(&mut message) + .await?; + + let length = message.len(); + + let err = match connection.send_datagram(message.into()) { + Ok(()) => { + tracing::trace!( + "Wrote {} unreliable event{}", + datagram_events.len(), + if datagram_events.len() == 1 { "" } else { "s" } + ); + + datagram_events.clear(); + continue; + } + Err(err) => err, + }; + + match err { + SendDatagramError::UnsupportedByPeer + | SendDatagramError::Disabled + | SendDatagramError::TooLarge => { + let sender = &senders[&id]; + for event in datagram_events.drain(..) { + let _ = sender.send(event).await; + } + + if matches!(err, SendDatagramError::TooLarge) { + tracing::warn!(length = %length, "Datagram too large"); + } else { + tracing::warn!("Disabling datagram support: {}", err); + enable_datagrams = false; + } + } + SendDatagramError::ConnectionLost(err) => return Err(err.into()), + } + } + _ => { + // Send only consecutive relative events as datagrams. + let sender = &senders[&id]; + for event in datagram_events.drain(..).chain(iter::once(event)) { + let _ = sender.send(event).await; + } + } + }, } - - tracing::trace!("Wrote an update"); } Ok(()) } + +async fn stream( + write: &mut T, + device_info: &DeviceInfo, + mut receiver: Receiver, +) -> Result<(), ClientError> { + let span = tracing::info_span!("device", id = device_info.id); + + async { + device_info.encode(write).await?; + write.flush().await?; + + let mut events = 0usize; + + while let Some(event) = receiver.recv().await { + event.encode(write).await?; + events += 1; + + // Coalesce multiple events into a single QUIC packet. + // The `Interceptor` won't emit them until it receives a sync event anyway. + if let Event::Sync(_) = event { + write.flush().await?; + + tracing::trace!( + "Wrote {} event{}", + events, + if events == 1 { "" } else { "s" } + ); + + events = 0; + } + } + + write.shutdown().await?; + + Ok(()) + } + .instrument(span) + .await +} diff --git a/rkvm-server/src/tls.rs b/rkvm-server/src/tls.rs index ff0d00c..694a905 100644 --- a/rkvm-server/src/tls.rs +++ b/rkvm-server/src/tls.rs @@ -1,11 +1,10 @@ -use rustls_pemfile::Item; +use quinn::crypto::rustls::{NoInitialCipherSuite, QuicServerConfig}; +use quinn::{rustls, ServerConfig}; +use std::io; use std::path::Path; use std::sync::Arc; -use std::{io, iter}; use thiserror::Error; use tokio::fs; -use tokio_rustls::rustls::{self, Certificate, PrivateKey, ServerConfig}; -use tokio_rustls::TlsAcceptor; #[derive(Error, Debug)] pub enum Error { @@ -13,70 +12,26 @@ pub enum Error { Rustls(#[from] rustls::Error), #[error(transparent)] Io(#[from] io::Error), - #[error("Multiple private keys provided")] - MultipleKeys, - #[error("No suitable private keys provided")] + #[error("No private key provided")] NoKeys, + #[error(transparent)] + NoInitialCipherSuite(#[from] NoInitialCipherSuite), } -pub async fn configure(certificate: &Path, key: &Path) -> Result { - enum LoadedItem { - Certificate(Vec), - Key(Vec), - } +pub async fn configure(certificate: &Path, key: &Path) -> Result { + let certificate = fs::read(certificate).await?; + let certificate = rustls_pemfile::certs(&mut &*certificate).collect::>()?; - let certificate = fs::read_to_string(certificate).await?; - let key = fs::read_to_string(key).await?; + let key = fs::read(key).await?; + let key = rustls_pemfile::private_key(&mut &*key)?.ok_or(Error::NoKeys)?; - let certificates_iter = iter::from_fn({ - let mut buffer = certificate.as_bytes(); - - move || rustls_pemfile::read_one(&mut buffer).transpose() - }) - .filter_map(|item| match item { - Ok(Item::X509Certificate(data)) => Some(Ok(LoadedItem::Certificate(data))), - Err(err) => Some(Err(err)), - _ => None, - }); - - let keys_iter = iter::from_fn({ - let mut buffer = key.as_bytes(); - - move || rustls_pemfile::read_one(&mut buffer).transpose() - }) - .filter_map(|item| match item { - Ok(Item::RSAKey(data)) | Ok(Item::PKCS8Key(data)) | Ok(Item::ECKey(data)) => { - Some(Ok(LoadedItem::Key(data))) - } - Err(err) => Some(Err(err)), - _ => None, - }); - - let mut certificates = Vec::new(); - let mut key = None; - - for item in certificates_iter.chain(keys_iter) { - let item = item?; - - match item { - LoadedItem::Certificate(data) => certificates.push(Certificate(data)), - LoadedItem::Key(data) => { - if key.is_some() { - return Err(Error::MultipleKeys); - } - - key = Some(PrivateKey(data)); - } - } - } - - let key = key.ok_or(Error::NoKeys)?; - - ServerConfig::builder() - .with_safe_defaults() + let config = rustls::ServerConfig::builder() .with_no_client_auth() - .with_single_cert(certificates, key) - .map(Arc::new) - .map(Into::into) - .map_err(Into::into) + .with_single_cert(certificate, key)?; + + let config = QuicServerConfig::try_from(config)?; + let config = Arc::new(config); + let config = ServerConfig::with_crypto(config); + + Ok(config) }