mirror of
https://github.com/htrefil/rkvm.git
synced 2024-11-16 07:47:24 +01:00
Implement application layer timeouts
This commit is contained in:
parent
65083f47c1
commit
9560c761e5
3 changed files with 180 additions and 73 deletions
|
@ -2,13 +2,15 @@ use rkvm_input::writer::Writer;
|
|||
use rkvm_net::auth::{AuthChallenge, AuthStatus};
|
||||
use rkvm_net::message::Message;
|
||||
use rkvm_net::version::Version;
|
||||
use rkvm_net::Update;
|
||||
use rkvm_net::{Pong, Update};
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::time::Instant;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncWriteExt, BufStream};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time;
|
||||
use tokio_rustls::rustls::ServerName;
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
|
@ -30,20 +32,20 @@ pub async fn run(
|
|||
connector: TlsConnector,
|
||||
password: &str,
|
||||
) -> Result<(), Error> {
|
||||
// Intentionally don't impose any timeout for TCP connect.
|
||||
let stream = match hostname {
|
||||
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port))
|
||||
.await
|
||||
.map_err(Error::Network)?,
|
||||
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port))
|
||||
.await
|
||||
.map_err(Error::Network)?,
|
||||
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port)).await,
|
||||
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port)).await,
|
||||
_ => unimplemented!("Unhandled rustls ServerName variant: {:?}", hostname),
|
||||
};
|
||||
}
|
||||
.map_err(Error::Network)?;
|
||||
|
||||
tracing::info!("Connected to server");
|
||||
|
||||
let stream = connector
|
||||
.connect(hostname.clone(), stream)
|
||||
let stream = rkvm_net::timeout(
|
||||
rkvm_net::TLS_TIMEOUT,
|
||||
connector.connect(hostname.clone(), stream),
|
||||
)
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
|
||||
|
@ -51,13 +53,19 @@ pub async fn run(
|
|||
|
||||
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
||||
|
||||
Version::CURRENT
|
||||
.encode(&mut stream)
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
Version::CURRENT.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
|
||||
let version = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Version::decode(&mut stream))
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
stream.flush().await.map_err(Error::Network)?;
|
||||
|
||||
let version = Version::decode(&mut stream).await.map_err(Error::Network)?;
|
||||
if version != Version::CURRENT {
|
||||
return Err(Error::Version {
|
||||
server: Version::CURRENT,
|
||||
|
@ -65,25 +73,47 @@ pub async fn run(
|
|||
});
|
||||
}
|
||||
|
||||
let challenge = AuthChallenge::decode(&mut stream)
|
||||
let challenge = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthChallenge::decode(&mut stream))
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
|
||||
let response = challenge.respond(password);
|
||||
|
||||
response.encode(&mut stream).await.map_err(Error::Network)?;
|
||||
stream.flush().await.map_err(Error::Network)?;
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
response.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
match Message::decode(&mut stream).await.map_err(Error::Network)? {
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
|
||||
let status = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthStatus::decode(&mut stream))
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
|
||||
match status {
|
||||
AuthStatus::Passed => {}
|
||||
AuthStatus::Failed => return Err(Error::Auth),
|
||||
}
|
||||
|
||||
tracing::info!("Authenticated successfully");
|
||||
|
||||
let mut start = Instant::now();
|
||||
|
||||
let mut interval = time::interval(rkvm_net::PING_INTERVAL + rkvm_net::READ_TIMEOUT);
|
||||
let mut writers = HashMap::new();
|
||||
|
||||
// Interval ticks immediately after creation.
|
||||
interval.tick().await;
|
||||
|
||||
loop {
|
||||
match Update::decode(&mut stream).await.map_err(Error::Network)? {
|
||||
let update = tokio::select! {
|
||||
update = Update::decode(&mut stream) => update.map_err(Error::Network)?,
|
||||
_ = interval.tick() => return Err(Error::Network(io::Error::new(io::ErrorKind::TimedOut, "Ping timed out"))),
|
||||
};
|
||||
|
||||
match update {
|
||||
Update::CreateDevice {
|
||||
id,
|
||||
name,
|
||||
|
@ -150,6 +180,25 @@ pub async fn run(
|
|||
|
||||
tracing::trace!("Wrote an event to device {}", id);
|
||||
}
|
||||
Update::Ping => {
|
||||
let duration = start.elapsed();
|
||||
tracing::debug!(duration = ?duration, "Received ping");
|
||||
|
||||
start = Instant::now();
|
||||
interval.reset();
|
||||
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
Pong.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
|
||||
let duration = start.elapsed();
|
||||
tracing::debug!(duration = ?duration, "Sent pong");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,21 @@ use rkvm_input::rel::RelAxis;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::ffi::CString;
|
||||
use std::future::Future;
|
||||
use std::io::{Error, ErrorKind};
|
||||
use std::time::Duration;
|
||||
use tokio::time;
|
||||
|
||||
pub const PING_INTERVAL: Duration = Duration::from_secs(1);
|
||||
|
||||
// Message read timeout (does not apply to updates, only auth negotiation and replies).
|
||||
pub const READ_TIMEOUT: Duration = Duration::from_millis(500);
|
||||
|
||||
// Message write timeout (applies to all messages).
|
||||
pub const WRITE_TIMEOUT: Duration = Duration::from_millis(500);
|
||||
|
||||
// TLS negotiation timeout.
|
||||
pub const TLS_TIMEOUT: Duration = Duration::from_millis(500);
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub enum Update {
|
||||
|
@ -29,4 +44,31 @@ pub enum Update {
|
|||
id: usize,
|
||||
event: Event,
|
||||
},
|
||||
Ping,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct Pong;
|
||||
|
||||
pub async fn timeout<T: Future<Output = Result<U, Error>>, U>(
|
||||
duration: Duration,
|
||||
future: T,
|
||||
) -> Result<U, Error> {
|
||||
time::timeout(duration, future)
|
||||
.await
|
||||
.map_err(|err| Error::new(ErrorKind::TimedOut, err))?
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::message::Message;
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn pong_is_not_empty() {
|
||||
let mut data = Vec::new();
|
||||
Pong.encode(&mut data).await.unwrap();
|
||||
|
||||
assert!(!data.is_empty());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,13 +7,13 @@ use rkvm_input::sync::SyncEvent;
|
|||
use rkvm_net::auth::{AuthChallenge, AuthResponse, AuthStatus};
|
||||
use rkvm_net::message::Message;
|
||||
use rkvm_net::version::Version;
|
||||
use rkvm_net::Update;
|
||||
use rkvm_net::{Pong, Update};
|
||||
use slab::Slab;
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
use std::ffi::CString;
|
||||
use std::io::{self, ErrorKind};
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncWriteExt, BufStream};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
@ -291,16 +291,20 @@ async fn client(
|
|||
acceptor: TlsAcceptor,
|
||||
password: &str,
|
||||
) -> Result<(), ClientError> {
|
||||
let negotiate = async {
|
||||
let stream = acceptor.accept(stream).await?;
|
||||
let stream = rkvm_net::timeout(rkvm_net::TLS_TIMEOUT, acceptor.accept(stream)).await?;
|
||||
tracing::info!("TLS connected");
|
||||
|
||||
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
||||
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
Version::CURRENT.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let version = Version::decode(&mut stream).await?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
let version = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Version::decode(&mut stream)).await?;
|
||||
if version != Version::CURRENT {
|
||||
return Err(ClientError::Version {
|
||||
server: Version::CURRENT,
|
||||
|
@ -310,69 +314,81 @@ async fn client(
|
|||
|
||||
let challenge = AuthChallenge::generate().await?;
|
||||
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
challenge.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let response = AuthResponse::decode(&mut stream).await?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
let response =
|
||||
rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthResponse::decode(&mut stream)).await?;
|
||||
let status = match response.verify(&challenge, password) {
|
||||
true => AuthStatus::Passed,
|
||||
false => AuthStatus::Failed,
|
||||
};
|
||||
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
status.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
if status == AuthStatus::Failed {
|
||||
return Err(ClientError::Auth);
|
||||
}
|
||||
|
||||
tracing::info!("Authenticated successfully");
|
||||
|
||||
Ok(stream)
|
||||
};
|
||||
|
||||
let mut stream = time::timeout(Duration::from_secs(1), negotiate)
|
||||
.await
|
||||
.map_err(|_| io::Error::new(ErrorKind::TimedOut, "Negotiation took too long"))??;
|
||||
let mut interval = time::interval(rkvm_net::PING_INTERVAL);
|
||||
|
||||
loop {
|
||||
let recv = async {
|
||||
match init_updates.pop_front() {
|
||||
Some(update) => Some((update, false)),
|
||||
None => receiver.recv().await.map(|update| (update, true)),
|
||||
Some(update) => Some(update),
|
||||
None => receiver.recv().await,
|
||||
}
|
||||
};
|
||||
|
||||
let (update, more) = match recv.await {
|
||||
Some(um) => um,
|
||||
let update = tokio::select! {
|
||||
// Make sure pings have priority.
|
||||
// The client could time out otherwise.
|
||||
biased;
|
||||
|
||||
_ = interval.tick() => Some(Update::Ping),
|
||||
recv = recv => recv,
|
||||
};
|
||||
|
||||
let update = match update {
|
||||
Some(update) => update,
|
||||
None => break,
|
||||
};
|
||||
|
||||
let mut count = 1;
|
||||
|
||||
let write = async {
|
||||
let start = Instant::now();
|
||||
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
|
||||
update.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
// Coalesce multiple consecutive updates into one chunk.
|
||||
if more {
|
||||
while let Ok(update) = receiver.try_recv() {
|
||||
update.encode(&mut stream).await?;
|
||||
count += 1;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
let duration = start.elapsed();
|
||||
|
||||
if let Update::Ping = update {
|
||||
// Keeping these as debug because it's not as frequent as other updates.
|
||||
tracing::debug!(duration = ?duration, "Sent ping");
|
||||
|
||||
let start = Instant::now();
|
||||
rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Pong::decode(&mut stream)).await?;
|
||||
let duration = start.elapsed();
|
||||
|
||||
tracing::debug!(duration = ?duration, "Received pong");
|
||||
}
|
||||
|
||||
stream.flush().await
|
||||
};
|
||||
|
||||
time::timeout(Duration::from_millis(500), write)
|
||||
.await
|
||||
.map_err(|_| io::Error::new(ErrorKind::TimedOut, "Update writing took too long"))??;
|
||||
|
||||
tracing::trace!(
|
||||
"Wrote {} update{}",
|
||||
count,
|
||||
if count == 1 { "" } else { "s" }
|
||||
);
|
||||
tracing::trace!("Wrote an update");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
Loading…
Reference in a new issue