Add SOCKET_DIR env var

This commit is contained in:
Ottatop 2023-08-07 11:53:56 -05:00
parent edba4d2424
commit fe686af815
2 changed files with 37 additions and 10 deletions

View file

@ -3,7 +3,13 @@
local socket = require("posix.sys.socket") local socket = require("posix.sys.socket")
local msgpack = require("msgpack") local msgpack = require("msgpack")
local SOCKET_PATH = "/tmp/pinnacle_socket" local socket_dir = os.getenv("SOCKET_DIR")
if socket_dir then
if socket_dir:match("/$") then
socket_dir = socket_dir:sub(0, socket_dir:len() - 1)
end
end
local SOCKET_PATH = (socket_dir or "/tmp") .. "/pinnacle_socket"
---From https://gist.github.com/stuby/5445834#file-rprint-lua ---From https://gist.github.com/stuby/5445834#file-rprint-lua
---rPrint(struct, [limit], [indent]) Recursively print arbitrary data. ---rPrint(struct, [limit], [indent]) Recursively print arbitrary data.
@ -98,7 +104,10 @@ end
---@param config_func fun(pinnacle: Pinnacle) ---@param config_func fun(pinnacle: Pinnacle)
function pinnacle.setup(config_func) function pinnacle.setup(config_func)
---@type integer ---@type integer
local socket_fd = assert(socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0), "Failed to create socket") local socket_fd = assert(
socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0),
"Failed to create socket"
)
print("created socket at fd " .. socket_fd) print("created socket at fd " .. socket_fd)
assert(0 == socket.connect(socket_fd, { assert(0 == socket.connect(socket_fd, {
@ -184,7 +193,8 @@ function pinnacle.setup(config_func)
if inc_msg.CallCallback then if inc_msg.CallCallback then
unread_cb_msgs[inc_msg.CallCallback.callback_id] = inc_msg unread_cb_msgs[inc_msg.CallCallback.callback_id] = inc_msg
elseif inc_msg.RequestResponse.request_id ~= req_id then elseif inc_msg.RequestResponse.request_id ~= req_id then
unread_req_msgs[inc_msg.RequestResponse.request_id] = inc_msg unread_req_msgs[inc_msg.RequestResponse.request_id] =
inc_msg
else else
return inc_msg return inc_msg
end end
@ -198,7 +208,9 @@ function pinnacle.setup(config_func)
while true do while true do
for cb_id, inc_msg in pairs(unread_cb_msgs) do for cb_id, inc_msg in pairs(unread_cb_msgs) do
CallbackTable[inc_msg.CallCallback.callback_id](inc_msg.CallCallback.args) CallbackTable[inc_msg.CallCallback.callback_id](
inc_msg.CallCallback.args
)
unread_cb_msgs[cb_id] = nil -- INFO: does this shift the table and frick everything up? unread_cb_msgs[cb_id] = nil -- INFO: does this shift the table and frick everything up?
end end
@ -208,7 +220,9 @@ function pinnacle.setup(config_func)
if inc_msg.CallCallback and inc_msg.CallCallback.callback_id then if inc_msg.CallCallback and inc_msg.CallCallback.callback_id then
if inc_msg.CallCallback.args then -- TODO: can just inline if inc_msg.CallCallback.args then -- TODO: can just inline
CallbackTable[inc_msg.CallCallback.callback_id](inc_msg.CallCallback.args) CallbackTable[inc_msg.CallCallback.callback_id](
inc_msg.CallCallback.args
)
else else
CallbackTable[inc_msg.CallCallback.callback_id](nil) CallbackTable[inc_msg.CallCallback.callback_id](nil)
end end

View file

@ -48,7 +48,7 @@ use smithay::reexports::calloop::{
use self::msg::{Msg, OutgoingMsg}; use self::msg::{Msg, OutgoingMsg};
const SOCKET_PATH: &str = "/tmp/pinnacle_socket"; const DEFAULT_SOCKET_DIR: &str = "/tmp";
fn handle_client( fn handle_client(
mut stream: UnixStream, mut stream: UnixStream,
@ -87,21 +87,34 @@ pub struct PinnacleSocketSource {
impl PinnacleSocketSource { impl PinnacleSocketSource {
pub fn new(sender: Sender<Msg>) -> Result<Self, io::Error> { pub fn new(sender: Sender<Msg>) -> Result<Self, io::Error> {
let socket_path = Path::new(SOCKET_PATH); let socket_path = std::env::var("SOCKET_DIR").unwrap_or(DEFAULT_SOCKET_DIR.to_string());
let socket_path = Path::new(&socket_path);
if !socket_path.is_dir() {
tracing::error!("SOCKET_DIR must be a directory");
return Err(io::Error::new(
io::ErrorKind::Other,
"SOCKET_DIR must be a directory",
));
}
let socket_path = socket_path.join("pinnacle_socket");
// TODO: use anyhow // TODO: use anyhow
if let Ok(exists) = socket_path.try_exists() { if let Ok(exists) = socket_path.try_exists() {
if exists { if exists {
if let Err(err) = std::fs::remove_file(socket_path) { if let Err(err) = std::fs::remove_file(&socket_path) {
tracing::error!("Failed to remove old socket: {err}"); tracing::error!("Failed to remove old socket: {err}");
return Err(err); return Err(err);
} }
} }
} }
let listener = match UnixListener::bind(socket_path) { let listener = match UnixListener::bind(&socket_path) {
Ok(listener) => listener, Ok(listener) => {
tracing::info!("Bound to socket at {socket_path:?}");
listener
}
Err(err) => { Err(err) => {
tracing::error!("Failed to bind to socket: {err}"); tracing::error!("Failed to bind to socket: {err}");
return Err(err); return Err(err);