Implement client IO timeout, improve logs

This commit is contained in:
Jan Trefil 2023-04-22 14:03:57 +02:00
parent ca3733e357
commit c210f2c85a
2 changed files with 56 additions and 41 deletions

View file

@ -34,15 +34,17 @@ pub async fn run(
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port))
.await
.map_err(Error::Network)?,
_ => unimplemented!("Unhandled rustls ServerName variant"),
_ => unimplemented!("Unhandled rustls ServerName variant: {:?}", hostname),
};
log::info!("Connected to server");
let stream = connector
.connect(hostname.clone(), stream)
.await
.map_err(Error::Network)?;
log::info!("Connected to server");
log::info!("TLS connected");
let mut stream = BufStream::with_capacity(1024, 1024, stream);
@ -73,7 +75,7 @@ pub async fn run(
AuthStatus::Failed => return Err(Error::Auth),
}
log::info!("Passed auth check");
log::info!("Authenticated successfully");
let mut writer = EventWriter::new().await.map_err(Error::Input)?;
loop {

View file

@ -6,11 +6,12 @@ use slab::Slab;
use std::collections::HashSet;
use std::io;
use std::net::SocketAddr;
use std::time::Duration;
use thiserror::Error;
use tokio::io::{AsyncWriteExt, BufStream};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio_rustls::server::TlsStream;
use tokio::time;
use tokio_rustls::TlsAcceptor;
#[derive(Error, Debug)]
@ -53,17 +54,9 @@ pub async fn run(
clients.insert(sender);
tokio::spawn(async move {
let stream = match acceptor.accept(stream).await {
Ok(stream) => stream,
Err(err) => {
log::error!("{}: TLS accept error: {}", addr, err);
return;
}
};
log::info!("{}: Connected", addr);
match client(receiver, stream, addr, &password).await {
match client(receiver, stream, addr, acceptor, &password).await {
Ok(()) => log::info!("{}: Disconnected", addr),
Err(err) => log::error!("{}: Disconnected: {}", addr, err),
}
@ -135,46 +128,66 @@ enum ClientError {
async fn client(
mut receiver: Receiver<EventPack>,
stream: TlsStream<TcpStream>,
stream: TcpStream,
addr: SocketAddr,
acceptor: TlsAcceptor,
password: &str,
) -> Result<(), ClientError> {
let mut stream = BufStream::with_capacity(1024, 1024, stream);
let negotiate = async {
let stream = acceptor.accept(stream).await?;
log::info!("{}: TLS connected", addr);
Version::CURRENT.encode(&mut stream).await?;
stream.flush().await?;
let mut stream = BufStream::with_capacity(1024, 1024, stream);
let version = Version::decode(&mut stream).await?;
if version != Version::CURRENT {
return Err(ClientError::Version {
server: Version::CURRENT,
client: version,
});
}
Version::CURRENT.encode(&mut stream).await?;
stream.flush().await?;
let challenge = AuthChallenge::generate().await?;
let version = Version::decode(&mut stream).await?;
if version != Version::CURRENT {
return Err(ClientError::Version {
server: Version::CURRENT,
client: version,
});
}
challenge.encode(&mut stream).await?;
stream.flush().await?;
let challenge = AuthChallenge::generate().await?;
let response = AuthResponse::decode(&mut stream).await?;
let status = match response.verify(&challenge, password) {
true => AuthStatus::Passed,
false => AuthStatus::Failed,
challenge.encode(&mut stream).await?;
stream.flush().await?;
let response = AuthResponse::decode(&mut stream).await?;
let status = match response.verify(&challenge, password) {
true => AuthStatus::Passed,
false => AuthStatus::Failed,
};
status.encode(&mut stream).await?;
stream.flush().await?;
if status == AuthStatus::Failed {
return Err(ClientError::Auth);
}
log::info!("{}: Authenticated successfully", addr);
Ok(stream)
};
status.encode(&mut stream).await?;
stream.flush().await?;
if status == AuthStatus::Failed {
return Err(ClientError::Auth);
}
log::info!("{}: Passed auth check", addr);
let mut stream = time::timeout(Duration::from_secs(1), negotiate)
.await
.map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "Negotiation took too long"))??;
while let Some(events) = receiver.recv().await {
events.encode(&mut stream).await?;
stream.flush().await?;
let write = async {
events.encode(&mut stream).await?;
stream.flush().await
};
time::timeout(Duration::from_millis(500), write)
.await
.map_err(|_| {
io::Error::new(io::ErrorKind::TimedOut, "Event writing took too long")
})??;
log::trace!(
"{}: Sent {} event{}",