Implement authentication

This commit is contained in:
Jan Trefil 2023-04-17 19:52:57 +02:00
parent e56ea03351
commit 26e9e78271
14 changed files with 441 additions and 154 deletions

132
Cargo.lock generated
View file

@ -142,6 +142,15 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "block-buffer"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
dependencies = [
"generic-array",
]
[[package]]
name = "bumpalo"
version = "3.12.0"
@ -255,6 +264,36 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "cpufeatures"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "280a9f2d8b3a38871a3c8a46fb80db65e5e5ed97da80c4d08bf27fb63e35e181"
dependencies = [
"libc",
]
[[package]]
name = "crypto-common"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"typenum",
]
[[package]]
name = "digest"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f"
dependencies = [
"block-buffer",
"crypto-common",
"subtle",
]
[[package]]
name = "env_logger"
version = "0.7.1"
@ -400,6 +439,27 @@ dependencies = [
"slab",
]
[[package]]
name = "generic-array"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "getrandom"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
dependencies = [
"cfg-if 1.0.0",
"libc",
"wasi",
]
[[package]]
name = "glob"
version = "0.3.1"
@ -436,6 +496,15 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
[[package]]
name = "hmac"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
]
[[package]]
name = "humantime"
version = "1.3.0"
@ -625,6 +694,12 @@ version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "proc-macro2"
version = "1.0.56"
@ -649,6 +724,36 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "redox_syscall"
version = "0.3.5"
@ -736,8 +841,11 @@ version = "0.2.0"
dependencies = [
"async-trait",
"bincode",
"hmac",
"rand",
"rkvm-input",
"serde",
"sha2",
"thiserror",
"tokio",
]
@ -749,6 +857,7 @@ dependencies = [
"clap 4.2.2",
"env_logger 0.8.4",
"log",
"rand",
"rkvm-input",
"rkvm-net",
"rustls-pemfile",
@ -841,6 +950,17 @@ dependencies = [
"syn 2.0.15",
]
[[package]]
name = "sha2"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0"
dependencies = [
"cfg-if 1.0.0",
"cpufeatures",
"digest",
]
[[package]]
name = "shlex"
version = "0.1.1"
@ -893,6 +1013,12 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "subtle"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601"
[[package]]
name = "syn"
version = "1.0.109"
@ -1014,6 +1140,12 @@ dependencies = [
"serde",
]
[[package]]
name = "typenum"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba"
[[package]]
name = "unicode-ident"
version = "1.0.8"

View file

@ -1,2 +1,3 @@
server = "localhost:5258"
certificate = "example/certificate.crt"
password = "123456789"

View file

@ -7,6 +7,6 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
clap = "4.2.2"
clap = { version = "4.2.2", features = ["derive"] }
tempfile = "3.1.0"
thiserror = "1.0.40"

86
rkvm-client/src/client.rs Normal file
View file

@ -0,0 +1,86 @@
use rkvm_input::EventWriter;
use rkvm_net::auth::{AuthChallenge, AuthStatus};
use rkvm_net::message::Message;
use rkvm_net::version::Version;
use std::io;
use thiserror::Error;
use tokio::io::{AsyncWriteExt, BufStream};
use tokio::net::TcpStream;
use tokio_rustls::rustls::ServerName;
use tokio_rustls::TlsConnector;
#[derive(Error, Debug)]
pub enum Error {
#[error("Network error: {0}")]
Network(io::Error),
#[error("Input error: {0}")]
Input(io::Error),
#[error("Incompatible server version (got {server}, expected {client})")]
Version { server: Version, client: Version },
#[error("Auth challenge failed (possibly wrong password)")]
Auth,
}
pub async fn run(
hostname: &ServerName,
port: u16,
connector: TlsConnector,
password: &str,
) -> Result<(), Error> {
let stream = match hostname {
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port))
.await
.map_err(Error::Network)?,
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port))
.await
.map_err(Error::Network)?,
_ => unimplemented!("Unhandled rustls ServerName variant"),
};
let stream = connector
.connect(hostname.clone(), stream)
.await
.map_err(Error::Network)?;
log::info!("Connected to server");
let mut stream = BufStream::with_capacity(1024, 1024, stream);
Version::CURRENT
.encode(&mut stream)
.await
.map_err(Error::Network)?;
stream.flush().await.map_err(Error::Network)?;
let version = Version::decode(&mut stream).await.map_err(Error::Network)?;
if version != Version::CURRENT {
return Err(Error::Version {
server: Version::CURRENT,
client: version,
});
}
let challenge = AuthChallenge::decode(&mut stream)
.await
.map_err(Error::Network)?;
let response = challenge.respond(password);
response.encode(&mut stream).await.map_err(Error::Network)?;
stream.flush().await.map_err(Error::Network)?;
match Message::decode(&mut stream).await.map_err(Error::Network)? {
AuthStatus::Passed => {}
AuthStatus::Failed => return Err(Error::Auth),
}
log::info!("Passed auth check");
let mut writer = EventWriter::new().await.map_err(Error::Input)?;
loop {
let event = Message::decode(&mut stream).await.map_err(Error::Network)?;
log::trace!("Received event");
writer.write(event).await.map_err(Error::Input)?;
log::trace!("Wrote event");
}
}

