Implement multitouch events

This commit is contained in:
Jan Trefil 2023-06-20 19:47:37 +02:00
parent 2b20ba1f1b
commit 9b9a63ace8
13 changed files with 267 additions and 203 deletions

View file

@ -140,22 +140,17 @@ pub async fn run(
log::info!("Destroyed device {}", id); log::info!("Destroyed device {}", id);
} }
Update::EventBatch { id, events } => { Update::Event { id, event } => {
let writer = writers.get_mut(&id).ok_or_else(|| { let writer = writers.get_mut(&id).ok_or_else(|| {
Error::Network(io::Error::new( Error::Network(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
"Server sent events to a nonexistent device", "Server sent an event to a nonexistent device",
)) ))
})?; })?;
writer.write(&events).await.map_err(Error::Input)?; writer.write(&event).await.map_err(Error::Input)?;
log::trace!( log::trace!("Wrote an event to device {}", id);
"Wrote {} event{} to device {}",
events.len(),
if events.len() == 1 { "" } else { "s" },
id
);
} }
} }
} }

View file

@ -7,8 +7,7 @@ use config::Config;
use log::LevelFilter; use log::LevelFilter;
use std::path::PathBuf; use std::path::PathBuf;
use std::process::ExitCode; use std::process::ExitCode;
use tokio::fs; use tokio::{fs, signal};
use tokio::signal;
#[derive(Parser)] #[derive(Parser)]
#[structopt(name = "rkvm-client", about = "The rkvm client application")] #[structopt(name = "rkvm-client", about = "The rkvm client application")]

View file

@ -1,5 +1,6 @@
use std::io;
use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::{io, path::Path};
use thiserror::Error; use thiserror::Error;
use tokio::fs; use tokio::fs;
use tokio_rustls::rustls::{self, Certificate, ClientConfig, RootCertStore}; use tokio_rustls::rustls::{self, Certificate, ClientConfig, RootCertStore};

View file

