diff --git a/server/Cargo.toml b/server/Cargo.toml index 22396f0..e821392 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -15,3 +15,4 @@ toml = "0.5.7" structopt = "0.3.20" log = "0.4.11" env_logger = "0.8.1" +tokio-native-tls = "0.2.0" diff --git a/server/src/config.rs b/server/src/config.rs index 8275102..7c9401c 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -1,10 +1,13 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::net::SocketAddr; +use std::path::PathBuf; #[derive(Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] pub struct Config { pub listen_address: SocketAddr, pub switch_keys: HashSet, + pub identity_path: PathBuf, + pub identity_password: String, } diff --git a/server/src/main.rs b/server/src/main.rs index b065d29..b5ecb1a 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -7,7 +7,7 @@ use std::collections::{HashMap, HashSet}; use std::convert::Infallible; use std::io::{Error, ErrorKind}; use std::net::SocketAddr; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::process; use std::time::Duration; use structopt::StructOpt; @@ -16,6 +16,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; use tokio::time; +use tokio_native_tls::native_tls::{Identity, TlsAcceptor}; async fn handle_connection( mut stream: T, @@ -48,7 +49,18 @@ where } } -async fn run(listen_address: SocketAddr, switch_keys: &HashSet) -> Result { +async fn run( + listen_address: SocketAddr, + switch_keys: &HashSet, + identity_path: &Path, + identity_password: &str, +) -> Result { + let identity = fs::read(identity_path).await?; + let identity = Identity::from_pkcs12(&identity, identity_password) + .map_err(|err| Error::new(ErrorKind::InvalidData, err))?; + let acceptor: tokio_native_tls::TlsAcceptor = TlsAcceptor::new(identity) + .map_err(|err| Error::new(ErrorKind::InvalidData, err)) + .map(Into::into)?; let listener = TcpListener::bind(listen_address).await?; log::info!("Listening on {}", listen_address); @@ -64,6 +76,14 @@ async fn run(listen_address: SocketAddr, switch_keys: &HashSet) -> Result stream, + Err(err) => { + log::error!("{}: TLS error: {}", address, err); + continue; + } + }; + let (sender, receiver) = mpsc::unbounded_channel(); if client_sender.send(Ok(sender)).is_err() { return; @@ -163,7 +183,7 @@ async fn main() { }; tokio::select! { - result = run(config.listen_address, &config.switch_keys) => { + result = run(config.listen_address, &config.switch_keys, &config.identity_path, &config.identity_password) => { if let Err(err) = result { log::error!("Error: {}", err); process::exit(1);