Port rkvm to QUIC

This commit is contained in:
Jan Trefil 2024-07-21 15:50:50 +02:00
parent 5ec294dc03
commit 897eb0a68e
13 changed files with 862 additions and 416 deletions

303
Cargo.lock generated
View file

@ -109,9 +109,9 @@ dependencies = [
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.21.7" version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]] [[package]]
name = "bincode" name = "bincode"
@ -178,6 +178,12 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47de7e88bbbd467951ae7f5a6f34f70d1b4d9cfce53d5fd70f74ebe118b3db56" checksum = "47de7e88bbbd467951ae7f5a6f34f70d1b4d9cfce53d5fd70f74ebe118b3db56"
[[package]]
name = "cesu8"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
[[package]] [[package]]
name = "cexpr" name = "cexpr"
version = "0.6.0" version = "0.6.0"
@ -250,6 +256,32 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
[[package]]
name = "combine"
version = "4.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
dependencies = [
"bytes",
"memchr",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f"
[[package]] [[package]]
name = "cpufeatures" name = "cpufeatures"
version = "0.2.12" version = "0.2.12"
@ -510,6 +542,26 @@ version = "1.70.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800"
[[package]]
name = "jni"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec"
dependencies = [
"cesu8",
"combine",
"jni-sys",
"log",
"thiserror",
"walkdir",
]
[[package]]
name = "jni-sys"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.5.0" version = "1.5.0"
@ -611,6 +663,34 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "num-bigint"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "num_cpus" name = "num_cpus"
version = "1.16.0" version = "1.16.0"
@ -636,6 +716,12 @@ version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "openssl-probe"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]] [[package]]
name = "overload" name = "overload"
version = "0.1.1" version = "0.1.1"
@ -691,6 +777,54 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "quinn"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4ceeeeabace7857413798eb1ffa1e9c905a9946a57d81fb69b4b71c4d8eb3ad"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"thiserror",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddf517c03a109db8100448a4be38d498df8a210a99fe0e1b9eaf39e78c640efe"
dependencies = [
"bytes",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-platform-verifier",
"slab",
"thiserror",
"tinyvec",
"tracing",
]
[[package]]
name = "quinn-udp"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9096629c45860fc7fb143e125eb826b5e721e10be3263160c7d60ca832cf8c46"
dependencies = [
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.36" version = "1.0.36"
@ -803,14 +937,13 @@ name = "rkvm-client"
version = "0.6.1" version = "0.6.1"
dependencies = [ dependencies = [
"clap", "clap",
"env_logger", "quinn",
"rkvm-input", "rkvm-input",
"rkvm-net", "rkvm-net",
"rustls-pemfile", "rustls-pemfile",
"serde", "serde",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-rustls",
"toml", "toml",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -839,6 +972,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"bincode", "bincode",
"hmac", "hmac",
"quinn",
"rand", "rand",
"rkvm-input", "rkvm-input",
"serde", "serde",
@ -854,6 +988,7 @@ version = "0.6.1"
dependencies = [ dependencies = [
"clap", "clap",
"env_logger", "env_logger",
"quinn",
"rand", "rand",
"rkvm-input", "rkvm-input",
"rkvm-net", "rkvm-net",
@ -862,7 +997,6 @@ dependencies = [
"slab", "slab",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-rustls",
"toml", "toml",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -895,43 +1029,125 @@ dependencies = [
[[package]] [[package]]
name = "rustls" name = "rustls"
version = "0.21.12" version = "0.23.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" checksum = "4828ea528154ae444e5a642dbb7d5623354030dc9822b83fd9bb79683c7399d0"
dependencies = [ dependencies = [
"log", "once_cell",
"ring", "ring",
"rustls-pki-types",
"rustls-webpki", "rustls-webpki",
"sct", "subtle",
"zeroize",
]
[[package]]
name = "rustls-native-certs"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba"
dependencies = [
"openssl-probe",
"rustls-pemfile",
"rustls-pki-types",
"schannel",
"security-framework",
] ]
[[package]] [[package]]
name = "rustls-pemfile" name = "rustls-pemfile"
version = "1.0.4" version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d"
dependencies = [ dependencies = [
"base64", "base64",
"rustls-pki-types",
] ]
[[package]]
name = "rustls-pki-types"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d"
[[package]]
name = "rustls-platform-verifier"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e3beb939bcd33c269f4bf946cc829fcd336370267c4a927ac0399c84a3151a1"
dependencies = [
"core-foundation",
"core-foundation-sys",
"jni",
"log",
"once_cell",
"rustls",
"rustls-native-certs",
"rustls-platform-verifier-android",
"rustls-webpki",
"security-framework",
"security-framework-sys",
"webpki-roots",
"winapi",
]
[[package]]
name = "rustls-platform-verifier-android"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84e217e7fdc8466b5b35d30f8c0a30febd29173df4a3a0c2115d306b9c4117ad"
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-webpki"
version = "0.101.7" version = "0.102.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" checksum = "f9a6fccd794a42c2c105b513a2f62bc3fd8f3ba57a4593677ceb0bd035164d78"
dependencies = [ dependencies = [
"ring", "ring",
"rustls-pki-types",
"untrusted", "untrusted",
] ]
[[package]] [[package]]
name = "sct" name = "same-file"
version = "0.7.1" version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [ dependencies = [
"ring", "winapi-util",
"untrusted", ]
[[package]]
name = "schannel"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "security-framework"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0"
dependencies = [
"bitflags 2.6.0",
"core-foundation",
"core-foundation-sys",
"libc",
"num-bigint",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7"
dependencies = [
"core-foundation-sys",
"libc",
] ]
[[package]] [[package]]
@ -1097,6 +1313,21 @@ dependencies = [
"once_cell", "once_cell",
] ]
[[package]]
name = "tinyvec"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.38.0" version = "1.38.0"
@ -1126,16 +1357,6 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "tokio-rustls"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081"
dependencies = [
"rustls",
"tokio",
]
[[package]] [[package]]
name = "toml" name = "toml"
version = "0.5.11" version = "0.5.11"
@ -1151,6 +1372,7 @@ version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [ dependencies = [
"log",
"pin-project-lite", "pin-project-lite",
"tracing-attributes", "tracing-attributes",
"tracing-core", "tracing-core",
@ -1242,12 +1464,31 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.11.0+wasi-snapshot-preview1" version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "webpki-roots"
version = "0.26.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd7c23921eeb1713a4e851530e9b9756e4fb0e89978582942612524cf09f01cd"
dependencies = [
"rustls-pki-types",
]
[[package]] [[package]]
name = "which" name = "which"
version = "4.4.2" version = "4.4.2"
@ -1429,3 +1670,9 @@ name = "windows_x86_64_msvc"
version = "0.52.6" version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "zeroize"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"

View file

@ -13,13 +13,12 @@ rkvm-input = { path = "../rkvm-input" }
rkvm-net = { path = "../rkvm-net" } rkvm-net = { path = "../rkvm-net" }
serde = { version = "1.0.117", features = ["derive"] } serde = { version = "1.0.117", features = ["derive"] }
toml = "0.5.7" toml = "0.5.7"
env_logger = "0.8.1"
clap = { version = "4.2.2", features = ["derive"] } clap = { version = "4.2.2", features = ["derive"] }
thiserror = "1.0.40" thiserror = "1.0.40"
tokio-rustls = "0.24.0" rustls-pemfile = "2.1.2"
rustls-pemfile = "1.0.2"
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
quinn = "0.11.2"
[package.metadata.rpm] [package.metadata.rpm]
package = "rkvm-client" package = "rkvm-client"

View file

@ -1,23 +1,31 @@
use quinn::ClientConfig;
use quinn::ConnectError;
use quinn::Connection;
use quinn::ConnectionError;
use quinn::Endpoint;
use rkvm_input::event::Event;
use rkvm_input::writer::Writer; use rkvm_input::writer::Writer;
use rkvm_net::auth::{AuthChallenge, AuthStatus}; use rkvm_net::auth::{AuthChallenge, AuthStatus};
use rkvm_net::message::Message; use rkvm_net::message::Message;
use rkvm_net::version::Version; use rkvm_net::version::Version;
use rkvm_net::{Pong, Update}; use rkvm_net::Datagram;
use std::collections::hash_map::Entry; use rkvm_net::DeviceInfo;
use std::collections::HashMap; use std::collections::HashMap;
use std::io; use std::io;
use std::time::Instant; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use thiserror::Error; use thiserror::Error;
use tokio::io::{AsyncWriteExt, BufStream}; use tokio::io::AsyncRead;
use tokio::net::TcpStream; use tokio::io::AsyncWriteExt;
use tokio::time; use tokio::io::BufReader;
use tokio_rustls::rustls::ServerName; use tokio::io::BufWriter;
use tokio_rustls::TlsConnector; use tokio::net;
use tokio::sync::mpsc::{self, Sender};
use tracing::Instrument;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
#[error("Network error: {0}")] #[error("Network error: {0}")]
Network(io::Error), Network(#[from] NetworkError),
#[error("Input error: {0}")] #[error("Input error: {0}")]
Input(io::Error), Input(io::Error),
#[error("Incompatible server version (got {server}, expected {client})")] #[error("Incompatible server version (got {server}, expected {client})")]
@ -26,45 +34,39 @@ pub enum Error {
Auth, Auth,
} }
#[derive(Error, Debug)]
pub enum NetworkError {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Connect(#[from] ConnectError),
#[error(transparent)]
Connection(#[from] ConnectionError),
}
pub async fn run( pub async fn run(
hostname: &ServerName, hostname: &str,
port: u16, port: u16,
connector: TlsConnector, mut config: ClientConfig,
password: &str, password: &str,
) -> Result<(), Error> { ) -> Result<(), Error> {
// Intentionally don't impose any timeout for TCP connect. config.transport_config(rkvm_net::transport_config().into());
let stream = match hostname {
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port)).await,
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port)).await,
_ => unimplemented!("Unhandled rustls ServerName variant: {:?}", hostname),
}
.map_err(Error::Network)?;
tracing::info!("Connected to server"); let connection = connect(hostname, port, config).await?;
let (data_write, data_read) = connection.accept_bi().await.map_err(NetworkError::from)?;
let stream = rkvm_net::timeout( let mut data_write = BufWriter::new(data_write);
rkvm_net::TLS_TIMEOUT, let mut data_read = BufReader::new(data_read);
connector.connect(hostname.clone(), stream),
)
.await
.map_err(Error::Network)?;
tracing::info!("TLS connected"); Version::CURRENT
.encode(&mut data_write)
let mut stream = BufStream::with_capacity(1024, 1024, stream);
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
Version::CURRENT.encode(&mut stream).await?;
stream.flush().await?;
Ok(())
})
.await
.map_err(Error::Network)?;
let version = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Version::decode(&mut stream))
.await .await
.map_err(Error::Network)?; .map_err(NetworkError::from)?;
data_write.flush().await.map_err(NetworkError::from)?;
let version = Version::decode(&mut data_read)
.await
.map_err(NetworkError::from)?;
if version != Version::CURRENT { if version != Version::CURRENT {
return Err(Error::Version { return Err(Error::Version {
@ -73,24 +75,21 @@ pub async fn run(
}); });
} }
let challenge = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthChallenge::decode(&mut stream)) let challenge = AuthChallenge::decode(&mut data_read)
.await .await
.map_err(Error::Network)?; .map_err(NetworkError::from)?;
let response = challenge.respond(password); let response = challenge.respond(password);
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { response
response.encode(&mut stream).await?; .encode(&mut data_write)
stream.flush().await?;
Ok(())
})
.await
.map_err(Error::Network)?;
let status = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthStatus::decode(&mut stream))
.await .await
.map_err(Error::Network)?; .map_err(NetworkError::from)?;
data_write.flush().await.map_err(NetworkError::from)?;
let status = AuthStatus::decode(&mut data_read)
.await
.map_err(NetworkError::from)?;
match status { match status {
AuthStatus::Passed => {} AuthStatus::Passed => {}
@ -99,110 +98,235 @@ pub async fn run(
tracing::info!("Authenticated successfully"); tracing::info!("Authenticated successfully");
let mut start = Instant::now(); data_write.shutdown().await.map_err(NetworkError::from)?;
let mut interval = time::interval(rkvm_net::PING_INTERVAL + rkvm_net::READ_TIMEOUT); let (error_sender, mut error_receiver) = mpsc::channel(1);
let mut writers = HashMap::new(); let (device_sender, mut device_receiver) = mpsc::channel(1);
// Interval ticks immediately after creation. let mut devices = HashMap::new();
interval.tick().await;
loop { loop {
let update = tokio::select! { let device = async { device_receiver.recv().await.unwrap() };
update = Update::decode(&mut stream) => update.map_err(Error::Network)?, let read = tokio::select! {
_ = interval.tick() => return Err(Error::Network(io::Error::new(io::ErrorKind::TimedOut, "Ping timed out"))), read = connection.accept_uni() => read.map_err(NetworkError::from)?,
device = device => {
match device {
DeviceEvent::Create { id, sender } => {
if devices.insert(id, sender).is_some() {
return Err(NetworkError::from(io::Error::new(io::ErrorKind::AlreadyExists, "Device already exists")).into());
}
continue;
}
DeviceEvent::Destroy { id } => {
devices.remove(&id).unwrap();
continue;
}
}
}
datagram = connection.read_datagram() => {
let datagram = datagram.map_err(NetworkError::from)?;
let datagram = Datagram::decode(&mut &*datagram).await.map_err(NetworkError::from)?;
let sender = match devices.get(&datagram.id) {
Some(sender) => sender,
None => {
tracing::warn!(id = %datagram.id, "Received datagram for unknown device");
continue;
}
};
let _ = sender.send(datagram.events.into_owned()).await;
continue;
}
err = error_receiver.recv() => return Err(err.unwrap()),
}; };
match update { let stream_id = read.id();
Update::CreateDevice {
id, let read = BufReader::new(read);
name,
vendor, let error_sender = error_sender.clone();
product, let device_sender = device_sender.clone();
version, let span = tracing::debug_span!("stream", id = %stream_id);
rel,
abs, tokio::spawn(
keys, async move {
delay, tracing::debug!("Stream connected");
period,
} => { match stream(read, device_sender).await {
let entry = writers.entry(id); Ok(()) => {
if let Entry::Occupied(_) = entry { tracing::debug!("Stream disconnected");
return Err(Error::Network(io::Error::new( }
io::ErrorKind::InvalidData, Err(err) => {
"Server created the same device twice", tracing::debug!("Stream disconnected: {}", err);
))); let _ = error_sender.send(err).await;
}
} }
let writer = async {
Writer::builder()?
.name(&name)
.vendor(vendor)
.product(product)
.version(version)
.rel(rel)?
.abs(abs)?
.key(keys)?
.delay(delay)?
.period(period)?
.build()
.await
}
.await
.map_err(Error::Input)?;
entry.or_insert(writer);
tracing::info!(
id = %id,
name = ?name,
vendor = %vendor,
product = %product,
version = %version,
"Created new device"
);
} }
Update::DestroyDevice { id } => { .instrument(span),
if writers.remove(&id).is_none() { );
return Err(Error::Network(io::Error::new(
io::ErrorKind::InvalidData,
"Server destroyed a nonexistent device",
)));
}
tracing::info!(id = %id, "Destroyed device");
}
Update::Event { id, event } => {
let writer = writers.get_mut(&id).ok_or_else(|| {
Error::Network(io::Error::new(
io::ErrorKind::InvalidData,
"Server sent an event to a nonexistent device",
))
})?;
writer.write(&event).await.map_err(Error::Input)?;
tracing::trace!(id = %id, "Wrote an event to device");
}
Update::Ping => {
let duration = start.elapsed();
tracing::debug!(duration = ?duration, "Received ping");
start = Instant::now();
interval.reset();
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
Pong.encode(&mut stream).await?;
stream.flush().await?;
Ok(())
})
.await
.map_err(Error::Network)?;
let duration = start.elapsed();
tracing::debug!(duration = ?duration, "Sent pong");
}
}
} }
} }
async fn connect(
hostname: &str,
port: u16,
config: ClientConfig,
) -> Result<Connection, NetworkError> {
let mut last_err = None;
let addrs = net::lookup_host((hostname, port))
.await
.map_err(NetworkError::from)?;
for addr in addrs {
let bind = match addr {
SocketAddr::V4(_) => (Ipv4Addr::UNSPECIFIED, 0).into(),
SocketAddr::V6(_) => (Ipv6Addr::UNSPECIFIED, 0).into(),
};
let endpoint = match Endpoint::client(bind) {
Ok(endpoint) => endpoint,
Err(err) => {
tracing::debug!(addr = %addr, "Error binding: {}", err);
last_err = Some(err.into());
continue;
}
};
let connection = match endpoint.connect_with(config.clone(), addr, hostname) {
Ok(connection) => connection,
Err(err) => {
tracing::debug!(addr = %addr, "Error connecting: {}", err);
last_err = Some(err.into());
continue;
}
};
let connection = match connection.await {
Ok(connection) => connection,
Err(err) => {
tracing::debug!(addr = %addr, "Error connecting: {}", err);
last_err = Some(err.into());
continue;
}
};
tracing::info!(addr = %addr, "Connected");
return Ok(connection);
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "No addresses resolved").into()
}))
}
enum DeviceEvent {
Create {
id: usize,
sender: Sender<Vec<Event>>,
},
Destroy {
id: usize,
},
}
async fn stream<T: AsyncRead + Send + Unpin + 'static>(
mut read: T,
device_sender: Sender<DeviceEvent>,
) -> Result<(), Error> {
let device_info = DeviceInfo::decode(&mut read)
.await
.map_err(NetworkError::from)?;
let id = device_info.id;
let span = tracing::info_span!("device", id = ?device_info.id);
let mut writer = build(device_info).await.map_err(Error::Input)?;
let (datagram_sender, mut datagram_receiver) = mpsc::channel(1);
let event = DeviceEvent::Create {
id,
sender: datagram_sender,
};
if device_sender.send(event).await.is_err() {
return Ok(());
}
let (event_sender, mut event_receiver) = mpsc::channel(1);
tokio::spawn(
async move {
loop {
let event = tokio::select! {
event = Event::decode(&mut read) => event,
_ = event_sender.closed() => break,
};
if event.is_err() | event_sender.send(event).await.is_err() {
break;
}
}
}
.instrument(span.clone()),
);
let result = async move {
loop {
let event = async { event_receiver.recv().await.unwrap() };
tokio::select! {
event = event => {
let event = event.map_err(NetworkError::from)?;
writer.write(&event).await.map_err(Error::Input)?;
}
datagram = datagram_receiver.recv() => {
let datagram = match datagram {
Some(datagram) => datagram,
None => break,
};
for event in datagram {
writer.write(&event).await.map_err(Error::Input)?;
}
}
}
}
Ok(())
}
.instrument(span)
.await;
let _ = device_sender.send(DeviceEvent::Destroy { id }).await;
result
}
async fn build(device_info: DeviceInfo) -> Result<Writer, io::Error> {
let writer = Writer::builder()?
.name(&device_info.name)
.vendor(device_info.vendor)
.product(device_info.product)
.version(device_info.version)
.rel(device_info.rel)?
.abs(device_info.abs)?
.key(device_info.keys)?
.delay(device_info.delay)?
.period(device_info.period)?
.build()
.await?;
tracing::info!(
name = ?device_info.name,
vendor = %device_info.vendor,
product = %device_info.product,
version = %device_info.version,
"Created new device"
);
Ok(writer)
}

