From 1539f73e459771a9cf2d10c5597ff52bfa4dbdaf Mon Sep 17 00:00:00 2001 From: Ottatop Date: Wed, 21 Feb 2024 19:40:11 -0600 Subject: [PATCH] Get a signal to work --- api/lua/pinnacle-api-dev-1.rockspec | 1 + api/lua/pinnacle/grpc/client.lua | 70 ++++++++--- api/lua/pinnacle/grpc/protobuf.lua | 22 ++++ api/lua/pinnacle/signal.lua | 177 ++++++++++++++++++++++++++++ api/lua/pinnacle/tag.lua | 5 + api/lua/pinnacle/window.lua | 2 + src/api.rs | 32 +++-- src/api/signal.rs | 23 +++- 8 files changed, 296 insertions(+), 36 deletions(-) create mode 100644 api/lua/pinnacle/signal.lua diff --git a/api/lua/pinnacle-api-dev-1.rockspec b/api/lua/pinnacle-api-dev-1.rockspec index 3a727d9..13e56f9 100644 --- a/api/lua/pinnacle-api-dev-1.rockspec +++ b/api/lua/pinnacle-api-dev-1.rockspec @@ -26,5 +26,6 @@ build = { ["pinnacle.tag"] = "pinnacle/tag.lua", ["pinnacle.window"] = "pinnacle/window.lua", ["pinnacle.util"] = "pinnacle/util.lua", + ["pinnacle.signal"] = "pinnacle/signal.lua", }, } diff --git a/api/lua/pinnacle/grpc/client.lua b/api/lua/pinnacle/grpc/client.lua index eb62da5..7cbaacd 100644 --- a/api/lua/pinnacle/grpc/client.lua +++ b/api/lua/pinnacle/grpc/client.lua @@ -5,6 +5,7 @@ local socket = require("cqueues.socket") local headers = require("http.headers") local h2_connection = require("http.h2_connection") +local protobuf = require("pinnacle.grpc.protobuf") local pb = require("pb") ---@nodoc @@ -40,6 +41,8 @@ end ---@class H2Connection ---@field new_stream function +---@class H2Stream + ---@nodoc ---@class Client ---@field conn H2Connection @@ -76,12 +79,7 @@ function client.unary_request(grpc_request_params) local response_type = grpc_request_params.response_type or "google.protobuf.Empty" local data = grpc_request_params.data - local encoded_protobuf = assert(pb.encode(request_type, data), "wrong table schema") - - local packed_prefix = string.pack("I1", 0) - local payload_len = string.pack(">I4", encoded_protobuf:len()) - - local body = packed_prefix .. payload_len .. encoded_protobuf + local body = protobuf.encode(request_type, data) stream:write_headers(create_request_headers(service, method), false) stream:write_chunk(body, true) @@ -126,18 +124,7 @@ function client.server_streaming_request(grpc_request_params, callback) local response_type = grpc_request_params.response_type or "google.protobuf.Empty" local data = grpc_request_params.data - local success, obj = pcall(pb.encode, request_type, data) - if not success then - print("failed to encode:", obj, "for", service, method, request_type, response_type) - os.exit(1) - end - - local encoded_protobuf = obj - - local packed_prefix = string.pack("I1", 0) - local payload_len = string.pack(">I4", encoded_protobuf:len()) - - local body = packed_prefix .. payload_len .. encoded_protobuf + local body = protobuf.encode(request_type, data) stream:write_headers(create_request_headers(service, method), false) stream:write_chunk(body, true) @@ -175,6 +162,51 @@ end ---@param callback fun(response: table) --- ---@return H2Stream -function client.bidirectional_streaming_request(grpc_request_params, callback) end +function client.bidirectional_streaming_request(grpc_request_params, callback) + local stream = client.conn:new_stream() + + local service = grpc_request_params.service + local method = grpc_request_params.method + local request_type = grpc_request_params.request_type + local response_type = grpc_request_params.response_type or "google.protobuf.Empty" + local data = grpc_request_params.data + + local body = protobuf.encode(request_type, data) + + stream:write_headers(create_request_headers(service, method), false) + stream:write_chunk(body, false) + + -- TODO: check response headers for errors + local _ = stream:get_headers() + + client.loop:wrap(function() + for response_body in stream:each_chunk() do + -- Skip the 1-byte compressed flag and the 4-byte message length + ---@diagnostic disable-next-line: redefined-local + local response_body = response_body:sub(6) + + ---@diagnostic disable-next-line: redefined-local + local success, obj = pcall(pb.decode, response_type, response_body) + if not success then + print(obj) + os.exit(1) + end + + local response = obj + callback(response) + end + + local trailers = stream:get_headers() + if trailers then + for name, value, never_index in trailers:each() do + print(name, value, never_index) + end + end + + print("AFTER bidirectional_streaming_request ENDS") + end) + + return stream +end return client diff --git a/api/lua/pinnacle/grpc/protobuf.lua b/api/lua/pinnacle/grpc/protobuf.lua index 187888d..7c2c359 100644 --- a/api/lua/pinnacle/grpc/protobuf.lua +++ b/api/lua/pinnacle/grpc/protobuf.lua @@ -17,6 +17,7 @@ function protobuf.build_protos() PINNACLE_PROTO_DIR .. "/pinnacle/output/" .. version .. "/output.proto", PINNACLE_PROTO_DIR .. "/pinnacle/process/" .. version .. "/process.proto", PINNACLE_PROTO_DIR .. "/pinnacle/window/" .. version .. "/window.proto", + PINNACLE_PROTO_DIR .. "/pinnacle/signal/" .. version .. "/signal.proto", } local cmd = "protoc --descriptor_set_out=/tmp/pinnacle.pb --proto_path=" .. PINNACLE_PROTO_DIR .. " " @@ -38,4 +39,25 @@ function protobuf.build_protos() pb.option("enum_as_value") end +---Encode the given `data` as the protobuf `type`. +---@param type string The absolute protobuf type +---@param data table The table of data, conforming to its protobuf definition +---@return string buffer The encoded buffer +function protobuf.encode(type, data) + local success, obj = pcall(pb.encode, type, data) + if not success then + print("failed to encode:", obj, "type:", type) + os.exit(1) + end + + local encoded_protobuf = obj + + local packed_prefix = string.pack("I1", 0) + local payload_len = string.pack(">I4", encoded_protobuf:len()) + + local body = packed_prefix .. payload_len .. encoded_protobuf + + return body +end + return protobuf diff --git a/api/lua/pinnacle/signal.lua b/api/lua/pinnacle/signal.lua new file mode 100644 index 0000000..4783064 --- /dev/null +++ b/api/lua/pinnacle/signal.lua @@ -0,0 +1,177 @@ +-- This Source Code Form is subject to the terms of the Mozilla Public +-- License, v. 2.0. If a copy of the MPL was not distributed with this +-- file, You can obtain one at https://mozilla.org/MPL/2.0/. + +local client = require("pinnacle.grpc.client") + +---The protobuf absolute path prefix +local prefix = "pinnacle.signal." .. client.version .. "." +local service = prefix .. "SignalService" + +---@type table +---@enum (key) SignalServiceMethod +local rpc_types = { + OutputConnect = { + response_type = "OutputConnectResponse", + }, + Layout = { + response_type = "LayoutResponse", + }, + WindowPointerEnter = { + response_type = "WindowPointerEnterResponse", + }, + WindowPointerLeave = { + response_type = "WindowPointerLeaveResponse", + }, +} + +---Build GrpcRequestParams +---@param method SignalServiceMethod +---@param data table +---@return GrpcRequestParams +local function build_grpc_request_params(method, data) + local req_type = rpc_types[method].request_type + local resp_type = rpc_types[method].response_type + + ---@type GrpcRequestParams + return { + service = service, + method = method, + request_type = req_type and prefix .. req_type or prefix .. method .. "Request", + response_type = resp_type and prefix .. resp_type, + data = data, + } +end + +local stream_control = { + UNSPECIFIED = 0, + READY = 1, + DISCONNECT = 2, +} + +local signals = { + output_connect = { + ---@type H2Stream? + sender = nil, + ---@type (fun(output: OutputHandle))[] + callbacks = {}, + }, + layout = { + ---@type H2Stream? + sender = nil, + ---@type (fun(windows: WindowHandle[], tag: TagHandle))[] + callbacks = {}, + }, + window_pointer_enter = { + ---@type H2Stream? + sender = nil, + ---@type (fun(output: OutputHandle))[] + callbacks = {}, + }, + window_pointer_leave = { + ---@type H2Stream? + sender = nil, + ---@type (fun(output: OutputHandle))[] + callbacks = {}, + }, +} + +---@class Signal +local signal = {} + +---@param fn fun(windows: WindowHandle[], tag: TagHandle) +function signal.layout_add(fn) + if #signals.layout.callbacks == 0 then + signal.layout_connect() + end + + table.insert(signals.layout.callbacks, fn) +end + +function signal.layout_dc() + signal.layout_disconnect() +end + +function signal.output_connect_connect() + local stream = client.bidirectional_streaming_request( + build_grpc_request_params("OutputConnect", { + control = stream_control.READY, + }), + function(response) + ---@diagnostic disable-next-line: invisible + local handle = require("pinnacle.output").handle.new(response.output_name) + for _, callback in ipairs(signals.output_connect.callbacks) do + callback(handle) + end + + local chunk = require("pinnacle.grpc.protobuf").encode(prefix .. "OutputConnectRequest", { + control = stream_control.READY, + }) + + if signals.layout.sender then + signals.layout.sender:write_chunk(chunk) + end + end + ) + + signals.output_connect.sender = stream +end + +function signal.output_connect_disconnect() + if signals.output_connect.sender then + local chunk = require("pinnacle.grpc.protobuf").encode(prefix .. "OutputConnectRequest", { + control = stream_control.DISCONNECT, + }) + + signals.output_connect.sender:write_chunk(chunk) + signals.output_connect.sender = nil + end +end + +function signal.layout_connect() + local stream = client.bidirectional_streaming_request( + build_grpc_request_params("Layout", { + control = stream_control.READY, + }), + function(response) + ---@diagnostic disable-next-line: invisible + local window_handles = require("pinnacle.window").handle.new_from_table(response.window_ids or {}) + ---@diagnostic disable-next-line: invisible + local tag_handle = require("pinnacle.tag").handle.new(response.tag_id) + + for _, callback in ipairs(signals.layout.callbacks) do + print("calling layout callback") + callback(window_handles, tag_handle) + end + + print("creating control request") + local chunk = require("pinnacle.grpc.protobuf").encode(prefix .. "LayoutRequest", { + control = stream_control.READY, + }) + + if signals.layout.sender then + local success, err = pcall(signals.layout.sender.write_chunk, signals.layout.sender, chunk) + if not success then + print("error sending to stream:", err) + os.exit(1) + end + end + end + ) + + signals.layout.sender = stream +end + +function signal.layout_disconnect() + if signals.layout.sender then + local chunk = require("pinnacle.grpc.protobuf").encode(prefix .. "LayoutRequest", { + control = stream_control.DISCONNECT, + }) + + signals.layout.sender:write_chunk(chunk) + signals.layout.sender = nil + end + signals.layout.callbacks = {} +end + +return signal diff --git a/api/lua/pinnacle/tag.lua b/api/lua/pinnacle/tag.lua index 7497389..e1b61d2 100644 --- a/api/lua/pinnacle/tag.lua +++ b/api/lua/pinnacle/tag.lua @@ -319,6 +319,11 @@ function tag.new_layout_cycler(layouts) } end +---@param fn fun(windows: WindowHandle[], tag: TagHandle) +function tag.connect_layout(fn) + require("pinnacle.signal").layout_add(fn) +end + ---Remove this tag. --- ---### Example diff --git a/api/lua/pinnacle/window.lua b/api/lua/pinnacle/window.lua index 7e287fd..7a7668c 100644 --- a/api/lua/pinnacle/window.lua +++ b/api/lua/pinnacle/window.lua @@ -69,7 +69,9 @@ local WindowHandle = {} ---This module helps you deal with setting windows to fullscreen and maximized, setting their size, ---moving them between tags, and various other actions. ---@class Window +---@field private handle WindowHandleModule local window = {} +window.handle = window_handle ---Get all windows. --- diff --git a/src/api.rs b/src/api.rs index 5e11e99..92887b2 100644 --- a/src/api.rs +++ b/src/api.rs @@ -46,6 +46,7 @@ use sysinfo::ProcessRefreshKind; use tokio::{ io::AsyncBufReadExt, sync::mpsc::{unbounded_channel, UnboundedSender}, + task::JoinHandle, }; use tokio_stream::{Stream, StreamExt}; use tonic::{Request, Response, Status, Streaming}; @@ -60,8 +61,6 @@ use crate::{ window::{window_state::WindowId, WindowElement}, }; -use self::signal::SignalData; - type ResponseStream = Pin> + Send>>; pub type StateFnSender = calloop::channel::Sender>; @@ -134,31 +133,33 @@ fn run_bidirectional_streaming( ) -> Result>, Status> where F1: Fn(&mut State, Result) + Clone + Send + 'static, - F2: FnOnce(&mut State, UnboundedSender>) + Send + 'static, + F2: FnOnce(&mut State, UnboundedSender>, JoinHandle<()>) + Send + 'static, I: Send + 'static, O: Send + 'static, { let (sender, receiver) = unbounded_channel::>(); - let with_out_stream = Box::new(|state: &mut State| { - with_out_stream(state, sender); - }); - - fn_sender - .send(with_out_stream) - .map_err(|_| Status::internal("failed to execute request"))?; + let fn_sender_clone = fn_sender.clone(); let with_in_stream = async move { while let Some(t) = in_stream.next().await { let with_client_item = with_client_item.clone(); // TODO: handle error - let _ = fn_sender.send(Box::new(move |state: &mut State| { + let _ = fn_sender_clone.send(Box::new(move |state: &mut State| { with_client_item(state, t); })); } }; - tokio::spawn(with_in_stream); + let join_handle = tokio::spawn(with_in_stream); + + let with_out_stream = Box::new(|state: &mut State| { + with_out_stream(state, sender, join_handle); + }); + + fn_sender + .send(with_out_stream) + .map_err(|_| Status::internal("failed to execute request"))?; let receiver_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(receiver); Ok(Response::new(Box::pin(receiver_stream))) @@ -732,6 +733,13 @@ impl tag_service_server::TagService for TagService { state.update_windows(&output); state.update_focus(&output); state.schedule_render(&output); + + state.signal_state.layout.signal(|_| { + pinnacle_api_defs::pinnacle::signal::v0alpha1::LayoutResponse { + window_ids: vec![1, 2, 3], + tag_id: Some(1), + } + }); }) .await } diff --git a/src/api/signal.rs b/src/api/signal.rs index 5b10c74..452982c 100644 --- a/src/api/signal.rs +++ b/src/api/signal.rs @@ -3,7 +3,7 @@ use pinnacle_api_defs::pinnacle::signal::v0alpha1::{ OutputConnectResponse, StreamControl, WindowPointerEnterRequest, WindowPointerEnterResponse, WindowPointerLeaveRequest, WindowPointerLeaveResponse, }; -use tokio::sync::mpsc::UnboundedSender; +use tokio::{sync::mpsc::UnboundedSender, task::JoinHandle}; use tonic::{Request, Response, Status, Streaming}; use crate::state::State; @@ -21,6 +21,7 @@ pub struct SignalState { #[derive(Debug, Default)] pub struct SignalData { sender: Option>>, + join_handle: Option>, ready: bool, value: Option, } @@ -41,12 +42,22 @@ impl SignalData { } } - pub fn connect(&mut self, sender: UnboundedSender>) { + pub fn connect( + &mut self, + sender: UnboundedSender>, + join_handle: JoinHandle<()>, + ) { self.sender.replace(sender); + if let Some(handle) = self.join_handle.replace(join_handle) { + handle.abort(); + } } fn disconnect(&mut self) { self.sender.take(); + if let Some(handle) = self.join_handle.take() { + handle.abort(); + } self.ready = false; self.value.take(); } @@ -88,7 +99,7 @@ impl_signal_request!( WindowPointerLeaveRequest ); -fn start_signal_stream( +fn start_signal_stream( sender: StateFnSender, in_stream: Streaming, signal: impl Fn(&mut State) -> &mut SignalData + Clone + Send + 'static, @@ -111,6 +122,8 @@ where } }; + tracing::info!("GOT {request:?} FROM CLIENT STREAM"); + let signal = signal(state); match request.control() { StreamControl::Ready => signal.ready(), @@ -118,9 +131,9 @@ where StreamControl::Unspecified => tracing::warn!("Received unspecified stream control"), } }, - move |state, sender| { + move |state, sender, join_handle| { let signal = signal_clone(state); - signal.connect(sender); + signal.connect(sender, join_handle); }, ) }