diff --git a/Cargo.lock b/Cargo.lock index ac98d773..7d973fc3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1330,6 +1330,7 @@ dependencies = [ "ctrlc", "delegate", "educe", + "io-uring", "ipnet", "jsonwebtoken", "libc", @@ -1354,6 +1355,7 @@ dependencies = [ "tracing", "tracing-log", "tracing-subscriber", + "tun", "twelf", ] diff --git a/Cargo.toml b/Cargo.toml index 6afc7e2d..38760725 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ clap = { version = "4.4.7", features = ["derive"] } ctrlc = { version = "3.4.2", features = ["termination"] } delegate = "0.12.0" educe = { version = "0.6.0", default-features = false, features = ["Debug"] } +io-uring = "0.7.0" ipnet = { version = "2.8.0", features = ["serde"]} libc = "0.2.152" lightway-app-utils = { path = "./lightway-app-utils" } @@ -52,3 +53,4 @@ tokio-util = "0.7.10" tracing = "0.1.37" tracing-subscriber = "0.3.17" twelf = { version = "0.15.0", default-features = false, features = ["env", "clap", "yaml"]} +tun = { version = "0.7.1" } diff --git a/README.md b/README.md index 8c43cb02..84f7c081 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Protocol and design documentation can be found in the Lightway rust implementation currently supports Linux OS. Both x86_64 and arm64 platforms are supported and built as part of CI. -Support for other platforms will be added soon. +Support for other client platforms will be added soon. ## Development steps diff --git a/lightway-app-utils/Cargo.toml b/lightway-app-utils/Cargo.toml index 3a73a13d..c152beef 100644 --- a/lightway-app-utils/Cargo.toml +++ b/lightway-app-utils/Cargo.toml @@ -23,7 +23,7 @@ bytes.workspace = true clap.workspace = true fs-mistrust = { version = "0.8.0", default-features = false } humantime = "2.1.0" -io-uring = { version = "0.7.0", optional = true } +io-uring = { workspace = true, optional = true } ipnet.workspace = true libc.workspace = true lightway-core.workspace = true @@ -38,11 +38,12 @@ tokio-stream = { workspace = true, optional = true } tokio-util.workspace = true tracing.workspace = true tracing-subscriber = { workspace = true, features = ["json"] } -tun = { version = "0.7", features = ["async"] } +tun = { workspace = true, features = ["async"] } [[example]] name = "udprelay" path = "examples/udprelay.rs" +required-features = ["io-uring"] [dev-dependencies] async-trait.workspace = true diff --git a/lightway-app-utils/src/lib.rs b/lightway-app-utils/src/lib.rs index 4e48b6c1..5e5b0215 100644 --- a/lightway-app-utils/src/lib.rs +++ b/lightway-app-utils/src/lib.rs @@ -14,6 +14,9 @@ mod event_stream; mod iouring; mod tun; +mod net; +pub use net::{sockaddr_from_socket_addr, socket_addr_from_sockaddr}; + #[cfg(feature = "tokio")] pub use connection_ticker::{ connection_ticker_cb, ConnectionTicker, ConnectionTickerState, ConnectionTickerTask, Tickable, diff --git a/lightway-app-utils/src/net.rs b/lightway-app-utils/src/net.rs new file mode 100644 index 00000000..b306c6df --- /dev/null +++ b/lightway-app-utils/src/net.rs @@ -0,0 +1,179 @@ +use std::{io, net::SocketAddr}; + +/// Convert from `libc::sockaddr_storage` to `std::net::SocketAddr` +#[allow(unsafe_code)] +pub fn socket_addr_from_sockaddr( + storage: &libc::sockaddr_storage, + len: libc::socklen_t, +) -> io::Result { + match storage.ss_family as libc::c_int { + libc::AF_INET => { + if (len as usize) < std::mem::size_of::() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid argument (inet len)", + )); + } + + // SAFETY: Casting from sockaddr_storage to sockaddr_in is safe since we have validated the len. + let addr = + unsafe { &*(storage as *const libc::sockaddr_storage as *const libc::sockaddr_in) }; + + let ip = u32::from_be(addr.sin_addr.s_addr); + let ip = std::net::Ipv4Addr::from_bits(ip); + let port = u16::from_be(addr.sin_port); + + Ok((ip, port).into()) + } + libc::AF_INET6 => { + if (len as usize) < std::mem::size_of::() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid argument (inet6 len)", + )); + } + // SAFETY: Casting from sockaddr_storage to sockaddr_in6 is safe since we have validated the len. + let addr = unsafe { + &*(storage as *const libc::sockaddr_storage as *const libc::sockaddr_in6) + }; + + let ip = u128::from_be_bytes(addr.sin6_addr.s6_addr); + let ip = std::net::Ipv6Addr::from_bits(ip); + let port = u16::from_be(addr.sin6_port); + + Ok((ip, port).into()) + } + _ => Err(io::Error::new( + std::io::ErrorKind::InvalidInput, + "invalid argument (ss_family)", + )), + } +} + +/// Convert from `std::net::SocketAddr` to `libc::sockaddr_storage`+`libc::socklen_t` +#[allow(unsafe_code)] +pub fn sockaddr_from_socket_addr(addr: SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) { + // SAFETY: All zeroes is a valid sockaddr_storage + let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; + + let len = match addr { + SocketAddr::V4(v4) => { + let p = &mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in; + // SAFETY: sockaddr_storage is defined to be big enough for any sockaddr_*. + unsafe { + p.write(libc::sockaddr_in { + sin_family: libc::AF_INET as _, + sin_port: v4.port().to_be(), + sin_addr: libc::in_addr { + s_addr: v4.ip().to_bits().to_be(), + }, + sin_zero: Default::default(), + }) + }; + std::mem::size_of::() as libc::socklen_t + } + SocketAddr::V6(v6) => { + let p = &mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in6; + // SAFETY: sockaddr_storage is defined to be big enough for any sockaddr_*. + unsafe { + p.write(libc::sockaddr_in6 { + sin6_family: libc::AF_INET6 as _, + sin6_port: v6.port().to_be(), + sin6_flowinfo: v6.flowinfo().to_be(), + sin6_addr: libc::in6_addr { + s6_addr: v6.ip().to_bits().to_be_bytes(), + }, + sin6_scope_id: v6.scope_id().to_be(), + }) + }; + std::mem::size_of::() as libc::socklen_t + } + }; + + (storage, len) +} + +#[cfg(test)] +mod tests { + #![allow(unsafe_code, clippy::undocumented_unsafe_blocks)] + + use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + str::FromStr as _, + }; + + use super::*; + + use test_case::test_case; + + #[test] + fn socket_addr_from_sockaddr_unknown_af() { + // Test assumes these don't match the zero initialized + // libc::sockaddr_storage::ss_family. + assert_ne!(libc::AF_INET, 0); + assert_ne!(libc::AF_INET6, 0); + + let storage = unsafe { std::mem::zeroed() }; + let err = + socket_addr_from_sockaddr(&storage, std::mem::size_of::() as _) + .unwrap_err(); + + assert!(matches!(err.kind(), std::io::ErrorKind::InvalidInput)); + assert!(err.to_string().contains("invalid argument (ss_family)")); + } + + #[test] + fn socket_addr_from_sockaddr_unknown_af_inet_short() { + let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; + storage.ss_family = libc::AF_INET as libc::sa_family_t; + + let err = socket_addr_from_sockaddr( + &storage, + (std::mem::size_of::() - 1) as _, + ) + .unwrap_err(); + + assert!(matches!(err.kind(), std::io::ErrorKind::InvalidInput)); + assert!(err.to_string().contains("invalid argument (inet len)")); + } + + #[test] + fn socket_addr_from_sockaddr_unknown_af_inet6_short() { + let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; + storage.ss_family = libc::AF_INET6 as libc::sa_family_t; + + let err = socket_addr_from_sockaddr( + &storage, + (std::mem::size_of::() - 1) as _, + ) + .unwrap_err(); + + assert!(matches!(err.kind(), std::io::ErrorKind::InvalidInput)); + assert!(err.to_string().contains("invalid argument (inet6 len)")); + } + + #[test] + fn sockaddr_from_socket_addr_inet() { + let socket_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); + let (storage, len) = sockaddr_from_socket_addr(socket_addr); + assert_eq!(storage.ss_family, libc::AF_INET as libc::sa_family_t); + assert_eq!(len as usize, std::mem::size_of::()); + } + + #[test] + fn sockaddr_from_socket_addr_inet6() { + let socket_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8080); + let (storage, len) = sockaddr_from_socket_addr(socket_addr); + assert_eq!(storage.ss_family, libc::AF_INET6 as libc::sa_family_t); + assert_eq!(len as usize, std::mem::size_of::()); + } + + #[test_case("127.0.0.1:443")] + #[test_case("[::1]:8888")] + fn round_trip(addr: &str) { + let orig = SocketAddr::from_str(addr).unwrap(); + let (storage, len) = sockaddr_from_socket_addr(orig); + let round_tripped = socket_addr_from_sockaddr(&storage, len).unwrap(); + assert_eq!(orig, round_tripped) + } +} diff --git a/lightway-client/src/io/outside/tcp.rs b/lightway-client/src/io/outside/tcp.rs index 3ab8133e..204de44b 100644 --- a/lightway-client/src/io/outside/tcp.rs +++ b/lightway-client/src/io/outside/tcp.rs @@ -4,7 +4,7 @@ use std::{net::SocketAddr, sync::Arc}; use tokio::net::TcpStream; use super::OutsideIO; -use lightway_core::{IOCallbackResult, OutsideIOSendCallback, OutsideIOSendCallbackArg}; +use lightway_core::{CowBytes, IOCallbackResult, OutsideIOSendCallback, OutsideIOSendCallbackArg}; pub struct Tcp(tokio::net::TcpStream, SocketAddr); @@ -58,8 +58,8 @@ impl OutsideIO for Tcp { } impl OutsideIOSendCallback for Tcp { - fn send(&self, buf: &[u8]) -> IOCallbackResult { - match self.0.try_write(buf) { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + match self.0.try_write(buf.as_bytes()) { Ok(nr) => IOCallbackResult::Ok(nr), Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { IOCallbackResult::WouldBlock diff --git a/lightway-client/src/io/outside/udp.rs b/lightway-client/src/io/outside/udp.rs index 112059d9..69e34636 100644 --- a/lightway-client/src/io/outside/udp.rs +++ b/lightway-client/src/io/outside/udp.rs @@ -5,7 +5,7 @@ use tokio::net::UdpSocket; use super::OutsideIO; use lightway_app_utils::sockopt; -use lightway_core::{IOCallbackResult, OutsideIOSendCallback, OutsideIOSendCallbackArg}; +use lightway_core::{CowBytes, IOCallbackResult, OutsideIOSendCallback, OutsideIOSendCallbackArg}; pub struct Udp { sock: tokio::net::UdpSocket, @@ -67,8 +67,8 @@ impl OutsideIO for Udp { } impl OutsideIOSendCallback for Udp { - fn send(&self, buf: &[u8]) -> IOCallbackResult { - match self.sock.try_send_to(buf, self.peer_addr) { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + match self.sock.try_send_to(buf.as_bytes(), self.peer_addr) { Ok(nr) => IOCallbackResult::Ok(nr), Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { IOCallbackResult::WouldBlock diff --git a/lightway-core/src/connection/io_adapter.rs b/lightway-core/src/connection/io_adapter.rs index 7b2e22ca..03fc4b32 100644 --- a/lightway-core/src/connection/io_adapter.rs +++ b/lightway-core/src/connection/io_adapter.rs @@ -5,7 +5,8 @@ use more_asserts::*; use wolfssl::IOCallbackResult; use crate::{ - plugin::PluginList, wire, ConnectionType, OutsideIOSendCallbackArg, PluginResult, Version, + plugin::PluginList, wire, ConnectionType, CowBytes, OutsideIOSendCallbackArg, PluginResult, + Version, }; pub(crate) struct SendBuffer { @@ -164,26 +165,28 @@ impl WolfSSLIOAdapter { } } + let b = b.freeze(); + // Send header + buf. If we are in aggressive mode we send it // a total of three times. On any send error we return // immediately without the remaining tries, otherwise we // return the result of the final attempt. if self.aggressive_send { - match self.io.send(&b[..]) { + match self.io.send(CowBytes::Owned(b.clone())) { IOCallbackResult::Ok(_) => {} wb @ IOCallbackResult::WouldBlock => return wb, err @ IOCallbackResult::Err(_) => return err, } - match self.io.send(&b[..]) { + match self.io.send(CowBytes::Owned(b.clone())) { IOCallbackResult::Ok(_) => {} wb @ IOCallbackResult::WouldBlock => return wb, err @ IOCallbackResult::Err(_) => return err, } } - match self.io.send(&b[..]) { + match self.io.send(CowBytes::Owned(b)) { IOCallbackResult::Ok(n) => { // We've sent `n` bytes successfully out of // `wire::Header::WIRE_SIZE` + `b.len()` that we @@ -250,7 +253,7 @@ impl WolfSSLIOAdapter { debug_assert_le!(send_buffer.original_len(), buf.len()); } - match self.io.send(send_buffer.as_bytes()) { + match self.io.send(CowBytes::Borrowed(send_buffer.as_bytes())) { IOCallbackResult::Ok(n) if n == send_buffer.actual_len() => { // We've now sent everything we were originally // asked to, so signal completion of that original @@ -335,7 +338,8 @@ mod tests { } impl OutsideIOSendCallback for FakeOutsideIOSend { - fn send(&self, buf: &[u8]) -> IOCallbackResult { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + let buf = buf.as_bytes(); let (fakes, sent) = &mut *self.0.lock().unwrap(); match fakes.pop_front() { Some(IOCallbackResult::Ok(n)) => { diff --git a/lightway-core/src/io.rs b/lightway-core/src/io.rs index 747a4162..42d5a335 100644 --- a/lightway-core/src/io.rs +++ b/lightway-core/src/io.rs @@ -1,6 +1,6 @@ use std::{net::SocketAddr, sync::Arc}; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use wolfssl::IOCallbackResult; /// Application provided callback used to send inside data. @@ -20,6 +20,33 @@ pub trait InsideIOSendCallback { /// Convenience type to use as function arguments pub type InsideIOSendCallbackArg = Arc + Send + Sync>; +/// A byte buffer to be sent, may be owned or borrowed. +pub enum CowBytes<'a> { + /// An owned buffer + Owned(Bytes), + /// A borrowed buffer + Borrowed(&'a [u8]), +} + +impl CowBytes<'_> { + /// Convert this buffer into an owned `Bytes`. Cheap if this + /// instance if `::Owned`, but copied if not. + pub fn into_owned(self) -> Bytes { + match self { + CowBytes::Owned(b) => b, + CowBytes::Borrowed(b) => Bytes::copy_from_slice(b), + } + } + + /// Gain access to the underlying byte buffer. + pub fn as_bytes(&self) -> &[u8] { + match self { + CowBytes::Owned(b) => b.as_ref(), + CowBytes::Borrowed(b) => b, + } + } +} + /// Application provided callback used to send outside data. pub trait OutsideIOSendCallback { /// Called when Lightway wishes to send some outside data @@ -30,7 +57,7 @@ pub trait OutsideIOSendCallback { /// [`IOCallbackResult::WouldBlock`]. /// /// This is the same method as [`wolfssl::IOCallbacks::send`]. - fn send(&self, buf: &[u8]) -> IOCallbackResult; + fn send(&self, buf: CowBytes) -> IOCallbackResult; /// Get the peer's [`SocketAddr`] fn peer_addr(&self) -> SocketAddr; diff --git a/lightway-core/src/lib.rs b/lightway-core/src/lib.rs index 3930ef20..d89dfa69 100644 --- a/lightway-core/src/lib.rs +++ b/lightway-core/src/lib.rs @@ -38,7 +38,8 @@ pub use context::{ ServerAuthArg, ServerAuthHandle, ServerAuthResult, ServerContext, ServerContextBuilder, }; pub use io::{ - InsideIOSendCallback, InsideIOSendCallbackArg, OutsideIOSendCallback, OutsideIOSendCallbackArg, + CowBytes, InsideIOSendCallback, InsideIOSendCallbackArg, OutsideIOSendCallback, + OutsideIOSendCallbackArg, }; pub use packet::OutsidePacket; pub use plugin::{ diff --git a/lightway-core/tests/connection.rs b/lightway-core/tests/connection.rs index 50ee400c..aca1ab4e 100644 --- a/lightway-core/tests/connection.rs +++ b/lightway-core/tests/connection.rs @@ -106,7 +106,8 @@ impl TestSock for TestDatagramSock { } impl OutsideIOSendCallback for TestDatagramSock { - fn send(&self, buf: &[u8]) -> IOCallbackResult { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + let buf = buf.as_bytes(); match self.0.try_send(buf) { Ok(nr) => IOCallbackResult::Ok(nr), Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { @@ -156,8 +157,8 @@ impl TestSock for TestStreamSock { } impl OutsideIOSendCallback for TestStreamSock { - fn send(&self, buf: &[u8]) -> IOCallbackResult { - match self.0.try_write(buf) { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + match self.0.try_write(buf.as_bytes()) { Ok(nr) => IOCallbackResult::Ok(nr), Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { IOCallbackResult::WouldBlock diff --git a/lightway-server/Cargo.toml b/lightway-server/Cargo.toml index 02274171..18756df8 100644 --- a/lightway-server/Cargo.toml +++ b/lightway-server/Cargo.toml @@ -9,9 +9,8 @@ license = "AGPL-3.0-only" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["io-uring"] +default = [] debug = ["lightway-core/debug"] -io-uring = ["lightway-app-utils/io-uring"] [lints] workspace = true @@ -26,6 +25,7 @@ clap.workspace = true ctrlc.workspace = true delegate.workspace = true educe.workspace = true +io-uring.workspace = true ipnet.workspace = true jsonwebtoken = "9.3.0" libc.workspace = true @@ -48,6 +48,7 @@ tokio-stream = { workspace = true, features = ["time"] } tracing.workspace = true tracing-log = "0.2.0" tracing-subscriber = { workspace = true, features = ["json"] } +tun.workspace = true twelf.workspace = true [dev-dependencies] diff --git a/lightway-server/src/args.rs b/lightway-server/src/args.rs index f4b87ad7..23786993 100644 --- a/lightway-server/src/args.rs +++ b/lightway-server/src/args.rs @@ -71,13 +71,25 @@ pub struct Config { #[clap(long, default_value_t)] pub enable_pqc: bool, - /// Enable IO-uring interface for Tunnel - #[clap(long, default_value_t)] - pub enable_tun_iouring: bool, - - /// IO-uring submission queue count. Only applicable when - /// `enable_tun_iouring` is `true` - // Any value more than 1024 negatively impact the throughput + /// Total IO-uring submission queue count. + /// + /// Must be larger than the total of: + /// + /// UDP: + /// + /// iouring_tun_rx_count + iouring_udp_rx_count + + /// iouring_tx_count + 1 (cancellation request) + /// + /// TCP: + /// + /// iouring_tun_rx_count + iouring_tx_count + 1 (cancellation + /// request) + 2 * maximum number of connections. + /// + /// Each connection actually uses up to 3 slots, a persistent + /// recv request and on demand slots for TX and cancellation + /// (teardown). + /// + /// There is no downside to setting this much larger. #[clap(long, default_value_t = 1024)] pub iouring_entry_count: usize, @@ -87,6 +99,36 @@ pub struct Config { #[clap(long, default_value = "100ms")] pub iouring_sqpoll_idle_time: Duration, + /// Number of concurrent TUN device read requests to issue to + /// IO-uring. Setting this too large may negatively impact + /// performance. + #[clap(long, default_value_t = 64)] + pub iouring_tun_rx_count: u32, + + /// Configure TUN device in blocking mode. This can allow + /// equivalent performance with fewer `ìouring-tun-rx-count` + /// entries but can significantly harm performance on some kernels + /// where the kernel does not indicate that the tun device handles + /// `FMODE_NOWAIT`. + /// + /// If blocking mode is enabled then `iouring_tun_rx_count` may be + /// set much lower. + /// + /// This was fixed by + /// which was part of v6.4-rc1. + #[clap(long, default_value_t = false)] + pub iouring_tun_blocking: bool, + + /// Number of concurrent UDP socket recvmsg requests to issue to + /// IO-uring. + #[clap(long, default_value_t = 32)] + pub iouring_udp_rx_count: u32, + + /// Maximum number of concurrent UDP + TUN sendmsg/write requests + /// to issue to IO-uring. + #[clap(long, default_value_t = 512)] + pub iouring_tx_count: u32, + /// Log format #[clap(long, value_enum, default_value_t = LogFormat::Full)] pub log_format: LogFormat, @@ -111,6 +153,10 @@ pub struct Config { #[clap(long, default_value_t = ByteSize::mib(15))] pub udp_buffer_size: ByteSize, + /// Set UDP buffer size. Default value is 256 KiB. + #[clap(long, default_value_t = ByteSize::kib(256))] + pub tcp_buffer_size: ByteSize, + /// Enable WolfSSL debug logging #[cfg(feature = "debug")] #[clap(long)] diff --git a/lightway-server/src/io.rs b/lightway-server/src/io.rs index c32ee10e..36f6d332 100644 --- a/lightway-server/src/io.rs +++ b/lightway-server/src/io.rs @@ -1,2 +1,280 @@ pub(crate) mod inside; pub(crate) mod outside; + +mod ffi; +mod tx; + +use std::{ + os::fd::{AsRawFd, OwnedFd, RawFd}, + sync::{Arc, Mutex}, + time::Duration, +}; + +use anyhow::{anyhow, Context as _, Result}; +use io_uring::{ + cqueue::Entry as CEntry, + opcode, + squeue::Entry as SEntry, + types::{Fd, Fixed}, + Builder, IoUring, SubmissionQueue, Submitter, +}; + +use ffi::{iovec, msghdr}; +pub use tx::TxQueue; + +/// Convenience function to handle errors in a uring result codes +/// (which are negative errno codes). +fn io_uring_res(res: i32) -> std::io::Result { + if res < 0 { + Err(std::io::Error::from_raw_os_error(-res)) + } else { + Ok(res) + } +} + +/// An I/O source pushing requests to a uring instance +pub(crate) trait UringIoSource: Send { + /// Return the raw file descriptor. This will be registered as an + /// fd with the ring, allowing the use of io_uring::types::Fixed. + fn as_raw_fd(&self) -> RawFd; + + /// Push the initial set of requests to `sq`. + fn push_initial_ops(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()>; + + /// Complete an rx request + fn complete_rx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()>; + + /// Complete a tx request + fn complete_tx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()>; +} + +pub(crate) enum OutsideIoSource { + Udp(outside::udp::UdpServer), + Tcp(outside::tcp::TcpServer), +} + +// Avoiding `dyn`amic dispatch is a small performance win. +impl UringIoSource for OutsideIoSource { + fn as_raw_fd(&self) -> RawFd { + match self { + OutsideIoSource::Udp(udp) => udp.as_raw_fd(), + OutsideIoSource::Tcp(tcp) => tcp.as_raw_fd(), + } + } + + fn push_initial_ops(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { + match self { + OutsideIoSource::Udp(udp) => udp.push_initial_ops(sq), + OutsideIoSource::Tcp(tcp) => tcp.push_initial_ops(sq), + } + } + + fn complete_rx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + match self { + OutsideIoSource::Udp(udp) => udp.complete_rx(sq, cqe, idx), + OutsideIoSource::Tcp(tcp) => tcp.complete_rx(sq, cqe, idx), + } + } + + fn complete_tx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + match self { + OutsideIoSource::Udp(udp) => udp.complete_tx(sq, cqe, idx), + OutsideIoSource::Tcp(tcp) => tcp.complete_tx(sq, cqe, idx), + } + } +} + +pub(crate) struct Loop { + ring: IoUring, + + tx: Arc>, + + cancel_buf: u8, + + outside: OutsideIoSource, + inside: inside::tun::Tun, +} + +impl Loop { + /// Use for outside IO requests, `self.outside.as_raw_fd` will be registered in this slot. + const FIXED_OUTSIDE_FD: Fixed = Fixed(0); + /// Use for inside IO requests, `self.inside.as_raw_fd` will be registered in this slot. + const FIXED_INSIDE_FD: Fixed = Fixed(1); + + /// Masks the bits used by `*_USER_DATA_BASE` + const USER_DATA_TYPE_MASK: u64 = 0xe000_0000_0000_0000; + + /// Indexes in this range will result in a call to `self.outside.complete_rx` + const OUTSIDE_RX_USER_DATA_BASE: u64 = 0xc000_0000_0000_0000; + /// Indexes in this range will result in a call to `self.outside.complete_tx` + const OUTSIDE_TX_USER_DATA_BASE: u64 = 0x8000_0000_0000_0000; + + /// Indexes in this range will result in a call to `self.inside.complete_rx` + const INSIDE_RX_USER_DATA_BASE: u64 = 0x4000_0000_0000_0000; + /// Indexes in this range will result in a call to `self.inside.complete_tx` + const INSIDE_TX_USER_DATA_BASE: u64 = 0x2000_0000_0000_0000; + + /// Indexes in this range are used by `Loop` itself. + const CONTROL_USER_DATA_BASE: u64 = 0x0000_0000_0000_0000; + + /// A read request on the cancellation fd (used to exit the io loop) + const CANCEL_USER_DATA: u64 = Self::CONTROL_USER_DATA_BASE + 1; + + /// Return user data for a particular outside rx index. + fn outside_rx_user_data(idx: u32) -> u64 { + Self::OUTSIDE_RX_USER_DATA_BASE + (idx as u64) + } + + /// Return user data for a particular inside rx index. + fn inside_rx_user_data(idx: u32) -> u64 { + Self::INSIDE_RX_USER_DATA_BASE + (idx as u64) + } + + /// Return user data for a particular inside tx index. + fn inside_tx_user_data(idx: u32) -> u64 { + Self::INSIDE_TX_USER_DATA_BASE + (idx as u64) + } + + /// Return user data for a particular outside tx index. + fn outside_tx_user_data(idx: u32) -> u64 { + Self::OUTSIDE_TX_USER_DATA_BASE + (idx as u64) + } + + pub(crate) fn new( + ring_size: usize, + sqpoll_idle_time: Duration, + tx: Arc>, + outside: OutsideIoSource, + inside: inside::tun::Tun, + ) -> Result { + tracing::info!(ring_size, "creating IoUring"); + let mut builder: Builder = IoUring::builder(); + + builder.dontfork(); + + if sqpoll_idle_time.as_millis() > 0 { + let idle_time: u32 = sqpoll_idle_time + .as_millis() + .try_into() + .with_context(|| "invalid sqpoll idle time")?; + // This setting makes CPU go 100% when there is continuous traffic + builder.setup_sqpoll(idle_time); // Needs 5.13 + } + + let ring = builder + .build(ring_size as u32) + .inspect_err(|e| tracing::error!("iouring setup failed: {e}"))?; + + Ok(Self { + ring, + tx, + cancel_buf: 0, + outside, + inside, + }) + } + + pub(crate) fn run(mut self, cancel: OwnedFd) -> Result<()> { + let (submitter, mut sq, mut cq) = self.ring.split(); + + submitter.register_files(&[self.outside.as_raw_fd(), self.inside.as_raw_fd()])?; + + let sqe = opcode::Read::new( + Fd(cancel.as_raw_fd()), + &mut self.cancel_buf as *mut _, + std::mem::size_of_val(&self.cancel_buf) as _, + ) + .build() + .user_data(Self::CANCEL_USER_DATA); + + #[allow(unsafe_code)] + // SAFETY: The buffer is owned by `self.cancel_buf` and `self` is owned + unsafe { + sq.push(&sqe)? + }; + + self.outside.push_initial_ops(&mut sq)?; + self.inside.push_initial_ops(&mut sq)?; + sq.sync(); + + loop { + let _ = submitter.submit_and_wait(1)?; + + cq.sync(); + + for cqe in &mut cq { + let user_data = cqe.user_data(); + + match user_data & Self::USER_DATA_TYPE_MASK { + Self::CONTROL_USER_DATA_BASE => { + match user_data - Self::CONTROL_USER_DATA_BASE { + Self::CANCEL_USER_DATA => { + let res = cqe.result(); + tracing::debug!(?res, "Uring cancelled"); + return Ok(()); + } + idx => { + return Err(anyhow!( + "Unknown control data {user_data:016x} => {idx:016x}" + )) + } + } + } + Self::OUTSIDE_RX_USER_DATA_BASE => { + self.outside.complete_rx( + &mut sq, + cqe, + (user_data - Self::OUTSIDE_RX_USER_DATA_BASE) as u32, + )?; + } + Self::OUTSIDE_TX_USER_DATA_BASE => { + self.outside.complete_tx( + &mut sq, + cqe, + (user_data - Self::OUTSIDE_TX_USER_DATA_BASE) as u32, + )?; + } + + Self::INSIDE_RX_USER_DATA_BASE => { + self.inside.complete_rx( + &mut sq, + cqe, + (user_data - Self::INSIDE_RX_USER_DATA_BASE) as u32, + )?; + } + Self::INSIDE_TX_USER_DATA_BASE => { + self.inside.complete_tx( + &mut sq, + cqe, + (user_data - Self::INSIDE_TX_USER_DATA_BASE) as u32, + )?; + } + + _ => unreachable!(), + } + + self.tx.lock().unwrap().drain(&submitter, &mut sq)?; + } + } + } +} diff --git a/lightway-server/src/io/ffi.rs b/lightway-server/src/io/ffi.rs new file mode 100644 index 00000000..ddd0ad42 --- /dev/null +++ b/lightway-server/src/io/ffi.rs @@ -0,0 +1,53 @@ +#![allow(unsafe_code)] +#![allow(non_camel_case_types, reason = "Using POSIX/libc naming")] + +/// Marker for types which are usable with syscalls +/// +/// # Safety +/// +/// Implement only for types containing raw pointers which are +/// passed to syscalls where the concept of Sync/Send is orthogonal to +/// Rust's model. +pub(super) unsafe trait IsSyscallSafe {} + +// SAFETY: iovec is used with syscalls +unsafe impl IsSyscallSafe for libc::iovec {} +// SAFETY: msghdr is used with syscalls +unsafe impl IsSyscallSafe for libc::msghdr {} + +pub(super) struct SyscallSafe(T); + +impl SyscallSafe { + pub fn new(t: T) -> Self { + Self(t) + } + + pub fn as_mut_ptr(&mut self) -> *mut T { + &mut self.0 as *mut T + } +} + +impl std::ops::Deref for SyscallSafe { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for SyscallSafe { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +// SAFETY: T must be e.g. a libc type which contains raw pointers for syscall use. +// The `pub` aliases below all satisfy this. +unsafe impl Send for SyscallSafe {} + +// SAFETY: T must be e.g. a libc type which contains raw pointers for syscall use. +// The `pub` aliases below all satisfy this. +unsafe impl Sync for SyscallSafe {} + +pub type iovec = SyscallSafe; +pub type msghdr = SyscallSafe; diff --git a/lightway-server/src/io/inside.rs b/lightway-server/src/io/inside.rs index decf5c57..430cbcab 100644 --- a/lightway-server/src/io/inside.rs +++ b/lightway-server/src/io/inside.rs @@ -2,14 +2,4 @@ pub(crate) mod tun; pub(crate) use tun::Tun; -use crate::connection::ConnectionState; -use async_trait::async_trait; -use lightway_core::{IOCallbackResult, InsideIOSendCallbackArg}; -use std::sync::Arc; - -#[async_trait] -pub(crate) trait InsideIO: Sync + Send { - async fn recv_buf(&self) -> IOCallbackResult; - - fn into_io_send_callback(self: Arc) -> InsideIOSendCallbackArg; -} +use super::{io_uring_res, Loop, TxQueue, UringIoSource}; diff --git a/lightway-server/src/io/inside/tun.rs b/lightway-server/src/io/inside/tun.rs index 909b8355..ef96a7da 100644 --- a/lightway-server/src/io/inside/tun.rs +++ b/lightway-server/src/io/inside/tun.rs @@ -1,56 +1,213 @@ -use crate::{io::inside::InsideIO, metrics}; +//! Tun UringIoSource +//! +//! Uses uring indexes: +//! +//! Loop::inside_rx_user_data: +//! - 0..Tun::rx.len(): A set of recv requests +//! +//! Loop::inside_tx_user_data: +//! - Managed by TxQueue + +use crate::ip_manager::IpManager; +use crate::metrics; + +use super::{io_uring_res, Loop, TxQueue, UringIoSource}; use crate::connection::ConnectionState; -use anyhow::Result; -use async_trait::async_trait; + +use anyhow::{Context as _, Result}; use bytes::BytesMut; -use lightway_app_utils::{Tun as AppUtilsTun, TunConfig}; +use io_uring::opcode; use lightway_core::{ - ipv4_update_source, IOCallbackResult, InsideIOSendCallback, InsideIOSendCallbackArg, + ipv4_update_destination, ipv4_update_source, ConnectionError, IOCallbackResult, + InsideIOSendCallback, InsideIOSendCallbackArg, }; -use std::os::fd::{AsRawFd, RawFd}; -use std::sync::Arc; -use std::time::Duration; +use pnet::packet::ipv4::Ipv4Packet; +use std::net::Ipv4Addr; +use std::os::fd::{AsRawFd as _, RawFd}; +use std::sync::{Arc, Mutex}; +use tun::{AbstractDevice as _, Configuration as TunConfig, Device as TunDevice}; + +pub(crate) struct Tun { + tun: TunDevice, + lightway_client_ip: Ipv4Addr, + ip_manager: Arc, -pub(crate) struct Tun(AppUtilsTun); + tx_queue: Arc>, + + mtu: usize, + + rx: Vec, +} impl Tun { - pub async fn new(tun: TunConfig, iouring: Option<(usize, Duration)>) -> Result { - let tun = match iouring { - Some((ring_size, sqpoll_idle_time)) => { - AppUtilsTun::iouring(tun, ring_size, sqpoll_idle_time).await? - } - None => AppUtilsTun::direct(tun).await?, - }; - Ok(Tun(tun)) + pub fn new( + nr_slots: u32, + blocking: bool, + mut tun: TunConfig, + lightway_client_ip: Ipv4Addr, + ip_manager: Arc, + tx_queue: Arc>, + ) -> Result { + tracing::info!("Tun with {nr_slots} slots (blocking: {blocking})"); + + tun.platform_config(|cfg| { + cfg.napi(true); + }); + + let tun = tun::create(&tun)?; + if !blocking { + tun.set_nonblock()?; + } + + let mtu = tun.mtu()? as usize; + + let rx = (0..nr_slots).map(|_| BytesMut::new()).collect(); + + Ok(Tun { + tun, + lightway_client_ip, + ip_manager, + tx_queue, + mtu, + rx, + }) + } + + pub fn inside_io_sender(&self) -> InsideIOSendCallbackArg { + Arc::new(TunInsideIO::new(self.tx_queue.clone(), self)) + } + + fn push_rx(&mut self, sq: &mut io_uring::SubmissionQueue, idx: u32) -> Result<()> { + let buf = &mut self.rx[idx as usize]; + + // Recover full capacity + buf.clear(); + buf.reserve(self.mtu); + + let sqe = opcode::Read::new( + Loop::FIXED_INSIDE_FD, + buf.as_mut_ptr() as *mut _, + buf.capacity() as _, + ) + .build() + .user_data(Loop::inside_rx_user_data(idx)); + + #[allow(unsafe_code)] + // SAFETY: The buffer is owned by `self.rx` and `self` is owned by the `io::Loop` + unsafe { + sq.push(&sqe)?; + } + + sq.sync(); + + Ok(()) } } -impl AsRawFd for Tun { +impl UringIoSource for Tun { fn as_raw_fd(&self) -> RawFd { - self.0.as_raw_fd() + self.tun.as_raw_fd() } -} -#[async_trait] -impl InsideIO for Tun { - async fn recv_buf(&self) -> IOCallbackResult { - match self.0.recv_buf().await { - IOCallbackResult::Ok(buf) => { - metrics::tun_to_client(buf.len()); - IOCallbackResult::Ok(buf) + fn push_initial_ops(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { + for idx in 0..self.rx.len() as u32 { + self.push_rx(sq, idx)? + } + Ok(()) + } + + fn complete_rx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + let res = match io_uring_res(cqe.result()) { + Ok(res) => res, + Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { + self.push_rx(sq, idx)?; + return Ok(()); } - e => e, + Err(err) => return Err(err).with_context(|| "inside read completion"), + }; + + let buf = &mut self.rx[idx as usize]; + + metrics::tun_to_client(res as usize); + + #[allow(unsafe_code)] + // SAFETY: We rely on recv_from giving us the correct size + unsafe { + buf.set_len(res as usize); } + + // Find connection based on client ip (dest ip) and forward packet + let packet = Ipv4Packet::new(buf.as_ref()); + let Some(packet) = packet else { + eprintln!("Invalid inside packet size (less than Ipv4 header)!"); + // Queue another recv + self.push_rx(sq, idx)?; + return Ok(()); + }; + let conn = self.ip_manager.find_connection(packet.get_destination()); + + // Update destination IP address to client's ip + ipv4_update_destination(buf.as_mut(), self.lightway_client_ip); + + if let Some(conn) = conn { + match conn.inside_data_received(buf) { + Ok(()) => {} + Err(ConnectionError::InvalidState) => { + // Skip forwarding packet when offline + metrics::tun_rejected_packet_invalid_state(); + } + Err(ConnectionError::InvalidInsidePacket(_)) => { + // Skip processing invalid packet + metrics::tun_rejected_packet_invalid_inside_packet(); + } + Err(err) => { + let fatal = err.is_fatal(conn.connection_type()); + metrics::tun_rejected_packet_invalid_other(fatal); + if fatal { + conn.handle_end_of_stream(); + return Ok(()); + } + } + } + } else { + metrics::tun_rejected_packet_no_connection(); + }; + + // Queue another recv + self.push_rx(sq, idx)?; + + Ok(()) } - fn into_io_send_callback(self: Arc) -> InsideIOSendCallbackArg { - self + fn complete_tx( + &mut self, + _sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + let _ = self.tx_queue.lock().unwrap().complete(cqe, idx); + Ok(()) } } -impl InsideIOSendCallback for Tun { +pub(crate) struct TunInsideIO(Arc>, usize); + +impl TunInsideIO { + pub(crate) fn new(queue: Arc>, tun: &Tun) -> Self { + Self(queue, tun.mtu) + } +} + +impl InsideIOSendCallback for TunInsideIO { fn send(&self, mut buf: BytesMut, state: &mut ConnectionState) -> IOCallbackResult { + let len = buf.len(); + let Some(client_ip) = state.internal_ip else { metrics::tun_rejected_packet_no_client_ip(); // Ip address not found, dropping the packet @@ -58,11 +215,37 @@ impl InsideIOSendCallback for Tun { }; ipv4_update_source(buf.as_mut(), client_ip); - metrics::tun_from_client(buf.len()); - self.0.try_send(buf) + metrics::tun_from_client(len); + + let buf = buf.freeze(); + + let mut tx_queue = self.0.lock().unwrap(); + + let Some((slot, state)) = tx_queue.take_slot() else { + return IOCallbackResult::WouldBlock; + }; + + let sqe = opcode::Write::new( + Loop::FIXED_INSIDE_FD, + buf.as_ptr() as *mut _, + buf.len() as _, + ) + .build(); + + state.buf = Some(buf); + + #[allow(unsafe_code)] + // SAFETY: + // - slot was optained from take_slot above + // - The buffer is owned by `state` and which is owned by the `TxRing` + unsafe { + tx_queue.push_inside_slot(slot, sqe) + }; + + IOCallbackResult::Ok(len) } fn mtu(&self) -> usize { - self.0.mtu() + self.1 } } diff --git a/lightway-server/src/io/outside.rs b/lightway-server/src/io/outside.rs index e233a80f..f4168d82 100644 --- a/lightway-server/src/io/outside.rs +++ b/lightway-server/src/io/outside.rs @@ -4,10 +4,4 @@ pub(crate) mod udp; pub(crate) use tcp::TcpServer; pub(crate) use udp::UdpServer; -use anyhow::Result; -use async_trait::async_trait; - -#[async_trait] -pub(crate) trait Server { - async fn run(&mut self) -> Result<()>; -} +use super::{io_uring_res, iovec, msghdr, Loop, TxQueue, UringIoSource}; diff --git a/lightway-server/src/io/outside/tcp.rs b/lightway-server/src/io/outside/tcp.rs index 9d6fdb5f..0994fea8 100644 --- a/lightway-server/src/io/outside/tcp.rs +++ b/lightway-server/src/io/outside/tcp.rs @@ -1,231 +1,596 @@ -use std::{net::SocketAddr, sync::Arc}; +//! TcpServer UringIoSource +//! +//! Uses uring indexes: +//! +//! Loop::outside_rx_user_data: +//! - TcpServer::ACCEPT_IDX: +//! The accept request. +//! - The fd for a connection (positive i32): +//! The RX request for that connection. +//! - The fd for a connection (positive i32) + TcpServer::RX_CANCEL_IDX_BIT: +//! A cancellation request for that connection +//! +//! Loop::outside_tx_user_data: +//! - The fd for a connection (positive i32): +//! The TX request for that connection. -use anyhow::{anyhow, Result}; -use async_trait::async_trait; +use std::{ + collections::HashMap, + net::{SocketAddr, TcpStream}, + os::fd::{AsRawFd, FromRawFd as _, RawFd}, + sync::{Arc, Mutex}, +}; + +use anyhow::{anyhow, Context as _, Result}; use bytes::BytesMut; +use bytesize::ByteSize; +use io_uring::{ + opcode, + types::{CancelBuilder, Fd}, +}; +use lightway_app_utils::socket_addr_from_sockaddr; use lightway_core::{ - ConnectionType, IOCallbackResult, OutsideIOSendCallback, OutsidePacket, Version, - MAX_OUTSIDE_MTU, + ConnectionType, CowBytes, IOCallbackResult, OutsideIOSendCallback, OutsidePacket, Version, }; -use socket2::SockRef; -use tokio::io::AsyncReadExt as _; -use tracing::{debug, info, instrument, warn}; +use tracing::{debug, info, warn}; -use crate::{connection_manager::ConnectionManager, metrics}; +use crate::{connection::Connection, connection_manager::ConnectionManager, metrics}; -use super::Server; +use super::{io_uring_res, Loop, TxQueue, UringIoSource}; -struct TcpStream { - sock: Arc, - peer_addr: SocketAddr, +enum ConnectionPhase { + ProxyInitial { + local_addr: SocketAddr, + }, + Proxy { + local_addr: SocketAddr, + rest: usize, + }, + Connected { + conn: Arc, + buffer: Arc>, + }, } -impl OutsideIOSendCallback for TcpStream { - fn send(&self, buf: &[u8]) -> IOCallbackResult { - match self.sock.try_write(buf) { - Ok(nr) => IOCallbackResult::Ok(nr), - Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { - IOCallbackResult::WouldBlock +struct ConnectionState { + sock: TcpStream, + rx_buf: BytesMut, + tx_buffer_size: usize, + phase: ConnectionPhase, +} + +impl ConnectionState { + const RX_BUFFER_SIZE: usize = 15 * 1024 * 1024; // 15M + + fn push_rx(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { + use ConnectionPhase::*; + let (buf, len) = match &mut self.phase { + ProxyInitial { .. } => (self.rx_buf.as_mut_ptr(), 16), + Proxy { rest, .. } => (self.rx_buf[16..].as_mut_ptr(), *rest), + Connected { .. } => { + // Recover full capacity + self.rx_buf.clear(); + self.rx_buf.reserve(Self::RX_BUFFER_SIZE); + (self.rx_buf.as_mut_ptr(), self.rx_buf.capacity()) } - Err(err) => IOCallbackResult::Err(err), - } - } + }; + let fd = self.sock.as_raw_fd(); - fn peer_addr(&self) -> SocketAddr { - self.peer_addr + let sqe = opcode::Recv::new(Fd(fd), buf, len as _) + .build() + .user_data(Loop::outside_rx_user_data(fd as u32)); + + #[allow(unsafe_code)] + // SAFETY: The buffer is owned by `self` and `self` is owned by `TcpServer::fd_map` + unsafe { + sq.push(&sqe)? + }; + + sq.sync(); + + Ok(()) } -} -async fn handle_proxy_protocol(sock: &mut tokio::net::TcpStream) -> Result { - use ppp::v2::{Header, ParseError}; + fn push_cancel(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { + let fd = self.sock.as_raw_fd(); + info!(fd, "Cancelling"); + let builder = CancelBuilder::fd(Fd(fd)).all(); + let sqe = opcode::AsyncCancel2::new(builder) + .build() + .user_data(Loop::outside_rx_user_data( + fd as u32 + TcpServer::RX_CANCEL_IDX_BIT, + )); - // https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt §2.2 - const MINIMUM_LENGTH: usize = 16; + #[allow(unsafe_code)] + // SAFETY: The cancel sqe is well formed above + unsafe { + sq.push(&sqe)? + }; - let mut header: Vec = [0; MINIMUM_LENGTH].into(); - if let Err(err) = sock.read_exact(&mut header[..MINIMUM_LENGTH]).await { - return Err(anyhow!(err).context("Failed to read initial PROXY header")); - }; - let rest = match Header::try_from(&header[..]) { - // Failure tells us exactly how many more bytes are required. - Err(ParseError::Partial(_, rest)) => rest, + sq.sync(); - Ok(_) => { - // The initial 16 bytes is never enough to actually succeed. - return Err(anyhow!("Unexpectedly parsed initial PROXY header")); - } - Err(err) => { - return Err(anyhow!(err).context("Failed to parse initial PROXY header")); + Ok(()) + } + + fn complete_tx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + ) -> Result<()> { + match &mut self.phase { + // Nothing to do for either of these cases. + ConnectionPhase::ProxyInitial { .. } | ConnectionPhase::Proxy { .. } => Ok(()), + ConnectionPhase::Connected { + conn: _conn, + buffer, + } => buffer.lock().unwrap().complete_tx(sq, cqe), } - }; + } - header.resize(MINIMUM_LENGTH + rest, 0); + fn complete_rx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + tx_queue: &Arc>, + conn_manager: &Arc, + ) -> Result<()> { + use ppp::v2::{Header, ParseError}; + use ConnectionPhase::*; - if let Err(err) = sock.read_exact(&mut header[MINIMUM_LENGTH..]).await { - return Err(anyhow!(err).context("Failed to read remainder of PROXY header")); - }; + let res = io_uring_res(cqe.result()).with_context(|| "outside recv completion")?; - let header = match Header::try_from(&header[..]) { - Ok(h) => h, - Err(err) => { - return Err(anyhow!(err).context("Failed to parse complete PROXY header")); - } - }; + match &mut self.phase { + ProxyInitial { local_addr } => { + assert!(16 == res); + #[allow(unsafe_code)] + // SAFETY: We rely on recv_from giving us the correct size + unsafe { + self.rx_buf.set_len(res as usize); + } - let addr = match header.addresses { - ppp::v2::Addresses::Unspecified => { - return Err(anyhow!("Unspecified PROXY connection")); - } - ppp::v2::Addresses::IPv4(addr) => { - SocketAddr::new(addr.source_address.into(), addr.source_port) - } - ppp::v2::Addresses::IPv6(_) => { - return Err(anyhow!("IPv6 PROXY connection")); - } - ppp::v2::Addresses::Unix(_) => { - return Err(anyhow!("Unix PROXY connection")); - } - }; - Ok(addr) -} + let rest = match Header::try_from(&self.rx_buf[..]) { + // Failure tells us exactly how many more bytes are required. + Err(ParseError::Partial(_, rest)) => rest, -#[instrument(level = "trace", skip_all)] -async fn handle_connection( - mut sock: tokio::net::TcpStream, - mut peer_addr: SocketAddr, - local_addr: SocketAddr, - conn_manager: Arc, - proxy_protocol: bool, -) { - if proxy_protocol { - peer_addr = match handle_proxy_protocol(&mut sock).await { - Ok(real_addr) => real_addr, - Err(err) => { - debug!(?err, "Failed to process PROXY header"); - metrics::connection_accept_proxy_header_failed(); - return; + Ok(_) => { + // The initial 16 bytes is never enough to actually succeed. + return Err(anyhow!("Unexpectedly parsed initial PROXY header")); + } + Err(err) => { + return Err(anyhow!(err).context("Failed to parse initial PROXY header")); + } + }; + + self.phase = Proxy { + local_addr: *local_addr, + rest, + } } - }; - } + Proxy { local_addr, rest } => { + assert!(*rest == res as usize); + #[allow(unsafe_code)] + // SAFETY: We rely on recv_from giving us the correct size + // We read 16 bytes in state ProxyInitial + unsafe { + self.rx_buf.set_len((res + 16) as usize); + } + let header = match Header::try_from(&self.rx_buf[..]) { + Ok(h) => h, + Err(err) => { + return Err(anyhow!(err).context("Failed to parse complete PROXY header")); + } + }; - let sock = Arc::new(sock); - - let outside_io = Arc::new(TcpStream { - sock: sock.clone(), - peer_addr, - }); - // TCP has no version indication, default to the minimum - // supported version. - let Ok(conn) = - conn_manager.create_streaming_connection(Version::MINIMUM, local_addr, outside_io) - else { - return; - }; - - // We no longer need to hold this reference. - drop(conn_manager); - - let mut buf = BytesMut::with_capacity(MAX_OUTSIDE_MTU); - let err: anyhow::Error = loop { - // Recover full capacity - buf.clear(); - buf.reserve(MAX_OUTSIDE_MTU); - if let Err(e) = sock.readable().await { - break anyhow!(e).context("Sock readable error"); - } + let peer_addr = match header.addresses { + ppp::v2::Addresses::Unspecified => { + return Err(anyhow!("Unspecified PROXY connection")); + } + ppp::v2::Addresses::IPv4(addr) => { + SocketAddr::new(addr.source_address.into(), addr.source_port) + } + ppp::v2::Addresses::IPv6(_) => { + return Err(anyhow!("IPv6 PROXY connection")); + } + ppp::v2::Addresses::Unix(_) => { + return Err(anyhow!("Unix PROXY connection")); + } + }; - match sock.try_read_buf(&mut buf) { - Ok(0) => { - // EOF - break anyhow!("End of stream"); + let buffer = + TcpSocketBuffer::new(tx_queue.clone(), self.tx_buffer_size, &self.sock); + let outside_io = Arc::new(TcpSocket { + buffer: buffer.clone(), + peer_addr, + }); + let conn = conn_manager.create_streaming_connection( + Version::MINIMUM, + *local_addr, + outside_io, + )?; + self.phase = ConnectionPhase::Connected { conn, buffer } } - Ok(_nr) => {} - Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { - // Spuriously failed to read, keep waiting - continue; + Connected { + conn, + buffer: _buffer, + } => { + if res == 0 { + // EOF + conn.handle_end_of_stream(); + return Err(anyhow!("End of stream")); + } + + #[allow(unsafe_code)] + // SAFETY: We rely on recv_from giving us the correct size + unsafe { + self.rx_buf.set_len(res as usize); + } + let pkt = OutsidePacket::Wire(&mut self.rx_buf, ConnectionType::Stream); + if let Err(err) = conn.outside_data_received(pkt) { + warn!("Failed to process outside data: {err}"); + if conn.handle_outside_data_error(&err).is_break() { + return Err(anyhow!(err).context("Outside data fatal error")); + } + } } - Err(err) => break anyhow!(err).context("TCP read error"), }; + self.push_rx(sq)?; + Ok(()) + } +} + +pub(in super::super) struct TcpSocketBuffer { + tx_queue: Arc>, + fd: Fd, + // We double buffer the tx. + tx_in_flight: BytesMut, + tx_buffer: BytesMut, + tx_buffer_size: usize, +} + +impl TcpSocketBuffer { + fn new( + tx_queue: Arc>, + tx_buffer_size: usize, + sock: &impl AsRawFd, + ) -> Arc> { + Arc::new(Mutex::new(TcpSocketBuffer { + tx_queue, + fd: Fd(sock.as_raw_fd()), + tx_in_flight: BytesMut::new(), + tx_buffer: BytesMut::new(), + tx_buffer_size, + })) + } + + fn push_tx(&mut self) { + let mut tx_queue = self.tx_queue.lock().unwrap(); + let len = self.tx_in_flight.len(); + + let sqe = opcode::Send::new(self.fd, self.tx_in_flight.as_ptr() as *const _, len as _) + .flags(libc::MSG_WAITALL) + .build() + .user_data(Loop::outside_tx_user_data(self.fd.0 as u32)); + + #[allow(unsafe_code)] + // SAFETY: + // - The buffer is owned by `self` and which is owned by the connection and ultimately by `TcpServer::fd_map` + unsafe { + tx_queue.push(sqe) + }; + } - let pkt = OutsidePacket::Wire(&mut buf, ConnectionType::Stream); - if let Err(err) = conn.outside_data_received(pkt) { - warn!("Failed to process outside data: {err}"); - if conn.handle_outside_data_error(&err).is_break() { - break anyhow!(err).context("Outside data fatal error"); + fn send(&mut self, buf: CowBytes) -> IOCallbackResult { + let bytes = buf.as_bytes(); + + if !self.tx_in_flight.is_empty() { + // tx_buffer_size is not a strict limit, but once we have + // exceeded it we stop adding more. + if self.tx_buffer.len() > self.tx_buffer_size { + return IOCallbackResult::WouldBlock; } + + self.tx_buffer.extend_from_slice(bytes); + return IOCallbackResult::Ok(bytes.len()); } - }; - conn.handle_end_of_stream(); + self.tx_in_flight.extend_from_slice(bytes); + self.push_tx(); + + IOCallbackResult::Ok(bytes.len()) + } + + pub fn complete_tx( + &mut self, + _sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + ) -> Result<()> { + let res = io_uring_res(cqe.result()).with_context(|| "outside send completion")? as usize; + + // We use MSG_WAITALL so this should not happen + assert!(res == self.tx_in_flight.len(), "Unexpected short send"); + + self.tx_in_flight.clear(); + + std::mem::swap(&mut self.tx_buffer, &mut self.tx_in_flight); + + if !self.tx_in_flight.is_empty() { + self.push_tx(); + } + + Ok(()) + } +} + +struct TcpSocket { + buffer: Arc>, + peer_addr: SocketAddr, +} - info!("Connection closed: {:?}", err); +impl OutsideIOSendCallback for TcpSocket { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + self.buffer.lock().unwrap().send(buf) + } + + fn peer_addr(&self) -> SocketAddr { + self.peer_addr + } } pub(crate) struct TcpServer { conn_manager: Arc, - sock: Arc, + sock: Arc, + tx_queue: Arc>, + tx_buffer_size: usize, proxy_protocol: bool, + + // Buffers passed to opcode::Accept + accept_addr: Box<(libc::sockaddr_storage, libc::socklen_t)>, + // Map from accepted fds to connections + fd_map: HashMap, } impl TcpServer { + // idx reserved for the accept request. Cannot clash with indexes + // for connections since those are fd numbers which are positive + // i32 values. + const ACCEPT_IDX: u32 = 0x8000_0000; + + // Signals a cancelation request for a connection when added to + // the idx for a rx request (which is an fd number). Since fd is + // never 0 (that is stdin) cannot clash with ACCEPT_IDX. + // + // We must cancel any in flight requests before destroying the + // connection state since they may be reading from owned data or, + // worse, writing to it! + const RX_CANCEL_IDX_BIT: u32 = 0x8000_0000; + pub(crate) async fn new( conn_manager: Arc, + tx_queue: Arc>, bind_address: SocketAddr, proxy_protocol: bool, + tcp_buffer_size: ByteSize, ) -> Result { - let sock = Arc::new(tokio::net::TcpListener::bind(bind_address).await?); + eprintln!("Binding to {bind_address}"); + let sock = tokio::net::TcpListener::bind(bind_address).await?; + eprintln!("Bound to {bind_address}"); + + let sock = sock.into_std()?; + sock.set_nonblocking(false)?; + let sock = Arc::new(sock); + + let tx_buffer_size = tcp_buffer_size.as_u64().try_into()?; Ok(Self { conn_manager, sock, + tx_queue, + tx_buffer_size, proxy_protocol, + + #[allow(unsafe_code)] + // SAFETY: All zeroes is a valid sockaddr_storage + accept_addr: Box::new((unsafe { std::mem::zeroed() }, 0)), + + fd_map: Default::default(), }) } -} -#[async_trait] -impl Server for TcpServer { - async fn run(&mut self) -> Result<()> { + fn push_accept(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { info!("Accepting traffic on {}", self.sock.local_addr()?); - loop { - let (sock, peer_addr) = match self.sock.accept().await { - Ok(r) => r, - Err(err) => { - // Some of the errors which accept(2) can return - // - // while never a good thing needn't necessarily be - // fatal to the entire server and prevent us from - // servicing existing connections or potentially - // new connections in the future. - warn!(?err, "Failed to accept a new connection"); - metrics::connection_accept_failed(); - continue; - } - }; - - sock.set_nodelay(true)?; - let local_addr = match SockRef::from(&sock).local_addr() { - Ok(local_addr) => local_addr, - Err(err) => { - // Since we have a bound socket this shouldn't happen. - debug!(?err, "Failed to get local addr"); - return Err(err.into()); - } - }; - let Some(local_addr) = local_addr.as_socket() else { - // Since we only bind to IP sockets this shouldn't happen. - debug!("Failed to convert local addr to socketaddr"); - return Err(anyhow!("Failed to convert local addr to socketaddr")); - }; - - tokio::spawn(handle_connection( - sock, + let (addr, len) = &mut *self.accept_addr; + *len = std::mem::size_of_val(addr) as _; + + let sqe = opcode::Accept::new( + Loop::FIXED_OUTSIDE_FD, + addr as *mut libc::sockaddr_storage as *mut _, + len as *mut libc::socklen_t as *mut _, + ) + .build() + .user_data(Loop::outside_rx_user_data(Self::ACCEPT_IDX)); + + #[allow(unsafe_code)] + // SAFETY: The address buffers are owned by `self` and`` self` is owned by the `io::Loop` + unsafe { + sq.push(&sqe)? + }; + + sq.sync(); + + Ok(()) + } + + fn complete_accept( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + ) -> Result<()> { + let res = io_uring_res(cqe.result()).with_context(|| "outside accept")?; + // Should be impossible since as a twos complement i32 it would be negative. + assert!(res as u32 != Self::ACCEPT_IDX); + + let peer_addr = socket_addr_from_sockaddr(&self.accept_addr.0, self.accept_addr.1)?; + + #[allow(unsafe_code)] + // SAFETY: We trust that on success `accept(2)` returns a + // valid socket fd. + let sock = unsafe { TcpStream::from_raw_fd(res) }; + sock.set_nodelay(true)?; + + let local_addr = match sock.local_addr() { + Ok(local_addr) => local_addr, + Err(err) => { + // Since we have a bound socket this shouldn't happen. + debug!(?err, "Failed to get local addr"); + return Err(err.into()); + } + }; + + let rx_buf = BytesMut::with_capacity(ConnectionState::RX_BUFFER_SIZE); + + let phase = if self.proxy_protocol { + ConnectionPhase::ProxyInitial { local_addr } + } else { + let buffer = TcpSocketBuffer::new(self.tx_queue.clone(), self.tx_buffer_size, &sock); + let outside_io = Arc::new(TcpSocket { + buffer: buffer.clone(), peer_addr, + }); + let conn = self.conn_manager.create_streaming_connection( + Version::MINIMUM, local_addr, - self.conn_manager.clone(), - self.proxy_protocol, - )); + outside_io, + )?; + ConnectionPhase::Connected { conn, buffer } + }; + + let mut state = ConnectionState { + sock, + rx_buf, + phase, + tx_buffer_size: self.tx_buffer_size, + }; + + // Before we add to the hash, due to insert taking ownership + // of state, but we cannot complete anything until we return + // so that's ok. + state.push_rx(sq)?; + + self.fd_map.insert(res as u32, state); + + Ok(()) + } +} + +impl UringIoSource for TcpServer { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } + + fn push_initial_ops(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { + self.push_accept(sq) + } + + fn complete_rx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + if idx == Self::ACCEPT_IDX { + if let Err(err) = self.complete_accept(sq, cqe) { + // Some of the errors which accept(2) can return + // + // while never a good thing needn't necessarily be + // fatal to the entire server and prevent us from + // servicing existing connections or potentially + // new connections in the future. + warn!(?err, "Failed to accept a new connection"); + metrics::connection_accept_failed(); + } + self.push_accept(sq)?; + return Ok(()); + } + + let (idx, cancelling) = if (idx & Self::RX_CANCEL_IDX_BIT) != 0 { + (idx - Self::RX_CANCEL_IDX_BIT, true) + } else { + (idx, false) + }; + + use std::collections::hash_map::Entry; + + match self.fd_map.entry(idx) { + Entry::Occupied(entry) if cancelling => { + let nr = io_uring_res(cqe.result()).with_context(|| "Cancelling")?; + info!(fd = idx, nr, "Cancelled"); + entry.remove_entry(); + Ok(()) + } + + Entry::Occupied(mut entry) => { + let state = entry.get_mut(); + match state.complete_rx(sq, cqe, &self.tx_queue, &self.conn_manager) { + Ok(()) => Ok(()), + Err(err) => { + if matches!( + state.phase, + ConnectionPhase::ProxyInitial { .. } | ConnectionPhase::Proxy { .. } + ) { + metrics::connection_accept_proxy_header_failed(); + } + info!("Connection closed: {:?}", err); + state.push_cancel(sq)?; + + if let ConnectionPhase::Connected { conn, .. } = &state.phase { + conn.handle_end_of_stream(); + } + + Ok(()) // Error is for the connection, not the process + } + } + } + + // Likely we raced with a cancellation request + Entry::Vacant(_) => { + match io_uring_res(cqe.result()) { + Err(err) => info!("complete unknown tcp rx {idx} with {err}"), + Ok(res) => info!("complete unknown tcp rx {idx} with {res}"), + }; + Ok(()) + } + } + } + + fn complete_tx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + use std::collections::hash_map::Entry; + match self.fd_map.entry(idx) { + Entry::Occupied(mut entry) => { + let state = entry.get_mut(); + match state.complete_tx(sq, cqe) { + Ok(()) => Ok(()), + Err(err) => { + info!("Connection closed: {:?}", err); + state.push_cancel(sq)?; + Ok(()) // Error is for the connection, not the process + } + } + } + + // Likely we raced with a cancellation request + Entry::Vacant(_) => { + match io_uring_res(cqe.result()) { + Err(err) => info!("complete unknown tcp tx {idx} with {err}"), + Ok(res) => info!("complete unknown tcp tx {idx} with {res}"), + }; + Ok(()) + } } } } diff --git a/lightway-server/src/io/outside/udp.rs b/lightway-server/src/io/outside/udp.rs index 2554108f..436c3d8d 100644 --- a/lightway-server/src/io/outside/udp.rs +++ b/lightway-server/src/io/outside/udp.rs @@ -1,26 +1,37 @@ -mod cmsg; +//! UdpServer UringIoSource +//! +//! Uses uring indexes: +//! +//! Loop::outside_rx_user_data: +//! - 0..UdpServer::rx.len(): A set of recv requests +//! +//! Loop::outside_tx_user_data: +//! - Managed by TxQueue + +pub(crate) mod cmsg; use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, - sync::{Arc, RwLock}, + os::fd::{AsRawFd as _, RawFd}, + sync::{Arc, Mutex, MutexGuard, RwLock}, }; -use anyhow::Result; -use async_trait::async_trait; -use bytes::BytesMut; +use anyhow::{Context as _, Result}; +use bytes::{Bytes, BytesMut}; use bytesize::ByteSize; -use lightway_app_utils::sockopt::socket_enable_pktinfo; +use io_uring::opcode; +use lightway_app_utils::{ + sockaddr_from_socket_addr, socket_addr_from_sockaddr, sockopt::socket_enable_pktinfo, +}; use lightway_core::{ - ConnectionType, Header, IOCallbackResult, OutsideIOSendCallback, OutsidePacket, SessionId, - Version, MAX_OUTSIDE_MTU, + ConnectionType, CowBytes, Header, IOCallbackResult, OutsideIOSendCallback, OutsidePacket, + SessionId, Version, MAX_OUTSIDE_MTU, }; -use socket2::{MaybeUninitSlice, MsgHdr, MsgHdrMut, SockAddr, SockRef}; -use tokio::io::Interest; -use tracing::{info, warn}; +use tracing::warn; use crate::{connection_manager::ConnectionManager, metrics}; -use super::Server; +use super::{io_uring_res, iovec, msghdr, Loop, TxQueue, UringIoSource}; enum BindMode { UnspecifiedAddress { local_port: u16 }, @@ -44,52 +55,70 @@ impl std::fmt::Display for BindMode { } } -fn send_to_socket( - sock: &Arc, - buf: &[u8], - peer_addr: &SockAddr, +fn queue_tx( + mut tx_queue: MutexGuard, + buf: Bytes, + peer_addr: libc::sockaddr_storage, + peer_addr_len: libc::socklen_t, pktinfo: Option, ) -> IOCallbackResult { - let res = sock.try_io(Interest::WRITABLE, || { - let sock = SockRef::from(sock.as_ref()); - let bufs = [std::io::IoSlice::new(buf)]; - - let msghdr = MsgHdr::new().with_addr(peer_addr).with_buffers(&bufs); + let len = buf.len(); - const CMSG_SIZE: usize = cmsg::Message::space::(); - let mut cmsg = cmsg::BufferMut::::zeroed(); + let Some((slot, state)) = tx_queue.take_slot() else { + return IOCallbackResult::WouldBlock; + }; - let msghdr = if let Some(pktinfo) = pktinfo { - let mut builder = cmsg.builder(); - builder.fill_next(libc::SOL_IP, libc::IP_PKTINFO, pktinfo)?; + state.iov[0].iov_base = buf.as_ptr() as *mut _; + state.iov[0].iov_len = buf.len(); + state.addr = peer_addr; + state.addr_len = peer_addr_len; - msghdr.with_control(cmsg.as_ref()) - } else { - msghdr - }; + state.buf = Some(buf); - sock.sendmsg(&msghdr, 0) - }); + state.msghdr.msg_name = &mut state.addr as *mut libc::sockaddr_storage as *mut _; + state.msghdr.msg_namelen = state.addr_len; + state.msghdr.msg_iov = state.iov.as_mut_ptr() as *mut _; + state.msghdr.msg_iovlen = state.iov.len(); - match res { - Ok(nr) => IOCallbackResult::Ok(nr), - Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { - IOCallbackResult::WouldBlock + if let Some(pktinfo) = pktinfo { + let mut builder = state.control.builder(); + if let Err(err) = builder.fill_next(libc::SOL_IP, libc::IP_PKTINFO, pktinfo) { + return IOCallbackResult::Err(err); } - Err(err) => IOCallbackResult::Err(err), + state.msghdr.msg_control = state.control.as_mut_ptr() as *mut _; + // Get from builder? + state.msghdr.msg_controllen = std::mem::size_of_val(&state.control) as _; + } else { + state.msghdr.msg_control = std::ptr::null_mut(); + state.msghdr.msg_controllen = 0; } + + let sqe = opcode::SendMsg::new(Loop::FIXED_OUTSIDE_FD, state.msghdr.as_mut_ptr()).build(); + + #[allow(unsafe_code)] + // SAFETY: + // - slot was optained from take_slot above + // - The buffer is owned by `state` and which is owned by the `TxRing` + unsafe { + tx_queue.push_outside_slot(slot, sqe) + }; + + IOCallbackResult::Ok(len) } struct UdpSocket { - sock: Arc, - peer_addr: RwLock<(SocketAddr, SockAddr)>, + tx_queue: Arc>, + peer_addr: RwLock<(SocketAddr, libc::sockaddr_storage, libc::socklen_t)>, reply_pktinfo: Option, } impl OutsideIOSendCallback for UdpSocket { - fn send(&self, buf: &[u8]) -> IOCallbackResult { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + let buf = buf.into_owned(); let peer_addr = self.peer_addr.read().unwrap(); - send_to_socket(&self.sock, buf, &peer_addr.1, self.reply_pktinfo) + let tx_queue = self.tx_queue.lock().unwrap(); + + queue_tx(tx_queue, buf, peer_addr.1, peer_addr.2, self.reply_pktinfo) } fn peer_addr(&self) -> SocketAddr { @@ -99,24 +128,67 @@ impl OutsideIOSendCallback for UdpSocket { fn set_peer_addr(&self, addr: SocketAddr) -> SocketAddr { let mut peer_addr = self.peer_addr.write().unwrap(); let old_addr = peer_addr.0; - *peer_addr = (addr, addr.into()); + + let (raw_addr, raw_addr_len) = sockaddr_from_socket_addr(addr); + + *peer_addr = (addr, raw_addr, raw_addr_len); old_addr } } +struct RxState { + buf: BytesMut, + addr: libc::sockaddr_storage, + control: cmsg::Buffer<{ Self::CONTROL_SIZE }>, + iov: [iovec; 1], + msghdr: msghdr, +} + +impl RxState { + const CONTROL_SIZE: usize = cmsg::Message::space::(); + + fn new() -> Self { + let mut buf = BytesMut::with_capacity(MAX_OUTSIDE_MTU); + let iov = iovec::new(libc::iovec { + iov_base: buf.as_mut_ptr() as *mut _, + iov_len: buf.capacity(), + }); + #[allow(unsafe_code)] + Self { + buf, + // SAFETY: All zeroes is a valid sockaddr + addr: unsafe { std::mem::zeroed() }, + control: cmsg::Buffer::new(), + iov: [iov], + // SAFETY: All zeroes is a valid msghdr + msghdr: unsafe { std::mem::zeroed() }, + } + } +} pub(crate) struct UdpServer { conn_manager: Arc, - sock: Arc, + sock: Arc, bind_mode: BindMode, + tx_queue: Arc>, + // The contents are used for I/O syscalls, ensure they stay put. + rx: Vec, } impl UdpServer { pub(crate) async fn new( + nr_slots: u32, conn_manager: Arc, + tx_queue: Arc>, bind_address: SocketAddr, udp_buffer_size: ByteSize, ) -> Result { - let sock = Arc::new(tokio::net::UdpSocket::bind(bind_address).await?); + tracing::info!("UdpServer with {nr_slots} slots"); + + let sock = tokio::net::UdpSocket::bind(bind_address).await?; + + let sock = sock.into_std()?; + sock.set_nonblocking(false)?; + let sock = Arc::new(sock); let bind_mode = if bind_address.ip().is_unspecified() { BindMode::UnspecifiedAddress { @@ -137,20 +209,31 @@ impl UdpServer { socket_enable_pktinfo(&sock)?; } + let rx = (0..nr_slots).map(|_| RxState::new()).collect(); + + #[allow(unsafe_code)] Ok(Self { conn_manager, sock, bind_mode, + tx_queue, + rx, }) } - async fn data_received( + fn data_received( &mut self, peer_addr: SocketAddr, + raw_peer_addr: libc::sockaddr_storage, + raw_peer_addr_len: libc::socklen_t, local_addr: SocketAddr, reply_pktinfo: Option, - buf: &mut BytesMut, + idx: u32, ) { + #[allow(unsafe_code)] + // SAFETY: The caller must already have validated this. + let buf = &mut unsafe { self.rx.get_unchecked_mut(idx as usize) }.buf; + let pkt = OutsidePacket::Wire(buf, ConnectionType::Datagram); let pkt = match self.conn_manager.parse_raw_outside_packet(pkt) { Ok(hdr) => hdr, @@ -184,8 +267,8 @@ impl UdpServer { local_addr, || { Arc::new(UdpSocket { - sock: self.sock.clone(), - peer_addr: RwLock::new((peer_addr, peer_addr.into())), + tx_queue: self.tx_queue.clone(), + peer_addr: RwLock::new((peer_addr, raw_peer_addr, raw_peer_addr_len)), reply_pktinfo, }) }, @@ -194,7 +277,7 @@ impl UdpServer { match conn_result { Ok(conn) => conn, Err(_e) => { - self.send_reject(peer_addr.into(), reply_pktinfo).await; + self.send_reject(raw_peer_addr, raw_peer_addr_len, reply_pktinfo); return; } } @@ -233,7 +316,12 @@ impl UdpServer { } } - async fn send_reject(&self, peer_addr: SockAddr, reply_pktinfo: Option) { + fn send_reject( + &self, + peer_addr: libc::sockaddr_storage, + peer_addr_len: libc::socklen_t, + pktinfo: Option, + ) { metrics::udp_rejected_session(); let msg = Header { version: Version::MINIMUM, @@ -244,92 +332,98 @@ impl UdpServer { let mut buf = BytesMut::with_capacity(Header::WIRE_SIZE); msg.append_to_wire(&mut buf); + let tx_queue = self.tx_queue.lock().unwrap(); + // Ignore failure to send. - let _ = send_to_socket(&self.sock, &buf, &peer_addr, reply_pktinfo); + + let _ = queue_tx(tx_queue, buf.freeze(), peer_addr, peer_addr_len, pktinfo); + } + + fn push_rx(&mut self, sq: &mut io_uring::SubmissionQueue, idx: u32) -> Result<()> { + let rx = &mut self.rx[idx as usize]; + + // Recover full capacity in case this is a resubmit + rx.buf.clear(); + rx.buf.reserve(MAX_OUTSIDE_MTU); + + rx.msghdr = msghdr::new(libc::msghdr { + msg_name: &mut rx.addr as *mut libc::sockaddr_storage as *mut _, + msg_namelen: std::mem::size_of::() as _, + msg_iov: rx.iov.as_mut_ptr() as *mut libc::msghdr as *mut _, + msg_iovlen: rx.iov.len(), + msg_control: rx.control.as_mut_ptr() as *mut _, + msg_controllen: RxState::CONTROL_SIZE, + msg_flags: 0, + }); + let sqe = opcode::RecvMsg::new(Loop::FIXED_OUTSIDE_FD, rx.msghdr.as_mut_ptr()) + .build() + .user_data(Loop::outside_rx_user_data(idx)); + + #[allow(unsafe_code)] + // SAFETY: The buffer is owned by `self.rx` and `self` is owned by the `io::Loop` + unsafe { + sq.push(&sqe)? + }; + + sq.sync(); + + Ok(()) } } -#[async_trait] -impl Server for UdpServer { - async fn run(&mut self) -> Result<()> { - info!("Accepting traffic on {}", self.bind_mode); - let mut buf = BytesMut::with_capacity(MAX_OUTSIDE_MTU); - loop { - // Recover full capacity - buf.clear(); - buf.reserve(MAX_OUTSIDE_MTU); - - let (peer_addr, local_addr, reply_pktinfo) = self - .sock - .async_io(Interest::READABLE, || { - let sock = SockRef::from(self.sock.as_ref()); - let mut raw_buf = [MaybeUninitSlice::new(buf.spare_capacity_mut())]; - - #[allow(unsafe_code)] - let mut peer_sock_addr = { - // SAFETY: sockaddr_storage is defined - // () - // as being a suitable size and alignment for - // "all supported protocol-specific address - // structures" in the underlying OS APIs. - // - // All zeros is a valid representation, - // corresponding to the `ss_family` having a - // value of `AF_UNSPEC`. - let addr_storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; - let len = std::mem::size_of_val(&addr_storage) as libc::socklen_t; - // SAFETY: We initialized above as `AF_UNSPEC` - // so the storage is correct from that - // angle. The `recvmsg` call will change this - // which should be ok since `sockaddr_storage` - // is big enough. - unsafe { SockAddr::new(addr_storage, len) } - }; - - // We only need this control buffer if - // `self.bind_mode.needs_pktinfo()`. However the hit - // on reserving a fairly small on stack buffer - // should be small compared with the conditional - // logic and dynamically sized buffer needed to - // allow omitting it. - const SIZE: usize = cmsg::Message::space::(); - let mut control = cmsg::Buffer::::new(); - - let mut msg = MsgHdrMut::new() - .with_addr(&mut peer_sock_addr) - .with_buffers(&mut raw_buf) - .with_control(control.as_mut()); - - let len = sock.recvmsg(&mut msg, 0)?; - - if msg.flags().is_truncated() { - metrics::udp_recv_truncated(); - } +impl UringIoSource for UdpServer { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } + + fn push_initial_ops(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { + for idx in 0..self.rx.len() as u32 { + self.push_rx(sq, idx)? + } + Ok(()) + } + + fn complete_rx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + let res = { + let res = io_uring_res(cqe.result()).with_context(|| "outside recvmsg completion")?; + + let rx = &mut self.rx[idx as usize]; + + #[allow(unsafe_code)] + // SAFETY: We rely on recv_from giving us the correct size + unsafe { + rx.buf.set_len(res as usize); + } - let control_len = msg.control_len(); - - // SAFETY: We rely on recv_from giving us the correct size - #[allow(unsafe_code)] - unsafe { - buf.set_len(len) - }; - - let Some(peer_addr) = peer_sock_addr.as_socket() else { - // Since we only bind to IP sockets this shouldn't happen. - metrics::udp_recv_invalid_addr(); - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "failed to convert local addr to socketaddr", - )); - }; - - #[allow(unsafe_code)] - let (local_addr, reply_pktinfo) = match self.bind_mode { - BindMode::UnspecifiedAddress { local_port } => { - let Some((local_addr, reply_pktinfo)) = + let raw_peer_addr = rx.addr; + let raw_peer_addr_len = rx.msghdr.msg_namelen; + + let peer_addr = match socket_addr_from_sockaddr(&raw_peer_addr, raw_peer_addr_len) { + Ok(a) => a, + Err(err) => { + metrics::udp_recv_invalid_addr(); + return Err(err.into()); + } + }; + + if (rx.msghdr.msg_flags & libc::MSG_TRUNC) != 0 { + metrics::udp_recv_truncated(); + } + + let control_len = rx.msghdr.msg_controllen; + + #[allow(unsafe_code)] + let (local_addr, reply_pktinfo) = match self.bind_mode { + BindMode::UnspecifiedAddress { local_port } => { + let Some((local_addr, reply_pktinfo)) = // SAFETY: The call to `recvmsg` above updated // the control buffer length field. - unsafe { control.iter(control_len) }.find_map(|cmsg| { + unsafe { rx.control.iter(control_len) }.find_map(|cmsg| { match cmsg { cmsg::Message::IpPktinfo(pi) => { // From https://pubs.opengroup.org/onlinepubs/009695399/basedefs/netinet/in.h.html @@ -355,22 +449,42 @@ impl Server for UdpServer { // and we have set IP_PKTINFO // sockopt this shouldn't happen. metrics::udp_recv_missing_pktinfo(); + println!("outside user data {:016x}, idx {:x} had no PKTINFO", cqe.user_data(),idx); return Err(std::io::Error::new( std::io::ErrorKind::Other, "recvmsg did not return IP_PKTINFO", - )); + ).into()); }; - (local_addr, Some(reply_pktinfo)) - } - BindMode::SpecificAddress { local_addr } => (local_addr, None), - }; + (local_addr, Some(reply_pktinfo)) + } + BindMode::SpecificAddress { local_addr } => (local_addr, None), + }; + + self.data_received( + peer_addr, + raw_peer_addr, + raw_peer_addr_len, + local_addr, + reply_pktinfo, + idx, + ); + + Ok(()) + }; - Ok((peer_addr, local_addr, reply_pktinfo)) - }) - .await?; + // Queue another recv + self.push_rx(sq, idx)?; - self.data_received(peer_addr, local_addr, reply_pktinfo, &mut buf) - .await; - } + res + } + + fn complete_tx( + &mut self, + _sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + let _ = self.tx_queue.lock().unwrap().complete(cqe, idx); + Ok(()) } } diff --git a/lightway-server/src/io/outside/udp/cmsg.rs b/lightway-server/src/io/outside/udp/cmsg.rs index 0aa94c8c..9e553185 100644 --- a/lightway-server/src/io/outside/udp/cmsg.rs +++ b/lightway-server/src/io/outside/udp/cmsg.rs @@ -8,8 +8,8 @@ impl Buffer { Self([std::mem::MaybeUninit::::uninit(); N]) } - pub(crate) fn as_mut(&mut self) -> &mut [std::mem::MaybeUninit] { - &mut self.0 + pub(crate) fn as_mut_ptr(&mut self) -> *mut u8 { + self.0.as_mut_ptr() as *mut _ } /// # Safety @@ -137,6 +137,10 @@ impl BufferMut { _phantom: std::marker::PhantomData, } } + + pub(crate) fn as_mut_ptr(&mut self) -> *mut u8 { + self.0.as_mut_ptr() as *mut _ + } } impl AsRef<[u8]> for BufferMut { diff --git a/lightway-server/src/io/tx.rs b/lightway-server/src/io/tx.rs new file mode 100644 index 00000000..fa027313 --- /dev/null +++ b/lightway-server/src/io/tx.rs @@ -0,0 +1,170 @@ +//! TxQueue, helper/queue for UringIoSource tx implementations +//! +//! Uses uring indexes: +//! +//! Loop::outside_rx_user_data: +//! - None +//! +//! Loop::outside_tx_user_data: +//! - 0..TxQueue::state.len() + +use std::collections::VecDeque; + +use anyhow::{Context as _, Result}; +use bytes::Bytes; +use io_uring::squeue::Entry as SEntry; + +use super::{ + ffi::{iovec, msghdr}, + io_uring_res, + outside::udp::cmsg, + Loop, SubmissionQueue, Submitter, +}; + +pub(super) struct TxState { + pub buf: Option, + pub addr: libc::sockaddr_storage, + pub addr_len: libc::socklen_t, + pub control: cmsg::BufferMut<{ Self::CONTROL_SIZE }>, + pub iov: [iovec; 1], + pub msghdr: msghdr, +} + +impl TxState { + const CONTROL_SIZE: usize = cmsg::Message::space::(); + fn new() -> Self { + #[allow(unsafe_code)] + Self { + buf: None, + // SAFETY: All zeroes is a valid sockaddr + addr: unsafe { std::mem::zeroed() }, + addr_len: 0, + control: cmsg::BufferMut::zeroed(), + + // SAFETY: All zeroes is a valid iov + iov: [unsafe { std::mem::zeroed() }], + // SAFETY: All zeroes is a valid msghdr + msghdr: unsafe { std::mem::zeroed() }, + } + } +} + +pub struct TxQueue { + sqe_ring: VecDeque, + slots: Vec, + state: Vec, +} + +impl TxQueue { + pub fn new(nr_slots: u32) -> Self { + tracing::info!("TxQueue with {nr_slots} slots"); + let sqe_ring = VecDeque::with_capacity(nr_slots as usize); + let (slots, state) = (0..nr_slots).map(|nr| (nr, TxState::new())).unzip(); + + Self { + sqe_ring, + slots, + state, + } + } + + /// Reserve a slot, the returned value should be passed to + /// `push_*_slot` after setting up the state and constructing an + /// sqe. + pub(super) fn take_slot(&mut self) -> Option<(u32, &mut TxState)> { + let slot = self.slots.pop()?; + let state = &mut self.state[slot as usize]; + Some((slot, state)) + } + + #[allow(unsafe_code)] + /// Push an inside request entry to the tx queue. + /// + /// Callers are responsible for calling `::complete` when the + /// request completes to free the slot. + /// + /// # Safety: + /// + /// - idx must have been previously obtained from `take_slot` + /// - sqe must meet the safety requirements + /// + /// Any sqe userdata will be overwritten + pub(super) unsafe fn push_inside_slot(&mut self, idx: u32, sqe: SEntry) { + let sqe = sqe.user_data(Loop::inside_tx_user_data(idx)); + self.sqe_ring.push_back(sqe); + } + + #[allow(unsafe_code)] + /// Push an outside request entry to the tx queue. + /// + /// Callers are responsible for calling `::complete` when the + /// request completes to free the slot. + /// + /// # Safety: + /// + /// - idx must have been previously obtained from `take_slot` + /// - sqe must meet the safety requirements + /// + /// Any sqe userdata will be overwritten + pub(super) unsafe fn push_outside_slot(&mut self, idx: u32, sqe: SEntry) { + let sqe = sqe.user_data(Loop::outside_tx_user_data(idx)); + self.sqe_ring.push_back(sqe); + } + + #[allow(unsafe_code)] + /// Push an arbitrary entry to the tx queue. Does not consume a slot. + /// + /// Callers are responsible for completion and should not call + /// `::complete`. + /// + /// Use this for SQEs which do not require an entry in `::state` + /// to keep buffers live and/or for which the calling code wants + /// to manage the idx space itself. + /// + /// # Safety: + /// + /// - sqe must meet the safety requirements + pub(super) unsafe fn push(&mut self, sqe: SEntry) { + self.sqe_ring.push_back(sqe); + } + + /// Push all entries (added by `push_*_slot` or `push`) to the uring. + pub(super) fn drain(&mut self, submitter: &Submitter, sq: &mut SubmissionQueue) -> Result<()> { + while let Some(sqe) = self.sqe_ring.pop_front() { + if sq.is_full() { + match submitter.submit() { + Ok(_) => (), + Err(ref err) if err.raw_os_error() == Some(libc::EBUSY) => break, + Err(err) => return Err(err.into()), + } + sq.sync(); + } + + #[allow(unsafe_code)] + // SAFETY: Safe according to the safety requirements of `push_*_slot` or `push` + unsafe { + sq.push(&sqe)? + }; + + sq.sync() + } + + Ok(()) + } + + /// Complete an entry added with `push_*_slot`, intended to be + /// called from the IoUringSource's `complete_*` method. Note that + /// users of plain `push` are responsible for their own + /// completion. + pub(super) fn complete(&mut self, cqe: io_uring::cqueue::Entry, idx: u32) -> Result<()> { + let _res = io_uring_res(cqe.result()).with_context(|| "tx completion")?; + + let slot = &mut self.state[idx as usize]; + + slot.buf = None; + + self.slots.push(idx); + + Ok(()) + } +} diff --git a/lightway-server/src/lib.rs b/lightway-server/src/lib.rs index 0bcd6a3e..9eaeda8c 100644 --- a/lightway-server/src/lib.rs +++ b/lightway-server/src/lib.rs @@ -17,26 +17,19 @@ pub use lightway_core::{ use anyhow::{anyhow, Context, Result}; use ipnet::Ipv4Net; use lightway_app_utils::{connection_ticker_cb, TunConfig}; -use lightway_core::{ - ipv4_update_destination, AuthMethod, BuilderPredicates, ConnectionError, IOCallbackResult, - InsideIpConfig, Secret, ServerContextBuilder, -}; -use pnet::packet::ipv4::Ipv4Packet; +use lightway_core::{AuthMethod, BuilderPredicates, InsideIpConfig, Secret, ServerContextBuilder}; use std::{ collections::HashMap, net::{IpAddr, Ipv4Addr, SocketAddr}, path::PathBuf, - sync::Arc, + sync::{Arc, Mutex}, time::Duration, }; -use tokio::task::JoinHandle; use tracing::{info, warn}; -use crate::io::inside::InsideIO; use crate::ip_manager::IpManager; use connection_manager::ConnectionManager; -use io::outside::Server; fn debug_fmt_plugin_list( list: &PluginFactoryList, @@ -112,15 +105,28 @@ pub struct ServerConfig ServerAuth>> { /// Enable Post Quantum Crypto pub enable_pqc: bool, - /// Enable IO-uring interface for Tunnel - pub enable_tun_iouring: bool, - /// IO-uring submission queue count pub iouring_entry_count: usize, /// IO-uring sqpoll idle time. pub iouring_sqpoll_idle_time: Duration, + /// Number of concurrent TUN device read requests to issue to + /// IO-uring. Setting this too large may negatively impact + /// performance. + pub iouring_tun_rx_count: u32, + + /// Configure TUN in blocking mode. + pub iouring_tun_blocking: bool, + + /// Number of concurrent UDP socket recvmsg requests to issue to + /// IO-uring. + pub iouring_udp_rx_count: u32, + + /// Maximum number of concurrent UDP + TUN sendmsg/write requests + /// to issue to IO-uring. + pub iouring_tx_count: u32, + /// The key update interval for DTLS/TLS 1.3 connections pub key_update_interval: Duration, @@ -140,11 +146,40 @@ pub struct ServerConfig ServerAuth>> { /// UDP Buffer size for the server pub udp_buffer_size: ByteSize, + + /// TCP Buffer size for the server + pub tcp_buffer_size: ByteSize, +} + +impl ServerAuth> + Sync + Send + 'static> ServerConfig { + fn validate(&self) -> Result<()> { + let mut required_uring_slots = + self.iouring_tun_rx_count as usize + self.iouring_tx_count as usize + 1; // cancellation request + + required_uring_slots += match self.connection_type { + // this should be 2 * max connections, but max connections + // is unknown, assume at least 1. + ConnectionType::Stream => 2, + ConnectionType::Datagram => self.iouring_udp_rx_count as usize, + }; + + if self.iouring_entry_count < required_uring_slots { + return Err(anyhow!( + "iouring_entry_count too small {} < {}", + self.iouring_entry_count, + required_uring_slots + )); + } + + Ok(()) + } } pub async fn server ServerAuth> + Sync + Send + 'static>( config: ServerConfig, ) -> Result<()> { + config.validate()?; + let server_key = Secret::PemFile(&config.server_key); let server_cert = Secret::PemFile(&config.server_cert); @@ -175,12 +210,16 @@ pub async fn server ServerAuth> + Sync + Send + 'stati let connection_type = config.connection_type; let auth = Arc::new(AuthAdapter(config.auth)); - let iouring = if config.enable_tun_iouring { - Some((config.iouring_entry_count, config.iouring_sqpoll_idle_time)) - } else { - None - }; - let inside_io = Arc::new(io::inside::Tun::new(config.tun_config, iouring).await?); + let tx_queue = Arc::new(Mutex::new(io::TxQueue::new(config.iouring_tx_count))); + + let tun = io::inside::Tun::new( + config.iouring_tun_rx_count, + config.iouring_tun_blocking, + config.tun_config, + config.lightway_client_ip, + ip_manager.clone(), + tx_queue.clone(), + )?; let ctx = ServerContextBuilder::new( connection_type, @@ -188,7 +227,7 @@ pub async fn server ServerAuth> + Sync + Send + 'stati server_key, auth, ip_manager.clone(), - inside_io.clone().into_io_send_callback(), + tun.inside_io_sender(), )? .with_schedule_tick_cb(connection_ticker_cb) .with_key_update_interval(config.key_update_interval) @@ -201,69 +240,43 @@ pub async fn server ServerAuth> + Sync + Send + 'stati tokio::spawn(statistics::run(conn_manager.clone(), ip_manager.clone())); - let mut server: Box = match connection_type { - ConnectionType::Datagram => Box::new( + let server = match connection_type { + ConnectionType::Datagram => io::OutsideIoSource::Udp( io::outside::UdpServer::new( + config.iouring_udp_rx_count, conn_manager.clone(), + tx_queue.clone(), config.bind_address, config.udp_buffer_size, ) .await?, ), - ConnectionType::Stream => Box::new( + ConnectionType::Stream => io::OutsideIoSource::Tcp( io::outside::TcpServer::new( conn_manager.clone(), + tx_queue.clone(), config.bind_address, config.proxy_protocol, + config.tcp_buffer_size, ) .await?, ), }; - let inside_io_loop: JoinHandle> = tokio::spawn(async move { - loop { - let mut buf = match inside_io.recv_buf().await { - IOCallbackResult::Ok(buf) => buf, - IOCallbackResult::WouldBlock => continue, // Spuriously failed to read, keep waiting - IOCallbackResult::Err(err) => { - break Err(anyhow!(err).context("InsideIO recv buf error")); - } - }; - - // Find connection based on client ip (dest ip) and forward packet - let packet = Ipv4Packet::new(buf.as_ref()); - let Some(packet) = packet else { - eprintln!("Invalid inside packet size (less than Ipv4 header)!"); - continue; - }; - let conn = ip_manager.find_connection(packet.get_destination()); - - // Update destination IP address to client's ip - ipv4_update_destination(buf.as_mut(), config.lightway_client_ip); - - if let Some(conn) = conn { - match conn.inside_data_received(&mut buf) { - Ok(()) => {} - Err(ConnectionError::InvalidState) => { - // Skip forwarding packet when offline - metrics::tun_rejected_packet_invalid_state(); - } - Err(ConnectionError::InvalidInsidePacket(_)) => { - // Skip processing invalid packet - metrics::tun_rejected_packet_invalid_inside_packet(); - } - Err(err) => { - let fatal = err.is_fatal(conn.connection_type()); - metrics::tun_rejected_packet_invalid_other(fatal); - if fatal { - conn.handle_end_of_stream(); - } - } - } - } else { - metrics::tun_rejected_packet_no_connection(); - } - } + // On exit dropping _io_handle will cause EPIPE to be delivered to + // io_cancel. This causes the corresponding read request on the + // ring to complete and signal the loop should exit. + let (_io_handle, io_cancel) = tokio::net::unix::pipe::pipe()?; + let io_cancel = io_cancel.into_blocking_fd()?; + let io_task = tokio::task::spawn_blocking(move || { + let io_loop = io::Loop::new( + config.iouring_entry_count, + config.iouring_sqpoll_idle_time, + tx_queue, + server, + tun, + )?; + io_loop.run(io_cancel) }); let (ctrlc_tx, ctrlc_rx) = tokio::sync::oneshot::channel(); @@ -275,8 +288,7 @@ pub async fn server ServerAuth> + Sync + Send + 'stati })?; tokio::select! { - err = server.run() => err.context("Outside IO loop exited"), - io = inside_io_loop => io.map_err(|e| anyhow!(e).context("Inside IO loop panicked"))?.context("Inside IO loop exited"), + r = io_task => r?.context("IO task exited"), _ = ctrlc_rx => { info!("Sigterm or Sigint received"); conn_manager.close_all_connections(); @@ -284,3 +296,62 @@ pub async fn server ServerAuth> + Sync + Send + 'stati } } } + +#[cfg(test)] +mod tests { + use super::*; + + use test_case::test_case; + + struct Auth; + + impl ServerAuth> for Auth {} + + #[test_case(ConnectionType::Stream, 0, 0, 0, 0 => panics "iouring_entry_count too small")] + #[test_case(ConnectionType::Stream, 3, 0, 0, 0 => ())] + #[test_case(ConnectionType::Stream, 20, 5, 0, 13 => panics "iouring_entry_count too small")] + #[test_case(ConnectionType::Stream, 21, 5, 0, 13 => ())] + #[test_case(ConnectionType::Stream, 22, 5, 0, 13 => ())] + #[test_case(ConnectionType::Stream, 7, 1, 10_000, 3 => ())] // udp rx count irrelevant for stream + #[test_case(ConnectionType::Datagram, 0, 0, 0, 0 => panics "iouring_entry_count too small")] + #[test_case(ConnectionType::Datagram, 1, 0, 0, 0 => ())] + #[test_case(ConnectionType::Datagram, 25, 5, 7, 13 => panics "iouring_entry_count too small")] + #[test_case(ConnectionType::Datagram, 26, 5, 7, 13 => ())] + #[test_case(ConnectionType::Datagram, 27, 5, 7, 13 => ())] + fn validate_iouring_entry_count( + connection_type: ConnectionType, + iouring_entry_count: usize, + iouring_tun_rx_count: u32, + iouring_udp_rx_count: u32, + iouring_tx_count: u32, + ) { + let config = ServerConfig { + connection_type, + auth: Auth, + server_cert: "".into(), + server_key: "".into(), + tun_config: Default::default(), + ip_pool: "10.0.0.0/8".parse().unwrap(), + ip_map: Default::default(), + tun_ip: None, + lightway_server_ip: "1.1.1.1".parse().unwrap(), + lightway_client_ip: "2.2.2.2".parse().unwrap(), + lightway_dns_ip: "3.3.3.3".parse().unwrap(), + enable_pqc: false, + iouring_entry_count, + iouring_sqpoll_idle_time: Default::default(), + iouring_tun_rx_count, + iouring_tun_blocking: false, + iouring_udp_rx_count, + iouring_tx_count, + key_update_interval: Default::default(), + inside_plugins: Default::default(), + outside_plugins: Default::default(), + bind_address: "0.0.0.0:0".parse().unwrap(), + proxy_protocol: false, + udp_buffer_size: Default::default(), + tcp_buffer_size: Default::default(), + }; + config.validate().unwrap(); + } +} diff --git a/lightway-server/src/main.rs b/lightway-server/src/main.rs index 649f2ed3..7fde294c 100644 --- a/lightway-server/src/main.rs +++ b/lightway-server/src/main.rs @@ -130,15 +130,19 @@ async fn main() -> Result<()> { lightway_client_ip: config.lightway_client_ip, lightway_dns_ip: config.lightway_dns_ip, enable_pqc: config.enable_pqc, - enable_tun_iouring: config.enable_tun_iouring, iouring_entry_count: config.iouring_entry_count, iouring_sqpoll_idle_time: config.iouring_sqpoll_idle_time.into(), + iouring_tun_rx_count: config.iouring_tun_rx_count, + iouring_tun_blocking: config.iouring_tun_blocking, + iouring_udp_rx_count: config.iouring_udp_rx_count, + iouring_tx_count: config.iouring_tx_count, key_update_interval: config.key_update_interval.into(), inside_plugins: Default::default(), outside_plugins: Default::default(), bind_address: config.bind_address, proxy_protocol: config.proxy_protocol, udp_buffer_size: config.udp_buffer_size, + tcp_buffer_size: config.tcp_buffer_size, }; server(config).await diff --git a/tests/Earthfile b/tests/Earthfile index 19669363..fb01446d 100644 --- a/tests/Earthfile +++ b/tests/Earthfile @@ -83,13 +83,13 @@ run-udp-floating-ip-test: run-udp-pmtud-test: DO +TEST --MODE=udp --SERVER_PORT=27690 --CLIENT_EXTRA_ARGS="--enable-pmtud" -# run-udp-iouring-test runs e2e test using UDP and default cipher with io-uring enabled +# run-udp-iouring-test runs e2e test using UDP and default cipher with client io-uring enabled run-udp-iouring-test: - DO +TEST --MODE=udp --SERVER_PORT=27690 --SERVER_EXTRA_ARGS="--enable-tun-iouring" --CLIENT_EXTRA_ARGS="--enable-tun-iouring" + DO +TEST --MODE=udp --SERVER_PORT=27690 --CLIENT_EXTRA_ARGS="--enable-tun-iouring" -# run-tcp-iouring-test runs e2e test using TCP and default cipher with io-uring enabled +# run-tcp-iouring-test runs e2e test using TCP and default cipher with client io-uring enabled run-tcp-iouring-test: - DO +TEST --MODE=tcp --SERVER_PORT=27690 --SERVER_EXTRA_ARGS="--enable-tun-iouring" --CLIENT_EXTRA_ARGS="--enable-tun-iouring" + DO +TEST --MODE=tcp --SERVER_PORT=27690 --CLIENT_EXTRA_ARGS="--enable-tun-iouring" # run-udp-min-inside-mtu-test runs e2e test of UDP with client using smallest valid inside MTU run-udp-min-inside-mtu-test: @@ -121,7 +121,7 @@ run-tcp-keepalive-test: # run-udp-single-threaded-test runs e2e test of UDP with server and client using a single Tokio worker thread run-udp-single-threaded-test: - DO +TEST --MODE=udp --SERVER_PORT=27690 --SERVER_TOKIO_WORKER_THREADS=1 --SERVER_EXTRA_ARGS="--enable-tun-iouring" --CLIENT_TOKIO_WORKER_THREADS=1 --CLIENT_EXTRA_ARGS="--keepalive-interval=2s --keepalive-timeout=6s --enable-tun-iouring --enable-pmtud" + DO +TEST --MODE=udp --SERVER_PORT=27690 --SERVER_TOKIO_WORKER_THREADS=1 --CLIENT_TOKIO_WORKER_THREADS=1 --CLIENT_EXTRA_ARGS="--keepalive-interval=2s --keepalive-timeout=6s --enable-tun-iouring --enable-pmtud" # run-tcp-single-threaded-test runs e2e test of TCP with server and client using a single Tokio worker thread run-tcp-single-threaded-test: