diff --git a/rkvm-client/src/config.rs b/rkvm-client/src/config.rs index d44c04f..9a538d0 100644 --- a/rkvm-client/src/config.rs +++ b/rkvm-client/src/config.rs @@ -1,7 +1,9 @@ use serde::de::{self, Visitor}; use serde::{Deserialize, Deserializer}; use std::fmt::{self, Formatter}; +use std::net::SocketAddr; use std::path::PathBuf; +use std::str::FromStr; use tokio_rustls::rustls::ServerName; #[derive(Deserialize)] @@ -39,8 +41,16 @@ impl<'de> Visitor<'de> for ServerVisitor { where E: de::Error, { + // Parsing IPv6 socket addresses can get quite hairy, so let the SocketAddr parser do it for us. + if let Ok(socket_addr) = SocketAddr::from_str(data) { + return Ok(Server { + hostname: ServerName::IpAddress(socket_addr.ip()), + port: socket_addr.port(), + }); + } + let (hostname, port) = data - .rsplit_once(':') + .split_once(':') .ok_or_else(|| E::custom("No port provided"))?; let hostname = hostname.try_into().map_err(E::custom)?; @@ -49,3 +59,64 @@ impl<'de> Visitor<'de> for ServerVisitor { Ok(Server { hostname, port }) } } + +#[cfg(test)] +mod tests { + use std::net::Ipv6Addr; + + use super::*; + + #[derive(Deserialize)] + struct Data { + server: Server, + } + + #[test] + fn server_dns() { + let parsed = toml::from_str::(r#"server = "example.com:8523""#) + .unwrap() + .server; + let expected = Server { + hostname: "example.com".try_into().unwrap(), + port: 8523, + }; + + assert_eq!(parsed.hostname, expected.hostname); + assert_eq!(parsed.port, expected.port); + } + + #[test] + fn server_ipv4() { + let parsed = toml::from_str::(r#"server = "127.0.0.1:8523""#) + .unwrap() + .server; + let expected = Server { + hostname: "127.0.0.1".try_into().unwrap(), + port: 8523, + }; + + assert_eq!(parsed.hostname, expected.hostname); + assert_eq!(parsed.port, expected.port); + } + + #[test] + fn server_ipv6() { + let parsed = toml::from_str::(r#"server = "[::1]:8523""#) + .unwrap() + .server; + let expected = Server { + hostname: "::1".try_into().unwrap(), + port: 8523, + }; + + assert_eq!(parsed.hostname, expected.hostname); + assert_eq!(parsed.port, expected.port); + + let parsed_ip = match parsed.hostname { + ServerName::IpAddress(parsed_ip) => parsed_ip, + _ => unreachable!(), + }; + + assert_eq!(parsed_ip, Ipv6Addr::from_str("::1").unwrap()); + } +}