diff --git a/Cargo.toml b/Cargo.toml index 325c66d..50ab020 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ license = "MIT" [dependencies] tokio = { version = "1", features = ["rt", "net", "time", "sync", "io-util"] } +socket2 = { version = "0.5.2", features = ["all"] } [dev-dependencies] tokio = { version = "1", features = ["full"] } diff --git a/README.md b/README.md index 7ac9a92..a8f82c1 100644 --- a/README.md +++ b/README.md @@ -12,17 +12,21 @@ use tokio::net::UdpSocket; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use udpflow::{UdpListener, UdpStreamLocal, UdpStreamRemote}; async fn server() { - let socket = UdpSocket::bind("127.0.0.1:5000").await.unwrap(); - let listener = UdpListener::new(socket); - let mut buf = vec![0u8; 0x2000]; - // listener must be continuously polled to recv packets or accept new streams - while let Ok((stream, addr)) = listener.accept(&mut buf).await { - tokio::spawn(handle(stream)); + let addr = "127.0.0.1:5000".parse().unwrap(); + let listener = UdpListener::new(addr).unwrap(); + loop { + let mut buf = vec![0u8; 0x2000]; + let (n, stream, addr) = listener.accept(&mut buf).await.unwrap(); + buf.truncate(n); + tokio::spawn(handle(stream, buf)); } } -async fn handle(mut stream1: UdpStreamLocal) { - let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); - let mut stream2 = UdpStreamRemote::new(socket, "127.0.0.1:10000".parse().unwrap()); + +async fn handle(mut stream1: UdpStreamLocal, first_packet: Vec) { + let local = "127.0.0.1:0".parse().unwrap(); + let remote = "127.0.0.1:10000".parse().unwrap(); + let mut stream2 = UdpStreamRemote::new(local, remote).await.unwrap(); + stream2.write_all(&first_packet).await.unwrap(); let mut buf = vec![0u8; 256]; stream1.read(&mut buf).await; stream2.write(&buf).await; stream2.read(&mut buf).await; stream1.write(&buf).await; @@ -56,4 +60,4 @@ async { +------+----------+ ``` -*LEN is a 16-bit unsigned integer in big endian byte order. +\*LEN is a 16-bit unsigned integer in big endian byte order. diff --git a/src/frame.rs b/src/frame.rs index a185688..18b8104 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -93,8 +93,10 @@ where loop { match this.rd { State::Len => { + let to_read_bytes = 2 - buf.filled().len(); + assert!(buf.remaining() >= to_read_bytes); let mut read_buf = - ReadBuf::new(buf.initialize_unfilled_to(2 - buf.filled().len())); + ReadBuf::uninit(unsafe { &mut buf.unfilled_mut()[..to_read_bytes] }); let n = ready!(Pin::new(&mut this.buf).poll_read(cx, &mut read_buf)) .map(|_| read_buf.filled().len())?; if n == 0 { @@ -110,7 +112,9 @@ where buf.clear(); } State::Data(length) => { - let mut read_buf = ReadBuf::new(buf.initialize_unfilled_to(length as usize)); + assert!(buf.remaining() >= length as usize); + let mut read_buf = + ReadBuf::uninit(unsafe { &mut buf.unfilled_mut()[..length as usize] }); let n = ready!(Pin::new(&mut this.buf).poll_read(cx, &mut read_buf)) .map(|_| read_buf.filled().len())?; if n == 0 { @@ -119,7 +123,7 @@ where return Poll::Ready(Ok(())); } buf.advance(n); - if n != length as usize { + if n < length as usize { this.rd = State::Data(length - n as u16); continue; } @@ -141,6 +145,11 @@ where let this = self.get_mut(); + // Zero-sized datagrams are not allowed + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + loop { match this.wr { State::Len => { diff --git a/src/lib.rs b/src/lib.rs index 0e3adae..c2a43d4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,18 +7,21 @@ //! use tokio::io::{AsyncReadExt, AsyncWriteExt}; //! use udpflow::{UdpListener, UdpStreamLocal, UdpStreamRemote}; //! async fn server() { -//! let socket = UdpSocket::bind("127.0.0.1:5000").await.unwrap(); -//! let listener = UdpListener::new(socket); -//! let mut buf = vec![0u8; 0x2000]; -//! // listener must be continuously polled to recv packets or accept new streams -//! while let Ok((stream, addr)) = listener.accept(&mut buf).await { -//! tokio::spawn(handle(stream)); +//! let addr = "127.0.0.1:5000".parse().unwrap(); +//! let listener = UdpListener::new(addr).unwrap(); +//! loop { +//! let mut buf = vec![0u8; 0x2000]; +//! let (n, stream, addr) = listener.accept(&mut buf).await.unwrap(); +//! buf.truncate(n); +//! tokio::spawn(handle(stream, buf)); //! } //! } //! -//! async fn handle(mut stream1: UdpStreamLocal) { -//! let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); -//! let mut stream2 = UdpStreamRemote::new(socket, "127.0.0.1:10000".parse().unwrap()); +//! async fn handle(mut stream1: UdpStreamLocal, first_packet: Vec) { +//! let local = "127.0.0.1:0".parse().unwrap(); +//! let remote = "127.0.0.1:10000".parse().unwrap(); +//! let mut stream2 = UdpStreamRemote::new(local, remote).await.unwrap(); +//! stream2.write_all(&first_packet).await.unwrap(); //! let mut buf = vec![0u8; 256]; //! stream1.read(&mut buf).await; stream2.write(&buf).await; //! stream2.read(&mut buf).await; stream1.write(&buf).await; @@ -43,7 +46,6 @@ //! ``` //! -mod sockmap; mod streaml; mod streamr; mod listener; @@ -75,3 +77,22 @@ mod statics { } pub use statics::{set_timeout, get_timeout}; + +pub(crate) fn new_udp_socket(local_addr: std::net::SocketAddr) -> std::io::Result { + let udp_sock = socket2::Socket::new( + if local_addr.is_ipv4() { + socket2::Domain::IPV4 + } else { + socket2::Domain::IPV6 + }, + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + )?; + udp_sock.set_reuse_address(true)?; + #[cfg(not(windows))] + udp_sock.set_reuse_port(true)?; + udp_sock.set_nonblocking(true)?; + udp_sock.bind(&socket2::SockAddr::from(local_addr))?; + let udp_sock: std::net::UdpSocket = udp_sock.into(); + udp_sock.try_into() +} diff --git a/src/listener.rs b/src/listener.rs index a10c027..a49afc6 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -3,52 +3,34 @@ use std::net::SocketAddr; use std::sync::Arc; use tokio::net::UdpSocket; -use tokio::sync::mpsc; -use crate::UdpStreamLocal; - -use crate::sockmap::{SockMap, Packet}; +use crate::{UdpStreamLocal, new_udp_socket}; /// Udp packet listener. pub struct UdpListener { socket: Arc, - sockmap: SockMap, } impl UdpListener { /// Create from a **bound** udp socket. - pub fn new(socket: UdpSocket) -> Self { - Self { - socket: Arc::new(socket), - sockmap: SockMap::new(), - } + pub fn new(local_address: SocketAddr) -> std::io::Result { + Ok(Self { + socket: Arc::new(new_udp_socket(local_address)?), + }) } /// Accept a new stream. /// - /// A listener must be continuously polled to recv packets or accept new streams. - /// - /// When receiving a packet from a known peer, this function does not return, - /// and the packet will be copied then sent to the associated + /// On success, it returns peer stream socket, peer address and + /// the number of bytes read. /// [`UdpStreamLocal`](super::UdpStreamLocal). - pub async fn accept(&self, buf: &mut [u8]) -> Result<(UdpStreamLocal, SocketAddr)> { - loop { - let (n, addr) = self.socket.recv_from(buf).await?; - debug_assert!(n != 0); - - // existed session - if let Some(tx) = self.sockmap.get(&addr) { - let _ = tx.send(Vec::from(&buf[..n])).await; - continue; - } - - // new session - let (tx, rx) = mpsc::channel::(32); - let _ = tx.send(Vec::from(&buf[..n])).await; - self.sockmap.insert(addr, tx); - - let stream = UdpStreamLocal::new(rx, self.socket.clone(), self.sockmap.clone(), addr); - return Ok((stream, addr)); - } + pub async fn accept(&self, buf: &mut [u8]) -> Result<(usize, UdpStreamLocal, SocketAddr)> { + let (n, addr) = self.socket.recv_from(buf).await?; + + debug_assert!(n != 0); + + let stream = UdpStreamLocal::new(self.socket.local_addr().unwrap(), addr).await?; + + Ok((n, stream, addr)) } } diff --git a/src/sockmap.rs b/src/sockmap.rs deleted file mode 100644 index 27381d9..0000000 --- a/src/sockmap.rs +++ /dev/null @@ -1,45 +0,0 @@ -use std::net::SocketAddr; -use std::sync::{Arc, RwLock}; -use std::collections::HashMap; - -use tokio::sync::mpsc::Sender; - -pub(crate) type Packet = Vec; - -#[derive(Clone)] -pub(crate) struct SockMap(Arc>>>); - -impl SockMap { - pub fn new() -> Self { Self(Arc::new(RwLock::new(HashMap::new()))) } - - #[inline] - pub fn get(&self, addr: &SocketAddr) -> Option> { - // fetch the lock - - let sockmap = self.0.read().unwrap(); - - sockmap.get(addr).cloned() - - // drop the lock - } - - #[inline] - pub fn insert(&self, addr: SocketAddr, tx: Sender) { - // fetch the lock - let mut sockmap = self.0.write().unwrap(); - - let _ = sockmap.insert(addr, tx); - - // drop the lock - } - - #[inline] - pub fn remove(&self, addr: &SocketAddr) { - // fetch the lock - let mut sockmap = self.0.write().unwrap(); - - let _ = sockmap.remove(addr); - - // drop the lock - } -} diff --git a/src/streaml.rs b/src/streaml.rs index c777692..70eeed5 100644 --- a/src/streaml.rs +++ b/src/streaml.rs @@ -1,17 +1,15 @@ use std::io::Result; -use std::sync::Arc; -use std::net::SocketAddr; + +use std::net::{SocketAddr}; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::net::UdpSocket; use tokio::time::{sleep, Sleep, Instant}; -use tokio::sync::mpsc::Receiver; use tokio::io::{ReadBuf, AsyncRead, AsyncWrite}; -use crate::sockmap::{SockMap, Packet}; -use crate::get_timeout; +use crate::{get_timeout, new_udp_socket}; /// Udp stream accepted from local listener. /// @@ -19,32 +17,28 @@ use crate::get_timeout; /// during a period of time. This is treated as `EOF`, and /// a `Ok(0)` will be returned. pub struct UdpStreamLocal { - rx: Receiver, - socket: Arc, + socket: UdpSocket, timeout: Pin>, - sockmap: SockMap, - addr: SocketAddr, } impl UdpStreamLocal { - pub(crate) fn new( - rx: Receiver, - socket: Arc, - sockmap: SockMap, - addr: SocketAddr, - ) -> Self { - Self { - rx, + /// Create from a **bound** udp socket. + #[inline] + pub(crate) async fn new( + local_addr: SocketAddr, + remote_addr: SocketAddr, + ) -> std::io::Result { + let socket = new_udp_socket(local_addr)?; + socket.connect(remote_addr).await?; + Ok(Self { socket, - addr, - sockmap, timeout: Box::pin(sleep(get_timeout())), - } + }) } /// Get peer sockaddr. #[inline] - pub const fn peer_addr(&self) -> SocketAddr { self.addr } + pub fn peer_addr(&self) -> SocketAddr { self.socket.peer_addr().unwrap() } /// Get local sockaddr. #[inline] @@ -52,14 +46,7 @@ impl UdpStreamLocal { /// Get inner udp socket. #[inline] - pub const fn inner_socket(&self) -> &Arc { &self.socket } -} - -impl Drop for UdpStreamLocal { - fn drop(&mut self) { - self.sockmap.remove(&self.addr); - // left elements are popped - } + pub const fn inner_socket(&self) -> &UdpSocket { &self.socket } } impl AsyncRead for UdpStreamLocal { @@ -70,17 +57,20 @@ impl AsyncRead for UdpStreamLocal { ) -> Poll> { let this = self.get_mut(); - if let Poll::Ready(Some(pkt)) = this.rx.poll_recv(cx) { - buf.put_slice(&pkt); - + if let Poll::Ready(result) = this.socket.poll_recv(cx, buf) { // reset timer this.timeout.as_mut().reset(Instant::now() + get_timeout()); - return Poll::Ready(Ok(())); + return match result { + Ok(_) => Poll::Ready(Ok(())), + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Poll::Pending, + Err(e) => Poll::Ready(Err(e)), + }; } // EOF if this.timeout.as_mut().poll(cx).is_ready() { + buf.clear(); return Poll::Ready(Ok(())); } @@ -91,7 +81,7 @@ impl AsyncRead for UdpStreamLocal { impl AsyncWrite for UdpStreamLocal { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); - this.socket.poll_send_to(cx, buf, this.addr) + this.socket.poll_send(cx, buf) } fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { @@ -99,7 +89,6 @@ impl AsyncWrite for UdpStreamLocal { } fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - self.get_mut().rx.close(); Poll::Ready(Ok(())) } } diff --git a/src/streamr.rs b/src/streamr.rs index 9864cff..0512e58 100644 --- a/src/streamr.rs +++ b/src/streamr.rs @@ -8,7 +8,7 @@ use tokio::net::UdpSocket; use tokio::time::{sleep, Sleep, Instant}; use tokio::io::{ReadBuf, AsyncRead, AsyncWrite}; -use crate::get_timeout; +use crate::{get_timeout, new_udp_socket}; /// Udp stream which is actively established. /// @@ -18,23 +18,23 @@ use crate::get_timeout; pub struct UdpStreamRemote { socket: UdpSocket, timeout: Pin>, - addr: SocketAddr, } impl UdpStreamRemote { /// Create from a **bound** udp socket. #[inline] - pub fn new(socket: UdpSocket, addr: SocketAddr) -> Self { - Self { + pub async fn new(local_addr: SocketAddr, remote_addr: SocketAddr) -> std::io::Result { + let socket = new_udp_socket(local_addr)?; + socket.connect(remote_addr).await?; + Ok(Self { socket, - addr, timeout: Box::pin(sleep(get_timeout())), - } + }) } /// Get peer sockaddr. #[inline] - pub const fn peer_addr(&self) -> SocketAddr { self.addr } + pub fn peer_addr(&self) -> SocketAddr { self.socket.peer_addr().unwrap() } /// Get local sockaddr. #[inline] @@ -53,15 +53,20 @@ impl AsyncRead for UdpStreamRemote { ) -> Poll> { let this = self.get_mut(); - if let Poll::Ready(x) = this.socket.poll_recv_from(cx, buf) { + if let Poll::Ready(result) = this.socket.poll_recv(cx, buf) { // reset timer this.timeout.as_mut().reset(Instant::now() + get_timeout()); - return Poll::Ready(x.map(|_| ())); + return match result { + Ok(_) => Poll::Ready(Ok(())), + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Poll::Pending, + Err(e) => Poll::Ready(Err(e)), + }; } // EOF if this.timeout.as_mut().poll(cx).is_ready() { + buf.clear(); return Poll::Ready(Ok(())); } @@ -72,7 +77,7 @@ impl AsyncRead for UdpStreamRemote { impl AsyncWrite for UdpStreamRemote { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); - this.socket.poll_send_to(cx, buf, this.addr) + this.socket.poll_send(cx, buf) } fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { diff --git a/tests/bidi_copy.rs b/tests/bidi_copy.rs index d303aa9..6eb43e7 100644 --- a/tests/bidi_copy.rs +++ b/tests/bidi_copy.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; use std::time::Duration; -use tokio::time::sleep; +use tokio::{time::sleep, io::AsyncWriteExt}; use udpflow::{UdpSocket, UdpListener, UdpStreamLocal, UdpStreamRemote}; const BIND: &str = "127.0.0.1:10000"; @@ -37,20 +37,23 @@ async fn client() { } async fn relay_server() { - let socket = UdpSocket::bind(BIND).await.unwrap(); - let listener = UdpListener::new(socket); - - let mut buf = vec![0u8; 0x2000]; + let addr = BIND.parse::().unwrap(); + let listener = UdpListener::new(addr).unwrap(); - while let Ok((stream, addr)) = listener.accept(&mut buf).await { + loop { + let mut buf = vec![0u8; 0x2000]; + let (n, stream, addr) = listener.accept(&mut buf).await.unwrap(); + buf.truncate(n); assert_eq!(addr, SENDER.parse().unwrap()); - tokio::spawn(handle(stream)); + tokio::spawn(handle(stream, buf)); } } -async fn handle(mut stream1: UdpStreamLocal) { - let socket = UdpSocket::bind("0.0.0.0:0").await.unwrap(); - let mut stream2 = UdpStreamRemote::new(socket, RECVER.parse().unwrap()); +async fn handle(mut stream1: UdpStreamLocal, buf: Vec) { + let local = "0.0.0.0:0".parse::().unwrap(); + let remote = RECVER.parse::().unwrap(); + let mut stream2 = UdpStreamRemote::new(local, remote).await.unwrap(); + stream2.write_all(&buf).await.unwrap(); let _ = tokio::io::copy_bidirectional(&mut stream1, &mut stream2).await; } diff --git a/tests/bidi_copy_reassociate.rs b/tests/bidi_copy_reassociate.rs index bbd74ea..93ac29c 100644 --- a/tests/bidi_copy_reassociate.rs +++ b/tests/bidi_copy_reassociate.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; use std::time::Duration; -use tokio::time::sleep; +use tokio::{time::sleep, io::AsyncWriteExt}; use udpflow::{UdpSocket, UdpListener, UdpStreamLocal, UdpStreamRemote}; const BIND: &str = "127.0.0.1:10000"; @@ -9,14 +9,14 @@ const RECVER: &str = "127.0.0.1:15000"; const MSG: &[u8] = b"Ciallo"; const INTV: Duration = Duration::from_millis(20); const WAIT: Duration = Duration::from_millis(200); -const TIMEOUT: Duration = Duration::from_millis(300); +const TIMEOUT: Duration = Duration::from_secs(60); #[tokio::test] async fn bidi_copy_reassociate() { udpflow::set_timeout(TIMEOUT); tokio::select! { _ = client() => {}, - _ = async { tokio::join!(echo_server(),relay_server()) } => {} + _ = async { tokio::join!(echo_server(), relay_server()) } => {} }; } @@ -40,21 +40,24 @@ async fn client() { } async fn relay_server() { - let socket = UdpSocket::bind(BIND).await.unwrap(); - let listener = UdpListener::new(socket); - - let mut buf = vec![0u8; 0x2000]; + let addr = BIND.parse::().unwrap(); + let listener = UdpListener::new(addr).unwrap(); - while let Ok((stream, addr)) = listener.accept(&mut buf).await { + loop { + let mut buf = vec![0u8; 0x2000]; + let (n, stream, addr) = listener.accept(&mut buf).await.unwrap(); + buf.truncate(n); assert_eq!(addr, SENDER.parse().unwrap()); - tokio::spawn(handle(stream)); + tokio::spawn(handle(stream, buf)); } } -async fn handle(mut stream1: UdpStreamLocal) { +async fn handle(mut stream1: UdpStreamLocal, buf: Vec) { println!("relay: spawned"); - let socket = UdpSocket::bind("0.0.0.0:0").await.unwrap(); - let mut stream2 = UdpStreamRemote::new(socket, RECVER.parse().unwrap()); + let local = "0.0.0.0:0".parse::().unwrap(); + let remote = RECVER.parse::().unwrap(); + let mut stream2 = UdpStreamRemote::new(local, remote).await.unwrap(); + stream2.write_all(&buf).await.unwrap(); let _ = tokio::io::copy_bidirectional(&mut stream1, &mut stream2).await; println!("relay: timeout"); } diff --git a/tests/bidi_copy_uot.rs b/tests/bidi_copy_uot.rs index 4084975..90e6998 100644 --- a/tests/bidi_copy_uot.rs +++ b/tests/bidi_copy_uot.rs @@ -1,5 +1,6 @@ use std::net::SocketAddr; use std::time::Duration; +use tokio::io::AsyncWriteExt; use tokio::time::sleep; use tokio::net::{TcpStream, TcpListener}; use udpflow::{UdpSocket, UdpListener, UdpStreamLocal, UdpStreamRemote, UotStream}; @@ -16,9 +17,11 @@ async fn bidi_copy_uot() { tokio::select! { _ = client() => {}, _ = async { - tokio::join!(async { - tokio::join!(relay_server1(), relay_server2()) - }, echo_server()) + tokio::join!( + tokio::spawn(relay_server1()), + tokio::spawn(relay_server2()), + tokio::spawn(echo_server()) + ) } => {} }; } @@ -44,20 +47,22 @@ async fn client() { // udp -> tcp async fn relay_server1() { - let socket = UdpSocket::bind(RELAY1).await.unwrap(); - let listener = UdpListener::new(socket); - - let mut buf = vec![0u8; 0x2000]; + let addr = RELAY1.parse::().unwrap(); + let listener = UdpListener::new(addr).unwrap(); - while let Ok((stream, addr)) = listener.accept(&mut buf).await { + loop { + let mut buf = vec![0u8; 0x2000]; + let (n, stream, addr) = listener.accept(&mut buf).await.unwrap(); + buf.truncate(n); assert_eq!(addr, SENDER.parse().unwrap()); - tokio::spawn(handle1(stream)); + tokio::spawn(handle1(stream, buf)); } } // recv packet, send framed data -async fn handle1(mut stream1: UdpStreamLocal) { +async fn handle1(mut stream1: UdpStreamLocal, buf: Vec) { let mut stream2 = UotStream::new(TcpStream::connect(RELAY2).await.unwrap()); + stream2.write_all(&buf).await.unwrap(); let _ = tokio::io::copy_bidirectional(&mut stream1, &mut stream2).await; } @@ -65,15 +70,17 @@ async fn handle1(mut stream1: UdpStreamLocal) { async fn relay_server2() { let listener = TcpListener::bind(RELAY2).await.unwrap(); - while let Ok((stream, _)) = listener.accept().await { + loop { + let (stream, _) = listener.accept().await.unwrap(); tokio::spawn(handle2(UotStream::new(stream))); } } // recv framed data, send packet async fn handle2(mut stream1: UotStream) { - let socket = UdpSocket::bind("0.0.0.0:0").await.unwrap(); - let mut stream2 = UdpStreamRemote::new(socket, RECVER.parse().unwrap()); + let local = "0.0.0.0:0".parse::().unwrap(); + let remote = RECVER.parse::().unwrap(); + let mut stream2 = UdpStreamRemote::new(local, remote).await.unwrap(); let _ = tokio::io::copy_bidirectional(&mut stream1, &mut stream2).await; } diff --git a/tests/local.rs b/tests/local.rs index 53dc8a0..51774aa 100644 --- a/tests/local.rs +++ b/tests/local.rs @@ -37,23 +37,30 @@ async fn client() { } async fn server() { - let socket = UdpSocket::bind(BIND).await.unwrap(); - let listener = UdpListener::new(socket); - - let mut buf = vec![0u8; 0x2000]; + let addr = BIND.parse::().unwrap(); + let listener = UdpListener::new(addr).unwrap(); - while let Ok((stream, addr)) = listener.accept(&mut buf).await { + loop { + let mut buf = vec![0u8; 0x2000]; + let (n, stream, addr) = listener.accept(&mut buf).await.unwrap(); + buf.truncate(n); assert_eq!(addr, SENDER.parse().unwrap()); - tokio::spawn(handle(stream)); + tokio::spawn(handle(stream, buf)); } } -async fn handle(mut stream: UdpStreamLocal) { +async fn handle(mut stream: UdpStreamLocal, first_packet: Vec) { let mut buf = [0u8; 32]; let mut i = 0; loop { println!("server: recv[{}]..", i); - let n = stream.read(&mut buf).await.unwrap(); + let n = if i == 0 { + let len = buf.len().min(first_packet.len()); + buf[..len].copy_from_slice(&first_packet[..len]); + len + } else { + stream.read(&mut buf).await.unwrap() + }; assert_eq!(&buf[..n], MSG); println!("server: send[{}]..", i); diff --git a/tests/local_multi.rs b/tests/local_multi.rs index 8feeff3..7d282ab 100644 --- a/tests/local_multi.rs +++ b/tests/local_multi.rs @@ -43,25 +43,33 @@ async fn client(laddr: &'static str, idx: usize) { } async fn server() { - let socket = UdpSocket::bind(BIND).await.unwrap(); - let listener = UdpListener::new(socket); + let addr = BIND.parse::().unwrap(); + let listener = UdpListener::new(addr).unwrap(); - let mut buf = vec![0u8; 0x2000]; let mut idx = 0; - while let Ok((stream, addr)) = listener.accept(&mut buf).await { + loop { + let mut buf = vec![0u8; 0x2000]; + let (n, stream, addr) = listener.accept(&mut buf).await.unwrap(); + buf.truncate(n); println!("server: handle {}", addr); - tokio::spawn(handle(stream, idx)); + tokio::spawn(handle(stream, idx, buf)); idx += 1; } } -async fn handle(mut stream: UdpStreamLocal, idx: usize) { +async fn handle(mut stream: UdpStreamLocal, idx: usize, first_packet: Vec) { let mut buf = [0u8; 32]; let mut i = 0; loop { println!("handle[{idx}]: recv[{}]..", i); - let n = stream.read(&mut buf).await.unwrap(); + let n = if i == 0 { + let len = buf.len().min(first_packet.len()); + buf[..len].copy_from_slice(&first_packet[..len]); + len + } else { + stream.read(&mut buf).await.unwrap() + }; assert_eq!(&buf[..n], MSG); println!("handle[{idx}]: send[{}]..", i); diff --git a/tests/local_timeout.rs b/tests/local_timeout.rs index c55e503..d220814 100644 --- a/tests/local_timeout.rs +++ b/tests/local_timeout.rs @@ -44,23 +44,30 @@ async fn client() { } async fn server() { - let socket = UdpSocket::bind(BIND).await.unwrap(); - let listener = UdpListener::new(socket); - - let mut buf = vec![0u8; 0x2000]; + let addr = BIND.parse::().unwrap(); + let listener = UdpListener::new(addr).unwrap(); - while let Ok((stream, addr)) = listener.accept(&mut buf).await { + loop { + let mut buf = vec![0u8; 0x2000]; + let (n, stream, addr) = listener.accept(&mut buf).await.unwrap(); + buf.truncate(n); assert_eq!(addr, SENDER.parse().unwrap()); - tokio::spawn(handle(stream)); + tokio::spawn(handle(stream, buf)); } } -async fn handle(mut stream: UdpStreamLocal) { +async fn handle(mut stream: UdpStreamLocal, first_packet: Vec) { let mut buf = [0u8; 32]; let mut i = 0; loop { println!("server: recv[{}]..", i); - let n = stream.read(&mut buf).await.unwrap(); + let n = if i == 0 { + let len = buf.len().min(first_packet.len()); + buf[..len].copy_from_slice(&first_packet[..len]); + len + } else { + stream.read(&mut buf).await.unwrap() + }; if n == 0 { println!("server: recv[{}].. EOF", i); diff --git a/tests/remote.rs b/tests/remote.rs index 2097e42..ecf74cc 100644 --- a/tests/remote.rs +++ b/tests/remote.rs @@ -20,9 +20,9 @@ async fn remote() { async fn client() { sleep(WAIT).await; - let addr = BIND.parse::().unwrap(); - let socket = UdpSocket::bind(SENDER).await.unwrap(); - let mut stream = UdpStreamRemote::new(socket, addr); + let local = SENDER.parse::().unwrap(); + let remote = BIND.parse::().unwrap(); + let mut stream = UdpStreamRemote::new(local, remote).await.unwrap(); let mut buf = [0u8; 32]; for i in 0..5 { diff --git a/tests/remote_timeout.rs b/tests/remote_timeout.rs index f44e510..f572022 100644 --- a/tests/remote_timeout.rs +++ b/tests/remote_timeout.rs @@ -22,9 +22,9 @@ async fn remote_timeout() { async fn client() { sleep(WAIT).await; - let addr = BIND.parse::().unwrap(); - let socket = UdpSocket::bind(SENDER).await.unwrap(); - let mut stream = UdpStreamRemote::new(socket, addr); + let local = SENDER.parse::().unwrap(); + let remote = BIND.parse::().unwrap(); + let mut stream = UdpStreamRemote::new(local, remote).await.unwrap(); let mut buf = [0u8; 32]; for i in 0..3 {