From fb4188e7c118c0dc252cd533afbcd9b0754d64cc Mon Sep 17 00:00:00 2001 From: Ryan Butler Date: Tue, 13 Oct 2020 00:08:27 -0400 Subject: [PATCH 1/4] Added IPC to Endpoint enum and refactored into new partial skeleton --- src/endpoint/mod.rs | 8 ++++++++ src/endpoint/transport.rs | 3 +++ src/transport/ipc/mod.rs | 24 ++++++++++++++++++++++++ src/transport/ipc/tokio.rs | 19 +++++++++++++++++++ src/transport/mod.rs | 3 +++ src/util.rs | 22 ++++++++++++++-------- 6 files changed, 71 insertions(+), 8 deletions(-) create mode 100644 src/transport/ipc/mod.rs create mode 100644 src/transport/ipc/tokio.rs diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index bd59fbc..9c34221 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -8,6 +8,7 @@ use lazy_static::lazy_static; use regex::Regex; use std::fmt; use std::net::SocketAddr; +use std::path::PathBuf; use std::str::FromStr; use crate::error::EndpointError; @@ -31,12 +32,14 @@ pub type Port = u16; pub enum Endpoint { // TODO: Add endpoints for the other transport variants Tcp(Host, Port), + Ipc(PathBuf), } impl Endpoint { pub fn transport(&self) -> Transport { match self { Self::Tcp(_, _) => Transport::Tcp, + Self::Ipc(_) => Transport::Ipc, } } @@ -81,6 +84,10 @@ impl FromStr for Endpoint { let (host, port) = extract_host_port(address)?; Endpoint::Tcp(host, port) } + Transport::Ipc => { + let path: PathBuf = address.to_string().into(); + Endpoint::Ipc(path) + } }; Ok(endpoint) @@ -97,6 +104,7 @@ impl fmt::Display for Endpoint { write!(f, "tcp://{}:{}", host, port) } } + Endpoint::Ipc(path) => write!(f, "ipc://{}", path.display()), } } } diff --git a/src/endpoint/transport.rs b/src/endpoint/transport.rs index 9b000ad..a629931 100644 --- a/src/endpoint/transport.rs +++ b/src/endpoint/transport.rs @@ -10,6 +10,7 @@ use super::EndpointError; pub enum Transport { /// TCP transport Tcp, + Ipc, } impl FromStr for Transport { @@ -18,6 +19,7 @@ impl FromStr for Transport { fn from_str(s: &str) -> Result { let result = match s { "tcp" => Transport::Tcp, + "ipc" => Transport::Ipc, _ => return Err(EndpointError::UnknownTransport(s.to_string())), }; Ok(result) @@ -35,6 +37,7 @@ impl fmt::Display for Transport { fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), std::fmt::Error> { let s = match self { Transport::Tcp => "tcp", + Transport::Ipc => "ipc", }; write!(f, "{}", s) } diff --git a/src/transport/ipc/mod.rs b/src/transport/ipc/mod.rs new file mode 100644 index 0000000..91ff25a --- /dev/null +++ b/src/transport/ipc/mod.rs @@ -0,0 +1,24 @@ +// TODO: Conditionally compile things +mod tokio; + +use self::tokio as tk; +use crate::codec::FramedIo; +use crate::endpoint::Endpoint; +use crate::transport::AcceptStopChannel; +use crate::ZmqResult; + +use std::path::PathBuf; + +pub(crate) async fn connect(path: PathBuf) -> ZmqResult<(FramedIo, Endpoint)> { + tk::connect(path).await +} + +pub(crate) async fn begin_accept( + path: PathBuf, + cback: impl Fn(ZmqResult<(FramedIo, Endpoint)>) -> T + Send + 'static, +) -> ZmqResult<(Endpoint, AcceptStopChannel)> +where + T: std::future::Future + Send + 'static, +{ + tk::begin_accept(path, cback).await +} diff --git a/src/transport/ipc/tokio.rs b/src/transport/ipc/tokio.rs new file mode 100644 index 0000000..5c3d1f8 --- /dev/null +++ b/src/transport/ipc/tokio.rs @@ -0,0 +1,19 @@ +use crate::codec::FramedIo; +use crate::endpoint::Endpoint; +use crate::transport::AcceptStopChannel; +use crate::ZmqResult; + +use std::path::PathBuf; + +pub(crate) async fn connect(_path: PathBuf) -> ZmqResult<(FramedIo, Endpoint)> { + todo!() +} +pub(crate) async fn begin_accept( + _path: PathBuf, + _cback: impl Fn(ZmqResult<(FramedIo, Endpoint)>) -> T + Send + 'static, +) -> ZmqResult<(Endpoint, AcceptStopChannel)> +where + T: std::future::Future + Send + 'static, +{ + todo!() +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs index f90449c..010383c 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -1,3 +1,4 @@ +mod ipc; mod tcp; use crate::codec::FramedIo; @@ -7,6 +8,7 @@ use crate::ZmqResult; pub(crate) async fn connect(endpoint: Endpoint) -> ZmqResult<(FramedIo, Endpoint)> { match endpoint { Endpoint::Tcp(host, port) => tcp::connect(host, port).await, + Endpoint::Ipc(path) => ipc::connect(path).await, } } @@ -29,5 +31,6 @@ where { match endpoint { Endpoint::Tcp(host, port) => tcp::begin_accept(host, port, cback).await, + Endpoint::Ipc(path) => ipc::begin_accept(path, cback).await, } } diff --git a/src/util.rs b/src/util.rs index f19a130..206056a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -268,9 +268,12 @@ pub(crate) mod tests { let mut port_set = std::collections::HashSet::new(); for b in bound_to.keys() { - let Endpoint::Tcp(host, port) = b; - assert_eq!(host, &any.into()); - port_set.insert(*port); + if let Endpoint::Tcp(host, port) = b { + assert_eq!(host, &any.into()); + port_set.insert(*port); + } else { + unreachable!() + } } (start_port..start_port + 4).for_each(|p| assert!(port_set.contains(&p))); @@ -288,11 +291,14 @@ pub(crate) mod tests { assert_eq!(bound_to.len(), 4); let mut port_set = std::collections::HashSet::new(); for b in bound_to.keys() { - let Endpoint::Tcp(host, port) = b; - assert_eq!(host, &Host::Domain("localhost".to_string())); - assert_ne!(*port, 0); - // Insert and check that it wasn't already present - assert!(port_set.insert(*port)); + if let Endpoint::Tcp(host, port) = b { + assert_eq!(host, &Host::Domain("localhost".to_string())); + assert_ne!(*port, 0); + // Insert and check that it wasn't already present + assert!(port_set.insert(*port)); + } else { + unreachable!() + } } Ok(()) From a727c270de84bda3349e91275b59a2a4954b1ba5 Mon Sep 17 00:00:00 2001 From: Ryan Butler Date: Tue, 13 Oct 2020 00:48:49 -0400 Subject: [PATCH 2/4] Added tests for ipc endpoint display --- src/endpoint/mod.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index 9c34221..11ed1e6 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -147,6 +147,15 @@ mod tests { lazy_static! { static ref PAIRS: Vec<(Endpoint, &'static str)> = vec![ + (Endpoint::Ipc(PathBuf::from("/tmp/asdf")), "ipc:///tmp/asdf"), + ( + Endpoint::Ipc(PathBuf::from("my/dir_1/dir-2")), + "ipc://my/dir_1/dir-2" + ), + ( + Endpoint::Ipc(PathBuf::from("@abstract/namespace")), + "ipc://@abstract/namespace" + ), ( Endpoint::Tcp(Host::Domain("www.example.com".to_string()), 1234), "tcp://www.example.com:1234", From 87e002b75e849c29430223754ca824e1fc2d5402 Mon Sep 17 00:00:00 2001 From: Ryan Butler Date: Tue, 13 Oct 2020 01:23:28 -0400 Subject: [PATCH 3/4] Added IPC implementation --- Cargo.toml | 2 +- src/endpoint/mod.rs | 23 +++++++++++----- src/transport/ipc/tokio.rs | 56 ++++++++++++++++++++++++++++++++++---- src/transport/mod.rs | 9 ++++-- src/transport/tcp/tokio.rs | 6 ++-- tests/pub_sub.rs | 11 ++++---- 6 files changed, 82 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d0297ea..8fb20cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ bench = false [[bench]] name = "pub_sub" harness = false -bench = false # Don't actually benchmark this, until we fix it +bench = false [[bench]] name = "req_rep" diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index 11ed1e6..acc0f9b 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -32,7 +32,7 @@ pub type Port = u16; pub enum Endpoint { // TODO: Add endpoints for the other transport variants Tcp(Host, Port), - Ipc(PathBuf), + Ipc(Option), } impl Endpoint { @@ -44,9 +44,13 @@ impl Endpoint { } /// Creates an `Endpoint::Tcp` from a [`SocketAddr`] - pub fn from_tcp_sock_addr(addr: SocketAddr) -> Self { + pub fn from_tcp_addr(addr: SocketAddr) -> Self { Endpoint::Tcp(addr.ip().into(), addr.port()) } + + pub fn from_tcp_domain(addr: String, port: u16) -> Self { + Endpoint::Tcp(Host::Domain(addr), port) + } } impl FromStr for Endpoint { @@ -86,7 +90,7 @@ impl FromStr for Endpoint { } Transport::Ipc => { let path: PathBuf = address.to_string().into(); - Endpoint::Ipc(path) + Endpoint::Ipc(Some(path)) } }; @@ -104,7 +108,8 @@ impl fmt::Display for Endpoint { write!(f, "tcp://{}:{}", host, port) } } - Endpoint::Ipc(path) => write!(f, "ipc://{}", path.display()), + Endpoint::Ipc(Some(path)) => write!(f, "ipc://{}", path.display()), + Endpoint::Ipc(None) => write!(f, "ipc://????"), } } } @@ -147,13 +152,16 @@ mod tests { lazy_static! { static ref PAIRS: Vec<(Endpoint, &'static str)> = vec![ - (Endpoint::Ipc(PathBuf::from("/tmp/asdf")), "ipc:///tmp/asdf"), ( - Endpoint::Ipc(PathBuf::from("my/dir_1/dir-2")), + Endpoint::Ipc(Some(PathBuf::from("/tmp/asdf"))), + "ipc:///tmp/asdf" + ), + ( + Endpoint::Ipc(Some(PathBuf::from("my/dir_1/dir-2"))), "ipc://my/dir_1/dir-2" ), ( - Endpoint::Ipc(PathBuf::from("@abstract/namespace")), + Endpoint::Ipc(Some(PathBuf::from("@abstract/namespace"))), "ipc://@abstract/namespace" ), ( @@ -196,6 +204,7 @@ mod tests { for (e, s) in PAIRS.iter() { assert_eq!(&format!("{}", e), s); } + assert_eq!(&format!("{}", Endpoint::Ipc(None)), "ipc://????"); } #[test] diff --git a/src/transport/ipc/tokio.rs b/src/transport/ipc/tokio.rs index 5c3d1f8..a1ccbf6 100644 --- a/src/transport/ipc/tokio.rs +++ b/src/transport/ipc/tokio.rs @@ -3,17 +3,61 @@ use crate::endpoint::Endpoint; use crate::transport::AcceptStopChannel; use crate::ZmqResult; -use std::path::PathBuf; +use futures::{select, FutureExt}; +use std::path::{Path, PathBuf}; +use tokio_util::compat::Tokio02AsyncReadCompatExt; -pub(crate) async fn connect(_path: PathBuf) -> ZmqResult<(FramedIo, Endpoint)> { - todo!() +pub(crate) async fn connect(path: PathBuf) -> ZmqResult<(FramedIo, Endpoint)> { + let raw_socket = tokio::net::UnixStream::connect(&path).await?; + let peer_addr = raw_socket.peer_addr()?; + let peer_addr = peer_addr.as_pathname().map(|a| a.to_owned()); + let boxed_sock = Box::new(raw_socket.compat()); + Ok((FramedIo::new(boxed_sock), Endpoint::Ipc(peer_addr))) } + pub(crate) async fn begin_accept( - _path: PathBuf, - _cback: impl Fn(ZmqResult<(FramedIo, Endpoint)>) -> T + Send + 'static, + path: PathBuf, + cback: impl Fn(ZmqResult<(FramedIo, Endpoint)>) -> T + Send + 'static, ) -> ZmqResult<(Endpoint, AcceptStopChannel)> where T: std::future::Future + Send + 'static, { - todo!() + let wildcard: &Path = "*".as_ref(); + if path == wildcard { + todo!("Need to implement support for wildcard paths!"); + } + let mut listener = tokio::net::UnixListener::bind(path)?; + let resolved_addr = listener.local_addr()?; + let resolved_addr = resolved_addr.as_pathname().map(|a| a.to_owned()); + let listener_addr = resolved_addr.clone(); + let (stop_handle, stop_callback) = futures::channel::oneshot::channel::<()>(); + tokio::spawn(async move { + let mut stop_callback = stop_callback.fuse(); + loop { + select! { + incoming = listener.accept().fuse() => { + let maybe_accepted: Result<_, _> = incoming.map(|(raw_sock, peer_addr)| { + let raw_sock = FramedIo::new(Box::new(raw_sock.compat())); + let peer_addr = peer_addr.as_pathname().map(|a| a.to_owned()); + (raw_sock, Endpoint::Ipc(peer_addr)) + }).map_err(|err| err.into()); + tokio::spawn(cback(maybe_accepted.into())); + }, + _ = stop_callback => { + break + } + } + } + drop(listener); + if let Some(listener_addr) = listener_addr { + if let Err(err) = tokio::fs::remove_file(&listener_addr).await { + log::warn!( + "Could not delete unix socket at {}: {}", + listener_addr.display(), + err + ); + } + } + }); + Ok((Endpoint::Ipc(resolved_addr), AcceptStopChannel(stop_handle))) } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 010383c..c16adbb 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -3,12 +3,14 @@ mod tcp; use crate::codec::FramedIo; use crate::endpoint::Endpoint; +use crate::error::ZmqError; use crate::ZmqResult; pub(crate) async fn connect(endpoint: Endpoint) -> ZmqResult<(FramedIo, Endpoint)> { match endpoint { Endpoint::Tcp(host, port) => tcp::connect(host, port).await, - Endpoint::Ipc(path) => ipc::connect(path).await, + Endpoint::Ipc(Some(path)) => ipc::connect(path).await, + Endpoint::Ipc(None) => Err(ZmqError::Socket("Cannot connect to an unnamed ipc socket")), } } @@ -31,6 +33,9 @@ where { match endpoint { Endpoint::Tcp(host, port) => tcp::begin_accept(host, port, cback).await, - Endpoint::Ipc(path) => ipc::begin_accept(path, cback).await, + Endpoint::Ipc(Some(path)) => ipc::begin_accept(path, cback).await, + Endpoint::Ipc(None) => Err(ZmqError::Socket( + "Cannot begin accepting peers at an unnamed ipc socket", + )), } } diff --git a/src/transport/tcp/tokio.rs b/src/transport/tcp/tokio.rs index 75874d6..7111f0b 100644 --- a/src/transport/tcp/tokio.rs +++ b/src/transport/tcp/tokio.rs @@ -10,11 +10,11 @@ use tokio_util::compat::Tokio02AsyncReadCompatExt; pub(crate) async fn connect(host: Host, port: Port) -> ZmqResult<(FramedIo, Endpoint)> { let raw_socket = tokio::net::TcpStream::connect((host.to_string().as_str(), port)).await?; - let remote_addr = raw_socket.peer_addr()?; + let peer_addr = raw_socket.peer_addr()?; let boxed_sock = Box::new(raw_socket.compat()); Ok(( FramedIo::new(boxed_sock), - Endpoint::from_tcp_sock_addr(remote_addr), + Endpoint::from_tcp_addr(peer_addr), )) } @@ -36,7 +36,7 @@ where incoming = listener.accept().fuse() => { let maybe_accepted: Result<_, _> = incoming.map(|(raw_sock, remote_addr)| { let raw_sock = FramedIo::new(Box::new(raw_sock.compat())); - (raw_sock, Endpoint::from_tcp_sock_addr(remote_addr)) + (raw_sock, Endpoint::from_tcp_addr(remote_addr)) }).map_err(|err| err.into()); tokio::spawn(cback(maybe_accepted.into())); }, diff --git a/tests/pub_sub.rs b/tests/pub_sub.rs index 0a5e098..a560a74 100644 --- a/tests/pub_sub.rs +++ b/tests/pub_sub.rs @@ -22,7 +22,7 @@ async fn test_pub_sub_sockets() { let bound_to = pub_socket .bind(bind_addr) .await - .unwrap_or_else(|_| panic!("Failed to bind to {}", bind_addr)); + .unwrap_or_else(|e| panic!("Failed to bind to {}: {}", bind_addr, e)); has_bound_sender .send(bound_to) .expect("channel was dropped"); @@ -42,11 +42,9 @@ async fn test_pub_sub_sockets() { // TODO: ZMQ sockets should not care about this sort of ordering. // See https://github.com/zeromq/zmq.rs/issues/73 let bound_addr = has_bound.await.expect("channel was cancelled"); - assert!(if let Endpoint::Tcp(_host, port) = bound_addr.clone() { - port != 0 - } else { - unreachable!() - }); + if let Endpoint::Tcp(_host, port) = bound_addr.clone() { + assert_ne!(port, 0); + } let (sub_results_sender, sub_results) = mpsc::channel(100); for _ in 0..10 { @@ -94,6 +92,7 @@ async fn test_pub_sub_sockets() { "tcp://localhost:0", "tcp://127.0.0.1:0", "tcp://[::1]:0", + "ipc://asdf.sock", ]; futures::future::join_all(addrs.into_iter().map(helper)).await; } From 6280334ee416f71f5990dcb02d0770b68c583aa5 Mon Sep 17 00:00:00 2001 From: Ryan Butler Date: Tue, 13 Oct 2020 12:00:58 -0400 Subject: [PATCH 4/4] Added pretty_env_logger to tests --- Cargo.toml | 1 + src/transport/ipc/tokio.rs | 1 + tests/pub_sub.rs | 8 +++++--- tests/rep_req.rs | 4 ++++ 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8fb20cb..adf958a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ futures_codec = "0.4" [dev-dependencies] chrono = "^0.4" criterion = "0.3" +pretty_env_logger = "0.4" [lib] bench = false diff --git a/src/transport/ipc/tokio.rs b/src/transport/ipc/tokio.rs index a1ccbf6..ebbaa89 100644 --- a/src/transport/ipc/tokio.rs +++ b/src/transport/ipc/tokio.rs @@ -44,6 +44,7 @@ where tokio::spawn(cback(maybe_accepted.into())); }, _ = stop_callback => { + log::debug!("Accept task received stop signal. {:?}", listener_addr); break } } diff --git a/tests/pub_sub.rs b/tests/pub_sub.rs index a560a74..d8d2a01 100644 --- a/tests/pub_sub.rs +++ b/tests/pub_sub.rs @@ -1,13 +1,15 @@ +use zeromq::prelude::*; +use zeromq::Endpoint; + use futures::channel::{mpsc, oneshot}; use futures::{SinkExt, StreamExt}; use std::convert::TryInto; use std::time::Duration; -use zeromq::prelude::*; -use zeromq::Endpoint; - #[tokio::test] async fn test_pub_sub_sockets() { + pretty_env_logger::try_init().ok(); + async fn helper(bind_addr: &'static str) { // We will join on these at the end to determine if any tasks we spawned // panicked diff --git a/tests/rep_req.rs b/tests/rep_req.rs index 74eb7bf..905064b 100644 --- a/tests/rep_req.rs +++ b/tests/rep_req.rs @@ -19,6 +19,8 @@ async fn run_rep_server(mut rep_socket: RepSocket) -> Result<(), Box> #[tokio::test] async fn test_req_rep_sockets() -> Result<(), Box> { + pretty_env_logger::try_init().ok(); + let mut rep_socket = zeromq::RepSocket::new(); let endpoint = rep_socket.bind("tcp://localhost:0").await?; println!("Started rep server on {}", endpoint); @@ -40,6 +42,8 @@ async fn test_req_rep_sockets() -> Result<(), Box> { #[tokio::test] async fn test_many_req_rep_sockets() -> Result<(), Box> { + pretty_env_logger::try_init().ok(); + let mut rep_socket = zeromq::RepSocket::new(); let endpoint = rep_socket.bind("tcp://localhost:0").await?; println!("Started rep server on {}", endpoint);