diff --git a/Cargo.toml b/Cargo.toml index d0297ea..adf958a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ futures_codec = "0.4" [dev-dependencies] chrono = "^0.4" criterion = "0.3" +pretty_env_logger = "0.4" [lib] bench = false @@ -35,7 +36,7 @@ bench = false [[bench]] name = "pub_sub" harness = false -bench = false # Don't actually benchmark this, until we fix it +bench = false [[bench]] name = "req_rep" diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index bd59fbc..acc0f9b 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -8,6 +8,7 @@ use lazy_static::lazy_static; use regex::Regex; use std::fmt; use std::net::SocketAddr; +use std::path::PathBuf; use std::str::FromStr; use crate::error::EndpointError; @@ -31,19 +32,25 @@ pub type Port = u16; pub enum Endpoint { // TODO: Add endpoints for the other transport variants Tcp(Host, Port), + Ipc(Option), } impl Endpoint { pub fn transport(&self) -> Transport { match self { Self::Tcp(_, _) => Transport::Tcp, + Self::Ipc(_) => Transport::Ipc, } } /// Creates an `Endpoint::Tcp` from a [`SocketAddr`] - pub fn from_tcp_sock_addr(addr: SocketAddr) -> Self { + pub fn from_tcp_addr(addr: SocketAddr) -> Self { Endpoint::Tcp(addr.ip().into(), addr.port()) } + + pub fn from_tcp_domain(addr: String, port: u16) -> Self { + Endpoint::Tcp(Host::Domain(addr), port) + } } impl FromStr for Endpoint { @@ -81,6 +88,10 @@ impl FromStr for Endpoint { let (host, port) = extract_host_port(address)?; Endpoint::Tcp(host, port) } + Transport::Ipc => { + let path: PathBuf = address.to_string().into(); + Endpoint::Ipc(Some(path)) + } }; Ok(endpoint) @@ -97,6 +108,8 @@ impl fmt::Display for Endpoint { write!(f, "tcp://{}:{}", host, port) } } + Endpoint::Ipc(Some(path)) => write!(f, "ipc://{}", path.display()), + Endpoint::Ipc(None) => write!(f, "ipc://????"), } } } @@ -139,6 +152,18 @@ mod tests { lazy_static! { static ref PAIRS: Vec<(Endpoint, &'static str)> = vec![ + ( + Endpoint::Ipc(Some(PathBuf::from("/tmp/asdf"))), + "ipc:///tmp/asdf" + ), + ( + Endpoint::Ipc(Some(PathBuf::from("my/dir_1/dir-2"))), + "ipc://my/dir_1/dir-2" + ), + ( + Endpoint::Ipc(Some(PathBuf::from("@abstract/namespace"))), + "ipc://@abstract/namespace" + ), ( Endpoint::Tcp(Host::Domain("www.example.com".to_string()), 1234), "tcp://www.example.com:1234", @@ -179,6 +204,7 @@ mod tests { for (e, s) in PAIRS.iter() { assert_eq!(&format!("{}", e), s); } + assert_eq!(&format!("{}", Endpoint::Ipc(None)), "ipc://????"); } #[test] diff --git a/src/endpoint/transport.rs b/src/endpoint/transport.rs index 9b000ad..a629931 100644 --- a/src/endpoint/transport.rs +++ b/src/endpoint/transport.rs @@ -10,6 +10,7 @@ use super::EndpointError; pub enum Transport { /// TCP transport Tcp, + Ipc, } impl FromStr for Transport { @@ -18,6 +19,7 @@ impl FromStr for Transport { fn from_str(s: &str) -> Result { let result = match s { "tcp" => Transport::Tcp, + "ipc" => Transport::Ipc, _ => return Err(EndpointError::UnknownTransport(s.to_string())), }; Ok(result) @@ -35,6 +37,7 @@ impl fmt::Display for Transport { fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), std::fmt::Error> { let s = match self { Transport::Tcp => "tcp", + Transport::Ipc => "ipc", }; write!(f, "{}", s) } diff --git a/src/transport/ipc/mod.rs b/src/transport/ipc/mod.rs new file mode 100644 index 0000000..91ff25a --- /dev/null +++ b/src/transport/ipc/mod.rs @@ -0,0 +1,24 @@ +// TODO: Conditionally compile things +mod tokio; + +use self::tokio as tk; +use crate::codec::FramedIo; +use crate::endpoint::Endpoint; +use crate::transport::AcceptStopChannel; +use crate::ZmqResult; + +use std::path::PathBuf; + +pub(crate) async fn connect(path: PathBuf) -> ZmqResult<(FramedIo, Endpoint)> { + tk::connect(path).await +} + +pub(crate) async fn begin_accept( + path: PathBuf, + cback: impl Fn(ZmqResult<(FramedIo, Endpoint)>) -> T + Send + 'static, +) -> ZmqResult<(Endpoint, AcceptStopChannel)> +where + T: std::future::Future + Send + 'static, +{ + tk::begin_accept(path, cback).await +} diff --git a/src/transport/ipc/tokio.rs b/src/transport/ipc/tokio.rs new file mode 100644 index 0000000..ebbaa89 --- /dev/null +++ b/src/transport/ipc/tokio.rs @@ -0,0 +1,64 @@ +use crate::codec::FramedIo; +use crate::endpoint::Endpoint; +use crate::transport::AcceptStopChannel; +use crate::ZmqResult; + +use futures::{select, FutureExt}; +use std::path::{Path, PathBuf}; +use tokio_util::compat::Tokio02AsyncReadCompatExt; + +pub(crate) async fn connect(path: PathBuf) -> ZmqResult<(FramedIo, Endpoint)> { + let raw_socket = tokio::net::UnixStream::connect(&path).await?; + let peer_addr = raw_socket.peer_addr()?; + let peer_addr = peer_addr.as_pathname().map(|a| a.to_owned()); + let boxed_sock = Box::new(raw_socket.compat()); + Ok((FramedIo::new(boxed_sock), Endpoint::Ipc(peer_addr))) +} + +pub(crate) async fn begin_accept( + path: PathBuf, + cback: impl Fn(ZmqResult<(FramedIo, Endpoint)>) -> T + Send + 'static, +) -> ZmqResult<(Endpoint, AcceptStopChannel)> +where + T: std::future::Future + Send + 'static, +{ + let wildcard: &Path = "*".as_ref(); + if path == wildcard { + todo!("Need to implement support for wildcard paths!"); + } + let mut listener = tokio::net::UnixListener::bind(path)?; + let resolved_addr = listener.local_addr()?; + let resolved_addr = resolved_addr.as_pathname().map(|a| a.to_owned()); + let listener_addr = resolved_addr.clone(); + let (stop_handle, stop_callback) = futures::channel::oneshot::channel::<()>(); + tokio::spawn(async move { + let mut stop_callback = stop_callback.fuse(); + loop { + select! { + incoming = listener.accept().fuse() => { + let maybe_accepted: Result<_, _> = incoming.map(|(raw_sock, peer_addr)| { + let raw_sock = FramedIo::new(Box::new(raw_sock.compat())); + let peer_addr = peer_addr.as_pathname().map(|a| a.to_owned()); + (raw_sock, Endpoint::Ipc(peer_addr)) + }).map_err(|err| err.into()); + tokio::spawn(cback(maybe_accepted.into())); + }, + _ = stop_callback => { + log::debug!("Accept task received stop signal. {:?}", listener_addr); + break + } + } + } + drop(listener); + if let Some(listener_addr) = listener_addr { + if let Err(err) = tokio::fs::remove_file(&listener_addr).await { + log::warn!( + "Could not delete unix socket at {}: {}", + listener_addr.display(), + err + ); + } + } + }); + Ok((Endpoint::Ipc(resolved_addr), AcceptStopChannel(stop_handle))) +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs index f90449c..c16adbb 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -1,12 +1,16 @@ +mod ipc; mod tcp; use crate::codec::FramedIo; use crate::endpoint::Endpoint; +use crate::error::ZmqError; use crate::ZmqResult; pub(crate) async fn connect(endpoint: Endpoint) -> ZmqResult<(FramedIo, Endpoint)> { match endpoint { Endpoint::Tcp(host, port) => tcp::connect(host, port).await, + Endpoint::Ipc(Some(path)) => ipc::connect(path).await, + Endpoint::Ipc(None) => Err(ZmqError::Socket("Cannot connect to an unnamed ipc socket")), } } @@ -29,5 +33,9 @@ where { match endpoint { Endpoint::Tcp(host, port) => tcp::begin_accept(host, port, cback).await, + Endpoint::Ipc(Some(path)) => ipc::begin_accept(path, cback).await, + Endpoint::Ipc(None) => Err(ZmqError::Socket( + "Cannot begin accepting peers at an unnamed ipc socket", + )), } } diff --git a/src/transport/tcp/tokio.rs b/src/transport/tcp/tokio.rs index 75874d6..7111f0b 100644 --- a/src/transport/tcp/tokio.rs +++ b/src/transport/tcp/tokio.rs @@ -10,11 +10,11 @@ use tokio_util::compat::Tokio02AsyncReadCompatExt; pub(crate) async fn connect(host: Host, port: Port) -> ZmqResult<(FramedIo, Endpoint)> { let raw_socket = tokio::net::TcpStream::connect((host.to_string().as_str(), port)).await?; - let remote_addr = raw_socket.peer_addr()?; + let peer_addr = raw_socket.peer_addr()?; let boxed_sock = Box::new(raw_socket.compat()); Ok(( FramedIo::new(boxed_sock), - Endpoint::from_tcp_sock_addr(remote_addr), + Endpoint::from_tcp_addr(peer_addr), )) } @@ -36,7 +36,7 @@ where incoming = listener.accept().fuse() => { let maybe_accepted: Result<_, _> = incoming.map(|(raw_sock, remote_addr)| { let raw_sock = FramedIo::new(Box::new(raw_sock.compat())); - (raw_sock, Endpoint::from_tcp_sock_addr(remote_addr)) + (raw_sock, Endpoint::from_tcp_addr(remote_addr)) }).map_err(|err| err.into()); tokio::spawn(cback(maybe_accepted.into())); }, diff --git a/src/util.rs b/src/util.rs index f19a130..206056a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -268,9 +268,12 @@ pub(crate) mod tests { let mut port_set = std::collections::HashSet::new(); for b in bound_to.keys() { - let Endpoint::Tcp(host, port) = b; - assert_eq!(host, &any.into()); - port_set.insert(*port); + if let Endpoint::Tcp(host, port) = b { + assert_eq!(host, &any.into()); + port_set.insert(*port); + } else { + unreachable!() + } } (start_port..start_port + 4).for_each(|p| assert!(port_set.contains(&p))); @@ -288,11 +291,14 @@ pub(crate) mod tests { assert_eq!(bound_to.len(), 4); let mut port_set = std::collections::HashSet::new(); for b in bound_to.keys() { - let Endpoint::Tcp(host, port) = b; - assert_eq!(host, &Host::Domain("localhost".to_string())); - assert_ne!(*port, 0); - // Insert and check that it wasn't already present - assert!(port_set.insert(*port)); + if let Endpoint::Tcp(host, port) = b { + assert_eq!(host, &Host::Domain("localhost".to_string())); + assert_ne!(*port, 0); + // Insert and check that it wasn't already present + assert!(port_set.insert(*port)); + } else { + unreachable!() + } } Ok(()) diff --git a/tests/pub_sub.rs b/tests/pub_sub.rs index 0a5e098..d8d2a01 100644 --- a/tests/pub_sub.rs +++ b/tests/pub_sub.rs @@ -1,13 +1,15 @@ +use zeromq::prelude::*; +use zeromq::Endpoint; + use futures::channel::{mpsc, oneshot}; use futures::{SinkExt, StreamExt}; use std::convert::TryInto; use std::time::Duration; -use zeromq::prelude::*; -use zeromq::Endpoint; - #[tokio::test] async fn test_pub_sub_sockets() { + pretty_env_logger::try_init().ok(); + async fn helper(bind_addr: &'static str) { // We will join on these at the end to determine if any tasks we spawned // panicked @@ -22,7 +24,7 @@ async fn test_pub_sub_sockets() { let bound_to = pub_socket .bind(bind_addr) .await - .unwrap_or_else(|_| panic!("Failed to bind to {}", bind_addr)); + .unwrap_or_else(|e| panic!("Failed to bind to {}: {}", bind_addr, e)); has_bound_sender .send(bound_to) .expect("channel was dropped"); @@ -42,11 +44,9 @@ async fn test_pub_sub_sockets() { // TODO: ZMQ sockets should not care about this sort of ordering. // See https://github.com/zeromq/zmq.rs/issues/73 let bound_addr = has_bound.await.expect("channel was cancelled"); - assert!(if let Endpoint::Tcp(_host, port) = bound_addr.clone() { - port != 0 - } else { - unreachable!() - }); + if let Endpoint::Tcp(_host, port) = bound_addr.clone() { + assert_ne!(port, 0); + } let (sub_results_sender, sub_results) = mpsc::channel(100); for _ in 0..10 { @@ -94,6 +94,7 @@ async fn test_pub_sub_sockets() { "tcp://localhost:0", "tcp://127.0.0.1:0", "tcp://[::1]:0", + "ipc://asdf.sock", ]; futures::future::join_all(addrs.into_iter().map(helper)).await; } diff --git a/tests/rep_req.rs b/tests/rep_req.rs index 74eb7bf..905064b 100644 --- a/tests/rep_req.rs +++ b/tests/rep_req.rs @@ -19,6 +19,8 @@ async fn run_rep_server(mut rep_socket: RepSocket) -> Result<(), Box> #[tokio::test] async fn test_req_rep_sockets() -> Result<(), Box> { + pretty_env_logger::try_init().ok(); + let mut rep_socket = zeromq::RepSocket::new(); let endpoint = rep_socket.bind("tcp://localhost:0").await?; println!("Started rep server on {}", endpoint); @@ -40,6 +42,8 @@ async fn test_req_rep_sockets() -> Result<(), Box> { #[tokio::test] async fn test_many_req_rep_sockets() -> Result<(), Box> { + pretty_env_logger::try_init().ok(); + let mut rep_socket = zeromq::RepSocket::new(); let endpoint = rep_socket.bind("tcp://localhost:0").await?; println!("Started rep server on {}", endpoint);