Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IPC Implementation! #84

Merged
merged 4 commits into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ futures_codec = "0.4"
[dev-dependencies]
chrono = "^0.4"
criterion = "0.3"
pretty_env_logger = "0.4"

[lib]
bench = false

[[bench]]
name = "pub_sub"
harness = false
bench = false # Don't actually benchmark this, until we fix it
bench = false

[[bench]]
name = "req_rep"
Expand Down
28 changes: 27 additions & 1 deletion src/endpoint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use lazy_static::lazy_static;
use regex::Regex;
use std::fmt;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::str::FromStr;

use crate::error::EndpointError;
Expand All @@ -31,19 +32,25 @@ pub type Port = u16;
pub enum Endpoint {
// TODO: Add endpoints for the other transport variants
Tcp(Host, Port),
Ipc(Option<PathBuf>),
}

impl Endpoint {
pub fn transport(&self) -> Transport {
match self {
Self::Tcp(_, _) => Transport::Tcp,
Self::Ipc(_) => Transport::Ipc,
}
}

/// Creates an `Endpoint::Tcp` from a [`SocketAddr`]
pub fn from_tcp_sock_addr(addr: SocketAddr) -> Self {
pub fn from_tcp_addr(addr: SocketAddr) -> Self {
Endpoint::Tcp(addr.ip().into(), addr.port())
}

pub fn from_tcp_domain(addr: String, port: u16) -> Self {
Endpoint::Tcp(Host::Domain(addr), port)
}
}

impl FromStr for Endpoint {
Expand Down Expand Up @@ -81,6 +88,10 @@ impl FromStr for Endpoint {
let (host, port) = extract_host_port(address)?;
Endpoint::Tcp(host, port)
}
Transport::Ipc => {
let path: PathBuf = address.to_string().into();
Endpoint::Ipc(Some(path))
}
};

Ok(endpoint)
Expand All @@ -97,6 +108,8 @@ impl fmt::Display for Endpoint {
write!(f, "tcp://{}:{}", host, port)
}
}
Endpoint::Ipc(Some(path)) => write!(f, "ipc://{}", path.display()),
Endpoint::Ipc(None) => write!(f, "ipc://????"),
}
}
}
Expand Down Expand Up @@ -139,6 +152,18 @@ mod tests {

lazy_static! {
static ref PAIRS: Vec<(Endpoint, &'static str)> = vec![
(
Endpoint::Ipc(Some(PathBuf::from("/tmp/asdf"))),
"ipc:///tmp/asdf"
),
(
Endpoint::Ipc(Some(PathBuf::from("my/dir_1/dir-2"))),
"ipc://my/dir_1/dir-2"
),
(
Endpoint::Ipc(Some(PathBuf::from("@abstract/namespace"))),
"ipc://@abstract/namespace"
),
(
Endpoint::Tcp(Host::Domain("www.example.com".to_string()), 1234),
"tcp://www.example.com:1234",
Expand Down Expand Up @@ -179,6 +204,7 @@ mod tests {
for (e, s) in PAIRS.iter() {
assert_eq!(&format!("{}", e), s);
}
assert_eq!(&format!("{}", Endpoint::Ipc(None)), "ipc://????");
}

#[test]
Expand Down
3 changes: 3 additions & 0 deletions src/endpoint/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use super::EndpointError;
pub enum Transport {
/// TCP transport
Tcp,
Ipc,
}

impl FromStr for Transport {
Expand All @@ -18,6 +19,7 @@ impl FromStr for Transport {
fn from_str(s: &str) -> Result<Self, Self::Err> {
let result = match s {
"tcp" => Transport::Tcp,
"ipc" => Transport::Ipc,
_ => return Err(EndpointError::UnknownTransport(s.to_string())),
};
Ok(result)
Expand All @@ -35,6 +37,7 @@ impl fmt::Display for Transport {
fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), std::fmt::Error> {
let s = match self {
Transport::Tcp => "tcp",
Transport::Ipc => "ipc",
};
write!(f, "{}", s)
}
Expand Down
24 changes: 24 additions & 0 deletions src/transport/ipc/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// TODO: Conditionally compile things
mod tokio;

use self::tokio as tk;
use crate::codec::FramedIo;
use crate::endpoint::Endpoint;
use crate::transport::AcceptStopChannel;
use crate::ZmqResult;

use std::path::PathBuf;

pub(crate) async fn connect(path: PathBuf) -> ZmqResult<(FramedIo, Endpoint)> {
tk::connect(path).await
}

pub(crate) async fn begin_accept<T>(
path: PathBuf,
cback: impl Fn(ZmqResult<(FramedIo, Endpoint)>) -> T + Send + 'static,
) -> ZmqResult<(Endpoint, AcceptStopChannel)>
where
T: std::future::Future<Output = ()> + Send + 'static,
{
tk::begin_accept(path, cback).await
}
64 changes: 64 additions & 0 deletions src/transport/ipc/tokio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use crate::codec::FramedIo;
use crate::endpoint::Endpoint;
use crate::transport::AcceptStopChannel;
use crate::ZmqResult;

use futures::{select, FutureExt};
use std::path::{Path, PathBuf};
use tokio_util::compat::Tokio02AsyncReadCompatExt;

pub(crate) async fn connect(path: PathBuf) -> ZmqResult<(FramedIo, Endpoint)> {
let raw_socket = tokio::net::UnixStream::connect(&path).await?;
let peer_addr = raw_socket.peer_addr()?;
let peer_addr = peer_addr.as_pathname().map(|a| a.to_owned());
let boxed_sock = Box::new(raw_socket.compat());
Ok((FramedIo::new(boxed_sock), Endpoint::Ipc(peer_addr)))
}

pub(crate) async fn begin_accept<T>(
path: PathBuf,
cback: impl Fn(ZmqResult<(FramedIo, Endpoint)>) -> T + Send + 'static,
) -> ZmqResult<(Endpoint, AcceptStopChannel)>
where
T: std::future::Future<Output = ()> + Send + 'static,
{
let wildcard: &Path = "*".as_ref();
if path == wildcard {
todo!("Need to implement support for wildcard paths!");
}
let mut listener = tokio::net::UnixListener::bind(path)?;
let resolved_addr = listener.local_addr()?;
let resolved_addr = resolved_addr.as_pathname().map(|a| a.to_owned());
let listener_addr = resolved_addr.clone();
let (stop_handle, stop_callback) = futures::channel::oneshot::channel::<()>();
tokio::spawn(async move {
let mut stop_callback = stop_callback.fuse();
loop {
select! {
incoming = listener.accept().fuse() => {
let maybe_accepted: Result<_, _> = incoming.map(|(raw_sock, peer_addr)| {
let raw_sock = FramedIo::new(Box::new(raw_sock.compat()));
let peer_addr = peer_addr.as_pathname().map(|a| a.to_owned());
(raw_sock, Endpoint::Ipc(peer_addr))
}).map_err(|err| err.into());
tokio::spawn(cback(maybe_accepted.into()));
},
_ = stop_callback => {
log::debug!("Accept task received stop signal. {:?}", listener_addr);
break
}
}
}
drop(listener);
if let Some(listener_addr) = listener_addr {
if let Err(err) = tokio::fs::remove_file(&listener_addr).await {
log::warn!(
"Could not delete unix socket at {}: {}",
listener_addr.display(),
err
);
}
}
});
Ok((Endpoint::Ipc(resolved_addr), AcceptStopChannel(stop_handle)))
}
8 changes: 8 additions & 0 deletions src/transport/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
mod ipc;
mod tcp;

use crate::codec::FramedIo;
use crate::endpoint::Endpoint;
use crate::error::ZmqError;
use crate::ZmqResult;

pub(crate) async fn connect(endpoint: Endpoint) -> ZmqResult<(FramedIo, Endpoint)> {
match endpoint {
Endpoint::Tcp(host, port) => tcp::connect(host, port).await,
Endpoint::Ipc(Some(path)) => ipc::connect(path).await,
Endpoint::Ipc(None) => Err(ZmqError::Socket("Cannot connect to an unnamed ipc socket")),
}
}

Expand All @@ -29,5 +33,9 @@ where
{
match endpoint {
Endpoint::Tcp(host, port) => tcp::begin_accept(host, port, cback).await,
Endpoint::Ipc(Some(path)) => ipc::begin_accept(path, cback).await,
Endpoint::Ipc(None) => Err(ZmqError::Socket(
"Cannot begin accepting peers at an unnamed ipc socket",
)),
}
}
6 changes: 3 additions & 3 deletions src/transport/tcp/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use tokio_util::compat::Tokio02AsyncReadCompatExt;

pub(crate) async fn connect(host: Host, port: Port) -> ZmqResult<(FramedIo, Endpoint)> {
let raw_socket = tokio::net::TcpStream::connect((host.to_string().as_str(), port)).await?;
let remote_addr = raw_socket.peer_addr()?;
let peer_addr = raw_socket.peer_addr()?;
let boxed_sock = Box::new(raw_socket.compat());
Ok((
FramedIo::new(boxed_sock),
Endpoint::from_tcp_sock_addr(remote_addr),
Endpoint::from_tcp_addr(peer_addr),
))
}

Expand All @@ -36,7 +36,7 @@ where
incoming = listener.accept().fuse() => {
let maybe_accepted: Result<_, _> = incoming.map(|(raw_sock, remote_addr)| {
let raw_sock = FramedIo::new(Box::new(raw_sock.compat()));
(raw_sock, Endpoint::from_tcp_sock_addr(remote_addr))
(raw_sock, Endpoint::from_tcp_addr(remote_addr))
}).map_err(|err| err.into());
tokio::spawn(cback(maybe_accepted.into()));
},
Expand Down
22 changes: 14 additions & 8 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,12 @@ pub(crate) mod tests {

let mut port_set = std::collections::HashSet::new();
for b in bound_to.keys() {
let Endpoint::Tcp(host, port) = b;
assert_eq!(host, &any.into());
port_set.insert(*port);
if let Endpoint::Tcp(host, port) = b {
assert_eq!(host, &any.into());
port_set.insert(*port);
} else {
unreachable!()
}
}

(start_port..start_port + 4).for_each(|p| assert!(port_set.contains(&p)));
Expand All @@ -288,11 +291,14 @@ pub(crate) mod tests {
assert_eq!(bound_to.len(), 4);
let mut port_set = std::collections::HashSet::new();
for b in bound_to.keys() {
let Endpoint::Tcp(host, port) = b;
assert_eq!(host, &Host::Domain("localhost".to_string()));
assert_ne!(*port, 0);
// Insert and check that it wasn't already present
assert!(port_set.insert(*port));
if let Endpoint::Tcp(host, port) = b {
assert_eq!(host, &Host::Domain("localhost".to_string()));
assert_ne!(*port, 0);
// Insert and check that it wasn't already present
assert!(port_set.insert(*port));
} else {
unreachable!()
}
}

Ok(())
Expand Down
19 changes: 10 additions & 9 deletions tests/pub_sub.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use zeromq::prelude::*;
use zeromq::Endpoint;

use futures::channel::{mpsc, oneshot};
use futures::{SinkExt, StreamExt};
use std::convert::TryInto;
use std::time::Duration;

use zeromq::prelude::*;
use zeromq::Endpoint;

#[tokio::test]
async fn test_pub_sub_sockets() {
pretty_env_logger::try_init().ok();

async fn helper(bind_addr: &'static str) {
// We will join on these at the end to determine if any tasks we spawned
// panicked
Expand All @@ -22,7 +24,7 @@ async fn test_pub_sub_sockets() {
let bound_to = pub_socket
.bind(bind_addr)
.await
.unwrap_or_else(|_| panic!("Failed to bind to {}", bind_addr));
.unwrap_or_else(|e| panic!("Failed to bind to {}: {}", bind_addr, e));
has_bound_sender
.send(bound_to)
.expect("channel was dropped");
Expand All @@ -42,11 +44,9 @@ async fn test_pub_sub_sockets() {
// TODO: ZMQ sockets should not care about this sort of ordering.
// See https://github.com/zeromq/zmq.rs/issues/73
let bound_addr = has_bound.await.expect("channel was cancelled");
assert!(if let Endpoint::Tcp(_host, port) = bound_addr.clone() {
port != 0
} else {
unreachable!()
});
if let Endpoint::Tcp(_host, port) = bound_addr.clone() {
assert_ne!(port, 0);
}

let (sub_results_sender, sub_results) = mpsc::channel(100);
for _ in 0..10 {
Expand Down Expand Up @@ -94,6 +94,7 @@ async fn test_pub_sub_sockets() {
"tcp://localhost:0",
"tcp://127.0.0.1:0",
"tcp://[::1]:0",
"ipc://asdf.sock",
];
futures::future::join_all(addrs.into_iter().map(helper)).await;
}
4 changes: 4 additions & 0 deletions tests/rep_req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ async fn run_rep_server(mut rep_socket: RepSocket) -> Result<(), Box<dyn Error>>

#[tokio::test]
async fn test_req_rep_sockets() -> Result<(), Box<dyn Error>> {
pretty_env_logger::try_init().ok();

let mut rep_socket = zeromq::RepSocket::new();
let endpoint = rep_socket.bind("tcp://localhost:0").await?;
println!("Started rep server on {}", endpoint);
Expand All @@ -40,6 +42,8 @@ async fn test_req_rep_sockets() -> Result<(), Box<dyn Error>> {

#[tokio::test]
async fn test_many_req_rep_sockets() -> Result<(), Box<dyn Error>> {
pretty_env_logger::try_init().ok();

let mut rep_socket = zeromq::RepSocket::new();
let endpoint = rep_socket.bind("tcp://localhost:0").await?;
println!("Started rep server on {}", endpoint);
Expand Down