View file

@ -9,7 +9,9 @@ use tokio_rustls::rustls::ServerName;
pub struct Config {
pub server: Server,
pub certificate: PathBuf,
pub password: String,
}
pub struct Server {
pub hostname: ServerName,
pub port: u16,

View file

@ -1,19 +1,14 @@
mod client;
mod config;
mod tls;
use clap::Parser;
use config::Config;
use log::LevelFilter;
use rkvm_input::EventWriter;
use rkvm_net::Message;
use std::io::Error;
use std::path::PathBuf;
use std::process::ExitCode;
use tokio::fs;
use tokio::net::TcpStream;
use tokio::signal;
use tokio_rustls::rustls::ServerName;
use tokio_rustls::TlsConnector;
#[derive(Parser)]
#[structopt(name = "rkvm-client", about = "The rkvm client application")]
@ -56,9 +51,9 @@ async fn main() -> ExitCode {
};
tokio::select! {
result = run(&config.server.hostname, config.server.port, connector) => {
result = client::run(&config.server.hostname, config.server.port, connector, &config.password) => {
if let Err(err) = result {
log::error!("Error running client: {}", err);
log::error!("Error: {}", err);
return ExitCode::FAILURE;
}
}
@ -75,22 +70,3 @@ async fn main() -> ExitCode {
ExitCode::SUCCESS
}
async fn run(hostname: &ServerName, port: u16, connector: TlsConnector) -> Result<(), Error> {
let stream = match hostname {
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port)).await?,
ServerName::IpAddress(address) => TcpStream::connect(&(*address, port)).await?,
_ => unimplemented!("Unhandled rustls ServerName variant"),
};
let mut stream = connector.connect(hostname.clone(), stream).await?;
log::info!("Connected to server");
rkvm_net::negotiate(&mut stream).await?;
let mut writer = EventWriter::new().await?;
loop {
let event = Message::decode(&mut stream).await?;
writer.write(event).await?;
}
}

View file

@ -13,3 +13,6 @@ bincode = "1.3.3"
tokio = { version = "1.0.1", features = ["io-util"] }
async-trait = "0.1.68"
thiserror = "1.0.40"
hmac = "0.12.1"
sha2 = "0.10.6"
rand = "0.8.5"

54
rkvm-net/src/auth.rs Normal file
View file

@ -0,0 +1,54 @@
use hmac::{Hmac, Mac};
use rand::rngs::OsRng;
use rand::Error;
use rand::Rng;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use tokio::task;
type ChallengeHmac = Hmac<Sha256>;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
pub struct AuthChallenge([u8; 32]);
impl AuthChallenge {
pub async fn generate() -> Result<Self, Error> {
task::spawn_blocking(|| {
let mut data = [0; 32];
OsRng.try_fill(&mut data)?;
Ok(Self(data))
})
.await
.unwrap()
}
pub fn respond(&self, password: &str) -> AuthResponse {
let mut mac = ChallengeHmac::new_from_slice(password.as_bytes()).unwrap();
mac.update(&self.0);
let result = mac.finalize();
let result = result.into_bytes();
let result = result[..].try_into().unwrap();
AuthResponse(result)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
pub struct AuthResponse([u8; 32]);
impl AuthResponse {
pub fn verify(&self, challenge: &AuthChallenge, password: &str) -> bool {
let mut mac = ChallengeHmac::new_from_slice(password.as_bytes()).unwrap();
mac.update(&challenge.0);
mac.verify_slice(&self.0).is_ok()
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
pub enum AuthStatus {
Passed,
Failed,
}

View file

@ -1,36 +1,3 @@
mod message;
mod version;
use std::io::{Error, ErrorKind};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use version::Version;
pub use message::Message;
pub async fn negotiate<T: AsyncRead + AsyncWrite + Send + Unpin>(
stream: &mut T,
) -> Result<(), Error> {
#[derive(Error, Debug)]
#[error("Invalid version (expected {expected}, got {got} instead)")]
struct InvalidVersionError {
expected: Version,
got: Version,
}
Version::CURRENT.encode(stream).await?;
stream.flush().await?;
let version = Version::decode(stream).await?;
if version != Version::CURRENT {
return Err(Error::new(
ErrorKind::InvalidData,
InvalidVersionError {
expected: Version::CURRENT,
got: version,
},
));
}
Ok(())
}
pub mod auth;
pub mod message;
pub mod version;

View file

@ -33,7 +33,9 @@ impl<T: DeserializeOwned + Serialize + Sync> Message for T {
.len()
.try_into()
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Data too large"))?;
stream.write_u16(length).await?;
stream.write_all(&data).await?;
Ok(())
}

View file

@ -20,6 +20,7 @@ tokio-rustls = "0.24.0"
rustls-pemfile = "1.0.2"
thiserror = "1.0.40"
slab = "0.4.8"
rand = "0.8.5"
[package.metadata.rpm]
package = "rkvm-server"

View file

@ -7,7 +7,8 @@ use std::path::PathBuf;
#[serde(rename_all = "kebab-case")]
pub struct Config {
pub listen: SocketAddr,
pub switch_key: Key,
pub certificate: PathBuf,
pub key: PathBuf,
pub password: String,
pub switch_key: Key,
}

View file

@ -1,22 +1,14 @@
mod config;
mod server;
mod tls;
use clap::Parser;
use config::Config;
use log::LevelFilter;
use rkvm_input::{Direction, Event, EventManager, Key, KeyKind};
use rkvm_net::Message;
use slab::Slab;
use std::io::Error;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::process::ExitCode;
use tokio::fs;
use tokio::io::{AsyncWriteExt, BufStream};
use tokio::net::TcpListener;
use tokio::signal;
use tokio::sync::mpsc::{self, Sender};
use tokio_rustls::TlsAcceptor;
#[derive(Parser)]
#[structopt(name = "rkvm-server", about = "The rkvm server application")]
@ -59,9 +51,9 @@ async fn main() -> ExitCode {
};
tokio::select! {
result = run(config.listen, acceptor, config.switch_key) => {
result = server::run(config.listen, acceptor, &config.password, config.switch_key) => {
if let Err(err) = result {
log::error!("Error running server: {}", err);
log::error!("Error: {}", err);
return ExitCode::FAILURE;
}
}
@ -78,80 +70,3 @@ async fn main() -> ExitCode {
ExitCode::SUCCESS
}
async fn run(listen: SocketAddr, acceptor: TlsAcceptor, switch_key: Key) -> Result<(), Error> {
let listener = TcpListener::bind(&listen).await?;
log::info!("Listening on {}", listen);
let mut clients = Slab::<Sender<_>>::new();
let mut current = 0;
let mut manager = EventManager::new().await?;
loop {
tokio::select! {
result = listener.accept() => {
// Remove dead clients.
clients.retain(|_, client| !client.is_closed());
if !clients.contains(current) {
current = 0;
}
let (stream, addr) = result?;
let acceptor = acceptor.clone();
let (sender, mut receiver) = mpsc::channel::<Event>(1);
clients.insert(sender);
tokio::spawn(async move {
let stream = match acceptor.accept(stream).await {
Ok(stream) => stream,
Err(err) => {
log::error!("{}: TLS accept error: {}", addr, err);
return;
}
};
log::info!("{}: Connected", addr);
let result = async {
let mut stream = BufStream::with_capacity(1024, 1024, stream);
rkvm_net::negotiate(&mut stream).await?;
loop {
let event = match receiver.recv().await {
Some(event) => event,
None => break,
};
event.encode(&mut stream).await?;
stream.flush().await?;
}
Ok::<_, Error>(())
}
.await;
match result {
Ok(()) => log::info!("{}: Disconnected", addr),
Err(err) => log::error!("{}: Disconnected: {}", addr, err),
}
});
}
result = manager.read() => {
let event = result?;
if let Event::Key { direction: Direction::Down, kind: KeyKind::Key(key) } = event {
if key == switch_key {
current = (current + 1) % (clients.len() + 1);
log::info!("Switching to client {}", current);
}
}
if current == 0 || clients[current].send(event).await.is_err() {
current = 0;
manager.write(event).await?;
}
}
}
}
}

147
rkvm-server/src/server.rs Normal file
View file

@ -0,0 +1,147 @@
use rkvm_input::{Direction, Event, EventManager, Key, KeyKind};
use rkvm_net::auth::{AuthChallenge, AuthResponse, AuthStatus};
use rkvm_net::message::Message;
use rkvm_net::version::Version;
use slab::Slab;
use std::io;
use std::net::SocketAddr;
use thiserror::Error;
use tokio::io::{AsyncWriteExt, BufStream};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor;
#[derive(Error, Debug)]
pub enum Error {
#[error("Network error: {0}")]
Network(io::Error),
#[error("Input error: {0}")]
Input(io::Error),
}
pub async fn run(
listen: SocketAddr,
acceptor: TlsAcceptor,
password: &str,
switch_key: Key,
) -> Result<(), Error> {
let listener = TcpListener::bind(&listen).await.map_err(Error::Network)?;
log::info!("Listening on {}", listen);
let mut clients = Slab::<Sender<_>>::new();
let mut current = 0;
let mut manager = EventManager::new().await.map_err(Error::Input)?;
loop {
tokio::select! {
result = listener.accept() => {
let (stream, addr) = result.map_err(Error::Network)?;
let acceptor = acceptor.clone();
let password = password.to_owned();
// Remove dead clients.
clients.retain(|_, client| !client.is_closed());
if !clients.contains(current) {
current = 0;
}
let (sender, receiver) = mpsc::channel(1);
clients.insert(sender);
tokio::spawn(async move {
let stream = match acceptor.accept(stream).await {
Ok(stream) => stream,
Err(err) => {
log::error!("{}: TLS accept error: {}", addr, err);
return;
}
};
log::info!("{}: Connected", addr);
match client(receiver, stream, addr, &password).await {
Ok(()) => log::info!("{}: Disconnected", addr),
Err(err) => log::error!("{}: Disconnected: {}", addr, err),
}
});
}
result = manager.read() => {
let event = result.map_err(Error::Input)?;
if let Event::Key { direction: Direction::Down, kind: KeyKind::Key(key) } = event {
if key == switch_key {
current = (current + 1) % (clients.len() + 1);
log::info!("Switching to client {}", current);
continue;
}
}
if current == 0 || clients[current - 1].send(event).await.is_err() {
current = 0;
manager.write(event).await.map_err(Error::Input)?;
}
}
}
}
}
#[derive(Error, Debug)]
enum ClientError {
#[error(transparent)]
Io(#[from] io::Error),
#[error("Incompatible client version (got {client}, expected {server})")]
Version { server: Version, client: Version },
#[error("Auth challenge failed (possibly wrong password)")]
Auth,
#[error(transparent)]
Rand(#[from] rand::Error),
}
async fn client(
mut receiver: Receiver<Event>,
stream: TlsStream<TcpStream>,
addr: SocketAddr,
password: &str,
) -> Result<(), ClientError> {
let mut stream = BufStream::with_capacity(1024, 1024, stream);
Version::CURRENT.encode(&mut stream).await?;
stream.flush().await?;
let version = Version::decode(&mut stream).await?;
if version != Version::CURRENT {
return Err(ClientError::Version {
server: Version::CURRENT,
client: version,
});
}
let challenge = AuthChallenge::generate().await?;
challenge.encode(&mut stream).await?;
stream.flush().await?;
let response = AuthResponse::decode(&mut stream).await?;
let status = match response.verify(&challenge, password) {
true => AuthStatus::Passed,
false => AuthStatus::Failed,
};
status.encode(&mut stream).await?;
stream.flush().await?;
if status == AuthStatus::Failed {
return Err(ClientError::Auth);
}
log::info!("{}: Passed auth check", addr);
while let Some(event) = receiver.recv().await {
event.encode(&mut stream).await?;
stream.flush().await?;
log::trace!("{}: Sent event", addr);
}
Ok(())
}