mirror of
https://github.com/htrefil/rkvm.git
synced 2024-12-26 09:58:32 +01:00
Port rkvm to QUIC
This commit is contained in:
parent
5ec294dc03
commit
897eb0a68e
13 changed files with 862 additions and 416 deletions
303
Cargo.lock
generated
303
Cargo.lock
generated
|
@ -109,9 +109,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.7"
|
||||
version = "0.22.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||
|
||||
[[package]]
|
||||
name = "bincode"
|
||||
|
@ -178,6 +178,12 @@ version = "1.1.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47de7e88bbbd467951ae7f5a6f34f70d1b4d9cfce53d5fd70f74ebe118b3db56"
|
||||
|
||||
[[package]]
|
||||
name = "cesu8"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
|
||||
|
||||
[[package]]
|
||||
name = "cexpr"
|
||||
version = "0.6.0"
|
||||
|
@ -250,6 +256,32 @@ version = "1.0.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
|
||||
|
||||
[[package]]
|
||||
name = "combine"
|
||||
version = "4.6.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation-sys"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.12"
|
||||
|
@ -510,6 +542,26 @@ version = "1.70.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800"
|
||||
|
||||
[[package]]
|
||||
name = "jni"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec"
|
||||
dependencies = [
|
||||
"cesu8",
|
||||
"combine",
|
||||
"jni-sys",
|
||||
"log",
|
||||
"thiserror",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
|
@ -611,6 +663,34 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
|
||||
dependencies = [
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.16.0"
|
||||
|
@ -636,6 +716,12 @@ version = "1.19.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-probe"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
|
||||
|
||||
[[package]]
|
||||
name = "overload"
|
||||
version = "0.1.1"
|
||||
|
@ -691,6 +777,54 @@ dependencies = [
|
|||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e4ceeeeabace7857413798eb1ffa1e9c905a9946a57d81fb69b4b71c4d8eb3ad"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"pin-project-lite",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash",
|
||||
"rustls",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-proto"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ddf517c03a109db8100448a4be38d498df8a210a99fe0e1b9eaf39e78c640efe"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"rand",
|
||||
"ring",
|
||||
"rustc-hash",
|
||||
"rustls",
|
||||
"rustls-platform-verifier",
|
||||
"slab",
|
||||
"thiserror",
|
||||
"tinyvec",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-udp"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9096629c45860fc7fb143e125eb826b5e721e10be3263160c7d60ca832cf8c46"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"once_cell",
|
||||
"socket2",
|
||||
"tracing",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.36"
|
||||
|
@ -803,14 +937,13 @@ name = "rkvm-client"
|
|||
version = "0.6.1"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"env_logger",
|
||||
"quinn",
|
||||
"rkvm-input",
|
||||
"rkvm-net",
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"toml",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
|
@ -839,6 +972,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"bincode",
|
||||
"hmac",
|
||||
"quinn",
|
||||
"rand",
|
||||
"rkvm-input",
|
||||
"serde",
|
||||
|
@ -854,6 +988,7 @@ version = "0.6.1"
|
|||
dependencies = [
|
||||
"clap",
|
||||
"env_logger",
|
||||
"quinn",
|
||||
"rand",
|
||||
"rkvm-input",
|
||||
"rkvm-net",
|
||||
|
@ -862,7 +997,6 @@ dependencies = [
|
|||
"slab",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"toml",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
|
@ -895,43 +1029,125 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.21.12"
|
||||
version = "0.23.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e"
|
||||
checksum = "4828ea528154ae444e5a642dbb7d5623354030dc9822b83fd9bb79683c7399d0"
|
||||
dependencies = [
|
||||
"log",
|
||||
"once_cell",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki",
|
||||
"sct",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba"
|
||||
dependencies = [
|
||||
"openssl-probe",
|
||||
"rustls-pemfile",
|
||||
"rustls-pki-types",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "1.0.4"
|
||||
version = "2.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c"
|
||||
checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-platform-verifier"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3e3beb939bcd33c269f4bf946cc829fcd336370267c4a927ac0399c84a3151a1"
|
||||
dependencies = [
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"jni",
|
||||
"log",
|
||||
"once_cell",
|
||||
"rustls",
|
||||
"rustls-native-certs",
|
||||
"rustls-platform-verifier-android",
|
||||
"rustls-webpki",
|
||||
"security-framework",
|
||||
"security-framework-sys",
|
||||
"webpki-roots",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-platform-verifier-android"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "84e217e7fdc8466b5b35d30f8c0a30febd29173df4a3a0c2115d306b9c4117ad"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.101.7"
|
||||
version = "0.102.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
|
||||
checksum = "f9a6fccd794a42c2c105b513a2f62bc3fd8f3ba57a4593677ceb0bd035164d78"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sct"
|
||||
version = "0.7.1"
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
|
||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534"
|
||||
dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0"
|
||||
dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"num-bigint",
|
||||
"security-framework-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework-sys"
|
||||
version = "2.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1097,6 +1313,21 @@ dependencies = [
|
|||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938"
|
||||
dependencies = [
|
||||
"tinyvec_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec_macros"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.38.0"
|
||||
|
@ -1126,16 +1357,6 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.24.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.5.11"
|
||||
|
@ -1151,6 +1372,7 @@ version = "0.1.40"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
|
||||
dependencies = [
|
||||
"log",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
|
@ -1242,12 +1464,31 @@ version = "0.9.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
|
||||
dependencies = [
|
||||
"same-file",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.26.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd7c23921eeb1713a4e851530e9b9756e4fb0e89978582942612524cf09f01cd"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "4.4.2"
|
||||
|
@ -1429,3 +1670,9 @@ name = "windows_x86_64_msvc"
|
|||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
|
||||
[[package]]
|
||||
name = "zeroize"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
|
||||
|
|
|
@ -13,13 +13,12 @@ rkvm-input = { path = "../rkvm-input" }
|
|||
rkvm-net = { path = "../rkvm-net" }
|
||||
serde = { version = "1.0.117", features = ["derive"] }
|
||||
toml = "0.5.7"
|
||||
env_logger = "0.8.1"
|
||||
clap = { version = "4.2.2", features = ["derive"] }
|
||||
thiserror = "1.0.40"
|
||||
tokio-rustls = "0.24.0"
|
||||
rustls-pemfile = "1.0.2"
|
||||
rustls-pemfile = "2.1.2"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
|
||||
quinn = "0.11.2"
|
||||
|
||||
[package.metadata.rpm]
|
||||
package = "rkvm-client"
|
||||
|
|
|
@ -1,23 +1,31 @@
|
|||
use quinn::ClientConfig;
|
||||
use quinn::ConnectError;
|
||||
use quinn::Connection;
|
||||
use quinn::ConnectionError;
|
||||
use quinn::Endpoint;
|
||||
use rkvm_input::event::Event;
|
||||
use rkvm_input::writer::Writer;
|
||||
use rkvm_net::auth::{AuthChallenge, AuthStatus};
|
||||
use rkvm_net::message::Message;
|
||||
use rkvm_net::version::Version;
|
||||
use rkvm_net::{Pong, Update};
|
||||
use std::collections::hash_map::Entry;
|
||||
use rkvm_net::Datagram;
|
||||
use rkvm_net::DeviceInfo;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::time::Instant;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncWriteExt, BufStream};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time;
|
||||
use tokio_rustls::rustls::ServerName;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::io::BufWriter;
|
||||
use tokio::net;
|
||||
use tokio::sync::mpsc::{self, Sender};
|
||||
use tracing::Instrument;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Network error: {0}")]
|
||||
Network(io::Error),
|
||||
Network(#[from] NetworkError),
|
||||
#[error("Input error: {0}")]
|
||||
Input(io::Error),
|
||||
#[error("Incompatible server version (got {server}, expected {client})")]
|
||||
|
@ -26,45 +34,39 @@ pub enum Error {
|
|||
Auth,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum NetworkError {
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
#[error(transparent)]
|
||||
Connect(#[from] ConnectError),
|
||||
#[error(transparent)]
|
||||
Connection(#[from] ConnectionError),
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
hostname: &ServerName,
|
||||
hostname: &str,
|
||||
port: u16,
|
||||
connector: TlsConnector,
|
||||
mut config: ClientConfig,
|
||||
password: &str,
|
||||
) -> Result<(), Error> {
|
||||
// Intentionally don't impose any timeout for TCP connect.
|
||||
let stream = match hostname {
|
||||
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port)).await,
|
||||
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port)).await,
|
||||
_ => unimplemented!("Unhandled rustls ServerName variant: {:?}", hostname),
|
||||
}
|
||||
.map_err(Error::Network)?;
|
||||
config.transport_config(rkvm_net::transport_config().into());
|
||||
|
||||
tracing::info!("Connected to server");
|
||||
let connection = connect(hostname, port, config).await?;
|
||||
let (data_write, data_read) = connection.accept_bi().await.map_err(NetworkError::from)?;
|
||||
|
||||
let stream = rkvm_net::timeout(
|
||||
rkvm_net::TLS_TIMEOUT,
|
||||
connector.connect(hostname.clone(), stream),
|
||||
)
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
let mut data_write = BufWriter::new(data_write);
|
||||
let mut data_read = BufReader::new(data_read);
|
||||
|
||||
tracing::info!("TLS connected");
|
||||
|
||||
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
||||
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
Version::CURRENT.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
|
||||
let version = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Version::decode(&mut stream))
|
||||
Version::CURRENT
|
||||
.encode(&mut data_write)
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
.map_err(NetworkError::from)?;
|
||||
data_write.flush().await.map_err(NetworkError::from)?;
|
||||
|
||||
let version = Version::decode(&mut data_read)
|
||||
.await
|
||||
.map_err(NetworkError::from)?;
|
||||
|
||||
if version != Version::CURRENT {
|
||||
return Err(Error::Version {
|
||||
|
@ -73,24 +75,21 @@ pub async fn run(
|
|||
});
|
||||
}
|
||||
|
||||
let challenge = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthChallenge::decode(&mut stream))
|
||||
let challenge = AuthChallenge::decode(&mut data_read)
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
.map_err(NetworkError::from)?;
|
||||
|
||||
let response = challenge.respond(password);
|
||||
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
response.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
|
||||
let status = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthStatus::decode(&mut stream))
|
||||
response
|
||||
.encode(&mut data_write)
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
.map_err(NetworkError::from)?;
|
||||
data_write.flush().await.map_err(NetworkError::from)?;
|
||||
|
||||
let status = AuthStatus::decode(&mut data_read)
|
||||
.await
|
||||
.map_err(NetworkError::from)?;
|
||||
|
||||
match status {
|
||||
AuthStatus::Passed => {}
|
||||
|
@ -99,110 +98,235 @@ pub async fn run(
|
|||
|
||||
tracing::info!("Authenticated successfully");
|
||||
|
||||
let mut start = Instant::now();
|
||||
data_write.shutdown().await.map_err(NetworkError::from)?;
|
||||
|
||||
let mut interval = time::interval(rkvm_net::PING_INTERVAL + rkvm_net::READ_TIMEOUT);
|
||||
let mut writers = HashMap::new();
|
||||
let (error_sender, mut error_receiver) = mpsc::channel(1);
|
||||
let (device_sender, mut device_receiver) = mpsc::channel(1);
|
||||
|
||||
// Interval ticks immediately after creation.
|
||||
interval.tick().await;
|
||||
let mut devices = HashMap::new();
|
||||
|
||||
loop {
|
||||
let update = tokio::select! {
|
||||
update = Update::decode(&mut stream) => update.map_err(Error::Network)?,
|
||||
_ = interval.tick() => return Err(Error::Network(io::Error::new(io::ErrorKind::TimedOut, "Ping timed out"))),
|
||||
let device = async { device_receiver.recv().await.unwrap() };
|
||||
let read = tokio::select! {
|
||||
read = connection.accept_uni() => read.map_err(NetworkError::from)?,
|
||||
device = device => {
|
||||
match device {
|
||||
DeviceEvent::Create { id, sender } => {
|
||||
if devices.insert(id, sender).is_some() {
|
||||
return Err(NetworkError::from(io::Error::new(io::ErrorKind::AlreadyExists, "Device already exists")).into());
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
DeviceEvent::Destroy { id } => {
|
||||
devices.remove(&id).unwrap();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
datagram = connection.read_datagram() => {
|
||||
let datagram = datagram.map_err(NetworkError::from)?;
|
||||
let datagram = Datagram::decode(&mut &*datagram).await.map_err(NetworkError::from)?;
|
||||
|
||||
let sender = match devices.get(&datagram.id) {
|
||||
Some(sender) => sender,
|
||||
None => {
|
||||
tracing::warn!(id = %datagram.id, "Received datagram for unknown device");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let _ = sender.send(datagram.events.into_owned()).await;
|
||||
continue;
|
||||
}
|
||||
err = error_receiver.recv() => return Err(err.unwrap()),
|
||||
};
|
||||
|
||||
match update {
|
||||
Update::CreateDevice {
|
||||
id,
|
||||
name,
|
||||
vendor,
|
||||
product,
|
||||
version,
|
||||
rel,
|
||||
abs,
|
||||
keys,
|
||||
delay,
|
||||
period,
|
||||
} => {
|
||||
let entry = writers.entry(id);
|
||||
if let Entry::Occupied(_) = entry {
|
||||
return Err(Error::Network(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Server created the same device twice",
|
||||
)));
|
||||
let stream_id = read.id();
|
||||
|
||||
let read = BufReader::new(read);
|
||||
|
||||
let error_sender = error_sender.clone();
|
||||
let device_sender = device_sender.clone();
|
||||
let span = tracing::debug_span!("stream", id = %stream_id);
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
tracing::debug!("Stream connected");
|
||||
|
||||
match stream(read, device_sender).await {
|
||||
Ok(()) => {
|
||||
tracing::debug!("Stream disconnected");
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::debug!("Stream disconnected: {}", err);
|
||||
let _ = error_sender.send(err).await;
|
||||
}
|
||||
}
|
||||
|
||||
let writer = async {
|
||||
Writer::builder()?
|
||||
.name(&name)
|
||||
.vendor(vendor)
|
||||
.product(product)
|
||||
.version(version)
|
||||
.rel(rel)?
|
||||
.abs(abs)?
|
||||
.key(keys)?
|
||||
.delay(delay)?
|
||||
.period(period)?
|
||||
.build()
|
||||
.await
|
||||
}
|
||||
.await
|
||||
.map_err(Error::Input)?;
|
||||
|
||||
entry.or_insert(writer);
|
||||
|
||||
tracing::info!(
|
||||
id = %id,
|
||||
name = ?name,
|
||||
vendor = %vendor,
|
||||
product = %product,
|
||||
version = %version,
|
||||
"Created new device"
|
||||
);
|
||||
}
|
||||
Update::DestroyDevice { id } => {
|
||||
if writers.remove(&id).is_none() {
|
||||
return Err(Error::Network(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Server destroyed a nonexistent device",
|
||||
)));
|
||||
}
|
||||
|
||||
tracing::info!(id = %id, "Destroyed device");
|
||||
}
|
||||
Update::Event { id, event } => {
|
||||
let writer = writers.get_mut(&id).ok_or_else(|| {
|
||||
Error::Network(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Server sent an event to a nonexistent device",
|
||||
))
|
||||
})?;
|
||||
|
||||
writer.write(&event).await.map_err(Error::Input)?;
|
||||
|
||||
tracing::trace!(id = %id, "Wrote an event to device");
|
||||
}
|
||||
Update::Ping => {
|
||||
let duration = start.elapsed();
|
||||
tracing::debug!(duration = ?duration, "Received ping");
|
||||
|
||||
start = Instant::now();
|
||||
interval.reset();
|
||||
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
Pong.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
|
||||
let duration = start.elapsed();
|
||||
tracing::debug!(duration = ?duration, "Sent pong");
|
||||
}
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
hostname: &str,
|
||||
port: u16,
|
||||
config: ClientConfig,
|
||||
) -> Result<Connection, NetworkError> {
|
||||
let mut last_err = None;
|
||||
|
||||
let addrs = net::lookup_host((hostname, port))
|
||||
.await
|
||||
.map_err(NetworkError::from)?;
|
||||
|
||||
for addr in addrs {
|
||||
let bind = match addr {
|
||||
SocketAddr::V4(_) => (Ipv4Addr::UNSPECIFIED, 0).into(),
|
||||
SocketAddr::V6(_) => (Ipv6Addr::UNSPECIFIED, 0).into(),
|
||||
};
|
||||
|
||||
let endpoint = match Endpoint::client(bind) {
|
||||
Ok(endpoint) => endpoint,
|
||||
Err(err) => {
|
||||
tracing::debug!(addr = %addr, "Error binding: {}", err);
|
||||
last_err = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let connection = match endpoint.connect_with(config.clone(), addr, hostname) {
|
||||
Ok(connection) => connection,
|
||||
Err(err) => {
|
||||
tracing::debug!(addr = %addr, "Error connecting: {}", err);
|
||||
last_err = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let connection = match connection.await {
|
||||
Ok(connection) => connection,
|
||||
Err(err) => {
|
||||
tracing::debug!(addr = %addr, "Error connecting: {}", err);
|
||||
last_err = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!(addr = %addr, "Connected");
|
||||
|
||||
return Ok(connection);
|
||||
}
|
||||
|
||||
Err(last_err.unwrap_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, "No addresses resolved").into()
|
||||
}))
|
||||
}
|
||||
|
||||
enum DeviceEvent {
|
||||
Create {
|
||||
id: usize,
|
||||
sender: Sender<Vec<Event>>,
|
||||
},
|
||||
Destroy {
|
||||
id: usize,
|
||||
},
|
||||
}
|
||||
|
||||
async fn stream<T: AsyncRead + Send + Unpin + 'static>(
|
||||
mut read: T,
|
||||
device_sender: Sender<DeviceEvent>,
|
||||
) -> Result<(), Error> {
|
||||
let device_info = DeviceInfo::decode(&mut read)
|
||||
.await
|
||||
.map_err(NetworkError::from)?;
|
||||
|
||||
let id = device_info.id;
|
||||
let span = tracing::info_span!("device", id = ?device_info.id);
|
||||
|
||||
let mut writer = build(device_info).await.map_err(Error::Input)?;
|
||||
|
||||
let (datagram_sender, mut datagram_receiver) = mpsc::channel(1);
|
||||
|
||||
let event = DeviceEvent::Create {
|
||||
id,
|
||||
sender: datagram_sender,
|
||||
};
|
||||
|
||||
if device_sender.send(event).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (event_sender, mut event_receiver) = mpsc::channel(1);
|
||||
tokio::spawn(
|
||||
async move {
|
||||
loop {
|
||||
let event = tokio::select! {
|
||||
event = Event::decode(&mut read) => event,
|
||||
_ = event_sender.closed() => break,
|
||||
};
|
||||
|
||||
if event.is_err() | event_sender.send(event).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(span.clone()),
|
||||
);
|
||||
|
||||
let result = async move {
|
||||
loop {
|
||||
let event = async { event_receiver.recv().await.unwrap() };
|
||||
|
||||
tokio::select! {
|
||||
event = event => {
|
||||
let event = event.map_err(NetworkError::from)?;
|
||||
writer.write(&event).await.map_err(Error::Input)?;
|
||||
}
|
||||
datagram = datagram_receiver.recv() => {
|
||||
let datagram = match datagram {
|
||||
Some(datagram) => datagram,
|
||||
None => break,
|
||||
};
|
||||
|
||||
for event in datagram {
|
||||
writer.write(&event).await.map_err(Error::Input)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.instrument(span)
|
||||
.await;
|
||||
|
||||
let _ = device_sender.send(DeviceEvent::Destroy { id }).await;
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
async fn build(device_info: DeviceInfo) -> Result<Writer, io::Error> {
|
||||
let writer = Writer::builder()?
|
||||
.name(&device_info.name)
|
||||
.vendor(device_info.vendor)
|
||||
.product(device_info.product)
|
||||
.version(device_info.version)
|
||||
.rel(device_info.rel)?
|
||||
.abs(device_info.abs)?
|
||||
.key(device_info.keys)?
|
||||
.delay(device_info.delay)?
|
||||
.period(device_info.period)?
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
tracing::info!(
|
||||
name = ?device_info.name,
|
||||
vendor = %device_info.vendor,
|
||||
product = %device_info.product,
|
||||
version = %device_info.version,
|
||||
"Created new device"
|
||||
);
|
||||
|
||||
Ok(writer)
|
||||
}
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
use serde::de::{self, Visitor};
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use std::fmt::{self, Formatter};
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
use tokio_rustls::rustls::ServerName;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
|
@ -15,7 +12,7 @@ pub struct Config {
|
|||
}
|
||||
|
||||
pub struct Server {
|
||||
pub hostname: ServerName,
|
||||
pub hostname: String,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
|
@ -41,19 +38,11 @@ impl<'de> Visitor<'de> for ServerVisitor {
|
|||
where
|
||||
E: de::Error,
|
||||
{
|
||||
// Parsing IPv6 socket addresses can get quite hairy, so let the SocketAddr parser do it for us.
|
||||
if let Ok(socket_addr) = SocketAddr::from_str(data) {
|
||||
return Ok(Server {
|
||||
hostname: ServerName::IpAddress(socket_addr.ip()),
|
||||
port: socket_addr.port(),
|
||||
});
|
||||
}
|
||||
|
||||
let (hostname, port) = data
|
||||
.split_once(':')
|
||||
.rsplit_once(':')
|
||||
.ok_or_else(|| E::custom("No port provided"))?;
|
||||
|
||||
let hostname = hostname.try_into().map_err(E::custom)?;
|
||||
let hostname = hostname.to_owned();
|
||||
let port = port.parse().map_err(E::custom)?;
|
||||
|
||||
Ok(Server { hostname, port })
|
||||
|
@ -62,8 +51,6 @@ impl<'de> Visitor<'de> for ServerVisitor {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
@ -91,7 +78,7 @@ mod tests {
|
|||
.unwrap()
|
||||
.server;
|
||||
let expected = Server {
|
||||
hostname: "127.0.0.1".try_into().unwrap(),
|
||||
hostname: "127.0.0.1".to_owned(),
|
||||
port: 8523,
|
||||
};
|
||||
|
||||
|
@ -105,19 +92,12 @@ mod tests {
|
|||
.unwrap()
|
||||
.server;
|
||||
let expected = Server {
|
||||
hostname: "::1".try_into().unwrap(),
|
||||
hostname: "[::1]".to_owned(),
|
||||
port: 8523,
|
||||
};
|
||||
|
||||
assert_eq!(parsed.hostname, expected.hostname);
|
||||
assert_eq!(parsed.port, expected.port);
|
||||
|
||||
let parsed_ip = match parsed.hostname {
|
||||
ServerName::IpAddress(parsed_ip) => parsed_ip,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
assert_eq!(parsed_ip, Ipv6Addr::from_str("::1").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -48,8 +48,8 @@ async fn main() -> ExitCode {
|
|||
}
|
||||
};
|
||||
|
||||
let connector = match tls::configure(&config.certificate).await {
|
||||
Ok(connector) => connector,
|
||||
let client_config = match tls::configure(&config.certificate).await {
|
||||
Ok(client_config) => client_config,
|
||||
Err(err) => {
|
||||
tracing::error!("Error configuring TLS: {}", err);
|
||||
return ExitCode::FAILURE;
|
||||
|
@ -57,7 +57,7 @@ async fn main() -> ExitCode {
|
|||
};
|
||||
|
||||
tokio::select! {
|
||||
result = client::run(&config.server.hostname, config.server.port, connector, &config.password) => {
|
||||
result = client::run(&config.server.hostname, config.server.port, client_config, &config.password) => {
|
||||
if let Err(err) = result {
|
||||
tracing::error!("Error: {}", err);
|
||||
return ExitCode::FAILURE;
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use quinn::crypto::rustls::{NoInitialCipherSuite, QuicClientConfig};
|
||||
use quinn::rustls;
|
||||
use quinn::rustls::RootCertStore;
|
||||
use quinn::ClientConfig;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::fs;
|
||||
use tokio_rustls::rustls::{self, Certificate, ClientConfig, RootCertStore};
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
|
@ -12,23 +14,26 @@ pub enum Error {
|
|||
Rustls(#[from] rustls::Error),
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
#[error(transparent)]
|
||||
NoInitialCipherSuite(#[from] NoInitialCipherSuite),
|
||||
}
|
||||
|
||||
pub async fn configure(certificate: &Path) -> Result<TlsConnector, Error> {
|
||||
pub async fn configure(certificate: &Path) -> Result<ClientConfig, Error> {
|
||||
let certificate = fs::read(certificate).await?;
|
||||
let certificates = rustls_pemfile::certs(&mut certificate.as_slice())?;
|
||||
let certificate = rustls_pemfile::certs(&mut &*certificate).collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let mut store = RootCertStore::empty();
|
||||
for certificate in certificates {
|
||||
store.add(&Certificate(certificate))?;
|
||||
for certificate in certificate {
|
||||
store.add(certificate)?;
|
||||
}
|
||||
|
||||
let config = Arc::new(
|
||||
ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(store)
|
||||
.with_no_client_auth(),
|
||||
);
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.with_root_certificates(store)
|
||||
.with_no_client_auth();
|
||||
|
||||
Ok(config.into())
|
||||
let config = QuicClientConfig::try_from(config)?;
|
||||
let config = Arc::new(config);
|
||||
let config = ClientConfig::new(config);
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::sync::SyncEvent;
|
|||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
pub enum Event {
|
||||
Rel(RelEvent),
|
||||
Abs(AbsEvent),
|
||||
|
|
|
@ -16,3 +16,4 @@ hmac = "0.12.1"
|
|||
sha2 = "0.10.6"
|
||||
rand = "0.8.5"
|
||||
tracing = "0.1.37"
|
||||
quinn = "0.11.2"
|
||||
|
|
|
@ -5,75 +5,42 @@ pub mod auth;
|
|||
pub mod message;
|
||||
pub mod version;
|
||||
|
||||
use quinn::TransportConfig;
|
||||
use rkvm_input::abs::{AbsAxis, AbsInfo};
|
||||
use rkvm_input::event::Event;
|
||||
use rkvm_input::key::Key;
|
||||
use rkvm_input::rel::RelAxis;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Cow;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::ffi::CString;
|
||||
use std::future::Future;
|
||||
use std::io::{Error, ErrorKind};
|
||||
use std::time::Duration;
|
||||
use tokio::time;
|
||||
|
||||
pub const PING_INTERVAL: Duration = Duration::from_secs(1);
|
||||
|
||||
// Message read timeout (does not apply to updates, only auth negotiation and replies).
|
||||
pub const READ_TIMEOUT: Duration = Duration::from_millis(500);
|
||||
|
||||
// Message write timeout (applies to all messages).
|
||||
pub const WRITE_TIMEOUT: Duration = Duration::from_millis(500);
|
||||
|
||||
// TLS negotiation timeout.
|
||||
pub const TLS_TIMEOUT: Duration = Duration::from_millis(500);
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub enum Update {
|
||||
CreateDevice {
|
||||
id: usize,
|
||||
name: CString,
|
||||
vendor: u16,
|
||||
product: u16,
|
||||
version: u16,
|
||||
rel: HashSet<RelAxis>,
|
||||
abs: HashMap<AbsAxis, AbsInfo>,
|
||||
keys: HashSet<Key>,
|
||||
delay: Option<i32>,
|
||||
period: Option<i32>,
|
||||
},
|
||||
DestroyDevice {
|
||||
id: usize,
|
||||
},
|
||||
Event {
|
||||
id: usize,
|
||||
event: Event,
|
||||
},
|
||||
Ping,
|
||||
pub struct DeviceInfo {
|
||||
// ID generated by rkvm-server, sent for debugging purposes.
|
||||
pub id: usize,
|
||||
pub name: CString,
|
||||
pub vendor: u16,
|
||||
pub product: u16,
|
||||
pub version: u16,
|
||||
pub rel: HashSet<RelAxis>,
|
||||
pub abs: HashMap<AbsAxis, AbsInfo>,
|
||||
pub keys: HashSet<Key>,
|
||||
pub delay: Option<i32>,
|
||||
pub period: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct Pong;
|
||||
|
||||
pub async fn timeout<T: Future<Output = Result<U, Error>>, U>(
|
||||
duration: Duration,
|
||||
future: T,
|
||||
) -> Result<U, Error> {
|
||||
time::timeout(duration, future)
|
||||
.await
|
||||
.map_err(|_| Error::new(ErrorKind::TimedOut, "Message timeout"))?
|
||||
pub struct Datagram<'a> {
|
||||
pub id: usize,
|
||||
pub events: Cow<'a, [Event]>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::message::Message;
|
||||
use super::*;
|
||||
pub fn transport_config() -> TransportConfig {
|
||||
let mut transport_config = TransportConfig::default();
|
||||
transport_config.keep_alive_interval(Some(Duration::from_millis(500)));
|
||||
transport_config.max_idle_timeout(Some(Duration::from_secs(1).try_into().unwrap()));
|
||||
|
||||
#[tokio::test]
|
||||
async fn pong_is_not_empty() {
|
||||
let mut data = Vec::new();
|
||||
Pong.encode(&mut data).await.unwrap();
|
||||
|
||||
assert!(!data.is_empty());
|
||||
}
|
||||
transport_config
|
||||
}
|
||||
|
|
|
@ -13,8 +13,7 @@ serde = { version = "1.0.117", features = ["derive"] }
|
|||
toml = "0.5.7"
|
||||
env_logger = "0.8.1"
|
||||
clap = { version = "4.2.2", features = ["derive"] }
|
||||
tokio-rustls = "0.24.0"
|
||||
rustls-pemfile = "1.0.2"
|
||||
rustls-pemfile = "2.1.2"
|
||||
thiserror = "1.0.40"
|
||||
slab = "0.4.8"
|
||||
rand = "0.8.5"
|
||||
|
@ -22,6 +21,7 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
|
|||
tracing = "0.1.37"
|
||||
rkvm-net = { path = "../rkvm-net" }
|
||||
rkvm-input = { path = "../rkvm-input" }
|
||||
quinn = "0.11.2"
|
||||
|
||||
[package.metadata.rpm]
|
||||
package = "rkvm-server"
|
||||
|
|
|
@ -52,8 +52,8 @@ async fn main() -> ExitCode {
|
|||
}
|
||||
};
|
||||
|
||||
let acceptor = match tls::configure(&config.certificate, &config.key).await {
|
||||
Ok(acceptor) => acceptor,
|
||||
let server_config = match tls::configure(&config.certificate, &config.key).await {
|
||||
Ok(server_config) => server_config,
|
||||
Err(err) => {
|
||||
tracing::error!("Error configuring TLS: {}", err);
|
||||
return ExitCode::FAILURE;
|
||||
|
@ -71,7 +71,7 @@ async fn main() -> ExitCode {
|
|||
let propagate_switch_keys = config.propagate_switch_keys.unwrap_or(true);
|
||||
|
||||
tokio::select! {
|
||||
result = server::run(config.listen, acceptor, &config.password, &switch_keys, propagate_switch_keys) => {
|
||||
result = server::run(config.listen, server_config, &config.password, &switch_keys, propagate_switch_keys) => {
|
||||
if let Err(err) = result {
|
||||
tracing::error!("Error: {}", err);
|
||||
return ExitCode::FAILURE;
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use quinn::{ConnectionError, Endpoint, Incoming, SendDatagramError, ServerConfig};
|
||||
use rkvm_input::abs::{AbsAxis, AbsInfo};
|
||||
use rkvm_input::event::Event;
|
||||
use rkvm_input::key::{Key, KeyEvent};
|
||||
|
@ -7,20 +8,17 @@ use rkvm_input::sync::SyncEvent;
|
|||
use rkvm_net::auth::{AuthChallenge, AuthResponse, AuthStatus};
|
||||
use rkvm_net::message::Message;
|
||||
use rkvm_net::version::Version;
|
||||
use rkvm_net::{Pong, Update};
|
||||
use rkvm_net::{Datagram, DeviceInfo};
|
||||
use slab::Slab;
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
use std::ffi::CString;
|
||||
use std::io::{self, ErrorKind};
|
||||
use std::iter;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Instant;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncWriteExt, BufStream};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::io::{AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
|
||||
use tokio::sync::mpsc::error::TrySendError;
|
||||
use tokio::sync::mpsc::{self, Receiver, Sender};
|
||||
use tokio::time;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tracing::Instrument;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
@ -35,14 +33,35 @@ pub enum Error {
|
|||
|
||||
pub async fn run(
|
||||
listen: SocketAddr,
|
||||
acceptor: TlsAcceptor,
|
||||
mut config: ServerConfig,
|
||||
password: &str,
|
||||
switch_keys: &HashSet<Key>,
|
||||
propagate_switch_keys: bool,
|
||||
) -> Result<(), Error> {
|
||||
let listener = TcpListener::bind(&listen).await.map_err(Error::Network)?;
|
||||
config.transport_config(rkvm_net::transport_config().into());
|
||||
|
||||
let endpoint = Endpoint::server(config, listen).map_err(Error::Network)?;
|
||||
tracing::info!("Listening on {}", listen);
|
||||
|
||||
let (connection_sender, mut connection_receiver) = mpsc::channel(1);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
connection = endpoint.accept() => {
|
||||
let connection = match connection {
|
||||
Some(connection) => connection,
|
||||
None => break,
|
||||
};
|
||||
|
||||
if connection_sender.send(connection).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = connection_sender.closed() => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut monitor = Monitor::new();
|
||||
let mut devices = Slab::<Device>::new();
|
||||
let mut clients = Slab::<(Sender<_>, SocketAddr)>::new();
|
||||
|
@ -52,14 +71,17 @@ pub async fn run(
|
|||
let mut pressed_keys = HashSet::new();
|
||||
|
||||
let (events_sender, mut events_receiver) = mpsc::channel(1);
|
||||
|
||||
loop {
|
||||
let event = async { events_receiver.recv().await.unwrap() };
|
||||
|
||||
tokio::select! {
|
||||
result = listener.accept() => {
|
||||
let (stream, addr) = result.map_err(Error::Network)?;
|
||||
let acceptor = acceptor.clone();
|
||||
connection = connection_receiver.recv() => {
|
||||
let connection = match connection {
|
||||
Some(connection) => connection,
|
||||
None => break,
|
||||
};
|
||||
|
||||
let addr = connection.remote_address();
|
||||
let password = password.to_owned();
|
||||
|
||||
// Remove dead clients.
|
||||
|
@ -87,12 +109,13 @@ pub async fn run(
|
|||
let (sender, receiver) = mpsc::channel(1);
|
||||
clients.insert((sender, addr));
|
||||
|
||||
let span = tracing::info_span!("connection", addr = %addr);
|
||||
let span = tracing::info_span!("client", addr = %addr);
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
tracing::info!("Connected");
|
||||
|
||||
match client(init_updates, receiver, stream, acceptor, &password).await {
|
||||
match client(init_updates, receiver, connection, &password).await {
|
||||
Ok(()) => tracing::info!("Disconnected"),
|
||||
Err(err) => tracing::error!("Disconnected: {}", err),
|
||||
}
|
||||
|
@ -100,8 +123,8 @@ pub async fn run(
|
|||
.instrument(span),
|
||||
);
|
||||
}
|
||||
result = monitor.read() => {
|
||||
let mut interceptor = result.map_err(Error::Input)?;
|
||||
interceptor = monitor.read() => {
|
||||
let mut interceptor = interceptor.map_err(Error::Input)?;
|
||||
|
||||
let name = interceptor.name().to_owned();
|
||||
let id = devices.vacant_key();
|
||||
|
@ -275,8 +298,30 @@ pub async fn run(
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
enum Update {
|
||||
CreateDevice {
|
||||
id: usize,
|
||||
name: CString,
|
||||
vendor: u16,
|
||||
product: u16,
|
||||
version: u16,
|
||||
rel: HashSet<RelAxis>,
|
||||
abs: HashMap<AbsAxis, AbsInfo>,
|
||||
keys: HashSet<Key>,
|
||||
delay: Option<i32>,
|
||||
period: Option<i32>,
|
||||
},
|
||||
DestroyDevice {
|
||||
id: usize,
|
||||
},
|
||||
Event {
|
||||
id: usize,
|
||||
event: Event,
|
||||
},
|
||||
}
|
||||
struct Device {
|
||||
name: CString,
|
||||
vendor: u16,
|
||||
|
@ -300,29 +345,27 @@ enum ClientError {
|
|||
Auth,
|
||||
#[error(transparent)]
|
||||
Rand(#[from] rand::Error),
|
||||
#[error(transparent)]
|
||||
Connection(#[from] ConnectionError),
|
||||
}
|
||||
|
||||
async fn client(
|
||||
mut init_updates: VecDeque<Update>,
|
||||
mut receiver: Receiver<Update>,
|
||||
stream: TcpStream,
|
||||
acceptor: TlsAcceptor,
|
||||
connection: Incoming,
|
||||
password: &str,
|
||||
) -> Result<(), ClientError> {
|
||||
let stream = rkvm_net::timeout(rkvm_net::TLS_TIMEOUT, acceptor.accept(stream)).await?;
|
||||
tracing::info!("TLS connected");
|
||||
let connection = connection.await?;
|
||||
|
||||
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
||||
let (data_write, data_read) = connection.open_bi().await?;
|
||||
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
Version::CURRENT.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
let mut data_write = BufWriter::new(data_write);
|
||||
let mut data_read = BufReader::new(data_read);
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
Version::CURRENT.encode(&mut data_write).await?;
|
||||
data_write.flush().await?;
|
||||
|
||||
let version = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Version::decode(&mut stream)).await?;
|
||||
let version = Version::decode(&mut data_read).await?;
|
||||
if version != Version::CURRENT {
|
||||
return Err(ClientError::Version {
|
||||
server: Version::CURRENT,
|
||||
|
@ -332,28 +375,18 @@ async fn client(
|
|||
|
||||
let challenge = AuthChallenge::generate().await?;
|
||||
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
challenge.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
challenge.encode(&mut data_write).await?;
|
||||
data_write.flush().await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
let response = AuthResponse::decode(&mut data_read).await?;
|
||||
|
||||
let response =
|
||||
rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthResponse::decode(&mut stream)).await?;
|
||||
let status = match response.verify(&challenge, password) {
|
||||
true => AuthStatus::Passed,
|
||||
false => AuthStatus::Failed,
|
||||
};
|
||||
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
status.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
status.encode(&mut data_write).await?;
|
||||
data_write.flush().await?;
|
||||
|
||||
if status == AuthStatus::Failed {
|
||||
return Err(ClientError::Auth);
|
||||
|
@ -361,10 +394,16 @@ async fn client(
|
|||
|
||||
tracing::info!("Authenticated successfully");
|
||||
|
||||
let mut interval = time::interval(rkvm_net::PING_INTERVAL);
|
||||
let mut senders = HashMap::new();
|
||||
let (error_sender, mut error_receiver) = mpsc::channel(1);
|
||||
|
||||
data_write.shutdown().await?;
|
||||
|
||||
let mut enable_datagrams = true;
|
||||
let mut datagram_events = Vec::new();
|
||||
|
||||
loop {
|
||||
let recv = async {
|
||||
let update = async {
|
||||
match init_updates.pop_front() {
|
||||
Some(update) => Some(update),
|
||||
None => receiver.recv().await,
|
||||
|
@ -372,12 +411,9 @@ async fn client(
|
|||
};
|
||||
|
||||
let update = tokio::select! {
|
||||
// Make sure pings have priority.
|
||||
// The client could time out otherwise.
|
||||
biased;
|
||||
|
||||
_ = interval.tick() => Some(Update::Ping),
|
||||
recv = recv => recv,
|
||||
update = update => update,
|
||||
err = connection.closed() => return Err(err.into()),
|
||||
err = error_receiver.recv() => return Err(err.unwrap()),
|
||||
};
|
||||
|
||||
let update = match update {
|
||||
|
@ -385,29 +421,161 @@ async fn client(
|
|||
None => break,
|
||||
};
|
||||
|
||||
let start = Instant::now();
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
update.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
match update {
|
||||
Update::CreateDevice {
|
||||
id,
|
||||
name,
|
||||
vendor,
|
||||
product,
|
||||
version,
|
||||
rel,
|
||||
abs,
|
||||
keys,
|
||||
delay,
|
||||
period,
|
||||
} => {
|
||||
let write = connection.open_uni().await?;
|
||||
let stream_id = write.id();
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
let duration = start.elapsed();
|
||||
let mut write = BufWriter::new(write);
|
||||
|
||||
if let Update::Ping = update {
|
||||
// Keeping these as debug because it's not as frequent as other updates.
|
||||
tracing::debug!(duration = ?duration, "Sent ping");
|
||||
let device_info = DeviceInfo {
|
||||
id,
|
||||
name,
|
||||
vendor,
|
||||
product,
|
||||
version,
|
||||
rel,
|
||||
abs,
|
||||
keys,
|
||||
delay,
|
||||
period,
|
||||
};
|
||||
|
||||
let start = Instant::now();
|
||||
rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Pong::decode(&mut stream)).await?;
|
||||
let duration = start.elapsed();
|
||||
let (device_sender, stream_receiver) = mpsc::channel(1);
|
||||
senders.insert(id, device_sender);
|
||||
|
||||
tracing::debug!(duration = ?duration, "Received pong");
|
||||
let error_sender = error_sender.clone();
|
||||
let span = tracing::debug_span!("stream", id = %stream_id);
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
tracing::debug!("Stream connected");
|
||||
|
||||
match stream(&mut write, &device_info, stream_receiver).await {
|
||||
Ok(()) => tracing::debug!("Stream disconnected"),
|
||||
Err(err) => {
|
||||
tracing::debug!("Stream disconnected: {}", err);
|
||||
let _ = error_sender.send(err).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
Update::DestroyDevice { id } => {
|
||||
senders.remove(&id).unwrap();
|
||||
}
|
||||
Update::Event { id, event } => match event {
|
||||
Event::Rel(_) if enable_datagrams => {
|
||||
datagram_events.push(event);
|
||||
}
|
||||
Event::Sync(SyncEvent::All) if enable_datagrams && !datagram_events.is_empty() => {
|
||||
datagram_events.push(event);
|
||||
|
||||
let mut message = Vec::new();
|
||||
Datagram {
|
||||
id,
|
||||
events: datagram_events.as_slice().into(),
|
||||
}
|
||||
.encode(&mut message)
|
||||
.await?;
|
||||
|
||||
let length = message.len();
|
||||
|
||||
let err = match connection.send_datagram(message.into()) {
|
||||
Ok(()) => {
|
||||
tracing::trace!(
|
||||
"Wrote {} unreliable event{}",
|
||||
datagram_events.len(),
|
||||
if datagram_events.len() == 1 { "" } else { "s" }
|
||||
);
|
||||
|
||||
datagram_events.clear();
|
||||
continue;
|
||||
}
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
match err {
|
||||
SendDatagramError::UnsupportedByPeer
|
||||
| SendDatagramError::Disabled
|
||||
| SendDatagramError::TooLarge => {
|
||||
let sender = &senders[&id];
|
||||
for event in datagram_events.drain(..) {
|
||||
let _ = sender.send(event).await;
|
||||
}
|
||||
|
||||
if matches!(err, SendDatagramError::TooLarge) {
|
||||
tracing::warn!(length = %length, "Datagram too large");
|
||||
} else {
|
||||
tracing::warn!("Disabling datagram support: {}", err);
|
||||
enable_datagrams = false;
|
||||
}
|
||||
}
|
||||
SendDatagramError::ConnectionLost(err) => return Err(err.into()),
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Send only consecutive relative events as datagrams.
|
||||
let sender = &senders[&id];
|
||||
for event in datagram_events.drain(..).chain(iter::once(event)) {
|
||||
let _ = sender.send(event).await;
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
tracing::trace!("Wrote an update");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stream<T: AsyncWrite + Send + Unpin>(
|
||||
write: &mut T,
|
||||
device_info: &DeviceInfo,
|
||||
mut receiver: Receiver<Event>,
|
||||
) -> Result<(), ClientError> {
|
||||
let span = tracing::info_span!("device", id = device_info.id);
|
||||
|
||||
async {
|
||||
device_info.encode(write).await?;
|
||||
write.flush().await?;
|
||||
|
||||
let mut events = 0usize;
|
||||
|
||||
while let Some(event) = receiver.recv().await {
|
||||
event.encode(write).await?;
|
||||
events += 1;
|
||||
|
||||
// Coalesce multiple events into a single QUIC packet.
|
||||
// The `Interceptor` won't emit them until it receives a sync event anyway.
|
||||
if let Event::Sync(_) = event {
|
||||
write.flush().await?;
|
||||
|
||||
tracing::trace!(
|
||||
"Wrote {} event{}",
|
||||
events,
|
||||
if events == 1 { "" } else { "s" }
|
||||
);
|
||||
|
||||
events = 0;
|
||||
}
|
||||
}
|
||||
|
||||
write.shutdown().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.instrument(span)
|
||||
.await
|
||||
}
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
use rustls_pemfile::Item;
|
||||
use quinn::crypto::rustls::{NoInitialCipherSuite, QuicServerConfig};
|
||||
use quinn::{rustls, ServerConfig};
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::{io, iter};
|
||||
use thiserror::Error;
|
||||
use tokio::fs;
|
||||
use tokio_rustls::rustls::{self, Certificate, PrivateKey, ServerConfig};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
|
@ -13,70 +12,26 @@ pub enum Error {
|
|||
Rustls(#[from] rustls::Error),
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
#[error("Multiple private keys provided")]
|
||||
MultipleKeys,
|
||||
#[error("No suitable private keys provided")]
|
||||
#[error("No private key provided")]
|
||||
NoKeys,
|
||||
#[error(transparent)]
|
||||
NoInitialCipherSuite(#[from] NoInitialCipherSuite),
|
||||
}
|
||||
|
||||
pub async fn configure(certificate: &Path, key: &Path) -> Result<TlsAcceptor, Error> {
|
||||
enum LoadedItem {
|
||||
Certificate(Vec<u8>),
|
||||
Key(Vec<u8>),
|
||||
}
|
||||
pub async fn configure(certificate: &Path, key: &Path) -> Result<ServerConfig, Error> {
|
||||
let certificate = fs::read(certificate).await?;
|
||||
let certificate = rustls_pemfile::certs(&mut &*certificate).collect::<Result<_, _>>()?;
|
||||
|
||||
let certificate = fs::read_to_string(certificate).await?;
|
||||
let key = fs::read_to_string(key).await?;
|
||||
let key = fs::read(key).await?;
|
||||
let key = rustls_pemfile::private_key(&mut &*key)?.ok_or(Error::NoKeys)?;
|
||||
|
||||
let certificates_iter = iter::from_fn({
|
||||
let mut buffer = certificate.as_bytes();
|
||||
|
||||
move || rustls_pemfile::read_one(&mut buffer).transpose()
|
||||
})
|
||||
.filter_map(|item| match item {
|
||||
Ok(Item::X509Certificate(data)) => Some(Ok(LoadedItem::Certificate(data))),
|
||||
Err(err) => Some(Err(err)),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
let keys_iter = iter::from_fn({
|
||||
let mut buffer = key.as_bytes();
|
||||
|
||||
move || rustls_pemfile::read_one(&mut buffer).transpose()
|
||||
})
|
||||
.filter_map(|item| match item {
|
||||
Ok(Item::RSAKey(data)) | Ok(Item::PKCS8Key(data)) | Ok(Item::ECKey(data)) => {
|
||||
Some(Ok(LoadedItem::Key(data)))
|
||||
}
|
||||
Err(err) => Some(Err(err)),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
let mut certificates = Vec::new();
|
||||
let mut key = None;
|
||||
|
||||
for item in certificates_iter.chain(keys_iter) {
|
||||
let item = item?;
|
||||
|
||||
match item {
|
||||
LoadedItem::Certificate(data) => certificates.push(Certificate(data)),
|
||||
LoadedItem::Key(data) => {
|
||||
if key.is_some() {
|
||||
return Err(Error::MultipleKeys);
|
||||
}
|
||||
|
||||
key = Some(PrivateKey(data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let key = key.ok_or(Error::NoKeys)?;
|
||||
|
||||
ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certificates, key)
|
||||
.map(Arc::new)
|
||||
.map(Into::into)
|
||||
.map_err(Into::into)
|
||||
.with_single_cert(certificate, key)?;
|
||||
|
||||
let config = QuicServerConfig::try_from(config)?;
|
||||
let config = Arc::new(config);
|
||||
let config = ServerConfig::with_crypto(config);
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue