Refactor client

This commit is contained in:
Jan Trefil 2023-04-16 18:21:13 +02:00
parent b6c7ff6f85
commit b8a403ae7d
7 changed files with 116 additions and 260 deletions

161
Cargo.lock generated
View file

@ -250,22 +250,6 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "core-foundation"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]]
name = "env_logger"
version = "0.7.1"
@ -322,21 +306,6 @@ dependencies = [
"instant",
]
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "futures"
version = "0.3.28"
@ -610,24 +579,6 @@ dependencies = [
"windows-sys 0.45.0",
]
[[package]]
name = "native-tls"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e"
dependencies = [
"lazy_static",
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "nom"
version = "5.1.2"
@ -654,50 +605,6 @@ version = "1.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
[[package]]
name = "openssl"
version = "0.10.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e30d8bc91859781f0a943411186324d580f2bbeb71b452fe91ae344806af3f1"
dependencies = [
"bitflags",
"cfg-if 1.0.0",
"foreign-types",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.15",
]
[[package]]
name = "openssl-probe"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "openssl-sys"
version = "0.9.85"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d3d193fb1488ad46ffe3aaabc912cc931d02ee8518fe2959aea8ef52718b0c0"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "peeking_take_while"
version = "0.1.2"
@ -824,15 +731,16 @@ dependencies = [
name = "rkvm-client"
version = "0.2.0"
dependencies = [
"anyhow",
"clap 4.2.2",
"env_logger 0.8.4",
"log",
"rkvm-input",
"rkvm-net",
"rustls-pemfile",
"serde",
"thiserror",
"tokio",
"tokio-native-tls",
"tokio-rustls",
"toml",
]
@ -929,15 +837,6 @@ dependencies = [
"untrusted",
]
[[package]]
name = "schannel"
version = "0.1.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3"
dependencies = [
"windows-sys 0.42.0",
]
[[package]]
name = "sct"
version = "0.7.0"
@ -948,29 +847,6 @@ dependencies = [
"untrusted",
]
[[package]]
name = "security-framework"
version = "2.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a332be01508d814fed64bf28f798a146d73792121129962fdf335bb3c49a4254"
dependencies = [
"bitflags",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31c9bb296072e961fcbd8853511dd39c2d8be2deb1e17c6860b1d30732b323b4"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "serde"
version = "1.0.160"
@ -1169,16 +1045,6 @@ dependencies = [
"syn 2.0.15",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.24.0"
@ -1228,12 +1094,6 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "vec_map"
version = "0.8.2"
@ -1356,21 +1216,6 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
dependencies = [
"windows_aarch64_gnullvm 0.42.2",
"windows_aarch64_msvc 0.42.2",
"windows_i686_gnu 0.42.2",
"windows_i686_msvc 0.42.2",
"windows_x86_64_gnu 0.42.2",
"windows_x86_64_gnullvm 0.42.2",
"windows_x86_64_msvc 0.42.2",
]
[[package]]
name = "windows-sys"
version = "0.45.0"

View file

@ -3,7 +3,7 @@ name = "rkvm-client"
license = "MIT"
version = "0.2.0"
authors = ["Jan Trefil <8711792+htrefil@users.noreply.github.com>"]
edition = "2018"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@ -15,9 +15,10 @@ serde = { version = "1.0.117", features = ["derive"] }
toml = "0.5.7"
log = "0.4.11"
env_logger = "0.8.1"
tokio-native-tls = "0.3.0"
anyhow = "1.0.33"
clap = { version = "4.2.2", features = ["derive"] }
thiserror = "1.0.40"
tokio-rustls = "0.24.0"
rustls-pemfile = "1.0.2"
[package.metadata.rpm]
package = "rkvm-client"

View file

@ -2,16 +2,16 @@ use serde::de::{self, Visitor};
use serde::{Deserialize, Deserializer};
use std::fmt::{self, Formatter};
use std::path::PathBuf;
use tokio_rustls::rustls::ServerName;
#[derive(Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct Config {
pub server: Server,
pub certificate_path: PathBuf,
pub certificate: PathBuf,
}
pub struct Server {
pub hostname: String,
pub hostname: ServerName,
pub port: u16,
}
@ -37,22 +37,13 @@ impl<'de> Visitor<'de> for ServerVisitor {
where
E: de::Error,
{
let err = || E::custom("Invalid server description");
let (hostname, port) = data
.rsplit_once(':')
.ok_or_else(|| E::custom("No port provided"))?;
let mut split = data.split(':');
let hostname = split.next().ok_or_else(err)?;
let port = split
.next()
.and_then(|data| data.parse().ok())
.ok_or_else(err)?;
let hostname = hostname.try_into().map_err(E::custom)?;
let port = port.parse().map_err(E::custom)?;
if split.next().is_some() {
return Err(E::custom("Extraneous data"));
}
Ok(Server {
hostname: hostname.to_owned(),
port,
})
Ok(Server { hostname, port })
}
}

View file

@ -1,65 +1,19 @@
mod config;
mod tls;
use anyhow::{Context, Error};
use config::Config;
use rkvm_input::EventWriter;
use log::LevelFilter;
use rkvm_net::{self, Message, PROTOCOL_VERSION};
use std::convert::Infallible;
use std::path::{Path, PathBuf};
use std::process;
use clap::Parser;
use config::Config;
use log::LevelFilter;
use rkvm_input::EventWriter;
use std::io::{Error, ErrorKind};
use std::path::PathBuf;
use std::process::ExitCode;
use tokio::fs;
use tokio::io::BufReader;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::time;
use tokio_native_tls::native_tls::{Certificate, TlsConnector};
async fn run(server: &str, port: u16, certificate_path: &Path) -> Result<Infallible, Error> {
let certificate = fs::read(certificate_path)
.await
.context("Failed to read certificate")?;
let certificate = Certificate::from_der(&certificate)
.or_else(|_| Certificate::from_pem(&certificate))
.context("Failed to parse certificate")?;
let connector: tokio_native_tls::TlsConnector = TlsConnector::builder()
.add_root_certificate(certificate)
.build()
.context("Failed to create connector")?
.into();
let stream = TcpStream::connect((server, port)).await?;
let stream = BufReader::new(stream);
let mut stream = connector
.connect(server, stream)
.await
.context("Failed to connect")?;
log::info!("Connected to {}:{}", server, port);
rkvm_net::write_version(&mut stream, PROTOCOL_VERSION).await?;
let version = rkvm_net::read_version(&mut stream).await?;
if version != PROTOCOL_VERSION {
return Err(anyhow::anyhow!(
"Incompatible protocol version (got {}, expecting {})",
version,
PROTOCOL_VERSION
));
}
let mut writer = EventWriter::new().await?;
loop {
let message = time::timeout(rkvm_net::MESSAGE_TIMEOUT, rkvm_net::read_message(&mut stream))
.await
.context("Read timed out")??;
match message {
Message::Event(event) => writer.write(event).await?,
Message::KeepAlive => {}
}
}
}
use tokio::signal;
use tokio_rustls::rustls::ServerName;
use tokio_rustls::TlsConnector;
#[derive(Parser)]
#[structopt(name = "rkvm-client", about = "The rkvm client application")]
@ -69,43 +23,83 @@ struct Args {
}
#[tokio::main]
async fn main() {
async fn main() -> ExitCode {
env_logger::builder()
.format_timestamp(None)
.filter(None, LevelFilter::Info)
.parse_default_env()
.init();
let args = Args::parse();
let config = match fs::read_to_string(&args.config_path).await {
Ok(config) => config,
Err(err) => {
log::error!("Error loading config: {}", err);
process::exit(1);
log::error!("Error reading config: {}", err);
return ExitCode::FAILURE;
}
};
let config: Config = match toml::from_str(&config) {
let config = match toml::from_str::<Config>(&config) {
Ok(config) => config,
Err(err) => {
log::error!("Error parsing config: {}", err);
process::exit(1);
return ExitCode::FAILURE;
}
};
let connector = match tls::configure(&config.certificate).await {
Ok(connector) => connector,
Err(err) => {
log::error!("Error configuring TLS: {}", err);
return ExitCode::FAILURE;
}
};
tokio::select! {
result = run(&config.server.hostname, config.server.port, &config.certificate_path) => {
result = run(&config.server.hostname, config.server.port, connector) => {
if let Err(err) = result {
log::error!("Error: {:#}", err);
process::exit(1);
log::error!("Error running client: {}", err);
return ExitCode::FAILURE;
}
}
result = tokio::signal::ctrl_c() => {
// This is needed to properly clean libevent stuff up.
result = signal::ctrl_c() => {
if let Err(err) = result {
log::error!("Error setting up signal handler: {}", err);
process::exit(1);
return ExitCode::FAILURE;
}
log::info!("Exiting on signal");
}
}
ExitCode::SUCCESS
}
async fn run(hostname: &ServerName, port: u16, connector: TlsConnector) -> Result<(), Error> {
let stream = match hostname {
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port)).await?,
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port)).await?,
_ => unimplemented!("Unhandled rustls ServerName variant"),
};
let mut stream = connector.connect(hostname.clone(), stream).await?;
log::info!("Connected to server");
rkvm_net::write_version(&mut stream, rkvm_net::PROTOCOL_VERSION).await?;
stream.flush().await?;
let version = rkvm_net::read_version(&mut stream).await?;
if version != rkvm_net::PROTOCOL_VERSION {
return Err(Error::new(
ErrorKind::InvalidData,
"Invalid server protocol version",
));
}
let mut writer = EventWriter::new().await?;
loop {
let event = rkvm_net::read_message(&mut stream).await?;
writer.write(event).await?;
}
}

33
rkvm-client/src/tls.rs Normal file
View file

@ -0,0 +1,33 @@
use std::sync::Arc;
use std::{io, path::Path};
use thiserror::Error;
use tokio::fs;
use tokio_rustls::rustls::{self, Certificate, ClientConfig, RootCertStore};
use tokio_rustls::TlsConnector;
#[derive(Error, Debug)]
pub enum Error {
#[error(transparent)]
Rustls(#[from] rustls::Error),
#[error(transparent)]
Io(#[from] io::Error),
}
pub async fn configure(certificate: &Path) -> Result<TlsConnector, Error> {
let certificate = fs::read(certificate).await?;
let certificates = rustls_pemfile::certs(&mut certificate.as_slice())?;
let mut store = RootCertStore::empty();
for certificate in certificates {
store.add(&Certificate(certificate))?;
}
let config = Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(store)
.with_no_client_auth(),
);
Ok(config.into())
}

View file

@ -1,8 +1,7 @@
use rkvm_input::Event;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::convert::TryInto;
use std::io::{Error, ErrorKind};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
// Is it bold to assume there won't be more than 65536 protocol versions?
@ -25,9 +24,10 @@ where
writer.write_all(&version.to_le_bytes()).await
}
pub async fn read_message<R>(mut reader: R) -> Result<Message, Error>
pub async fn read_message<R, T>(mut reader: R) -> Result<T, Error>
where
R: AsyncRead + Unpin,
T: DeserializeOwned,
{
let length = {
let mut bytes = [0; 1];
@ -42,7 +42,7 @@ where
bincode::deserialize(&data).map_err(|err| Error::new(ErrorKind::InvalidData, err))
}
pub async fn write_message<W>(mut writer: W, message: &Message) -> Result<(), Error>
pub async fn write_message<W, T: Serialize>(mut writer: W, message: &T) -> Result<(), Error>
where
W: AsyncWrite + Unpin,
{
@ -57,10 +57,3 @@ where
Ok(())
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum Message {
Event(Event),
// Sent only to keep the connection alive.
KeepAlive,
}

View file

@ -5,7 +5,6 @@ use clap::Parser;
use config::Config;
use log::LevelFilter;
use rkvm_input::{Direction, Event, EventManager, Key, KeyKind};
use rkvm_net::{self, Message};
use slab::Slab;
use std::io::{Error, ErrorKind};
use std::net::SocketAddr;
@ -65,7 +64,7 @@ async fn main() -> ExitCode {
return ExitCode::FAILURE;
}
}
// This is needed to properly clean evdev stuff up.
// This is needed to properly clean libevent stuff up.
result = signal::ctrl_c() => {
if let Err(err) = result {
log::error!("Error setting up signal handler: {}", err);
@ -130,7 +129,7 @@ async fn run(listen: SocketAddr, acceptor: TlsAcceptor, switch_key: Key) -> Resu
None => break,
};
rkvm_net::write_message(&mut stream, &Message::Event(event)).await?;
rkvm_net::write_message(&mut stream, &event).await?;
stream.flush().await?;
}