From fe686af815d6902dc10af0725583b878da321db3 Mon Sep 17 00:00:00 2001 From: Ottatop Date: Mon, 7 Aug 2023 11:53:56 -0500 Subject: [PATCH] Add SOCKET_DIR env var --- api/lua/pinnacle.lua | 24 +++++++++++++++++++----- src/api.rs | 23 ++++++++++++++++++----- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/api/lua/pinnacle.lua b/api/lua/pinnacle.lua index 914f51c..e6095b7 100644 --- a/api/lua/pinnacle.lua +++ b/api/lua/pinnacle.lua @@ -3,7 +3,13 @@ local socket = require("posix.sys.socket") local msgpack = require("msgpack") -local SOCKET_PATH = "/tmp/pinnacle_socket" +local socket_dir = os.getenv("SOCKET_DIR") +if socket_dir then + if socket_dir:match("/$") then + socket_dir = socket_dir:sub(0, socket_dir:len() - 1) + end +end +local SOCKET_PATH = (socket_dir or "/tmp") .. "/pinnacle_socket" ---From https://gist.github.com/stuby/5445834#file-rprint-lua ---rPrint(struct, [limit], [indent]) Recursively print arbitrary data. @@ -98,7 +104,10 @@ end ---@param config_func fun(pinnacle: Pinnacle) function pinnacle.setup(config_func) ---@type integer - local socket_fd = assert(socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0), "Failed to create socket") + local socket_fd = assert( + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0), + "Failed to create socket" + ) print("created socket at fd " .. socket_fd) assert(0 == socket.connect(socket_fd, { @@ -184,7 +193,8 @@ function pinnacle.setup(config_func) if inc_msg.CallCallback then unread_cb_msgs[inc_msg.CallCallback.callback_id] = inc_msg elseif inc_msg.RequestResponse.request_id ~= req_id then - unread_req_msgs[inc_msg.RequestResponse.request_id] = inc_msg + unread_req_msgs[inc_msg.RequestResponse.request_id] = + inc_msg else return inc_msg end @@ -198,7 +208,9 @@ function pinnacle.setup(config_func) while true do for cb_id, inc_msg in pairs(unread_cb_msgs) do - CallbackTable[inc_msg.CallCallback.callback_id](inc_msg.CallCallback.args) + CallbackTable[inc_msg.CallCallback.callback_id]( + inc_msg.CallCallback.args + ) unread_cb_msgs[cb_id] = nil -- INFO: does this shift the table and frick everything up? end @@ -208,7 +220,9 @@ function pinnacle.setup(config_func) if inc_msg.CallCallback and inc_msg.CallCallback.callback_id then if inc_msg.CallCallback.args then -- TODO: can just inline - CallbackTable[inc_msg.CallCallback.callback_id](inc_msg.CallCallback.args) + CallbackTable[inc_msg.CallCallback.callback_id]( + inc_msg.CallCallback.args + ) else CallbackTable[inc_msg.CallCallback.callback_id](nil) end diff --git a/src/api.rs b/src/api.rs index 77d03d3..2653973 100644 --- a/src/api.rs +++ b/src/api.rs @@ -48,7 +48,7 @@ use smithay::reexports::calloop::{ use self::msg::{Msg, OutgoingMsg}; -const SOCKET_PATH: &str = "/tmp/pinnacle_socket"; +const DEFAULT_SOCKET_DIR: &str = "/tmp"; fn handle_client( mut stream: UnixStream, @@ -87,21 +87,34 @@ pub struct PinnacleSocketSource { impl PinnacleSocketSource { pub fn new(sender: Sender) -> Result { - let socket_path = Path::new(SOCKET_PATH); + let socket_path = std::env::var("SOCKET_DIR").unwrap_or(DEFAULT_SOCKET_DIR.to_string()); + let socket_path = Path::new(&socket_path); + if !socket_path.is_dir() { + tracing::error!("SOCKET_DIR must be a directory"); + return Err(io::Error::new( + io::ErrorKind::Other, + "SOCKET_DIR must be a directory", + )); + } + + let socket_path = socket_path.join("pinnacle_socket"); // TODO: use anyhow if let Ok(exists) = socket_path.try_exists() { if exists { - if let Err(err) = std::fs::remove_file(socket_path) { + if let Err(err) = std::fs::remove_file(&socket_path) { tracing::error!("Failed to remove old socket: {err}"); return Err(err); } } } - let listener = match UnixListener::bind(socket_path) { - Ok(listener) => listener, + let listener = match UnixListener::bind(&socket_path) { + Ok(listener) => { + tracing::info!("Bound to socket at {socket_path:?}"); + listener + } Err(err) => { tracing::error!("Failed to bind to socket: {err}"); return Err(err);