Implement application layer timeouts

This commit is contained in:
Jan Trefil 2023-10-11 09:09:18 +02:00
parent 65083f47c1
commit 9560c761e5
3 changed files with 180 additions and 73 deletions

View file

@ -2,13 +2,15 @@ use rkvm_input::writer::Writer;
use rkvm_net::auth::{AuthChallenge, AuthStatus};
use rkvm_net::message::Message;
use rkvm_net::version::Version;
use rkvm_net::Update;
use rkvm_net::{Pong, Update};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::io;
use std::time::Instant;
use thiserror::Error;
use tokio::io::{AsyncWriteExt, BufStream};
use tokio::net::TcpStream;
use tokio::time;
use tokio_rustls::rustls::ServerName;
use tokio_rustls::TlsConnector;
@ -30,34 +32,40 @@ pub async fn run(
connector: TlsConnector,
password: &str,
) -> Result<(), Error> {
// Intentionally don't impose any timeout for TCP connect.
let stream = match hostname {
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port))
.await
.map_err(Error::Network)?,
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port))
.await
.map_err(Error::Network)?,
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port)).await,
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port)).await,
_ => unimplemented!("Unhandled rustls ServerName variant: {:?}", hostname),
};
}
.map_err(Error::Network)?;
tracing::info!("Connected to server");
let stream = connector
.connect(hostname.clone(), stream)
.await
.map_err(Error::Network)?;
let stream = rkvm_net::timeout(
rkvm_net::TLS_TIMEOUT,
connector.connect(hostname.clone(), stream),
)
.await
.map_err(Error::Network)?;
tracing::info!("TLS connected");
let mut stream = BufStream::with_capacity(1024, 1024, stream);
Version::CURRENT
.encode(&mut stream)
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
Version::CURRENT.encode(&mut stream).await?;
stream.flush().await?;
Ok(())
})
.await
.map_err(Error::Network)?;
let version = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Version::decode(&mut stream))
.await
.map_err(Error::Network)?;
stream.flush().await.map_err(Error::Network)?;
let version = Version::decode(&mut stream).await.map_err(Error::Network)?;
if version != Version::CURRENT {
return Err(Error::Version {
server: Version::CURRENT,
@ -65,25 +73,47 @@ pub async fn run(
});
}
let challenge = AuthChallenge::decode(&mut stream)
let challenge = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthChallenge::decode(&mut stream))
.await
.map_err(Error::Network)?;
let response = challenge.respond(password);
response.encode(&mut stream).await.map_err(Error::Network)?;
stream.flush().await.map_err(Error::Network)?;
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
response.encode(&mut stream).await?;
stream.flush().await?;
match Message::decode(&mut stream).await.map_err(Error::Network)? {
Ok(())
})
.await
.map_err(Error::Network)?;
let status = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthStatus::decode(&mut stream))
.await
.map_err(Error::Network)?;
match status {
AuthStatus::Passed => {}
AuthStatus::Failed => return Err(Error::Auth),
}
tracing::info!("Authenticated successfully");
let mut start = Instant::now();
let mut interval = time::interval(rkvm_net::PING_INTERVAL + rkvm_net::READ_TIMEOUT);
let mut writers = HashMap::new();
// Interval ticks immediately after creation.
interval.tick().await;
loop {
match Update::decode(&mut stream).await.map_err(Error::Network)? {
let update = tokio::select! {
update = Update::decode(&mut stream) => update.map_err(Error::Network)?,
_ = interval.tick() => return Err(Error::Network(io::Error::new(io::ErrorKind::TimedOut, "Ping timed out"))),
};
match update {
Update::CreateDevice {
id,
name,
@ -150,6 +180,25 @@ pub async fn run(
tracing::trace!("Wrote an event to device {}", id);
}
Update::Ping => {
let duration = start.elapsed();
tracing::debug!(duration = ?duration, "Received ping");
start = Instant::now();
interval.reset();
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
Pong.encode(&mut stream).await?;
stream.flush().await?;
Ok(())
})
.await
.map_err(Error::Network)?;
let duration = start.elapsed();
tracing::debug!(duration = ?duration, "Sent pong");
}
}
}
}

View file

