From 16f4a72b30b1e2232253fe5616415672e794b21c Mon Sep 17 00:00:00 2001 From: Jan Trefil <8711792+htrefil@users.noreply.github.com> Date: Sat, 13 Jul 2024 19:30:58 +0200 Subject: [PATCH] Make timeouts configurable --- Cargo.lock | 90 +++++++++++++++++++++++++++++++++++- Cargo.toml | 2 +- example/client.toml | 10 ++++ example/server.toml | 10 ++++ rkvm-client/Cargo.toml | 3 +- rkvm-client/src/client.rs | 25 +++++----- rkvm-client/src/config.rs | 3 ++ rkvm-client/src/main.rs | 2 +- rkvm-config/Cargo.toml | 13 ++++++ rkvm-config/src/lib.rs | 97 +++++++++++++++++++++++++++++++++++++++ rkvm-net/src/lib.rs | 9 ---- rkvm-server/Cargo.toml | 2 + rkvm-server/src/config.rs | 3 ++ rkvm-server/src/main.rs | 2 +- rkvm-server/src/server.rs | 22 +++++---- 15 files changed, 255 insertions(+), 38 deletions(-) create mode 100644 rkvm-config/Cargo.toml create mode 100644 rkvm-config/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 0e1355d..4c601f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -299,6 +299,12 @@ dependencies = [ "termcolor", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.9" @@ -437,6 +443,12 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "heck" version = "0.5.0" @@ -482,6 +494,16 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "indexmap" +version = "2.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "inotify" version = "0.10.2" @@ -804,6 +826,7 @@ version = "0.6.0" dependencies = [ "clap", "env_logger", + "rkvm-config", "rkvm-input", "rkvm-net", "rustls-pemfile", @@ -811,11 +834,20 @@ dependencies = [ "thiserror", "tokio", "tokio-rustls", - "toml", + "toml 0.5.11", "tracing", "tracing-subscriber", ] +[[package]] +name = "rkvm-config" +version = "0.1.0" +dependencies = [ + "humantime", + "serde", + "toml 0.8.14", +] + [[package]] name = "rkvm-input" version = "0.1.0" @@ -854,7 +886,9 @@ version = "0.6.0" dependencies = [ "clap", "env_logger", + "humantime", "rand", + "rkvm-config", "rkvm-input", "rkvm-net", "rustls-pemfile", @@ -863,7 +897,7 @@ dependencies = [ "thiserror", "tokio", "tokio-rustls", - "toml", + "toml 0.5.11", "tracing", "tracing-subscriber", ] @@ -954,6 +988,15 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_spanned" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0" +dependencies = [ + "serde", +] + [[package]] name = "sha2" version = "0.10.8" @@ -1145,6 +1188,40 @@ dependencies = [ "serde", ] +[[package]] +name = "toml" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f49eb2ab21d2f26bd6db7bf383edc527a7ebaee412d17af4d40fdccd442f335" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d59a3a72298453f564e2b111fa896f8d07fabb36f51f06d7e875fc5e0b5a3ef1" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + [[package]] name = "tracing" version = "0.1.40" @@ -1429,3 +1506,12 @@ name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "winnow" +version = "0.6.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59b5e5f6c299a3c7890b876a2a587f3115162487e704907d9b6cd29473052ba1" +dependencies = [ + "memchr", +] diff --git a/Cargo.toml b/Cargo.toml index 230d58b..f6b91c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,3 @@ [workspace] resolver = "2" -members = ["rkvm-client", "rkvm-server", "rkvm-input", "rkvm-net", "rkvm-certificate-gen"] +members = ["rkvm-client", "rkvm-server", "rkvm-input", "rkvm-net", "rkvm-certificate-gen", "rkvm-config"] diff --git a/example/client.toml b/example/client.toml index b3a880f..9e0a3ae 100644 --- a/example/client.toml +++ b/example/client.toml @@ -6,3 +6,13 @@ certificate = "/etc/rkvm/certificate.pem" # # Change this to your own value before deploying rkvm. password = "123456789" + +# Optional values. +# Make sure these match what you have in your server's config. + +# Message read timeout. +# read-timeout = "500ms" +# Message write timeout. +# write-timeout = "500ms" +# TLS handshake timeout. +# tls-timeout = "500ms" \ No newline at end of file diff --git a/example/server.toml b/example/server.toml index 1c00141..a0a614c 100644 --- a/example/server.toml +++ b/example/server.toml @@ -12,3 +12,13 @@ key = "/etc/rkvm/key.pem" # # Change this to your own value before deploying rkvm. password = "123456789" + +# Optional values. +# Make sure these match what you have in your client's config. + +# Message write timeout. Increase this if you are getting timeout errors. +# read-timeout = "500ms" +# Message write timeout. Increase this if you are getting timeout errors. +# write-timeout = "500ms" +# TLS handshake timeout. +# tls-timeout = "500ms" \ No newline at end of file diff --git a/rkvm-client/Cargo.toml b/rkvm-client/Cargo.toml index 819788a..a05d1b2 100644 --- a/rkvm-client/Cargo.toml +++ b/rkvm-client/Cargo.toml @@ -10,7 +10,6 @@ edition = "2021" [dependencies] tokio = { version = "1.0.1", features = ["macros", "time", "fs", "net", "signal", "rt-multi-thread", "sync"] } rkvm-input = { path = "../rkvm-input" } -rkvm-net = { path = "../rkvm-net" } serde = { version = "1.0.117", features = ["derive"] } toml = "0.5.7" env_logger = "0.8.1" @@ -20,6 +19,8 @@ tokio-rustls = "0.24.0" rustls-pemfile = "1.0.2" tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +rkvm-net = { path = "../rkvm-net" } +rkvm-config = { path = "../rkvm-config" } [package.metadata.rpm] package = "rkvm-client" diff --git a/rkvm-client/src/client.rs b/rkvm-client/src/client.rs index 9cefbf1..06bb80f 100644 --- a/rkvm-client/src/client.rs +++ b/rkvm-client/src/client.rs @@ -1,3 +1,4 @@ +use rkvm_config::Timeout; use rkvm_input::writer::Writer; use rkvm_net::auth::{AuthChallenge, AuthStatus}; use rkvm_net::message::Message; @@ -31,6 +32,7 @@ pub async fn run( port: u16, connector: TlsConnector, password: &str, + timeout: Timeout, ) -> Result<(), Error> { // Intentionally don't impose any timeout for TCP connect. let stream = match hostname { @@ -42,18 +44,15 @@ pub async fn run( tracing::info!("Connected to server"); - let stream = rkvm_net::timeout( - rkvm_net::TLS_TIMEOUT, - connector.connect(hostname.clone(), stream), - ) - .await - .map_err(Error::Network)?; + let stream = rkvm_net::timeout(timeout.tls, connector.connect(hostname.clone(), stream)) + .await + .map_err(Error::Network)?; tracing::info!("TLS connected"); let mut stream = BufStream::with_capacity(1024, 1024, stream); - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { + rkvm_net::timeout(timeout.write, async { Version::CURRENT.encode(&mut stream).await?; stream.flush().await?; @@ -62,7 +61,7 @@ pub async fn run( .await .map_err(Error::Network)?; - let version = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Version::decode(&mut stream)) + let version = rkvm_net::timeout(timeout.read, Version::decode(&mut stream)) .await .map_err(Error::Network)?; @@ -73,13 +72,13 @@ pub async fn run( }); } - let challenge = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthChallenge::decode(&mut stream)) + let challenge = rkvm_net::timeout(timeout.read, AuthChallenge::decode(&mut stream)) .await .map_err(Error::Network)?; let response = challenge.respond(password); - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { + rkvm_net::timeout(timeout.write, async { response.encode(&mut stream).await?; stream.flush().await?; @@ -88,7 +87,7 @@ pub async fn run( .await .map_err(Error::Network)?; - let status = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthStatus::decode(&mut stream)) + let status = rkvm_net::timeout(timeout.read, AuthStatus::decode(&mut stream)) .await .map_err(Error::Network)?; @@ -101,7 +100,7 @@ pub async fn run( let mut start = Instant::now(); - let mut interval = time::interval(rkvm_net::PING_INTERVAL + rkvm_net::READ_TIMEOUT); + let mut interval = time::interval(rkvm_net::PING_INTERVAL + timeout.read); let mut writers = HashMap::new(); // Interval ticks immediately after creation. @@ -191,7 +190,7 @@ pub async fn run( start = Instant::now(); interval.reset(); - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { + rkvm_net::timeout(timeout.write, async { Pong.encode(&mut stream).await?; stream.flush().await?; diff --git a/rkvm-client/src/config.rs b/rkvm-client/src/config.rs index 8d8145e..6b45291 100644 --- a/rkvm-client/src/config.rs +++ b/rkvm-client/src/config.rs @@ -1,3 +1,4 @@ +use rkvm_config::Timeout; use serde::de::{self, Visitor}; use serde::{Deserialize, Deserializer}; use std::fmt::{self, Formatter}; @@ -12,6 +13,8 @@ pub struct Config { pub server: Server, pub certificate: PathBuf, pub password: String, + #[serde(flatten)] + pub timeout: Timeout, } pub struct Server { diff --git a/rkvm-client/src/main.rs b/rkvm-client/src/main.rs index 5e15b42..853ef64 100644 --- a/rkvm-client/src/main.rs +++ b/rkvm-client/src/main.rs @@ -57,7 +57,7 @@ async fn main() -> ExitCode { }; tokio::select! { - result = client::run(&config.server.hostname, config.server.port, connector, &config.password) => { + result = client::run(&config.server.hostname, config.server.port, connector, &config.password, config.timeout) => { if let Err(err) = result { tracing::error!("Error: {}", err); return ExitCode::FAILURE; diff --git a/rkvm-config/Cargo.toml b/rkvm-config/Cargo.toml new file mode 100644 index 0000000..0f2066a --- /dev/null +++ b/rkvm-config/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "rkvm-config" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +humantime = "2.1.0" +serde = { version = "1.0.204", features = ["derive"] } + +[dev-dependencies] +toml = "0.8.14" diff --git a/rkvm-config/src/lib.rs b/rkvm-config/src/lib.rs new file mode 100644 index 0000000..3b30584 --- /dev/null +++ b/rkvm-config/src/lib.rs @@ -0,0 +1,97 @@ +use serde::de::{Deserializer, Error, Visitor}; +use serde::Deserialize; +use std::fmt::{self, Formatter}; +use std::str::FromStr; +use std::time::Duration; + +#[derive(Deserialize, Clone, Copy, PartialEq, Eq, Debug)] +pub struct Timeout { + #[serde( + rename = "read-timeout", + default = "default_timeout", + deserialize_with = "deserialize_duration" + )] + pub read: Duration, + #[serde( + rename = "write-timeout", + default = "default_timeout", + deserialize_with = "deserialize_duration" + )] + pub write: Duration, + #[serde( + rename = "tls-timeout", + default = "default_timeout", + deserialize_with = "deserialize_duration" + )] + pub tls: Duration, +} + +fn default_timeout() -> Duration { + Duration::from_millis(500) +} + +fn deserialize_duration<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + struct DurationVisitor; + + impl Visitor<'_> for DurationVisitor { + type Value = Duration; + + fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { + write!(formatter, "a duration of time (for example \"500ms\")") + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + humantime::Duration::from_str(v) + .map_err(E::custom) + .map(Into::into) + } + } + + deserializer.deserialize_str(DurationVisitor) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn timeout_deserialize() { + let parsed = toml::from_str::( + r#" + read-timeout = "1s" + write-timeout = "200ms" + tls-timeout = "500ms" + "#, + ) + .unwrap(); + + assert_eq!( + parsed, + Timeout { + read: Duration::from_secs(1), + write: Duration::from_millis(200), + tls: Duration::from_millis(500), + } + ); + } + + #[test] + fn timeout_missing_values() { + let parsed = toml::from_str::("").unwrap(); + + assert_eq!( + parsed, + Timeout { + read: default_timeout(), + write: default_timeout(), + tls: default_timeout(), + } + ); + } +} diff --git a/rkvm-net/src/lib.rs b/rkvm-net/src/lib.rs index 9425e38..f7332e3 100644 --- a/rkvm-net/src/lib.rs +++ b/rkvm-net/src/lib.rs @@ -19,15 +19,6 @@ use tokio::time; pub const PING_INTERVAL: Duration = Duration::from_secs(1); -// Message read timeout (does not apply to updates, only auth negotiation and replies). -pub const READ_TIMEOUT: Duration = Duration::from_millis(500); - -// Message write timeout (applies to all messages). -pub const WRITE_TIMEOUT: Duration = Duration::from_millis(500); - -// TLS negotiation timeout. -pub const TLS_TIMEOUT: Duration = Duration::from_millis(500); - #[derive(Deserialize, Serialize, Debug)] pub enum Update { CreateDevice { diff --git a/rkvm-server/Cargo.toml b/rkvm-server/Cargo.toml index 0dac7fd..c84c0f0 100644 --- a/rkvm-server/Cargo.toml +++ b/rkvm-server/Cargo.toml @@ -20,8 +20,10 @@ slab = "0.4.8" rand = "0.8.5" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } tracing = "0.1.37" +humantime = "2.1.0" rkvm-net = { path = "../rkvm-net" } rkvm-input = { path = "../rkvm-input" } +rkvm-config = { path = "../rkvm-config" } [package.metadata.rpm] package = "rkvm-server" diff --git a/rkvm-server/src/config.rs b/rkvm-server/src/config.rs index 98ef27f..6cd483e 100644 --- a/rkvm-server/src/config.rs +++ b/rkvm-server/src/config.rs @@ -1,3 +1,4 @@ +use rkvm_config::Timeout; use rkvm_input::key::{Button, Key, Keyboard}; use serde::Deserialize; use std::collections::HashSet; @@ -13,6 +14,8 @@ pub struct Config { pub password: String, pub switch_keys: HashSet, pub propagate_switch_keys: Option, + #[serde(flatten)] + pub timeout: Timeout, } #[derive(Deserialize, Clone, Copy, PartialEq, Eq, Hash)] diff --git a/rkvm-server/src/main.rs b/rkvm-server/src/main.rs index eb3f402..686d046 100644 --- a/rkvm-server/src/main.rs +++ b/rkvm-server/src/main.rs @@ -71,7 +71,7 @@ async fn main() -> ExitCode { let propagate_switch_keys = config.propagate_switch_keys.unwrap_or(true); tokio::select! { - result = server::run(config.listen, acceptor, &config.password, &switch_keys, propagate_switch_keys) => { + result = server::run(config.listen, acceptor, &config.password, &switch_keys, propagate_switch_keys, config.timeout) => { if let Err(err) = result { tracing::error!("Error: {}", err); return ExitCode::FAILURE; diff --git a/rkvm-server/src/server.rs b/rkvm-server/src/server.rs index b783181..eea4b3c 100644 --- a/rkvm-server/src/server.rs +++ b/rkvm-server/src/server.rs @@ -1,3 +1,4 @@ +use rkvm_config::Timeout; use rkvm_input::abs::{AbsAxis, AbsInfo}; use rkvm_input::event::Event; use rkvm_input::key::{Key, KeyEvent}; @@ -39,6 +40,7 @@ pub async fn run( password: &str, switch_keys: &HashSet, propagate_switch_keys: bool, + timeout: Timeout, ) -> Result<(), Error> { let listener = TcpListener::bind(&listen).await.map_err(Error::Network)?; tracing::info!("Listening on {}", listen); @@ -92,7 +94,7 @@ pub async fn run( async move { tracing::info!("Connected"); - match client(init_updates, receiver, stream, acceptor, &password).await { + match client(init_updates, receiver, stream, acceptor, &password, timeout).await { Ok(()) => tracing::info!("Disconnected"), Err(err) => tracing::error!("Disconnected: {}", err), } @@ -308,13 +310,14 @@ async fn client( stream: TcpStream, acceptor: TlsAcceptor, password: &str, + timeout: Timeout, ) -> Result<(), ClientError> { - let stream = rkvm_net::timeout(rkvm_net::TLS_TIMEOUT, acceptor.accept(stream)).await?; + let stream = rkvm_net::timeout(timeout.tls, acceptor.accept(stream)).await?; tracing::info!("TLS connected"); let mut stream = BufStream::with_capacity(1024, 1024, stream); - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { + rkvm_net::timeout(timeout.write, async { Version::CURRENT.encode(&mut stream).await?; stream.flush().await?; @@ -322,7 +325,7 @@ async fn client( }) .await?; - let version = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Version::decode(&mut stream)).await?; + let version = rkvm_net::timeout(timeout.read, Version::decode(&mut stream)).await?; if version != Version::CURRENT { return Err(ClientError::Version { server: Version::CURRENT, @@ -332,7 +335,7 @@ async fn client( let challenge = AuthChallenge::generate().await?; - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { + rkvm_net::timeout(timeout.write, async { challenge.encode(&mut stream).await?; stream.flush().await?; @@ -340,14 +343,13 @@ async fn client( }) .await?; - let response = - rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthResponse::decode(&mut stream)).await?; + let response = rkvm_net::timeout(timeout.read, AuthResponse::decode(&mut stream)).await?; let status = match response.verify(&challenge, password) { true => AuthStatus::Passed, false => AuthStatus::Failed, }; - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { + rkvm_net::timeout(timeout.write, async { status.encode(&mut stream).await?; stream.flush().await?; @@ -386,7 +388,7 @@ async fn client( }; let start = Instant::now(); - rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async { + rkvm_net::timeout(timeout.write, async { update.encode(&mut stream).await?; stream.flush().await?; @@ -400,7 +402,7 @@ async fn client( tracing::debug!(duration = ?duration, "Sent ping"); let start = Instant::now(); - rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Pong::decode(&mut stream)).await?; + rkvm_net::timeout(timeout.read, Pong::decode(&mut stream)).await?; let duration = start.elapsed(); tracing::debug!(duration = ?duration, "Received pong");