View file

@ -1,10 +1,7 @@
use serde::de::{self, Visitor}; use serde::de::{self, Visitor};
use serde::{Deserialize, Deserializer}; use serde::{Deserialize, Deserializer};
use std::fmt::{self, Formatter}; use std::fmt::{self, Formatter};
use std::net::SocketAddr;
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr;
use tokio_rustls::rustls::ServerName;
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(rename_all = "kebab-case")] #[serde(rename_all = "kebab-case")]
@ -15,7 +12,7 @@ pub struct Config {
} }
pub struct Server { pub struct Server {
pub hostname: ServerName, pub hostname: String,
pub port: u16, pub port: u16,
} }
@ -41,19 +38,11 @@ impl<'de> Visitor<'de> for ServerVisitor {
where where
E: de::Error, E: de::Error,
{ {
// Parsing IPv6 socket addresses can get quite hairy, so let the SocketAddr parser do it for us.
if let Ok(socket_addr) = SocketAddr::from_str(data) {
return Ok(Server {
hostname: ServerName::IpAddress(socket_addr.ip()),
port: socket_addr.port(),
});
}
let (hostname, port) = data let (hostname, port) = data
.split_once(':') .rsplit_once(':')
.ok_or_else(|| E::custom("No port provided"))?; .ok_or_else(|| E::custom("No port provided"))?;
let hostname = hostname.try_into().map_err(E::custom)?; let hostname = hostname.to_owned();
let port = port.parse().map_err(E::custom)?; let port = port.parse().map_err(E::custom)?;
Ok(Server { hostname, port }) Ok(Server { hostname, port })
@ -62,8 +51,6 @@ impl<'de> Visitor<'de> for ServerVisitor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::net::Ipv6Addr;
use super::*; use super::*;
#[derive(Deserialize)] #[derive(Deserialize)]
@ -91,7 +78,7 @@ mod tests {
.unwrap() .unwrap()
.server; .server;
let expected = Server { let expected = Server {
hostname: "127.0.0.1".try_into().unwrap(), hostname: "127.0.0.1".to_owned(),
port: 8523, port: 8523,
}; };
@ -105,19 +92,12 @@ mod tests {
.unwrap() .unwrap()
.server; .server;
let expected = Server { let expected = Server {
hostname: "::1".try_into().unwrap(), hostname: "[::1]".to_owned(),
port: 8523, port: 8523,
}; };
assert_eq!(parsed.hostname, expected.hostname); assert_eq!(parsed.hostname, expected.hostname);
assert_eq!(parsed.port, expected.port); assert_eq!(parsed.port, expected.port);
let parsed_ip = match parsed.hostname {
ServerName::IpAddress(parsed_ip) => parsed_ip,
_ => unreachable!(),
};
assert_eq!(parsed_ip, Ipv6Addr::from_str("::1").unwrap());
} }
#[test] #[test]

View file

@ -48,8 +48,8 @@ async fn main() -> ExitCode {
} }
}; };
let connector = match tls::configure(&config.certificate).await { let client_config = match tls::configure(&config.certificate).await {
Ok(connector) => connector, Ok(client_config) => client_config,
Err(err) => { Err(err) => {
tracing::error!("Error configuring TLS: {}", err); tracing::error!("Error configuring TLS: {}", err);
return ExitCode::FAILURE; return ExitCode::FAILURE;
@ -57,7 +57,7 @@ async fn main() -> ExitCode {
}; };
tokio::select! { tokio::select! {
result = client::run(&config.server.hostname, config.server.port, connector, &config.password) => { result = client::run(&config.server.hostname, config.server.port, client_config, &config.password) => {
if let Err(err) = result { if let Err(err) = result {
tracing::error!("Error: {}", err); tracing::error!("Error: {}", err);
return ExitCode::FAILURE; return ExitCode::FAILURE;

View file

@ -1,10 +1,12 @@
use quinn::crypto::rustls::{NoInitialCipherSuite, QuicClientConfig};
use quinn::rustls;
use quinn::rustls::RootCertStore;
use quinn::ClientConfig;
use std::io; use std::io;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokio::fs; use tokio::fs;
use tokio_rustls::rustls::{self, Certificate, ClientConfig, RootCertStore};
use tokio_rustls::TlsConnector;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
@ -12,23 +14,26 @@ pub enum Error {
Rustls(#[from] rustls::Error), Rustls(#[from] rustls::Error),
#[error(transparent)] #[error(transparent)]
Io(#[from] io::Error), Io(#[from] io::Error),
#[error(transparent)]
NoInitialCipherSuite(#[from] NoInitialCipherSuite),
} }
pub async fn configure(certificate: &Path) -> Result<TlsConnector, Error> { pub async fn configure(certificate: &Path) -> Result<ClientConfig, Error> {
let certificate = fs::read(certificate).await?; let certificate = fs::read(certificate).await?;
let certificates = rustls_pemfile::certs(&mut certificate.as_slice())?; let certificate = rustls_pemfile::certs(&mut &*certificate).collect::<Result<Vec<_>, _>>()?;
let mut store = RootCertStore::empty(); let mut store = RootCertStore::empty();
for certificate in certificates { for certificate in certificate {
store.add(&Certificate(certificate))?; store.add(certificate)?;
} }
let config = Arc::new( let config = rustls::ClientConfig::builder()
ClientConfig::builder() .with_root_certificates(store)
.with_safe_defaults() .with_no_client_auth();
.with_root_certificates(store)
.with_no_client_auth(),
);
Ok(config.into()) let config = QuicClientConfig::try_from(config)?;
let config = Arc::new(config);
let config = ClientConfig::new(config);
Ok(config)
} }

View file

@ -5,7 +5,7 @@ use crate::sync::SyncEvent;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum Event { pub enum Event {
Rel(RelEvent), Rel(RelEvent),
Abs(AbsEvent), Abs(AbsEvent),

View file

@ -16,3 +16,4 @@ hmac = "0.12.1"
sha2 = "0.10.6" sha2 = "0.10.6"
rand = "0.8.5" rand = "0.8.5"
tracing = "0.1.37" tracing = "0.1.37"
quinn = "0.11.2"

View file

@ -5,75 +5,42 @@ pub mod auth;
pub mod message; pub mod message;
pub mod version; pub mod version;
use quinn::TransportConfig;
use rkvm_input::abs::{AbsAxis, AbsInfo}; use rkvm_input::abs::{AbsAxis, AbsInfo};
use rkvm_input::event::Event; use rkvm_input::event::Event;
use rkvm_input::key::Key; use rkvm_input::key::Key;
use rkvm_input::rel::RelAxis; use rkvm_input::rel::RelAxis;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::ffi::CString; use std::ffi::CString;
use std::future::Future;
use std::io::{Error, ErrorKind};
use std::time::Duration; use std::time::Duration;
use tokio::time;
pub const PING_INTERVAL: Duration = Duration::from_secs(1);
// Message read timeout (does not apply to updates, only auth negotiation and replies).
pub const READ_TIMEOUT: Duration = Duration::from_millis(500);
// Message write timeout (applies to all messages).
pub const WRITE_TIMEOUT: Duration = Duration::from_millis(500);
// TLS negotiation timeout.
pub const TLS_TIMEOUT: Duration = Duration::from_millis(500);
#[derive(Deserialize, Serialize, Debug)] #[derive(Deserialize, Serialize, Debug)]
pub enum Update { pub struct DeviceInfo {
CreateDevice { // ID generated by rkvm-server, sent for debugging purposes.
id: usize, pub id: usize,
name: CString, pub name: CString,
vendor: u16, pub vendor: u16,
product: u16, pub product: u16,
version: u16, pub version: u16,
rel: HashSet<RelAxis>, pub rel: HashSet<RelAxis>,
abs: HashMap<AbsAxis, AbsInfo>, pub abs: HashMap<AbsAxis, AbsInfo>,
keys: HashSet<Key>, pub keys: HashSet<Key>,
delay: Option<i32>, pub delay: Option<i32>,
period: Option<i32>, pub period: Option<i32>,
},
DestroyDevice {
id: usize,
},
Event {
id: usize,
event: Event,
},
Ping,
} }
#[derive(Deserialize, Serialize, Debug)] #[derive(Deserialize, Serialize, Debug)]
pub struct Pong; pub struct Datagram<'a> {
pub id: usize,
pub async fn timeout<T: Future<Output = Result<U, Error>>, U>( pub events: Cow<'a, [Event]>,
duration: Duration,
future: T,
) -> Result<U, Error> {
time::timeout(duration, future)
.await
.map_err(|_| Error::new(ErrorKind::TimedOut, "Message timeout"))?
} }
#[cfg(test)] pub fn transport_config() -> TransportConfig {
mod test { let mut transport_config = TransportConfig::default();
use super::message::Message; transport_config.keep_alive_interval(Some(Duration::from_millis(500)));
use super::*; transport_config.max_idle_timeout(Some(Duration::from_secs(1).try_into().unwrap()));
#[tokio::test] transport_config
async fn pong_is_not_empty() {
let mut data = Vec::new();
Pong.encode(&mut data).await.unwrap();
assert!(!data.is_empty());
}
} }

View file

@ -13,8 +13,7 @@ serde = { version = "1.0.117", features = ["derive"] }
toml = "0.5.7" toml = "0.5.7"
env_logger = "0.8.1" env_logger = "0.8.1"
clap = { version = "4.2.2", features = ["derive"] } clap = { version = "4.2.2", features = ["derive"] }
tokio-rustls = "0.24.0" rustls-pemfile = "2.1.2"
rustls-pemfile = "1.0.2"
thiserror = "1.0.40" thiserror = "1.0.40"
slab = "0.4.8" slab = "0.4.8"
rand = "0.8.5" rand = "0.8.5"
@ -22,6 +21,7 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
tracing = "0.1.37" tracing = "0.1.37"
rkvm-net = { path = "../rkvm-net" } rkvm-net = { path = "../rkvm-net" }
rkvm-input = { path = "../rkvm-input" } rkvm-input = { path = "../rkvm-input" }
quinn = "0.11.2"
[package.metadata.rpm] [package.metadata.rpm]
package = "rkvm-server" package = "rkvm-server"

View file

@ -52,8 +52,8 @@ async fn main() -> ExitCode {
} }
}; };
let acceptor = match tls::configure(&config.certificate, &config.key).await { let server_config = match tls::configure(&config.certificate, &config.key).await {
Ok(acceptor) => acceptor, Ok(server_config) => server_config,
Err(err) => { Err(err) => {
tracing::error!("Error configuring TLS: {}", err); tracing::error!("Error configuring TLS: {}", err);
return ExitCode::FAILURE; return ExitCode::FAILURE;
@ -71,7 +71,7 @@ async fn main() -> ExitCode {
let propagate_switch_keys = config.propagate_switch_keys.unwrap_or(true); let propagate_switch_keys = config.propagate_switch_keys.unwrap_or(true);
tokio::select! { tokio::select! {
result = server::run(config.listen, acceptor, &config.password, &switch_keys, propagate_switch_keys) => { result = server::run(config.listen, server_config, &config.password, &switch_keys, propagate_switch_keys) => {
if let Err(err) = result { if let Err(err) = result {
tracing::error!("Error: {}", err); tracing::error!("Error: {}", err);
return ExitCode::FAILURE; return ExitCode::FAILURE;

View file

@ -1,3 +1,4 @@
use quinn::{ConnectionError, Endpoint, Incoming, SendDatagramError, ServerConfig};
use rkvm_input::abs::{AbsAxis, AbsInfo}; use rkvm_input::abs::{AbsAxis, AbsInfo};
use rkvm_input::event::Event; use rkvm_input::event::Event;
use rkvm_input::key::{Key, KeyEvent}; use rkvm_input::key::{Key, KeyEvent};
@ -7,20 +8,17 @@ use rkvm_input::sync::SyncEvent;
use rkvm_net::auth::{AuthChallenge, AuthResponse, AuthStatus}; use rkvm_net::auth::{AuthChallenge, AuthResponse, AuthStatus};
use rkvm_net::message::Message; use rkvm_net::message::Message;
use rkvm_net::version::Version; use rkvm_net::version::Version;
use rkvm_net::{Pong, Update}; use rkvm_net::{Datagram, DeviceInfo};
use slab::Slab; use slab::Slab;
use std::collections::{HashMap, HashSet, VecDeque}; use std::collections::{HashMap, HashSet, VecDeque};
use std::ffi::CString; use std::ffi::CString;
use std::io::{self, ErrorKind}; use std::io::{self, ErrorKind};
use std::iter;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Instant;
use thiserror::Error; use thiserror::Error;
use tokio::io::{AsyncWriteExt, BufStream}; use tokio::io::{AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::error::TrySendError; use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::time;
use tokio_rustls::TlsAcceptor;
use tracing::Instrument; use tracing::Instrument;
#[derive(Error, Debug)] #[derive(Error, Debug)]
@ -35,14 +33,35 @@ pub enum Error {
pub async fn run( pub async fn run(
listen: SocketAddr, listen: SocketAddr,
acceptor: TlsAcceptor, mut config: ServerConfig,
password: &str, password: &str,
switch_keys: &HashSet<Key>, switch_keys: &HashSet<Key>,
propagate_switch_keys: bool, propagate_switch_keys: bool,
) -> Result<(), Error> { ) -> Result<(), Error> {
let listener = TcpListener::bind(&listen).await.map_err(Error::Network)?; config.transport_config(rkvm_net::transport_config().into());
let endpoint = Endpoint::server(config, listen).map_err(Error::Network)?;
tracing::info!("Listening on {}", listen); tracing::info!("Listening on {}", listen);
let (connection_sender, mut connection_receiver) = mpsc::channel(1);
tokio::spawn(async move {
loop {
tokio::select! {
connection = endpoint.accept() => {
let connection = match connection {
Some(connection) => connection,
None => break,
};
if connection_sender.send(connection).await.is_err() {
break;
}
}
_ = connection_sender.closed() => break,
}
}
});
let mut monitor = Monitor::new(); let mut monitor = Monitor::new();
let mut devices = Slab::<Device>::new(); let mut devices = Slab::<Device>::new();
let mut clients = Slab::<(Sender<_>, SocketAddr)>::new(); let mut clients = Slab::<(Sender<_>, SocketAddr)>::new();
@ -52,14 +71,17 @@ pub async fn run(
let mut pressed_keys = HashSet::new(); let mut pressed_keys = HashSet::new();
let (events_sender, mut events_receiver) = mpsc::channel(1); let (events_sender, mut events_receiver) = mpsc::channel(1);
loop { loop {
let event = async { events_receiver.recv().await.unwrap() }; let event = async { events_receiver.recv().await.unwrap() };
tokio::select! { tokio::select! {
result = listener.accept() => { connection = connection_receiver.recv() => {
let (stream, addr) = result.map_err(Error::Network)?; let connection = match connection {
let acceptor = acceptor.clone(); Some(connection) => connection,
None => break,
};
let addr = connection.remote_address();
let password = password.to_owned(); let password = password.to_owned();
// Remove dead clients. // Remove dead clients.
@ -87,12 +109,13 @@ pub async fn run(
let (sender, receiver) = mpsc::channel(1); let (sender, receiver) = mpsc::channel(1);
clients.insert((sender, addr)); clients.insert((sender, addr));
let span = tracing::info_span!("connection", addr = %addr); let span = tracing::info_span!("client", addr = %addr);
tokio::spawn( tokio::spawn(
async move { async move {
tracing::info!("Connected"); tracing::info!("Connected");
match client(init_updates, receiver, stream, acceptor, &password).await { match client(init_updates, receiver, connection, &password).await {
Ok(()) => tracing::info!("Disconnected"), Ok(()) => tracing::info!("Disconnected"),
Err(err) => tracing::error!("Disconnected: {}", err), Err(err) => tracing::error!("Disconnected: {}", err),
} }
@ -100,8 +123,8 @@ pub async fn run(
.instrument(span), .instrument(span),
); );
} }
result = monitor.read() => { interceptor = monitor.read() => {
let mut interceptor = result.map_err(Error::Input)?; let mut interceptor = interceptor.map_err(Error::Input)?;
let name = interceptor.name().to_owned(); let name = interceptor.name().to_owned();
let id = devices.vacant_key(); let id = devices.vacant_key();
@ -275,8 +298,30 @@ pub async fn run(
} }
} }
} }
}
Ok(())
}
enum Update {
CreateDevice {
id: usize,
name: CString,
vendor: u16,
product: u16,
version: u16,
rel: HashSet<RelAxis>,
abs: HashMap<AbsAxis, AbsInfo>,
keys: HashSet<Key>,
delay: Option<i32>,
period: Option<i32>,
},
DestroyDevice {
id: usize,
},
Event {
id: usize,
event: Event,
},
}
struct Device { struct Device {
name: CString, name: CString,
vendor: u16, vendor: u16,
@ -300,29 +345,27 @@ enum ClientError {
Auth, Auth,
#[error(transparent)] #[error(transparent)]
Rand(#[from] rand::Error), Rand(#[from] rand::Error),
#[error(transparent)]
Connection(#[from] ConnectionError),
} }
async fn client( async fn client(
mut init_updates: VecDeque<Update>, mut init_updates: VecDeque<Update>,
mut receiver: Receiver<Update>, mut receiver: Receiver<Update>,
stream: TcpStream, connection: Incoming,
acceptor: TlsAcceptor,
password: &str, password: &str,
) -> Result<(), ClientError> { ) -> Result<(), ClientError> {
let stream = rkvm_net::timeout(rkvm_net::TLS_TIMEOUT, acceptor.accept(stream)).await?; let connection = connection.await?;
tracing::info!("TLS connected");
let mut stream = BufStream::with_capacity(1024, 1024, stream); let (data_write, data_read) = connection.open_bi().await?;
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { let mut data_write = BufWriter::new(data_write);
Version::CURRENT.encode(&mut stream).await?; let mut data_read = BufReader::new(data_read);
stream.flush().await?;
Ok(()) Version::CURRENT.encode(&mut data_write).await?;
}) data_write.flush().await?;
.await?;
let version = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Version::decode(&mut stream)).await?; let version = Version::decode(&mut data_read).await?;
if version != Version::CURRENT { if version != Version::CURRENT {
return Err(ClientError::Version { return Err(ClientError::Version {
server: Version::CURRENT, server: Version::CURRENT,
@ -332,28 +375,18 @@ async fn client(
let challenge = AuthChallenge::generate().await?; let challenge = AuthChallenge::generate().await?;
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { challenge.encode(&mut data_write).await?;
challenge.encode(&mut stream).await?; data_write.flush().await?;
stream.flush().await?;
Ok(()) let response = AuthResponse::decode(&mut data_read).await?;
})
.await?;
let response =
rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthResponse::decode(&mut stream)).await?;
let status = match response.verify(&challenge, password) { let status = match response.verify(&challenge, password) {
true => AuthStatus::Passed, true => AuthStatus::Passed,
false => AuthStatus::Failed, false => AuthStatus::Failed,
}; };
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { status.encode(&mut data_write).await?;
status.encode(&mut stream).await?; data_write.flush().await?;
stream.flush().await?;
Ok(())
})
.await?;
if status == AuthStatus::Failed { if status == AuthStatus::Failed {
return Err(ClientError::Auth); return Err(ClientError::Auth);
@ -361,10 +394,16 @@ async fn client(
tracing::info!("Authenticated successfully"); tracing::info!("Authenticated successfully");
let mut interval = time::interval(rkvm_net::PING_INTERVAL); let mut senders = HashMap::new();
let (error_sender, mut error_receiver) = mpsc::channel(1);
data_write.shutdown().await?;
let mut enable_datagrams = true;
let mut datagram_events = Vec::new();
loop { loop {
let recv = async { let update = async {
match init_updates.pop_front() { match init_updates.pop_front() {
Some(update) => Some(update), Some(update) => Some(update),
None => receiver.recv().await, None => receiver.recv().await,
@ -372,12 +411,9 @@ async fn client(
}; };
let update = tokio::select! { let update = tokio::select! {
// Make sure pings have priority. update = update => update,
// The client could time out otherwise. err = connection.closed() => return Err(err.into()),
biased; err = error_receiver.recv() => return Err(err.unwrap()),
_ = interval.tick() => Some(Update::Ping),
recv = recv => recv,
}; };
let update = match update { let update = match update {
@ -385,29 +421,161 @@ async fn client(
None => break, None => break,
}; };
let start = Instant::now(); match update {
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { Update::CreateDevice {
update.encode(&mut stream).await?; id,
stream.flush().await?; name,
vendor,
product,
version,
rel,
abs,
keys,
delay,
period,
} => {
let write = connection.open_uni().await?;
let stream_id = write.id();
Ok(()) let mut write = BufWriter::new(write);
})
.await?;
let duration = start.elapsed();
if let Update::Ping = update { let device_info = DeviceInfo {
// Keeping these as debug because it's not as frequent as other updates. id,
tracing::debug!(duration = ?duration, "Sent ping"); name,
vendor,
product,
version,
rel,
abs,
keys,
delay,
period,
};
let start = Instant::now(); let (device_sender, stream_receiver) = mpsc::channel(1);
rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Pong::decode(&mut stream)).await?; senders.insert(id, device_sender);
let duration = start.elapsed();
tracing::debug!(duration = ?duration, "Received pong"); let error_sender = error_sender.clone();
let span = tracing::debug_span!("stream", id = %stream_id);
tokio::spawn(
async move {
tracing::debug!("Stream connected");
match stream(&mut write, &device_info, stream_receiver).await {
Ok(()) => tracing::debug!("Stream disconnected"),
Err(err) => {
tracing::debug!("Stream disconnected: {}", err);
let _ = error_sender.send(err).await;
}
}
}
.instrument(span),
);
}
Update::DestroyDevice { id } => {
senders.remove(&id).unwrap();
}
Update::Event { id, event } => match event {
Event::Rel(_) if enable_datagrams => {
datagram_events.push(event);
}
Event::Sync(SyncEvent::All) if enable_datagrams && !datagram_events.is_empty() => {
datagram_events.push(event);
let mut message = Vec::new();
Datagram {
id,
events: datagram_events.as_slice().into(),
}
.encode(&mut message)
.await?;
let length = message.len();
let err = match connection.send_datagram(message.into()) {
Ok(()) => {
tracing::trace!(
"Wrote {} unreliable event{}",
datagram_events.len(),
if datagram_events.len() == 1 { "" } else { "s" }
);
datagram_events.clear();
continue;
}
Err(err) => err,
};
match err {
SendDatagramError::UnsupportedByPeer
| SendDatagramError::Disabled
| SendDatagramError::TooLarge => {
let sender = &senders[&id];
for event in datagram_events.drain(..) {
let _ = sender.send(event).await;
}
if matches!(err, SendDatagramError::TooLarge) {
tracing::warn!(length = %length, "Datagram too large");
} else {
tracing::warn!("Disabling datagram support: {}", err);
enable_datagrams = false;
}
}
SendDatagramError::ConnectionLost(err) => return Err(err.into()),
}
}
_ => {
// Send only consecutive relative events as datagrams.
let sender = &senders[&id];
for event in datagram_events.drain(..).chain(iter::once(event)) {
let _ = sender.send(event).await;
}
}
},
} }
tracing::trace!("Wrote an update");
} }
Ok(()) Ok(())
} }
async fn stream<T: AsyncWrite + Send + Unpin>(
write: &mut T,
device_info: &DeviceInfo,
mut receiver: Receiver<Event>,
) -> Result<(), ClientError> {
let span = tracing::info_span!("device", id = device_info.id);
async {
device_info.encode(write).await?;
write.flush().await?;
let mut events = 0usize;
while let Some(event) = receiver.recv().await {
event.encode(write).await?;
events += 1;
// Coalesce multiple events into a single QUIC packet.
// The `Interceptor` won't emit them until it receives a sync event anyway.
if let Event::Sync(_) = event {
write.flush().await?;
tracing::trace!(
"Wrote {} event{}",
events,
if events == 1 { "" } else { "s" }
);
events = 0;
}
}
write.shutdown().await?;
Ok(())
}
.instrument(span)
.await
}

View file

@ -1,11 +1,10 @@
use rustls_pemfile::Item; use quinn::crypto::rustls::{NoInitialCipherSuite, QuicServerConfig};
use quinn::{rustls, ServerConfig};
use std::io;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::{io, iter};
use thiserror::Error; use thiserror::Error;
use tokio::fs; use tokio::fs;
use tokio_rustls::rustls::{self, Certificate, PrivateKey, ServerConfig};
use tokio_rustls::TlsAcceptor;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
@ -13,70 +12,26 @@ pub enum Error {
Rustls(#[from] rustls::Error), Rustls(#[from] rustls::Error),
#[error(transparent)] #[error(transparent)]
Io(#[from] io::Error), Io(#[from] io::Error),
#[error("Multiple private keys provided")] #[error("No private key provided")]
MultipleKeys,
#[error("No suitable private keys provided")]
NoKeys, NoKeys,
#[error(transparent)]
NoInitialCipherSuite(#[from] NoInitialCipherSuite),
} }
pub async fn configure(certificate: &Path, key: &Path) -> Result<TlsAcceptor, Error> { pub async fn configure(certificate: &Path, key: &Path) -> Result<ServerConfig, Error> {
enum LoadedItem { let certificate = fs::read(certificate).await?;
Certificate(Vec<u8>), let certificate = rustls_pemfile::certs(&mut &*certificate).collect::<Result<_, _>>()?;
Key(Vec<u8>),
}
let certificate = fs::read_to_string(certificate).await?; let key = fs::read(key).await?;
let key = fs::read_to_string(key).await?; let key = rustls_pemfile::private_key(&mut &*key)?.ok_or(Error::NoKeys)?;
let certificates_iter = iter::from_fn({ let config = rustls::ServerConfig::builder()
let mut buffer = certificate.as_bytes();
move || rustls_pemfile::read_one(&mut buffer).transpose()
})
.filter_map(|item| match item {
Ok(Item::X509Certificate(data)) => Some(Ok(LoadedItem::Certificate(data))),
Err(err) => Some(Err(err)),
_ => None,
});
let keys_iter = iter::from_fn({
let mut buffer = key.as_bytes();
move || rustls_pemfile::read_one(&mut buffer).transpose()
})
.filter_map(|item| match item {
Ok(Item::RSAKey(data)) | Ok(Item::PKCS8Key(data)) | Ok(Item::ECKey(data)) => {
Some(Ok(LoadedItem::Key(data)))
}
Err(err) => Some(Err(err)),
_ => None,
});
let mut certificates = Vec::new();
let mut key = None;
for item in certificates_iter.chain(keys_iter) {
let item = item?;
match item {
LoadedItem::Certificate(data) => certificates.push(Certificate(data)),
LoadedItem::Key(data) => {
if key.is_some() {
return Err(Error::MultipleKeys);
}
key = Some(PrivateKey(data));
}
}
}
let key = key.ok_or(Error::NoKeys)?;
ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth() .with_no_client_auth()
.with_single_cert(certificates, key) .with_single_cert(certificate, key)?;
.map(Arc::new)
.map(Into::into) let config = QuicServerConfig::try_from(config)?;
.map_err(Into::into) let config = Arc::new(config);
let config = ServerConfig::with_crypto(config);
Ok(config)
} }