@ -2,25 +2,12 @@ use crate::glue;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
// #define ABS_MT_SLOT 0x2f /* MT slot being modified */
// #define ABS_MT_TOUCH_MAJOR 0x30 /* Major axis of touching ellipse */
// #define ABS_MT_TOUCH_MINOR 0x31 /* Minor axis (omit if circular) */
// #define ABS_MT_WIDTH_MAJOR 0x32 /* Major axis of approaching ellipse */
// #define ABS_MT_WIDTH_MINOR 0x33 /* Minor axis (omit if circular) */
// #define ABS_MT_ORIENTATION 0x34 /* Ellipse orientation */
// #define ABS_MT_POSITION_X 0x35 /* Center X touch position */
// #define ABS_MT_POSITION_Y 0x36 /* Center Y touch position */
// #define ABS_MT_TOOL_TYPE 0x37 /* Type of touching device */
// #define ABS_MT_BLOB_ID 0x38 /* Group a set of packets as a blob */
// #define ABS_MT_TRACKING_ID 0x39 /* Unique ID of initiated contact */
// #define ABS_MT_PRESSURE 0x3a /* Pressure on contact area */
// #define ABS_MT_DISTANCE 0x3b /* Contact hover distance */
// #define ABS_MT_TOOL_X 0x3c /* Center X tool position */
// #define ABS_MT_TOOL_Y 0x3d /* Center Y tool position */
#[derive(Clone, Copy, Debug, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub struct AbsEvent { pub enum AbsEvent {
pub axis: AbsAxis, Axis { axis: AbsAxis, value: i32 },
pub value: i32, MtToolType { value: ToolType },
// TODO: This might actually belong to the Axis variant.
MtBlobId { value: i32 },
} }
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
@ -52,6 +39,19 @@ pub enum AbsAxis {
Volume, Volume,
Profile, Profile,
Misc, Misc,
MtSlot,
MtTouchMajor,
MtTouchMinor,
MtWidthMajor,
MtWidthMinor,
MtOrientation,
MtPositionX,
MtPositionY,
MtTrackingId,
MtPressure,
MtDistance,
MtToolX,
MtToolY,
} }
impl AbsAxis { impl AbsAxis {
@ -84,6 +84,19 @@ impl AbsAxis {
glue::ABS_VOLUME => Self::Volume, glue::ABS_VOLUME => Self::Volume,
glue::ABS_PROFILE => Self::Profile, glue::ABS_PROFILE => Self::Profile,
glue::ABS_MISC => Self::Misc, glue::ABS_MISC => Self::Misc,
glue::ABS_MT_SLOT => Self::MtSlot,
glue::ABS_MT_TOUCH_MAJOR => Self::MtTouchMajor,
glue::ABS_MT_TOUCH_MINOR => Self::MtTouchMinor,
glue::ABS_MT_WIDTH_MAJOR => Self::MtWidthMajor,
glue::ABS_MT_WIDTH_MINOR => Self::MtWidthMinor,
glue::ABS_MT_ORIENTATION => Self::MtOrientation,
glue::ABS_MT_POSITION_X => Self::MtPositionX,
glue::ABS_MT_POSITION_Y => Self::MtPositionY,
glue::ABS_MT_TRACKING_ID => Self::MtTrackingId,
glue::ABS_MT_PRESSURE => Self::MtPressure,
glue::ABS_MT_DISTANCE => Self::MtDistance,
glue::ABS_MT_TOOL_X => Self::MtToolX,
glue::ABS_MT_TOOL_Y => Self::MtToolY,
_ => return None, _ => return None,
}; };
@ -119,6 +132,19 @@ impl AbsAxis {
Self::Volume => glue::ABS_VOLUME, Self::Volume => glue::ABS_VOLUME,
Self::Profile => glue::ABS_PROFILE, Self::Profile => glue::ABS_PROFILE,
Self::Misc => glue::ABS_MISC, Self::Misc => glue::ABS_MISC,
Self::MtSlot => glue::ABS_MT_SLOT,
Self::MtTouchMajor => glue::ABS_MT_TOUCH_MAJOR,
Self::MtTouchMinor => glue::ABS_MT_TOUCH_MINOR,
Self::MtWidthMajor => glue::ABS_MT_WIDTH_MAJOR,
Self::MtWidthMinor => glue::ABS_MT_WIDTH_MINOR,
Self::MtOrientation => glue::ABS_MT_ORIENTATION,
Self::MtPositionX => glue::ABS_MT_POSITION_X,
Self::MtPositionY => glue::ABS_MT_POSITION_Y,
Self::MtTrackingId => glue::ABS_MT_TRACKING_ID,
Self::MtPressure => glue::ABS_MT_PRESSURE,
Self::MtDistance => glue::ABS_MT_DISTANCE,
Self::MtToolX => glue::ABS_MT_TOOL_X,
Self::MtToolY => glue::ABS_MT_TOOL_Y,
}; };
code as _ code as _
@ -135,10 +161,35 @@ pub struct AbsInfo {
pub resolution: i32, pub resolution: i32,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Clone, Copy, Serialize, Deserialize, Debug)]
pub enum ToolType { pub enum ToolType {
Finger, Finger,
Pen, Pen,
Palm, Palm,
Dial, Dial,
} }
impl ToolType {
pub(crate) fn from_raw(value: i32) -> Option<Self> {
let value = match value as _ {
glue::MT_TOOL_FINGER => Self::Finger,
glue::MT_TOOL_PEN => Self::Pen,
glue::MT_TOOL_PALM => Self::Palm,
glue::MT_TOOL_DIAL => Self::Dial,
_ => return None,
};
Some(value)
}
pub(crate) fn to_raw(&self) -> i32 {
let value = match self {
Self::Finger => glue::MT_TOOL_FINGER,
Self::Pen => glue::MT_TOOL_PEN,
Self::Palm => glue::MT_TOOL_PALM,
Self::Dial => glue::MT_TOOL_DIAL,
};
value as _
}
}

View file

@ -1,15 +1,14 @@
use crate::abs::AbsEvent; use crate::abs::AbsEvent;
use crate::key::KeyEvent; use crate::key::KeyEvent;
use crate::rel::RelEvent; use crate::rel::RelEvent;
use crate::sync::SyncEvent;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
pub type Packet = SmallVec<[Event; 2]>;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum Event { pub enum Event {
Rel(RelEvent), Rel(RelEvent),
Abs(AbsEvent), Abs(AbsEvent),
Key(KeyEvent), Key(KeyEvent),
Sync(SyncEvent),
} }

View file

