mirror of
https://github.com/htrefil/rkvm.git
synced 2024-11-16 07:47:24 +01:00
Implement authentication
This commit is contained in:
parent
e56ea03351
commit
26e9e78271
14 changed files with 441 additions and 154 deletions
132
Cargo.lock
generated
132
Cargo.lock
generated
|
@ -142,6 +142,15 @@ version = "1.3.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "block-buffer"
|
||||||
|
version = "0.10.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||||
|
dependencies = [
|
||||||
|
"generic-array",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bumpalo"
|
name = "bumpalo"
|
||||||
version = "3.12.0"
|
version = "3.12.0"
|
||||||
|
@ -255,6 +264,36 @@ version = "1.0.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
|
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cpufeatures"
|
||||||
|
version = "0.2.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "280a9f2d8b3a38871a3c8a46fb80db65e5e5ed97da80c4d08bf27fb63e35e181"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crypto-common"
|
||||||
|
version = "0.1.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||||
|
dependencies = [
|
||||||
|
"generic-array",
|
||||||
|
"typenum",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "digest"
|
||||||
|
version = "0.10.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f"
|
||||||
|
dependencies = [
|
||||||
|
"block-buffer",
|
||||||
|
"crypto-common",
|
||||||
|
"subtle",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "env_logger"
|
name = "env_logger"
|
||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
|
@ -400,6 +439,27 @@ dependencies = [
|
||||||
"slab",
|
"slab",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "generic-array"
|
||||||
|
version = "0.14.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
|
||||||
|
dependencies = [
|
||||||
|
"typenum",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "getrandom"
|
||||||
|
version = "0.2.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"libc",
|
||||||
|
"wasi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "glob"
|
name = "glob"
|
||||||
version = "0.3.1"
|
version = "0.3.1"
|
||||||
|
@ -436,6 +496,15 @@ version = "0.3.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
|
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hmac"
|
||||||
|
version = "0.12.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
|
||||||
|
dependencies = [
|
||||||
|
"digest",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "humantime"
|
name = "humantime"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
|
@ -625,6 +694,12 @@ version = "0.3.26"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
|
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ppv-lite86"
|
||||||
|
version = "0.2.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro2"
|
||||||
version = "1.0.56"
|
version = "1.0.56"
|
||||||
|
@ -649,6 +724,36 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand"
|
||||||
|
version = "0.8.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"rand_chacha",
|
||||||
|
"rand_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_chacha"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||||
|
dependencies = [
|
||||||
|
"ppv-lite86",
|
||||||
|
"rand_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_core"
|
||||||
|
version = "0.6.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
version = "0.3.5"
|
version = "0.3.5"
|
||||||
|
@ -736,8 +841,11 @@ version = "0.2.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"bincode",
|
"bincode",
|
||||||
|
"hmac",
|
||||||
|
"rand",
|
||||||
"rkvm-input",
|
"rkvm-input",
|
||||||
"serde",
|
"serde",
|
||||||
|
"sha2",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
@ -749,6 +857,7 @@ dependencies = [
|
||||||
"clap 4.2.2",
|
"clap 4.2.2",
|
||||||
"env_logger 0.8.4",
|
"env_logger 0.8.4",
|
||||||
"log",
|
"log",
|
||||||
|
"rand",
|
||||||
"rkvm-input",
|
"rkvm-input",
|
||||||
"rkvm-net",
|
"rkvm-net",
|
||||||
"rustls-pemfile",
|
"rustls-pemfile",
|
||||||
|
@ -841,6 +950,17 @@ dependencies = [
|
||||||
"syn 2.0.15",
|
"syn 2.0.15",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sha2"
|
||||||
|
version = "0.10.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"cpufeatures",
|
||||||
|
"digest",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "shlex"
|
name = "shlex"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
|
@ -893,6 +1013,12 @@ version = "0.10.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "subtle"
|
||||||
|
version = "2.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "1.0.109"
|
version = "1.0.109"
|
||||||
|
@ -1014,6 +1140,12 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typenum"
|
||||||
|
version = "1.16.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-ident"
|
name = "unicode-ident"
|
||||||
version = "1.0.8"
|
version = "1.0.8"
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
server = "localhost:5258"
|
server = "localhost:5258"
|
||||||
certificate = "example/certificate.crt"
|
certificate = "example/certificate.crt"
|
||||||
|
password = "123456789"
|
|
@ -7,6 +7,6 @@ edition = "2021"
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
clap = "4.2.2"
|
clap = { version = "4.2.2", features = ["derive"] }
|
||||||
tempfile = "3.1.0"
|
tempfile = "3.1.0"
|
||||||
thiserror = "1.0.40"
|
thiserror = "1.0.40"
|
||||||
|
|
86
rkvm-client/src/client.rs
Normal file
86
rkvm-client/src/client.rs
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
use rkvm_input::EventWriter;
|
||||||
|
use rkvm_net::auth::{AuthChallenge, AuthStatus};
|
||||||
|
use rkvm_net::message::Message;
|
||||||
|
use rkvm_net::version::Version;
|
||||||
|
use std::io;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncWriteExt, BufStream};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio_rustls::rustls::ServerName;
|
||||||
|
use tokio_rustls::TlsConnector;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum Error {
|
||||||
|
#[error("Network error: {0}")]
|
||||||
|
Network(io::Error),
|
||||||
|
#[error("Input error: {0}")]
|
||||||
|
Input(io::Error),
|
||||||
|
#[error("Incompatible server version (got {server}, expected {client})")]
|
||||||
|
Version { server: Version, client: Version },
|
||||||
|
#[error("Auth challenge failed (possibly wrong password)")]
|
||||||
|
Auth,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(
|
||||||
|
hostname: &ServerName,
|
||||||
|
port: u16,
|
||||||
|
connector: TlsConnector,
|
||||||
|
password: &str,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let stream = match hostname {
|
||||||
|
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port))
|
||||||
|
.await
|
||||||
|
.map_err(Error::Network)?,
|
||||||
|
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port))
|
||||||
|
.await
|
||||||
|
.map_err(Error::Network)?,
|
||||||
|
_ => unimplemented!("Unhandled rustls ServerName variant"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream = connector
|
||||||
|
.connect(hostname.clone(), stream)
|
||||||
|
.await
|
||||||
|
.map_err(Error::Network)?;
|
||||||
|
|
||||||
|
log::info!("Connected to server");
|
||||||
|
|
||||||
|
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
||||||
|
|
||||||
|
Version::CURRENT
|
||||||
|
.encode(&mut stream)
|
||||||
|
.await
|
||||||
|
.map_err(Error::Network)?;
|
||||||
|
stream.flush().await.map_err(Error::Network)?;
|
||||||
|
|
||||||
|
let version = Version::decode(&mut stream).await.map_err(Error::Network)?;
|
||||||
|
if version != Version::CURRENT {
|
||||||
|
return Err(Error::Version {
|
||||||
|
server: Version::CURRENT,
|
||||||
|
client: version,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let challenge = AuthChallenge::decode(&mut stream)
|
||||||
|
.await
|
||||||
|
.map_err(Error::Network)?;
|
||||||
|
let response = challenge.respond(password);
|
||||||
|
|
||||||
|
response.encode(&mut stream).await.map_err(Error::Network)?;
|
||||||
|
stream.flush().await.map_err(Error::Network)?;
|
||||||
|
|
||||||
|
match Message::decode(&mut stream).await.map_err(Error::Network)? {
|
||||||
|
AuthStatus::Passed => {}
|
||||||
|
AuthStatus::Failed => return Err(Error::Auth),
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("Passed auth check");
|
||||||
|
|
||||||
|
let mut writer = EventWriter::new().await.map_err(Error::Input)?;
|
||||||
|
loop {
|
||||||
|
let event = Message::decode(&mut stream).await.map_err(Error::Network)?;
|
||||||
|
log::trace!("Received event");
|
||||||
|
|
||||||
|
writer.write(event).await.map_err(Error::Input)?;
|
||||||
|
log::trace!("Wrote event");
|
||||||
|
}
|
||||||
|
}
|
|
@ -9,7 +9,9 @@ use tokio_rustls::rustls::ServerName;
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub server: Server,
|
pub server: Server,
|
||||||
pub certificate: PathBuf,
|
pub certificate: PathBuf,
|
||||||
|
pub password: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Server {
|
pub struct Server {
|
||||||
pub hostname: ServerName,
|
pub hostname: ServerName,
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
|
|
|
@ -1,19 +1,14 @@
|
||||||
|
mod client;
|
||||||
mod config;
|
mod config;
|
||||||
mod tls;
|
mod tls;
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use config::Config;
|
use config::Config;
|
||||||
use log::LevelFilter;
|
use log::LevelFilter;
|
||||||
use rkvm_input::EventWriter;
|
|
||||||
use rkvm_net::Message;
|
|
||||||
use std::io::Error;
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::process::ExitCode;
|
use std::process::ExitCode;
|
||||||
use tokio::fs;
|
use tokio::fs;
|
||||||
use tokio::net::TcpStream;
|
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio_rustls::rustls::ServerName;
|
|
||||||
use tokio_rustls::TlsConnector;
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[structopt(name = "rkvm-client", about = "The rkvm client application")]
|
#[structopt(name = "rkvm-client", about = "The rkvm client application")]
|
||||||
|
@ -56,9 +51,9 @@ async fn main() -> ExitCode {
|
||||||
};
|
};
|
||||||
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
result = run(&config.server.hostname, config.server.port, connector) => {
|
result = client::run(&config.server.hostname, config.server.port, connector, &config.password) => {
|
||||||
if let Err(err) = result {
|
if let Err(err) = result {
|
||||||
log::error!("Error running client: {}", err);
|
log::error!("Error: {}", err);
|
||||||
return ExitCode::FAILURE;
|
return ExitCode::FAILURE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -75,22 +70,3 @@ async fn main() -> ExitCode {
|
||||||
|
|
||||||
ExitCode::SUCCESS
|
ExitCode::SUCCESS
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run(hostname: &ServerName, port: u16, connector: TlsConnector) -> Result<(), Error> {
|
|
||||||
let stream = match hostname {
|
|
||||||
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port)).await?,
|
|
||||||
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port)).await?,
|
|
||||||
_ => unimplemented!("Unhandled rustls ServerName variant"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut stream = connector.connect(hostname.clone(), stream).await?;
|
|
||||||
log::info!("Connected to server");
|
|
||||||
|
|
||||||
rkvm_net::negotiate(&mut stream).await?;
|
|
||||||
|
|
||||||
let mut writer = EventWriter::new().await?;
|
|
||||||
loop {
|
|
||||||
let event = Message::decode(&mut stream).await?;
|
|
||||||
writer.write(event).await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -13,3 +13,6 @@ bincode = "1.3.3"
|
||||||
tokio = { version = "1.0.1", features = ["io-util"] }
|
tokio = { version = "1.0.1", features = ["io-util"] }
|
||||||
async-trait = "0.1.68"
|
async-trait = "0.1.68"
|
||||||
thiserror = "1.0.40"
|
thiserror = "1.0.40"
|
||||||
|
hmac = "0.12.1"
|
||||||
|
sha2 = "0.10.6"
|
||||||
|
rand = "0.8.5"
|
||||||
|
|
54
rkvm-net/src/auth.rs
Normal file
54
rkvm-net/src/auth.rs
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
use hmac::{Hmac, Mac};
|
||||||
|
use rand::rngs::OsRng;
|
||||||
|
use rand::Error;
|
||||||
|
use rand::Rng;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use sha2::Sha256;
|
||||||
|
use tokio::task;
|
||||||
|
|
||||||
|
type ChallengeHmac = Hmac<Sha256>;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||||
|
pub struct AuthChallenge([u8; 32]);
|
||||||
|
|
||||||
|
impl AuthChallenge {
|
||||||
|
pub async fn generate() -> Result<Self, Error> {
|
||||||
|
task::spawn_blocking(|| {
|
||||||
|
let mut data = [0; 32];
|
||||||
|
OsRng.try_fill(&mut data)?;
|
||||||
|
|
||||||
|
Ok(Self(data))
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn respond(&self, password: &str) -> AuthResponse {
|
||||||
|
let mut mac = ChallengeHmac::new_from_slice(password.as_bytes()).unwrap();
|
||||||
|
mac.update(&self.0);
|
||||||
|
|
||||||
|
let result = mac.finalize();
|
||||||
|
let result = result.into_bytes();
|
||||||
|
let result = result[..].try_into().unwrap();
|
||||||
|
|
||||||
|
AuthResponse(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||||
|
pub struct AuthResponse([u8; 32]);
|
||||||
|
|
||||||
|
impl AuthResponse {
|
||||||
|
pub fn verify(&self, challenge: &AuthChallenge, password: &str) -> bool {
|
||||||
|
let mut mac = ChallengeHmac::new_from_slice(password.as_bytes()).unwrap();
|
||||||
|
mac.update(&challenge.0);
|
||||||
|
|
||||||
|
mac.verify_slice(&self.0).is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||||
|
pub enum AuthStatus {
|
||||||
|
Passed,
|
||||||
|
Failed,
|
||||||
|
}
|
|
@ -1,36 +1,3 @@
|
||||||
mod message;
|
pub mod auth;
|
||||||
mod version;
|
pub mod message;
|
||||||
|
pub mod version;
|
||||||
use std::io::{Error, ErrorKind};
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
|
||||||
use version::Version;
|
|
||||||
|
|
||||||
pub use message::Message;
|
|
||||||
|
|
||||||
pub async fn negotiate<T: AsyncRead + AsyncWrite + Send + Unpin>(
|
|
||||||
stream: &mut T,
|
|
||||||
) -> Result<(), Error> {
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
#[error("Invalid version (expected {expected}, got {got} instead)")]
|
|
||||||
struct InvalidVersionError {
|
|
||||||
expected: Version,
|
|
||||||
got: Version,
|
|
||||||
}
|
|
||||||
|
|
||||||
Version::CURRENT.encode(stream).await?;
|
|
||||||
stream.flush().await?;
|
|
||||||
|
|
||||||
let version = Version::decode(stream).await?;
|
|
||||||
if version != Version::CURRENT {
|
|
||||||
return Err(Error::new(
|
|
||||||
ErrorKind::InvalidData,
|
|
||||||
InvalidVersionError {
|
|
||||||
expected: Version::CURRENT,
|
|
||||||
got: version,
|
|
||||||
},
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
|
@ -33,7 +33,9 @@ impl<T: DeserializeOwned + Serialize + Sync> Message for T {
|
||||||
.len()
|
.len()
|
||||||
.try_into()
|
.try_into()
|
||||||
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Data too large"))?;
|
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Data too large"))?;
|
||||||
|
|
||||||
stream.write_u16(length).await?;
|
stream.write_u16(length).await?;
|
||||||
|
stream.write_all(&data).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ tokio-rustls = "0.24.0"
|
||||||
rustls-pemfile = "1.0.2"
|
rustls-pemfile = "1.0.2"
|
||||||
thiserror = "1.0.40"
|
thiserror = "1.0.40"
|
||||||
slab = "0.4.8"
|
slab = "0.4.8"
|
||||||
|
rand = "0.8.5"
|
||||||
|
|
||||||
[package.metadata.rpm]
|
[package.metadata.rpm]
|
||||||
package = "rkvm-server"
|
package = "rkvm-server"
|
||||||
|
|
|
@ -7,7 +7,8 @@ use std::path::PathBuf;
|
||||||
#[serde(rename_all = "kebab-case")]
|
#[serde(rename_all = "kebab-case")]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub listen: SocketAddr,
|
pub listen: SocketAddr,
|
||||||
pub switch_key: Key,
|
|
||||||
pub certificate: PathBuf,
|
pub certificate: PathBuf,
|
||||||
pub key: PathBuf,
|
pub key: PathBuf,
|
||||||
|
pub password: String,
|
||||||
|
pub switch_key: Key,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,22 +1,14 @@
|
||||||
mod config;
|
mod config;
|
||||||
|
mod server;
|
||||||
mod tls;
|
mod tls;
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use config::Config;
|
use config::Config;
|
||||||
use log::LevelFilter;
|
use log::LevelFilter;
|
||||||
use rkvm_input::{Direction, Event, EventManager, Key, KeyKind};
|
|
||||||
use rkvm_net::Message;
|
|
||||||
use slab::Slab;
|
|
||||||
use std::io::Error;
|
|
||||||
use std::net::SocketAddr;
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::process::ExitCode;
|
use std::process::ExitCode;
|
||||||
use tokio::fs;
|
use tokio::fs;
|
||||||
use tokio::io::{AsyncWriteExt, BufStream};
|
|
||||||
use tokio::net::TcpListener;
|
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::sync::mpsc::{self, Sender};
|
|
||||||
use tokio_rustls::TlsAcceptor;
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[structopt(name = "rkvm-server", about = "The rkvm server application")]
|
#[structopt(name = "rkvm-server", about = "The rkvm server application")]
|
||||||
|
@ -59,9 +51,9 @@ async fn main() -> ExitCode {
|
||||||
};
|
};
|
||||||
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
result = run(config.listen, acceptor, config.switch_key) => {
|
result = server::run(config.listen, acceptor, &config.password, config.switch_key) => {
|
||||||
if let Err(err) = result {
|
if let Err(err) = result {
|
||||||
log::error!("Error running server: {}", err);
|
log::error!("Error: {}", err);
|
||||||
return ExitCode::FAILURE;
|
return ExitCode::FAILURE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -78,80 +70,3 @@ async fn main() -> ExitCode {
|
||||||
|
|
||||||
ExitCode::SUCCESS
|
ExitCode::SUCCESS
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run(listen: SocketAddr, acceptor: TlsAcceptor, switch_key: Key) -> Result<(), Error> {
|
|
||||||
let listener = TcpListener::bind(&listen).await?;
|
|
||||||
log::info!("Listening on {}", listen);
|
|
||||||
|
|
||||||
let mut clients = Slab::<Sender<_>>::new();
|
|
||||||
let mut current = 0;
|
|
||||||
let mut manager = EventManager::new().await?;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
result = listener.accept() => {
|
|
||||||
// Remove dead clients.
|
|
||||||
clients.retain(|_, client| !client.is_closed());
|
|
||||||
if !clients.contains(current) {
|
|
||||||
current = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (stream, addr) = result?;
|
|
||||||
let acceptor = acceptor.clone();
|
|
||||||
|
|
||||||
let (sender, mut receiver) = mpsc::channel::<Event>(1);
|
|
||||||
clients.insert(sender);
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let stream = match acceptor.accept(stream).await {
|
|
||||||
Ok(stream) => stream,
|
|
||||||
Err(err) => {
|
|
||||||
log::error!("{}: TLS accept error: {}", addr, err);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
log::info!("{}: Connected", addr);
|
|
||||||
|
|
||||||
let result = async {
|
|
||||||
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
|
||||||
|
|
||||||
rkvm_net::negotiate(&mut stream).await?;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
let event = match receiver.recv().await {
|
|
||||||
Some(event) => event,
|
|
||||||
None => break,
|
|
||||||
};
|
|
||||||
|
|
||||||
event.encode(&mut stream).await?;
|
|
||||||
stream.flush().await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok::<_, Error>(())
|
|
||||||
}
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(()) => log::info!("{}: Disconnected", addr),
|
|
||||||
Err(err) => log::error!("{}: Disconnected: {}", addr, err),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
result = manager.read() => {
|
|
||||||
let event = result?;
|
|
||||||
if let Event::Key { direction: Direction::Down, kind: KeyKind::Key(key) } = event {
|
|
||||||
if key == switch_key {
|
|
||||||
current = (current + 1) % (clients.len() + 1);
|
|
||||||
log::info!("Switching to client {}", current);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if current == 0 || clients[current].send(event).await.is_err() {
|
|
||||||
current = 0;
|
|
||||||
manager.write(event).await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
147
rkvm-server/src/server.rs
Normal file
147
rkvm-server/src/server.rs
Normal file
|
@ -0,0 +1,147 @@
|
||||||
|
use rkvm_input::{Direction, Event, EventManager, Key, KeyKind};
|
||||||
|
use rkvm_net::auth::{AuthChallenge, AuthResponse, AuthStatus};
|
||||||
|
use rkvm_net::message::Message;
|
||||||
|
use rkvm_net::version::Version;
|
||||||
|
use slab::Slab;
|
||||||
|
use std::io;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncWriteExt, BufStream};
|
||||||
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
use tokio::sync::mpsc::{self, Receiver, Sender};
|
||||||
|
use tokio_rustls::server::TlsStream;
|
||||||
|
use tokio_rustls::TlsAcceptor;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum Error {
|
||||||
|
#[error("Network error: {0}")]
|
||||||
|
Network(io::Error),
|
||||||
|
#[error("Input error: {0}")]
|
||||||
|
Input(io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(
|
||||||
|
listen: SocketAddr,
|
||||||
|
acceptor: TlsAcceptor,
|
||||||
|
password: &str,
|
||||||
|
switch_key: Key,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let listener = TcpListener::bind(&listen).await.map_err(Error::Network)?;
|
||||||
|
log::info!("Listening on {}", listen);
|
||||||
|
|
||||||
|
let mut clients = Slab::<Sender<_>>::new();
|
||||||
|
let mut current = 0;
|
||||||
|
let mut manager = EventManager::new().await.map_err(Error::Input)?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
result = listener.accept() => {
|
||||||
|
let (stream, addr) = result.map_err(Error::Network)?;
|
||||||
|
let acceptor = acceptor.clone();
|
||||||
|
let password = password.to_owned();
|
||||||
|
|
||||||
|
// Remove dead clients.
|
||||||
|
clients.retain(|_, client| !client.is_closed());
|
||||||
|
if !clients.contains(current) {
|
||||||
|
current = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let (sender, receiver) = mpsc::channel(1);
|
||||||
|
clients.insert(sender);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let stream = match acceptor.accept(stream).await {
|
||||||
|
Ok(stream) => stream,
|
||||||
|
Err(err) => {
|
||||||
|
log::error!("{}: TLS accept error: {}", addr, err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
log::info!("{}: Connected", addr);
|
||||||
|
|
||||||
|
match client(receiver, stream, addr, &password).await {
|
||||||
|
Ok(()) => log::info!("{}: Disconnected", addr),
|
||||||
|
Err(err) => log::error!("{}: Disconnected: {}", addr, err),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
result = manager.read() => {
|
||||||
|
let event = result.map_err(Error::Input)?;
|
||||||
|
if let Event::Key { direction: Direction::Down, kind: KeyKind::Key(key) } = event {
|
||||||
|
if key == switch_key {
|
||||||
|
current = (current + 1) % (clients.len() + 1);
|
||||||
|
log::info!("Switching to client {}", current);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if current == 0 || clients[current - 1].send(event).await.is_err() {
|
||||||
|
current = 0;
|
||||||
|
manager.write(event).await.map_err(Error::Input)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
enum ClientError {
|
||||||
|
#[error(transparent)]
|
||||||
|
Io(#[from] io::Error),
|
||||||
|
#[error("Incompatible client version (got {client}, expected {server})")]
|
||||||
|
Version { server: Version, client: Version },
|
||||||
|
#[error("Auth challenge failed (possibly wrong password)")]
|
||||||
|
Auth,
|
||||||
|
#[error(transparent)]
|
||||||
|
Rand(#[from] rand::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn client(
|
||||||
|
mut receiver: Receiver<Event>,
|
||||||
|
stream: TlsStream<TcpStream>,
|
||||||
|
addr: SocketAddr,
|
||||||
|
password: &str,
|
||||||
|
) -> Result<(), ClientError> {
|
||||||
|
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
||||||
|
|
||||||
|
Version::CURRENT.encode(&mut stream).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
|
||||||
|
let version = Version::decode(&mut stream).await?;
|
||||||
|
if version != Version::CURRENT {
|
||||||
|
return Err(ClientError::Version {
|
||||||
|
server: Version::CURRENT,
|
||||||
|
client: version,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let challenge = AuthChallenge::generate().await?;
|
||||||
|
|
||||||
|
challenge.encode(&mut stream).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
|
||||||
|
let response = AuthResponse::decode(&mut stream).await?;
|
||||||
|
let status = match response.verify(&challenge, password) {
|
||||||
|
true => AuthStatus::Passed,
|
||||||
|
false => AuthStatus::Failed,
|
||||||
|
};
|
||||||
|
|
||||||
|
status.encode(&mut stream).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
|
||||||
|
if status == AuthStatus::Failed {
|
||||||
|
return Err(ClientError::Auth);
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("{}: Passed auth check", addr);
|
||||||
|
|
||||||
|
while let Some(event) = receiver.recv().await {
|
||||||
|
event.encode(&mut stream).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
|
||||||
|
log::trace!("{}: Sent event", addr);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
Loading…
Reference in a new issue