Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ license = "MIT"

[dependencies]
tokio = { version = "1", features = ["rt", "net", "time", "sync", "io-util"] }
socket2 = { version = "0.5.2", features = ["all"] }

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
24 changes: 14 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@ use tokio::net::UdpSocket;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use udpflow::{UdpListener, UdpStreamLocal, UdpStreamRemote};
async fn server() {
let socket = UdpSocket::bind("127.0.0.1:5000").await.unwrap();
let listener = UdpListener::new(socket);
let mut buf = vec![0u8; 0x2000];
// listener must be continuously polled to recv packets or accept new streams
while let Ok((stream, addr)) = listener.accept(&mut buf).await {
tokio::spawn(handle(stream));
let addr = "127.0.0.1:5000".parse().unwrap();
let listener = UdpListener::new(addr).unwrap();
loop {
let mut buf = vec![0u8; 0x2000];
let (n, stream, addr) = listener.accept(&mut buf).await.unwrap();
buf.truncate(n);
tokio::spawn(handle(stream, buf));
}
}
async fn handle(mut stream1: UdpStreamLocal) {
let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let mut stream2 = UdpStreamRemote::new(socket, "127.0.0.1:10000".parse().unwrap());

async fn handle(mut stream1: UdpStreamLocal, first_packet: Vec<u8>) {
let local = "127.0.0.1:0".parse().unwrap();
let remote = "127.0.0.1:10000".parse().unwrap();
let mut stream2 = UdpStreamRemote::new(local, remote).await.unwrap();
stream2.write_all(&first_packet).await.unwrap();
let mut buf = vec![0u8; 256];
stream1.read(&mut buf).await; stream2.write(&buf).await;
stream2.read(&mut buf).await; stream1.write(&buf).await;
Expand Down Expand Up @@ -56,4 +60,4 @@ async {
+------+----------+
```

*LEN is a 16-bit unsigned integer in big endian byte order.
\*LEN is a 16-bit unsigned integer in big endian byte order.
15 changes: 12 additions & 3 deletions src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ where
loop {
match this.rd {
State::Len => {
let to_read_bytes = 2 - buf.filled().len();
assert!(buf.remaining() >= to_read_bytes);
let mut read_buf =
ReadBuf::new(buf.initialize_unfilled_to(2 - buf.filled().len()));
ReadBuf::uninit(unsafe { &mut buf.unfilled_mut()[..to_read_bytes] });
let n = ready!(Pin::new(&mut this.buf).poll_read(cx, &mut read_buf))
.map(|_| read_buf.filled().len())?;
if n == 0 {
Expand All @@ -110,7 +112,9 @@ where
buf.clear();
}
State::Data(length) => {
let mut read_buf = ReadBuf::new(buf.initialize_unfilled_to(length as usize));
assert!(buf.remaining() >= length as usize);
let mut read_buf =
ReadBuf::uninit(unsafe { &mut buf.unfilled_mut()[..length as usize] });
let n = ready!(Pin::new(&mut this.buf).poll_read(cx, &mut read_buf))
.map(|_| read_buf.filled().len())?;
if n == 0 {
Expand All @@ -119,7 +123,7 @@ where
return Poll::Ready(Ok(()));
}
buf.advance(n);
if n != length as usize {
if n < length as usize {
this.rd = State::Data(length - n as u16);
continue;
}
Expand All @@ -141,6 +145,11 @@ where

let this = self.get_mut();

// Zero-sized datagrams are not allowed
if buf.is_empty() {
return Poll::Ready(Ok(0));
}

loop {
match this.wr {
State::Len => {
Expand Down
41 changes: 31 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@
//! use tokio::io::{AsyncReadExt, AsyncWriteExt};
//! use udpflow::{UdpListener, UdpStreamLocal, UdpStreamRemote};
//! async fn server() {
//! let socket = UdpSocket::bind("127.0.0.1:5000").await.unwrap();
//! let listener = UdpListener::new(socket);
//! let mut buf = vec![0u8; 0x2000];
//! // listener must be continuously polled to recv packets or accept new streams
//! while let Ok((stream, addr)) = listener.accept(&mut buf).await {
//! tokio::spawn(handle(stream));
//! let addr = "127.0.0.1:5000".parse().unwrap();
//! let listener = UdpListener::new(addr).unwrap();
//! loop {
//! let mut buf = vec![0u8; 0x2000];
//! let (n, stream, addr) = listener.accept(&mut buf).await.unwrap();
//! buf.truncate(n);
//! tokio::spawn(handle(stream, buf));
//! }
//! }
//!
//! async fn handle(mut stream1: UdpStreamLocal) {
//! let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
//! let mut stream2 = UdpStreamRemote::new(socket, "127.0.0.1:10000".parse().unwrap());
//! async fn handle(mut stream1: UdpStreamLocal, first_packet: Vec<u8>) {
//! let local = "127.0.0.1:0".parse().unwrap();
//! let remote = "127.0.0.1:10000".parse().unwrap();
//! let mut stream2 = UdpStreamRemote::new(local, remote).await.unwrap();
//! stream2.write_all(&first_packet).await.unwrap();
//! let mut buf = vec![0u8; 256];
//! stream1.read(&mut buf).await; stream2.write(&buf).await;
//! stream2.read(&mut buf).await; stream1.write(&buf).await;
Expand All @@ -43,7 +46,6 @@
//! ```
//!

mod sockmap;
mod streaml;
mod streamr;
mod listener;
Expand Down Expand Up @@ -75,3 +77,22 @@ mod statics {
}

pub use statics::{set_timeout, get_timeout};

pub(crate) fn new_udp_socket(local_addr: std::net::SocketAddr) -> std::io::Result<UdpSocket> {
let udp_sock = socket2::Socket::new(
if local_addr.is_ipv4() {
socket2::Domain::IPV4
} else {
socket2::Domain::IPV6
},
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
udp_sock.set_reuse_address(true)?;
#[cfg(not(windows))]
udp_sock.set_reuse_port(true)?;
udp_sock.set_nonblocking(true)?;
udp_sock.bind(&socket2::SockAddr::from(local_addr))?;
let udp_sock: std::net::UdpSocket = udp_sock.into();
udp_sock.try_into()
}
48 changes: 15 additions & 33 deletions src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,34 @@ use std::net::SocketAddr;
use std::sync::Arc;

use tokio::net::UdpSocket;
use tokio::sync::mpsc;

use crate::UdpStreamLocal;

use crate::sockmap::{SockMap, Packet};
use crate::{UdpStreamLocal, new_udp_socket};

/// Udp packet listener.
pub struct UdpListener {
socket: Arc<UdpSocket>,
sockmap: SockMap,
}

impl UdpListener {
/// Create from a **bound** udp socket.
pub fn new(socket: UdpSocket) -> Self {
Self {
socket: Arc::new(socket),
sockmap: SockMap::new(),
}
pub fn new(local_address: SocketAddr) -> std::io::Result<Self> {
Ok(Self {
socket: Arc::new(new_udp_socket(local_address)?),
})
}

/// Accept a new stream.
///
/// A listener must be continuously polled to recv packets or accept new streams.
///
/// When receiving a packet from a known peer, this function does not return,
/// and the packet will be copied then sent to the associated
/// On success, it returns peer stream socket, peer address and
/// the number of bytes read.
/// [`UdpStreamLocal`](super::UdpStreamLocal).
pub async fn accept(&self, buf: &mut [u8]) -> Result<(UdpStreamLocal, SocketAddr)> {
loop {
let (n, addr) = self.socket.recv_from(buf).await?;
debug_assert!(n != 0);

// existed session
if let Some(tx) = self.sockmap.get(&addr) {
let _ = tx.send(Vec::from(&buf[..n])).await;
continue;
}

// new session
let (tx, rx) = mpsc::channel::<Packet>(32);
let _ = tx.send(Vec::from(&buf[..n])).await;
self.sockmap.insert(addr, tx);

let stream = UdpStreamLocal::new(rx, self.socket.clone(), self.sockmap.clone(), addr);
return Ok((stream, addr));
}
pub async fn accept(&self, buf: &mut [u8]) -> Result<(usize, UdpStreamLocal, SocketAddr)> {
let (n, addr) = self.socket.recv_from(buf).await?;

debug_assert!(n != 0);

let stream = UdpStreamLocal::new(self.socket.local_addr().unwrap(), addr).await?;

Ok((n, stream, addr))
}
}
45 changes: 0 additions & 45 deletions src/sockmap.rs

This file was deleted.

59 changes: 24 additions & 35 deletions src/streaml.rs
Original file line number Diff line number Diff line change
@@ -1,65 +1,52 @@
use std::io::Result;
use std::sync::Arc;
use std::net::SocketAddr;

use std::net::{SocketAddr};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

use tokio::net::UdpSocket;
use tokio::time::{sleep, Sleep, Instant};
use tokio::sync::mpsc::Receiver;
use tokio::io::{ReadBuf, AsyncRead, AsyncWrite};

use crate::sockmap::{SockMap, Packet};
use crate::get_timeout;
use crate::{get_timeout, new_udp_socket};

/// Udp stream accepted from local listener.
///
/// A `Read` call times out when there is no packet received
/// during a period of time. This is treated as `EOF`, and
/// a `Ok(0)` will be returned.
pub struct UdpStreamLocal {
rx: Receiver<Packet>,
socket: Arc<UdpSocket>,
socket: UdpSocket,
timeout: Pin<Box<Sleep>>,
sockmap: SockMap,
addr: SocketAddr,
}

impl UdpStreamLocal {
pub(crate) fn new(
rx: Receiver<Packet>,
socket: Arc<UdpSocket>,
sockmap: SockMap,
addr: SocketAddr,
) -> Self {
Self {
rx,
/// Create from a **bound** udp socket.
#[inline]
pub(crate) async fn new(
local_addr: SocketAddr,
remote_addr: SocketAddr,
) -> std::io::Result<Self> {
let socket = new_udp_socket(local_addr)?;
socket.connect(remote_addr).await?;
Ok(Self {
socket,
addr,
sockmap,
timeout: Box::pin(sleep(get_timeout())),
}
})
}

/// Get peer sockaddr.
#[inline]
pub const fn peer_addr(&self) -> SocketAddr { self.addr }
pub fn peer_addr(&self) -> SocketAddr { self.socket.peer_addr().unwrap() }

/// Get local sockaddr.
#[inline]
pub fn local_addr(&self) -> SocketAddr { self.socket.local_addr().unwrap() }

/// Get inner udp socket.
#[inline]
pub const fn inner_socket(&self) -> &Arc<UdpSocket> { &self.socket }
}

impl Drop for UdpStreamLocal {
fn drop(&mut self) {
self.sockmap.remove(&self.addr);
// left elements are popped
}
pub const fn inner_socket(&self) -> &UdpSocket { &self.socket }
}

impl AsyncRead for UdpStreamLocal {
Expand All @@ -70,17 +57,20 @@ impl AsyncRead for UdpStreamLocal {
) -> Poll<Result<()>> {
let this = self.get_mut();

if let Poll::Ready(Some(pkt)) = this.rx.poll_recv(cx) {
buf.put_slice(&pkt);

if let Poll::Ready(result) = this.socket.poll_recv(cx, buf) {
// reset timer
this.timeout.as_mut().reset(Instant::now() + get_timeout());

return Poll::Ready(Ok(()));
return match result {
Ok(_) => Poll::Ready(Ok(())),
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
};
}

// EOF
if this.timeout.as_mut().poll(cx).is_ready() {
buf.clear();
return Poll::Ready(Ok(()));
}

Expand All @@ -91,15 +81,14 @@ impl AsyncRead for UdpStreamLocal {
impl AsyncWrite for UdpStreamLocal {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
let this = self.get_mut();
this.socket.poll_send_to(cx, buf, this.addr)
this.socket.poll_send(cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
self.get_mut().rx.close();
Poll::Ready(Ok(()))
}
}
Loading