@ -2,17 +2,18 @@ mod caps;
pub use caps::{AbsCaps, KeyCaps, RelCaps}; pub use caps::{AbsCaps, KeyCaps, RelCaps};
use crate::abs::{AbsAxis, AbsEvent}; use crate::abs::{AbsAxis, AbsEvent, ToolType};
use crate::event::{Event, Packet}; use crate::event::Event;
use crate::glue::{self, libevdev}; use crate::glue::{self, libevdev};
use crate::key::{Key, KeyEvent}; use crate::key::{Key, KeyEvent};
use crate::rel::{RelAxis, RelEvent}; use crate::rel::{RelAxis, RelEvent};
use crate::sync::SyncEvent;
use crate::writer::Writer; use crate::writer::Writer;
use std::collections::VecDeque;
use std::ffi::CStr; use std::ffi::CStr;
use std::fs::{File, OpenOptions}; use std::fs::{File, OpenOptions};
use std::io::{Error, ErrorKind}; use std::io::{Error, ErrorKind};
use std::mem;
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::os::fd::AsRawFd; use std::os::fd::AsRawFd;
use std::os::unix::prelude::OpenOptionsExt; use std::os::unix::prelude::OpenOptionsExt;
@ -27,113 +28,82 @@ pub struct Interceptor {
evdev: NonNull<libevdev>, evdev: NonNull<libevdev>,
writer: Writer, writer: Writer,
// The state of `read` is stored here to make it cancel safe. // The state of `read` is stored here to make it cancel safe.
events: Packet, events: VecDeque<Event>,
wrote: bool, writing: Option<(u16, u16, i32)>,
dropped: bool, dropped: bool,
writing: Option<Writing>,
} }
impl Interceptor { impl Interceptor {
pub async fn read(&mut self) -> Result<Packet, Error> { pub async fn read(&mut self) -> Result<Event, Error> {
if let Some(writing) = self.writing { if let Some((r#type, code, value)) = self.writing {
let (r#type, code, value) = match writing { log::trace!("Resuming interrupted write");
Writing::Event {
r#type,
code,
value,
} => (r#type, code, value),
Writing::Sync => (glue::EV_SYN as _, glue::SYN_REPORT as _, 0),
};
self.writer.write_raw(r#type, code, value).await?; self.writer.write_raw(r#type, code, value).await?;
self.writing = None; self.writing = None;
} }
loop { while !matches!(self.events.back(), Some(Event::Sync(SyncEvent::All))) {
loop { let (r#type, code, value) = self.read_raw().await?;
let (r#type, code, value) = self.read_raw().await?; let event = match r#type as _ {
let event = match r#type as _ { glue::EV_REL if !self.dropped => {
glue::EV_REL if !self.dropped => { RelAxis::from_raw(code).map(|axis| Event::Rel(RelEvent { axis, value }))
RelAxis::from_raw(code).map(|axis| Event::Rel(RelEvent { axis, value })) }
glue::EV_ABS if !self.dropped => match code as _ {
glue::ABS_MT_TOOL_TYPE => {
ToolType::from_raw(value).map(|value| AbsEvent::MtToolType { value })
} }
glue::EV_ABS if !self.dropped => { glue::ABS_MT_BLOB_ID => Some(AbsEvent::MtBlobId { value }),
AbsAxis::from_raw(code).map(|axis| Event::Abs(AbsEvent { axis, value })) _ => AbsAxis::from_raw(code).map(|axis| AbsEvent::Axis { axis, value }),
} }
glue::EV_KEY if !self.dropped && (value == 0 || value == 1) => { .map(Event::Abs),
Key::from_raw(code).map(|key| { glue::EV_KEY if !self.dropped && (value == 0 || value == 1) => Key::from_raw(code)
Event::Key(KeyEvent { .map(|key| {
key, Event::Key(KeyEvent {
down: value == 1, key,
}) down: value == 1,
}) })
} }),
glue::EV_SYN => match code as _ { glue::EV_SYN => match code as _ {
glue::SYN_REPORT => { glue::SYN_REPORT => {
if self.dropped { if self.dropped {
self.dropped = false; self.dropped = false;
continue;
}
break;
}
glue::SYN_DROPPED => {
log::warn!(
"Dropped {} event{}",
self.events.len(),
if self.events.len() == 1 { "" } else { "s" }
);
self.events.clear();
self.dropped = true;
continue; continue;
} }
_ => continue,
},
_ => None,
};
if let Some(event) = event { Some(Event::Sync(SyncEvent::All))
self.events.push(event); }
continue; glue::SYN_DROPPED => {
} log::warn!(
"Dropped {} event{}",
self.events.len(),
if self.events.len() == 1 { "" } else { "s" }
);
log::trace!( self.events.clear();
"Writing back unknown event (type {}, code {}, value {})", self.dropped = true;
r#type, continue;
code, }
value glue::SYN_MT_REPORT if !self.dropped => Some(Event::Sync(SyncEvent::Mt)),
); _ => continue,
},
_ => None,
};
self.writing = Some(Writing::Event { if let Some(event) = event {
r#type, self.events.push_back(event);
code, continue;
value,
});
self.writer.write_raw(r#type, code, value).await?;
self.writing = None;
self.wrote = true;
} }
// Write an EV_SYN only if we actually wrote something back. self.writing = Some((r#type, code, value));
if self.wrote { self.writer.write_raw(r#type, code, value).await?;
self.writing = Some(Writing::Sync); self.writing = None;
self.writer
.write_raw(glue::EV_SYN as _, glue::SYN_REPORT as _, 0)
.await?;
self.writing = None;
self.wrote = false;
}
if !self.events.is_empty() {
return Ok(mem::take(&mut self.events));
}
// At this point, we received an EV_SYN, but no actual events useful to us, so try again.
} }
Ok(self.events.pop_front().unwrap())
} }
pub async fn write(&mut self, events: &[Event]) -> Result<(), Error> { pub async fn write(&mut self, event: &Event) -> Result<(), Error> {
self.writer.write(events).await self.writer.write(event).await
} }
pub fn name(&self) -> &CStr { pub fn name(&self) -> &CStr {
@ -276,9 +246,7 @@ impl Interceptor {
file, file,
evdev, evdev,
writer, writer,
events: VecDeque::new(),
events: Packet::new(),
wrote: false,
dropped: false, dropped: false,
writing: None, writing: None,
}) })
@ -295,12 +263,6 @@ impl Drop for Interceptor {
unsafe impl Send for Interceptor {} unsafe impl Send for Interceptor {}
#[derive(Clone, Copy)]
enum Writing {
Event { r#type: u16, code: u16, value: i32 },
Sync,
}
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub(crate) enum OpenError { pub(crate) enum OpenError {
#[error("Not appliable")] #[error("Not appliable")]

View file

@ -4,10 +4,7 @@ pub mod interceptor;
pub mod key; pub mod key;
pub mod monitor; pub mod monitor;
pub mod rel; pub mod rel;
pub mod sync;
pub mod writer; pub mod writer;
mod glue; mod glue;
pub use event::{Event, Packet};
pub use interceptor::Interceptor;
pub use monitor::Monitor;

20
rkvm-input/src/sync.rs Normal file
View file

@ -0,0 +1,20 @@
use crate::glue;
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum SyncEvent {
All,
Mt,
}
impl SyncEvent {
pub(crate) fn to_raw(&self) -> u16 {
let code = match self {
Self::All => glue::SYN_REPORT,
Self::Mt => glue::SYN_MT_REPORT,
};
code as _
}
}

View file

@ -1,7 +1,7 @@
use crate::abs::{AbsAxis, AbsEvent, AbsInfo}; use crate::abs::{AbsAxis, AbsEvent, AbsInfo};
use crate::event::Event; use crate::event::Event;
use crate::glue::{self, input_absinfo, libevdev, libevdev_uinput}; use crate::glue::{self, input_absinfo, libevdev, libevdev_uinput};
use crate::key::{Key, KeyEvent, Keyboard}; use crate::key::{Key, KeyEvent};
use crate::rel::{RelAxis, RelEvent}; use crate::rel::{RelAxis, RelEvent};
use std::ffi::{CStr, OsStr}; use std::ffi::{CStr, OsStr};
@ -12,8 +12,8 @@ use std::os::fd::AsRawFd;
use std::os::unix::ffi::OsStrExt; use std::os::unix::ffi::OsStrExt;
use std::os::unix::prelude::OpenOptionsExt; use std::os::unix::prelude::OpenOptionsExt;
use std::path::Path; use std::path::Path;
use std::ptr;
use std::ptr::NonNull; use std::ptr::NonNull;
use std::{iter, ptr};
use tokio::io::unix::AsyncFd; use tokio::io::unix::AsyncFd;
pub struct Writer { pub struct Writer {
@ -26,20 +26,21 @@ impl Writer {
WriterBuilder::new() WriterBuilder::new()
} }
pub async fn write(&mut self, events: &[Event]) -> Result<(), Error> { pub async fn write(&mut self, event: &Event) -> Result<(), Error> {
let events = events let (r#type, code, value) = match event {
.iter() Event::Rel(RelEvent { axis, value }) => (glue::EV_REL, axis.to_raw(), *value),
.map(|event| match event { Event::Abs(event) => match event {
Event::Rel(RelEvent { axis, value }) => (glue::EV_REL, axis.to_raw(), *value), AbsEvent::Axis { axis, value } => (glue::EV_ABS, axis.to_raw(), *value),
Event::Abs(AbsEvent { axis, value }) => (glue::EV_ABS, axis.to_raw(), *value), AbsEvent::MtToolType { value } => {
Event::Key(KeyEvent { down, key }) => (glue::EV_KEY, key.to_raw(), *down as _), (glue::EV_ABS, glue::ABS_MT_TOOL_TYPE as _, value.to_raw())
}) }
.chain(iter::once((glue::EV_SYN, glue::SYN_REPORT as _, 0))); AbsEvent::MtBlobId { value } => (glue::EV_ABS, glue::ABS_MT_BLOB_ID as _, *value),
},
for (r#type, code, value) in events { Event::Key(KeyEvent { down, key }) => (glue::EV_KEY, key.to_raw(), *down as _),
self.write_raw(r#type as _, code, value).await?; Event::Sync(event) => (glue::EV_SYN, event.to_raw(), 0),
} };
self.write_raw(r#type as _, code, value).await?;
Ok(()) Ok(())
} }
@ -184,6 +185,19 @@ impl WriterBuilder {
&mut self, &mut self,
items: T, items: T,
) -> Result<&mut Self, Error> { ) -> Result<&mut Self, Error> {
let ret = unsafe {
glue::libevdev_enable_event_code(
self.evdev.as_ptr(),
glue::EV_SYN,
glue::SYN_MT_REPORT,
ptr::null(),
)
};
if ret < 0 {
return Err(Error::from_raw_os_error(-ret));
}
for (axis, info) in items { for (axis, info) in items {
let info = input_absinfo { let info = input_absinfo {
value: info.min, value: info.min,

View file

@ -1,7 +1,6 @@
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use rand::rngs::OsRng; use rand::rngs::OsRng;
use rand::Error; use rand::{Error, Rng};
use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sha2::Sha256; use sha2::Sha256;
use tokio::task; use tokio::task;

View file

@ -3,9 +3,9 @@ pub mod message;
pub mod version; pub mod version;
use rkvm_input::abs::{AbsAxis, AbsInfo}; use rkvm_input::abs::{AbsAxis, AbsInfo};
use rkvm_input::event::Event;
use rkvm_input::key::Key; use rkvm_input::key::Key;
use rkvm_input::rel::RelAxis; use rkvm_input::rel::RelAxis;
use rkvm_input::Packet;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::ffi::CString; use std::ffi::CString;
@ -21,13 +21,12 @@ pub enum Update {
rel: HashSet<RelAxis>, rel: HashSet<RelAxis>,
abs: HashMap<AbsAxis, AbsInfo>, abs: HashMap<AbsAxis, AbsInfo>,
keys: HashSet<Key>, keys: HashSet<Key>,
// Multitouch events intentionally omitted.
}, },
DestroyDevice { DestroyDevice {
id: usize, id: usize,
}, },
EventBatch { Event {
id: usize, id: usize,
events: Packet, event: Event,
}, },
} }

View file

@ -9,9 +9,7 @@ use std::future;
use std::path::PathBuf; use std::path::PathBuf;
use std::process::ExitCode; use std::process::ExitCode;
use std::time::Duration; use std::time::Duration;
use tokio::fs; use tokio::{fs, signal, time};
use tokio::signal;
use tokio::time;
#[derive(Parser)] #[derive(Parser)]
#[structopt(name = "rkvm-server", about = "The rkvm server application")] #[structopt(name = "rkvm-server", about = "The rkvm server application")]

View file

@ -1,13 +1,14 @@
use rkvm_input::abs::{AbsAxis, AbsInfo}; use rkvm_input::abs::{AbsAxis, AbsInfo};
use rkvm_input::event::Event;
use rkvm_input::key::{Key, KeyEvent, Keyboard}; use rkvm_input::key::{Key, KeyEvent, Keyboard};
use rkvm_input::monitor::Monitor;
use rkvm_input::rel::RelAxis; use rkvm_input::rel::RelAxis;
use rkvm_input::{Event, Interceptor, Monitor, Packet};
use rkvm_net::auth::{AuthChallenge, AuthResponse, AuthStatus}; use rkvm_net::auth::{AuthChallenge, AuthResponse, AuthStatus};
use rkvm_net::message::Message; use rkvm_net::message::Message;
use rkvm_net::version::Version; use rkvm_net::version::Version;
use rkvm_net::Update; use rkvm_net::Update;
use slab::Slab; use slab::Slab;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet, VecDeque};
use std::ffi::CString; use std::ffi::CString;
use std::io::{self, ErrorKind}; use std::io::{self, ErrorKind};
use std::net::SocketAddr; use std::net::SocketAddr;
@ -15,6 +16,7 @@ use std::time::Duration;
use thiserror::Error; use thiserror::Error;
use tokio::io::{AsyncWriteExt, BufStream}; use tokio::io::{AsyncWriteExt, BufStream};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::time; use tokio::time;
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
@ -25,6 +27,8 @@ pub enum Error {
Network(io::Error), Network(io::Error),
#[error("Input error: {0}")] #[error("Input error: {0}")]
Input(io::Error), Input(io::Error),
#[error("Event queue overflow")]
Overflow,
} }
pub async fn run( pub async fn run(
@ -59,9 +63,9 @@ pub async fn run(
current = 0; current = 0;
} }
let (sender, receiver) = mpsc::channel(devices.len()); let init_updates = devices
for (id, device) in &devices { .iter()
let update = Update::CreateDevice { .map(|(id, device)| Update::CreateDevice {
id, id,
name: device.name.clone(), name: device.name.clone(),
version: device.version, version: device.version,
@ -70,17 +74,16 @@ pub async fn run(
rel: device.rel.clone(), rel: device.rel.clone(),
abs: device.abs.clone(), abs: device.abs.clone(),
keys: device.keys.clone(), keys: device.keys.clone(),
}; })
.collect();
sender.try_send(update).unwrap();
}
let (sender, receiver) = mpsc::channel(1);
clients.insert(sender); clients.insert(sender);
tokio::spawn(async move { tokio::spawn(async move {
log::info!("{}: Connected", addr); log::info!("{}: Connected", addr);
match client(receiver, stream, addr, acceptor, &password).await { match client(init_updates, receiver, stream, addr, acceptor, &password).await {
Ok(()) => log::info!("{}: Disconnected", addr), Ok(()) => log::info!("{}: Disconnected", addr),
Err(err) => log::error!("{}: Disconnected: {}", addr, err), Err(err) => log::error!("{}: Disconnected: {}", addr, err),
} }
@ -113,7 +116,7 @@ pub async fn run(
let _ = sender.send(update).await; let _ = sender.send(update).await;
} }
let (interceptor_sender, mut interceptor_receiver) = mpsc::channel::<Packet>(1); let (interceptor_sender, mut interceptor_receiver) = mpsc::channel(32);
devices.insert(Device { devices.insert(Device {
name, name,
version, version,
@ -129,18 +132,18 @@ pub async fn run(
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
tokio::select! { tokio::select! {
events = interceptor.read() => { event = interceptor.read() => {
if events.is_err() | events_sender.send((id, events)).await.is_err() { if event.is_err() | events_sender.send((id, event)).await.is_err() {
break; break;
} }
} }
events = interceptor_receiver.recv() => { event = interceptor_receiver.recv() => {
let events = match events { let event = match event {
Some(events) => events, Some(event) => event,
None => break, None => break,
}; };
match interceptor.write(&events).await { match interceptor.write(&event).await {
Ok(()) => {}, Ok(()) => {},
Err(err) => { Err(err) => {
let _ = events_sender.send((id, Err(err))).await; let _ = events_sender.send((id, Err(err))).await;
@ -148,11 +151,7 @@ pub async fn run(
} }
} }
log::trace!( log::trace!("Wrote an event to device {}", id);
"Wrote {} event{}",
events.len(),
if events.len() == 1 { "" } else { "s" }
);
} }
} }
} }
@ -170,25 +169,18 @@ pub async fn run(
); );
} }
(id, result) = event => match result { (id, result) = event => match result {
Ok(events) => { Ok(event) => {
let mut changed = false; let mut changed = false;
for event in &events { if let Event::Key(KeyEvent { key: Key::Key(key), down }) = event {
let (key, down) = match event { if switch_keys.contains(&key) {
Event::Key(KeyEvent { key: Key::Key(key), down }) => (key, down), changed = true;
_ => continue,
};
if !switch_keys.contains(key) { match down {
continue; true => pressed_keys.insert(key),
false => pressed_keys.remove(&key),
};
} }
changed = true;
match down {
true => pressed_keys.insert(*key),
false => pressed_keys.remove(key),
};
} }
// Who to send this batch of events to. // Who to send this batch of events to.
@ -202,14 +194,23 @@ pub async fn run(
break; break;
} }
} }
log::debug!("Switched to client {}", current);
} }
// Index 0 - special case to keep the modular arithmetic above working.
if idx == 0 { if idx == 0 {
let _ = devices[id].sender.send(events).await; // We do a try_send() here rather than a "blocking" send in order to prevent deadlocks.
continue; // In this scenario, the interceptor task is sending events to the main task,
// while the main task is simultaneously sending events back to the interceptor.
// This creates a classic deadlock situation where both tasks are waiting for each other.
match devices[id].sender.try_send(event) {
Ok(()) | Err(TrySendError::Closed(_)) => continue,
Err(TrySendError::Full(_)) => return Err(Error::Overflow),
}
} }
if clients[idx - 1].send(Update::EventBatch { id, events }).await.is_err() { if clients[idx - 1].send(Update::Event { id, event }).await.is_err() {
clients.remove(idx - 1); clients.remove(idx - 1);
if current == idx { if current == idx {
@ -239,7 +240,7 @@ struct Device {
rel: HashSet<RelAxis>, rel: HashSet<RelAxis>,
abs: HashMap<AbsAxis, AbsInfo>, abs: HashMap<AbsAxis, AbsInfo>,
keys: HashSet<Key>, keys: HashSet<Key>,
sender: Sender<Packet>, sender: Sender<Event>,
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
@ -255,6 +256,7 @@ enum ClientError {
} }
async fn client( async fn client(
mut init_updates: VecDeque<Update>,
mut receiver: Receiver<Update>, mut receiver: Receiver<Update>,
stream: TcpStream, stream: TcpStream,
addr: SocketAddr, addr: SocketAddr,
@ -308,9 +310,32 @@ async fn client(
.await .await
.map_err(|_| io::Error::new(ErrorKind::TimedOut, "Negotiation took too long"))??; .map_err(|_| io::Error::new(ErrorKind::TimedOut, "Negotiation took too long"))??;
while let Some(update) = receiver.recv().await { loop {
let recv = async {
match init_updates.pop_front() {
Some(update) => Some((update, false)),
None => receiver.recv().await.map(|update| (update, true)),
}
};
let (update, more) = match recv.await {
Some(update) => update,
None => break,
};
let mut count = 1;
let write = async { let write = async {
update.encode(&mut stream).await?; update.encode(&mut stream).await?;
// Coalesce multiple updates into one chunk.
if more {
while let Ok(update) = receiver.try_recv() {
update.encode(&mut stream).await?;
count += 1;
}
}
stream.flush().await stream.flush().await
}; };
@ -318,7 +343,12 @@ async fn client(
.await .await
.map_err(|_| io::Error::new(ErrorKind::TimedOut, "Update writing took too long"))??; .map_err(|_| io::Error::new(ErrorKind::TimedOut, "Update writing took too long"))??;
log::trace!("{}: Wrote an update", addr); log::trace!(
"{}: Wrote {} update{}",
addr,
count,
if count == 1 { "" } else { "s" }
);
} }
Ok(()) Ok(())