mirror of
https://github.com/htrefil/rkvm.git
synced 2024-12-27 09:58:13 +01:00
Implement authentication
This commit is contained in:
parent
e56ea03351
commit
26e9e78271
14 changed files with 441 additions and 154 deletions
132
Cargo.lock
generated
132
Cargo.lock
generated
|
@ -142,6 +142,15 @@ version = "1.3.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.12.0"
|
||||
|
@ -255,6 +264,36 @@ version = "1.0.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "280a9f2d8b3a38871a3c8a46fb80db65e5e5ed97da80c4d08bf27fb63e35e181"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"crypto-common",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.7.1"
|
||||
|
@ -400,6 +439,27 @@ dependencies = [
|
|||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
|
||||
dependencies = [
|
||||
"typenum",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"libc",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.1"
|
||||
|
@ -436,6 +496,15 @@ version = "0.3.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
|
||||
|
||||
[[package]]
|
||||
name = "hmac"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
|
||||
dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "1.3.0"
|
||||
|
@ -625,6 +694,12 @@ version = "0.3.26"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.56"
|
||||
|
@ -649,6 +724,36 @@ dependencies = [
|
|||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.3.5"
|
||||
|
@ -736,8 +841,11 @@ version = "0.2.0"
|
|||
dependencies = [
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"hmac",
|
||||
"rand",
|
||||
"rkvm-input",
|
||||
"serde",
|
||||
"sha2",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
]
|
||||
|
@ -749,6 +857,7 @@ dependencies = [
|
|||
"clap 4.2.2",
|
||||
"env_logger 0.8.4",
|
||||
"log",
|
||||
"rand",
|
||||
"rkvm-input",
|
||||
"rkvm-net",
|
||||
"rustls-pemfile",
|
||||
|
@ -841,6 +950,17 @@ dependencies = [
|
|||
"syn 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "0.1.1"
|
||||
|
@ -893,6 +1013,12 @@ version = "0.10.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
|
@ -1014,6 +1140,12 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.8"
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
server = "localhost:5258"
|
||||
certificate = "example/certificate.crt"
|
||||
password = "123456789"
|
|
@ -7,6 +7,6 @@ edition = "2021"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
clap = "4.2.2"
|
||||
clap = { version = "4.2.2", features = ["derive"] }
|
||||
tempfile = "3.1.0"
|
||||
thiserror = "1.0.40"
|
||||
|
|
86
rkvm-client/src/client.rs
Normal file
86
rkvm-client/src/client.rs
Normal file
|
@ -0,0 +1,86 @@
|
|||
use rkvm_input::EventWriter;
|
||||
use rkvm_net::auth::{AuthChallenge, AuthStatus};
|
||||
use rkvm_net::message::Message;
|
||||
use rkvm_net::version::Version;
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncWriteExt, BufStream};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::ServerName;
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Network error: {0}")]
|
||||
Network(io::Error),
|
||||
#[error("Input error: {0}")]
|
||||
Input(io::Error),
|
||||
#[error("Incompatible server version (got {server}, expected {client})")]
|
||||
Version { server: Version, client: Version },
|
||||
#[error("Auth challenge failed (possibly wrong password)")]
|
||||
Auth,
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
hostname: &ServerName,
|
||||
port: u16,
|
||||
connector: TlsConnector,
|
||||
password: &str,
|
||||
) -> Result<(), Error> {
|
||||
let stream = match hostname {
|
||||
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port))
|
||||
.await
|
||||
.map_err(Error::Network)?,
|
||||
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port))
|
||||
.await
|
||||
.map_err(Error::Network)?,
|
||||
_ => unimplemented!("Unhandled rustls ServerName variant"),
|
||||
};
|
||||
|
||||
let stream = connector
|
||||
.connect(hostname.clone(), stream)
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
|
||||
log::info!("Connected to server");
|
||||
|
||||
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
||||
|
||||
Version::CURRENT
|
||||
.encode(&mut stream)
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
stream.flush().await.map_err(Error::Network)?;
|
||||
|
||||
let version = Version::decode(&mut stream).await.map_err(Error::Network)?;
|
||||
if version != Version::CURRENT {
|
||||
return Err(Error::Version {
|
||||
server: Version::CURRENT,
|
||||
client: version,
|
||||
});
|
||||
}
|
||||
|
||||
let challenge = AuthChallenge::decode(&mut stream)
|
||||
.await
|
||||
.map_err(Error::Network)?;
|
||||
let response = challenge.respond(password);
|
||||
|
||||
response.encode(&mut stream).await.map_err(Error::Network)?;
|
||||
stream.flush().await.map_err(Error::Network)?;
|
||||
|
||||
match Message::decode(&mut stream).await.map_err(Error::Network)? {
|
||||
AuthStatus::Passed => {}
|
||||
AuthStatus::Failed => return Err(Error::Auth),
|
||||
}
|
||||
|
||||
log::info!("Passed auth check");
|
||||
|
||||
let mut writer = EventWriter::new().await.map_err(Error::Input)?;
|
||||
loop {
|
||||
let event = Message::decode(&mut stream).await.map_err(Error::Network)?;
|
||||
log::trace!("Received event");
|
||||
|
||||
writer.write(event).await.map_err(Error::Input)?;
|
||||
log::trace!("Wrote event");
|
||||
}
|
||||
}
|
|
@ -9,7 +9,9 @@ use tokio_rustls::rustls::ServerName;
|
|||
pub struct Config {
|
||||
pub server: Server,
|
||||
pub certificate: PathBuf,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
pub struct Server {
|
||||
pub hostname: ServerName,
|
||||
pub port: u16,
|
||||
|
|
|
@ -1,19 +1,14 @@
|
|||
mod client;
|
||||
mod config;
|
||||
mod tls;
|
||||
|
||||
use clap::Parser;
|
||||
use config::Config;
|
||||
use log::LevelFilter;
|
||||
use rkvm_input::EventWriter;
|
||||
use rkvm_net::Message;
|
||||
use std::io::Error;
|
||||
use std::path::PathBuf;
|
||||
use std::process::ExitCode;
|
||||
use tokio::fs;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::signal;
|
||||
use tokio_rustls::rustls::ServerName;
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[structopt(name = "rkvm-client", about = "The rkvm client application")]
|
||||
|
@ -56,9 +51,9 @@ async fn main() -> ExitCode {
|
|||
};
|
||||
|
||||
tokio::select! {
|
||||
result = run(&config.server.hostname, config.server.port, connector) => {
|
||||
result = client::run(&config.server.hostname, config.server.port, connector, &config.password) => {
|
||||
if let Err(err) = result {
|
||||
log::error!("Error running client: {}", err);
|
||||
log::error!("Error: {}", err);
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
}
|
||||
|
@ -75,22 +70,3 @@ async fn main() -> ExitCode {
|
|||
|
||||
ExitCode::SUCCESS
|
||||
}
|
||||
|
||||
async fn run(hostname: &ServerName, port: u16, connector: TlsConnector) -> Result<(), Error> {
|
||||
let stream = match hostname {
|
||||
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port)).await?,
|
||||
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port)).await?,
|
||||
_ => unimplemented!("Unhandled rustls ServerName variant"),
|
||||
};
|
||||
|
||||
let mut stream = connector.connect(hostname.clone(), stream).await?;
|
||||
log::info!("Connected to server");
|
||||
|
||||
rkvm_net::negotiate(&mut stream).await?;
|
||||
|
||||
let mut writer = EventWriter::new().await?;
|
||||
loop {
|
||||
let event = Message::decode(&mut stream).await?;
|
||||
writer.write(event).await?;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,3 +13,6 @@ bincode = "1.3.3"
|
|||
tokio = { version = "1.0.1", features = ["io-util"] }
|
||||
async-trait = "0.1.68"
|
||||
thiserror = "1.0.40"
|
||||
hmac = "0.12.1"
|
||||
sha2 = "0.10.6"
|
||||
rand = "0.8.5"
|
||||
|
|
54
rkvm-net/src/auth.rs
Normal file
54
rkvm-net/src/auth.rs
Normal file
|
@ -0,0 +1,54 @@
|
|||
use hmac::{Hmac, Mac};
|
||||
use rand::rngs::OsRng;
|
||||
use rand::Error;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::Sha256;
|
||||
use tokio::task;
|
||||
|
||||
type ChallengeHmac = Hmac<Sha256>;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||
pub struct AuthChallenge([u8; 32]);
|
||||
|
||||
impl AuthChallenge {
|
||||
pub async fn generate() -> Result<Self, Error> {
|
||||
task::spawn_blocking(|| {
|
||||
let mut data = [0; 32];
|
||||
OsRng.try_fill(&mut data)?;
|
||||
|
||||
Ok(Self(data))
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn respond(&self, password: &str) -> AuthResponse {
|
||||
let mut mac = ChallengeHmac::new_from_slice(password.as_bytes()).unwrap();
|
||||
mac.update(&self.0);
|
||||
|
||||
let result = mac.finalize();
|
||||
let result = result.into_bytes();
|
||||
let result = result[..].try_into().unwrap();
|
||||
|
||||
AuthResponse(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||
pub struct AuthResponse([u8; 32]);
|
||||
|
||||
impl AuthResponse {
|
||||
pub fn verify(&self, challenge: &AuthChallenge, password: &str) -> bool {
|
||||
let mut mac = ChallengeHmac::new_from_slice(password.as_bytes()).unwrap();
|
||||
mac.update(&challenge.0);
|
||||
|
||||
mac.verify_slice(&self.0).is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||
pub enum AuthStatus {
|
||||
Passed,
|
||||
Failed,
|
||||
}
|
|
@ -1,36 +1,3 @@
|
|||
mod message;
|
||||
mod version;
|
||||
|
||||
use std::io::{Error, ErrorKind};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use version::Version;
|
||||
|
||||
pub use message::Message;
|
||||
|
||||
pub async fn negotiate<T: AsyncRead + AsyncWrite + Send + Unpin>(
|
||||
stream: &mut T,
|
||||
) -> Result<(), Error> {
|
||||
#[derive(Error, Debug)]
|
||||
#[error("Invalid version (expected {expected}, got {got} instead)")]
|
||||
struct InvalidVersionError {
|
||||
expected: Version,
|
||||
got: Version,
|
||||
}
|
||||
|
||||
Version::CURRENT.encode(stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let version = Version::decode(stream).await?;
|
||||
if version != Version::CURRENT {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
InvalidVersionError {
|
||||
expected: Version::CURRENT,
|
||||
got: version,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
pub mod auth;
|
||||
pub mod message;
|
||||
pub mod version;
|
||||
|
|
|
@ -33,7 +33,9 @@ impl<T: DeserializeOwned + Serialize + Sync> Message for T {
|
|||
.len()
|
||||
.try_into()
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Data too large"))?;
|
||||
|
||||
stream.write_u16(length).await?;
|
||||
stream.write_all(&data).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ tokio-rustls = "0.24.0"
|
|||
rustls-pemfile = "1.0.2"
|
||||
thiserror = "1.0.40"
|
||||
slab = "0.4.8"
|
||||
rand = "0.8.5"
|
||||
|
||||
[package.metadata.rpm]
|
||||
package = "rkvm-server"
|
||||
|
|
|
@ -7,7 +7,8 @@ use std::path::PathBuf;
|
|||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct Config {
|
||||
pub listen: SocketAddr,
|
||||
pub switch_key: Key,
|
||||
pub certificate: PathBuf,
|
||||
pub key: PathBuf,
|
||||
pub password: String,
|
||||
pub switch_key: Key,
|
||||
}
|
||||
|
|
|
@ -1,22 +1,14 @@
|
|||
mod config;
|
||||
mod server;
|
||||
mod tls;
|
||||
|
||||
use clap::Parser;
|
||||
use config::Config;
|
||||
use log::LevelFilter;
|
||||
use rkvm_input::{Direction, Event, EventManager, Key, KeyKind};
|
||||
use rkvm_net::Message;
|
||||
use slab::Slab;
|
||||
use std::io::Error;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::process::ExitCode;
|
||||
use tokio::fs;
|
||||
use tokio::io::{AsyncWriteExt, BufStream};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
use tokio::sync::mpsc::{self, Sender};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[structopt(name = "rkvm-server", about = "The rkvm server application")]
|
||||
|
@ -59,9 +51,9 @@ async fn main() -> ExitCode {
|
|||
};
|
||||
|
||||
tokio::select! {
|
||||
result = run(config.listen, acceptor, config.switch_key) => {
|
||||
result = server::run(config.listen, acceptor, &config.password, config.switch_key) => {
|
||||
if let Err(err) = result {
|
||||
log::error!("Error running server: {}", err);
|
||||
log::error!("Error: {}", err);
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
}
|
||||
|
@ -78,80 +70,3 @@ async fn main() -> ExitCode {
|
|||
|
||||
ExitCode::SUCCESS
|
||||
}
|
||||
|
||||
async fn run(listen: SocketAddr, acceptor: TlsAcceptor, switch_key: Key) -> Result<(), Error> {
|
||||
let listener = TcpListener::bind(&listen).await?;
|
||||
log::info!("Listening on {}", listen);
|
||||
|
||||
let mut clients = Slab::<Sender<_>>::new();
|
||||
let mut current = 0;
|
||||
let mut manager = EventManager::new().await?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = listener.accept() => {
|
||||
// Remove dead clients.
|
||||
clients.retain(|_, client| !client.is_closed());
|
||||
if !clients.contains(current) {
|
||||
current = 0;
|
||||
}
|
||||
|
||||
let (stream, addr) = result?;
|
||||
let acceptor = acceptor.clone();
|
||||
|
||||
let (sender, mut receiver) = mpsc::channel::<Event>(1);
|
||||
clients.insert(sender);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let stream = match acceptor.accept(stream).await {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
log::error!("{}: TLS accept error: {}", addr, err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("{}: Connected", addr);
|
||||
|
||||
let result = async {
|
||||
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
||||
|
||||
rkvm_net::negotiate(&mut stream).await?;
|
||||
|
||||
loop {
|
||||
let event = match receiver.recv().await {
|
||||
Some(event) => event,
|
||||
None => break,
|
||||
};
|
||||
|
||||
event.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
|
||||
Ok::<_, Error>(())
|
||||
}
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(()) => log::info!("{}: Disconnected", addr),
|
||||
Err(err) => log::error!("{}: Disconnected: {}", addr, err),
|
||||
}
|
||||
});
|
||||
}
|
||||
result = manager.read() => {
|
||||
let event = result?;
|
||||
if let Event::Key { direction: Direction::Down, kind: KeyKind::Key(key) } = event {
|
||||
if key == switch_key {
|
||||
current = (current + 1) % (clients.len() + 1);
|
||||
log::info!("Switching to client {}", current);
|
||||
}
|
||||
}
|
||||
|
||||
if current == 0 || clients[current].send(event).await.is_err() {
|
||||
current = 0;
|
||||
manager.write(event).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
147
rkvm-server/src/server.rs
Normal file
147
rkvm-server/src/server.rs
Normal file
|
@ -0,0 +1,147 @@
|
|||
use rkvm_input::{Direction, Event, EventManager, Key, KeyKind};
|
||||
use rkvm_net::auth::{AuthChallenge, AuthResponse, AuthStatus};
|
||||
use rkvm_net::message::Message;
|
||||
use rkvm_net::version::Version;
|
||||
use slab::Slab;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncWriteExt, BufStream};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::mpsc::{self, Receiver, Sender};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Network error: {0}")]
|
||||
Network(io::Error),
|
||||
#[error("Input error: {0}")]
|
||||
Input(io::Error),
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
listen: SocketAddr,
|
||||
acceptor: TlsAcceptor,
|
||||
password: &str,
|
||||
switch_key: Key,
|
||||
) -> Result<(), Error> {
|
||||
let listener = TcpListener::bind(&listen).await.map_err(Error::Network)?;
|
||||
log::info!("Listening on {}", listen);
|
||||
|
||||
let mut clients = Slab::<Sender<_>>::new();
|
||||
let mut current = 0;
|
||||
let mut manager = EventManager::new().await.map_err(Error::Input)?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = listener.accept() => {
|
||||
let (stream, addr) = result.map_err(Error::Network)?;
|
||||
let acceptor = acceptor.clone();
|
||||
let password = password.to_owned();
|
||||
|
||||
// Remove dead clients.
|
||||
clients.retain(|_, client| !client.is_closed());
|
||||
if !clients.contains(current) {
|
||||
current = 0;
|
||||
}
|
||||
|
||||
let (sender, receiver) = mpsc::channel(1);
|
||||
clients.insert(sender);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let stream = match acceptor.accept(stream).await {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
log::error!("{}: TLS accept error: {}", addr, err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("{}: Connected", addr);
|
||||
|
||||
match client(receiver, stream, addr, &password).await {
|
||||
Ok(()) => log::info!("{}: Disconnected", addr),
|
||||
Err(err) => log::error!("{}: Disconnected: {}", addr, err),
|
||||
}
|
||||
});
|
||||
}
|
||||
result = manager.read() => {
|
||||
let event = result.map_err(Error::Input)?;
|
||||
if let Event::Key { direction: Direction::Down, kind: KeyKind::Key(key) } = event {
|
||||
if key == switch_key {
|
||||
current = (current + 1) % (clients.len() + 1);
|
||||
log::info!("Switching to client {}", current);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if current == 0 || clients[current - 1].send(event).await.is_err() {
|
||||
current = 0;
|
||||
manager.write(event).await.map_err(Error::Input)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
enum ClientError {
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
#[error("Incompatible client version (got {client}, expected {server})")]
|
||||
Version { server: Version, client: Version },
|
||||
#[error("Auth challenge failed (possibly wrong password)")]
|
||||
Auth,
|
||||
#[error(transparent)]
|
||||
Rand(#[from] rand::Error),
|
||||
}
|
||||
|
||||
async fn client(
|
||||
mut receiver: Receiver<Event>,
|
||||
stream: TlsStream<TcpStream>,
|
||||
addr: SocketAddr,
|
||||
password: &str,
|
||||
) -> Result<(), ClientError> {
|
||||
let mut stream = BufStream::with_capacity(1024, 1024, stream);
|
||||
|
||||
Version::CURRENT.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let version = Version::decode(&mut stream).await?;
|
||||
if version != Version::CURRENT {
|
||||
return Err(ClientError::Version {
|
||||
server: Version::CURRENT,
|
||||
client: version,
|
||||
});
|
||||
}
|
||||
|
||||
let challenge = AuthChallenge::generate().await?;
|
||||
|
||||
challenge.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let response = AuthResponse::decode(&mut stream).await?;
|
||||
let status = match response.verify(&challenge, password) {
|
||||
true => AuthStatus::Passed,
|
||||
false => AuthStatus::Failed,
|
||||
};
|
||||
|
||||
status.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
if status == AuthStatus::Failed {
|
||||
return Err(ClientError::Auth);
|
||||
}
|
||||
|
||||
log::info!("{}: Passed auth check", addr);
|
||||
|
||||
while let Some(event) = receiver.recv().await {
|
||||
event.encode(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
log::trace!("{}: Sent event", addr);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in a new issue