@ -9,6 +9,21 @@ use rkvm_input::rel::RelAxis;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::ffi::CString;
use std::future::Future;
use std::io::{Error, ErrorKind};
use std::time::Duration;
use tokio::time;
pub const PING_INTERVAL: Duration = Duration::from_secs(1);
// Message read timeout (does not apply to updates, only auth negotiation and replies).
pub const READ_TIMEOUT: Duration = Duration::from_millis(500);
// Message write timeout (applies to all messages).
pub const WRITE_TIMEOUT: Duration = Duration::from_millis(500);
// TLS negotiation timeout.
pub const TLS_TIMEOUT: Duration = Duration::from_millis(500);
#[derive(Deserialize, Serialize, Debug)]
pub enum Update {
@ -29,4 +44,31 @@ pub enum Update {
id: usize,
event: Event,
},
Ping,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct Pong;
pub async fn timeout<T: Future<Output = Result<U, Error>>, U>(
duration: Duration,
future: T,
) -> Result<U, Error> {
time::timeout(duration, future)
.await
.map_err(|err| Error::new(ErrorKind::TimedOut, err))?
}
#[cfg(test)]
mod test {
use super::message::Message;
use super::*;
#[tokio::test]
async fn pong_is_not_empty() {
let mut data = Vec::new();
Pong.encode(&mut data).await.unwrap();
assert!(!data.is_empty());
}
}

View file

@ -7,13 +7,13 @@ use rkvm_input::sync::SyncEvent;
use rkvm_net::auth::{AuthChallenge, AuthResponse, AuthStatus};
use rkvm_net::message::Message;
use rkvm_net::version::Version;
use rkvm_net::Update;
use rkvm_net::{Pong, Update};
use slab::Slab;
use std::collections::{HashMap, HashSet, VecDeque};
use std::ffi::CString;
use std::io::{self, ErrorKind};
use std::net::SocketAddr;
use std::time::Duration;
use std::time::Instant;
use thiserror::Error;
use tokio::io::{AsyncWriteExt, BufStream};
use tokio::net::{TcpListener, TcpStream};
@ -291,88 +291,104 @@ async fn client(
acceptor: TlsAcceptor,
password: &str,
) -> Result<(), ClientError> {
let negotiate = async {
let stream = acceptor.accept(stream).await?;
tracing::info!("TLS connected");
let stream = rkvm_net::timeout(rkvm_net::TLS_TIMEOUT, acceptor.accept(stream)).await?;
tracing::info!("TLS connected");
let mut stream = BufStream::with_capacity(1024, 1024, stream);
let mut stream = BufStream::with_capacity(1024, 1024, stream);
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
Version::CURRENT.encode(&mut stream).await?;
stream.flush().await?;
let version = Version::decode(&mut stream).await?;
if version != Version::CURRENT {
return Err(ClientError::Version {
server: Version::CURRENT,
client: version,
});
}
Ok(())
})
.await?;
let challenge = AuthChallenge::generate().await?;
let version = rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Version::decode(&mut stream)).await?;
if version != Version::CURRENT {
return Err(ClientError::Version {
server: Version::CURRENT,
client: version,
});
}
let challenge = AuthChallenge::generate().await?;
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
challenge.encode(&mut stream).await?;
stream.flush().await?;
let response = AuthResponse::decode(&mut stream).await?;
let status = match response.verify(&challenge, password) {
true => AuthStatus::Passed,
false => AuthStatus::Failed,
};
Ok(())
})
.await?;
let response =
rkvm_net::timeout(rkvm_net::READ_TIMEOUT, AuthResponse::decode(&mut stream)).await?;
let status = match response.verify(&challenge, password) {
true => AuthStatus::Passed,
false => AuthStatus::Failed,
};
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
status.encode(&mut stream).await?;
stream.flush().await?;
if status == AuthStatus::Failed {
return Err(ClientError::Auth);
}
Ok(())
})
.await?;
tracing::info!("Authenticated successfully");
if status == AuthStatus::Failed {
return Err(ClientError::Auth);
}
Ok(stream)
};
tracing::info!("Authenticated successfully");
let mut stream = time::timeout(Duration::from_secs(1), negotiate)
.await
.map_err(|_| io::Error::new(ErrorKind::TimedOut, "Negotiation took too long"))??;
let mut interval = time::interval(rkvm_net::PING_INTERVAL);
loop {
let recv = async {
match init_updates.pop_front() {
Some(update) => Some((update, false)),
None => receiver.recv().await.map(|update| (update, true)),
Some(update) => Some(update),
None => receiver.recv().await,
}
};
let (update, more) = match recv.await {
Some(um) => um,
let update = tokio::select! {
// Make sure pings have priority.
// The client could time out otherwise.
biased;
_ = interval.tick() => Some(Update::Ping),
recv = recv => recv,
};
let update = match update {
Some(update) => update,
None => break,
};
let mut count = 1;
let write = async {
let start = Instant::now();
rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
update.encode(&mut stream).await?;
stream.flush().await?;
// Coalesce multiple consecutive updates into one chunk.
if more {
while let Ok(update) = receiver.try_recv() {
update.encode(&mut stream).await?;
count += 1;
}
}
Ok(())
})
.await?;
let duration = start.elapsed();
stream.flush().await
};
if let Update::Ping = update {
// Keeping these as debug because it's not as frequent as other updates.
tracing::debug!(duration = ?duration, "Sent ping");
time::timeout(Duration::from_millis(500), write)
.await
.map_err(|_| io::Error::new(ErrorKind::TimedOut, "Update writing took too long"))??;
let start = Instant::now();
rkvm_net::timeout(rkvm_net::READ_TIMEOUT, Pong::decode(&mut stream)).await?;
let duration = start.elapsed();
tracing::trace!(
"Wrote {} update{}",
count,
if count == 1 { "" } else { "s" }
);
tracing::debug!(duration = ?duration, "Received pong");
}
tracing::trace!("Wrote an update");
}
Ok(())