From 863cf5ef633b0eec5d2661ae8309d6a530bb0c08 Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Fri, 6 Feb 2026 11:14:20 +0100 Subject: [PATCH 01/18] wip: add middleware support for tcp server --- crates/hyli-net/src/tcp.rs | 1 + crates/hyli-net/src/tcp/middleware/impls.rs | 194 ++++++++++++++++ crates/hyli-net/src/tcp/middleware/mod.rs | 237 ++++++++++++++++++++ 3 files changed, 432 insertions(+) create mode 100644 crates/hyli-net/src/tcp/middleware/impls.rs create mode 100644 crates/hyli-net/src/tcp/middleware/mod.rs diff --git a/crates/hyli-net/src/tcp.rs b/crates/hyli-net/src/tcp.rs index 76804604f..bfcd53076 100644 --- a/crates/hyli-net/src/tcp.rs +++ b/crates/hyli-net/src/tcp.rs @@ -1,6 +1,7 @@ #[cfg(feature = "turmoil")] pub mod intercept; pub mod p2p_server; +pub mod middleware; pub mod tcp_client; pub mod tcp_server; diff --git a/crates/hyli-net/src/tcp/middleware/impls.rs b/crates/hyli-net/src/tcp/middleware/impls.rs new file mode 100644 index 000000000..de56bec98 --- /dev/null +++ b/crates/hyli-net/src/tcp/middleware/impls.rs @@ -0,0 +1,194 @@ +use std::collections::VecDeque; +use std::time::Duration; + +use tokio::time::Instant; + +use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders}; + +use super::{SendErrorContext, TcpServerMiddleware}; + +#[derive(Default)] +pub struct DropOnError; + +impl TcpServerMiddleware for DropOnError { + type EventOut = TcpEvent; + + fn on_event( + &mut self, + server: &mut TcpServer, + event: TcpEvent, + ) -> Option { + match &event { + TcpEvent::Error { socket_addr, .. } | TcpEvent::Closed { socket_addr } => { + server.drop_peer_stream(socket_addr.clone()); + } + _ => {} + } + Some(event) + } +} + +struct PendingSend { + socket_addr: String, + msg: Res, + headers: TcpHeaders, + retries: usize, + next_attempt_at: Instant, +} + +pub struct RetryingSend { + max_retries: usize, + base_delay: Duration, + max_per_tick: usize, + queue: VecDeque>, +} + +impl RetryingSend { + pub fn new(max_retries: usize, base_delay: Duration) -> Self { + Self { + max_retries, + base_delay, + max_per_tick: 64, + queue: VecDeque::new(), + } + } + + pub fn max_per_tick(mut self, max_per_tick: usize) -> Self { + self.max_per_tick = max_per_tick.max(1); + self + } +} + +impl TcpServerMiddleware for RetryingSend +where + Res: Clone, +{ + type EventOut = TcpEvent; + + fn on_event( + &mut self, + _server: &mut TcpServer, + event: TcpEvent, + ) -> Option { + Some(event) + } + + fn on_send_error( + &mut self, + server: &mut TcpServer, + ctx: SendErrorContext, + ) -> anyhow::Result<()> { + if !server.connected(&ctx.socket_addr) { + return Err(ctx.error); + } + self.queue.push_back(PendingSend { + socket_addr: ctx.socket_addr, + msg: ctx.msg, + headers: ctx.headers, + retries: 0, + next_attempt_at: Instant::now() + self.base_delay, + }); + Ok(()) + } + + fn on_tick(&mut self, server: &mut TcpServer) { + if self.queue.is_empty() { + return; + } + + let now = Instant::now(); + let mut processed = 0usize; + let mut remaining = VecDeque::with_capacity(self.queue.len()); + + while let Some(mut pending) = self.queue.pop_front() { + if pending.next_attempt_at > now || processed >= self.max_per_tick { + remaining.push_back(pending); + continue; + } + + if !server.connected(&pending.socket_addr) { + continue; + } + + match server.send( + pending.socket_addr.clone(), + pending.msg.clone(), + pending.headers.clone(), + ) { + Ok(()) => {} + Err(_) => { + let next_retries = pending.retries + 1; + if next_retries > self.max_retries { + server.drop_peer_stream(pending.socket_addr); + } else { + pending.retries = next_retries; + pending.next_attempt_at = + now + self.base_delay.mul_f64(next_retries as f64); + remaining.push_back(pending); + } + } + } + + processed += 1; + } + + self.queue = remaining; + } + + fn next_wakeup(&self) -> Option { + self.queue + .iter() + .map(|pending| pending.next_attempt_at) + .min() + } +} + +pub struct DropOnErrorAndRetry { + drop_on_error: DropOnError, + retrying_send: RetryingSend, +} + +impl DropOnErrorAndRetry { + pub fn new(max_retries: usize, base_delay: Duration) -> Self { + Self { + drop_on_error: DropOnError, + retrying_send: RetryingSend::new(max_retries, base_delay), + } + } + + pub fn max_per_tick(mut self, max_per_tick: usize) -> Self { + self.retrying_send = self.retrying_send.max_per_tick(max_per_tick); + self + } +} + +impl TcpServerMiddleware for DropOnErrorAndRetry +where + Res: Clone, +{ + type EventOut = TcpEvent; + + fn on_event( + &mut self, + server: &mut TcpServer, + event: TcpEvent, + ) -> Option { + self.drop_on_error.on_event(server, event) + } + + fn on_send_error( + &mut self, + server: &mut TcpServer, + ctx: SendErrorContext, + ) -> anyhow::Result<()> { + self.retrying_send.on_send_error(server, ctx) + } + + fn on_tick(&mut self, server: &mut TcpServer) { + self.retrying_send.on_tick(server) + } + + fn next_wakeup(&self) -> Option { + self.retrying_send.next_wakeup() + } +} diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs new file mode 100644 index 000000000..8c35558e9 --- /dev/null +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -0,0 +1,237 @@ +//! TcpServer middleware helpers. +//! +//! This module provides a wrapper around `TcpServer` that preserves the +//! `listen_next()` API while allowing synchronous middleware actions +//! (drop-on-error) and listen-driven retries (send retry queue progressed +//! inside `listen_next()`). +//! +//! # Example +//! ```no_run +//! use std::time::Duration; +//! use hyli_net::tcp::{ +//! tcp_server::TcpServer, +//! middleware::{TcpServerWithMiddleware, DropOnErrorAndRetry}, +//! }; +//! # use hyli_net::tcp::{TcpEvent, TcpMessageLabel}; +//! # use borsh::{BorshDeserialize, BorshSerialize}; +//! # +//! # #[derive(Clone, Debug, BorshSerialize, BorshDeserialize)] +//! # struct Req; +//! # impl TcpMessageLabel for Req { +//! # fn message_label(&self) -> &'static str { "Req" } +//! # } +//! # #[derive(Clone, Debug, BorshSerialize, BorshDeserialize)] +//! # struct Res; +//! # impl TcpMessageLabel for Res { +//! # fn message_label(&self) -> &'static str { "Res" } +//! # } +//! # +//! # async fn example() -> anyhow::Result<()> { +//! let inner = TcpServer::::start(0, "Example").await?; +//! let middleware = DropOnErrorAndRetry::new(10, Duration::from_millis(100)); +//! let mut server = TcpServerWithMiddleware::new(inner, middleware); +//! +//! while let Some(event) = server.listen_next().await { +//! match event { +//! TcpEvent::Message { socket_addr, data, headers } => { +//! // Handle inbound message... +//! let _ = server.send(socket_addr, Res, headers); +//! } +//! TcpEvent::Closed { .. } | TcpEvent::Error { .. } => { +//! // Drop-on-error is handled by middleware. +//! } +//! } +//! } +//! # Ok(()) +//! # } +//! ``` +//! +//! You can also map events to a different output type. This example maps +//! `TcpEvent::Message` to the `Req` payload and filters out `Error/Closed`. +//! ```no_run +//! # use hyli_net::tcp::{tcp_server::TcpServer, TcpEvent, TcpMessageLabel}; +//! # use hyli_net::tcp::middleware::{TcpServerWithMiddleware, TcpServerMiddleware}; +//! # use borsh::{BorshDeserialize, BorshSerialize}; +//! # #[derive(Clone, Debug, BorshSerialize, BorshDeserialize)] +//! # struct Req; +//! # impl TcpMessageLabel for Req { +//! # fn message_label(&self) -> &'static str { "Req" } +//! # } +//! # #[derive(Clone, Debug, BorshSerialize, BorshDeserialize)] +//! # struct Res; +//! # impl TcpMessageLabel for Res { +//! # fn message_label(&self) -> &'static str { "Res" } +//! # } +//! # +//! # struct MessageOnly; +//! # impl TcpServerMiddleware for MessageOnly { +//! # type EventOut = Req; +//! # fn on_event(&mut self, _server: &mut TcpServer, event: TcpEvent) -> Option { +//! # match event { TcpEvent::Message { data, .. } => Some(data), _ => None } +//! # } +//! # } +//! # +//! # async fn example() -> anyhow::Result<()> { +//! let inner = TcpServer::::start(0, "Example").await?; +//! let mut server = TcpServerWithMiddleware::new(inner, MessageOnly); +//! while let Some(req) = server.listen_next().await { +//! // req is already the decoded payload +//! } +//! # Ok(()) +//! # } +//! ``` + +use tokio::time::Instant; + +use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders}; + +mod impls; + +pub use impls::{DropOnError, DropOnErrorAndRetry, RetryingSend}; + +pub struct SendErrorContext { + pub socket_addr: String, + pub msg: Res, + pub headers: TcpHeaders, + pub error: anyhow::Error, +} + +pub trait TcpServerMiddleware { + type EventOut = TcpEvent; + + /// Transform or filter inbound events before they are exposed to callers. + /// Returning `None` will cause the wrapper to keep listening. + fn on_event( + &mut self, + _server: &mut TcpServer, + event: TcpEvent, + ) -> Option; + + /// Handle outbound send errors. The default behavior is to surface the error. + /// Implementations can enqueue retries or drop peers. + fn on_send_error( + &mut self, + _server: &mut TcpServer, + ctx: SendErrorContext, + ) -> anyhow::Result<()> { + Err(ctx.error) + } + + /// Called on each `listen_next()` iteration before waiting for events. + /// Use this to drive retry queues or housekeeping. + fn on_tick(&mut self, _server: &mut TcpServer) {} + + /// Optional wakeup time for the next middleware action. If present, the + /// wrapper will `select!` between the next event and this deadline. + fn next_wakeup(&self) -> Option { + None + } +} + +/// Common interface for `TcpServer` and middleware wrappers. +pub trait TcpServerLike { + type EventOut; + + /// Receive the next inbound event (or mapped output if wrapped). + async fn listen_next(&mut self) -> Option; + /// Send a response to a peer. + fn send( + &mut self, + socket_addr: String, + msg: Res, + headers: TcpHeaders, + ) -> anyhow::Result<()>; +} + +pub struct TcpServerWithMiddleware { + inner: TcpServer, + middleware: M, +} + +impl TcpServerWithMiddleware +where + M: TcpServerMiddleware, +{ + pub fn new(inner: TcpServer, middleware: M) -> Self { + Self { inner, middleware } + } + + pub fn inner(&self) -> &TcpServer { + &self.inner + } + + pub fn inner_mut(&mut self) -> &mut TcpServer { + &mut self.inner + } +} + +impl TcpServerLike for TcpServer { + type EventOut = TcpEvent; + + async fn listen_next(&mut self) -> Option { + TcpServer::listen_next(self).await + } + + fn send( + &mut self, + socket_addr: String, + msg: Res, + headers: TcpHeaders, + ) -> anyhow::Result<()> { + TcpServer::send(self, socket_addr, msg, headers) + } +} + +impl TcpServerLike for TcpServerWithMiddleware +where + M: TcpServerMiddleware, + Res: Clone, +{ + type EventOut = M::EventOut; + + async fn listen_next(&mut self) -> Option { + loop { + self.middleware.on_tick(&mut self.inner); + if let Some(deadline) = self.middleware.next_wakeup() { + let now = Instant::now(); + if deadline <= now { + continue; + } + tokio::select! { + event = self.inner.listen_next() => { + let event = event?; + return self.middleware.on_event(&mut self.inner, event); + } + _ = tokio::time::sleep_until(deadline) => { + continue; + } + } + } else { + let event = self.inner.listen_next().await?; + return self.middleware.on_event(&mut self.inner, event); + } + } + } + + fn send( + &mut self, + socket_addr: String, + msg: Res, + headers: TcpHeaders, + ) -> anyhow::Result<()> { + let msg_clone = msg.clone(); + let headers_clone = headers.clone(); + match self.inner.send(socket_addr.clone(), msg, headers) { + Ok(()) => Ok(()), + Err(error) => self.middleware.on_send_error( + &mut self.inner, + SendErrorContext { + socket_addr, + msg: msg_clone, + headers: headers_clone, + error, + }, + ), + } + } +} From 46b948eb37e463c7da4296653c84adf40efc7d41 Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Thu, 12 Feb 2026 17:38:54 +0100 Subject: [PATCH 02/18] Extract DA retry logic into a dedicated middleware --- crates/hyli-net/src/tcp/middleware/impls.rs | 380 +++++++++++++++++++- crates/hyli-net/src/tcp/middleware/mod.rs | 203 +++++++++-- crates/hyli-net/src/tcp/p2p_server.rs | 8 +- crates/hyli-net/src/tcp/tcp_client.rs | 4 +- crates/hyli-net/src/tcp/tcp_server.rs | 20 +- crates/hyli-net/tests/basic.rs | 10 +- src/data_availability.rs | 328 ++++++----------- 7 files changed, 677 insertions(+), 276 deletions(-) diff --git a/crates/hyli-net/src/tcp/middleware/impls.rs b/crates/hyli-net/src/tcp/middleware/impls.rs index de56bec98..39b6a1043 100644 --- a/crates/hyli-net/src/tcp/middleware/impls.rs +++ b/crates/hyli-net/src/tcp/middleware/impls.rs @@ -1,16 +1,22 @@ -use std::collections::VecDeque; +use std::collections::{HashSet, VecDeque}; +use std::marker::PhantomData; use std::time::Duration; +use borsh::{BorshDeserialize, BorshSerialize}; use tokio::time::Instant; -use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders}; +use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders, TcpMessageLabel}; -use super::{SendErrorContext, TcpServerMiddleware}; +use super::{SendErrorContext, SendErrorOutcome, TcpServerMiddleware}; #[derive(Default)] pub struct DropOnError; -impl TcpServerMiddleware for DropOnError { +impl TcpServerMiddleware for DropOnError +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, +{ type EventOut = TcpEvent; fn on_event( @@ -28,6 +34,346 @@ impl TcpServerMiddleware for DropOnError { } } +fn min_wakeup(lhs: Option, rhs: Option) -> Option { + match (lhs, rhs) { + (Some(a), Some(b)) => Some(a.min(b)), + (Some(a), None) => Some(a), + (None, Some(b)) => Some(b), + (None, None) => None, + } +} + +pub struct EventPipeline { + first: A, + second: B, +} + +impl EventPipeline { + pub fn new(first: A, second: B) -> Self { + Self { first, second } + } + + pub(crate) fn second_mut(&mut self) -> &mut B { + &mut self.second + } +} + +impl TcpServerMiddleware for EventPipeline +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + A: TcpServerMiddleware>, + B: TcpServerMiddleware, +{ + type EventOut = B::EventOut; + + fn on_event( + &mut self, + server: &mut TcpServer, + event: TcpEvent, + ) -> Option { + let event = self.first.on_event(server, event)?; + self.second.on_event(server, event) + } + + fn on_send_error( + &mut self, + server: &mut TcpServer, + ctx: &SendErrorContext, + ) -> SendErrorOutcome { + match self.first.on_send_error(server, ctx) { + SendErrorOutcome::Unhandled(_) => self.second.on_send_error(server, ctx), + outcome => outcome, + } + } + + fn on_tick(&mut self, server: &mut TcpServer) { + self.first.on_tick(server); + self.second.on_tick(server); + } + + fn next_wakeup(&self) -> Option { + min_wakeup(self.first.next_wakeup(), self.second.next_wakeup()) + } +} + +pub struct TcpInboundMessage { + pub socket_addr: String, + pub data: Req, + pub headers: TcpHeaders, +} + +#[derive(Default)] +pub struct MessageOnly; + +impl TcpServerMiddleware for MessageOnly +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, +{ + type EventOut = Req; + + fn on_event( + &mut self, + _server: &mut TcpServer, + event: TcpEvent, + ) -> Option { + match event { + TcpEvent::Message { data, .. } => Some(data), + TcpEvent::Closed { .. } | TcpEvent::Error { .. } => None, + } + } +} + +#[derive(Default)] +pub struct MessageWithMeta; + +impl TcpServerMiddleware for MessageWithMeta +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, +{ + type EventOut = TcpInboundMessage; + + fn on_event( + &mut self, + _server: &mut TcpServer, + event: TcpEvent, + ) -> Option { + match event { + TcpEvent::Message { + socket_addr, + data, + headers, + } => Some(TcpInboundMessage { + socket_addr, + data, + headers, + }), + TcpEvent::Closed { .. } | TcpEvent::Error { .. } => None, + } + } +} + +struct QueuedOutbound { + msg: Res, + headers: TcpHeaders, + retries: usize, + next_attempt_at: Instant, +} + +pub struct QueuedSendWithRetry { + max_retries: usize, + base_delay: Duration, + max_per_tick: usize, + streaming_peers: HashSet, + queues: std::collections::HashMap>>, + _marker: PhantomData, +} + +impl QueuedSendWithRetry { + pub fn new(max_retries: usize, base_delay: Duration) -> Self { + Self { + max_retries, + base_delay, + max_per_tick: 64, + streaming_peers: HashSet::new(), + queues: std::collections::HashMap::new(), + _marker: PhantomData, + } + } + + pub fn max_per_tick(mut self, max_per_tick: usize) -> Self { + self.max_per_tick = max_per_tick.max(1); + self + } + + pub fn register_streaming_peer(&mut self, socket_addr: String) { + self.streaming_peers.insert(socket_addr); + } + + pub fn unregister_streaming_peer(&mut self, socket_addr: &str) { + self.streaming_peers.remove(socket_addr); + self.queues.remove(socket_addr); + } + + pub fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) { + self.queues + .entry(socket_addr) + .or_default() + .push_back(QueuedOutbound { + msg, + headers, + retries: 0, + next_attempt_at: Instant::now(), + }); + } + + pub fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) + where + Res: Clone, + { + for peer in self.streaming_peers.clone() { + self.enqueue_to_peer(peer, msg.clone(), headers.clone()); + } + } +} + +pub trait QueuedSenderMiddleware +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, +{ + fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders); + fn register_streaming_peer(&mut self, socket_addr: String); + fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders); +} + +impl QueuedSenderMiddleware for QueuedSendWithRetry +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, +{ + fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) { + QueuedSendWithRetry::enqueue_to_peer(self, socket_addr, msg, headers); + } + + fn register_streaming_peer(&mut self, socket_addr: String) { + QueuedSendWithRetry::register_streaming_peer(self, socket_addr); + } + + fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) { + QueuedSendWithRetry::enqueue_to_streaming_peers(self, msg, headers); + } +} + +impl QueuedSenderMiddleware for EventPipeline +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + A: TcpServerMiddleware>, + B: QueuedSenderMiddleware, +{ + fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) { + self.second_mut().enqueue_to_peer(socket_addr, msg, headers); + } + + fn register_streaming_peer(&mut self, socket_addr: String) { + self.second_mut().register_streaming_peer(socket_addr); + } + + fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) { + self.second_mut().enqueue_to_streaming_peers(msg, headers); + } +} + +impl TcpServerMiddleware for QueuedSendWithRetry +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, +{ + type EventOut = TcpInboundMessage; + + fn on_event( + &mut self, + _server: &mut TcpServer, + event: TcpEvent, + ) -> Option { + match event { + TcpEvent::Message { + socket_addr, + data, + headers, + } => Some(TcpInboundMessage { + socket_addr, + data, + headers, + }), + TcpEvent::Closed { socket_addr } => { + self.unregister_streaming_peer(&socket_addr); + None + } + TcpEvent::Error { socket_addr, .. } => { + self.unregister_streaming_peer(&socket_addr); + None + } + } + } + + fn on_send_error( + &mut self, + server: &mut TcpServer, + ctx: &SendErrorContext, + ) -> SendErrorOutcome { + if !server.connected(&ctx.socket_addr) { + self.unregister_streaming_peer(&ctx.socket_addr); + return SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())); + } + self.enqueue_to_peer(ctx.socket_addr.clone(), ctx.msg.clone(), ctx.headers.clone()); + SendErrorOutcome::RetryScheduled + } + + fn on_tick(&mut self, server: &mut TcpServer) { + if self.queues.is_empty() { + return; + } + + let mut processed = 0usize; + let now = Instant::now(); + let peers: Vec = self.queues.keys().cloned().collect(); + + for peer in peers { + if processed >= self.max_per_tick { + break; + } + + if !server.connected(&peer) { + self.unregister_streaming_peer(&peer); + continue; + } + + let Some(queue) = self.queues.get_mut(&peer) else { + continue; + }; + + let Some(front) = queue.front_mut() else { + continue; + }; + + if front.next_attempt_at > now { + continue; + } + + match server.send(peer.clone(), front.msg.clone(), front.headers.clone()) { + Ok(()) => { + queue.pop_front(); + } + Err(_) => { + front.retries += 1; + if front.retries > self.max_retries { + server.drop_peer_stream(peer.clone()); + self.unregister_streaming_peer(&peer); + } else { + front.next_attempt_at = + now + self.base_delay.mul_f64(front.retries as f64); + } + } + } + + processed += 1; + } + + self.queues.retain(|_, queue| !queue.is_empty()); + } + + fn next_wakeup(&self) -> Option { + self.queues + .values() + .filter_map(|queue| queue.front().map(|pending| pending.next_attempt_at)) + .min() + } +} + struct PendingSend { socket_addr: String, msg: Res, @@ -61,7 +407,8 @@ impl RetryingSend { impl TcpServerMiddleware for RetryingSend where - Res: Clone, + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, { type EventOut = TcpEvent; @@ -76,19 +423,19 @@ where fn on_send_error( &mut self, server: &mut TcpServer, - ctx: SendErrorContext, - ) -> anyhow::Result<()> { + ctx: &SendErrorContext, + ) -> SendErrorOutcome { if !server.connected(&ctx.socket_addr) { - return Err(ctx.error); + return SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())); } self.queue.push_back(PendingSend { - socket_addr: ctx.socket_addr, - msg: ctx.msg, - headers: ctx.headers, + socket_addr: ctx.socket_addr.clone(), + msg: ctx.msg.clone(), + headers: ctx.headers.clone(), retries: 0, next_attempt_at: Instant::now() + self.base_delay, }); - Ok(()) + SendErrorOutcome::RetryScheduled } fn on_tick(&mut self, server: &mut TcpServer) { @@ -164,7 +511,8 @@ impl DropOnErrorAndRetry { impl TcpServerMiddleware for DropOnErrorAndRetry where - Res: Clone, + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, { type EventOut = TcpEvent; @@ -179,8 +527,8 @@ where fn on_send_error( &mut self, server: &mut TcpServer, - ctx: SendErrorContext, - ) -> anyhow::Result<()> { + ctx: &SendErrorContext, + ) -> SendErrorOutcome { self.retrying_send.on_send_error(server, ctx) } @@ -189,6 +537,6 @@ where } fn next_wakeup(&self) -> Option { - self.retrying_send.next_wakeup() + as TcpServerMiddleware>::next_wakeup(&self.retrying_send) } } diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs index 8c35558e9..6a0b373c8 100644 --- a/crates/hyli-net/src/tcp/middleware/mod.rs +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -81,13 +81,20 @@ //! # } //! ``` +use std::ops::{Deref, DerefMut}; + use tokio::time::Instant; -use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders}; +use borsh::{BorshDeserialize, BorshSerialize}; + +use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders, TcpMessageLabel}; mod impls; -pub use impls::{DropOnError, DropOnErrorAndRetry, RetryingSend}; +pub use impls::{ + DropOnError, DropOnErrorAndRetry, EventPipeline, MessageOnly, MessageWithMeta, + QueuedSendWithRetry, QueuedSenderMiddleware, RetryingSend, TcpInboundMessage, +}; pub struct SendErrorContext { pub socket_addr: String, @@ -96,8 +103,23 @@ pub struct SendErrorContext { pub error: anyhow::Error, } -pub trait TcpServerMiddleware { - type EventOut = TcpEvent; +pub enum SendErrorOutcome { + /// Middleware absorbed the error (e.g. logged only). + Handled, + /// Middleware scheduled a retry. + RetryScheduled, + /// Middleware requests dropping the peer. + DropPeer, + /// Middleware did not handle the error; propagate upstream. + Unhandled(anyhow::Error), +} + +pub trait TcpServerMiddleware +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, +{ + type EventOut; /// Transform or filter inbound events before they are exposed to callers. /// Returning `None` will cause the wrapper to keep listening. @@ -112,9 +134,9 @@ pub trait TcpServerMiddleware { fn on_send_error( &mut self, _server: &mut TcpServer, - ctx: SendErrorContext, - ) -> anyhow::Result<()> { - Err(ctx.error) + ctx: &SendErrorContext, + ) -> SendErrorOutcome { + SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())) } /// Called on each `listen_next()` iteration before waiting for events. @@ -141,15 +163,45 @@ pub trait TcpServerLike { msg: Res, headers: TcpHeaders, ) -> anyhow::Result<()>; + /// Return the currently connected peer socket addresses. + fn connected_clients(&self) -> Box + '_>; + /// Check whether a peer socket is currently connected. + fn connected(&self, socket_addr: &str) -> bool { + self.connected_clients() + .any(|addr| addr == socket_addr) + } + /// Drop and disconnect a peer socket. + fn drop_peer_stream(&mut self, peer_ip: String); + + /// Broadcast by fanout over `connected_clients()` using `send()`. + fn broadcast(&mut self, msg: Res, headers: TcpHeaders) -> Vec<(String, anyhow::Error)> + where + Res: Clone, + { + let peers: Vec = self.connected_clients().cloned().collect(); + let mut errors = Vec::new(); + for peer in peers { + if let Err(error) = self.send(peer.clone(), msg.clone(), headers.clone()) { + errors.push((peer, error)); + } + } + errors + } } -pub struct TcpServerWithMiddleware { +pub struct TcpServerWithMiddleware +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, +{ inner: TcpServer, middleware: M, } impl TcpServerWithMiddleware where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, M: TcpServerMiddleware, { pub fn new(inner: TcpServer, middleware: M) -> Self { @@ -163,29 +215,96 @@ where pub fn inner_mut(&mut self) -> &mut TcpServer { &mut self.inner } + + pub fn middleware(&self) -> &M { + &self.middleware + } + + pub fn middleware_mut(&mut self) -> &mut M { + &mut self.middleware + } } -impl TcpServerLike for TcpServer { - type EventOut = TcpEvent; +impl Deref for TcpServerWithMiddleware +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, +{ + type Target = TcpServer; - async fn listen_next(&mut self) -> Option { - TcpServer::listen_next(self).await + fn deref(&self) -> &Self::Target { + &self.inner } +} - fn send( +impl DerefMut for TcpServerWithMiddleware +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl TcpServerWithMiddleware +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + M: TcpServerMiddleware + QueuedSenderMiddleware, +{ + /// Enqueue a message for ordered, retrying delivery to a specific peer. + pub fn enqueue( &mut self, socket_addr: String, msg: Res, headers: TcpHeaders, ) -> anyhow::Result<()> { - TcpServer::send(self, socket_addr, msg, headers) + self.middleware.enqueue_to_peer(socket_addr, msg, headers); + Ok(()) + } + + /// Immediate send through the underlying TCP server without middleware queueing. + pub fn send_now( + &mut self, + socket_addr: String, + msg: Res, + headers: TcpHeaders, + ) -> anyhow::Result<()> { + self.inner.send(socket_addr, msg, headers) + } + + /// Queue a message for ordered, retrying delivery to a specific peer. + pub fn send( + &mut self, + socket_addr: String, + msg: Res, + headers: TcpHeaders, + ) -> anyhow::Result<()> { + self.enqueue(socket_addr, msg, headers) + } + + /// Mark a peer as a streaming subscriber. + pub fn register_streaming_peer(&mut self, socket_addr: String) { + self.middleware.register_streaming_peer(socket_addr); + } + + /// Queue a message to all registered streaming peers. + pub fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) { + self.middleware.enqueue_to_streaming_peers(msg, headers); + } + + /// Backward-compatible alias for `enqueue_to_streaming_peers`. + pub fn send_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) { + self.enqueue_to_streaming_peers(msg, headers) } } impl TcpServerLike for TcpServerWithMiddleware where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, M: TcpServerMiddleware, - Res: Clone, { type EventOut = M::EventOut; @@ -223,15 +342,59 @@ where let headers_clone = headers.clone(); match self.inner.send(socket_addr.clone(), msg, headers) { Ok(()) => Ok(()), - Err(error) => self.middleware.on_send_error( - &mut self.inner, - SendErrorContext { + Err(error) => { + let ctx = SendErrorContext { socket_addr, msg: msg_clone, headers: headers_clone, error, - }, - ), + }; + match self.middleware.on_send_error(&mut self.inner, &ctx) { + SendErrorOutcome::Handled | SendErrorOutcome::RetryScheduled => Ok(()), + SendErrorOutcome::DropPeer => { + self.inner.drop_peer_stream(ctx.socket_addr.clone()); + Ok(()) + } + SendErrorOutcome::Unhandled(error) => Err(error), + } + } } } + + fn connected_clients(&self) -> Box + '_> { + Box::new(self.inner.connected_clients()) + } + + fn drop_peer_stream(&mut self, peer_ip: String) { + self.inner.drop_peer_stream(peer_ip) + } +} + +impl TcpServerLike for TcpServer +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, +{ + type EventOut = TcpEvent; + + async fn listen_next(&mut self) -> Option { + TcpServer::listen_next(self).await + } + + fn send( + &mut self, + socket_addr: String, + msg: Res, + headers: TcpHeaders, + ) -> anyhow::Result<()> { + TcpServer::send(self, socket_addr, msg, headers) + } + + fn connected_clients(&self) -> Box + '_> { + Box::new(TcpServer::connected_clients(self)) + } + + fn drop_peer_stream(&mut self, peer_ip: String) { + TcpServer::drop_peer_stream(self, peer_ip) + } } diff --git a/crates/hyli-net/src/tcp/p2p_server.rs b/crates/hyli-net/src/tcp/p2p_server.rs index ee22117bc..b2fb099ce 100644 --- a/crates/hyli-net/src/tcp/p2p_server.rs +++ b/crates/hyli-net/src/tcp/p2p_server.rs @@ -1536,7 +1536,7 @@ pub mod tests { .flat_map(|v| v.canals.values().map(|v2| v2.socket_addr.clone())) .collect::>() ), - HashSet::from_iter(p2p_server1.tcp_server.connected_clients()) + HashSet::from_iter(p2p_server1.tcp_server.connected_clients().cloned()) ); assert_eq!( HashSet::::from_iter( @@ -1546,7 +1546,7 @@ pub mod tests { .flat_map(|v| v.canals.values().map(|v2| v2.socket_addr.clone())) .collect::>() ), - HashSet::from_iter(p2p_server2.tcp_server.connected_clients()) + HashSet::from_iter(p2p_server2.tcp_server.connected_clients().cloned()) ); // Both peers should have each other's ValidatorPublicKey in their maps @@ -1810,9 +1810,9 @@ pub mod tests { assert_eq!(p2p_server1.peers.len(), 1); assert_eq!(p2p_server2.peers.len(), 1); - let connected = p2p_server1.tcp_server.connected_clients(); + let mut connected = p2p_server1.tcp_server.connected_clients(); assert_eq!(connected.len(), 1, "Expected a single client socket"); - let socket_addr = connected.first().cloned().unwrap(); + let socket_addr = connected.next().cloned().unwrap(); let send_errors = p2p_server1 diff --git a/crates/hyli-net/src/tcp/tcp_client.rs b/crates/hyli-net/src/tcp/tcp_client.rs index 452e56810..03ab406d8 100644 --- a/crates/hyli-net/src/tcp/tcp_client.rs +++ b/crates/hyli-net/src/tcp/tcp_client.rs @@ -217,14 +217,14 @@ mod tests { client.socket_addr }); - while server.connected_clients().is_empty() { + while server.connected_clients().len() == 0 { _ = tokio::time::timeout(Duration::from_millis(100), server.listen_next()).await; } let client_socket = client_socket.await?; assert_eq!(client_socket.port(), server_socket.port()); - let clients = server.connected_clients(); + let clients: Vec = server.connected_clients().cloned().collect(); assert_eq!(clients.len(), 1); assert_ne!(clients, vec![server_socket.to_string()]); diff --git a/crates/hyli-net/src/tcp/tcp_server.rs b/crates/hyli-net/src/tcp/tcp_server.rs index c0bcf4ef9..20c525baa 100644 --- a/crates/hyli-net/src/tcp/tcp_server.rs +++ b/crates/hyli-net/src/tcp/tcp_server.rs @@ -194,8 +194,10 @@ where } /// Adresses of currently connected clients (no health check) - pub fn connected_clients(&self) -> Vec { - self.sockets.keys().cloned().collect::>() + pub fn connected_clients( + &self, + ) -> impl Iterator { + self.sockets.keys() } pub fn connected(&self, socket_addr: &str) -> bool { @@ -705,7 +707,7 @@ pub mod tests { DataAvailabilityEvent::SignedBlock(Default::default()) ); - let client_socket_addr = server.connected_clients().first().unwrap().clone(); + let client_socket_addr = server.connected_clients().next().unwrap().clone(); server.ping(client_socket_addr)?; @@ -774,7 +776,7 @@ pub mod tests { .await?; _ = tokio::time::timeout(Duration::from_millis(200), server.listen_next()).await; - let client1_addr = server.connected_clients().clone().first().unwrap().clone(); + let client1_addr = server.connected_clients().next().unwrap().clone(); let mut client2 = DAClient::connect( "me2".to_string(), @@ -784,8 +786,7 @@ pub mod tests { _ = tokio::time::timeout(Duration::from_millis(200), server.listen_next()).await; let client2_addr = server .connected_clients() - .clone() - .into_iter() + .cloned() .rfind(|addr| addr != &client1_addr) .unwrap(); @@ -825,7 +826,7 @@ pub mod tests { ) .await?; _ = tokio::time::timeout(Duration::from_millis(200), server.listen_next()).await; - let client1_addr = server.connected_clients().first().unwrap().clone(); + let client1_addr = server.connected_clients().next().unwrap().clone(); let mut client2 = DAClient::connect( "me2".to_string(), @@ -835,8 +836,7 @@ pub mod tests { _ = tokio::time::timeout(Duration::from_millis(200), server.listen_next()).await; let client2_addr = server .connected_clients() - .clone() - .into_iter() + .cloned() .rfind(|addr| addr != &client1_addr) .unwrap(); @@ -968,7 +968,7 @@ pub mod tests { .await?; _ = tokio::time::timeout(Duration::from_millis(200), server.listen_next()).await; - let socket_addr = server.connected_clients().first().unwrap().clone(); + let socket_addr = server.connected_clients().next().unwrap().clone(); { let stored = server .sockets diff --git a/crates/hyli-net/tests/basic.rs b/crates/hyli-net/tests/basic.rs index 680b66c52..5eda0aae1 100644 --- a/crates/hyli-net/tests/basic.rs +++ b/crates/hyli-net/tests/basic.rs @@ -214,7 +214,8 @@ async fn setup_drop_host( // Peers map should match all_other_peers assert_eq!(all_other_peers.len(), p2p.peers.keys().len()); // All current peer sockets should be in tcp server sockets - let connected_tcp_clients = p2p.tcp_server.connected_clients().clone(); + let connected_tcp_clients: Vec = + p2p.tcp_server.connected_clients().cloned().collect(); assert!(p2p.peers.values().flat_map(|t| t.canals.values()).all(|v| connected_tcp_clients.contains(&v.socket_addr))); } } @@ -267,7 +268,8 @@ async fn setup_drop_client( // Peers map should match all_other_peers assert_eq!(all_other_peers.len(), p2p.peers.keys().len()); // All current peer sockets should be in tcp server sockets - let connected_tcp_clients = p2p.tcp_server.connected_clients().clone(); + let connected_tcp_clients: Vec = + p2p.tcp_server.connected_clients().cloned().collect(); assert!(p2p.peers.values().flat_map(|t| t.canals.values()).all(|v| connected_tcp_clients.contains(&v.socket_addr))); break Ok(()) @@ -413,7 +415,7 @@ async fn setup_decode_error_host(peer: String, peers: Vec) -> Result<(), } if armed && !sent_error && peer == "peer-1" { - if let Some(socket) = p2p.tcp_server.connected_clients().first().cloned() { + if let Some(socket) = p2p.tcp_server.connected_clients().next().cloned() { let errors = p2p .tcp_server .raw_send_parallel(vec![socket], vec![255], vec![], "raw") @@ -526,7 +528,7 @@ async fn setup_poisoned_socket_host( && start.elapsed() > Duration::from_millis(500) && p2p.peers.len() == all_other_peers.len() { - if let Some(socket) = p2p.tcp_server.connected_clients().first().cloned() { + if let Some(socket) = p2p.tcp_server.connected_clients().next().cloned() { let errors = p2p .tcp_server .raw_send_parallel(vec![socket], vec![255], vec![], "raw") diff --git a/src/data_availability.rs b/src/data_availability.rs index 1e447e43b..515e5cadb 100644 --- a/src/data_availability.rs +++ b/src/data_availability.rs @@ -2,13 +2,16 @@ // Pick one of the two implementations use hyli_modules::modules::data_availability::blocks_fjall::Blocks; -use hyli_modules::utils::da_codec::DataAvailabilityServer; +use hyli_modules::utils::da_codec::DataAvailabilityServer as RawDataAvailabilityServer; //use hyli_modules::modules::data_availability::blocks_memory::Blocks; use hyli_modules::modules::da_listener::{DaStreamPoll, SignedDaStream}; use hyli_modules::telemetry::{global_meter_or_panic, Counter, Gauge, KeyValue}; use hyli_modules::{bus::SharedMessageBus, modules::Module}; use hyli_modules::{log_error, module_bus_client, module_handle_messages}; -use hyli_net::tcp::TcpEvent; +use hyli_net::tcp::middleware::{ + DropOnError, EventPipeline, QueuedSendWithRetry, TcpInboundMessage, TcpServerLike, + TcpServerWithMiddleware, +}; use tokio::task::JoinHandle; use crate::{ @@ -23,14 +26,22 @@ use anyhow::{Context, Result}; use core::str; use rand::seq::IndexedRandom; use std::{ - collections::{BTreeSet, HashMap, VecDeque}, + collections::{BTreeSet, VecDeque}, time::Duration, }; -use tokio::task::JoinSet; use tracing::{debug, error, info, trace, warn}; use crate::model::SharedRunContext; +type DataAvailabilityServer = TcpServerWithMiddleware< + EventPipeline< + DropOnError, + QueuedSendWithRetry, + >, + DataAvailabilityRequest, + DataAvailabilityEvent, +>; + impl Module for DataAvailability { type Context = SharedRunContext; @@ -74,7 +85,6 @@ impl Module for DataAvailability { blocks, buffered_signed_blocks: BTreeSet::new(), catchupper: DaCatchupper::new(catchup_policy, ctx.config.da_max_frame_length), - peer_send_queues: HashMap::new(), }) } @@ -105,9 +115,6 @@ pub struct DataAvailability { buffered_signed_blocks: BTreeSet, catchupper: DaCatchupper, - - // Track blocks to send to each streaming peer (ensures ordering) - peer_send_queues: HashMap>, } /// Catchup configuration for the Data Availability module. @@ -483,23 +490,25 @@ impl DataAvailability { self.config.da_server_port ); - let mut server = DataAvailabilityServer::start_with_opts( + let inner_server = RawDataAvailabilityServer::start_with_opts( self.config.da_server_port, Some(self.config.da_max_frame_length), format!("DAServer-{}", self.config.id.clone()).as_str(), ) .await?; + let mut server = DataAvailabilityServer::new( + inner_server, + EventPipeline::new( + DropOnError, + QueuedSendWithRetry::new(10, Duration::from_millis(100)).max_per_tick(256), + ), + ); let (catchup_block_sender, mut catchup_block_receiver) = tokio::sync::mpsc::channel::(100); let mut first_hole_receiver = self.start_scanning_for_first_hole(); - // Used to send blocks to clients (indexers/peers) - // This is a JoinSet of tuples containing: - // - The peer IP address to send the blocks to - // - The number of retries for sending the blocks - let mut catchup_joinset: JoinSet<(String, usize)> = tokio::task::JoinSet::new(); let mut catchup_task_checker_ticker = tokio::time::interval(std::time::Duration::from_millis(5000)); let mut storage_metrics_ticker = tokio::time::interval(std::time::Duration::from_secs(30)); @@ -507,7 +516,7 @@ impl DataAvailability { module_handle_messages! { on_self self, listen evt => { - _ = log_error!(self.handle_mempool_event(evt, &mut server, &catchup_block_sender, &mut catchup_joinset).await, "Handling Mempool Event"); + _ = log_error!(self.handle_mempool_event(evt, &mut server, &catchup_block_sender).await, "Handling Mempool Event"); } listen evt => { @@ -517,7 +526,7 @@ impl DataAvailability { listen cmd => { if let GenesisEvent::GenesisBlock(signed_block) = cmd { debug!("🌱 Genesis block received with validators {:?}", signed_block.consensus_proposal.staking_actions.clone()); - _ = log_error!(self.handle_signed_block(signed_block, &mut server, &mut catchup_joinset).await.context("Handling Genesis block"), "Handling GenesisBlock Event"); + _ = log_error!(self.handle_signed_block(signed_block, &mut server).await.context("Handling Genesis block"), "Handling GenesisBlock Event"); } else { _ = log_error!( @@ -548,59 +557,33 @@ impl DataAvailability { } Some(streamed_block) = catchup_block_receiver.recv() => { - if let Some(height) = self.handle_signed_block(streamed_block, &mut server, &mut catchup_joinset).await { + if let Some(height) = self.handle_signed_block(streamed_block, &mut server).await { _ = log_error!(self.catchupper.manage_catchup(height, &catchup_block_sender), "Catchup transition after streamed block"); } } - Some(tcp_event) = server.listen_next() => { - match tcp_event { - TcpEvent::Message { socket_addr, data, .. } => { - match data { - DataAvailabilityRequest::StreamFromHeight(start_height) => { - _ = log_error!( - self.start_streaming_to_peer(start_height, &mut catchup_joinset, &socket_addr).await, - "Starting streaming to peer" - ); - } - DataAvailabilityRequest::BlockRequest(block_height) => { - _ = log_error!( - self.handle_block_request(block_height, &socket_addr, &mut server).await, - "Handling block request" - ); - } - } - } - TcpEvent::Closed { socket_addr } => { - server.drop_peer_stream(socket_addr.clone()); - self.peer_send_queues.remove(&socket_addr); + Some(TcpInboundMessage { socket_addr, data, .. }) = server.listen_next() => { + match data { + DataAvailabilityRequest::StreamFromHeight(start_height) => { + _ = log_error!( + self.start_streaming_to_peer( + start_height, + &mut server, + &socket_addr + ) + .await, + "Starting streaming to peer" + ); } - TcpEvent::Error { socket_addr, error } => { - warn!("TCP error from {}: {}. Dropping socket.", socket_addr, error); - server.drop_peer_stream(socket_addr.clone()); - self.peer_send_queues.remove(&socket_addr); + DataAvailabilityRequest::BlockRequest(block_height) => { + _ = log_error!( + self.handle_block_request(block_height, &socket_addr, &mut server).await, + "Handling block request" + ); } } } - // Send one block to a peer as part of "catchup", - // once we have sent all blocks the peer is presumably synchronised. - Some(Ok((peer_ip, retries))) = catchup_joinset.join_next() => { - - #[cfg(test)] - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - _ = log_error!( - self.handle_send_next_block_to_peer( - peer_ip.clone(), - retries, - &mut catchup_joinset, - &mut server - ).await, - "Send next block to peer" - ); - } - Some(hole) = first_hole_receiver.recv() => { info!("Setting backfill start height as {:?}", &hole); self.catchupper.backfill_start_height = hole; @@ -617,80 +600,6 @@ impl DataAvailability { Ok(()) } - async fn handle_send_next_block_to_peer( - &mut self, - peer_ip: String, - retries: usize, - catchup_joinset: &mut JoinSet<(String, usize)>, - server: &mut DataAvailabilityServer, - ) -> Result<()> { - if !server.connected(&peer_ip) { - debug!("Peer {} disconnected, removing from send queues", peer_ip); - self.peer_send_queues.remove(&peer_ip); - return Ok(()); - } - - if retries > 10 { - warn!( - "Failed to send block, too many retries for peer {}", - &peer_ip - ); - server.drop_peer_stream(peer_ip.clone()); - self.peer_send_queues.remove(&peer_ip); - return Ok(()); - } - - // Get next block from this peer's queue - let hash = match self.peer_send_queues.get_mut(&peer_ip) { - Some(queue) => match queue.pop_front() { - Some(h) => h, - None => { - // Queue is empty - peer is caught up and waiting for new blocks - // Keep them in the map but don't spawn a new task yet - debug!("Peer {} caught up, waiting for new blocks", peer_ip); - return Ok(()); - } - }, - None => { - debug!("Peer {} not in send queues", peer_ip); - return Ok(()); - } - }; - - debug!("📡 Sending block {} to peer {}", &hash, &peer_ip); - if let Ok(Some(signed_block)) = self.blocks.get(&hash) { - // Errors will be handled when sending new blocks, ignore here. - match server.send( - peer_ip.clone(), - DataAvailabilityEvent::SignedBlock(signed_block), - vec![], - ) { - Ok(()) => { - // Successfully sent, continue with next block - catchup_joinset.spawn(async move { (peer_ip, 0) }); - } - Err(_) => { - // Retry sending the same block (put it back at front of queue) - if let Some(queue) = self.peer_send_queues.get_mut(&peer_ip) { - queue.push_front(hash); - } - catchup_joinset.spawn(async move { - tokio::time::sleep(Duration::from_millis(100 * (retries as u64))).await; - (peer_ip, retries + 1) - }); - } - } - } else { - error!( - "Block {} not found in storage while sending to peer {}. Should not happen", - &hash, &peer_ip - ); - // Continue anyway with next block - catchup_joinset.spawn(async move { (peer_ip, 0) }); - } - Ok(()) - } - async fn handle_block_request( &mut self, block_height: BlockHeight, @@ -709,22 +618,18 @@ impl DataAvailability { "📦 Found block at height {}, sending to {}", block_height, socket_addr ); - // Send immediately - this is inserted next in the send queue if let Err(e) = server.send( socket_addr.to_string(), DataAvailabilityEvent::SignedBlock(block), vec![], ) { warn!( - "📦 Error while responding to block request at height {} for {}: {:#}. Dropping socket.", + "📦 Error while responding to block request at height {} for {}: {:#}.", block_height, socket_addr, e ); - server.drop_peer_stream(socket_addr.to_string()); - return Ok(()); } } Ok(None) => { - // Block not in storage - this is a gap error!( "📦 Block at height {} not found in storage, sending BlockNotFound to {}", block_height, socket_addr @@ -735,11 +640,9 @@ impl DataAvailability { vec![], ) { warn!( - "📦 Error while responding BlockNotFound at height {} for {}: {:#}. Dropping socket.", + "📦 Error while responding BlockNotFound at height {} for {}: {:#}.", block_height, socket_addr, e ); - server.drop_peer_stream(socket_addr.to_string()); - return Ok(()); } } Err(e) => { @@ -753,11 +656,9 @@ impl DataAvailability { vec![], ) { warn!( - "📦 Error while responding BlockNotFound at height {} for {}: {:#}. Dropping socket.", + "📦 Error while responding BlockNotFound at height {} for {}: {:#}.", block_height, socket_addr, e ); - server.drop_peer_stream(socket_addr.to_string()); - return Ok(()); } } } @@ -770,7 +671,6 @@ impl DataAvailability { evt: MempoolBlockEvent, tcp_server: &mut DataAvailabilityServer, sender: &tokio::sync::mpsc::Sender, - catchup_joinset: &mut JoinSet<(String, usize)>, ) -> Result<()> { match evt { MempoolBlockEvent::BuiltSignedBlock(signed_block) => { @@ -779,7 +679,7 @@ impl DataAvailability { signed_block.height() ); if let Some(height) = self - .handle_signed_block(signed_block, tcp_server, catchup_joinset) + .handle_signed_block(signed_block, tcp_server) .await { self.catchupper.manage_catchup(height, sender)?; @@ -802,11 +702,16 @@ impl DataAvailability { evt: MempoolStatusEvent, tcp_server: &mut DataAvailabilityServer, ) { - let errors = tcp_server.broadcast(DataAvailabilityEvent::MempoolStatusEvent(evt)); - + let errors = TcpServerLike::broadcast( + tcp_server, + DataAvailabilityEvent::MempoolStatusEvent(evt), + vec![], + ); for (peer, error) in errors { - warn!("Error while broadcasting mempool status event {:#}", error); - tcp_server.drop_peer_stream(peer.clone()); + warn!( + "Error while queueing mempool status event for {}: {:#}", + peer, error + ); } } @@ -815,7 +720,6 @@ impl DataAvailability { &mut self, block: SignedBlock, tcp_server: &mut DataAvailabilityServer, - catchup_joinset: &mut JoinSet<(String, usize)>, ) -> Option { let hash = block.hashed(); // if new block is already handled, ignore it @@ -858,13 +762,12 @@ impl DataAvailability { } else { // store block _ = log_error!( - self.add_processed_block(block.clone(), tcp_server, catchup_joinset) - .await, + self.add_processed_block(block.clone(), tcp_server).await, "Adding processed block" ); } - let highest_processed_height = self.pop_buffer(hash, tcp_server, catchup_joinset).await; + let highest_processed_height = self.pop_buffer(hash, tcp_server).await; _ = log_error!(self.blocks.persist(), "Persisting blocks"); let height = block.height(); @@ -877,7 +780,6 @@ impl DataAvailability { &mut self, mut last_block_hash: ConsensusProposalHash, tcp_server: &mut DataAvailabilityServer, - catchup_joinset: &mut JoinSet<(String, usize)>, ) -> Option { let mut res = None; @@ -901,7 +803,7 @@ impl DataAvailability { let height = first_buffered.height(); if self - .add_processed_block(first_buffered.clone(), tcp_server, catchup_joinset) + .add_processed_block(first_buffered.clone(), tcp_server) .await .is_ok() { @@ -948,36 +850,11 @@ impl DataAvailability { async fn add_processed_block( &mut self, block: SignedBlock, - _tcp_server: &mut DataAvailabilityServer, - catchup_joinset: &mut JoinSet<(String, usize)>, + tcp_server: &mut DataAvailabilityServer, ) -> anyhow::Result<()> { self.store_block(&block)?; - let block_hash = block.hashed(); - - // Add new block to all streaming peer queues to ensure ordering - // (instead of broadcasting which can cause out-of-order delivery) - for (peer, queue) in self.peer_send_queues.iter_mut() { - let was_empty = queue.is_empty(); - queue.push_back(block_hash.clone()); - - // If queue was empty (peer was caught up), restart their send task - if was_empty { - debug!( - "Restarting send task for caught-up peer {} with new block {}", - peer, block_hash - ); - let peer_clone = peer.clone(); - catchup_joinset.spawn(async move { (peer_clone, 0) }); - } else { - debug!( - "Appending block {} to queue for peer {} (queue size: {})", - block_hash, - peer, - queue.len() - ); - } - } + tcp_server.enqueue_to_streaming_peers(DataAvailabilityEvent::SignedBlock(block.clone()), vec![]); // Send the block to NodeState for processing _ = log_error!( @@ -993,7 +870,7 @@ impl DataAvailability { async fn start_streaming_to_peer( &mut self, start_height: BlockHeight, - catchup_joinset: &mut JoinSet<(String, usize)>, + server: &mut DataAvailabilityServer, peer_ip: &str, ) -> Result<()> { let range_start = std::time::Instant::now(); @@ -1019,13 +896,28 @@ impl DataAvailability { processed_block_hashes.len() ); - // Store queue for this peer - new blocks will be appended here - let peer_ip_string = peer_ip.to_string(); - self.peer_send_queues - .insert(peer_ip_string.clone(), processed_block_hashes); - - // Start the send task for this peer - catchup_joinset.spawn(async move { (peer_ip_string, 0) }); + let peer_ip = peer_ip.to_string(); + server.register_streaming_peer(peer_ip.clone()); + for hash in processed_block_hashes { + match self.blocks.get(&hash) { + Ok(Some(block)) => { + _ = server.enqueue( + peer_ip.clone(), + DataAvailabilityEvent::SignedBlock(block), + vec![], + ); + } + Ok(None) => { + warn!("Missing block {} while starting stream to {}", hash, peer_ip); + } + Err(e) => { + warn!( + "Error loading block {} while starting stream to {}: {:#}", + hash, peer_ip, e + ); + } + } + } Ok(()) } @@ -1050,10 +942,10 @@ pub mod tests { use hyli_modules::node_state::module::NodeStateBusClient; use hyli_modules::node_state::NodeState; use hyli_modules::utils::da_codec::DataAvailabilityClient; - use hyli_modules::utils::da_codec::DataAvailabilityServer; + use hyli_modules::utils::da_codec::DataAvailabilityServer as RawDataAvailabilityServer; + use hyli_net::tcp::middleware::{DropOnError, EventPipeline, QueuedSendWithRetry}; use hyli_net::tcp::TcpEvent; use staking::state::Staking; - use tokio::task::JoinSet; struct DataAvailabilityTestCtx { pub node_state_bus: NodeStateBusClient, @@ -1061,6 +953,17 @@ pub mod tests { pub node_state: NodeState, } + async fn make_da_server(port: u16, name: &str) -> super::DataAvailabilityServer { + let inner = RawDataAvailabilityServer::start(port, name).await.unwrap(); + super::DataAvailabilityServer::new( + inner, + EventPipeline::new( + DropOnError, + QueuedSendWithRetry::new(10, Duration::from_millis(100)).max_per_tick(256), + ), + ) + } + impl DataAvailabilityTestCtx { pub async fn new(shared_bus: crate::bus::SharedMessageBus) -> Self { let path = tempfile::tempdir().unwrap().keep(); @@ -1082,7 +985,6 @@ pub mod tests { blocks, buffered_signed_blocks: Default::default(), catchupper: Default::default(), - peer_send_queues: HashMap::new(), }; DataAvailabilityTestCtx { @@ -1095,12 +997,9 @@ pub mod tests { pub async fn handle_signed_block( &mut self, block: SignedBlock, - tcp_server: &mut DataAvailabilityServer, + tcp_server: &mut super::DataAvailabilityServer, ) { - let mut catchup_joinset: JoinSet<(String, usize)> = JoinSet::new(); - self.da - .handle_signed_block(block.clone(), tcp_server, &mut catchup_joinset) - .await; + self.da.handle_signed_block(block.clone(), tcp_server).await; let block_hash = block.hashed(); let Ok(full_block) = self.node_state.handle_signed_block(block) else { tracing::warn!("Error while handling signed block {}", block_hash); @@ -1133,9 +1032,7 @@ pub mod tests { let tmpdir = tempfile::tempdir().unwrap().keep(); let blocks = Blocks::new(&tmpdir).unwrap(); - let mut server = DataAvailabilityServer::start(7898, "DaServer") - .await - .unwrap(); + let mut server = make_da_server(7898, "DaServer").await; let bus = super::DABusClient::new_from_bus(crate::bus::SharedMessageBus::new()).await; let mut da = super::DataAvailability { @@ -1144,7 +1041,6 @@ pub mod tests { blocks, buffered_signed_blocks: Default::default(), catchupper: Default::default(), - peer_send_queues: HashMap::new(), }; let mut block = SignedBlock::default(); let mut blocks = vec![]; @@ -1154,18 +1050,15 @@ pub mod tests { block.consensus_proposal.slot = i; } blocks.reverse(); - let mut catchup_joinset: JoinSet<(String, usize)> = JoinSet::new(); for block in blocks { if block.height().0 == 0 { assert_eq!( - da.handle_signed_block(block, &mut server, &mut catchup_joinset) - .await, + da.handle_signed_block(block, &mut server).await, Some(BlockHeight(9998)) ); } else { assert_eq!( - da.handle_signed_block(block, &mut server, &mut catchup_joinset) - .await, + da.handle_signed_block(block, &mut server).await, None ); } @@ -1197,7 +1090,6 @@ pub mod tests { blocks, buffered_signed_blocks: Default::default(), catchupper: Default::default(), - peer_send_queues: HashMap::new(), }; let mut block = SignedBlock::default(); @@ -1299,7 +1191,7 @@ pub mod tests { #[test_log::test(tokio::test)] async fn test_da_many_clients_only_last_connected() { let port = find_available_port().await; - let mut server = DataAvailabilityServer::start(port, "DaServer") + let mut server = RawDataAvailabilityServer::start(port, "DaServer") .await .unwrap(); @@ -1372,13 +1264,14 @@ pub mod tests { let deadline = tokio::time::Instant::now() + Duration::from_secs(2); loop { - if server.connected_clients().len() == 1 && server.connected(&last_addr) { + let connected_clients: Vec = server.connected_clients().cloned().collect(); + if connected_clients.len() == 1 && server.connected(&last_addr) { break; } if tokio::time::Instant::now() >= deadline { panic!( "Expected only last client connected, got {:?}", - server.connected_clients() + connected_clients ); } if let Ok(Some( @@ -1397,9 +1290,7 @@ pub mod tests { let sender_global_bus = crate::bus::SharedMessageBus::new(); let mut block_sender = TestBusClient::new_from_bus(sender_global_bus.new_handle()).await; let mut da_sender = DataAvailabilityTestCtx::new(sender_global_bus).await; - let mut server = DataAvailabilityServer::start(7890, "DaServer") - .await - .unwrap(); + let mut server = make_da_server(7890, "DaServer").await; let receiver_global_bus = crate::bus::SharedMessageBus::new(); let mut da_receiver = DataAvailabilityTestCtx::new(receiver_global_bus).await; @@ -1544,9 +1435,7 @@ pub mod tests { let sender_global_bus = crate::bus::SharedMessageBus::new(); let mut block_sender = TestBusClient::new_from_bus(sender_global_bus.new_handle()).await; let mut da_sender = DataAvailabilityTestCtx::new(sender_global_bus).await; - let mut server = DataAvailabilityServer::start(7891, "DaServer") - .await - .unwrap(); + let mut server = make_da_server(7891, "DaServer").await; let receiver_global_bus = crate::bus::SharedMessageBus::new(); let mut da_receiver = DataAvailabilityTestCtx::new(receiver_global_bus).await; @@ -1680,7 +1569,6 @@ pub mod tests { blocks: blocks_storage, buffered_signed_blocks: Default::default(), catchupper: Default::default(), - peer_send_queues: HashMap::new(), }; // Start DA server From 909d555a98783b7acee99a018514fde90e1e2305 Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Thu, 12 Feb 2026 17:51:54 +0100 Subject: [PATCH 03/18] Simplify --- crates/hyli-net/src/tcp/middleware/mod.rs | 51 ----------------------- src/data_availability.rs | 6 +-- 2 files changed, 3 insertions(+), 54 deletions(-) diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs index 6a0b373c8..e33579a8d 100644 --- a/crates/hyli-net/src/tcp/middleware/mod.rs +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -81,8 +81,6 @@ //! # } //! ``` -use std::ops::{Deref, DerefMut}; - use tokio::time::Instant; use borsh::{BorshDeserialize, BorshSerialize}; @@ -202,49 +200,10 @@ impl TcpServerWithMiddleware where Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, - M: TcpServerMiddleware, { pub fn new(inner: TcpServer, middleware: M) -> Self { Self { inner, middleware } } - - pub fn inner(&self) -> &TcpServer { - &self.inner - } - - pub fn inner_mut(&mut self) -> &mut TcpServer { - &mut self.inner - } - - pub fn middleware(&self) -> &M { - &self.middleware - } - - pub fn middleware_mut(&mut self) -> &mut M { - &mut self.middleware - } -} - -impl Deref for TcpServerWithMiddleware -where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, -{ - type Target = TcpServer; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl DerefMut for TcpServerWithMiddleware -where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, -{ - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } } impl TcpServerWithMiddleware @@ -274,16 +233,6 @@ where self.inner.send(socket_addr, msg, headers) } - /// Queue a message for ordered, retrying delivery to a specific peer. - pub fn send( - &mut self, - socket_addr: String, - msg: Res, - headers: TcpHeaders, - ) -> anyhow::Result<()> { - self.enqueue(socket_addr, msg, headers) - } - /// Mark a peer as a streaming subscriber. pub fn register_streaming_peer(&mut self, socket_addr: String) { self.middleware.register_streaming_peer(socket_addr); diff --git a/src/data_availability.rs b/src/data_availability.rs index 515e5cadb..ead23c242 100644 --- a/src/data_availability.rs +++ b/src/data_availability.rs @@ -618,7 +618,7 @@ impl DataAvailability { "📦 Found block at height {}, sending to {}", block_height, socket_addr ); - if let Err(e) = server.send( + if let Err(e) = server.enqueue( socket_addr.to_string(), DataAvailabilityEvent::SignedBlock(block), vec![], @@ -634,7 +634,7 @@ impl DataAvailability { "📦 Block at height {} not found in storage, sending BlockNotFound to {}", block_height, socket_addr ); - if let Err(e) = server.send( + if let Err(e) = server.enqueue( socket_addr.to_string(), DataAvailabilityEvent::BlockNotFound(block_height), vec![], @@ -650,7 +650,7 @@ impl DataAvailability { "📦 Error retrieving block at height {}: {:#}", block_height, e ); - if let Err(e) = server.send( + if let Err(e) = server.enqueue( socket_addr.to_string(), DataAvailabilityEvent::BlockNotFound(block_height), vec![], From 5212e37781005061b0e0b51b75deb7d96fbd1271 Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Fri, 13 Feb 2026 17:20:38 +0100 Subject: [PATCH 04/18] refine api --- crates/hyli-net/src/tcp.rs | 2 +- crates/hyli-net/src/tcp/middleware/impls.rs | 132 ++++++------ crates/hyli-net/src/tcp/middleware/mod.rs | 217 +++++++++++++++----- crates/hyli-net/src/tcp/tcp_server.rs | 47 ++++- src/data_availability.rs | 52 +++-- 5 files changed, 304 insertions(+), 146 deletions(-) diff --git a/crates/hyli-net/src/tcp.rs b/crates/hyli-net/src/tcp.rs index bfcd53076..b208dd0c4 100644 --- a/crates/hyli-net/src/tcp.rs +++ b/crates/hyli-net/src/tcp.rs @@ -1,7 +1,7 @@ #[cfg(feature = "turmoil")] pub mod intercept; -pub mod p2p_server; pub mod middleware; +pub mod p2p_server; pub mod tcp_client; pub mod tcp_server; diff --git a/crates/hyli-net/src/tcp/middleware/impls.rs b/crates/hyli-net/src/tcp/middleware/impls.rs index 39b6a1043..dd28ffbd7 100644 --- a/crates/hyli-net/src/tcp/middleware/impls.rs +++ b/crates/hyli-net/src/tcp/middleware/impls.rs @@ -5,9 +5,9 @@ use std::time::Duration; use borsh::{BorshDeserialize, BorshSerialize}; use tokio::time::Instant; -use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders, TcpMessageLabel}; +use crate::tcp::{TcpEvent, TcpHeaders, TcpMessageLabel}; -use super::{SendErrorContext, SendErrorOutcome, TcpServerMiddleware}; +use super::{SendErrorContext, SendErrorOutcome, TcpServerLike, TcpServerMiddleware}; #[derive(Default)] pub struct DropOnError; @@ -19,11 +19,10 @@ where { type EventOut = TcpEvent; - fn on_event( - &mut self, - server: &mut TcpServer, - event: TcpEvent, - ) -> Option { + fn on_event(&mut self, server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { match &event { TcpEvent::Error { socket_addr, .. } | TcpEvent::Closed { socket_addr } => { server.drop_peer_stream(socket_addr.clone()); @@ -67,27 +66,28 @@ where { type EventOut = B::EventOut; - fn on_event( - &mut self, - server: &mut TcpServer, - event: TcpEvent, - ) -> Option { + fn on_event(&mut self, server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { let event = self.first.on_event(server, event)?; self.second.on_event(server, event) } - fn on_send_error( - &mut self, - server: &mut TcpServer, - ctx: &SendErrorContext, - ) -> SendErrorOutcome { + fn on_send_error(&mut self, server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome + where + S: TcpServerLike>, + { match self.first.on_send_error(server, ctx) { SendErrorOutcome::Unhandled(_) => self.second.on_send_error(server, ctx), outcome => outcome, } } - fn on_tick(&mut self, server: &mut TcpServer) { + fn on_tick(&mut self, server: &mut S) + where + S: TcpServerLike>, + { self.first.on_tick(server); self.second.on_tick(server); } @@ -113,11 +113,10 @@ where { type EventOut = Req; - fn on_event( - &mut self, - _server: &mut TcpServer, - event: TcpEvent, - ) -> Option { + fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { match event { TcpEvent::Message { data, .. } => Some(data), TcpEvent::Closed { .. } | TcpEvent::Error { .. } => None, @@ -135,11 +134,10 @@ where { type EventOut = TcpInboundMessage; - fn on_event( - &mut self, - _server: &mut TcpServer, - event: TcpEvent, - ) -> Option { + fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { match event { TcpEvent::Message { socket_addr, @@ -274,11 +272,10 @@ where { type EventOut = TcpInboundMessage; - fn on_event( - &mut self, - _server: &mut TcpServer, - event: TcpEvent, - ) -> Option { + fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { match event { TcpEvent::Message { socket_addr, @@ -300,20 +297,26 @@ where } } - fn on_send_error( - &mut self, - server: &mut TcpServer, - ctx: &SendErrorContext, - ) -> SendErrorOutcome { + fn on_send_error(&mut self, server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome + where + S: TcpServerLike>, + { if !server.connected(&ctx.socket_addr) { self.unregister_streaming_peer(&ctx.socket_addr); return SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())); } - self.enqueue_to_peer(ctx.socket_addr.clone(), ctx.msg.clone(), ctx.headers.clone()); + self.enqueue_to_peer( + ctx.socket_addr.clone(), + ctx.msg.clone(), + ctx.headers.clone(), + ); SendErrorOutcome::RetryScheduled } - fn on_tick(&mut self, server: &mut TcpServer) { + fn on_tick(&mut self, server: &mut S) + where + S: TcpServerLike>, + { if self.queues.is_empty() { return; } @@ -354,8 +357,7 @@ where server.drop_peer_stream(peer.clone()); self.unregister_streaming_peer(&peer); } else { - front.next_attempt_at = - now + self.base_delay.mul_f64(front.retries as f64); + front.next_attempt_at = now + self.base_delay.mul_f64(front.retries as f64); } } } @@ -412,19 +414,17 @@ where { type EventOut = TcpEvent; - fn on_event( - &mut self, - _server: &mut TcpServer, - event: TcpEvent, - ) -> Option { + fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { Some(event) } - fn on_send_error( - &mut self, - server: &mut TcpServer, - ctx: &SendErrorContext, - ) -> SendErrorOutcome { + fn on_send_error(&mut self, server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome + where + S: TcpServerLike>, + { if !server.connected(&ctx.socket_addr) { return SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())); } @@ -438,7 +438,10 @@ where SendErrorOutcome::RetryScheduled } - fn on_tick(&mut self, server: &mut TcpServer) { + fn on_tick(&mut self, server: &mut S) + where + S: TcpServerLike>, + { if self.queue.is_empty() { return; } @@ -516,23 +519,24 @@ where { type EventOut = TcpEvent; - fn on_event( - &mut self, - server: &mut TcpServer, - event: TcpEvent, - ) -> Option { + fn on_event(&mut self, server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { self.drop_on_error.on_event(server, event) } - fn on_send_error( - &mut self, - server: &mut TcpServer, - ctx: &SendErrorContext, - ) -> SendErrorOutcome { + fn on_send_error(&mut self, server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome + where + S: TcpServerLike>, + { self.retrying_send.on_send_error(server, ctx) } - fn on_tick(&mut self, server: &mut TcpServer) { + fn on_tick(&mut self, server: &mut S) + where + S: TcpServerLike>, + { self.retrying_send.on_tick(server) } diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs index e33579a8d..fd7b7c150 100644 --- a/crates/hyli-net/src/tcp/middleware/mod.rs +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -81,6 +81,7 @@ //! # } //! ``` +use std::marker::PhantomData; use tokio::time::Instant; use borsh::{BorshDeserialize, BorshSerialize}; @@ -94,6 +95,69 @@ pub use impls::{ QueuedSendWithRetry, QueuedSenderMiddleware, RetryingSend, TcpInboundMessage, }; +pub trait Layer { + type Service; + fn layer(self, inner: S) -> Self::Service; +} + +pub struct MiddlewareLayer { + middleware: M, +} + +impl MiddlewareLayer { + pub fn new(middleware: M) -> Self { + Self { middleware } + } +} + +pub fn middleware_layer(middleware: M) -> MiddlewareLayer { + MiddlewareLayer::new(middleware) +} + +/// Compose an arbitrary number of TCP middlewares into nested [`EventPipeline`]s. +/// +/// This is compile-time composition (no dynamic dispatch/runtime middleware list). +/// +/// # Example +/// ```ignore +/// let middleware = hyli_net::compose_tcp_middleware!( +/// DropOnError, +/// RetryingSend::new(10, Duration::from_millis(100)), +/// MessageWithMeta, +/// ); +/// ``` +#[macro_export] +macro_rules! compose_tcp_middleware { + ($single:expr $(,)?) => { + $single + }; + ($first:expr, $($rest:expr),+ $(,)?) => { + $crate::tcp::middleware::EventPipeline::new( + $first, + $crate::compose_tcp_middleware!($($rest),+), + ) + }; +} + +/// Compose middleware types into a nested [`EventPipeline`] type. +/// +/// # Example +/// ```ignore +/// type MyMiddleware = hyli_net::compose_tcp_middleware_type!(DropOnError, MessageOnly); +/// ``` +#[macro_export] +macro_rules! compose_tcp_middleware_type { + ($single:ty $(,)?) => { + $single + }; + ($first:ty, $($rest:ty),+ $(,)?) => { + $crate::tcp::middleware::EventPipeline< + $first, + $crate::compose_tcp_middleware_type!($($rest),+), + > + }; +} + pub struct SendErrorContext { pub socket_addr: String, pub msg: Res, @@ -121,25 +185,26 @@ where /// Transform or filter inbound events before they are exposed to callers. /// Returning `None` will cause the wrapper to keep listening. - fn on_event( - &mut self, - _server: &mut TcpServer, - event: TcpEvent, - ) -> Option; + fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>; /// Handle outbound send errors. The default behavior is to surface the error. /// Implementations can enqueue retries or drop peers. - fn on_send_error( - &mut self, - _server: &mut TcpServer, - ctx: &SendErrorContext, - ) -> SendErrorOutcome { + fn on_send_error(&mut self, _server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome + where + S: TcpServerLike>, + { SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())) } /// Called on each `listen_next()` iteration before waiting for events. /// Use this to drive retry queues or housekeeping. - fn on_tick(&mut self, _server: &mut TcpServer) {} + fn on_tick(&mut self, _server: &mut S) + where + S: TcpServerLike>, + { + } /// Optional wakeup time for the next middleware action. If present, the /// wrapper will `select!` between the next event and this deadline. @@ -151,22 +216,26 @@ where /// Common interface for `TcpServer` and middleware wrappers. pub trait TcpServerLike { type EventOut; + type ConnectedClients<'a>: Iterator + where + Self: 'a; /// Receive the next inbound event (or mapped output if wrapped). async fn listen_next(&mut self) -> Option; /// Send a response to a peer. - fn send( - &mut self, - socket_addr: String, - msg: Res, - headers: TcpHeaders, - ) -> anyhow::Result<()>; + fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()>; + /// Send using borrowed payload/headers to avoid cloning on the success path. + fn send_ref(&mut self, socket_addr: &str, msg: &Res, headers: &TcpHeaders) -> anyhow::Result<()> + where + Res: Clone, + { + self.send(socket_addr.to_string(), msg.clone(), headers.clone()) + } /// Return the currently connected peer socket addresses. - fn connected_clients(&self) -> Box + '_>; + fn connected_clients(&self) -> Self::ConnectedClients<'_>; /// Check whether a peer socket is currently connected. fn connected(&self, socket_addr: &str) -> bool { - self.connected_clients() - .any(|addr| addr == socket_addr) + self.connected_clients().any(|addr| addr == socket_addr) } /// Drop and disconnect a peer socket. fn drop_peer_stream(&mut self, peer_ip: String); @@ -187,29 +256,55 @@ pub trait TcpServerLike { } } -pub struct TcpServerWithMiddleware +/// Tower-style layering helper for TCP servers and already-layered services. +/// +/// # Example +/// ```ignore +/// let server = TcpServer::::start(0, "Example").await?; +/// let mut server = server +/// .layer(middleware_layer(DropOnError)) +/// .layer(middleware_layer(RetryingSend::new(10, Duration::from_millis(100)))); +/// ``` +pub trait TcpServerExt: TcpServerLike + Sized { + fn layer(self, layer: L) -> L::Service + where + L: Layer, + { + layer.layer(self) + } +} + +impl TcpServerExt for T where T: TcpServerLike + Sized {} + +pub struct TcpServerWithMiddleware> where Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, { - inner: TcpServer, + inner: S, middleware: M, + _marker: PhantomData<(Req, Res)>, } -impl TcpServerWithMiddleware +impl TcpServerWithMiddleware where Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, { - pub fn new(inner: TcpServer, middleware: M) -> Self { - Self { inner, middleware } + pub fn new(inner: S, middleware: M) -> Self { + Self { + inner, + middleware, + _marker: PhantomData, + } } } -impl TcpServerWithMiddleware +impl TcpServerWithMiddleware where Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + S: TcpServerLike>, M: TcpServerMiddleware + QueuedSenderMiddleware, { /// Enqueue a message for ordered, retrying delivery to a specific peer. @@ -249,13 +344,19 @@ where } } -impl TcpServerLike for TcpServerWithMiddleware +impl TcpServerLike for TcpServerWithMiddleware where Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + S: TcpServerLike>, M: TcpServerMiddleware, { type EventOut = M::EventOut; + type ConnectedClients<'a> + = S::ConnectedClients<'a> + where + Self: 'a, + S: 'a; async fn listen_next(&mut self) -> Option { loop { @@ -281,21 +382,14 @@ where } } - fn send( - &mut self, - socket_addr: String, - msg: Res, - headers: TcpHeaders, - ) -> anyhow::Result<()> { - let msg_clone = msg.clone(); - let headers_clone = headers.clone(); - match self.inner.send(socket_addr.clone(), msg, headers) { + fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()> { + match self.inner.send_ref(&socket_addr, &msg, &headers) { Ok(()) => Ok(()), Err(error) => { let ctx = SendErrorContext { socket_addr, - msg: msg_clone, - headers: headers_clone, + msg, + headers, error, }; match self.middleware.on_send_error(&mut self.inner, &ctx) { @@ -310,8 +404,15 @@ where } } - fn connected_clients(&self) -> Box + '_> { - Box::new(self.inner.connected_clients()) + fn send_ref(&mut self, socket_addr: &str, msg: &Res, headers: &TcpHeaders) -> anyhow::Result<()> + where + Res: Clone, + { + self.inner.send_ref(socket_addr, msg, headers) + } + + fn connected_clients(&self) -> Self::ConnectedClients<'_> { + self.inner.connected_clients() } fn drop_peer_stream(&mut self, peer_ip: String) { @@ -319,28 +420,48 @@ where } } +impl Layer for MiddlewareLayer +where + Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, + Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + S: TcpServerLike>, + M: TcpServerMiddleware, +{ + type Service = TcpServerWithMiddleware; + + fn layer(self, inner: S) -> Self::Service { + TcpServerWithMiddleware::new(inner, self.middleware) + } +} + impl TcpServerLike for TcpServer where Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, { type EventOut = TcpEvent; + type ConnectedClients<'a> + = crate::tcp::tcp_server::ConnectedClients<'a> + where + Self: 'a; async fn listen_next(&mut self) -> Option { TcpServer::listen_next(self).await } - fn send( - &mut self, - socket_addr: String, - msg: Res, - headers: TcpHeaders, - ) -> anyhow::Result<()> { + fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()> { TcpServer::send(self, socket_addr, msg, headers) } - fn connected_clients(&self) -> Box + '_> { - Box::new(TcpServer::connected_clients(self)) + fn send_ref(&mut self, socket_addr: &str, msg: &Res, headers: &TcpHeaders) -> anyhow::Result<()> + where + Res: Clone, + { + TcpServer::send_ref(self, socket_addr, msg, headers) + } + + fn connected_clients(&self) -> Self::ConnectedClients<'_> { + TcpServer::connected_clients(self) } fn drop_peer_stream(&mut self, peer_ip: String) { diff --git a/crates/hyli-net/src/tcp/tcp_server.rs b/crates/hyli-net/src/tcp/tcp_server.rs index 20c525baa..f085ea076 100644 --- a/crates/hyli-net/src/tcp/tcp_server.rs +++ b/crates/hyli-net/src/tcp/tcp_server.rs @@ -33,6 +33,16 @@ use super::{tcp_client::TcpClient, SocketStream, TcpEvent}; type TcpSender = SplitSink; type TcpReceiver = SplitStream; +pub struct ConnectedClients<'a>(std::collections::hash_map::Keys<'a, String, SocketStream>); + +impl<'a> Iterator for ConnectedClients<'a> { + type Item = &'a String; + + fn next(&mut self) -> Option { + self.0.next() + } +} + // Best-effort enqueue into the main TcpServer event loop. If the queue is full, log once and apply // backpressure by awaiting. If the queue is closed, do whatever the caller decides (typically // break out of the task loop). @@ -194,10 +204,8 @@ where } /// Adresses of currently connected clients (no health check) - pub fn connected_clients( - &self, - ) -> impl Iterator { - self.sockets.keys() + pub fn connected_clients(&self) -> ConnectedClients<'_> { + ConnectedClients(self.sockets.keys()) } pub fn connected(&self, socket_addr: &str) -> bool { @@ -327,6 +335,37 @@ where Ok(()) } + pub fn send_ref( + &mut self, + socket_addr: &str, + msg: &Res, + headers: &TcpHeaders, + ) -> anyhow::Result<()> { + debug!(pool = %self.pool_name, "Sending msg {:?} to {}", msg, socket_addr); + let message_label = msg.message_label(); + let stream = self + .sockets + .get_mut(socket_addr) + .context(format!("Retrieving client {socket_addr}"))?; + + let binary_data = to_tcp_message_with_headers(msg, headers.clone())?; + stream + .sender + .try_send(TcpOutboundMessage { + message: binary_data, + message_label, + }) + .map_err(|e| { + anyhow::anyhow!( + "Outbound TCP channel full/closed while sending msg to client {}: {}", + socket_addr, + e + ) + })?; + self.metrics.event_loop_message_sent(message_label); + Ok(()) + } + pub fn ping(&mut self, socket_addr: String) -> anyhow::Result<()> { let stream = self .sockets diff --git a/src/data_availability.rs b/src/data_availability.rs index ead23c242..fb7a983af 100644 --- a/src/data_availability.rs +++ b/src/data_availability.rs @@ -9,8 +9,8 @@ use hyli_modules::telemetry::{global_meter_or_panic, Counter, Gauge, KeyValue}; use hyli_modules::{bus::SharedMessageBus, modules::Module}; use hyli_modules::{log_error, module_bus_client, module_handle_messages}; use hyli_net::tcp::middleware::{ - DropOnError, EventPipeline, QueuedSendWithRetry, TcpInboundMessage, TcpServerLike, - TcpServerWithMiddleware, + middleware_layer, DropOnError, QueuedSendWithRetry, TcpInboundMessage, TcpServerExt, + TcpServerLike, TcpServerWithMiddleware, }; use tokio::task::JoinHandle; @@ -34,12 +34,10 @@ use tracing::{debug, error, info, trace, warn}; use crate::model::SharedRunContext; type DataAvailabilityServer = TcpServerWithMiddleware< - EventPipeline< - DropOnError, - QueuedSendWithRetry, - >, + QueuedSendWithRetry, DataAvailabilityRequest, DataAvailabilityEvent, + TcpServerWithMiddleware, >; impl Module for DataAvailability { @@ -496,13 +494,11 @@ impl DataAvailability { format!("DAServer-{}", self.config.id.clone()).as_str(), ) .await?; - let mut server = DataAvailabilityServer::new( - inner_server, - EventPipeline::new( - DropOnError, + let mut server = inner_server + .layer(middleware_layer(DropOnError)) + .layer(middleware_layer( QueuedSendWithRetry::new(10, Duration::from_millis(100)).max_per_tick(256), - ), - ); + )); let (catchup_block_sender, mut catchup_block_receiver) = tokio::sync::mpsc::channel::(100); @@ -678,10 +674,7 @@ impl DataAvailability { "📦 Received built block (height {}) from Mempool", signed_block.height() ); - if let Some(height) = self - .handle_signed_block(signed_block, tcp_server) - .await - { + if let Some(height) = self.handle_signed_block(signed_block, tcp_server).await { self.catchupper.manage_catchup(height, sender)?; } } @@ -854,7 +847,8 @@ impl DataAvailability { ) -> anyhow::Result<()> { self.store_block(&block)?; - tcp_server.enqueue_to_streaming_peers(DataAvailabilityEvent::SignedBlock(block.clone()), vec![]); + tcp_server + .enqueue_to_streaming_peers(DataAvailabilityEvent::SignedBlock(block.clone()), vec![]); // Send the block to NodeState for processing _ = log_error!( @@ -908,7 +902,10 @@ impl DataAvailability { ); } Ok(None) => { - warn!("Missing block {} while starting stream to {}", hash, peer_ip); + warn!( + "Missing block {} while starting stream to {}", + hash, peer_ip + ); } Err(e) => { warn!( @@ -943,7 +940,9 @@ pub mod tests { use hyli_modules::node_state::NodeState; use hyli_modules::utils::da_codec::DataAvailabilityClient; use hyli_modules::utils::da_codec::DataAvailabilityServer as RawDataAvailabilityServer; - use hyli_net::tcp::middleware::{DropOnError, EventPipeline, QueuedSendWithRetry}; + use hyli_net::tcp::middleware::{ + middleware_layer, DropOnError, QueuedSendWithRetry, TcpServerExt, + }; use hyli_net::tcp::TcpEvent; use staking::state::Staking; @@ -955,13 +954,11 @@ pub mod tests { async fn make_da_server(port: u16, name: &str) -> super::DataAvailabilityServer { let inner = RawDataAvailabilityServer::start(port, name).await.unwrap(); - super::DataAvailabilityServer::new( - inner, - EventPipeline::new( - DropOnError, + inner + .layer(middleware_layer(DropOnError)) + .layer(middleware_layer( QueuedSendWithRetry::new(10, Duration::from_millis(100)).max_per_tick(256), - ), - ) + )) } impl DataAvailabilityTestCtx { @@ -1057,10 +1054,7 @@ pub mod tests { Some(BlockHeight(9998)) ); } else { - assert_eq!( - da.handle_signed_block(block, &mut server).await, - None - ); + assert_eq!(da.handle_signed_block(block, &mut server).await, None); } } } From 942f934bdf98c84eb4d04ee5f6c1b85997affd0c Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Fri, 13 Feb 2026 17:24:58 +0100 Subject: [PATCH 05/18] Simplify TCP server middleware --- crates/hyli-net/src/tcp/middleware/impls.rs | 84 --------------------- crates/hyli-net/src/tcp/middleware/mod.rs | 48 +----------- 2 files changed, 2 insertions(+), 130 deletions(-) diff --git a/crates/hyli-net/src/tcp/middleware/impls.rs b/crates/hyli-net/src/tcp/middleware/impls.rs index dd28ffbd7..6e2cb1d6a 100644 --- a/crates/hyli-net/src/tcp/middleware/impls.rs +++ b/crates/hyli-net/src/tcp/middleware/impls.rs @@ -33,70 +33,6 @@ where } } -fn min_wakeup(lhs: Option, rhs: Option) -> Option { - match (lhs, rhs) { - (Some(a), Some(b)) => Some(a.min(b)), - (Some(a), None) => Some(a), - (None, Some(b)) => Some(b), - (None, None) => None, - } -} - -pub struct EventPipeline { - first: A, - second: B, -} - -impl EventPipeline { - pub fn new(first: A, second: B) -> Self { - Self { first, second } - } - - pub(crate) fn second_mut(&mut self) -> &mut B { - &mut self.second - } -} - -impl TcpServerMiddleware for EventPipeline -where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, - A: TcpServerMiddleware>, - B: TcpServerMiddleware, -{ - type EventOut = B::EventOut; - - fn on_event(&mut self, server: &mut S, event: TcpEvent) -> Option - where - S: TcpServerLike>, - { - let event = self.first.on_event(server, event)?; - self.second.on_event(server, event) - } - - fn on_send_error(&mut self, server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome - where - S: TcpServerLike>, - { - match self.first.on_send_error(server, ctx) { - SendErrorOutcome::Unhandled(_) => self.second.on_send_error(server, ctx), - outcome => outcome, - } - } - - fn on_tick(&mut self, server: &mut S) - where - S: TcpServerLike>, - { - self.first.on_tick(server); - self.second.on_tick(server); - } - - fn next_wakeup(&self) -> Option { - min_wakeup(self.first.next_wakeup(), self.second.next_wakeup()) - } -} - pub struct TcpInboundMessage { pub socket_addr: String, pub data: Req, @@ -245,26 +181,6 @@ where } } -impl QueuedSenderMiddleware for EventPipeline -where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, - A: TcpServerMiddleware>, - B: QueuedSenderMiddleware, -{ - fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) { - self.second_mut().enqueue_to_peer(socket_addr, msg, headers); - } - - fn register_streaming_peer(&mut self, socket_addr: String) { - self.second_mut().register_streaming_peer(socket_addr); - } - - fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) { - self.second_mut().enqueue_to_streaming_peers(msg, headers); - } -} - impl TcpServerMiddleware for QueuedSendWithRetry where Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs index fd7b7c150..5afd9bfc3 100644 --- a/crates/hyli-net/src/tcp/middleware/mod.rs +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -91,8 +91,8 @@ use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders, TcpMessageLabel}; mod impls; pub use impls::{ - DropOnError, DropOnErrorAndRetry, EventPipeline, MessageOnly, MessageWithMeta, - QueuedSendWithRetry, QueuedSenderMiddleware, RetryingSend, TcpInboundMessage, + DropOnError, DropOnErrorAndRetry, MessageOnly, MessageWithMeta, QueuedSendWithRetry, + QueuedSenderMiddleware, RetryingSend, TcpInboundMessage, }; pub trait Layer { @@ -114,50 +114,6 @@ pub fn middleware_layer(middleware: M) -> MiddlewareLayer { MiddlewareLayer::new(middleware) } -/// Compose an arbitrary number of TCP middlewares into nested [`EventPipeline`]s. -/// -/// This is compile-time composition (no dynamic dispatch/runtime middleware list). -/// -/// # Example -/// ```ignore -/// let middleware = hyli_net::compose_tcp_middleware!( -/// DropOnError, -/// RetryingSend::new(10, Duration::from_millis(100)), -/// MessageWithMeta, -/// ); -/// ``` -#[macro_export] -macro_rules! compose_tcp_middleware { - ($single:expr $(,)?) => { - $single - }; - ($first:expr, $($rest:expr),+ $(,)?) => { - $crate::tcp::middleware::EventPipeline::new( - $first, - $crate::compose_tcp_middleware!($($rest),+), - ) - }; -} - -/// Compose middleware types into a nested [`EventPipeline`] type. -/// -/// # Example -/// ```ignore -/// type MyMiddleware = hyli_net::compose_tcp_middleware_type!(DropOnError, MessageOnly); -/// ``` -#[macro_export] -macro_rules! compose_tcp_middleware_type { - ($single:ty $(,)?) => { - $single - }; - ($first:ty, $($rest:ty),+ $(,)?) => { - $crate::tcp::middleware::EventPipeline< - $first, - $crate::compose_tcp_middleware_type!($($rest),+), - > - }; -} - pub struct SendErrorContext { pub socket_addr: String, pub msg: Res, From bf1a08f6319d07a057820e56023ee852f7fd641e Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Mon, 16 Feb 2026 13:49:15 +0100 Subject: [PATCH 06/18] use more macros to remove boilerplate in tcp server middleware implementations --- Cargo.lock | 9 + Cargo.toml | 2 + crates/hyli-net-macros/Cargo.toml | 16 + crates/hyli-net-macros/src/lib.rs | 499 ++++++++++++++++++++ crates/hyli-net/Cargo.toml | 1 + crates/hyli-net/src/tcp/middleware/impls.rs | 223 ++++----- crates/hyli-net/src/tcp/middleware/mod.rs | 187 +++++++- 7 files changed, 771 insertions(+), 166 deletions(-) create mode 100644 crates/hyli-net-macros/Cargo.toml create mode 100644 crates/hyli-net-macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 6cf0f8750..edf3be2c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5671,6 +5671,7 @@ dependencies = [ "hyli-bus", "hyli-contract-sdk", "hyli-crypto", + "hyli-net-macros", "hyli-turmoil-shims", "hyper 1.8.1", "hyper-util", @@ -5691,6 +5692,14 @@ dependencies = [ "turmoil", ] +[[package]] +name = "hyli-net-macros" +version = "0.14.0" +dependencies = [ + "quote", + "syn 2.0.114", +] + [[package]] name = "hyli-noir-tools" version = "0.14.0" diff --git a/Cargo.toml b/Cargo.toml index 6460cb8eb..9084d68c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ members = [ "crates/contract-sdk", "crates/hyli-loadtest", "crates/hyli-model", + "crates/hyli-net-macros", "crates/hyli-crypto", "crates/hyli-turmoil-shims", "crates/hyli-verifiers", @@ -51,6 +52,7 @@ sdk = { version = "0.14.0", default-features = false, path = "crates/contract-sd hyli-contract-sdk = { version = "0.14.0", default-features = false, path = "crates/contract-sdk", package = "hyli-contract-sdk" } client-sdk = { version = "0.14.0", default-features = false, path = "crates/client-sdk", package = "hyli-client-sdk" } hyli-net = { version = "0.14.0", default-features = false, path = "crates/hyli-net", package = "hyli-net" } +hyli-net-macros = { version = "0.14.0", default-features = false, path = "crates/hyli-net-macros", package = "hyli-net-macros" } hyli-model = { version = "0.14.0", default-features = false, path = "crates/hyli-model", package = "hyli-model" } hyli-crypto = { version = "0.14.0", default-features = false, path = "crates/hyli-crypto", package = "hyli-crypto" } hyli-turmoil-shims = { version = "0.14.0", default-features = false, path = "crates/hyli-turmoil-shims", package = "hyli-turmoil-shims", features = [ diff --git a/crates/hyli-net-macros/Cargo.toml b/crates/hyli-net-macros/Cargo.toml new file mode 100644 index 000000000..13c2c615d --- /dev/null +++ b/crates/hyli-net-macros/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "hyli-net-macros" +description = "Proc macros for hyli-net" +license = "MIT" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +rust-version = { workspace = true } + +[lib] +proc-macro = true + +[dependencies] +quote = { workspace = true } +syn = { workspace = true, features = ["full"] } diff --git a/crates/hyli-net-macros/src/lib.rs b/crates/hyli-net-macros/src/lib.rs new file mode 100644 index 000000000..703c36ea4 --- /dev/null +++ b/crates/hyli-net-macros/src/lib.rs @@ -0,0 +1,499 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{ + parse::Parser, parse_macro_input, parse_quote, punctuated::Punctuated, GenericParam, ImplItem, + ImplItemFn, ItemImpl, Meta, Token, Type, TypeParamBound, WherePredicate, +}; + +fn has_method(item: &ItemImpl, name: &str) -> bool { + item.items.iter().any(|it| { + if let ImplItem::Fn(f) = it { + f.sig.ident == name + } else { + false + } + }) +} + +fn parse_mode_and_event_out(attr: TokenStream) -> syn::Result<(String, Option)> { + let parser = Punctuated::::parse_terminated; + let metas = parser.parse2(attr.into())?; + let mut mode: Option = None; + let mut event_out: Option = None; + + for meta in metas { + match meta { + Meta::Path(path) => { + let Some(seg) = path.segments.last() else { + continue; + }; + let ident = seg.ident.to_string(); + if ident == "inbound" || ident == "outbound_tick" || ident == "all" { + mode = Some(ident); + } else { + return Err(syn::Error::new_spanned( + path, + "unknown mode: expected `inbound`, `outbound_tick`, or `all`", + )); + } + } + Meta::List(list) if list.path.is_ident("event_out") => { + event_out = Some(syn::parse2::(list.tokens)?); + } + other => { + return Err(syn::Error::new_spanned( + other, + "invalid attribute args: expected `inbound`, `outbound_tick`, or `all`, optionally `event_out(Type)`", + )); + } + } + } + + let Some(mode) = mode else { + return Err(syn::Error::new_spanned( + quote!(tcp_middleware), + "missing mode: use `inbound`, `outbound_tick`, or `all`", + )); + }; + + Ok((mode, event_out)) +} + +fn add_req_res_bounds(generics: &mut syn::Generics, req: &syn::Ident, res: &syn::Ident) { + let where_clause = generics.make_where_clause(); + where_clause.predicates.push(parse_quote!( + #req: borsh::BorshSerialize + + borsh::BorshDeserialize + + std::fmt::Debug + + Send + + crate::tcp::TcpMessageLabel + + 'static + )); + where_clause.predicates.push(parse_quote!( + #res: borsh::BorshSerialize + borsh::BorshDeserialize + std::fmt::Debug + crate::tcp::TcpMessageLabel + )); +} + +fn has_type_param(generics: &syn::Generics, name: &str) -> bool { + generics.params.iter().any(|p| match p { + GenericParam::Type(t) => t.ident == name, + _ => false, + }) +} + +fn has_declared_bound_for_type(item_impl: &ItemImpl, name: &str) -> bool { + // Generic parameter bound: `impl ...` + let in_params = item_impl.generics.params.iter().any(|p| match p { + GenericParam::Type(t) if t.ident == name => !t.bounds.is_empty(), + _ => false, + }); + if in_params { + return true; + } + + // Where-clause bound: `where T: Bound` + let Some(where_clause) = &item_impl.generics.where_clause else { + return false; + }; + + where_clause.predicates.iter().any(|pred| match pred { + WherePredicate::Type(ty_pred) => { + let is_target_type = if let Type::Path(type_path) = &ty_pred.bounded_ty { + type_path.qself.is_none() + && type_path.path.segments.len() == 1 + && type_path.path.segments[0].ident == name + } else { + false + }; + let has_any_bound = ty_pred + .bounds + .iter() + .any(|b| matches!(b, TypeParamBound::Trait(_) | TypeParamBound::Lifetime(_))); + is_target_type && has_any_bound + } + _ => false, + }) +} + +fn maybe_add_default_impl_req_bound(item_impl: &mut ItemImpl, req: &syn::Ident) { + if has_declared_bound_for_type(item_impl, &req.to_string()) { + return; + } + let where_clause = item_impl.generics.make_where_clause(); + where_clause.predicates.push(parse_quote!( + #req: crate::tcp::middleware::TcpReqBound + )); +} + +fn maybe_add_default_impl_res_bound(item_impl: &mut ItemImpl, res: &syn::Ident) { + if has_declared_bound_for_type(item_impl, &res.to_string()) { + return; + } + let where_clause = item_impl.generics.make_where_clause(); + where_clause.predicates.push(parse_quote!( + #res: crate::tcp::middleware::TcpResBound + )); +} + +fn method_has_type_param(method: &ImplItemFn, name: &str) -> bool { + method.sig.generics.params.iter().any(|p| match p { + GenericParam::Type(t) => t.ident == name, + _ => false, + }) +} + +fn add_req_bound_if_present(method: &mut ImplItemFn, req: &syn::Ident) { + if !method_has_type_param(method, &req.to_string()) { + return; + } + let where_clause = method.sig.generics.make_where_clause(); + where_clause.predicates.push(parse_quote!( + #req: borsh::BorshSerialize + + borsh::BorshDeserialize + + std::fmt::Debug + + Send + + crate::tcp::TcpMessageLabel + + 'static + )); +} + +fn add_res_bound_if_present(method: &mut ImplItemFn, res: &syn::Ident) { + if !method_has_type_param(method, &res.to_string()) { + return; + } + let where_clause = method.sig.generics.make_where_clause(); + where_clause.predicates.push(parse_quote!( + #res: borsh::BorshSerialize + borsh::BorshDeserialize + std::fmt::Debug + crate::tcp::TcpMessageLabel + )); +} + +fn add_s_bound_if_present( + method: &mut ImplItemFn, + s: &syn::Ident, + req: &syn::Ident, + res: &syn::Ident, +) { + if !method_has_type_param(method, &s.to_string()) { + return; + } + let where_clause = method.sig.generics.make_where_clause(); + where_clause.predicates.push(parse_quote!( + #s: crate::tcp::middleware::TcpServerLike<#req, #res, EventOut = crate::tcp::TcpEvent<#req>> + )); +} + +fn inject_hook_bounds( + item_impl: &mut ItemImpl, + mode: &str, + req_ident: &syn::Ident, + res_ident: &syn::Ident, +) { + let s_ident = format_ident!("S"); + for item in &mut item_impl.items { + let ImplItem::Fn(method) = item else { + continue; + }; + + let name = method.sig.ident.to_string(); + if !matches!(name.as_str(), "inbound" | "outbound_error" | "tick") { + continue; + } + + match mode { + "inbound" => { + add_req_bound_if_present(method, req_ident); + add_res_bound_if_present(method, res_ident); + add_s_bound_if_present(method, &s_ident, req_ident, res_ident); + } + "outbound_tick" => { + add_req_bound_if_present(method, req_ident); + add_s_bound_if_present(method, &s_ident, req_ident, res_ident); + } + "all" => { + add_s_bound_if_present(method, &s_ident, req_ident, res_ident); + } + _ => {} + } + } +} + +#[proc_macro_attribute] +pub fn tcp_middleware(attr: TokenStream, item: TokenStream) -> TokenStream { + let mut item_impl = parse_macro_input!(item as ItemImpl); + if item_impl.trait_.is_some() { + return syn::Error::new_spanned( + &item_impl, + "#[tcp_middleware(...)] must be used on an inherent impl block", + ) + .to_compile_error() + .into(); + } + + let (mode, event_out) = match parse_mode_and_event_out(attr) { + Ok(parsed) => parsed, + Err(err) => return err.to_compile_error().into(), + }; + + let self_ty = item_impl.self_ty.clone(); + let req_ident = format_ident!("__TcpReq"); + let res_ident = format_ident!("__TcpRes"); + + match mode.as_str() { + "inbound" => { + let req_method_ident = format_ident!("Req"); + let res_method_ident = format_ident!("Res"); + inject_hook_bounds( + &mut item_impl, + "inbound", + &req_method_ident, + &res_method_ident, + ); + + let Some(event_out) = event_out else { + return syn::Error::new_spanned( + &item_impl, + "missing `event_out(Type)` for inbound mode, e.g. #[tcp_middleware(inbound, event_out(TcpEvent<__TcpReq>))]", + ) + .to_compile_error() + .into(); + }; + + let mut in_g = item_impl.generics.clone(); + in_g.params.push(parse_quote!(#req_ident)); + in_g.params.push(parse_quote!(#res_ident)); + add_req_res_bounds(&mut in_g, &req_ident, &res_ident); + let (in_impl_g, _, in_where) = in_g.split_for_impl(); + + let mut out_g = item_impl.generics.clone(); + out_g.params.push(parse_quote!(#req_ident)); + out_g.params.push(parse_quote!(#res_ident)); + add_req_res_bounds(&mut out_g, &req_ident, &res_ident); + let (out_impl_g, _, out_where) = out_g.split_for_impl(); + + let mut tick_g = item_impl.generics.clone(); + tick_g.params.push(parse_quote!(#req_ident)); + tick_g.params.push(parse_quote!(#res_ident)); + add_req_res_bounds(&mut tick_g, &req_ident, &res_ident); + let (tick_impl_g, _, tick_where) = tick_g.split_for_impl(); + + quote! { + #item_impl + + impl #in_impl_g crate::tcp::middleware::TcpInboundMiddleware<#req_ident, #res_ident> for #self_ty #in_where { + type EventOut = #event_out; + + fn on_event(&mut self, server: &mut S, event: crate::tcp::TcpEvent<#req_ident>) -> Option + where + S: crate::tcp::middleware::TcpServerLike<#req_ident, #res_ident, EventOut = crate::tcp::TcpEvent<#req_ident>>, + { + let mut cx = crate::tcp::middleware::InboundCx::<#req_ident, #res_ident, S>::new(server); + self.inbound(&mut cx, event) + } + } + + impl #out_impl_g crate::tcp::middleware::TcpOutboundMiddleware<#req_ident, #res_ident> for #self_ty #out_where {} + + impl #tick_impl_g crate::tcp::middleware::TcpTickMiddleware<#req_ident, #res_ident> for #self_ty #tick_where {} + } + .into() + } + "outbound_tick" => { + if !has_type_param(&item_impl.generics, "Res") { + return syn::Error::new_spanned( + &item_impl, + "`outbound_tick`/`all` mode requires a `Res` type parameter on the impl (e.g. impl Type)", + ) + .to_compile_error() + .into(); + } + + let res_self_ident = format_ident!("Res"); + maybe_add_default_impl_res_bound(&mut item_impl, &res_self_ident); + let req_method_ident = format_ident!("Req"); + inject_hook_bounds( + &mut item_impl, + "outbound_tick", + &req_method_ident, + &res_self_ident, + ); + let has_outbound = has_method(&item_impl, "outbound_error"); + let has_tick = has_method(&item_impl, "tick"); + let has_next_wakeup = has_method(&item_impl, "next_wakeup"); + + let mut in_g = item_impl.generics.clone(); + in_g.params.push(parse_quote!(#req_ident)); + add_req_res_bounds(&mut in_g, &req_ident, &res_self_ident); + let (in_impl_g, _, in_where) = in_g.split_for_impl(); + + let mut out_g = item_impl.generics.clone(); + out_g.params.push(parse_quote!(#req_ident)); + add_req_res_bounds(&mut out_g, &req_ident, &res_self_ident); + let (out_impl_g, _, out_where) = out_g.split_for_impl(); + + let mut tick_g = item_impl.generics.clone(); + tick_g.params.push(parse_quote!(#req_ident)); + add_req_res_bounds(&mut tick_g, &req_ident, &res_self_ident); + let (tick_impl_g, _, tick_where) = tick_g.split_for_impl(); + + let outbound_body = if has_outbound { + quote! { self.outbound_error(&mut cx, ctx) } + } else { + quote! { crate::tcp::middleware::SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())) } + }; + let tick_body = if has_tick { + quote! { self.tick(&mut cx) } + } else { + quote! {} + }; + let next_wakeup_body = if has_next_wakeup { + quote! { self.next_wakeup() } + } else { + quote! { None } + }; + + quote! { + #item_impl + + impl #in_impl_g crate::tcp::middleware::TcpInboundMiddleware<#req_ident, #res_self_ident> for #self_ty #in_where { + type EventOut = crate::tcp::TcpEvent<#req_ident>; + + fn on_event(&mut self, server: &mut S, event: crate::tcp::TcpEvent<#req_ident>) -> Option + where + S: crate::tcp::middleware::TcpServerLike<#req_ident, #res_self_ident, EventOut = crate::tcp::TcpEvent<#req_ident>>, + { + Some(event) + } + } + + impl #out_impl_g crate::tcp::middleware::TcpOutboundMiddleware<#req_ident, #res_self_ident> for #self_ty #out_where { + fn on_send_error(&mut self, server: &mut S, ctx: &crate::tcp::middleware::SendErrorContext<#res_self_ident>) -> crate::tcp::middleware::SendErrorOutcome + where + S: crate::tcp::middleware::TcpServerLike<#req_ident, #res_self_ident, EventOut = crate::tcp::TcpEvent<#req_ident>>, + { + let mut cx = crate::tcp::middleware::OutboundCx::<#req_ident, #res_self_ident, S>::new(server); + #outbound_body + } + } + + impl #tick_impl_g crate::tcp::middleware::TcpTickMiddleware<#req_ident, #res_self_ident> for #self_ty #tick_where { + fn on_tick(&mut self, server: &mut S) + where + S: crate::tcp::middleware::TcpServerLike<#req_ident, #res_self_ident, EventOut = crate::tcp::TcpEvent<#req_ident>>, + { + let mut cx = crate::tcp::middleware::TickCx::<#req_ident, #res_self_ident, S>::new(server); + #tick_body + } + + fn next_wakeup(&self) -> Option { + #next_wakeup_body + } + } + } + .into() + } + "all" => { + if !has_type_param(&item_impl.generics, "Req") || !has_type_param(&item_impl.generics, "Res") { + return syn::Error::new_spanned( + &item_impl, + "`all` mode requires `Req` and `Res` type parameters on the impl (e.g. impl Type)", + ) + .to_compile_error() + .into(); + } + + let Some(event_out_ty) = event_out else { + return syn::Error::new_spanned( + &item_impl, + "missing `event_out(Type)` for all mode, e.g. #[tcp_middleware(all, event_out(TcpInboundMessage))]", + ) + .to_compile_error() + .into(); + }; + + let has_inbound = has_method(&item_impl, "inbound"); + let has_outbound = has_method(&item_impl, "outbound_error"); + let has_tick = has_method(&item_impl, "tick"); + let has_next_wakeup = has_method(&item_impl, "next_wakeup"); + + let req_self_ident = format_ident!("Req"); + let res_self_ident = format_ident!("Res"); + maybe_add_default_impl_req_bound(&mut item_impl, &req_self_ident); + maybe_add_default_impl_res_bound(&mut item_impl, &res_self_ident); + inject_hook_bounds(&mut item_impl, "all", &req_self_ident, &res_self_ident); + + let mut all_g = item_impl.generics.clone(); + add_req_res_bounds(&mut all_g, &req_self_ident, &res_self_ident); + let (all_impl_g, _, all_where) = all_g.split_for_impl(); + + let inbound_body = if has_inbound { + quote! { + let mut cx = crate::tcp::middleware::InboundCx::::new(server); + self.inbound(&mut cx, event) + } + } else { + quote! { Some(event) } + }; + let outbound_body = if has_outbound { + quote! { self.outbound_error(&mut cx, ctx) } + } else { + quote! { crate::tcp::middleware::SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())) } + }; + let tick_body = if has_tick { + quote! { self.tick(&mut cx) } + } else { + quote! {} + }; + let next_wakeup_body = if has_next_wakeup { + quote! { self.next_wakeup() } + } else { + quote! { None } + }; + + quote! { + #item_impl + + impl #all_impl_g crate::tcp::middleware::TcpInboundMiddleware for #self_ty #all_where { + type EventOut = #event_out_ty; + + fn on_event(&mut self, server: &mut S, event: crate::tcp::TcpEvent) -> Option + where + S: crate::tcp::middleware::TcpServerLike>, + { + #inbound_body + } + } + + impl #all_impl_g crate::tcp::middleware::TcpOutboundMiddleware for #self_ty #all_where { + fn on_send_error(&mut self, server: &mut S, ctx: &crate::tcp::middleware::SendErrorContext) -> crate::tcp::middleware::SendErrorOutcome + where + S: crate::tcp::middleware::TcpServerLike>, + { + let mut cx = crate::tcp::middleware::OutboundCx::::new(server); + #outbound_body + } + } + + impl #all_impl_g crate::tcp::middleware::TcpTickMiddleware for #self_ty #all_where { + fn on_tick(&mut self, server: &mut S) + where + S: crate::tcp::middleware::TcpServerLike>, + { + let mut cx = crate::tcp::middleware::TickCx::::new(server); + #tick_body + } + + fn next_wakeup(&self) -> Option { + #next_wakeup_body + } + } + } + .into() + } + _ => syn::Error::new_spanned( + &item_impl, + format!("unknown tcp_middleware mode `{mode}`: expected `inbound`, `outbound_tick`, or `all`"), + ) + .to_compile_error() + .into(), + } +} diff --git a/crates/hyli-net/Cargo.toml b/crates/hyli-net/Cargo.toml index 348da412b..d85342618 100644 --- a/crates/hyli-net/Cargo.toml +++ b/crates/hyli-net/Cargo.toml @@ -13,6 +13,7 @@ sdk = { workspace = true, features = ["full-model"] } hyli-crypto = { workspace = true } hyli-turmoil-shims = { workspace = true, features = ["otlp"] } hyli-bus = { workspace = true, optional = true } +hyli-net-macros = { workspace = true } anyhow = { workspace = true } borsh = { workspace = true } diff --git a/crates/hyli-net/src/tcp/middleware/impls.rs b/crates/hyli-net/src/tcp/middleware/impls.rs index 6e2cb1d6a..eee7bbacc 100644 --- a/crates/hyli-net/src/tcp/middleware/impls.rs +++ b/crates/hyli-net/src/tcp/middleware/impls.rs @@ -2,30 +2,26 @@ use std::collections::{HashSet, VecDeque}; use std::marker::PhantomData; use std::time::Duration; -use borsh::{BorshDeserialize, BorshSerialize}; +use hyli_net_macros::tcp_middleware; use tokio::time::Instant; -use crate::tcp::{TcpEvent, TcpHeaders, TcpMessageLabel}; +use crate::tcp::{TcpEvent, TcpHeaders}; -use super::{SendErrorContext, SendErrorOutcome, TcpServerLike, TcpServerMiddleware}; +use super::{InboundCx, OutboundCx, SendErrorContext, SendErrorOutcome}; #[derive(Default)] pub struct DropOnError; -impl TcpServerMiddleware for DropOnError -where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, -{ - type EventOut = TcpEvent; - - fn on_event(&mut self, server: &mut S, event: TcpEvent) -> Option - where - S: TcpServerLike>, - { +#[tcp_middleware(inbound, event_out(TcpEvent<__TcpReq>))] +impl DropOnError { + fn inbound( + &mut self, + cx: &mut InboundCx, + event: TcpEvent, + ) -> Option> { match &event { TcpEvent::Error { socket_addr, .. } | TcpEvent::Closed { socket_addr } => { - server.drop_peer_stream(socket_addr.clone()); + cx.drop_peer(socket_addr.clone()); } _ => {} } @@ -42,17 +38,13 @@ pub struct TcpInboundMessage { #[derive(Default)] pub struct MessageOnly; -impl TcpServerMiddleware for MessageOnly -where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, -{ - type EventOut = Req; - - fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option - where - S: TcpServerLike>, - { +#[tcp_middleware(inbound, event_out(__TcpReq))] +impl MessageOnly { + fn inbound( + &mut self, + _cx: &mut InboundCx, + event: TcpEvent, + ) -> Option { match event { TcpEvent::Message { data, .. } => Some(data), TcpEvent::Closed { .. } | TcpEvent::Error { .. } => None, @@ -63,17 +55,13 @@ where #[derive(Default)] pub struct MessageWithMeta; -impl TcpServerMiddleware for MessageWithMeta -where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, -{ - type EventOut = TcpInboundMessage; - - fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option - where - S: TcpServerLike>, - { +#[tcp_middleware(inbound, event_out(TcpInboundMessage<__TcpReq>))] +impl MessageWithMeta { + fn inbound( + &mut self, + _cx: &mut InboundCx, + event: TcpEvent, + ) -> Option> { match event { TcpEvent::Message { socket_addr, @@ -155,8 +143,8 @@ impl QueuedSendWithRetry { pub trait QueuedSenderMiddleware where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + Req: super::TcpReqBound, + Res: super::TcpResBound + Clone, { fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders); fn register_streaming_peer(&mut self, socket_addr: String); @@ -165,33 +153,51 @@ where impl QueuedSenderMiddleware for QueuedSendWithRetry where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + Req: super::TcpReqBound, + Res: super::TcpResBound + Clone, { fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) { - QueuedSendWithRetry::enqueue_to_peer(self, socket_addr, msg, headers); + self.queues + .entry(socket_addr) + .or_default() + .push_back(QueuedOutbound { + msg, + headers, + retries: 0, + next_attempt_at: Instant::now(), + }); } fn register_streaming_peer(&mut self, socket_addr: String) { - QueuedSendWithRetry::register_streaming_peer(self, socket_addr); + self.streaming_peers.insert(socket_addr); } fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) { - QueuedSendWithRetry::enqueue_to_streaming_peers(self, msg, headers); + for peer in self.streaming_peers.clone() { + self.queues + .entry(peer) + .or_default() + .push_back(QueuedOutbound { + msg: msg.clone(), + headers: headers.clone(), + retries: 0, + next_attempt_at: Instant::now(), + }); + } } } -impl TcpServerMiddleware for QueuedSendWithRetry +#[tcp_middleware(all, event_out(TcpInboundMessage))] +impl QueuedSendWithRetry where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + Req: super::TcpReqBound, + Res: super::TcpResBound + Clone, { - type EventOut = TcpInboundMessage; - - fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option - where - S: TcpServerLike>, - { + fn inbound( + &mut self, + _cx: &mut InboundCx, + event: TcpEvent, + ) -> Option> { match event { TcpEvent::Message { socket_addr, @@ -213,11 +219,12 @@ where } } - fn on_send_error(&mut self, server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome - where - S: TcpServerLike>, - { - if !server.connected(&ctx.socket_addr) { + fn outbound_error( + &mut self, + cx: &mut OutboundCx, + ctx: &SendErrorContext, + ) -> SendErrorOutcome { + if !cx.connected(&ctx.socket_addr) { self.unregister_streaming_peer(&ctx.socket_addr); return SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())); } @@ -229,10 +236,7 @@ where SendErrorOutcome::RetryScheduled } - fn on_tick(&mut self, server: &mut S) - where - S: TcpServerLike>, - { + fn tick(&mut self, cx: &mut super::TickCx) { if self.queues.is_empty() { return; } @@ -246,7 +250,7 @@ where break; } - if !server.connected(&peer) { + if !cx.connected(&peer) { self.unregister_streaming_peer(&peer); continue; } @@ -263,14 +267,14 @@ where continue; } - match server.send(peer.clone(), front.msg.clone(), front.headers.clone()) { + match cx.send(peer.clone(), front.msg.clone(), front.headers.clone()) { Ok(()) => { queue.pop_front(); } Err(_) => { front.retries += 1; if front.retries > self.max_retries { - server.drop_peer_stream(peer.clone()); + cx.drop_peer(peer.clone()); self.unregister_streaming_peer(&peer); } else { front.next_attempt_at = now + self.base_delay.mul_f64(front.retries as f64); @@ -323,25 +327,17 @@ impl RetryingSend { } } -impl TcpServerMiddleware for RetryingSend +#[tcp_middleware(outbound_tick)] +impl RetryingSend where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + Res: super::TcpResBound + Clone, { - type EventOut = TcpEvent; - - fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option - where - S: TcpServerLike>, - { - Some(event) - } - - fn on_send_error(&mut self, server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome - where - S: TcpServerLike>, - { - if !server.connected(&ctx.socket_addr) { + fn outbound_error( + &mut self, + cx: &mut OutboundCx, + ctx: &SendErrorContext, + ) -> SendErrorOutcome { + if !cx.connected(&ctx.socket_addr) { return SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())); } self.queue.push_back(PendingSend { @@ -354,10 +350,7 @@ where SendErrorOutcome::RetryScheduled } - fn on_tick(&mut self, server: &mut S) - where - S: TcpServerLike>, - { + fn tick(&mut self, cx: &mut super::TickCx) { if self.queue.is_empty() { return; } @@ -372,11 +365,11 @@ where continue; } - if !server.connected(&pending.socket_addr) { + if !cx.connected(&pending.socket_addr) { continue; } - match server.send( + match cx.send( pending.socket_addr.clone(), pending.msg.clone(), pending.headers.clone(), @@ -385,7 +378,7 @@ where Err(_) => { let next_retries = pending.retries + 1; if next_retries > self.max_retries { - server.drop_peer_stream(pending.socket_addr); + cx.drop_peer(pending.socket_addr); } else { pending.retries = next_retries; pending.next_attempt_at = @@ -408,55 +401,3 @@ where .min() } } - -pub struct DropOnErrorAndRetry { - drop_on_error: DropOnError, - retrying_send: RetryingSend, -} - -impl DropOnErrorAndRetry { - pub fn new(max_retries: usize, base_delay: Duration) -> Self { - Self { - drop_on_error: DropOnError, - retrying_send: RetryingSend::new(max_retries, base_delay), - } - } - - pub fn max_per_tick(mut self, max_per_tick: usize) -> Self { - self.retrying_send = self.retrying_send.max_per_tick(max_per_tick); - self - } -} - -impl TcpServerMiddleware for DropOnErrorAndRetry -where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, -{ - type EventOut = TcpEvent; - - fn on_event(&mut self, server: &mut S, event: TcpEvent) -> Option - where - S: TcpServerLike>, - { - self.drop_on_error.on_event(server, event) - } - - fn on_send_error(&mut self, server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome - where - S: TcpServerLike>, - { - self.retrying_send.on_send_error(server, ctx) - } - - fn on_tick(&mut self, server: &mut S) - where - S: TcpServerLike>, - { - self.retrying_send.on_tick(server) - } - - fn next_wakeup(&self) -> Option { - as TcpServerMiddleware>::next_wakeup(&self.retrying_send) - } -} diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs index 5afd9bfc3..46f883e99 100644 --- a/crates/hyli-net/src/tcp/middleware/mod.rs +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -10,7 +10,7 @@ //! use std::time::Duration; //! use hyli_net::tcp::{ //! tcp_server::TcpServer, -//! middleware::{TcpServerWithMiddleware, DropOnErrorAndRetry}, +//! middleware::{preset, TcpServerExt}, //! }; //! # use hyli_net::tcp::{TcpEvent, TcpMessageLabel}; //! # use borsh::{BorshDeserialize, BorshSerialize}; @@ -28,8 +28,11 @@ //! # //! # async fn example() -> anyhow::Result<()> { //! let inner = TcpServer::::start(0, "Example").await?; -//! let middleware = DropOnErrorAndRetry::new(10, Duration::from_millis(100)); -//! let mut server = TcpServerWithMiddleware::new(inner, middleware); +//! let mut server = hyli_net::tcp_stack!( +//! inner, +//! preset::drop_on_error(), +//! preset::retrying_send::(10, Duration::from_millis(100)), +//! ); //! //! while let Some(event) = server.listen_next().await { //! match event { @@ -50,7 +53,7 @@ //! `TcpEvent::Message` to the `Req` payload and filters out `Error/Closed`. //! ```no_run //! # use hyli_net::tcp::{tcp_server::TcpServer, TcpEvent, TcpMessageLabel}; -//! # use hyli_net::tcp::middleware::{TcpServerWithMiddleware, TcpServerMiddleware}; +//! # use hyli_net::tcp::middleware::{TcpInboundMiddleware, TcpServerWithMiddleware}; //! # use borsh::{BorshDeserialize, BorshSerialize}; //! # #[derive(Clone, Debug, BorshSerialize, BorshDeserialize)] //! # struct Req; @@ -64,7 +67,7 @@ //! # } //! # //! # struct MessageOnly; -//! # impl TcpServerMiddleware for MessageOnly { +//! # impl TcpInboundMiddleware for MessageOnly { //! # type EventOut = Req; //! # fn on_event(&mut self, _server: &mut TcpServer, event: TcpEvent) -> Option { //! # match event { TcpEvent::Message { data, .. } => Some(data), _ => None } @@ -90,11 +93,50 @@ use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders, TcpMessageLabel}; mod impls; +pub use hyli_net_macros::tcp_middleware; pub use impls::{ - DropOnError, DropOnErrorAndRetry, MessageOnly, MessageWithMeta, QueuedSendWithRetry, - QueuedSenderMiddleware, RetryingSend, TcpInboundMessage, + DropOnError, MessageOnly, MessageWithMeta, QueuedSendWithRetry, QueuedSenderMiddleware, + RetryingSend, TcpInboundMessage, }; +pub mod preset { + use std::time::Duration; + + use super::{DropOnError, QueuedSendWithRetry, RetryingSend}; + + pub fn drop_on_error() -> DropOnError { + DropOnError + } + + pub fn retrying_send(max_retries: usize, base_delay: Duration) -> RetryingSend { + RetryingSend::new(max_retries, base_delay) + } + + pub fn drop_and_retry( + max_retries: usize, + base_delay: Duration, + ) -> (DropOnError, RetryingSend) { + (drop_on_error(), retrying_send(max_retries, base_delay)) + } + + pub fn queued_send_with_retry( + max_retries: usize, + base_delay: Duration, + ) -> QueuedSendWithRetry { + QueuedSendWithRetry::new(max_retries, base_delay) + } +} + +#[macro_export] +macro_rules! tcp_stack { + ($server:expr, $($middleware:expr),+ $(,)?) => {{ + use $crate::tcp::middleware::{middleware_layer, TcpServerExt}; + let server = $server; + $(let server = server.layer(middleware_layer($middleware));)+ + server + }}; +} + pub trait Layer { type Service; fn layer(self, inner: S) -> Self::Service; @@ -132,10 +174,28 @@ pub enum SendErrorOutcome { Unhandled(anyhow::Error), } -pub trait TcpServerMiddleware +pub trait TcpReqBound: + BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static +{ +} +impl TcpReqBound for T where + T: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static +{ +} + +pub trait TcpResBound: + BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel +{ +} +impl TcpResBound for T where + T: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel +{ +} + +pub trait TcpInboundMiddleware where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, + Req: TcpReqBound, + Res: TcpResBound, { type EventOut; @@ -144,7 +204,13 @@ where fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option where S: TcpServerLike>; +} +pub trait TcpOutboundMiddleware +where + Req: TcpReqBound, + Res: TcpResBound, +{ /// Handle outbound send errors. The default behavior is to surface the error. /// Implementations can enqueue retries or drop peers. fn on_send_error(&mut self, _server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome @@ -153,7 +219,13 @@ where { SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())) } +} +pub trait TcpTickMiddleware +where + Req: TcpReqBound, + Res: TcpResBound, +{ /// Called on each `listen_next()` iteration before waiting for events. /// Use this to drive retry queues or housekeeping. fn on_tick(&mut self, _server: &mut S) @@ -212,6 +284,64 @@ pub trait TcpServerLike { } } +pub struct InboundCx<'a, Req, Res, S> +where + Req: BorshDeserialize, + S: TcpServerLike>, +{ + server: &'a mut S, + _marker: PhantomData<(Req, Res)>, +} + +impl<'a, Req, Res, S> InboundCx<'a, Req, Res, S> +where + Req: BorshDeserialize, + S: TcpServerLike>, +{ + pub fn new(server: &'a mut S) -> Self { + Self { + server, + _marker: PhantomData, + } + } + + pub fn connected(&self, socket_addr: &str) -> bool { + self.server.connected(socket_addr) + } + + pub fn drop_peer(&mut self, peer_ip: String) { + self.server.drop_peer_stream(peer_ip); + } + + pub fn send( + &mut self, + socket_addr: String, + msg: Res, + headers: TcpHeaders, + ) -> anyhow::Result<()> { + self.server.send(socket_addr, msg, headers) + } + + pub fn send_ref( + &mut self, + socket_addr: &str, + msg: &Res, + headers: &TcpHeaders, + ) -> anyhow::Result<()> + where + Res: Clone, + { + self.server.send_ref(socket_addr, msg, headers) + } + + pub fn server_mut(&mut self) -> &mut S { + self.server + } +} + +pub type OutboundCx<'a, Req, Res, S> = InboundCx<'a, Req, Res, S>; +pub type TickCx<'a, Req, Res, S> = InboundCx<'a, Req, Res, S>; + /// Tower-style layering helper for TCP servers and already-layered services. /// /// # Example @@ -234,8 +364,8 @@ impl TcpServerExt for T where T: TcpServerLike pub struct TcpServerWithMiddleware> where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, + Req: TcpReqBound, + Res: TcpResBound, { inner: S, middleware: M, @@ -244,8 +374,8 @@ where impl TcpServerWithMiddleware where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, + Req: TcpReqBound, + Res: TcpResBound, { pub fn new(inner: S, middleware: M) -> Self { Self { @@ -258,10 +388,13 @@ where impl TcpServerWithMiddleware where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + Req: TcpReqBound, + Res: TcpResBound + Clone, S: TcpServerLike>, - M: TcpServerMiddleware + QueuedSenderMiddleware, + M: TcpInboundMiddleware + + TcpOutboundMiddleware + + TcpTickMiddleware + + QueuedSenderMiddleware, { /// Enqueue a message for ordered, retrying delivery to a specific peer. pub fn enqueue( @@ -302,10 +435,12 @@ where impl TcpServerLike for TcpServerWithMiddleware where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + Req: TcpReqBound, + Res: TcpResBound + Clone, S: TcpServerLike>, - M: TcpServerMiddleware, + M: TcpInboundMiddleware + + TcpOutboundMiddleware + + TcpTickMiddleware, { type EventOut = M::EventOut; type ConnectedClients<'a> @@ -378,10 +513,12 @@ where impl Layer for MiddlewareLayer where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel + Clone, + Req: TcpReqBound, + Res: TcpResBound + Clone, S: TcpServerLike>, - M: TcpServerMiddleware, + M: TcpInboundMiddleware + + TcpOutboundMiddleware + + TcpTickMiddleware, { type Service = TcpServerWithMiddleware; @@ -392,8 +529,8 @@ where impl TcpServerLike for TcpServer where - Req: BorshSerialize + BorshDeserialize + std::fmt::Debug + Send + TcpMessageLabel + 'static, - Res: BorshSerialize + BorshDeserialize + std::fmt::Debug + TcpMessageLabel, + Req: TcpReqBound, + Res: TcpResBound, { type EventOut = TcpEvent; type ConnectedClients<'a> From 134912e8841e4d10fedb3327c2340311369081c3 Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Wed, 18 Feb 2026 17:31:27 +0100 Subject: [PATCH 07/18] remove macro, use simple middleware instead --- Cargo.lock | 9 - Cargo.toml | 2 - crates/hyli-net-macros/Cargo.toml | 16 - crates/hyli-net-macros/src/lib.rs | 499 -------------------- crates/hyli-net/Cargo.toml | 1 - crates/hyli-net/src/tcp/middleware/impls.rs | 130 ++--- crates/hyli-net/src/tcp/middleware/mod.rs | 42 +- 7 files changed, 85 insertions(+), 614 deletions(-) delete mode 100644 crates/hyli-net-macros/Cargo.toml delete mode 100644 crates/hyli-net-macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index edf3be2c1..6cf0f8750 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5671,7 +5671,6 @@ dependencies = [ "hyli-bus", "hyli-contract-sdk", "hyli-crypto", - "hyli-net-macros", "hyli-turmoil-shims", "hyper 1.8.1", "hyper-util", @@ -5692,14 +5691,6 @@ dependencies = [ "turmoil", ] -[[package]] -name = "hyli-net-macros" -version = "0.14.0" -dependencies = [ - "quote", - "syn 2.0.114", -] - [[package]] name = "hyli-noir-tools" version = "0.14.0" diff --git a/Cargo.toml b/Cargo.toml index 9084d68c2..6460cb8eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ members = [ "crates/contract-sdk", "crates/hyli-loadtest", "crates/hyli-model", - "crates/hyli-net-macros", "crates/hyli-crypto", "crates/hyli-turmoil-shims", "crates/hyli-verifiers", @@ -52,7 +51,6 @@ sdk = { version = "0.14.0", default-features = false, path = "crates/contract-sd hyli-contract-sdk = { version = "0.14.0", default-features = false, path = "crates/contract-sdk", package = "hyli-contract-sdk" } client-sdk = { version = "0.14.0", default-features = false, path = "crates/client-sdk", package = "hyli-client-sdk" } hyli-net = { version = "0.14.0", default-features = false, path = "crates/hyli-net", package = "hyli-net" } -hyli-net-macros = { version = "0.14.0", default-features = false, path = "crates/hyli-net-macros", package = "hyli-net-macros" } hyli-model = { version = "0.14.0", default-features = false, path = "crates/hyli-model", package = "hyli-model" } hyli-crypto = { version = "0.14.0", default-features = false, path = "crates/hyli-crypto", package = "hyli-crypto" } hyli-turmoil-shims = { version = "0.14.0", default-features = false, path = "crates/hyli-turmoil-shims", package = "hyli-turmoil-shims", features = [ diff --git a/crates/hyli-net-macros/Cargo.toml b/crates/hyli-net-macros/Cargo.toml deleted file mode 100644 index 13c2c615d..000000000 --- a/crates/hyli-net-macros/Cargo.toml +++ /dev/null @@ -1,16 +0,0 @@ -[package] -name = "hyli-net-macros" -description = "Proc macros for hyli-net" -license = "MIT" -version = { workspace = true } -edition = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } -rust-version = { workspace = true } - -[lib] -proc-macro = true - -[dependencies] -quote = { workspace = true } -syn = { workspace = true, features = ["full"] } diff --git a/crates/hyli-net-macros/src/lib.rs b/crates/hyli-net-macros/src/lib.rs deleted file mode 100644 index 703c36ea4..000000000 --- a/crates/hyli-net-macros/src/lib.rs +++ /dev/null @@ -1,499 +0,0 @@ -use proc_macro::TokenStream; -use quote::{format_ident, quote}; -use syn::{ - parse::Parser, parse_macro_input, parse_quote, punctuated::Punctuated, GenericParam, ImplItem, - ImplItemFn, ItemImpl, Meta, Token, Type, TypeParamBound, WherePredicate, -}; - -fn has_method(item: &ItemImpl, name: &str) -> bool { - item.items.iter().any(|it| { - if let ImplItem::Fn(f) = it { - f.sig.ident == name - } else { - false - } - }) -} - -fn parse_mode_and_event_out(attr: TokenStream) -> syn::Result<(String, Option)> { - let parser = Punctuated::::parse_terminated; - let metas = parser.parse2(attr.into())?; - let mut mode: Option = None; - let mut event_out: Option = None; - - for meta in metas { - match meta { - Meta::Path(path) => { - let Some(seg) = path.segments.last() else { - continue; - }; - let ident = seg.ident.to_string(); - if ident == "inbound" || ident == "outbound_tick" || ident == "all" { - mode = Some(ident); - } else { - return Err(syn::Error::new_spanned( - path, - "unknown mode: expected `inbound`, `outbound_tick`, or `all`", - )); - } - } - Meta::List(list) if list.path.is_ident("event_out") => { - event_out = Some(syn::parse2::(list.tokens)?); - } - other => { - return Err(syn::Error::new_spanned( - other, - "invalid attribute args: expected `inbound`, `outbound_tick`, or `all`, optionally `event_out(Type)`", - )); - } - } - } - - let Some(mode) = mode else { - return Err(syn::Error::new_spanned( - quote!(tcp_middleware), - "missing mode: use `inbound`, `outbound_tick`, or `all`", - )); - }; - - Ok((mode, event_out)) -} - -fn add_req_res_bounds(generics: &mut syn::Generics, req: &syn::Ident, res: &syn::Ident) { - let where_clause = generics.make_where_clause(); - where_clause.predicates.push(parse_quote!( - #req: borsh::BorshSerialize - + borsh::BorshDeserialize - + std::fmt::Debug - + Send - + crate::tcp::TcpMessageLabel - + 'static - )); - where_clause.predicates.push(parse_quote!( - #res: borsh::BorshSerialize + borsh::BorshDeserialize + std::fmt::Debug + crate::tcp::TcpMessageLabel - )); -} - -fn has_type_param(generics: &syn::Generics, name: &str) -> bool { - generics.params.iter().any(|p| match p { - GenericParam::Type(t) => t.ident == name, - _ => false, - }) -} - -fn has_declared_bound_for_type(item_impl: &ItemImpl, name: &str) -> bool { - // Generic parameter bound: `impl ...` - let in_params = item_impl.generics.params.iter().any(|p| match p { - GenericParam::Type(t) if t.ident == name => !t.bounds.is_empty(), - _ => false, - }); - if in_params { - return true; - } - - // Where-clause bound: `where T: Bound` - let Some(where_clause) = &item_impl.generics.where_clause else { - return false; - }; - - where_clause.predicates.iter().any(|pred| match pred { - WherePredicate::Type(ty_pred) => { - let is_target_type = if let Type::Path(type_path) = &ty_pred.bounded_ty { - type_path.qself.is_none() - && type_path.path.segments.len() == 1 - && type_path.path.segments[0].ident == name - } else { - false - }; - let has_any_bound = ty_pred - .bounds - .iter() - .any(|b| matches!(b, TypeParamBound::Trait(_) | TypeParamBound::Lifetime(_))); - is_target_type && has_any_bound - } - _ => false, - }) -} - -fn maybe_add_default_impl_req_bound(item_impl: &mut ItemImpl, req: &syn::Ident) { - if has_declared_bound_for_type(item_impl, &req.to_string()) { - return; - } - let where_clause = item_impl.generics.make_where_clause(); - where_clause.predicates.push(parse_quote!( - #req: crate::tcp::middleware::TcpReqBound - )); -} - -fn maybe_add_default_impl_res_bound(item_impl: &mut ItemImpl, res: &syn::Ident) { - if has_declared_bound_for_type(item_impl, &res.to_string()) { - return; - } - let where_clause = item_impl.generics.make_where_clause(); - where_clause.predicates.push(parse_quote!( - #res: crate::tcp::middleware::TcpResBound - )); -} - -fn method_has_type_param(method: &ImplItemFn, name: &str) -> bool { - method.sig.generics.params.iter().any(|p| match p { - GenericParam::Type(t) => t.ident == name, - _ => false, - }) -} - -fn add_req_bound_if_present(method: &mut ImplItemFn, req: &syn::Ident) { - if !method_has_type_param(method, &req.to_string()) { - return; - } - let where_clause = method.sig.generics.make_where_clause(); - where_clause.predicates.push(parse_quote!( - #req: borsh::BorshSerialize - + borsh::BorshDeserialize - + std::fmt::Debug - + Send - + crate::tcp::TcpMessageLabel - + 'static - )); -} - -fn add_res_bound_if_present(method: &mut ImplItemFn, res: &syn::Ident) { - if !method_has_type_param(method, &res.to_string()) { - return; - } - let where_clause = method.sig.generics.make_where_clause(); - where_clause.predicates.push(parse_quote!( - #res: borsh::BorshSerialize + borsh::BorshDeserialize + std::fmt::Debug + crate::tcp::TcpMessageLabel - )); -} - -fn add_s_bound_if_present( - method: &mut ImplItemFn, - s: &syn::Ident, - req: &syn::Ident, - res: &syn::Ident, -) { - if !method_has_type_param(method, &s.to_string()) { - return; - } - let where_clause = method.sig.generics.make_where_clause(); - where_clause.predicates.push(parse_quote!( - #s: crate::tcp::middleware::TcpServerLike<#req, #res, EventOut = crate::tcp::TcpEvent<#req>> - )); -} - -fn inject_hook_bounds( - item_impl: &mut ItemImpl, - mode: &str, - req_ident: &syn::Ident, - res_ident: &syn::Ident, -) { - let s_ident = format_ident!("S"); - for item in &mut item_impl.items { - let ImplItem::Fn(method) = item else { - continue; - }; - - let name = method.sig.ident.to_string(); - if !matches!(name.as_str(), "inbound" | "outbound_error" | "tick") { - continue; - } - - match mode { - "inbound" => { - add_req_bound_if_present(method, req_ident); - add_res_bound_if_present(method, res_ident); - add_s_bound_if_present(method, &s_ident, req_ident, res_ident); - } - "outbound_tick" => { - add_req_bound_if_present(method, req_ident); - add_s_bound_if_present(method, &s_ident, req_ident, res_ident); - } - "all" => { - add_s_bound_if_present(method, &s_ident, req_ident, res_ident); - } - _ => {} - } - } -} - -#[proc_macro_attribute] -pub fn tcp_middleware(attr: TokenStream, item: TokenStream) -> TokenStream { - let mut item_impl = parse_macro_input!(item as ItemImpl); - if item_impl.trait_.is_some() { - return syn::Error::new_spanned( - &item_impl, - "#[tcp_middleware(...)] must be used on an inherent impl block", - ) - .to_compile_error() - .into(); - } - - let (mode, event_out) = match parse_mode_and_event_out(attr) { - Ok(parsed) => parsed, - Err(err) => return err.to_compile_error().into(), - }; - - let self_ty = item_impl.self_ty.clone(); - let req_ident = format_ident!("__TcpReq"); - let res_ident = format_ident!("__TcpRes"); - - match mode.as_str() { - "inbound" => { - let req_method_ident = format_ident!("Req"); - let res_method_ident = format_ident!("Res"); - inject_hook_bounds( - &mut item_impl, - "inbound", - &req_method_ident, - &res_method_ident, - ); - - let Some(event_out) = event_out else { - return syn::Error::new_spanned( - &item_impl, - "missing `event_out(Type)` for inbound mode, e.g. #[tcp_middleware(inbound, event_out(TcpEvent<__TcpReq>))]", - ) - .to_compile_error() - .into(); - }; - - let mut in_g = item_impl.generics.clone(); - in_g.params.push(parse_quote!(#req_ident)); - in_g.params.push(parse_quote!(#res_ident)); - add_req_res_bounds(&mut in_g, &req_ident, &res_ident); - let (in_impl_g, _, in_where) = in_g.split_for_impl(); - - let mut out_g = item_impl.generics.clone(); - out_g.params.push(parse_quote!(#req_ident)); - out_g.params.push(parse_quote!(#res_ident)); - add_req_res_bounds(&mut out_g, &req_ident, &res_ident); - let (out_impl_g, _, out_where) = out_g.split_for_impl(); - - let mut tick_g = item_impl.generics.clone(); - tick_g.params.push(parse_quote!(#req_ident)); - tick_g.params.push(parse_quote!(#res_ident)); - add_req_res_bounds(&mut tick_g, &req_ident, &res_ident); - let (tick_impl_g, _, tick_where) = tick_g.split_for_impl(); - - quote! { - #item_impl - - impl #in_impl_g crate::tcp::middleware::TcpInboundMiddleware<#req_ident, #res_ident> for #self_ty #in_where { - type EventOut = #event_out; - - fn on_event(&mut self, server: &mut S, event: crate::tcp::TcpEvent<#req_ident>) -> Option - where - S: crate::tcp::middleware::TcpServerLike<#req_ident, #res_ident, EventOut = crate::tcp::TcpEvent<#req_ident>>, - { - let mut cx = crate::tcp::middleware::InboundCx::<#req_ident, #res_ident, S>::new(server); - self.inbound(&mut cx, event) - } - } - - impl #out_impl_g crate::tcp::middleware::TcpOutboundMiddleware<#req_ident, #res_ident> for #self_ty #out_where {} - - impl #tick_impl_g crate::tcp::middleware::TcpTickMiddleware<#req_ident, #res_ident> for #self_ty #tick_where {} - } - .into() - } - "outbound_tick" => { - if !has_type_param(&item_impl.generics, "Res") { - return syn::Error::new_spanned( - &item_impl, - "`outbound_tick`/`all` mode requires a `Res` type parameter on the impl (e.g. impl Type)", - ) - .to_compile_error() - .into(); - } - - let res_self_ident = format_ident!("Res"); - maybe_add_default_impl_res_bound(&mut item_impl, &res_self_ident); - let req_method_ident = format_ident!("Req"); - inject_hook_bounds( - &mut item_impl, - "outbound_tick", - &req_method_ident, - &res_self_ident, - ); - let has_outbound = has_method(&item_impl, "outbound_error"); - let has_tick = has_method(&item_impl, "tick"); - let has_next_wakeup = has_method(&item_impl, "next_wakeup"); - - let mut in_g = item_impl.generics.clone(); - in_g.params.push(parse_quote!(#req_ident)); - add_req_res_bounds(&mut in_g, &req_ident, &res_self_ident); - let (in_impl_g, _, in_where) = in_g.split_for_impl(); - - let mut out_g = item_impl.generics.clone(); - out_g.params.push(parse_quote!(#req_ident)); - add_req_res_bounds(&mut out_g, &req_ident, &res_self_ident); - let (out_impl_g, _, out_where) = out_g.split_for_impl(); - - let mut tick_g = item_impl.generics.clone(); - tick_g.params.push(parse_quote!(#req_ident)); - add_req_res_bounds(&mut tick_g, &req_ident, &res_self_ident); - let (tick_impl_g, _, tick_where) = tick_g.split_for_impl(); - - let outbound_body = if has_outbound { - quote! { self.outbound_error(&mut cx, ctx) } - } else { - quote! { crate::tcp::middleware::SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())) } - }; - let tick_body = if has_tick { - quote! { self.tick(&mut cx) } - } else { - quote! {} - }; - let next_wakeup_body = if has_next_wakeup { - quote! { self.next_wakeup() } - } else { - quote! { None } - }; - - quote! { - #item_impl - - impl #in_impl_g crate::tcp::middleware::TcpInboundMiddleware<#req_ident, #res_self_ident> for #self_ty #in_where { - type EventOut = crate::tcp::TcpEvent<#req_ident>; - - fn on_event(&mut self, server: &mut S, event: crate::tcp::TcpEvent<#req_ident>) -> Option - where - S: crate::tcp::middleware::TcpServerLike<#req_ident, #res_self_ident, EventOut = crate::tcp::TcpEvent<#req_ident>>, - { - Some(event) - } - } - - impl #out_impl_g crate::tcp::middleware::TcpOutboundMiddleware<#req_ident, #res_self_ident> for #self_ty #out_where { - fn on_send_error(&mut self, server: &mut S, ctx: &crate::tcp::middleware::SendErrorContext<#res_self_ident>) -> crate::tcp::middleware::SendErrorOutcome - where - S: crate::tcp::middleware::TcpServerLike<#req_ident, #res_self_ident, EventOut = crate::tcp::TcpEvent<#req_ident>>, - { - let mut cx = crate::tcp::middleware::OutboundCx::<#req_ident, #res_self_ident, S>::new(server); - #outbound_body - } - } - - impl #tick_impl_g crate::tcp::middleware::TcpTickMiddleware<#req_ident, #res_self_ident> for #self_ty #tick_where { - fn on_tick(&mut self, server: &mut S) - where - S: crate::tcp::middleware::TcpServerLike<#req_ident, #res_self_ident, EventOut = crate::tcp::TcpEvent<#req_ident>>, - { - let mut cx = crate::tcp::middleware::TickCx::<#req_ident, #res_self_ident, S>::new(server); - #tick_body - } - - fn next_wakeup(&self) -> Option { - #next_wakeup_body - } - } - } - .into() - } - "all" => { - if !has_type_param(&item_impl.generics, "Req") || !has_type_param(&item_impl.generics, "Res") { - return syn::Error::new_spanned( - &item_impl, - "`all` mode requires `Req` and `Res` type parameters on the impl (e.g. impl Type)", - ) - .to_compile_error() - .into(); - } - - let Some(event_out_ty) = event_out else { - return syn::Error::new_spanned( - &item_impl, - "missing `event_out(Type)` for all mode, e.g. #[tcp_middleware(all, event_out(TcpInboundMessage))]", - ) - .to_compile_error() - .into(); - }; - - let has_inbound = has_method(&item_impl, "inbound"); - let has_outbound = has_method(&item_impl, "outbound_error"); - let has_tick = has_method(&item_impl, "tick"); - let has_next_wakeup = has_method(&item_impl, "next_wakeup"); - - let req_self_ident = format_ident!("Req"); - let res_self_ident = format_ident!("Res"); - maybe_add_default_impl_req_bound(&mut item_impl, &req_self_ident); - maybe_add_default_impl_res_bound(&mut item_impl, &res_self_ident); - inject_hook_bounds(&mut item_impl, "all", &req_self_ident, &res_self_ident); - - let mut all_g = item_impl.generics.clone(); - add_req_res_bounds(&mut all_g, &req_self_ident, &res_self_ident); - let (all_impl_g, _, all_where) = all_g.split_for_impl(); - - let inbound_body = if has_inbound { - quote! { - let mut cx = crate::tcp::middleware::InboundCx::::new(server); - self.inbound(&mut cx, event) - } - } else { - quote! { Some(event) } - }; - let outbound_body = if has_outbound { - quote! { self.outbound_error(&mut cx, ctx) } - } else { - quote! { crate::tcp::middleware::SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())) } - }; - let tick_body = if has_tick { - quote! { self.tick(&mut cx) } - } else { - quote! {} - }; - let next_wakeup_body = if has_next_wakeup { - quote! { self.next_wakeup() } - } else { - quote! { None } - }; - - quote! { - #item_impl - - impl #all_impl_g crate::tcp::middleware::TcpInboundMiddleware for #self_ty #all_where { - type EventOut = #event_out_ty; - - fn on_event(&mut self, server: &mut S, event: crate::tcp::TcpEvent) -> Option - where - S: crate::tcp::middleware::TcpServerLike>, - { - #inbound_body - } - } - - impl #all_impl_g crate::tcp::middleware::TcpOutboundMiddleware for #self_ty #all_where { - fn on_send_error(&mut self, server: &mut S, ctx: &crate::tcp::middleware::SendErrorContext) -> crate::tcp::middleware::SendErrorOutcome - where - S: crate::tcp::middleware::TcpServerLike>, - { - let mut cx = crate::tcp::middleware::OutboundCx::::new(server); - #outbound_body - } - } - - impl #all_impl_g crate::tcp::middleware::TcpTickMiddleware for #self_ty #all_where { - fn on_tick(&mut self, server: &mut S) - where - S: crate::tcp::middleware::TcpServerLike>, - { - let mut cx = crate::tcp::middleware::TickCx::::new(server); - #tick_body - } - - fn next_wakeup(&self) -> Option { - #next_wakeup_body - } - } - } - .into() - } - _ => syn::Error::new_spanned( - &item_impl, - format!("unknown tcp_middleware mode `{mode}`: expected `inbound`, `outbound_tick`, or `all`"), - ) - .to_compile_error() - .into(), - } -} diff --git a/crates/hyli-net/Cargo.toml b/crates/hyli-net/Cargo.toml index d85342618..348da412b 100644 --- a/crates/hyli-net/Cargo.toml +++ b/crates/hyli-net/Cargo.toml @@ -13,7 +13,6 @@ sdk = { workspace = true, features = ["full-model"] } hyli-crypto = { workspace = true } hyli-turmoil-shims = { workspace = true, features = ["otlp"] } hyli-bus = { workspace = true, optional = true } -hyli-net-macros = { workspace = true } anyhow = { workspace = true } borsh = { workspace = true } diff --git a/crates/hyli-net/src/tcp/middleware/impls.rs b/crates/hyli-net/src/tcp/middleware/impls.rs index eee7bbacc..274325928 100644 --- a/crates/hyli-net/src/tcp/middleware/impls.rs +++ b/crates/hyli-net/src/tcp/middleware/impls.rs @@ -2,26 +2,29 @@ use std::collections::{HashSet, VecDeque}; use std::marker::PhantomData; use std::time::Duration; -use hyli_net_macros::tcp_middleware; use tokio::time::Instant; use crate::tcp::{TcpEvent, TcpHeaders}; -use super::{InboundCx, OutboundCx, SendErrorContext, SendErrorOutcome}; +use super::{SendErrorContext, SendErrorOutcome, TcpMiddleware, TcpServerLike}; #[derive(Default)] pub struct DropOnError; -#[tcp_middleware(inbound, event_out(TcpEvent<__TcpReq>))] -impl DropOnError { - fn inbound( - &mut self, - cx: &mut InboundCx, - event: TcpEvent, - ) -> Option> { +impl TcpMiddleware for DropOnError +where + Req: super::TcpReqBound, + Res: super::TcpResBound, +{ + type EventOut = TcpEvent; + + fn on_event(&mut self, server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { match &event { TcpEvent::Error { socket_addr, .. } | TcpEvent::Closed { socket_addr } => { - cx.drop_peer(socket_addr.clone()); + server.drop_peer_stream(socket_addr.clone()); } _ => {} } @@ -38,13 +41,17 @@ pub struct TcpInboundMessage { #[derive(Default)] pub struct MessageOnly; -#[tcp_middleware(inbound, event_out(__TcpReq))] -impl MessageOnly { - fn inbound( - &mut self, - _cx: &mut InboundCx, - event: TcpEvent, - ) -> Option { +impl TcpMiddleware for MessageOnly +where + Req: super::TcpReqBound, + Res: super::TcpResBound, +{ + type EventOut = Req; + + fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { match event { TcpEvent::Message { data, .. } => Some(data), TcpEvent::Closed { .. } | TcpEvent::Error { .. } => None, @@ -55,13 +62,17 @@ impl MessageOnly { #[derive(Default)] pub struct MessageWithMeta; -#[tcp_middleware(inbound, event_out(TcpInboundMessage<__TcpReq>))] -impl MessageWithMeta { - fn inbound( - &mut self, - _cx: &mut InboundCx, - event: TcpEvent, - ) -> Option> { +impl TcpMiddleware for MessageWithMeta +where + Req: super::TcpReqBound, + Res: super::TcpResBound, +{ + type EventOut = TcpInboundMessage; + + fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { match event { TcpEvent::Message { socket_addr, @@ -187,17 +198,17 @@ where } } -#[tcp_middleware(all, event_out(TcpInboundMessage))] -impl QueuedSendWithRetry +impl TcpMiddleware for QueuedSendWithRetry where Req: super::TcpReqBound, Res: super::TcpResBound + Clone, { - fn inbound( - &mut self, - _cx: &mut InboundCx, - event: TcpEvent, - ) -> Option> { + type EventOut = TcpInboundMessage; + + fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { match event { TcpEvent::Message { socket_addr, @@ -219,12 +230,11 @@ where } } - fn outbound_error( - &mut self, - cx: &mut OutboundCx, - ctx: &SendErrorContext, - ) -> SendErrorOutcome { - if !cx.connected(&ctx.socket_addr) { + fn on_send_error(&mut self, server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome + where + S: TcpServerLike>, + { + if !server.connected(&ctx.socket_addr) { self.unregister_streaming_peer(&ctx.socket_addr); return SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())); } @@ -236,7 +246,10 @@ where SendErrorOutcome::RetryScheduled } - fn tick(&mut self, cx: &mut super::TickCx) { + fn on_tick(&mut self, server: &mut S) + where + S: TcpServerLike>, + { if self.queues.is_empty() { return; } @@ -250,7 +263,7 @@ where break; } - if !cx.connected(&peer) { + if !server.connected(&peer) { self.unregister_streaming_peer(&peer); continue; } @@ -267,14 +280,14 @@ where continue; } - match cx.send(peer.clone(), front.msg.clone(), front.headers.clone()) { + match server.send(peer.clone(), front.msg.clone(), front.headers.clone()) { Ok(()) => { queue.pop_front(); } Err(_) => { front.retries += 1; if front.retries > self.max_retries { - cx.drop_peer(peer.clone()); + server.drop_peer_stream(peer.clone()); self.unregister_streaming_peer(&peer); } else { front.next_attempt_at = now + self.base_delay.mul_f64(front.retries as f64); @@ -327,17 +340,25 @@ impl RetryingSend { } } -#[tcp_middleware(outbound_tick)] -impl RetryingSend +impl TcpMiddleware for RetryingSend where + Req: super::TcpReqBound, Res: super::TcpResBound + Clone, { - fn outbound_error( - &mut self, - cx: &mut OutboundCx, - ctx: &SendErrorContext, - ) -> SendErrorOutcome { - if !cx.connected(&ctx.socket_addr) { + type EventOut = TcpEvent; + + fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { + Some(event) + } + + fn on_send_error(&mut self, server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome + where + S: TcpServerLike>, + { + if !server.connected(&ctx.socket_addr) { return SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())); } self.queue.push_back(PendingSend { @@ -350,7 +371,10 @@ where SendErrorOutcome::RetryScheduled } - fn tick(&mut self, cx: &mut super::TickCx) { + fn on_tick(&mut self, server: &mut S) + where + S: TcpServerLike>, + { if self.queue.is_empty() { return; } @@ -365,11 +389,11 @@ where continue; } - if !cx.connected(&pending.socket_addr) { + if !server.connected(&pending.socket_addr) { continue; } - match cx.send( + match server.send( pending.socket_addr.clone(), pending.msg.clone(), pending.headers.clone(), @@ -378,7 +402,7 @@ where Err(_) => { let next_retries = pending.retries + 1; if next_retries > self.max_retries { - cx.drop_peer(pending.socket_addr); + server.drop_peer_stream(pending.socket_addr); } else { pending.retries = next_retries; pending.next_attempt_at = diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs index 46f883e99..3857d224f 100644 --- a/crates/hyli-net/src/tcp/middleware/mod.rs +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -53,7 +53,7 @@ //! `TcpEvent::Message` to the `Req` payload and filters out `Error/Closed`. //! ```no_run //! # use hyli_net::tcp::{tcp_server::TcpServer, TcpEvent, TcpMessageLabel}; -//! # use hyli_net::tcp::middleware::{TcpInboundMiddleware, TcpServerWithMiddleware}; +//! # use hyli_net::tcp::middleware::{TcpMiddleware, TcpServerWithMiddleware}; //! # use borsh::{BorshDeserialize, BorshSerialize}; //! # #[derive(Clone, Debug, BorshSerialize, BorshDeserialize)] //! # struct Req; @@ -67,7 +67,7 @@ //! # } //! # //! # struct MessageOnly; -//! # impl TcpInboundMiddleware for MessageOnly { +//! # impl TcpMiddleware for MessageOnly { //! # type EventOut = Req; //! # fn on_event(&mut self, _server: &mut TcpServer, event: TcpEvent) -> Option { //! # match event { TcpEvent::Message { data, .. } => Some(data), _ => None } @@ -93,7 +93,6 @@ use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders, TcpMessageLabel}; mod impls; -pub use hyli_net_macros::tcp_middleware; pub use impls::{ DropOnError, MessageOnly, MessageWithMeta, QueuedSendWithRetry, QueuedSenderMiddleware, RetryingSend, TcpInboundMessage, @@ -192,50 +191,32 @@ impl TcpResBound for T where { } -pub trait TcpInboundMiddleware +/// Unified middleware trait. Implement this trait when one type needs to handle +/// inbound mapping, outbound errors and tick-based housekeeping. +pub trait TcpMiddleware where Req: TcpReqBound, Res: TcpResBound, { type EventOut; - /// Transform or filter inbound events before they are exposed to callers. - /// Returning `None` will cause the wrapper to keep listening. fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option where S: TcpServerLike>; -} -pub trait TcpOutboundMiddleware -where - Req: TcpReqBound, - Res: TcpResBound, -{ - /// Handle outbound send errors. The default behavior is to surface the error. - /// Implementations can enqueue retries or drop peers. fn on_send_error(&mut self, _server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome where S: TcpServerLike>, { SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())) } -} -pub trait TcpTickMiddleware -where - Req: TcpReqBound, - Res: TcpResBound, -{ - /// Called on each `listen_next()` iteration before waiting for events. - /// Use this to drive retry queues or housekeeping. fn on_tick(&mut self, _server: &mut S) where S: TcpServerLike>, { } - /// Optional wakeup time for the next middleware action. If present, the - /// wrapper will `select!` between the next event and this deadline. fn next_wakeup(&self) -> Option { None } @@ -391,10 +372,7 @@ where Req: TcpReqBound, Res: TcpResBound + Clone, S: TcpServerLike>, - M: TcpInboundMiddleware - + TcpOutboundMiddleware - + TcpTickMiddleware - + QueuedSenderMiddleware, + M: TcpMiddleware + QueuedSenderMiddleware, { /// Enqueue a message for ordered, retrying delivery to a specific peer. pub fn enqueue( @@ -438,9 +416,7 @@ where Req: TcpReqBound, Res: TcpResBound + Clone, S: TcpServerLike>, - M: TcpInboundMiddleware - + TcpOutboundMiddleware - + TcpTickMiddleware, + M: TcpMiddleware, { type EventOut = M::EventOut; type ConnectedClients<'a> @@ -516,9 +492,7 @@ where Req: TcpReqBound, Res: TcpResBound + Clone, S: TcpServerLike>, - M: TcpInboundMiddleware - + TcpOutboundMiddleware - + TcpTickMiddleware, + M: TcpMiddleware, { type Service = TcpServerWithMiddleware; From f86ce1ff79739f3ffc9f00a3f9c65212965ced36 Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Wed, 18 Feb 2026 17:58:08 +0100 Subject: [PATCH 08/18] clean a bit --- crates/hyli-net/src/tcp/middleware/impls.rs | 39 +----- crates/hyli-net/src/tcp/middleware/mod.rs | 125 +------------------- crates/hyli-net/src/tcp/tcp_server.rs | 36 ++++++ src/data_availability.rs | 6 +- 4 files changed, 44 insertions(+), 162 deletions(-) diff --git a/crates/hyli-net/src/tcp/middleware/impls.rs b/crates/hyli-net/src/tcp/middleware/impls.rs index 274325928..be6948a5e 100644 --- a/crates/hyli-net/src/tcp/middleware/impls.rs +++ b/crates/hyli-net/src/tcp/middleware/impls.rs @@ -32,37 +32,12 @@ where } } -pub struct TcpInboundMessage { - pub socket_addr: String, - pub data: Req, - pub headers: TcpHeaders, -} +pub type TcpInboundMessage = (String, Req, TcpHeaders); #[derive(Default)] pub struct MessageOnly; impl TcpMiddleware for MessageOnly -where - Req: super::TcpReqBound, - Res: super::TcpResBound, -{ - type EventOut = Req; - - fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option - where - S: TcpServerLike>, - { - match event { - TcpEvent::Message { data, .. } => Some(data), - TcpEvent::Closed { .. } | TcpEvent::Error { .. } => None, - } - } -} - -#[derive(Default)] -pub struct MessageWithMeta; - -impl TcpMiddleware for MessageWithMeta where Req: super::TcpReqBound, Res: super::TcpResBound, @@ -78,11 +53,7 @@ where socket_addr, data, headers, - } => Some(TcpInboundMessage { - socket_addr, - data, - headers, - }), + } => Some((socket_addr, data, headers)), TcpEvent::Closed { .. } | TcpEvent::Error { .. } => None, } } @@ -214,11 +185,7 @@ where socket_addr, data, headers, - } => Some(TcpInboundMessage { - socket_addr, - data, - headers, - }), + } => Some((socket_addr, data, headers)), TcpEvent::Closed { socket_addr } => { self.unregister_streaming_peer(&socket_addr); None diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs index 3857d224f..70210f2c6 100644 --- a/crates/hyli-net/src/tcp/middleware/mod.rs +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -1,89 +1,3 @@ -//! TcpServer middleware helpers. -//! -//! This module provides a wrapper around `TcpServer` that preserves the -//! `listen_next()` API while allowing synchronous middleware actions -//! (drop-on-error) and listen-driven retries (send retry queue progressed -//! inside `listen_next()`). -//! -//! # Example -//! ```no_run -//! use std::time::Duration; -//! use hyli_net::tcp::{ -//! tcp_server::TcpServer, -//! middleware::{preset, TcpServerExt}, -//! }; -//! # use hyli_net::tcp::{TcpEvent, TcpMessageLabel}; -//! # use borsh::{BorshDeserialize, BorshSerialize}; -//! # -//! # #[derive(Clone, Debug, BorshSerialize, BorshDeserialize)] -//! # struct Req; -//! # impl TcpMessageLabel for Req { -//! # fn message_label(&self) -> &'static str { "Req" } -//! # } -//! # #[derive(Clone, Debug, BorshSerialize, BorshDeserialize)] -//! # struct Res; -//! # impl TcpMessageLabel for Res { -//! # fn message_label(&self) -> &'static str { "Res" } -//! # } -//! # -//! # async fn example() -> anyhow::Result<()> { -//! let inner = TcpServer::::start(0, "Example").await?; -//! let mut server = hyli_net::tcp_stack!( -//! inner, -//! preset::drop_on_error(), -//! preset::retrying_send::(10, Duration::from_millis(100)), -//! ); -//! -//! while let Some(event) = server.listen_next().await { -//! match event { -//! TcpEvent::Message { socket_addr, data, headers } => { -//! // Handle inbound message... -//! let _ = server.send(socket_addr, Res, headers); -//! } -//! TcpEvent::Closed { .. } | TcpEvent::Error { .. } => { -//! // Drop-on-error is handled by middleware. -//! } -//! } -//! } -//! # Ok(()) -//! # } -//! ``` -//! -//! You can also map events to a different output type. This example maps -//! `TcpEvent::Message` to the `Req` payload and filters out `Error/Closed`. -//! ```no_run -//! # use hyli_net::tcp::{tcp_server::TcpServer, TcpEvent, TcpMessageLabel}; -//! # use hyli_net::tcp::middleware::{TcpMiddleware, TcpServerWithMiddleware}; -//! # use borsh::{BorshDeserialize, BorshSerialize}; -//! # #[derive(Clone, Debug, BorshSerialize, BorshDeserialize)] -//! # struct Req; -//! # impl TcpMessageLabel for Req { -//! # fn message_label(&self) -> &'static str { "Req" } -//! # } -//! # #[derive(Clone, Debug, BorshSerialize, BorshDeserialize)] -//! # struct Res; -//! # impl TcpMessageLabel for Res { -//! # fn message_label(&self) -> &'static str { "Res" } -//! # } -//! # -//! # struct MessageOnly; -//! # impl TcpMiddleware for MessageOnly { -//! # type EventOut = Req; -//! # fn on_event(&mut self, _server: &mut TcpServer, event: TcpEvent) -> Option { -//! # match event { TcpEvent::Message { data, .. } => Some(data), _ => None } -//! # } -//! # } -//! # -//! # async fn example() -> anyhow::Result<()> { -//! let inner = TcpServer::::start(0, "Example").await?; -//! let mut server = TcpServerWithMiddleware::new(inner, MessageOnly); -//! while let Some(req) = server.listen_next().await { -//! // req is already the decoded payload -//! } -//! # Ok(()) -//! # } -//! ``` - use std::marker::PhantomData; use tokio::time::Instant; @@ -94,8 +8,8 @@ use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders, TcpMessageLabel}; mod impls; pub use impls::{ - DropOnError, MessageOnly, MessageWithMeta, QueuedSendWithRetry, QueuedSenderMiddleware, - RetryingSend, TcpInboundMessage, + DropOnError, MessageOnly, QueuedSendWithRetry, QueuedSenderMiddleware, RetryingSend, + TcpInboundMessage, }; pub mod preset { @@ -500,38 +414,3 @@ where TcpServerWithMiddleware::new(inner, self.middleware) } } - -impl TcpServerLike for TcpServer -where - Req: TcpReqBound, - Res: TcpResBound, -{ - type EventOut = TcpEvent; - type ConnectedClients<'a> - = crate::tcp::tcp_server::ConnectedClients<'a> - where - Self: 'a; - - async fn listen_next(&mut self) -> Option { - TcpServer::listen_next(self).await - } - - fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()> { - TcpServer::send(self, socket_addr, msg, headers) - } - - fn send_ref(&mut self, socket_addr: &str, msg: &Res, headers: &TcpHeaders) -> anyhow::Result<()> - where - Res: Clone, - { - TcpServer::send_ref(self, socket_addr, msg, headers) - } - - fn connected_clients(&self) -> Self::ConnectedClients<'_> { - TcpServer::connected_clients(self) - } - - fn drop_peer_stream(&mut self, peer_ip: String) { - TcpServer::drop_peer_stream(self, peer_ip) - } -} diff --git a/crates/hyli-net/src/tcp/tcp_server.rs b/crates/hyli-net/src/tcp/tcp_server.rs index f085ea076..2f7909e51 100644 --- a/crates/hyli-net/src/tcp/tcp_server.rs +++ b/crates/hyli-net/src/tcp/tcp_server.rs @@ -29,6 +29,7 @@ use hyli_turmoil_shims::collections::HashMap; use tracing::{debug, error, trace, warn}; use super::{tcp_client::TcpClient, SocketStream, TcpEvent}; +use crate::tcp::middleware::{TcpReqBound, TcpResBound, TcpServerLike}; type TcpSender = SplitSink; type TcpReceiver = SplitStream; @@ -698,6 +699,41 @@ where } } +impl TcpServerLike for TcpServer +where + Req: TcpReqBound, + Res: TcpResBound, +{ + type EventOut = TcpEvent; + type ConnectedClients<'a> + = ConnectedClients<'a> + where + Self: 'a; + + async fn listen_next(&mut self) -> Option { + TcpServer::listen_next(self).await + } + + fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()> { + TcpServer::send(self, socket_addr, msg, headers) + } + + fn send_ref(&mut self, socket_addr: &str, msg: &Res, headers: &TcpHeaders) -> anyhow::Result<()> + where + Res: Clone, + { + TcpServer::send_ref(self, socket_addr, msg, headers) + } + + fn connected_clients(&self) -> Self::ConnectedClients<'_> { + TcpServer::connected_clients(self) + } + + fn drop_peer_stream(&mut self, peer_ip: String) { + TcpServer::drop_peer_stream(self, peer_ip) + } +} + #[cfg(test)] pub mod tests { use std::time::Duration; diff --git a/src/data_availability.rs b/src/data_availability.rs index fb7a983af..4a71f8f00 100644 --- a/src/data_availability.rs +++ b/src/data_availability.rs @@ -9,8 +9,8 @@ use hyli_modules::telemetry::{global_meter_or_panic, Counter, Gauge, KeyValue}; use hyli_modules::{bus::SharedMessageBus, modules::Module}; use hyli_modules::{log_error, module_bus_client, module_handle_messages}; use hyli_net::tcp::middleware::{ - middleware_layer, DropOnError, QueuedSendWithRetry, TcpInboundMessage, TcpServerExt, - TcpServerLike, TcpServerWithMiddleware, + middleware_layer, DropOnError, QueuedSendWithRetry, TcpServerExt, TcpServerLike, + TcpServerWithMiddleware, }; use tokio::task::JoinHandle; @@ -558,7 +558,7 @@ impl DataAvailability { } } - Some(TcpInboundMessage { socket_addr, data, .. }) = server.listen_next() => { + Some((socket_addr, data, _headers)) = server.listen_next() => { match data { DataAvailabilityRequest::StreamFromHeight(start_height) => { _ = log_error!( From 977c4d807bdaabff7d94d626fa0d423636133c0b Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Wed, 18 Feb 2026 18:02:18 +0100 Subject: [PATCH 09/18] more cleaning --- crates/hyli-net/src/tcp/middleware/mod.rs | 96 ----------------------- 1 file changed, 96 deletions(-) diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs index 70210f2c6..f9faf13a7 100644 --- a/crates/hyli-net/src/tcp/middleware/mod.rs +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -12,44 +12,6 @@ pub use impls::{ TcpInboundMessage, }; -pub mod preset { - use std::time::Duration; - - use super::{DropOnError, QueuedSendWithRetry, RetryingSend}; - - pub fn drop_on_error() -> DropOnError { - DropOnError - } - - pub fn retrying_send(max_retries: usize, base_delay: Duration) -> RetryingSend { - RetryingSend::new(max_retries, base_delay) - } - - pub fn drop_and_retry( - max_retries: usize, - base_delay: Duration, - ) -> (DropOnError, RetryingSend) { - (drop_on_error(), retrying_send(max_retries, base_delay)) - } - - pub fn queued_send_with_retry( - max_retries: usize, - base_delay: Duration, - ) -> QueuedSendWithRetry { - QueuedSendWithRetry::new(max_retries, base_delay) - } -} - -#[macro_export] -macro_rules! tcp_stack { - ($server:expr, $($middleware:expr),+ $(,)?) => {{ - use $crate::tcp::middleware::{middleware_layer, TcpServerExt}; - let server = $server; - $(let server = server.layer(middleware_layer($middleware));)+ - server - }}; -} - pub trait Layer { type Service; fn layer(self, inner: S) -> Self::Service; @@ -179,64 +141,6 @@ pub trait TcpServerLike { } } -pub struct InboundCx<'a, Req, Res, S> -where - Req: BorshDeserialize, - S: TcpServerLike>, -{ - server: &'a mut S, - _marker: PhantomData<(Req, Res)>, -} - -impl<'a, Req, Res, S> InboundCx<'a, Req, Res, S> -where - Req: BorshDeserialize, - S: TcpServerLike>, -{ - pub fn new(server: &'a mut S) -> Self { - Self { - server, - _marker: PhantomData, - } - } - - pub fn connected(&self, socket_addr: &str) -> bool { - self.server.connected(socket_addr) - } - - pub fn drop_peer(&mut self, peer_ip: String) { - self.server.drop_peer_stream(peer_ip); - } - - pub fn send( - &mut self, - socket_addr: String, - msg: Res, - headers: TcpHeaders, - ) -> anyhow::Result<()> { - self.server.send(socket_addr, msg, headers) - } - - pub fn send_ref( - &mut self, - socket_addr: &str, - msg: &Res, - headers: &TcpHeaders, - ) -> anyhow::Result<()> - where - Res: Clone, - { - self.server.send_ref(socket_addr, msg, headers) - } - - pub fn server_mut(&mut self) -> &mut S { - self.server - } -} - -pub type OutboundCx<'a, Req, Res, S> = InboundCx<'a, Req, Res, S>; -pub type TickCx<'a, Req, Res, S> = InboundCx<'a, Req, Res, S>; - /// Tower-style layering helper for TCP servers and already-layered services. /// /// # Example From e4e9bf1eaf020de081a6f2ad77e33b1c205cfe28 Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Wed, 18 Feb 2026 19:34:59 +0100 Subject: [PATCH 10/18] Refacto queued sender to use retrysender --- crates/hyli-net/src/tcp/middleware/impls.rs | 149 ++++++-------------- crates/hyli-net/src/tcp/middleware/mod.rs | 81 +++++------ src/data_availability.rs | 8 +- 3 files changed, 84 insertions(+), 154 deletions(-) diff --git a/crates/hyli-net/src/tcp/middleware/impls.rs b/crates/hyli-net/src/tcp/middleware/impls.rs index be6948a5e..c34d745ed 100644 --- a/crates/hyli-net/src/tcp/middleware/impls.rs +++ b/crates/hyli-net/src/tcp/middleware/impls.rs @@ -59,58 +59,37 @@ where } } -struct QueuedOutbound { - msg: Res, - headers: TcpHeaders, - retries: usize, - next_attempt_at: Instant, -} - pub struct QueuedSendWithRetry { - max_retries: usize, - base_delay: Duration, - max_per_tick: usize, + retrying: RetryingSend, streaming_peers: HashSet, - queues: std::collections::HashMap>>, _marker: PhantomData, } impl QueuedSendWithRetry { pub fn new(max_retries: usize, base_delay: Duration) -> Self { Self { - max_retries, - base_delay, - max_per_tick: 64, + retrying: RetryingSend::new(max_retries, base_delay), streaming_peers: HashSet::new(), - queues: std::collections::HashMap::new(), _marker: PhantomData, } } pub fn max_per_tick(mut self, max_per_tick: usize) -> Self { - self.max_per_tick = max_per_tick.max(1); + self.retrying = self.retrying.max_per_tick(max_per_tick); self } + pub fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) { + self.retrying.enqueue_to_peer(socket_addr, msg, headers); + } + pub fn register_streaming_peer(&mut self, socket_addr: String) { self.streaming_peers.insert(socket_addr); } pub fn unregister_streaming_peer(&mut self, socket_addr: &str) { self.streaming_peers.remove(socket_addr); - self.queues.remove(socket_addr); - } - - pub fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) { - self.queues - .entry(socket_addr) - .or_default() - .push_back(QueuedOutbound { - msg, - headers, - retries: 0, - next_attempt_at: Instant::now(), - }); + self.retrying.drop_peer(socket_addr); } pub fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) @@ -128,7 +107,6 @@ where Req: super::TcpReqBound, Res: super::TcpResBound + Clone, { - fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders); fn register_streaming_peer(&mut self, socket_addr: String); fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders); } @@ -138,33 +116,14 @@ where Req: super::TcpReqBound, Res: super::TcpResBound + Clone, { - fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) { - self.queues - .entry(socket_addr) - .or_default() - .push_back(QueuedOutbound { - msg, - headers, - retries: 0, - next_attempt_at: Instant::now(), - }); - } - fn register_streaming_peer(&mut self, socket_addr: String) { self.streaming_peers.insert(socket_addr); } fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) { for peer in self.streaming_peers.clone() { - self.queues - .entry(peer) - .or_default() - .push_back(QueuedOutbound { - msg: msg.clone(), - headers: headers.clone(), - retries: 0, - next_attempt_at: Instant::now(), - }); + self.retrying + .enqueue_to_peer(peer, msg.clone(), headers.clone()); } } } @@ -197,6 +156,21 @@ where } } + fn on_send( + &mut self, + _server: &mut S, + socket_addr: String, + msg: Res, + headers: TcpHeaders, + ) -> anyhow::Result<()> + where + S: TcpServerLike>, + Res: Clone, + { + self.enqueue_to_peer(socket_addr, msg, headers); + Ok(()) + } + fn on_send_error(&mut self, server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome where S: TcpServerLike>, @@ -205,7 +179,7 @@ where self.unregister_streaming_peer(&ctx.socket_addr); return SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())); } - self.enqueue_to_peer( + self.retrying.enqueue_to_peer( ctx.socket_addr.clone(), ctx.msg.clone(), ctx.headers.clone(), @@ -217,62 +191,12 @@ where where S: TcpServerLike>, { - if self.queues.is_empty() { - return; - } - - let mut processed = 0usize; - let now = Instant::now(); - let peers: Vec = self.queues.keys().cloned().collect(); - - for peer in peers { - if processed >= self.max_per_tick { - break; - } - - if !server.connected(&peer) { - self.unregister_streaming_peer(&peer); - continue; - } - - let Some(queue) = self.queues.get_mut(&peer) else { - continue; - }; - - let Some(front) = queue.front_mut() else { - continue; - }; - - if front.next_attempt_at > now { - continue; - } - - match server.send(peer.clone(), front.msg.clone(), front.headers.clone()) { - Ok(()) => { - queue.pop_front(); - } - Err(_) => { - front.retries += 1; - if front.retries > self.max_retries { - server.drop_peer_stream(peer.clone()); - self.unregister_streaming_peer(&peer); - } else { - front.next_attempt_at = now + self.base_delay.mul_f64(front.retries as f64); - } - } - } - - processed += 1; - } - - self.queues.retain(|_, queue| !queue.is_empty()); + self.streaming_peers.retain(|peer| server.connected(peer)); + self.retrying.on_tick(server); } fn next_wakeup(&self) -> Option { - self.queues - .values() - .filter_map(|queue| queue.front().map(|pending| pending.next_attempt_at)) - .min() + as TcpMiddleware>::next_wakeup(&self.retrying) } } @@ -305,6 +229,21 @@ impl RetryingSend { self.max_per_tick = max_per_tick.max(1); self } + + pub fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) { + self.queue.push_back(PendingSend { + socket_addr, + msg, + headers, + retries: 0, + next_attempt_at: Instant::now(), + }); + } + + pub fn drop_peer(&mut self, socket_addr: &str) { + self.queue + .retain(|pending| pending.socket_addr != socket_addr); + } } impl TcpMiddleware for RetryingSend diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs index f9faf13a7..4f6a9916c 100644 --- a/crates/hyli-net/src/tcp/middleware/mod.rs +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -80,6 +80,40 @@ where where S: TcpServerLike>; + /// Handle outbound sends. Default behavior is immediate send with + /// `on_send_error` fallback. + fn on_send( + &mut self, + server: &mut S, + socket_addr: String, + msg: Res, + headers: TcpHeaders, + ) -> anyhow::Result<()> + where + S: TcpServerLike>, + Res: Clone, + { + match server.send_ref(&socket_addr, &msg, &headers) { + Ok(()) => Ok(()), + Err(error) => { + let ctx = SendErrorContext { + socket_addr, + msg, + headers, + error, + }; + match self.on_send_error(server, &ctx) { + SendErrorOutcome::Handled | SendErrorOutcome::RetryScheduled => Ok(()), + SendErrorOutcome::DropPeer => { + server.drop_peer_stream(ctx.socket_addr.clone()); + Ok(()) + } + SendErrorOutcome::Unhandled(error) => Err(error), + } + } + } + } + fn on_send_error(&mut self, _server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome where S: TcpServerLike>, @@ -192,27 +226,6 @@ where S: TcpServerLike>, M: TcpMiddleware + QueuedSenderMiddleware, { - /// Enqueue a message for ordered, retrying delivery to a specific peer. - pub fn enqueue( - &mut self, - socket_addr: String, - msg: Res, - headers: TcpHeaders, - ) -> anyhow::Result<()> { - self.middleware.enqueue_to_peer(socket_addr, msg, headers); - Ok(()) - } - - /// Immediate send through the underlying TCP server without middleware queueing. - pub fn send_now( - &mut self, - socket_addr: String, - msg: Res, - headers: TcpHeaders, - ) -> anyhow::Result<()> { - self.inner.send(socket_addr, msg, headers) - } - /// Mark a peer as a streaming subscriber. pub fn register_streaming_peer(&mut self, socket_addr: String) { self.middleware.register_streaming_peer(socket_addr); @@ -222,11 +235,6 @@ where pub fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) { self.middleware.enqueue_to_streaming_peers(msg, headers); } - - /// Backward-compatible alias for `enqueue_to_streaming_peers`. - pub fn send_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) { - self.enqueue_to_streaming_peers(msg, headers) - } } impl TcpServerLike for TcpServerWithMiddleware @@ -268,25 +276,8 @@ where } fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()> { - match self.inner.send_ref(&socket_addr, &msg, &headers) { - Ok(()) => Ok(()), - Err(error) => { - let ctx = SendErrorContext { - socket_addr, - msg, - headers, - error, - }; - match self.middleware.on_send_error(&mut self.inner, &ctx) { - SendErrorOutcome::Handled | SendErrorOutcome::RetryScheduled => Ok(()), - SendErrorOutcome::DropPeer => { - self.inner.drop_peer_stream(ctx.socket_addr.clone()); - Ok(()) - } - SendErrorOutcome::Unhandled(error) => Err(error), - } - } - } + self.middleware + .on_send(&mut self.inner, socket_addr, msg, headers) } fn send_ref(&mut self, socket_addr: &str, msg: &Res, headers: &TcpHeaders) -> anyhow::Result<()> diff --git a/src/data_availability.rs b/src/data_availability.rs index 4a71f8f00..773bd7d34 100644 --- a/src/data_availability.rs +++ b/src/data_availability.rs @@ -614,7 +614,7 @@ impl DataAvailability { "📦 Found block at height {}, sending to {}", block_height, socket_addr ); - if let Err(e) = server.enqueue( + if let Err(e) = server.send( socket_addr.to_string(), DataAvailabilityEvent::SignedBlock(block), vec![], @@ -630,7 +630,7 @@ impl DataAvailability { "📦 Block at height {} not found in storage, sending BlockNotFound to {}", block_height, socket_addr ); - if let Err(e) = server.enqueue( + if let Err(e) = server.send( socket_addr.to_string(), DataAvailabilityEvent::BlockNotFound(block_height), vec![], @@ -646,7 +646,7 @@ impl DataAvailability { "📦 Error retrieving block at height {}: {:#}", block_height, e ); - if let Err(e) = server.enqueue( + if let Err(e) = server.send( socket_addr.to_string(), DataAvailabilityEvent::BlockNotFound(block_height), vec![], @@ -895,7 +895,7 @@ impl DataAvailability { for hash in processed_block_hashes { match self.blocks.get(&hash) { Ok(Some(block)) => { - _ = server.enqueue( + _ = server.send( peer_ip.clone(), DataAvailabilityEvent::SignedBlock(block), vec![], From 9b60e82e7f1e0786045e596243f113a070ed5d84 Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Thu, 19 Feb 2026 15:25:51 +0100 Subject: [PATCH 11/18] simplify --- crates/hyli-net/src/tcp/middleware/impls.rs | 144 +- crates/hyli-net/src/tcp/middleware/mod.rs | 76 +- crates/hyli-net/src/tcp/tcp_server.rs | 10 +- src/data_availability.rs | 1694 ++++++++++++++----- 4 files changed, 1309 insertions(+), 615 deletions(-) diff --git a/crates/hyli-net/src/tcp/middleware/impls.rs b/crates/hyli-net/src/tcp/middleware/impls.rs index c34d745ed..95e3ecbf7 100644 --- a/crates/hyli-net/src/tcp/middleware/impls.rs +++ b/crates/hyli-net/src/tcp/middleware/impls.rs @@ -1,5 +1,4 @@ -use std::collections::{HashSet, VecDeque}; -use std::marker::PhantomData; +use std::collections::VecDeque; use std::time::Duration; use tokio::time::Instant; @@ -59,147 +58,6 @@ where } } -pub struct QueuedSendWithRetry { - retrying: RetryingSend, - streaming_peers: HashSet, - _marker: PhantomData, -} - -impl QueuedSendWithRetry { - pub fn new(max_retries: usize, base_delay: Duration) -> Self { - Self { - retrying: RetryingSend::new(max_retries, base_delay), - streaming_peers: HashSet::new(), - _marker: PhantomData, - } - } - - pub fn max_per_tick(mut self, max_per_tick: usize) -> Self { - self.retrying = self.retrying.max_per_tick(max_per_tick); - self - } - - pub fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) { - self.retrying.enqueue_to_peer(socket_addr, msg, headers); - } - - pub fn register_streaming_peer(&mut self, socket_addr: String) { - self.streaming_peers.insert(socket_addr); - } - - pub fn unregister_streaming_peer(&mut self, socket_addr: &str) { - self.streaming_peers.remove(socket_addr); - self.retrying.drop_peer(socket_addr); - } - - pub fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) - where - Res: Clone, - { - for peer in self.streaming_peers.clone() { - self.enqueue_to_peer(peer, msg.clone(), headers.clone()); - } - } -} - -pub trait QueuedSenderMiddleware -where - Req: super::TcpReqBound, - Res: super::TcpResBound + Clone, -{ - fn register_streaming_peer(&mut self, socket_addr: String); - fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders); -} - -impl QueuedSenderMiddleware for QueuedSendWithRetry -where - Req: super::TcpReqBound, - Res: super::TcpResBound + Clone, -{ - fn register_streaming_peer(&mut self, socket_addr: String) { - self.streaming_peers.insert(socket_addr); - } - - fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) { - for peer in self.streaming_peers.clone() { - self.retrying - .enqueue_to_peer(peer, msg.clone(), headers.clone()); - } - } -} - -impl TcpMiddleware for QueuedSendWithRetry -where - Req: super::TcpReqBound, - Res: super::TcpResBound + Clone, -{ - type EventOut = TcpInboundMessage; - - fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option - where - S: TcpServerLike>, - { - match event { - TcpEvent::Message { - socket_addr, - data, - headers, - } => Some((socket_addr, data, headers)), - TcpEvent::Closed { socket_addr } => { - self.unregister_streaming_peer(&socket_addr); - None - } - TcpEvent::Error { socket_addr, .. } => { - self.unregister_streaming_peer(&socket_addr); - None - } - } - } - - fn on_send( - &mut self, - _server: &mut S, - socket_addr: String, - msg: Res, - headers: TcpHeaders, - ) -> anyhow::Result<()> - where - S: TcpServerLike>, - Res: Clone, - { - self.enqueue_to_peer(socket_addr, msg, headers); - Ok(()) - } - - fn on_send_error(&mut self, server: &mut S, ctx: &SendErrorContext) -> SendErrorOutcome - where - S: TcpServerLike>, - { - if !server.connected(&ctx.socket_addr) { - self.unregister_streaming_peer(&ctx.socket_addr); - return SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())); - } - self.retrying.enqueue_to_peer( - ctx.socket_addr.clone(), - ctx.msg.clone(), - ctx.headers.clone(), - ); - SendErrorOutcome::RetryScheduled - } - - fn on_tick(&mut self, server: &mut S) - where - S: TcpServerLike>, - { - self.streaming_peers.retain(|peer| server.connected(peer)); - self.retrying.on_tick(server); - } - - fn next_wakeup(&self) -> Option { - as TcpMiddleware>::next_wakeup(&self.retrying) - } -} - struct PendingSend { socket_addr: String, msg: Res, diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs index 4f6a9916c..e3f7d0289 100644 --- a/crates/hyli-net/src/tcp/middleware/mod.rs +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -7,10 +7,56 @@ use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders, TcpMessageLabel}; mod impls; -pub use impls::{ - DropOnError, MessageOnly, QueuedSendWithRetry, QueuedSenderMiddleware, RetryingSend, - TcpInboundMessage, -}; +pub use impls::{DropOnError, MessageOnly, RetryingSend, TcpInboundMessage}; + +#[macro_export] +macro_rules! tcp_middleware_chain_type { + ($req:ty, $res:ty; $($mw:ty),+ $(,)?) => { + $crate::tcp_middleware_chain_type!($req, $res, $($mw),+) + }; + ($req:ty, $res:ty, $($mw:ty),+ $(,)?) => { + $crate::tcp_middleware_chain_type!( + @inner + $req, + $res, + $crate::tcp::tcp_server::TcpServer<$req, $res>; + $($mw),+ + ) + }; + ($req:ty, $res:ty, base = $base:ty, $($mw:ty),* $(,)?) => { + $crate::tcp_middleware_chain_type!(@inner $req, $res, $base; $($mw),*) + }; + (@inner $req:ty, $res:ty, $inner:ty; ) => { + $inner + }; + (@inner $req:ty, $res:ty, $inner:ty; $head:ty $(, $tail:ty)*) => { + $crate::tcp::middleware::TcpServerWithMiddleware< + $head, + $req, + $res, + $crate::tcp_middleware_chain_type!(@inner $req, $res, $inner; $($tail),*) + > + }; +} + +#[macro_export] +macro_rules! tcp_server { + ( + request: $req:ty, + response: $res:ty, + middlewares: [$($mw:ty),+ $(,)?] + $(,)? + ) => { + $crate::tcp_middleware_chain_type!($req, $res; $($mw),+) + }; + ( + request: $req:ty, + response: $res:ty + $(,)? + ) => { + $crate::tcp::tcp_server::TcpServer<$req, $res> + }; +} pub trait Layer { type Service; @@ -93,7 +139,7 @@ where S: TcpServerLike>, Res: Clone, { - match server.send_ref(&socket_addr, &msg, &headers) { + match server.send(socket_addr.clone(), msg.clone(), headers.clone()) { Ok(()) => Ok(()), Err(error) => { let ctx = SendErrorContext { @@ -219,24 +265,6 @@ where } } -impl TcpServerWithMiddleware -where - Req: TcpReqBound, - Res: TcpResBound + Clone, - S: TcpServerLike>, - M: TcpMiddleware + QueuedSenderMiddleware, -{ - /// Mark a peer as a streaming subscriber. - pub fn register_streaming_peer(&mut self, socket_addr: String) { - self.middleware.register_streaming_peer(socket_addr); - } - - /// Queue a message to all registered streaming peers. - pub fn enqueue_to_streaming_peers(&mut self, msg: Res, headers: TcpHeaders) { - self.middleware.enqueue_to_streaming_peers(msg, headers); - } -} - impl TcpServerLike for TcpServerWithMiddleware where Req: TcpReqBound, @@ -284,7 +312,7 @@ where where Res: Clone, { - self.inner.send_ref(socket_addr, msg, headers) + self.send(socket_addr.to_string(), msg.clone(), headers.clone()) } fn connected_clients(&self) -> Self::ConnectedClients<'_> { diff --git a/crates/hyli-net/src/tcp/tcp_server.rs b/crates/hyli-net/src/tcp/tcp_server.rs index 2f7909e51..404444128 100644 --- a/crates/hyli-net/src/tcp/tcp_server.rs +++ b/crates/hyli-net/src/tcp/tcp_server.rs @@ -36,6 +36,12 @@ type TcpReceiver = SplitStream; pub struct ConnectedClients<'a>(std::collections::hash_map::Keys<'a, String, SocketStream>); +impl<'a> ConnectedClients<'a> { + pub fn len(&self) -> usize { + self.0.len() + } +} + impl<'a> Iterator for ConnectedClients<'a> { type Item = &'a String; @@ -862,7 +868,7 @@ pub mod tests { let client2_addr = server .connected_clients() .cloned() - .rfind(|addr| addr != &client1_addr) + .find(|addr| addr != &client1_addr) .unwrap(); server.raw_send_parallel( @@ -912,7 +918,7 @@ pub mod tests { let client2_addr = server .connected_clients() .cloned() - .rfind(|addr| addr != &client1_addr) + .find(|addr| addr != &client1_addr) .unwrap(); _ = server.send( diff --git a/src/data_availability.rs b/src/data_availability.rs index 773bd7d34..145b232a9 100644 --- a/src/data_availability.rs +++ b/src/data_availability.rs @@ -5,12 +5,10 @@ use hyli_modules::modules::data_availability::blocks_fjall::Blocks; use hyli_modules::utils::da_codec::DataAvailabilityServer as RawDataAvailabilityServer; //use hyli_modules::modules::data_availability::blocks_memory::Blocks; use hyli_modules::modules::da_listener::{DaStreamPoll, SignedDaStream}; -use hyli_modules::telemetry::{global_meter_or_panic, Counter, Gauge, KeyValue}; use hyli_modules::{bus::SharedMessageBus, modules::Module}; use hyli_modules::{log_error, module_bus_client, module_handle_messages}; use hyli_net::tcp::middleware::{ - middleware_layer, DropOnError, QueuedSendWithRetry, TcpServerExt, TcpServerLike, - TcpServerWithMiddleware, + middleware_layer, DropOnError, MessageOnly, RetryingSend, TcpServerExt, TcpServerLike, }; use tokio::task::JoinHandle; @@ -20,25 +18,38 @@ use crate::{ genesis::GenesisEvent, model::*, p2p::network::{OutboundMessage, PeerEvent}, - utils::{conf::SharedConf, rng::deterministic_rng}, + utils::conf::SharedConf, }; use anyhow::{Context, Result}; use core::str; -use rand::seq::IndexedRandom; use std::{ - collections::{BTreeSet, VecDeque}, + collections::{BTreeSet, HashMap, VecDeque}, time::Duration, }; +use strum_macros::AsRefStr; +use tokio::task::JoinSet; use tracing::{debug, error, info, trace, warn}; use crate::model::SharedRunContext; -type DataAvailabilityServer = TcpServerWithMiddleware< - QueuedSendWithRetry, - DataAvailabilityRequest, - DataAvailabilityEvent, - TcpServerWithMiddleware, ->; +type DaServerStack = hyli_net::tcp_server!( + request: DataAvailabilityRequest, + response: DataAvailabilityEvent, + middlewares: [ + MessageOnly, + RetryingSend, + DropOnError, + ] +); + +fn with_da_middlewares(server: RawDataAvailabilityServer) -> DaServerStack { + server + .layer(middleware_layer(DropOnError)) + .layer(middleware_layer( + RetryingSend::new(10, Duration::from_millis(100)).max_per_tick(256), + )) + .layer(middleware_layer(MessageOnly)) +} impl Module for DataAvailability { type Context = SharedRunContext; @@ -55,20 +66,25 @@ impl Module for DataAvailability { let catchup_policy = if ctx.config.consensus.solo { None } else { - Some(DaCatchupPolicy { - floor: if ctx.config.run_fast_catchup { - ctx.start_height.and_then(|start_height| { - // Avoid fast catchup reexecution - if highest_block < start_height { - Some(start_height + 1) - } else { - None - } - }) - } else { - None - }, - backfill: ctx.config.fast_catchup_backfill, + let floor = if ctx.config.run_fast_catchup { + ctx.start_height.and_then(|start_height| { + // Avoid fast catchup reexecution + if highest_block < start_height { + Some(start_height + 1) + } else { + None + } + }) + } else { + None + }; + Some(DaCatchupPolicy::Regular { + floor, + ceiling: None, + backfill_enabled: ctx.config.run_fast_catchup + && ctx.config.fast_catchup_backfill + && floor.is_some(), + backfill_start: None, }) }; @@ -76,13 +92,14 @@ impl Module for DataAvailability { "📦 DataAvailability module built with policy {:?}", catchup_policy ); - Ok(DataAvailability { config: ctx.config.clone(), bus, blocks, buffered_signed_blocks: BTreeSet::new(), catchupper: DaCatchupper::new(catchup_policy, ctx.config.da_max_frame_length), + allow_peer_catchup: false, + peer_send_queues: HashMap::new(), }) } @@ -113,287 +130,412 @@ pub struct DataAvailability { buffered_signed_blocks: BTreeSet, catchupper: DaCatchupper, -} - -/// Catchup configuration for the Data Availability module. -#[derive(Default, Debug, Clone)] -struct DaCatchupPolicy { - floor: Option, - backfill: bool, -} + // Gate peer-triggered catchup until genesis outcome is known. + allow_peer_catchup: bool, -#[derive(Debug, Clone)] -struct DaCatchupMetrics { - start: Counter, - restart: Counter, - timeout: Counter, - stream_closed: Counter, - start_height: Gauge, -} - -impl Default for DaCatchupMetrics { - fn default() -> Self { - Self::global() - } + // Track blocks to send to each streaming peer (ensures ordering) + peer_send_queues: HashMap>, } -impl DaCatchupMetrics { - pub fn global() -> DaCatchupMetrics { - let my_meter = global_meter_or_panic(); - DaCatchupMetrics { - start: my_meter.u64_counter("da_catchup_start").build(), - restart: my_meter.u64_counter("da_catchup_restart").build(), - timeout: my_meter.u64_counter("da_catchup_timeout").build(), - stream_closed: my_meter.u64_counter("da_catchup_stream_closed").build(), - start_height: my_meter.u64_gauge("da_catchup_start_height").build(), - } - } - - fn start(&self, peer: &str, height: u64) { - let labels = [KeyValue::new("peer", peer.to_string())]; - self.start.add(1, &labels); - self.start_height.record(height, &labels); - } - - fn restart(&self, peer: &str, height: u64) { - let labels = [KeyValue::new("peer", peer.to_string())]; - self.restart.add(1, &labels); - self.start_height.record(height, &labels); - } - - fn timeout(&self, peer: &str) { - self.timeout - .add(1, &[KeyValue::new("peer", peer.to_string())]); - } - - fn stream_closed(&self, peer: &str) { - self.stream_closed - .add(1, &[KeyValue::new("peer", peer.to_string())]); - } +#[derive(Debug, Clone, AsRefStr)] +#[strum(serialize_all = "kebab-case")] +enum DaCatchupPolicy { + Regular { + floor: Option, + ceiling: Option, + backfill_enabled: bool, + backfill_start: Option, + }, + BackfillPending { + ceiling: BlockHeight, + }, + Backfill { + start: BlockHeight, + ceiling: BlockHeight, + }, } -#[derive(Debug, Default)] +#[derive(Debug)] struct DaCatchupper { policy: Option, - status: Option<(tokio::task::JoinHandle>, BlockHeight)>, - backfill_start_height: Option, + task: Option>>, + last_height: Option, + sender: tokio::sync::mpsc::Sender, + receiver: Option>, pub peers: Vec, - pub stop_height: Option, da_max_frame_length: usize, - metrics: DaCatchupMetrics, + restart_attempts: usize, } impl DaCatchupper { pub fn new(policy: Option, da_max_frame_length: usize) -> Self { + let (sender, receiver) = tokio::sync::mpsc::channel::(100); DaCatchupper { policy, - status: None, - backfill_start_height: None, + task: None, + last_height: None, + sender, + receiver: Some(receiver), peers: vec![], da_max_frame_length, - stop_height: None, - metrics: DaCatchupMetrics::global(), + restart_attempts: 0, } } + pub fn take_receiver(&mut self) -> Option> { + self.receiver.take() + } + pub fn is_fast_catchup_initial_block(&self, height: &BlockHeight) -> bool { matches!( self.policy, - Some(DaCatchupPolicy { floor: Some(floor), .. }) if height == &floor + Some(DaCatchupPolicy::Regular { + floor: Some(floor), + .. + }) if height == &floor ) } - pub fn need_to_tick(&self) -> bool { - self.policy.as_ref().is_some_and(|p| p.backfill) || self.status.is_some() - } - #[cfg(test)] pub fn stop_task(&mut self) { - if let Some((task, _)) = &mut self.status { + if let Some(task) = &mut self.task { task.abort(); - self.status = None; + self.task = None; } } - pub fn choose_random_peer(&self) -> Vec { - let Some(primary) = self.peers.choose(&mut deterministic_rng()).cloned() else { - return vec![]; - }; - let mut ordered = Vec::with_capacity(self.peers.len()); - ordered.push(primary.clone()); - for peer in &self.peers { - if peer != &primary { - ordered.push(peer.clone()); + pub fn ensure_started(&mut self, from_height: BlockHeight) -> anyhow::Result<()> { + self.ensure_task_running(Some(from_height)) + } + + fn max_restart_attempts() -> usize { + std::env::var("HYLI_DA_CATCHUP_MAX_RESTARTS") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(12) + } + + pub fn add_peer_and_maybe_restart(&mut self, peer: String) -> anyhow::Result { + if self.peers.contains(&peer) { + return Ok(false); + } + self.peers.push(peer); + + if self.task.is_some() { + let restart_height = self.last_height.unwrap_or(BlockHeight(0)); + info!( + "Catchup peer set changed, restarting task from height {}", + restart_height + ); + if let Some(task) = &mut self.task { + task.abort(); } + self.task = None; + self.restart_attempts = 0; + self.ensure_task_running(Some(restart_height))?; } - ordered - } - pub fn init_catchup( - &mut self, - from_height: BlockHeight, - sender: &tokio::sync::mpsc::Sender, - ) -> anyhow::Result<()> { - let mut start_height = from_height; + Ok(true) + } - if let Some(DaCatchupPolicy { - floor: Some(floor), .. - }) = &self.policy - { - start_height = *floor; + fn mode_start_height( + policy: &DaCatchupPolicy, + requested_start: Option, + ) -> Option { + match policy { + DaCatchupPolicy::Regular { + floor: Some(floor), .. + } => Some(*floor), + DaCatchupPolicy::Regular { floor: None, .. } => requested_start, + DaCatchupPolicy::BackfillPending { .. } => None, + DaCatchupPolicy::Backfill { start, .. } => Some(*start), } + } - self.catchup_from(start_height, sender) + fn mode_ceiling(policy: &DaCatchupPolicy) -> Option { + match policy { + DaCatchupPolicy::Regular { ceiling, .. } => *ceiling, + DaCatchupPolicy::BackfillPending { .. } => None, + DaCatchupPolicy::Backfill { ceiling, .. } => Some(*ceiling), + } } - /// Start catchup workflow based on the current policy - pub fn catchup_from( - &mut self, - from_height: BlockHeight, - sender: &tokio::sync::mpsc::Sender, - ) -> anyhow::Result<()> { - if self.policy.is_none() { - debug!("No catchup policy set, stopping catchup task"); + fn ensure_task_running(&mut self, requested_start: Option) -> anyhow::Result<()> { + if self.task.is_some() { + trace!("Catchup task already running, skipping spawn"); return Ok(()); } + let sender = self.sender.clone(); - if self.status.is_some() { - debug!("Catchup is already in progress, no need to start a new task"); + let Some(policy) = self.policy.as_ref() else { + debug!("No catchup policy configured, skipping catchup start"); return Ok(()); - } + }; + let Some(from_height) = Self::mode_start_height(policy, requested_start) else { + debug!( + "Catchup mode {} has no start height yet, waiting", + policy.as_ref() + ); + return Ok(()); + }; + self.spawn_for_mode(from_height, sender) + } - if self.stop_height.is_some_and(|height| height <= from_height) { - debug!("Catchup is already done, no need to start a new task"); + fn spawn_for_mode( + &mut self, + from_height: BlockHeight, + sender: tokio::sync::mpsc::Sender, + ) -> anyhow::Result<()> { + let Some(policy) = self.policy.as_ref() else { return Ok(()); + }; + let target_height = Self::mode_ceiling(policy); + let mode_name = policy.as_ref(); + if let Some(height) = target_height { + if height <= from_height { + debug!( + "Skipping {} catchup spawn: empty range (from={}, to={})", + mode_name, from_height, height + ); + return self.complete_mode_if_reached(height); + } } - let peers = self.choose_random_peer(); + let peers = self.peers.clone(); if peers.is_empty() { - info!("choose_random_peer returned no peers"); + info!("No peers available for catchup"); return Ok(()); } - #[expect(clippy::unwrap_used, reason = "gated above")] - let peer = peers.first().unwrap(); - debug!( - "Starting catchup from height {} to {:?} on peer {}", - from_height, self.stop_height, peer + info!( + "Starting {} catchup from height {} to {:?}", + mode_name, from_height, target_height ); - self.status = Some(( - Self::start_task( - peers, - self.da_max_frame_length, - from_height, - sender.clone(), - self.metrics.clone(), - ), + self.task = Some(Self::spawn_stream_task( + peers, + self.da_max_frame_length, from_height, + sender, )); + self.last_height = Some(from_height); Ok(()) } - /// Try transition the catchup state based on the current status and policy. - pub fn manage_catchup( - &mut self, - processed_height: BlockHeight, - sender: &tokio::sync::mpsc::Sender, - ) -> anyhow::Result<()> { - if self.policy.is_none() { - debug!("No catchup policy set, skipping catchup"); - return Ok(()); + fn transition_after_regular_done(&mut self) { + let Some(DaCatchupPolicy::Regular { + floor, + backfill_enabled, + backfill_start, + .. + }) = self.policy.clone() + else { + self.policy = None; + return; }; - if self.status.is_none() { - if let Some(policy) = &mut self.policy { - // In case status is None, we check if we need to start a new catchup task up to the floor height - if policy.backfill && policy.floor.is_some() { - if let Some(start_height) = self.backfill_start_height { - policy.backfill = false; // Disable backfill after the first catchup - self.stop_height = policy.floor; // Set stop height to the floor if backfill is enabled - - debug!( - "Starting backfill catchup from height {} to {:?}", - start_height, policy.floor + debug!( + "Regular catchup completion: floor={:?}, backfill_enabled={}, backfill_start={:?}", + floor, backfill_enabled, backfill_start + ); + self.policy = None; + if backfill_enabled { + if let Some(floor) = floor { + if let Some(start) = backfill_start { + if start < floor { + info!( + "Transitioning to backfill mode: start={}, ceiling={}", + start, floor + ); + self.policy = Some(DaCatchupPolicy::Backfill { + start, + ceiling: floor, + }); + } else { + info!( + "Skipping backfill: discovered start {} is not below floor {}", + start, floor ); - - self.catchup_from(start_height, sender)?; } } else { - trace!("Catchup is already done"); + info!( + "Transitioning to backfill-pending mode at ceiling {}", + floor + ); + self.policy = Some(DaCatchupPolicy::BackfillPending { ceiling: floor }); } + } else { + debug!("Backfill is enabled but regular floor is unknown, no transition"); } + } else { + debug!("Backfill disabled, clearing catchup policy"); + } + } + fn complete_mode_if_reached(&mut self, processed_height: BlockHeight) -> anyhow::Result<()> { + let Some(policy) = self.policy.as_ref() else { + return Ok(()); + }; + let Some(ceiling) = Self::mode_ceiling(policy) else { return Ok(()); }; + if processed_height < ceiling { + return Ok(()); + } - let peers = self.choose_random_peer(); - if peers.is_empty() { - info!("choose_random_peer returned no peers"); + if let Some(task) = &mut self.task { + task.abort(); + } + self.task = None; + match policy { + DaCatchupPolicy::Regular { .. } => { + info!( + "Regular catchup done at height {}, evaluating backfill", + processed_height + ); + self.transition_after_regular_done(); + } + DaCatchupPolicy::BackfillPending { .. } => { + debug!("Backfill is pending first-hole discovery"); + } + DaCatchupPolicy::Backfill { .. } => { + info!("Backfill catchup done at height {}", processed_height); + self.policy = None; + } + } + self.ensure_task_running(None) + } + + pub fn on_first_hole_discovered(&mut self, hole: Option) { + match &mut self.policy { + Some(DaCatchupPolicy::Regular { backfill_start, .. }) => { + debug!( + "First-hole discovery during regular mode: hole={:?}, existing_backfill_start={:?}", + hole, backfill_start + ); + *backfill_start = backfill_start.or(hole); + } + Some(DaCatchupPolicy::BackfillPending { ceiling }) => { + debug!( + "First-hole discovery during backfill-pending: hole={:?}, ceiling={}", + hole, ceiling + ); + if let Some(start) = hole { + if start < *ceiling { + info!( + "Resolved backfill-pending -> backfill (start={}, ceiling={})", + start, ceiling + ); + self.policy = Some(DaCatchupPolicy::Backfill { + start, + ceiling: *ceiling, + }); + } else { + info!( + "Dropping catchup policy: first hole {} is not below pending ceiling {}", + start, ceiling + ); + self.policy = None; + } + } else { + info!("Dropping catchup policy: no first hole found for pending backfill"); + self.policy = None; + } + } + _ => { + debug!( + "Ignoring first-hole discovery for non-catchup state: hole={:?}", + hole + ); + } + } + } + + pub fn on_catchup_progress(&mut self, processed_height: BlockHeight) -> anyhow::Result<()> { + if self.policy.is_none() { return Ok(()); } - #[expect(clippy::unwrap_used, reason = "gated above")] - let peer = peers.first().unwrap(); - let Some((task, old_height)) = &mut self.status else { - unreachable!("Status was already checked"); + if let Some(last_height) = &mut self.last_height { + *last_height = processed_height.max(*last_height); + } + self.restart_attempts = 0; + self.complete_mode_if_reached(processed_height) + } + + pub fn on_tick(&mut self) -> anyhow::Result<()> { + if self.policy.is_none() { + return Ok(()); + } + self.ensure_task_running(None)?; + let Some(task) = &self.task else { + return Ok(()); }; + if !task.is_finished() { + return Ok(()); + } + let restart_height = self.last_height.unwrap_or(BlockHeight(0)); + self.restart_attempts += 1; + let max_restart_attempts = Self::max_restart_attempts(); + if self.restart_attempts > max_restart_attempts { + self.task = None; + return Err(anyhow::anyhow!( + "Catchup failed after {} restarts (last height {}). Aborting catchup.", + self.restart_attempts, + restart_height + )); + } - if self - .stop_height - .is_some_and(|height| height <= processed_height) - { - info!( - "Catchup task finished, last processed height {}", - processed_height - ); - task.abort(); - self.status = None; - } else if task.is_finished() { - info!( - "Catchup task finished, but catchup is not done yet, restarting from height {}", - processed_height - ); - let from = processed_height.max(*old_height); + info!( + "Catchup task finished before reaching target, restarting from height {} (attempt {}/{})", + restart_height, + self.restart_attempts, + max_restart_attempts + ); + self.task = None; + self.spawn_for_mode(restart_height, self.sender.clone()) + } - self.metrics.restart(peer, from.0); - let new_task = Self::start_task( - peers, - self.da_max_frame_length, - from, - sender.clone(), - self.metrics.clone(), - ); - self.status = Some((new_task, from)); + pub fn on_mempool_started_building(&mut self, height: BlockHeight) -> anyhow::Result<()> { + if let Some(DaCatchupPolicy::Regular { + floor, + ceiling, + backfill_enabled, + backfill_start: _, + }) = &mut self.policy + { + if ceiling.is_none() { + *ceiling = Some(height); + info!( + "Bounded regular catchup at ceiling {} (floor={:?}, backfill_enabled={})", + height, floor, backfill_enabled + ); + } else { + debug!( + "Ignoring started-building event at {}: regular ceiling already set to {:?}", + height, ceiling + ); + } } else { debug!( - "Catchup task is still running, last processed height {}", - processed_height + "Ignoring started-building event at {}: catchup is not in regular mode", + height ); - *old_height = processed_height; } - + if let Some(progress) = self.last_height { + self.complete_mode_if_reached(progress)?; + } Ok(()) } - fn start_task( + fn spawn_stream_task( peers: Vec, da_max_frame_length: usize, start_height: BlockHeight, sender: tokio::sync::mpsc::Sender, - metrics: DaCatchupMetrics, ) -> JoinHandle> { - let peer_label = peers - .first() - .cloned() - .unwrap_or_else(|| "unknown".to_string()); - info!( - "Starting catchup from height {} on peer {}", - start_height, peer_label - ); - - metrics.start(&peer_label, start_height.0); + info!("Starting catchup from height {}", start_height); tokio::spawn(async move { let timeout_duration = std::env::var("HYLI_DA_SLEEP_TIMEOUT") @@ -411,7 +553,7 @@ impl DaCatchupper { timeout_duration, ); log_error!( - stream.start_client().await, + stream.start_client_with_metrics().await, "Error occurred setting up the DA listener" )?; @@ -419,16 +561,17 @@ impl DaCatchupper { match stream.listen_next().await? { DaStreamPoll::Timeout => { warn!("Timeout expired while waiting for block."); - metrics.timeout(&peer_label); + stream.reconnect("timeout").await?; } DaStreamPoll::StreamClosed => { - metrics.stream_closed(&peer_label); + tokio::time::sleep(Duration::from_secs(1)).await; + stream.reconnect("stream_closed").await?; } DaStreamPoll::Event(event) => match event { DataAvailabilityEvent::SignedBlock(block) => { let blocks = stream.on_signed_block(block).await?; for block in blocks { - info!( + debug!( "📦 Received block (height {}) from stream", block.consensus_proposal.slot ); @@ -452,6 +595,12 @@ impl DaCatchupper { } } +impl Default for DaCatchupper { + fn default() -> Self { + Self::new(None, 0) + } +} + impl DataAvailability { pub fn start_scanning_for_first_hole( &self, @@ -461,7 +610,11 @@ impl DataAvailability { let (first_hole_sender, first_hole_receiver) = tokio::sync::mpsc::channel::>(10); - if let Some(DaCatchupPolicy { backfill: true, .. }) = self.catchupper.policy { + if let Some(DaCatchupPolicy::Regular { + backfill_enabled: true, + .. + }) = self.catchupper.policy + { // Start scanning local storage for first hole, if any _ = tokio::task::spawn(async move { loop { @@ -488,31 +641,35 @@ impl DataAvailability { self.config.da_server_port ); - let inner_server = RawDataAvailabilityServer::start_with_opts( - self.config.da_server_port, - Some(self.config.da_max_frame_length), - format!("DAServer-{}", self.config.id.clone()).as_str(), - ) - .await?; - let mut server = inner_server - .layer(middleware_layer(DropOnError)) - .layer(middleware_layer( - QueuedSendWithRetry::new(10, Duration::from_millis(100)).max_per_tick(256), - )); + let mut server = with_da_middlewares( + RawDataAvailabilityServer::start_with_opts( + self.config.da_server_port, + Some(self.config.da_max_frame_length), + format!("DAServer-{}", self.config.id.clone()).as_str(), + ) + .await?, + ); - let (catchup_block_sender, mut catchup_block_receiver) = - tokio::sync::mpsc::channel::(100); + let mut catchup_block_receiver = self + .catchupper + .take_receiver() + .ok_or_else(|| anyhow::anyhow!("Catchup receiver already taken"))?; let mut first_hole_receiver = self.start_scanning_for_first_hole(); + // Used to send blocks to clients (indexers/peers) + // This is a JoinSet of tuples containing: + // - The peer IP address to send the blocks to + // - The number of retries for sending the blocks + let mut catchup_joinset: JoinSet<(String, usize)> = tokio::task::JoinSet::new(); let mut catchup_task_checker_ticker = - tokio::time::interval(std::time::Duration::from_millis(5000)); + tokio::time::interval(std::time::Duration::from_secs(5)); let mut storage_metrics_ticker = tokio::time::interval(std::time::Duration::from_secs(30)); module_handle_messages! { on_self self, listen evt => { - _ = log_error!(self.handle_mempool_event(evt, &mut server, &catchup_block_sender).await, "Handling Mempool Event"); + _ = log_error!(self.handle_mempool_event(evt, &mut server, &mut catchup_joinset).await, "Handling Mempool Event"); } listen evt => { @@ -520,41 +677,44 @@ impl DataAvailability { } listen cmd => { - if let GenesisEvent::GenesisBlock(signed_block) = cmd { - debug!("🌱 Genesis block received with validators {:?}", signed_block.consensus_proposal.staking_actions.clone()); - _ = log_error!(self.handle_signed_block(signed_block, &mut server).await.context("Handling Genesis block"), "Handling GenesisBlock Event"); - } - else { - _ = log_error!( - self.catchupper.init_catchup( - self.blocks.highest(), - &catchup_block_sender, - ), - "Init catchup on new peer" - ); + match cmd { + GenesisEvent::GenesisBlock(signed_block) => { + debug!("🌱 Genesis block received with validators {:?}", signed_block.consensus_proposal.staking_actions.clone()); + _ = log_error!(self.handle_signed_block(signed_block, &mut server, &mut catchup_joinset).await.context("Handling Genesis block"), "Handling GenesisBlock Event"); + } + GenesisEvent::NoGenesis => { + self.allow_peer_catchup = true; + _ = log_error!( + self.catchupper.ensure_started( + self.blocks.highest(), + ), + "Init catchup after NoGenesis" + ); + } } } listen PeerEvent::NewPeer { da_address, .. } => { - self.catchupper.peers.push(da_address.clone()); - info!("New peer {}", da_address); - _ = log_error!( - self.catchupper.init_catchup( - self.blocks.highest(), - &catchup_block_sender, - ), - "Init catchup on new peer" - ); + let added = self.catchupper.add_peer_and_maybe_restart(da_address.clone())?; + if added { + info!("New peer {}", da_address); + } else { + debug!("Known peer announced again: {}", da_address); + } + if self.allow_peer_catchup { + self.catchupper.ensure_started(self.blocks.highest())?; + } else { + debug!("Skipping catchup init on new peer while genesis path is unresolved"); + } } - _ = catchup_task_checker_ticker.tick(), if self.catchupper.need_to_tick() => { - let highest_block = self.blocks.highest(); - _ = log_error!(self.catchupper.manage_catchup(highest_block, &catchup_block_sender), "Catchup transition after tick"); + _ = catchup_task_checker_ticker.tick(), if self.catchupper.policy.is_some() => { + self.catchupper.on_tick()?; } Some(streamed_block) = catchup_block_receiver.recv() => { - if let Some(height) = self.handle_signed_block(streamed_block, &mut server).await { - _ = log_error!(self.catchupper.manage_catchup(height, &catchup_block_sender), "Catchup transition after streamed block"); + if let Some(height) = self.handle_signed_block(streamed_block, &mut server, &mut catchup_joinset).await { + _ = log_error!(self.catchupper.on_catchup_progress(height), "Catchup transition after streamed block"); } } @@ -564,10 +724,10 @@ impl DataAvailability { _ = log_error!( self.start_streaming_to_peer( start_height, + &mut catchup_joinset, + &socket_addr, &mut server, - &socket_addr - ) - .await, + ).await, "Starting streaming to peer" ); } @@ -580,12 +740,28 @@ impl DataAvailability { } } + // Send one block to a peer as part of "catchup", + // once we have sent all blocks the peer is presumably synchronised. + Some(Ok((peer_ip, retries))) = catchup_joinset.join_next() => { + + #[cfg(test)] + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + _ = log_error!( + self.handle_send_next_block_to_peer( + peer_ip.clone(), + retries, + &mut catchup_joinset, + &mut server + ).await, + "Send next block to peer" + ); + } + Some(hole) = first_hole_receiver.recv() => { info!("Setting backfill start height as {:?}", &hole); - self.catchupper.backfill_start_height = hole; - let highest_block = self.blocks.highest(); - _ = log_error!(self.catchupper.manage_catchup(highest_block, &catchup_block_sender), "Catchup transition after tick"); - + self.catchupper.on_first_hole_discovered(hole); + self.catchupper.on_tick()?; } _ = storage_metrics_ticker.tick() => { @@ -596,12 +772,92 @@ impl DataAvailability { Ok(()) } - async fn handle_block_request( + async fn handle_send_next_block_to_peer( + &mut self, + peer_ip: String, + retries: usize, + catchup_joinset: &mut JoinSet<(String, usize)>, + server: &mut S, + ) -> Result<()> + where + S: TcpServerLike, + { + if !server.connected(&peer_ip) { + debug!("Peer {} disconnected, removing from send queues", peer_ip); + self.peer_send_queues.remove(&peer_ip); + return Ok(()); + } + + if retries > 10 { + warn!( + "Failed to send block, too many retries for peer {}", + &peer_ip + ); + server.drop_peer_stream(peer_ip.clone()); + self.peer_send_queues.remove(&peer_ip); + return Ok(()); + } + + // Get next block from this peer's queue + let hash = match self.peer_send_queues.get_mut(&peer_ip) { + Some(queue) => match queue.pop_front() { + Some(h) => h, + None => { + // Queue is empty - peer is caught up and waiting for new blocks + // Keep them in the map but don't spawn a new task yet + debug!("Peer {} caught up, waiting for new blocks", peer_ip); + return Ok(()); + } + }, + None => { + debug!("Peer {} not in send queues", peer_ip); + return Ok(()); + } + }; + + debug!("📡 Sending block {} to peer {}", &hash, &peer_ip); + if let Ok(Some(signed_block)) = self.blocks.get(&hash) { + // Errors will be handled when sending new blocks, ignore here. + match server.send( + peer_ip.clone(), + DataAvailabilityEvent::SignedBlock(signed_block), + vec![], + ) { + Ok(()) => { + // Successfully sent, continue with next block + catchup_joinset.spawn(async move { (peer_ip, 0) }); + } + Err(_) => { + // Retry sending the same block (put it back at front of queue) + if let Some(queue) = self.peer_send_queues.get_mut(&peer_ip) { + queue.push_front(hash); + } + catchup_joinset.spawn(async move { + tokio::time::sleep(Duration::from_millis(100 * (retries as u64))).await; + (peer_ip, retries + 1) + }); + } + } + } else { + error!( + "Block {} not found in storage while sending to peer {}. Should not happen", + &hash, &peer_ip + ); + // Continue anyway with next block + catchup_joinset.spawn(async move { (peer_ip, 0) }); + } + Ok(()) + } + + async fn handle_block_request( &mut self, block_height: BlockHeight, socket_addr: &str, - server: &mut DataAvailabilityServer, - ) -> Result<()> { + server: &mut S, + ) -> Result<()> + where + S: TcpServerLike, + { debug!( "📦 Received block request for height {} from {}", block_height, socket_addr @@ -614,18 +870,22 @@ impl DataAvailability { "📦 Found block at height {}, sending to {}", block_height, socket_addr ); + // Send immediately - this is inserted next in the send queue if let Err(e) = server.send( socket_addr.to_string(), DataAvailabilityEvent::SignedBlock(block), vec![], ) { warn!( - "📦 Error while responding to block request at height {} for {}: {:#}.", + "📦 Error while responding to block request at height {} for {}: {:#}. Dropping socket.", block_height, socket_addr, e ); + server.drop_peer_stream(socket_addr.to_string()); + return Ok(()); } } Ok(None) => { + // Block not in storage - this is a gap error!( "📦 Block at height {} not found in storage, sending BlockNotFound to {}", block_height, socket_addr @@ -636,9 +896,11 @@ impl DataAvailability { vec![], ) { warn!( - "📦 Error while responding BlockNotFound at height {} for {}: {:#}.", + "📦 Error while responding BlockNotFound at height {} for {}: {:#}. Dropping socket.", block_height, socket_addr, e ); + server.drop_peer_stream(socket_addr.to_string()); + return Ok(()); } } Err(e) => { @@ -652,9 +914,11 @@ impl DataAvailability { vec![], ) { warn!( - "📦 Error while responding BlockNotFound at height {} for {}: {:#}.", + "📦 Error while responding BlockNotFound at height {} for {}: {:#}. Dropping socket.", block_height, socket_addr, e ); + server.drop_peer_stream(socket_addr.to_string()); + return Ok(()); } } } @@ -662,58 +926,61 @@ impl DataAvailability { Ok(()) } - async fn handle_mempool_event( + async fn handle_mempool_event( &mut self, evt: MempoolBlockEvent, - tcp_server: &mut DataAvailabilityServer, - sender: &tokio::sync::mpsc::Sender, - ) -> Result<()> { + tcp_server: &mut S, + catchup_joinset: &mut JoinSet<(String, usize)>, + ) -> Result<()> + where + S: TcpServerLike, + { match evt { MempoolBlockEvent::BuiltSignedBlock(signed_block) => { debug!( "📦 Received built block (height {}) from Mempool", signed_block.height() ); - if let Some(height) = self.handle_signed_block(signed_block, tcp_server).await { - self.catchupper.manage_catchup(height, sender)?; - } + // Mempool-produced blocks are local tip updates, not catchup-stream progress. + // Feeding them into catchup progress can prematurely complete backfill. + _ = self + .handle_signed_block(signed_block, tcp_server, catchup_joinset) + .await; } MempoolBlockEvent::StartedBuildingBlocks(height) => { debug!( "Received started building block (at height {}) from Mempool", height ); - self.catchupper.stop_height = Some(height); + self.catchupper.on_mempool_started_building(height)?; } } Ok(()) } - async fn handle_mempool_status_event( - &mut self, - evt: MempoolStatusEvent, - tcp_server: &mut DataAvailabilityServer, - ) { - let errors = TcpServerLike::broadcast( - tcp_server, - DataAvailabilityEvent::MempoolStatusEvent(evt), - vec![], - ); + async fn handle_mempool_status_event(&mut self, evt: MempoolStatusEvent, tcp_server: &mut S) + where + S: TcpServerLike, + { + let errors = tcp_server.broadcast(DataAvailabilityEvent::MempoolStatusEvent(evt), vec![]); + for (peer, error) in errors { - warn!( - "Error while queueing mempool status event for {}: {:#}", - peer, error - ); + warn!("Error while broadcasting mempool status event {:#}", error); + tcp_server.drop_peer_stream(peer.clone()); } } /// if handled, returns the highest height of the processed blocks - async fn handle_signed_block( + async fn handle_signed_block( &mut self, block: SignedBlock, - tcp_server: &mut DataAvailabilityServer, - ) -> Option { + tcp_server: &mut S, + catchup_joinset: &mut JoinSet<(String, usize)>, + ) -> Option + where + S: TcpServerLike, + { let hash = block.hashed(); // if new block is already handled, ignore it if self.blocks.contains(&hash) { @@ -755,12 +1022,13 @@ impl DataAvailability { } else { // store block _ = log_error!( - self.add_processed_block(block.clone(), tcp_server).await, + self.add_processed_block(block.clone(), tcp_server, catchup_joinset) + .await, "Adding processed block" ); } - let highest_processed_height = self.pop_buffer(hash, tcp_server).await; + let highest_processed_height = self.pop_buffer(hash, tcp_server, catchup_joinset).await; _ = log_error!(self.blocks.persist(), "Persisting blocks"); let height = block.height(); @@ -769,11 +1037,15 @@ impl DataAvailability { } /// Returns the highest height of the processed blocks - async fn pop_buffer( + async fn pop_buffer( &mut self, mut last_block_hash: ConsensusProposalHash, - tcp_server: &mut DataAvailabilityServer, - ) -> Option { + tcp_server: &mut S, + catchup_joinset: &mut JoinSet<(String, usize)>, + ) -> Option + where + S: TcpServerLike, + { let mut res = None; // Iterative loop to avoid stack overflows @@ -796,7 +1068,7 @@ impl DataAvailability { let height = first_buffered.height(); if self - .add_processed_block(first_buffered.clone(), tcp_server) + .add_processed_block(first_buffered.clone(), tcp_server, catchup_joinset) .await .is_ok() { @@ -814,7 +1086,9 @@ impl DataAvailability { trace!("Block {} {}: {:#?}", block.height(), block.hashed(), block); - if block.height().0.is_multiple_of(10) || block.has_txs() { + let height = block.height().0; + let info_log_interval = if height < 1_000 { 10 } else { 1_000 }; + if height.is_multiple_of(info_log_interval) { info!( "new block #{} 0x{} with {} txs", block.height(), @@ -840,15 +1114,42 @@ impl DataAvailability { Ok(()) } - async fn add_processed_block( + async fn add_processed_block( &mut self, block: SignedBlock, - tcp_server: &mut DataAvailabilityServer, - ) -> anyhow::Result<()> { + _tcp_server: &mut S, + catchup_joinset: &mut JoinSet<(String, usize)>, + ) -> anyhow::Result<()> + where + S: TcpServerLike, + { self.store_block(&block)?; - tcp_server - .enqueue_to_streaming_peers(DataAvailabilityEvent::SignedBlock(block.clone()), vec![]); + let block_hash = block.hashed(); + + // Add new block to all streaming peer queues to ensure ordering + // (instead of broadcasting which can cause out-of-order delivery) + for (peer, queue) in self.peer_send_queues.iter_mut() { + let was_empty = queue.is_empty(); + queue.push_back(block_hash.clone()); + + // If queue was empty (peer was caught up), restart their send task + if was_empty { + debug!( + "Restarting send task for caught-up peer {} with new block {}", + peer, block_hash + ); + let peer_clone = peer.clone(); + catchup_joinset.spawn(async move { (peer_clone, 0) }); + } else { + debug!( + "Appending block {} to queue for peer {} (queue size: {})", + block_hash, + peer, + queue.len() + ); + } + } // Send the block to NodeState for processing _ = log_error!( @@ -861,28 +1162,67 @@ impl DataAvailability { Ok(()) } - async fn start_streaming_to_peer( + async fn start_streaming_to_peer( &mut self, start_height: BlockHeight, - server: &mut DataAvailabilityServer, + catchup_joinset: &mut JoinSet<(String, usize)>, peer_ip: &str, - ) -> Result<()> { + server: &mut S, + ) -> Result<()> + where + S: TcpServerLike, + { let range_start = std::time::Instant::now(); + let highest = self + .blocks + .last() + .map_or(start_height, |block| block.height()); + // Collect all blocks from start_height to current highest let processed_block_hashes: VecDeque<_> = self .blocks - .range( - start_height, - self.blocks - .last() - .map_or(start_height, |block| block.height()) - + 1, - ) + .range(start_height, highest + 1) .filter_map(|item| item.ok()) .collect(); self.blocks .record_op("range_collect", "by_height", range_start.elapsed()); + let expected = highest.0.saturating_sub(start_height.0).saturating_add(1); + // If requester starts beyond our current tip, they are already caught up: + // the valid stream response is an empty queue (wait for future blocks), not BlockNotFound. + let expected = if start_height > highest { 0 } else { expected }; + if processed_block_hashes.len() as u64 != expected { + let first_missing = (start_height.0..=highest.0) + .find(|height| { + self.blocks + .get_by_height(BlockHeight(*height)) + .map(|block| block.is_none()) + .unwrap_or(true) + }) + .map(BlockHeight) + .unwrap_or(start_height); + + info!( + "Rejecting stream for peer {}: local gap detected at height {} while serving [{}..={}]", + peer_ip, first_missing, start_height, highest + ); + + if let Err(e) = server.send( + peer_ip.to_string(), + DataAvailabilityEvent::BlockNotFound(first_missing), + vec![], + ) { + warn!( + "Error sending BlockNotFound at height {} to {}: {:#}", + first_missing, peer_ip, e + ); + } + + server.drop_peer_stream(peer_ip.to_string()); + self.peer_send_queues.remove(peer_ip); + return Ok(()); + } + info!( "Starting stream to peer {} from height {} ({} blocks queued)", peer_ip, @@ -890,31 +1230,13 @@ impl DataAvailability { processed_block_hashes.len() ); - let peer_ip = peer_ip.to_string(); - server.register_streaming_peer(peer_ip.clone()); - for hash in processed_block_hashes { - match self.blocks.get(&hash) { - Ok(Some(block)) => { - _ = server.send( - peer_ip.clone(), - DataAvailabilityEvent::SignedBlock(block), - vec![], - ); - } - Ok(None) => { - warn!( - "Missing block {} while starting stream to {}", - hash, peer_ip - ); - } - Err(e) => { - warn!( - "Error loading block {} while starting stream to {}: {:#}", - hash, peer_ip, e - ); - } - } - } + // Store queue for this peer - new blocks will be appended here + let peer_ip_string = peer_ip.to_string(); + self.peer_send_queues + .insert(peer_ip_string.clone(), processed_block_hashes); + + // Start the send task for this peer + catchup_joinset.spawn(async move { (peer_ip_string, 0) }); Ok(()) } @@ -926,7 +1248,7 @@ pub mod tests { use std::{collections::HashMap, time::Duration}; use super::module_bus_client; - use super::Blocks; + use super::{Blocks, RawDataAvailabilityServer}; use crate::data_availability::DaCatchupPolicy; use crate::{ bus::BusClientSender, @@ -939,12 +1261,9 @@ pub mod tests { use hyli_modules::node_state::module::NodeStateBusClient; use hyli_modules::node_state::NodeState; use hyli_modules::utils::da_codec::DataAvailabilityClient; - use hyli_modules::utils::da_codec::DataAvailabilityServer as RawDataAvailabilityServer; - use hyli_net::tcp::middleware::{ - middleware_layer, DropOnError, QueuedSendWithRetry, TcpServerExt, - }; - use hyli_net::tcp::TcpEvent; + use hyli_net::tcp::middleware::TcpServerLike; use staking::state::Staking; + use tokio::task::JoinSet; struct DataAvailabilityTestCtx { pub node_state_bus: NodeStateBusClient, @@ -952,15 +1271,6 @@ pub mod tests { pub node_state: NodeState, } - async fn make_da_server(port: u16, name: &str) -> super::DataAvailabilityServer { - let inner = RawDataAvailabilityServer::start(port, name).await.unwrap(); - inner - .layer(middleware_layer(DropOnError)) - .layer(middleware_layer( - QueuedSendWithRetry::new(10, Duration::from_millis(100)).max_per_tick(256), - )) - } - impl DataAvailabilityTestCtx { pub async fn new(shared_bus: crate::bus::SharedMessageBus) -> Self { let path = tempfile::tempdir().unwrap().keep(); @@ -982,6 +1292,8 @@ pub mod tests { blocks, buffered_signed_blocks: Default::default(), catchupper: Default::default(), + allow_peer_catchup: false, + peer_send_queues: HashMap::new(), }; DataAvailabilityTestCtx { @@ -991,12 +1303,15 @@ pub mod tests { } } - pub async fn handle_signed_block( - &mut self, - block: SignedBlock, - tcp_server: &mut super::DataAvailabilityServer, - ) { - self.da.handle_signed_block(block.clone(), tcp_server).await; + pub async fn handle_signed_block(&mut self, block: SignedBlock, tcp_server: &mut S) + where + S: TcpServerLike, + { + let mut catchup_joinset: JoinSet<(String, usize)> = JoinSet::new(); + _ = self + .da + .handle_signed_block(block.clone(), tcp_server, &mut catchup_joinset) + .await; let block_hash = block.hashed(); let Ok(full_block) = self.node_state.handle_signed_block(block) else { tracing::warn!("Error while handling signed block {}", block_hash); @@ -1029,8 +1344,13 @@ pub mod tests { let tmpdir = tempfile::tempdir().unwrap().keep(); let blocks = Blocks::new(&tmpdir).unwrap(); - let mut server = make_da_server(7898, "DaServer").await; - + let port = find_available_port().await; + let mut server = super::with_da_middlewares( + RawDataAvailabilityServer::start(port, "DaServer") + .await + .unwrap(), + ); + let bus = super::DABusClient::new_from_bus(crate::bus::SharedMessageBus::new()).await; let mut da = super::DataAvailability { config: Default::default(), @@ -1038,6 +1358,8 @@ pub mod tests { blocks, buffered_signed_blocks: Default::default(), catchupper: Default::default(), + allow_peer_catchup: false, + peer_send_queues: HashMap::new(), }; let mut block = SignedBlock::default(); let mut blocks = vec![]; @@ -1047,14 +1369,20 @@ pub mod tests { block.consensus_proposal.slot = i; } blocks.reverse(); + let mut catchup_joinset: JoinSet<(String, usize)> = JoinSet::new(); for block in blocks { if block.height().0 == 0 { assert_eq!( - da.handle_signed_block(block, &mut server).await, + da.handle_signed_block(block, &mut server, &mut catchup_joinset) + .await, Some(BlockHeight(9998)) ); } else { - assert_eq!(da.handle_signed_block(block, &mut server).await, None); + assert_eq!( + da.handle_signed_block(block, &mut server, &mut catchup_joinset) + .await, + None + ); } } } @@ -1084,6 +1412,8 @@ pub mod tests { blocks, buffered_signed_blocks: Default::default(), catchupper: Default::default(), + allow_peer_catchup: false, + peer_send_queues: HashMap::new(), }; let mut block = SignedBlock::default(); @@ -1185,9 +1515,11 @@ pub mod tests { #[test_log::test(tokio::test)] async fn test_da_many_clients_only_last_connected() { let port = find_available_port().await; - let mut server = RawDataAvailabilityServer::start(port, "DaServer") - .await - .unwrap(); + let mut server = super::with_da_middlewares( + RawDataAvailabilityServer::start(port, "DaServer") + .await + .unwrap(), + ); let client_count = 5usize; let mut clients = Vec::with_capacity(client_count); @@ -1210,23 +1542,17 @@ pub mod tests { .unwrap() .unwrap(); - match event { - TcpEvent::Message { - socket_addr, data, .. - } => { - assert_eq!( - data, - DataAvailabilityRequest::StreamFromHeight(BlockHeight(i as u64)) - ); - assert!( - server.connected(&socket_addr), - "Server should track connected client {}", - socket_addr - ); - addr_by_idx.insert(i, socket_addr); - } - other => panic!("Expected Message event, got {other:?}"), - } + let (socket_addr, data, _headers) = event; + assert_eq!( + data, + DataAvailabilityRequest::StreamFromHeight(BlockHeight(i as u64)) + ); + assert!( + server.connected(&socket_addr), + "Server should track connected client {}", + socket_addr + ); + addr_by_idx.insert(i, socket_addr); clients.push(client); } @@ -1245,37 +1571,20 @@ pub mod tests { if tokio::time::Instant::now() >= deadline { panic!("Expected client {} to be dropped", dropped_addr); } - if let Ok(Some( - TcpEvent::Closed { socket_addr } | TcpEvent::Error { socket_addr, .. }, - )) = tokio::time::timeout(Duration::from_millis(200), server.listen_next()).await - { - if socket_addr == dropped_addr { - server.drop_peer_stream(socket_addr); - } - } + _ = tokio::time::timeout(Duration::from_millis(200), server.listen_next()).await; } } let deadline = tokio::time::Instant::now() + Duration::from_secs(2); loop { - let connected_clients: Vec = server.connected_clients().cloned().collect(); - if connected_clients.len() == 1 && server.connected(&last_addr) { + if server.connected_clients().count() == 1 && server.connected(&last_addr) { break; } if tokio::time::Instant::now() >= deadline { - panic!( - "Expected only last client connected, got {:?}", - connected_clients - ); - } - if let Ok(Some( - TcpEvent::Closed { socket_addr } | TcpEvent::Error { socket_addr, .. }, - )) = tokio::time::timeout(Duration::from_millis(200), server.listen_next()).await - { - if socket_addr != last_addr { - server.drop_peer_stream(socket_addr); - } + let peers: Vec<_> = server.connected_clients().cloned().collect(); + panic!("Expected only last client connected, got {:?}", peers); } + _ = tokio::time::timeout(Duration::from_millis(200), server.listen_next()).await; } } @@ -1284,13 +1593,20 @@ pub mod tests { let sender_global_bus = crate::bus::SharedMessageBus::new(); let mut block_sender = TestBusClient::new_from_bus(sender_global_bus.new_handle()).await; let mut da_sender = DataAvailabilityTestCtx::new(sender_global_bus).await; - let mut server = make_da_server(7890, "DaServer").await; + let port = find_available_port().await; + let mut server = super::with_da_middlewares( + RawDataAvailabilityServer::start(port, "DaServer") + .await + .unwrap(), + ); let receiver_global_bus = crate::bus::SharedMessageBus::new(); let mut da_receiver = DataAvailabilityTestCtx::new(receiver_global_bus).await; - da_receiver.da.catchupper.policy = Some(DaCatchupPolicy { + da_receiver.da.catchupper.policy = Some(DaCatchupPolicy::Regular { floor: None, - backfill: false, + ceiling: None, + backfill_enabled: false, + backfill_start: None, }); da_receiver.da.catchupper.da_max_frame_length = da_sender.da.config.da_max_frame_length; @@ -1321,14 +1637,18 @@ pub mod tests { tokio::time::sleep(std::time::Duration::from_millis(100)).await; // Setup done - let (tx, mut rx) = tokio::sync::mpsc::channel(200); + let mut rx = da_receiver + .da + .catchupper + .take_receiver() + .expect("catchup receiver should be available"); da_receiver .da .catchupper .peers .push(da_sender_address.clone()); - _ = da_receiver.da.catchupper.catchup_from(BlockHeight(0), &tx); + _ = da_receiver.da.catchupper.ensure_started(BlockHeight(0)); // Waiting a bit to push the block ten in the middle of all other 1..9 blocks tokio::time::sleep(Duration::from_millis(200)).await; @@ -1339,6 +1659,10 @@ pub mod tests { da_receiver .handle_signed_block(streamed_block.clone(), &mut server) .await; + let _ = da_receiver + .da + .catchupper + .on_catchup_progress(streamed_block.height()); received_blocks.push(streamed_block); if received_blocks.len() == 11 { break; @@ -1374,6 +1698,10 @@ pub mod tests { da_receiver .handle_signed_block(streamed_block.clone(), &mut server) .await; + let _ = da_receiver + .da + .catchupper + .on_catchup_progress(streamed_block.height()); received_blocks.push(streamed_block); if received_blocks.len() == 15 { break; @@ -1409,7 +1737,7 @@ pub mod tests { da_receiver .da .catchupper - .init_catchup(BlockHeight(15), &tx) + .ensure_started(BlockHeight(15)) .expect("Error while asking for catchup blocks"); let mut received_blocks = vec![]; @@ -1429,13 +1757,20 @@ pub mod tests { let sender_global_bus = crate::bus::SharedMessageBus::new(); let mut block_sender = TestBusClient::new_from_bus(sender_global_bus.new_handle()).await; let mut da_sender = DataAvailabilityTestCtx::new(sender_global_bus).await; - let mut server = make_da_server(7891, "DaServer").await; + let port = find_available_port().await; + let mut server = super::with_da_middlewares( + RawDataAvailabilityServer::start(port, "DaServer") + .await + .unwrap(), + ); let receiver_global_bus = crate::bus::SharedMessageBus::new(); let mut da_receiver = DataAvailabilityTestCtx::new(receiver_global_bus).await; - da_receiver.da.catchupper.policy = Some(DaCatchupPolicy { + da_receiver.da.catchupper.policy = Some(DaCatchupPolicy::Regular { floor: Some(BlockHeight(8)), - backfill: true, + ceiling: None, + backfill_enabled: true, + backfill_start: None, }); da_receiver.da.catchupper.da_max_frame_length = da_sender.da.config.da_max_frame_length; @@ -1466,15 +1801,19 @@ pub mod tests { tokio::time::sleep(std::time::Duration::from_millis(100)).await; // Setup done - let (tx, mut rx) = tokio::sync::mpsc::channel(200); + let mut rx = da_receiver + .da + .catchupper + .take_receiver() + .expect("catchup receiver should be available"); da_receiver .da .catchupper .peers .push(da_sender_address.clone()); - // first init catchup should get last blocks after the floor = 8 - _ = da_receiver.da.catchupper.init_catchup(BlockHeight(0), &tx); + // Initial fast catchup starts at floor and remains unbounded until mempool starts building. + _ = da_receiver.da.catchupper.ensure_started(BlockHeight(0)); _ = block_sender.send(MempoolBlockEvent::BuiltSignedBlock(block_ten.clone())); let mut received_blocks = vec![]; @@ -1482,6 +1821,10 @@ pub mod tests { da_receiver .handle_signed_block(streamed_block.clone(), &mut server) .await; + let _ = da_receiver + .da + .catchupper + .on_catchup_progress(streamed_block.height()); received_blocks.push(streamed_block); if received_blocks.len() == 3 { break; @@ -1494,34 +1837,25 @@ pub mod tests { assert!(received_blocks.iter().any(|b| b.height().0 == i)); } - // Stop the task - da_receiver.da.catchupper.stop_height = Some(BlockHeight(10)); - _ = da_receiver - .da - .catchupper - .manage_catchup(BlockHeight(10), &tx); - - // should not start backfill - _ = da_receiver + // Mempool starts producing blocks; this sets regular mode ceiling and transitions to backfill. + da_receiver .da .catchupper - .manage_catchup(BlockHeight(10), &tx); - - assert!(rx.try_recv().is_err()); - - da_receiver.da.catchupper.backfill_start_height = Some(BlockHeight(5)); - - // should start backfill from height 5 + .on_first_hole_discovered(Some(BlockHeight(5))); _ = da_receiver .da .catchupper - .manage_catchup(BlockHeight(10), &tx); + .on_mempool_started_building(BlockHeight(10)); let mut received_blocks = vec![]; while let Some(streamed_block) = rx.recv().await { da_receiver .handle_signed_block(streamed_block.clone(), &mut server) .await; + let _ = da_receiver + .da + .catchupper + .on_catchup_progress(streamed_block.height()); received_blocks.push(streamed_block); if received_blocks.len() == 3 { break; @@ -1563,6 +1897,8 @@ pub mod tests { blocks: blocks_storage, buffered_signed_blocks: Default::default(), catchupper: Default::default(), + allow_peer_catchup: false, + peer_send_queues: HashMap::new(), }; // Start DA server @@ -1680,4 +2016,470 @@ pub mod tests { client.close().await.unwrap(); } + + #[test_log::test(tokio::test)] + async fn test_stream_rejected_when_start_height_missing() { + let tmpdir = tempfile::tempdir().unwrap().keep(); + let mut blocks_storage = Blocks::new(&tmpdir).unwrap(); + + let global_bus = crate::bus::SharedMessageBus::new(); + let bus = super::DABusClient::new_from_bus(global_bus.new_handle()).await; + + let mut config: Conf = Conf::new(vec![], None, None).unwrap(); + config.da_server_port = find_available_port().await; + config.da_public_address = format!("127.0.0.1:{}", config.da_server_port); + + // Only store high blocks (10k+), leaving low heights unavailable. + let mut block = SignedBlock::default(); + block.consensus_proposal.slot = 10_000; + blocks_storage.put(block.clone()).unwrap(); + for i in 10_001..10_006 { + block.consensus_proposal.parent_hash = block.hashed(); + block.consensus_proposal.slot = i; + blocks_storage.put(block.clone()).unwrap(); + } + + let mut da = super::DataAvailability { + config: config.clone().into(), + bus, + blocks: blocks_storage, + buffered_signed_blocks: Default::default(), + catchupper: Default::default(), + allow_peer_catchup: false, + peer_send_queues: HashMap::new(), + }; + + tokio::spawn(async move { + da.start().await.unwrap(); + }); + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let mut client = + DataAvailabilityClient::connect("client_id", config.da_public_address.clone()) + .await + .unwrap(); + + client + .send(DataAvailabilityRequest::StreamFromHeight(BlockHeight(10))) + .await + .unwrap(); + + let first_event = tokio::time::timeout(Duration::from_secs(1), client.recv()) + .await + .expect("Timed out waiting for stream rejection"); + if let Some(event) = first_event { + assert_eq!( + event, + DataAvailabilityEvent::BlockNotFound(BlockHeight(10)), + "Only BlockNotFound should be emitted before stream closes" + ); + let second_event = tokio::time::timeout(Duration::from_secs(1), client.recv()) + .await + .expect("Timed out waiting for stream closure"); + assert!( + second_event.is_none(), + "Stream should close after rejecting invalid start height" + ); + } + } + + #[test_log::test(tokio::test)] + async fn test_stream_rejected_when_requested_range_has_gap() { + let tmpdir = tempfile::tempdir().unwrap().keep(); + let mut blocks_storage = Blocks::new(&tmpdir).unwrap(); + + let global_bus = crate::bus::SharedMessageBus::new(); + let bus = super::DABusClient::new_from_bus(global_bus.new_handle()).await; + + let mut config: Conf = Conf::new(vec![], None, None).unwrap(); + config.da_server_port = find_available_port().await; + config.da_public_address = format!("127.0.0.1:{}", config.da_server_port); + + // Build sparse heights 0,1,3,4 so start height exists but range is not contiguous. + let mut block = SignedBlock::default(); + block.consensus_proposal.slot = 0; + blocks_storage.put(block.clone()).unwrap(); + block.consensus_proposal.parent_hash = block.hashed(); + block.consensus_proposal.slot = 1; + blocks_storage.put(block.clone()).unwrap(); + block.consensus_proposal.parent_hash = block.hashed(); + block.consensus_proposal.slot = 3; + blocks_storage.put(block.clone()).unwrap(); + block.consensus_proposal.parent_hash = block.hashed(); + block.consensus_proposal.slot = 4; + blocks_storage.put(block.clone()).unwrap(); + + let mut da = super::DataAvailability { + config: config.clone().into(), + bus, + blocks: blocks_storage, + buffered_signed_blocks: Default::default(), + catchupper: Default::default(), + allow_peer_catchup: false, + peer_send_queues: HashMap::new(), + }; + + tokio::spawn(async move { + da.start().await.unwrap(); + }); + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let mut client = + DataAvailabilityClient::connect("client_id", config.da_public_address.clone()) + .await + .unwrap(); + + client + .send(DataAvailabilityRequest::StreamFromHeight(BlockHeight(0))) + .await + .unwrap(); + + let first_event = tokio::time::timeout(Duration::from_secs(1), client.recv()) + .await + .expect("Timed out waiting for stream rejection"); + if let Some(event) = first_event { + assert_eq!( + event, + DataAvailabilityEvent::BlockNotFound(BlockHeight(2)), + "Only BlockNotFound should be emitted before stream closes" + ); + let second_event = tokio::time::timeout(Duration::from_secs(1), client.recv()) + .await + .expect("Timed out waiting for stream closure"); + assert!( + second_event.is_none(), + "Stream should close after rejecting sparse range" + ); + } + } + + #[test_log::test(tokio::test)] + async fn test_regular_mode_stops_only_when_reaching_ceiling() { + let (tx, _rx) = tokio::sync::mpsc::channel::(1); + let mut catchupper = super::DaCatchupper { + policy: Some(DaCatchupPolicy::Regular { + floor: Some(BlockHeight(180)), + ceiling: Some(BlockHeight(75_001)), + backfill_enabled: false, + backfill_start: Some(BlockHeight(0)), + }), + task: Some(tokio::spawn(async { + futures::future::pending::>().await + })), + last_height: Some(BlockHeight(180)), + sender: tx.clone(), + receiver: None, + peers: vec!["127.0.0.1:12345".to_string()], + da_max_frame_length: 1024, + restart_attempts: 0, + }; + + // Simulate stale external progress lower than floor. + catchupper + .on_catchup_progress(BlockHeight(180)) + .expect("on_catchup_progress should succeed"); + assert!( + catchupper.task.is_some(), + "catchup should keep running until it reaches ceiling" + ); + + // Explicit catchup progress to the ceiling should complete it. + catchupper + .on_catchup_progress(BlockHeight(75_001)) + .expect("on_catchup_progress should succeed"); + assert!( + catchupper.task.is_none(), + "catchup should stop once progress reaches ceiling" + ); + assert!( + catchupper.policy.is_none(), + "regular mode should clear policy when done and no backfill is configured" + ); + } + + #[test_log::test(tokio::test)] + async fn test_started_building_only_sets_regular_ceiling() { + let (tx, _rx) = tokio::sync::mpsc::channel::(1); + let mut catchupper = super::DaCatchupper { + policy: Some(DaCatchupPolicy::Regular { + floor: Some(BlockHeight(180)), + ceiling: None, + backfill_enabled: true, + backfill_start: None, + }), + task: Some(tokio::spawn(async { + futures::future::pending::>().await + })), + last_height: Some(BlockHeight(180)), + sender: tx.clone(), + receiver: None, + peers: vec!["127.0.0.1:12345".to_string()], + da_max_frame_length: 1024, + restart_attempts: 0, + }; + + catchupper + .on_mempool_started_building(BlockHeight(75_001)) + .expect("on_mempool_started_building should succeed"); + + assert!( + catchupper.task.is_some(), + "started-building event should not complete regular catchup" + ); + assert!( + matches!( + catchupper.policy, + Some(DaCatchupPolicy::Regular { + ceiling: Some(BlockHeight(75_001)), + backfill_start: None, + .. + }) + ), + "started-building event should only bound regular mode" + ); + } + + #[test_log::test(tokio::test)] + async fn test_started_building_transitions_when_regular_range_is_empty_without_progress() { + let (tx, _rx) = tokio::sync::mpsc::channel::(1); + let mut catchupper = super::DaCatchupper { + policy: Some(DaCatchupPolicy::Regular { + floor: Some(BlockHeight(8)), + ceiling: None, + backfill_enabled: true, + backfill_start: Some(BlockHeight(5)), + }), + task: None, + last_height: None, + sender: tx.clone(), + receiver: None, + peers: vec![], + da_max_frame_length: 1024, + restart_attempts: 0, + }; + + catchupper + .on_mempool_started_building(BlockHeight(8)) + .expect("on_mempool_started_building should succeed"); + catchupper.on_tick().expect("on_tick should succeed"); + + assert!( + matches!( + catchupper.policy, + Some(DaCatchupPolicy::Backfill { + start: BlockHeight(5), + ceiling: BlockHeight(8) + }) + ), + "empty regular range should immediately transition to backfill mode" + ); + assert!( + catchupper.task.is_none(), + "backfill task should not start without peers" + ); + } + + #[test_log::test(tokio::test)] + async fn test_started_building_completes_regular_when_progress_already_reached_ceiling() { + let (tx, _rx) = tokio::sync::mpsc::channel::(1); + let mut catchupper = super::DaCatchupper { + policy: Some(DaCatchupPolicy::Regular { + floor: Some(BlockHeight(8)), + ceiling: None, + backfill_enabled: true, + backfill_start: Some(BlockHeight(5)), + }), + task: Some(tokio::spawn(async { + futures::future::pending::>().await + })), + last_height: Some(BlockHeight(10)), + sender: tx.clone(), + receiver: None, + peers: vec![], + da_max_frame_length: 1024, + restart_attempts: 0, + }; + + catchupper + .on_mempool_started_building(BlockHeight(10)) + .expect("on_mempool_started_building should succeed"); + + assert!( + catchupper.task.is_none(), + "regular catchup should complete when known progress already reached new ceiling" + ); + assert!( + matches!( + catchupper.policy, + Some(DaCatchupPolicy::Backfill { + start: BlockHeight(5), + ceiling: BlockHeight(8) + }) + ), + "regular mode should transition to backfill after completion" + ); + } + + #[test_log::test(tokio::test)] + async fn test_late_first_hole_transitions_pending_to_backfill() { + let (tx, _rx) = tokio::sync::mpsc::channel::(1); + let mut catchupper = super::DaCatchupper { + policy: Some(DaCatchupPolicy::Regular { + floor: Some(BlockHeight(10)), + ceiling: Some(BlockHeight(10)), + backfill_enabled: true, + backfill_start: None, + }), + task: Some(tokio::spawn(async { + futures::future::pending::>().await + })), + last_height: Some(BlockHeight(10)), + sender: tx.clone(), + receiver: None, + peers: vec![], + da_max_frame_length: 1024, + restart_attempts: 0, + }; + + catchupper + .on_catchup_progress(BlockHeight(10)) + .expect("regular mode completion should succeed"); + assert!( + matches!( + catchupper.policy, + Some(DaCatchupPolicy::BackfillPending { + ceiling: BlockHeight(10) + }) + ), + "regular completion without known hole should wait for hole discovery" + ); + + catchupper.on_first_hole_discovered(Some(BlockHeight(5))); + assert!( + matches!( + catchupper.policy, + Some(DaCatchupPolicy::Backfill { + start: BlockHeight(5), + ceiling: BlockHeight(10) + }) + ), + "late hole discovery should transition pending state into backfill mode" + ); + } + + #[test_log::test(tokio::test)] + async fn test_on_tick_restarts_finished_task() { + let (tx, _rx) = tokio::sync::mpsc::channel::(1); + let mut catchupper = super::DaCatchupper { + policy: Some(DaCatchupPolicy::Regular { + floor: Some(BlockHeight(180)), + ceiling: Some(BlockHeight(75_001)), + backfill_enabled: false, + backfill_start: Some(BlockHeight(0)), + }), + task: Some(tokio::spawn(async { Ok(()) })), + last_height: Some(BlockHeight(180)), + sender: tx.clone(), + receiver: None, + peers: vec!["127.0.0.1:12345".to_string()], + da_max_frame_length: 1024, + restart_attempts: 0, + }; + + tokio::time::sleep(Duration::from_millis(5)).await; + catchupper.on_tick().expect("on_tick should succeed"); + let restart_height = catchupper + .last_height + .expect("catchup should have restarted"); + assert_eq!( + restart_height, + BlockHeight(180), + "catchup should restart from the last known catchup height" + ); + } + + #[test_log::test(tokio::test)] + async fn test_on_tick_keeps_backfill_mode_when_task_not_started() { + let (tx, _rx) = tokio::sync::mpsc::channel::(1); + let mut catchupper = super::DaCatchupper { + policy: Some(DaCatchupPolicy::Backfill { + start: BlockHeight(180), + ceiling: BlockHeight(75_001), + }), + task: None, + last_height: None, + sender: tx.clone(), + receiver: None, + peers: vec![], + da_max_frame_length: 1024, + restart_attempts: 0, + }; + + catchupper.on_tick().expect("on_tick should succeed"); + assert!( + matches!(catchupper.policy, Some(DaCatchupPolicy::Backfill { .. })), + "backfill mode should remain active until a task can start" + ); + assert!( + catchupper.task.is_none(), + "no task should start when no peers are available" + ); + } + + #[test_log::test(tokio::test)] + async fn test_stream_accepts_when_peer_is_already_up_to_date() { + let tmpdir = tempfile::tempdir().unwrap().keep(); + let mut blocks_storage = Blocks::new(&tmpdir).unwrap(); + + let global_bus = crate::bus::SharedMessageBus::new(); + let bus = super::DABusClient::new_from_bus(global_bus.new_handle()).await; + + let mut config: Conf = Conf::new(vec![], None, None).unwrap(); + config.da_server_port = find_available_port().await; + config.da_public_address = format!("127.0.0.1:{}", config.da_server_port); + + let mut block = SignedBlock::default(); + blocks_storage.put(block.clone()).unwrap(); + for i in 1..5 { + block.consensus_proposal.parent_hash = block.hashed(); + block.consensus_proposal.slot = i; + blocks_storage.put(block.clone()).unwrap(); + } + + let mut da = super::DataAvailability { + config: config.clone().into(), + bus, + blocks: blocks_storage, + buffered_signed_blocks: Default::default(), + catchupper: Default::default(), + allow_peer_catchup: false, + peer_send_queues: HashMap::new(), + }; + + tokio::spawn(async move { + da.start().await.unwrap(); + }); + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let mut client = + DataAvailabilityClient::connect("client_id", config.da_public_address.clone()) + .await + .unwrap(); + + // Request starts at the next block after current highest (peer is already up to date). + client + .send(DataAvailabilityRequest::StreamFromHeight(BlockHeight(5))) + .await + .unwrap(); + + // Should not be rejected with BlockNotFound. + let next = tokio::time::timeout(Duration::from_millis(300), client.recv()).await; + assert!( + next.is_err(), + "Did not expect an immediate stream event/rejection for up-to-date peer" + ); + } } From fed53e7a179b85a6edfafd3d745cb3243e4fd9b6 Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Thu, 19 Feb 2026 15:43:26 +0100 Subject: [PATCH 12/18] more simplify --- crates/hyli-net/src/tcp.rs | 43 ++++ crates/hyli-net/src/tcp/middleware/impls.rs | 4 +- crates/hyli-net/src/tcp/middleware/mod.rs | 45 +--- crates/hyli-net/src/tcp/p2p_server.rs | 6 +- crates/hyli-net/src/tcp/tcp_client.rs | 2 +- crates/hyli-net/src/tcp/tcp_server.rs | 221 +++++++++----------- src/data_availability.rs | 5 +- src/tcp_server.rs | 2 +- 8 files changed, 151 insertions(+), 177 deletions(-) diff --git a/crates/hyli-net/src/tcp.rs b/crates/hyli-net/src/tcp.rs index b208dd0c4..d9e06a539 100644 --- a/crates/hyli-net/src/tcp.rs +++ b/crates/hyli-net/src/tcp.rs @@ -25,6 +25,49 @@ use crate::net::TcpStream; pub type TcpHeaders = Vec<(String, String)>; +/// Common interface for `TcpServer` and middleware wrappers. +pub trait TcpServerLike { + type EventOut; + type ConnectedClients<'a>: Iterator + where + Self: 'a; + + /// Receive the next inbound event (or mapped output if wrapped). + async fn listen_next(&mut self) -> Option; + /// Send a response to a peer. + fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()>; + /// Send using borrowed payload/headers to avoid cloning on the success path. + fn send_ref(&mut self, socket_addr: &str, msg: &Res, headers: &TcpHeaders) -> anyhow::Result<()> + where + Res: Clone, + { + self.send(socket_addr.to_string(), msg.clone(), headers.clone()) + } + /// Return the currently connected peer socket addresses. + fn connected_clients(&self) -> Self::ConnectedClients<'_>; + /// Check whether a peer socket is currently connected. + fn connected(&self, socket_addr: &str) -> bool { + self.connected_clients().any(|addr| addr == socket_addr) + } + /// Drop and disconnect a peer socket. + fn drop_peer_stream(&mut self, peer_ip: String); + + /// Broadcast by fanout over `connected_clients()` using `send()`. + fn broadcast(&mut self, msg: Res, headers: TcpHeaders) -> Vec<(String, anyhow::Error)> + where + Res: Clone, + { + let peers: Vec = self.connected_clients().cloned().collect(); + let mut errors = Vec::new(); + for peer in peers { + if let Err(error) = self.send(peer.clone(), msg.clone(), headers.clone()) { + errors.push((peer, error)); + } + } + errors + } +} + #[macro_export] macro_rules! impl_tcp_message_label_with_prefix { ($ty:ty, $prefix:literal, { $( $variant:ident ),+ $(,)? }) => { diff --git a/crates/hyli-net/src/tcp/middleware/impls.rs b/crates/hyli-net/src/tcp/middleware/impls.rs index 95e3ecbf7..1b36fed37 100644 --- a/crates/hyli-net/src/tcp/middleware/impls.rs +++ b/crates/hyli-net/src/tcp/middleware/impls.rs @@ -3,9 +3,9 @@ use std::time::Duration; use tokio::time::Instant; -use crate::tcp::{TcpEvent, TcpHeaders}; +use crate::tcp::{TcpEvent, TcpHeaders, TcpServerLike}; -use super::{SendErrorContext, SendErrorOutcome, TcpMiddleware, TcpServerLike}; +use super::{SendErrorContext, SendErrorOutcome, TcpMiddleware}; #[derive(Default)] pub struct DropOnError; diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs index e3f7d0289..5a54381f4 100644 --- a/crates/hyli-net/src/tcp/middleware/mod.rs +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -3,7 +3,7 @@ use tokio::time::Instant; use borsh::{BorshDeserialize, BorshSerialize}; -use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders, TcpMessageLabel}; +use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders, TcpMessageLabel, TcpServerLike}; mod impls; @@ -178,49 +178,6 @@ where } } -/// Common interface for `TcpServer` and middleware wrappers. -pub trait TcpServerLike { - type EventOut; - type ConnectedClients<'a>: Iterator - where - Self: 'a; - - /// Receive the next inbound event (or mapped output if wrapped). - async fn listen_next(&mut self) -> Option; - /// Send a response to a peer. - fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()>; - /// Send using borrowed payload/headers to avoid cloning on the success path. - fn send_ref(&mut self, socket_addr: &str, msg: &Res, headers: &TcpHeaders) -> anyhow::Result<()> - where - Res: Clone, - { - self.send(socket_addr.to_string(), msg.clone(), headers.clone()) - } - /// Return the currently connected peer socket addresses. - fn connected_clients(&self) -> Self::ConnectedClients<'_>; - /// Check whether a peer socket is currently connected. - fn connected(&self, socket_addr: &str) -> bool { - self.connected_clients().any(|addr| addr == socket_addr) - } - /// Drop and disconnect a peer socket. - fn drop_peer_stream(&mut self, peer_ip: String); - - /// Broadcast by fanout over `connected_clients()` using `send()`. - fn broadcast(&mut self, msg: Res, headers: TcpHeaders) -> Vec<(String, anyhow::Error)> - where - Res: Clone, - { - let peers: Vec = self.connected_clients().cloned().collect(); - let mut errors = Vec::new(); - for peer in peers { - if let Err(error) = self.send(peer.clone(), msg.clone(), headers.clone()) { - errors.push((peer, error)); - } - } - errors - } -} - /// Tower-style layering helper for TCP servers and already-layered services. /// /// # Example diff --git a/crates/hyli-net/src/tcp/p2p_server.rs b/crates/hyli-net/src/tcp/p2p_server.rs index b2fb099ce..c826f1593 100644 --- a/crates/hyli-net/src/tcp/p2p_server.rs +++ b/crates/hyli-net/src/tcp/p2p_server.rs @@ -14,7 +14,7 @@ use crate::{ clock::TimestampMsClock, metrics::P2PMetrics, ordered_join_set::OrderedJoinSet, - tcp::{tcp_client::TcpClient, Handshake, TcpHeaders, TcpMessageLabel}, + tcp::{tcp_client::TcpClient, Handshake, TcpHeaders, TcpMessageLabel, TcpServerLike}, }; use hyli_turmoil_shims::collections::HashMap; @@ -1221,7 +1221,9 @@ pub mod tests { use tokio::net::TcpListener; use crate::clock::TimestampMsClock; - use crate::tcp::{p2p_server::P2PServer, Canal, Handshake, P2PTcpMessage, TcpEvent}; + use crate::tcp::{ + p2p_server::P2PServer, Canal, Handshake, P2PTcpMessage, TcpEvent, TcpServerLike, + }; use super::P2PTcpEvent; diff --git a/crates/hyli-net/src/tcp/tcp_client.rs b/crates/hyli-net/src/tcp/tcp_client.rs index 03ab406d8..32d8ade8a 100644 --- a/crates/hyli-net/src/tcp/tcp_client.rs +++ b/crates/hyli-net/src/tcp/tcp_client.rs @@ -201,7 +201,7 @@ mod tests { use std::time::Duration; use super::TcpClient; - use crate::tcp::tcp_server::TcpServer; + use crate::tcp::{tcp_server::TcpServer, TcpServerLike}; type TestTCPServer = TcpServer; type TestTCPClient = TcpClient; diff --git a/crates/hyli-net/src/tcp/tcp_server.rs b/crates/hyli-net/src/tcp/tcp_server.rs index 404444128..36e405cfb 100644 --- a/crates/hyli-net/src/tcp/tcp_server.rs +++ b/crates/hyli-net/src/tcp/tcp_server.rs @@ -29,7 +29,8 @@ use hyli_turmoil_shims::collections::HashMap; use tracing::{debug, error, trace, warn}; use super::{tcp_client::TcpClient, SocketStream, TcpEvent}; -use crate::tcp::middleware::{TcpReqBound, TcpResBound, TcpServerLike}; +use crate::tcp::middleware::{TcpReqBound, TcpResBound}; +use crate::tcp::TcpServerLike; type TcpSender = SplitSink; type TcpReceiver = SplitStream; @@ -166,42 +167,6 @@ where }) } - pub async fn listen_next(&mut self) -> Option> { - loop { - hyli_turmoil_shims::tokio_select_biased! { - Ok((stream, socket_addr)) = self.tcp_listener.accept() => { - if let Some(len) = self.max_frame_length { - debug!("Setting max frame length to {}", len); - } - let (sender, receiver) = framed_stream(stream, self.max_frame_length).split(); - self.setup_stream(sender, receiver, &socket_addr.to_string()); - } - - Some(socket_addr) = self.ping_receiver.recv() => { - trace!("Received ping from {}", socket_addr); - if let Some(socket) = self.sockets.get_mut(&socket_addr) { - socket.last_ping = TimestampMsClock::now(); - } - } - message = self.pool_receiver.recv() => { - let queued = self.pool_receiver.len(); - if let Some(msg) = message.as_ref() { - match msg.as_ref() { - TcpEvent::Message { socket_addr, data, .. } => { - self.metrics - .event_loop_message_received(data.message_label()); - trace!(pool = %self.pool_name, "TcpServer event queue: message for {} ({} remaining)", socket_addr, queued) - } - TcpEvent::Closed { socket_addr } => trace!(pool = %self.pool_name, "TcpServer event queue: closed for {} ({} remaining)", socket_addr, queued), - TcpEvent::Error { socket_addr, error } => trace!(pool = %self.pool_name, "TcpServer event queue: error for {}: {} ({} remaining)", socket_addr, error, queued), - } - } - return message.map(|message| *message); - } - } - } - } - #[cfg(test)] /// Local_addr of the underlying tcp_listener pub fn local_addr(&self) -> anyhow::Result { @@ -210,11 +175,6 @@ where .context("Getting local_addr from TcpListener in TcpServer") } - /// Adresses of currently connected clients (no health check) - pub fn connected_clients(&self) -> ConnectedClients<'_> { - ConnectedClients(self.sockets.keys()) - } - pub fn connected(&self, socket_addr: &str) -> bool { self.sockets.contains_key(socket_addr) } @@ -311,68 +271,6 @@ where result } - pub fn send( - &mut self, - socket_addr: String, - msg: Res, - headers: TcpHeaders, - ) -> anyhow::Result<()> { - debug!(pool = %self.pool_name, "Sending msg {:?} to {}", msg, socket_addr); - let message_label = msg.message_label(); - let stream = self - .sockets - .get_mut(&socket_addr) - .context(format!("Retrieving client {socket_addr}"))?; - - let binary_data = to_tcp_message_with_headers(&msg, headers)?; - stream - .sender - .try_send(TcpOutboundMessage { - message: binary_data, - message_label, - }) - .map_err(|e| { - anyhow::anyhow!( - "Outbound TCP channel full/closed while sending msg to client {}: {}", - socket_addr, - e - ) - })?; - self.metrics.event_loop_message_sent(message_label); - Ok(()) - } - - pub fn send_ref( - &mut self, - socket_addr: &str, - msg: &Res, - headers: &TcpHeaders, - ) -> anyhow::Result<()> { - debug!(pool = %self.pool_name, "Sending msg {:?} to {}", msg, socket_addr); - let message_label = msg.message_label(); - let stream = self - .sockets - .get_mut(socket_addr) - .context(format!("Retrieving client {socket_addr}"))?; - - let binary_data = to_tcp_message_with_headers(msg, headers.clone())?; - stream - .sender - .try_send(TcpOutboundMessage { - message: binary_data, - message_label, - }) - .map_err(|e| { - anyhow::anyhow!( - "Outbound TCP channel full/closed while sending msg to client {}: {}", - socket_addr, - e - ) - })?; - self.metrics.event_loop_message_sent(message_label); - Ok(()) - } - pub fn ping(&mut self, socket_addr: String) -> anyhow::Result<()> { let stream = self .sockets @@ -679,21 +577,6 @@ where self.setup_stream(sender, receiver, &addr); } - pub fn drop_peer_stream(&mut self, peer_ip: String) { - if let Some(peer_stream) = self.sockets.remove(&peer_ip) { - tracing::debug!( - pool = %self.pool_name, - "Dropping peer stream {} (remaining sockets: {})", - peer_ip, - self.sockets.len() - ); - peer_stream.abort_sender_task.abort(); - peer_stream.abort_receiver_task.abort(); - tracing::debug!(pool = %self.pool_name, "Client {} dropped & disconnected", peer_ip); - self.metrics.peers_snapshot(self.sockets.len() as u64); - } - } - pub fn set_peer_label(&mut self, socket_addr: &str, label: String) { if let Some(stream) = self.sockets.get_mut(socket_addr) { let mut guard = match stream.socket_label.write() { @@ -717,26 +600,113 @@ where Self: 'a; async fn listen_next(&mut self) -> Option { - TcpServer::listen_next(self).await + loop { + hyli_turmoil_shims::tokio_select_biased! { + Ok((stream, socket_addr)) = self.tcp_listener.accept() => { + if let Some(len) = self.max_frame_length { + debug!("Setting max frame length to {}", len); + } + let (sender, receiver) = framed_stream(stream, self.max_frame_length).split(); + self.setup_stream(sender, receiver, &socket_addr.to_string()); + } + + Some(socket_addr) = self.ping_receiver.recv() => { + trace!("Received ping from {}", socket_addr); + if let Some(socket) = self.sockets.get_mut(&socket_addr) { + socket.last_ping = TimestampMsClock::now(); + } + } + message = self.pool_receiver.recv() => { + let queued = self.pool_receiver.len(); + if let Some(msg) = message.as_ref() { + match msg.as_ref() { + TcpEvent::Message { socket_addr, data, .. } => { + self.metrics + .event_loop_message_received(data.message_label()); + trace!(pool = %self.pool_name, "TcpServer event queue: message for {} ({} remaining)", socket_addr, queued) + } + TcpEvent::Closed { socket_addr } => trace!(pool = %self.pool_name, "TcpServer event queue: closed for {} ({} remaining)", socket_addr, queued), + TcpEvent::Error { socket_addr, error } => trace!(pool = %self.pool_name, "TcpServer event queue: error for {}: {} ({} remaining)", socket_addr, error, queued), + } + } + return message.map(|message| *message); + } + } + } } fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()> { - TcpServer::send(self, socket_addr, msg, headers) + debug!(pool = %self.pool_name, "Sending msg {:?} to {}", msg, socket_addr); + let message_label = msg.message_label(); + let stream = self + .sockets + .get_mut(&socket_addr) + .context(format!("Retrieving client {socket_addr}"))?; + + let binary_data = to_tcp_message_with_headers(&msg, headers)?; + stream + .sender + .try_send(TcpOutboundMessage { + message: binary_data, + message_label, + }) + .map_err(|e| { + anyhow::anyhow!( + "Outbound TCP channel full/closed while sending msg to client {}: {}", + socket_addr, + e + ) + })?; + self.metrics.event_loop_message_sent(message_label); + Ok(()) } fn send_ref(&mut self, socket_addr: &str, msg: &Res, headers: &TcpHeaders) -> anyhow::Result<()> where Res: Clone, { - TcpServer::send_ref(self, socket_addr, msg, headers) + debug!(pool = %self.pool_name, "Sending msg {:?} to {}", msg, socket_addr); + let message_label = msg.message_label(); + let stream = self + .sockets + .get_mut(socket_addr) + .context(format!("Retrieving client {socket_addr}"))?; + + let binary_data = to_tcp_message_with_headers(msg, headers.clone())?; + stream + .sender + .try_send(TcpOutboundMessage { + message: binary_data, + message_label, + }) + .map_err(|e| { + anyhow::anyhow!( + "Outbound TCP channel full/closed while sending msg to client {}: {}", + socket_addr, + e + ) + })?; + self.metrics.event_loop_message_sent(message_label); + Ok(()) } fn connected_clients(&self) -> Self::ConnectedClients<'_> { - TcpServer::connected_clients(self) + ConnectedClients(self.sockets.keys()) } fn drop_peer_stream(&mut self, peer_ip: String) { - TcpServer::drop_peer_stream(self, peer_ip) + if let Some(peer_stream) = self.sockets.remove(&peer_ip) { + tracing::debug!( + pool = %self.pool_name, + "Dropping peer stream {} (remaining sockets: {})", + peer_ip, + self.sockets.len() + ); + peer_stream.abort_sender_task.abort(); + peer_stream.abort_receiver_task.abort(); + tracing::debug!(pool = %self.pool_name, "Client {} dropped & disconnected", peer_ip); + self.metrics.peers_snapshot(self.sockets.len() as u64); + } } } @@ -746,7 +716,8 @@ pub mod tests { use super::TcpServer; use crate::tcp::{ - tcp_client::TcpClient, tcp_server::peer_label_or_addr, to_tcp_message, TcpEvent, TcpMessage, + tcp_client::TcpClient, tcp_server::peer_label_or_addr, to_tcp_message, TcpEvent, + TcpMessage, TcpServerLike, }; use anyhow::Result; use bytes::Bytes; diff --git a/src/data_availability.rs b/src/data_availability.rs index 145b232a9..a5259b3ce 100644 --- a/src/data_availability.rs +++ b/src/data_availability.rs @@ -8,8 +8,9 @@ use hyli_modules::modules::da_listener::{DaStreamPoll, SignedDaStream}; use hyli_modules::{bus::SharedMessageBus, modules::Module}; use hyli_modules::{log_error, module_bus_client, module_handle_messages}; use hyli_net::tcp::middleware::{ - middleware_layer, DropOnError, MessageOnly, RetryingSend, TcpServerExt, TcpServerLike, + middleware_layer, DropOnError, MessageOnly, RetryingSend, TcpServerExt, }; +use hyli_net::tcp::TcpServerLike; use tokio::task::JoinHandle; use crate::{ @@ -1261,7 +1262,7 @@ pub mod tests { use hyli_modules::node_state::module::NodeStateBusClient; use hyli_modules::node_state::NodeState; use hyli_modules::utils::da_codec::DataAvailabilityClient; - use hyli_net::tcp::middleware::TcpServerLike; + use hyli_net::tcp::TcpServerLike; use staking::state::Staking; use tokio::task::JoinSet; diff --git a/src/tcp_server.rs b/src/tcp_server.rs index af1b85da3..5530a0fad 100644 --- a/src/tcp_server.rs +++ b/src/tcp_server.rs @@ -7,7 +7,7 @@ use hyli_modules::{ log_error, module_handle_messages, modules::{module_bus_client, Module}, }; -use hyli_net::tcp::TcpEvent; +use hyli_net::tcp::{TcpEvent, TcpServerLike}; use tracing::{info, warn}; module_bus_client! { From 7b0f325377554da225ba57ca1d36cb8e6879f69c Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Thu, 19 Feb 2026 16:09:44 +0100 Subject: [PATCH 13/18] simplify --- crates/hyli-net/src/tcp.rs | 106 ++++++++----------------------------- src/data_availability.rs | 83 +++++++++++------------------ 2 files changed, 55 insertions(+), 134 deletions(-) diff --git a/crates/hyli-net/src/tcp.rs b/crates/hyli-net/src/tcp.rs index d9e06a539..e6d78d7c0 100644 --- a/crates/hyli-net/src/tcp.rs +++ b/crates/hyli-net/src/tcp.rs @@ -12,16 +12,33 @@ use std::{ use borsh::{BorshDeserialize, BorshSerialize}; use bytes::Bytes; -use sdk::{ - hyli_model_utils::TimestampMs, DataAvailabilityEvent, DataAvailabilityRequest, DataProposal, -}; -use strum_macros::IntoStaticStr; +use sdk::hyli_model_utils::TimestampMs; use tokio::task::JoinHandle; use tokio_util::codec::{Framed, LengthDelimitedCodec}; use anyhow::Result; use crate::net::TcpStream; +/// Derive macro for [`TcpMessageLabel`]. +/// +/// # Defaults +/// - Structs return `"TypeName"`. +/// - Enums return `"TypeName::VariantName"`. +/// +/// # Example +/// ```rust +/// use hyli_net::tcp::TcpMessageLabel; +/// +/// #[derive(TcpMessageLabel)] +/// enum Msg { +/// Ping, +/// Data(Vec), +/// } +/// +/// assert_eq!(Msg::Ping.message_label(), "Msg::Ping"); +/// assert_eq!(Msg::Data(vec![]).message_label(), "Msg::Data"); +/// ``` +pub use hyli_net_traits::TcpMessageLabel; pub type TcpHeaders = Vec<(String, String)>; @@ -32,27 +49,20 @@ pub trait TcpServerLike { where Self: 'a; - /// Receive the next inbound event (or mapped output if wrapped). async fn listen_next(&mut self) -> Option; - /// Send a response to a peer. fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()>; - /// Send using borrowed payload/headers to avoid cloning on the success path. fn send_ref(&mut self, socket_addr: &str, msg: &Res, headers: &TcpHeaders) -> anyhow::Result<()> where Res: Clone, { self.send(socket_addr.to_string(), msg.clone(), headers.clone()) } - /// Return the currently connected peer socket addresses. fn connected_clients(&self) -> Self::ConnectedClients<'_>; - /// Check whether a peer socket is currently connected. fn connected(&self, socket_addr: &str) -> bool { self.connected_clients().any(|addr| addr == socket_addr) } - /// Drop and disconnect a peer socket. fn drop_peer_stream(&mut self, peer_ip: String); - /// Broadcast by fanout over `connected_clients()` using `send()`. fn broadcast(&mut self, msg: Res, headers: TcpHeaders) -> Vec<(String, anyhow::Error)> where Res: Clone, @@ -68,23 +78,6 @@ pub trait TcpServerLike { } } -#[macro_export] -macro_rules! impl_tcp_message_label_with_prefix { - ($ty:ty, $prefix:literal, { $( $variant:ident ),+ $(,)? }) => { - impl $crate::tcp::TcpMessageLabel for $ty { - fn message_label(&self) -> &'static str { - match self { - $( Self::$variant(..) => concat!($prefix, "::", stringify!($variant)), )+ - } - } - } - }; -} - -pub trait TcpMessageLabel { - fn message_label(&self) -> &'static str; -} - pub(crate) type FramedStream = Framed; pub(crate) fn framed_stream(stream: TcpStream, max_frame_length: Option) -> FramedStream { @@ -116,7 +109,7 @@ pub fn headers_from_span() -> TcpHeaders { } } -#[derive(Clone, BorshDeserialize, BorshSerialize, PartialEq, IntoStaticStr)] +#[derive(Clone, BorshDeserialize, BorshSerialize, PartialEq, TcpMessageLabel)] pub enum TcpMessage { Ping, Data(TcpData), @@ -240,12 +233,6 @@ impl std::fmt::Debug for TcpMessage { } } -impl TcpMessageLabel for TcpMessage { - fn message_label(&self) -> &'static str { - self.clone().into() - } -} - #[expect(clippy::large_enum_variant)] #[derive(Debug, Clone, BorshDeserialize, BorshSerialize, PartialEq)] pub enum P2PTcpMessage { @@ -264,7 +251,7 @@ impl TcpMessageLabel } } -#[derive(Debug, Clone, BorshSerialize, BorshDeserialize, PartialEq, IntoStaticStr)] +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize, PartialEq, TcpMessageLabel)] pub enum Handshake { Hello( ( @@ -282,53 +269,6 @@ pub enum Handshake { ), } -impl TcpMessageLabel for Handshake { - fn message_label(&self) -> &'static str { - self.clone().into() - } -} - -impl TcpMessageLabel for Vec { - fn message_label(&self) -> &'static str { - "bytes" - } -} - -impl TcpMessageLabel for String { - fn message_label(&self) -> &'static str { - "string" - } -} - -impl TcpMessageLabel for DataAvailabilityRequest { - fn message_label(&self) -> &'static str { - match self { - DataAvailabilityRequest::StreamFromHeight(_) => { - "DataAvailabilityRequest::StreamFromHeight" - } - DataAvailabilityRequest::BlockRequest(_) => "DataAvailabilityRequest::BlockRequest", - } - } -} - -impl TcpMessageLabel for DataAvailabilityEvent { - fn message_label(&self) -> &'static str { - match self { - DataAvailabilityEvent::SignedBlock(_) => "DataAvailabilityEvent::SignedBlock", - DataAvailabilityEvent::MempoolStatusEvent(_) => { - "DataAvailabilityEvent::MempoolStatusEvent" - } - DataAvailabilityEvent::BlockNotFound(_) => "DataAvailabilityEvent::BlockNotFound", - } - } -} - -impl TcpMessageLabel for DataProposal { - fn message_label(&self) -> &'static str { - "DataProposal" - } -} - #[derive( Default, Debug, Clone, BorshSerialize, BorshDeserialize, Hash, PartialEq, Eq, PartialOrd, Ord, )] diff --git a/src/data_availability.rs b/src/data_availability.rs index a5259b3ce..4315f21cb 100644 --- a/src/data_availability.rs +++ b/src/data_availability.rs @@ -773,16 +773,13 @@ impl DataAvailability { Ok(()) } - async fn handle_send_next_block_to_peer( + async fn handle_send_next_block_to_peer( &mut self, peer_ip: String, retries: usize, catchup_joinset: &mut JoinSet<(String, usize)>, - server: &mut S, - ) -> Result<()> - where - S: TcpServerLike, - { + server: &mut DaServerStack, + ) -> Result<()> { if !server.connected(&peer_ip) { debug!("Peer {} disconnected, removing from send queues", peer_ip); self.peer_send_queues.remove(&peer_ip); @@ -850,15 +847,12 @@ impl DataAvailability { Ok(()) } - async fn handle_block_request( + async fn handle_block_request( &mut self, block_height: BlockHeight, socket_addr: &str, - server: &mut S, - ) -> Result<()> - where - S: TcpServerLike, - { + server: &mut DaServerStack, + ) -> Result<()> { debug!( "📦 Received block request for height {} from {}", block_height, socket_addr @@ -927,15 +921,12 @@ impl DataAvailability { Ok(()) } - async fn handle_mempool_event( + async fn handle_mempool_event( &mut self, evt: MempoolBlockEvent, - tcp_server: &mut S, + tcp_server: &mut DaServerStack, catchup_joinset: &mut JoinSet<(String, usize)>, - ) -> Result<()> - where - S: TcpServerLike, - { + ) -> Result<()> { match evt { MempoolBlockEvent::BuiltSignedBlock(signed_block) => { debug!( @@ -960,10 +951,11 @@ impl DataAvailability { Ok(()) } - async fn handle_mempool_status_event(&mut self, evt: MempoolStatusEvent, tcp_server: &mut S) - where - S: TcpServerLike, - { + async fn handle_mempool_status_event( + &mut self, + evt: MempoolStatusEvent, + tcp_server: &mut DaServerStack, + ) { let errors = tcp_server.broadcast(DataAvailabilityEvent::MempoolStatusEvent(evt), vec![]); for (peer, error) in errors { @@ -973,15 +965,12 @@ impl DataAvailability { } /// if handled, returns the highest height of the processed blocks - async fn handle_signed_block( + async fn handle_signed_block( &mut self, block: SignedBlock, - tcp_server: &mut S, + tcp_server: &mut DaServerStack, catchup_joinset: &mut JoinSet<(String, usize)>, - ) -> Option - where - S: TcpServerLike, - { + ) -> Option { let hash = block.hashed(); // if new block is already handled, ignore it if self.blocks.contains(&hash) { @@ -1038,15 +1027,12 @@ impl DataAvailability { } /// Returns the highest height of the processed blocks - async fn pop_buffer( + async fn pop_buffer( &mut self, mut last_block_hash: ConsensusProposalHash, - tcp_server: &mut S, + tcp_server: &mut DaServerStack, catchup_joinset: &mut JoinSet<(String, usize)>, - ) -> Option - where - S: TcpServerLike, - { + ) -> Option { let mut res = None; // Iterative loop to avoid stack overflows @@ -1115,15 +1101,12 @@ impl DataAvailability { Ok(()) } - async fn add_processed_block( + async fn add_processed_block( &mut self, block: SignedBlock, - _tcp_server: &mut S, + _tcp_server: &mut DaServerStack, catchup_joinset: &mut JoinSet<(String, usize)>, - ) -> anyhow::Result<()> - where - S: TcpServerLike, - { + ) -> anyhow::Result<()> { self.store_block(&block)?; let block_hash = block.hashed(); @@ -1163,16 +1146,13 @@ impl DataAvailability { Ok(()) } - async fn start_streaming_to_peer( + async fn start_streaming_to_peer( &mut self, start_height: BlockHeight, catchup_joinset: &mut JoinSet<(String, usize)>, peer_ip: &str, - server: &mut S, - ) -> Result<()> - where - S: TcpServerLike, - { + server: &mut DaServerStack, + ) -> Result<()> { let range_start = std::time::Instant::now(); let highest = self .blocks @@ -1249,7 +1229,7 @@ pub mod tests { use std::{collections::HashMap, time::Duration}; use super::module_bus_client; - use super::{Blocks, RawDataAvailabilityServer}; + use super::{Blocks, DaServerStack, RawDataAvailabilityServer}; use crate::data_availability::DaCatchupPolicy; use crate::{ bus::BusClientSender, @@ -1304,10 +1284,11 @@ pub mod tests { } } - pub async fn handle_signed_block(&mut self, block: SignedBlock, tcp_server: &mut S) - where - S: TcpServerLike, - { + pub async fn handle_signed_block( + &mut self, + block: SignedBlock, + tcp_server: &mut DaServerStack, + ) { let mut catchup_joinset: JoinSet<(String, usize)> = JoinSet::new(); _ = self .da From ff64f8662b8ba04753078707926e8e736df05d70 Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Thu, 19 Feb 2026 16:18:57 +0100 Subject: [PATCH 14/18] extract retry send logic --- src/data_availability.rs | 72 ++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 47 deletions(-) diff --git a/src/data_availability.rs b/src/data_availability.rs index 4315f21cb..0b9807ab1 100644 --- a/src/data_availability.rs +++ b/src/data_availability.rs @@ -659,10 +659,8 @@ impl DataAvailability { let mut first_hole_receiver = self.start_scanning_for_first_hole(); // Used to send blocks to clients (indexers/peers) - // This is a JoinSet of tuples containing: - // - The peer IP address to send the blocks to - // - The number of retries for sending the blocks - let mut catchup_joinset: JoinSet<(String, usize)> = tokio::task::JoinSet::new(); + // JoinSet of peer addresses to process one queued send at a time. + let mut catchup_joinset: JoinSet = tokio::task::JoinSet::new(); let mut catchup_task_checker_ticker = tokio::time::interval(std::time::Duration::from_secs(5)); let mut storage_metrics_ticker = tokio::time::interval(std::time::Duration::from_secs(30)); @@ -743,7 +741,7 @@ impl DataAvailability { // Send one block to a peer as part of "catchup", // once we have sent all blocks the peer is presumably synchronised. - Some(Ok((peer_ip, retries))) = catchup_joinset.join_next() => { + Some(Ok(peer_ip)) = catchup_joinset.join_next() => { #[cfg(test)] tokio::time::sleep(std::time::Duration::from_millis(100)).await; @@ -751,7 +749,6 @@ impl DataAvailability { _ = log_error!( self.handle_send_next_block_to_peer( peer_ip.clone(), - retries, &mut catchup_joinset, &mut server ).await, @@ -776,8 +773,7 @@ impl DataAvailability { async fn handle_send_next_block_to_peer( &mut self, peer_ip: String, - retries: usize, - catchup_joinset: &mut JoinSet<(String, usize)>, + catchup_joinset: &mut JoinSet, server: &mut DaServerStack, ) -> Result<()> { if !server.connected(&peer_ip) { @@ -786,16 +782,6 @@ impl DataAvailability { return Ok(()); } - if retries > 10 { - warn!( - "Failed to send block, too many retries for peer {}", - &peer_ip - ); - server.drop_peer_stream(peer_ip.clone()); - self.peer_send_queues.remove(&peer_ip); - return Ok(()); - } - // Get next block from this peer's queue let hash = match self.peer_send_queues.get_mut(&peer_ip) { Some(queue) => match queue.pop_front() { @@ -823,17 +809,11 @@ impl DataAvailability { ) { Ok(()) => { // Successfully sent, continue with next block - catchup_joinset.spawn(async move { (peer_ip, 0) }); + catchup_joinset.spawn(async move { peer_ip }); } - Err(_) => { - // Retry sending the same block (put it back at front of queue) - if let Some(queue) = self.peer_send_queues.get_mut(&peer_ip) { - queue.push_front(hash); - } - catchup_joinset.spawn(async move { - tokio::time::sleep(Duration::from_millis(100 * (retries as u64))).await; - (peer_ip, retries + 1) - }); + Err(e) => { + warn!("Error sending block {} to peer {}: {:#}", hash, peer_ip, e); + self.peer_send_queues.remove(&peer_ip); } } } else { @@ -842,7 +822,7 @@ impl DataAvailability { &hash, &peer_ip ); // Continue anyway with next block - catchup_joinset.spawn(async move { (peer_ip, 0) }); + catchup_joinset.spawn(async move { peer_ip }); } Ok(()) } @@ -872,10 +852,9 @@ impl DataAvailability { vec![], ) { warn!( - "📦 Error while responding to block request at height {} for {}: {:#}. Dropping socket.", + "📦 Error while responding to block request at height {} for {}: {:#}", block_height, socket_addr, e ); - server.drop_peer_stream(socket_addr.to_string()); return Ok(()); } } @@ -891,10 +870,9 @@ impl DataAvailability { vec![], ) { warn!( - "📦 Error while responding BlockNotFound at height {} for {}: {:#}. Dropping socket.", + "📦 Error while responding BlockNotFound at height {} for {}: {:#}", block_height, socket_addr, e ); - server.drop_peer_stream(socket_addr.to_string()); return Ok(()); } } @@ -909,10 +887,9 @@ impl DataAvailability { vec![], ) { warn!( - "📦 Error while responding BlockNotFound at height {} for {}: {:#}. Dropping socket.", + "📦 Error while responding BlockNotFound at height {} for {}: {:#}", block_height, socket_addr, e ); - server.drop_peer_stream(socket_addr.to_string()); return Ok(()); } } @@ -925,7 +902,7 @@ impl DataAvailability { &mut self, evt: MempoolBlockEvent, tcp_server: &mut DaServerStack, - catchup_joinset: &mut JoinSet<(String, usize)>, + catchup_joinset: &mut JoinSet, ) -> Result<()> { match evt { MempoolBlockEvent::BuiltSignedBlock(signed_block) => { @@ -959,8 +936,10 @@ impl DataAvailability { let errors = tcp_server.broadcast(DataAvailabilityEvent::MempoolStatusEvent(evt), vec![]); for (peer, error) in errors { - warn!("Error while broadcasting mempool status event {:#}", error); - tcp_server.drop_peer_stream(peer.clone()); + warn!( + "Error while broadcasting mempool status event to {}: {:#}", + peer, error + ); } } @@ -969,7 +948,7 @@ impl DataAvailability { &mut self, block: SignedBlock, tcp_server: &mut DaServerStack, - catchup_joinset: &mut JoinSet<(String, usize)>, + catchup_joinset: &mut JoinSet, ) -> Option { let hash = block.hashed(); // if new block is already handled, ignore it @@ -1031,7 +1010,7 @@ impl DataAvailability { &mut self, mut last_block_hash: ConsensusProposalHash, tcp_server: &mut DaServerStack, - catchup_joinset: &mut JoinSet<(String, usize)>, + catchup_joinset: &mut JoinSet, ) -> Option { let mut res = None; @@ -1105,7 +1084,7 @@ impl DataAvailability { &mut self, block: SignedBlock, _tcp_server: &mut DaServerStack, - catchup_joinset: &mut JoinSet<(String, usize)>, + catchup_joinset: &mut JoinSet, ) -> anyhow::Result<()> { self.store_block(&block)?; @@ -1124,7 +1103,7 @@ impl DataAvailability { peer, block_hash ); let peer_clone = peer.clone(); - catchup_joinset.spawn(async move { (peer_clone, 0) }); + catchup_joinset.spawn(async move { peer_clone }); } else { debug!( "Appending block {} to queue for peer {} (queue size: {})", @@ -1149,7 +1128,7 @@ impl DataAvailability { async fn start_streaming_to_peer( &mut self, start_height: BlockHeight, - catchup_joinset: &mut JoinSet<(String, usize)>, + catchup_joinset: &mut JoinSet, peer_ip: &str, server: &mut DaServerStack, ) -> Result<()> { @@ -1199,7 +1178,6 @@ impl DataAvailability { ); } - server.drop_peer_stream(peer_ip.to_string()); self.peer_send_queues.remove(peer_ip); return Ok(()); } @@ -1217,7 +1195,7 @@ impl DataAvailability { .insert(peer_ip_string.clone(), processed_block_hashes); // Start the send task for this peer - catchup_joinset.spawn(async move { (peer_ip_string, 0) }); + catchup_joinset.spawn(async move { peer_ip_string }); Ok(()) } @@ -1289,7 +1267,7 @@ pub mod tests { block: SignedBlock, tcp_server: &mut DaServerStack, ) { - let mut catchup_joinset: JoinSet<(String, usize)> = JoinSet::new(); + let mut catchup_joinset: JoinSet = JoinSet::new(); _ = self .da .handle_signed_block(block.clone(), tcp_server, &mut catchup_joinset) @@ -1351,7 +1329,7 @@ pub mod tests { block.consensus_proposal.slot = i; } blocks.reverse(); - let mut catchup_joinset: JoinSet<(String, usize)> = JoinSet::new(); + let mut catchup_joinset: JoinSet = JoinSet::new(); for block in blocks { if block.height().0 == 0 { assert_eq!( From 62bc098e04c79be02c7045fa13aef703d8052373 Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Fri, 20 Feb 2026 18:22:46 +0100 Subject: [PATCH 15/18] clean middlewares --- crates/hyli-net/src/tcp.rs | 11 +- crates/hyli-net/src/tcp/middleware/impls.rs | 209 ++++++++++++++- crates/hyli-net/src/tcp/middleware/mod.rs | 56 +++- crates/hyli-net/src/tcp/tcp_server.rs | 10 +- src/data_availability.rs | 281 +++++++------------- 5 files changed, 372 insertions(+), 195 deletions(-) diff --git a/crates/hyli-net/src/tcp.rs b/crates/hyli-net/src/tcp.rs index e6d78d7c0..3d7e6deec 100644 --- a/crates/hyli-net/src/tcp.rs +++ b/crates/hyli-net/src/tcp.rs @@ -50,18 +50,27 @@ pub trait TcpServerLike { Self: 'a; async fn listen_next(&mut self) -> Option; - fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()>; + fn send( + &mut self, + socket_addr: String, + msg: Res, + headers: TcpHeaders, + ) -> anyhow::Result; fn send_ref(&mut self, socket_addr: &str, msg: &Res, headers: &TcpHeaders) -> anyhow::Result<()> where Res: Clone, { self.send(socket_addr.to_string(), msg.clone(), headers.clone()) + .map(|_| ()) } fn connected_clients(&self) -> Self::ConnectedClients<'_>; fn connected(&self, socket_addr: &str) -> bool { self.connected_clients().any(|addr| addr == socket_addr) } fn drop_peer_stream(&mut self, peer_ip: String); + fn poll_send_completion(&mut self) -> Option { + None + } fn broadcast(&mut self, msg: Res, headers: TcpHeaders) -> Vec<(String, anyhow::Error)> where diff --git a/crates/hyli-net/src/tcp/middleware/impls.rs b/crates/hyli-net/src/tcp/middleware/impls.rs index 1b36fed37..9ef52d795 100644 --- a/crates/hyli-net/src/tcp/middleware/impls.rs +++ b/crates/hyli-net/src/tcp/middleware/impls.rs @@ -1,11 +1,11 @@ -use std::collections::VecDeque; +use std::collections::{HashMap, HashSet, VecDeque}; use std::time::Duration; use tokio::time::Instant; use crate::tcp::{TcpEvent, TcpHeaders, TcpServerLike}; -use super::{SendErrorContext, SendErrorOutcome, TcpMiddleware}; +use super::{SendCompletion, SendErrorContext, SendErrorOutcome, SendStatus, TcpMiddleware}; #[derive(Default)] pub struct DropOnError; @@ -63,6 +63,7 @@ struct PendingSend { msg: Res, headers: TcpHeaders, retries: usize, + ticket: u64, next_attempt_at: Instant, } @@ -71,6 +72,8 @@ pub struct RetryingSend { base_delay: Duration, max_per_tick: usize, queue: VecDeque>, + next_ticket: u64, + completions: VecDeque, } impl RetryingSend { @@ -80,6 +83,8 @@ impl RetryingSend { base_delay, max_per_tick: 64, queue: VecDeque::new(), + next_ticket: 1, + completions: VecDeque::new(), } } @@ -89,11 +94,14 @@ impl RetryingSend { } pub fn enqueue_to_peer(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) { + let ticket = self.next_ticket; + self.next_ticket = self.next_ticket.wrapping_add(1).max(1); self.queue.push_back(PendingSend { socket_addr, msg, headers, retries: 0, + ticket, next_attempt_at: Instant::now(), }); } @@ -125,14 +133,17 @@ where if !server.connected(&ctx.socket_addr) { return SendErrorOutcome::Unhandled(anyhow::anyhow!(ctx.error.to_string())); } + let ticket = self.next_ticket; + self.next_ticket = self.next_ticket.wrapping_add(1).max(1); self.queue.push_back(PendingSend { socket_addr: ctx.socket_addr.clone(), msg: ctx.msg.clone(), headers: ctx.headers.clone(), retries: 0, + ticket, next_attempt_at: Instant::now() + self.base_delay, }); - SendErrorOutcome::RetryScheduled + SendErrorOutcome::RetryScheduled { ticket } } fn on_tick(&mut self, server: &mut S) @@ -154,6 +165,9 @@ where } if !server.connected(&pending.socket_addr) { + self.completions.push_back(SendCompletion::Failed { + ticket: pending.ticket, + }); continue; } @@ -162,11 +176,18 @@ where pending.msg.clone(), pending.headers.clone(), ) { - Ok(()) => {} + Ok(_) => { + self.completions.push_back(SendCompletion::Delivered { + ticket: pending.ticket, + }); + } Err(_) => { let next_retries = pending.retries + 1; if next_retries > self.max_retries { server.drop_peer_stream(pending.socket_addr); + self.completions.push_back(SendCompletion::Failed { + ticket: pending.ticket, + }); } else { pending.retries = next_retries; pending.next_attempt_at = @@ -188,4 +209,184 @@ where .map(|pending| pending.next_attempt_at) .min() } + + fn poll_send_completion(&mut self) -> Option { + self.completions.pop_front() + } +} + +#[derive(Debug, Clone, Copy)] +pub enum AdvanceOn { + Accepted, + Delivered, +} + +struct QueuedItem { + input: In, + headers: TcpHeaders, +} + +pub struct DequeDispatch { + state: State, + resolver: Resolve, + advance_on: AdvanceOn, + tick_interval: Duration, + max_queue_len: Option, + queues: HashMap>>, + in_flight: HashSet, + ticket_to_peer: HashMap, + _res: std::marker::PhantomData, +} + +impl DequeDispatch { + pub fn new(state: State, resolver: Resolve) -> Self { + Self { + state, + resolver, + advance_on: AdvanceOn::Delivered, + tick_interval: Duration::from_millis(5), + max_queue_len: None, + queues: HashMap::new(), + in_flight: HashSet::new(), + ticket_to_peer: HashMap::new(), + _res: std::marker::PhantomData, + } + } + + pub fn advance_on(mut self, advance_on: AdvanceOn) -> Self { + self.advance_on = advance_on; + self + } + + pub fn tick_interval(mut self, interval: Duration) -> Self { + self.tick_interval = interval; + self + } + + pub fn max_queue_len(mut self, max_queue_len: usize) -> Self { + self.max_queue_len = Some(max_queue_len.max(1)); + self + } + + pub fn enqueue( + &mut self, + socket_addr: String, + input: In, + headers: TcpHeaders, + ) -> anyhow::Result<()> { + if let Some(max_len) = self.max_queue_len { + let total_len: usize = self.queues.values().map(VecDeque::len).sum(); + if total_len >= max_len { + return Err(anyhow::anyhow!( + "DequeDispatch queue is full (max_len={max_len})" + )); + } + } + + self.queues + .entry(socket_addr) + .or_default() + .push_back(QueuedItem { input, headers }); + Ok(()) + } + + pub fn send_now( + &mut self, + socket_addr: String, + input: In, + headers: TcpHeaders, + ) -> anyhow::Result<()> { + if let Some(max_len) = self.max_queue_len { + let total_len: usize = self.queues.values().map(VecDeque::len).sum(); + if total_len >= max_len { + return Err(anyhow::anyhow!( + "DequeDispatch queue is full (max_len={max_len})" + )); + } + } + + self.queues + .entry(socket_addr) + .or_default() + .push_front(QueuedItem { input, headers }); + Ok(()) + } +} + +impl TcpMiddleware + for DequeDispatch +where + Req: super::TcpReqBound, + Res: super::TcpResBound + Clone + Send + 'static, + Resolve: Fn(&State, In) -> anyhow::Result + Send + 'static, +{ + type EventOut = TcpEvent; + + fn on_event(&mut self, _server: &mut S, event: TcpEvent) -> Option + where + S: TcpServerLike>, + { + Some(event) + } + + fn on_tick(&mut self, server: &mut S) + where + S: TcpServerLike>, + { + while let Some(completion) = server.poll_send_completion() { + let ticket = match completion { + SendCompletion::Delivered { ticket } | SendCompletion::Failed { ticket } => ticket, + }; + if let Some(peer) = self.ticket_to_peer.remove(&ticket) { + self.in_flight.remove(&peer); + } + } + + let eligible: Vec = self + .queues + .iter() + .filter(|(peer, queue)| !queue.is_empty() && !self.in_flight.contains(*peer)) + .map(|(peer, _)| peer.clone()) + .collect(); + + for peer in eligible { + let Some(queue) = self.queues.get_mut(&peer) else { + continue; + }; + let Some(item) = queue.pop_front() else { + continue; + }; + if queue.is_empty() { + self.queues.remove(&peer); + } + + let headers = item.headers; + match (self.resolver)(&self.state, item.input) { + Ok(msg) => match server.send(peer.clone(), msg, headers) { + Ok(SendStatus::SentNow) => {} + Ok(SendStatus::RetryScheduled { ticket }) => match self.advance_on { + AdvanceOn::Accepted => {} + AdvanceOn::Delivered => { + self.ticket_to_peer.insert(ticket, peer.clone()); + self.in_flight.insert(peer); + } + }, + Err(err) => { + tracing::warn!("DequeDispatch send error for peer {}: {:#}", peer, err); + } + }, + Err(err) => { + tracing::warn!("DequeDispatch resolver error for peer {}: {:#}", peer, err); + } + } + } + } + + fn next_wakeup(&self) -> Option { + if self.queues.values().any(|q| !q.is_empty()) || !self.in_flight.is_empty() { + Some(Instant::now() + self.tick_interval) + } else { + None + } + } } diff --git a/crates/hyli-net/src/tcp/middleware/mod.rs b/crates/hyli-net/src/tcp/middleware/mod.rs index 5a54381f4..bdd4f2991 100644 --- a/crates/hyli-net/src/tcp/middleware/mod.rs +++ b/crates/hyli-net/src/tcp/middleware/mod.rs @@ -7,7 +7,9 @@ use crate::tcp::{tcp_server::TcpServer, TcpEvent, TcpHeaders, TcpMessageLabel, T mod impls; -pub use impls::{DropOnError, MessageOnly, RetryingSend, TcpInboundMessage}; +pub use impls::{ + AdvanceOn, DequeDispatch, DropOnError, MessageOnly, RetryingSend, TcpInboundMessage, +}; #[macro_export] macro_rules! tcp_middleware_chain_type { @@ -84,11 +86,23 @@ pub struct SendErrorContext { pub error: anyhow::Error, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SendStatus { + SentNow, + RetryScheduled { ticket: u64 }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SendCompletion { + Delivered { ticket: u64 }, + Failed { ticket: u64 }, +} + pub enum SendErrorOutcome { /// Middleware absorbed the error (e.g. logged only). Handled, /// Middleware scheduled a retry. - RetryScheduled, + RetryScheduled { ticket: u64 }, /// Middleware requests dropping the peer. DropPeer, /// Middleware did not handle the error; propagate upstream. @@ -134,13 +148,14 @@ where socket_addr: String, msg: Res, headers: TcpHeaders, - ) -> anyhow::Result<()> + ) -> anyhow::Result where S: TcpServerLike>, Res: Clone, { match server.send(socket_addr.clone(), msg.clone(), headers.clone()) { - Ok(()) => Ok(()), + Ok(SendStatus::SentNow) => Ok(SendStatus::SentNow), + Ok(SendStatus::RetryScheduled { ticket }) => Ok(SendStatus::RetryScheduled { ticket }), Err(error) => { let ctx = SendErrorContext { socket_addr, @@ -149,10 +164,13 @@ where error, }; match self.on_send_error(server, &ctx) { - SendErrorOutcome::Handled | SendErrorOutcome::RetryScheduled => Ok(()), + SendErrorOutcome::Handled => Ok(SendStatus::SentNow), + SendErrorOutcome::RetryScheduled { ticket } => { + Ok(SendStatus::RetryScheduled { ticket }) + } SendErrorOutcome::DropPeer => { server.drop_peer_stream(ctx.socket_addr.clone()); - Ok(()) + Ok(SendStatus::SentNow) } SendErrorOutcome::Unhandled(error) => Err(error), } @@ -176,6 +194,10 @@ where fn next_wakeup(&self) -> Option { None } + + fn poll_send_completion(&mut self) -> Option { + None + } } /// Tower-style layering helper for TCP servers and already-layered services. @@ -220,6 +242,14 @@ where _marker: PhantomData, } } + + pub fn middleware_mut(&mut self) -> &mut M { + &mut self.middleware + } + + pub fn inner_mut(&mut self) -> &mut S { + &mut self.inner + } } impl TcpServerLike for TcpServerWithMiddleware @@ -260,7 +290,12 @@ where } } - fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()> { + fn send( + &mut self, + socket_addr: String, + msg: Res, + headers: TcpHeaders, + ) -> anyhow::Result { self.middleware .on_send(&mut self.inner, socket_addr, msg, headers) } @@ -270,6 +305,7 @@ where Res: Clone, { self.send(socket_addr.to_string(), msg.clone(), headers.clone()) + .map(|_| ()) } fn connected_clients(&self) -> Self::ConnectedClients<'_> { @@ -279,6 +315,12 @@ where fn drop_peer_stream(&mut self, peer_ip: String) { self.inner.drop_peer_stream(peer_ip) } + + fn poll_send_completion(&mut self) -> Option { + self.middleware + .poll_send_completion() + .or_else(|| self.inner.poll_send_completion()) + } } impl Layer for MiddlewareLayer diff --git a/crates/hyli-net/src/tcp/tcp_server.rs b/crates/hyli-net/src/tcp/tcp_server.rs index 36e405cfb..28ee02cf8 100644 --- a/crates/hyli-net/src/tcp/tcp_server.rs +++ b/crates/hyli-net/src/tcp/tcp_server.rs @@ -29,6 +29,7 @@ use hyli_turmoil_shims::collections::HashMap; use tracing::{debug, error, trace, warn}; use super::{tcp_client::TcpClient, SocketStream, TcpEvent}; +use crate::tcp::middleware::SendStatus; use crate::tcp::middleware::{TcpReqBound, TcpResBound}; use crate::tcp::TcpServerLike; @@ -635,7 +636,12 @@ where } } - fn send(&mut self, socket_addr: String, msg: Res, headers: TcpHeaders) -> anyhow::Result<()> { + fn send( + &mut self, + socket_addr: String, + msg: Res, + headers: TcpHeaders, + ) -> anyhow::Result { debug!(pool = %self.pool_name, "Sending msg {:?} to {}", msg, socket_addr); let message_label = msg.message_label(); let stream = self @@ -658,7 +664,7 @@ where ) })?; self.metrics.event_loop_message_sent(message_label); - Ok(()) + Ok(SendStatus::SentNow) } fn send_ref(&mut self, socket_addr: &str, msg: &Res, headers: &TcpHeaders) -> anyhow::Result<()> diff --git a/src/data_availability.rs b/src/data_availability.rs index 0b9807ab1..615a33b5f 100644 --- a/src/data_availability.rs +++ b/src/data_availability.rs @@ -8,7 +8,7 @@ use hyli_modules::modules::da_listener::{DaStreamPoll, SignedDaStream}; use hyli_modules::{bus::SharedMessageBus, modules::Module}; use hyli_modules::{log_error, module_bus_client, module_handle_messages}; use hyli_net::tcp::middleware::{ - middleware_layer, DropOnError, MessageOnly, RetryingSend, TcpServerExt, + middleware_layer, DequeDispatch, DropOnError, MessageOnly, RetryingSend, TcpServerExt, }; use hyli_net::tcp::TcpServerLike; use tokio::task::JoinHandle; @@ -24,11 +24,10 @@ use crate::{ use anyhow::{Context, Result}; use core::str; use std::{ - collections::{BTreeSet, HashMap, VecDeque}, + collections::{BTreeSet, VecDeque}, time::Duration, }; use strum_macros::AsRefStr; -use tokio::task::JoinSet; use tracing::{debug, error, info, trace, warn}; use crate::model::SharedRunContext; @@ -38,17 +37,41 @@ type DaServerStack = hyli_net::tcp_server!( response: DataAvailabilityEvent, middlewares: [ MessageOnly, + DaDequeDispatch, RetryingSend, DropOnError, ] ); -fn with_da_middlewares(server: RawDataAvailabilityServer) -> DaServerStack { +type DaResolverFn = fn(&Blocks, ConsensusProposalHash) -> anyhow::Result; +type DaDequeDispatch = + DequeDispatch; + +fn resolve_da_dispatch_input( + blocks: &Blocks, + hash: ConsensusProposalHash, +) -> anyhow::Result { + match blocks.get(&hash) { + Ok(Some(signed_block)) => Ok(DataAvailabilityEvent::SignedBlock(signed_block)), + Ok(None) => Err(anyhow::anyhow!( + "DequeDispatch missing block for hash {}", + hash + )), + Err(err) => Err(err), + } +} + +fn with_da_middlewares(server: RawDataAvailabilityServer, blocks: Blocks) -> DaServerStack { server .layer(middleware_layer(DropOnError)) .layer(middleware_layer( RetryingSend::new(10, Duration::from_millis(100)).max_per_tick(256), )) + .layer(middleware_layer( + DequeDispatch::new(blocks, resolve_da_dispatch_input as DaResolverFn) + .tick_interval(Duration::from_millis(5)) + .max_queue_len(50_000), + )) .layer(middleware_layer(MessageOnly)) } @@ -100,7 +123,6 @@ impl Module for DataAvailability { buffered_signed_blocks: BTreeSet::new(), catchupper: DaCatchupper::new(catchup_policy, ctx.config.da_max_frame_length), allow_peer_catchup: false, - peer_send_queues: HashMap::new(), }) } @@ -133,9 +155,6 @@ pub struct DataAvailability { catchupper: DaCatchupper, // Gate peer-triggered catchup until genesis outcome is known. allow_peer_catchup: bool, - - // Track blocks to send to each streaming peer (ensures ordering) - peer_send_queues: HashMap>, } #[derive(Debug, Clone, AsRefStr)] @@ -649,6 +668,7 @@ impl DataAvailability { format!("DAServer-{}", self.config.id.clone()).as_str(), ) .await?, + self.blocks.new_handle(), ); let mut catchup_block_receiver = self @@ -658,9 +678,6 @@ impl DataAvailability { let mut first_hole_receiver = self.start_scanning_for_first_hole(); - // Used to send blocks to clients (indexers/peers) - // JoinSet of peer addresses to process one queued send at a time. - let mut catchup_joinset: JoinSet = tokio::task::JoinSet::new(); let mut catchup_task_checker_ticker = tokio::time::interval(std::time::Duration::from_secs(5)); let mut storage_metrics_ticker = tokio::time::interval(std::time::Duration::from_secs(30)); @@ -668,7 +685,7 @@ impl DataAvailability { module_handle_messages! { on_self self, listen evt => { - _ = log_error!(self.handle_mempool_event(evt, &mut server, &mut catchup_joinset).await, "Handling Mempool Event"); + _ = log_error!(self.handle_mempool_event(evt, &mut server).await, "Handling Mempool Event"); } listen evt => { @@ -679,7 +696,7 @@ impl DataAvailability { match cmd { GenesisEvent::GenesisBlock(signed_block) => { debug!("🌱 Genesis block received with validators {:?}", signed_block.consensus_proposal.staking_actions.clone()); - _ = log_error!(self.handle_signed_block(signed_block, &mut server, &mut catchup_joinset).await.context("Handling Genesis block"), "Handling GenesisBlock Event"); + _ = log_error!(self.handle_signed_block(signed_block, &mut server).await.context("Handling Genesis block"), "Handling GenesisBlock Event"); } GenesisEvent::NoGenesis => { self.allow_peer_catchup = true; @@ -712,7 +729,7 @@ impl DataAvailability { } Some(streamed_block) = catchup_block_receiver.recv() => { - if let Some(height) = self.handle_signed_block(streamed_block, &mut server, &mut catchup_joinset).await { + if let Some(height) = self.handle_signed_block(streamed_block, &mut server).await { _ = log_error!(self.catchupper.on_catchup_progress(height), "Catchup transition after streamed block"); } } @@ -723,7 +740,6 @@ impl DataAvailability { _ = log_error!( self.start_streaming_to_peer( start_height, - &mut catchup_joinset, &socket_addr, &mut server, ).await, @@ -739,23 +755,6 @@ impl DataAvailability { } } - // Send one block to a peer as part of "catchup", - // once we have sent all blocks the peer is presumably synchronised. - Some(Ok(peer_ip)) = catchup_joinset.join_next() => { - - #[cfg(test)] - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - _ = log_error!( - self.handle_send_next_block_to_peer( - peer_ip.clone(), - &mut catchup_joinset, - &mut server - ).await, - "Send next block to peer" - ); - } - Some(hole) = first_hole_receiver.recv() => { info!("Setting backfill start height as {:?}", &hole); self.catchupper.on_first_hole_discovered(hole); @@ -770,63 +769,6 @@ impl DataAvailability { Ok(()) } - async fn handle_send_next_block_to_peer( - &mut self, - peer_ip: String, - catchup_joinset: &mut JoinSet, - server: &mut DaServerStack, - ) -> Result<()> { - if !server.connected(&peer_ip) { - debug!("Peer {} disconnected, removing from send queues", peer_ip); - self.peer_send_queues.remove(&peer_ip); - return Ok(()); - } - - // Get next block from this peer's queue - let hash = match self.peer_send_queues.get_mut(&peer_ip) { - Some(queue) => match queue.pop_front() { - Some(h) => h, - None => { - // Queue is empty - peer is caught up and waiting for new blocks - // Keep them in the map but don't spawn a new task yet - debug!("Peer {} caught up, waiting for new blocks", peer_ip); - return Ok(()); - } - }, - None => { - debug!("Peer {} not in send queues", peer_ip); - return Ok(()); - } - }; - - debug!("📡 Sending block {} to peer {}", &hash, &peer_ip); - if let Ok(Some(signed_block)) = self.blocks.get(&hash) { - // Errors will be handled when sending new blocks, ignore here. - match server.send( - peer_ip.clone(), - DataAvailabilityEvent::SignedBlock(signed_block), - vec![], - ) { - Ok(()) => { - // Successfully sent, continue with next block - catchup_joinset.spawn(async move { peer_ip }); - } - Err(e) => { - warn!("Error sending block {} to peer {}: {:#}", hash, peer_ip, e); - self.peer_send_queues.remove(&peer_ip); - } - } - } else { - error!( - "Block {} not found in storage while sending to peer {}. Should not happen", - &hash, &peer_ip - ); - // Continue anyway with next block - catchup_joinset.spawn(async move { peer_ip }); - } - Ok(()) - } - async fn handle_block_request( &mut self, block_height: BlockHeight, @@ -845,11 +787,10 @@ impl DataAvailability { "📦 Found block at height {}, sending to {}", block_height, socket_addr ); - // Send immediately - this is inserted next in the send queue - if let Err(e) = server.send( + if let Err(e) = self.send_immediate_to_peer( + server, socket_addr.to_string(), DataAvailabilityEvent::SignedBlock(block), - vec![], ) { warn!( "📦 Error while responding to block request at height {} for {}: {:#}", @@ -864,10 +805,10 @@ impl DataAvailability { "📦 Block at height {} not found in storage, sending BlockNotFound to {}", block_height, socket_addr ); - if let Err(e) = server.send( + if let Err(e) = self.send_immediate_to_peer( + server, socket_addr.to_string(), DataAvailabilityEvent::BlockNotFound(block_height), - vec![], ) { warn!( "📦 Error while responding BlockNotFound at height {} for {}: {:#}", @@ -881,10 +822,10 @@ impl DataAvailability { "📦 Error retrieving block at height {}: {:#}", block_height, e ); - if let Err(e) = server.send( + if let Err(e) = self.send_immediate_to_peer( + server, socket_addr.to_string(), DataAvailabilityEvent::BlockNotFound(block_height), - vec![], ) { warn!( "📦 Error while responding BlockNotFound at height {} for {}: {:#}", @@ -902,7 +843,6 @@ impl DataAvailability { &mut self, evt: MempoolBlockEvent, tcp_server: &mut DaServerStack, - catchup_joinset: &mut JoinSet, ) -> Result<()> { match evt { MempoolBlockEvent::BuiltSignedBlock(signed_block) => { @@ -912,9 +852,7 @@ impl DataAvailability { ); // Mempool-produced blocks are local tip updates, not catchup-stream progress. // Feeding them into catchup progress can prematurely complete backfill. - _ = self - .handle_signed_block(signed_block, tcp_server, catchup_joinset) - .await; + _ = self.handle_signed_block(signed_block, tcp_server).await; } MempoolBlockEvent::StartedBuildingBlocks(height) => { debug!( @@ -948,7 +886,6 @@ impl DataAvailability { &mut self, block: SignedBlock, tcp_server: &mut DaServerStack, - catchup_joinset: &mut JoinSet, ) -> Option { let hash = block.hashed(); // if new block is already handled, ignore it @@ -991,13 +928,12 @@ impl DataAvailability { } else { // store block _ = log_error!( - self.add_processed_block(block.clone(), tcp_server, catchup_joinset) - .await, + self.add_processed_block(block.clone(), tcp_server).await, "Adding processed block" ); } - let highest_processed_height = self.pop_buffer(hash, tcp_server, catchup_joinset).await; + let highest_processed_height = self.pop_buffer(hash, tcp_server).await; _ = log_error!(self.blocks.persist(), "Persisting blocks"); let height = block.height(); @@ -1010,7 +946,6 @@ impl DataAvailability { &mut self, mut last_block_hash: ConsensusProposalHash, tcp_server: &mut DaServerStack, - catchup_joinset: &mut JoinSet, ) -> Option { let mut res = None; @@ -1034,7 +969,7 @@ impl DataAvailability { let height = first_buffered.height(); if self - .add_processed_block(first_buffered.clone(), tcp_server, catchup_joinset) + .add_processed_block(first_buffered.clone(), tcp_server) .await .is_ok() { @@ -1083,35 +1018,18 @@ impl DataAvailability { async fn add_processed_block( &mut self, block: SignedBlock, - _tcp_server: &mut DaServerStack, - catchup_joinset: &mut JoinSet, + tcp_server: &mut DaServerStack, ) -> anyhow::Result<()> { self.store_block(&block)?; - let block_hash = block.hashed(); - - // Add new block to all streaming peer queues to ensure ordering - // (instead of broadcasting which can cause out-of-order delivery) - for (peer, queue) in self.peer_send_queues.iter_mut() { - let was_empty = queue.is_empty(); - queue.push_back(block_hash.clone()); - - // If queue was empty (peer was caught up), restart their send task - if was_empty { - debug!( - "Restarting send task for caught-up peer {} with new block {}", - peer, block_hash - ); - let peer_clone = peer.clone(); - catchup_joinset.spawn(async move { peer_clone }); - } else { - debug!( - "Appending block {} to queue for peer {} (queue size: {})", - block_hash, - peer, - queue.len() - ); - } + // Broadcast live updates to all connected peers. + let errors = + tcp_server.broadcast(DataAvailabilityEvent::SignedBlock(block.clone()), vec![]); + for (peer, error) in errors { + warn!( + "Error while broadcasting signed block to {}: {:#}", + peer, error + ); } // Send the block to NodeState for processing @@ -1128,7 +1046,6 @@ impl DataAvailability { async fn start_streaming_to_peer( &mut self, start_height: BlockHeight, - catchup_joinset: &mut JoinSet, peer_ip: &str, server: &mut DaServerStack, ) -> Result<()> { @@ -1177,8 +1094,6 @@ impl DataAvailability { first_missing, peer_ip, e ); } - - self.peer_send_queues.remove(peer_ip); return Ok(()); } @@ -1189,14 +1104,34 @@ impl DataAvailability { processed_block_hashes.len() ); - // Store queue for this peer - new blocks will be appended here let peer_ip_string = peer_ip.to_string(); - self.peer_send_queues - .insert(peer_ip_string.clone(), processed_block_hashes); + for hash in processed_block_hashes { + self.enqueue_hash_for_peer(server, peer_ip_string.clone(), hash)?; + } - // Start the send task for this peer - catchup_joinset.spawn(async move { peer_ip_string }); + Ok(()) + } + fn enqueue_hash_for_peer( + &self, + server: &mut DaServerStack, + peer_ip: String, + hash: ConsensusProposalHash, + ) -> anyhow::Result<()> { + server + .inner_mut() + .middleware_mut() + .enqueue(peer_ip, hash, vec![]) + } + + fn send_immediate_to_peer( + &self, + server: &mut DaServerStack, + peer_ip: String, + event: DataAvailabilityEvent, + ) -> anyhow::Result<()> { + // Bypass DequeDispatch for request/response priority while keeping RetryingSend. + server.inner_mut().inner_mut().send(peer_ip, event, vec![])?; Ok(()) } } @@ -1222,7 +1157,6 @@ pub mod tests { use hyli_modules::utils::da_codec::DataAvailabilityClient; use hyli_net::tcp::TcpServerLike; use staking::state::Staking; - use tokio::task::JoinSet; struct DataAvailabilityTestCtx { pub node_state_bus: NodeStateBusClient, @@ -1252,7 +1186,6 @@ pub mod tests { buffered_signed_blocks: Default::default(), catchupper: Default::default(), allow_peer_catchup: false, - peer_send_queues: HashMap::new(), }; DataAvailabilityTestCtx { @@ -1267,11 +1200,7 @@ pub mod tests { block: SignedBlock, tcp_server: &mut DaServerStack, ) { - let mut catchup_joinset: JoinSet = JoinSet::new(); - _ = self - .da - .handle_signed_block(block.clone(), tcp_server, &mut catchup_joinset) - .await; + _ = self.da.handle_signed_block(block.clone(), tcp_server).await; let block_hash = block.hashed(); let Ok(full_block) = self.node_state.handle_signed_block(block) else { tracing::warn!("Error while handling signed block {}", block_hash); @@ -1309,6 +1238,7 @@ pub mod tests { RawDataAvailabilityServer::start(port, "DaServer") .await .unwrap(), + blocks.new_handle(), ); let bus = super::DABusClient::new_from_bus(crate::bus::SharedMessageBus::new()).await; @@ -1319,7 +1249,6 @@ pub mod tests { buffered_signed_blocks: Default::default(), catchupper: Default::default(), allow_peer_catchup: false, - peer_send_queues: HashMap::new(), }; let mut block = SignedBlock::default(); let mut blocks = vec![]; @@ -1329,20 +1258,14 @@ pub mod tests { block.consensus_proposal.slot = i; } blocks.reverse(); - let mut catchup_joinset: JoinSet = JoinSet::new(); for block in blocks { if block.height().0 == 0 { assert_eq!( - da.handle_signed_block(block, &mut server, &mut catchup_joinset) - .await, + da.handle_signed_block(block, &mut server).await, Some(BlockHeight(9998)) ); } else { - assert_eq!( - da.handle_signed_block(block, &mut server, &mut catchup_joinset) - .await, - None - ); + assert_eq!(da.handle_signed_block(block, &mut server).await, None); } } } @@ -1373,7 +1296,6 @@ pub mod tests { buffered_signed_blocks: Default::default(), catchupper: Default::default(), allow_peer_catchup: false, - peer_send_queues: HashMap::new(), }; let mut block = SignedBlock::default(); @@ -1475,10 +1397,13 @@ pub mod tests { #[test_log::test(tokio::test)] async fn test_da_many_clients_only_last_connected() { let port = find_available_port().await; + let tmpdir = tempfile::tempdir().unwrap().keep(); + let blocks = Blocks::new(&tmpdir).unwrap(); let mut server = super::with_da_middlewares( RawDataAvailabilityServer::start(port, "DaServer") .await .unwrap(), + blocks, ); let client_count = 5usize; @@ -1558,6 +1483,7 @@ pub mod tests { RawDataAvailabilityServer::start(port, "DaServer") .await .unwrap(), + da_sender.da.blocks.new_handle(), ); let receiver_global_bus = crate::bus::SharedMessageBus::new(); @@ -1722,6 +1648,7 @@ pub mod tests { RawDataAvailabilityServer::start(port, "DaServer") .await .unwrap(), + da_sender.da.blocks.new_handle(), ); let receiver_global_bus = crate::bus::SharedMessageBus::new(); @@ -1858,7 +1785,6 @@ pub mod tests { buffered_signed_blocks: Default::default(), catchupper: Default::default(), allow_peer_catchup: false, - peer_send_queues: HashMap::new(), }; // Start DA server @@ -1873,9 +1799,9 @@ pub mod tests { .await .unwrap(); - // Start streaming from block 0 + // Start streaming from block 8 so block 7 can only come from BlockRequest. client - .send(DataAvailabilityRequest::StreamFromHeight(BlockHeight(0))) + .send(DataAvailabilityRequest::StreamFromHeight(BlockHeight(8))) .await .unwrap(); @@ -1897,7 +1823,7 @@ pub mod tests { // Collect responses (use a set to track unique blocks received) let mut received_block_heights = std::collections::HashSet::new(); - let mut received_block_7_from_request = false; + let mut received_block_7 = false; let mut received_block_not_found = false; let mut event_count = 0; let start_time = tokio::time::Instant::now(); @@ -1909,9 +1835,8 @@ pub mod tests { let height = block.height().0; tracing::info!("Received block {} (event #{})", height, event_count); - // Track if block 7 arrives early (from request, not just stream) - if height == 7 && received_block_heights.len() < 5 { - received_block_7_from_request = true; + if height == 7 { + received_block_7 = true; } received_block_heights.insert(height); @@ -1924,8 +1849,8 @@ pub mod tests { DataAvailabilityEvent::MempoolStatusEvent(_) => {} } - // Stop after receiving enough events (at least 8 blocks and BlockNotFound) - if received_block_heights.len() >= 8 && received_block_not_found { + // Stop after receiving the requested block, some streamed blocks, and BlockNotFound. + if received_block_7 && received_block_heights.len() >= 3 && received_block_not_found { break; } @@ -1937,28 +1862,25 @@ pub mod tests { } // Verify results - assert!( - received_block_7_from_request, - "Block 7 should have arrived early (from BlockRequest, not just stream)" - ); + assert!(received_block_7, "Should have received requested block 7"); assert!( received_block_not_found, "Should have received BlockNotFound for block 100" ); assert!( - received_block_heights.len() >= 8, - "Should have received at least 8 different blocks, got {}", + received_block_heights.len() >= 3, + "Should have received at least 3 different blocks, got {}", received_block_heights.len() ); - // Verify we got essential blocks including block 7 + // Verify we got the request block and stream blocks from the requested stream start. assert!( - received_block_heights.contains(&0), - "Should have received block 0" + received_block_heights.contains(&7), + "Should have received block 7 from BlockRequest" ); assert!( - received_block_heights.contains(&7), - "Should have received block 7 (from request)" + received_block_heights.iter().any(|h| *h >= 8), + "Should have received streamed blocks at or above 8" ); tracing::info!("✅ Test passed: BlockRequest works while streaming"); @@ -1966,7 +1888,7 @@ pub mod tests { " - Received {} unique blocks", received_block_heights.len() ); - tracing::info!(" - Block 7 arrived early via BlockRequest (not just stream)"); + tracing::info!(" - Received requested block 7 while streaming from 8"); tracing::info!(" - Got BlockNotFound for non-existent block 100"); tracing::info!(" - Blocks received: {:?}", { let mut v: Vec<_> = received_block_heights.iter().collect(); @@ -2006,7 +1928,6 @@ pub mod tests { buffered_signed_blocks: Default::default(), catchupper: Default::default(), allow_peer_catchup: false, - peer_send_queues: HashMap::new(), }; tokio::spawn(async move { @@ -2077,7 +1998,6 @@ pub mod tests { buffered_signed_blocks: Default::default(), catchupper: Default::default(), allow_peer_catchup: false, - peer_send_queues: HashMap::new(), }; tokio::spawn(async move { @@ -2415,7 +2335,6 @@ pub mod tests { buffered_signed_blocks: Default::default(), catchupper: Default::default(), allow_peer_catchup: false, - peer_send_queues: HashMap::new(), }; tokio::spawn(async move { From 08b46dfaa1251136aaa76c755bb0d9e60b2d7d1b Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Fri, 20 Feb 2026 18:27:12 +0100 Subject: [PATCH 16/18] fix turmoil tests --- crates/hyli-net/tests/basic.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/crates/hyli-net/tests/basic.rs b/crates/hyli-net/tests/basic.rs index 3736d0d0d..b341c9e4c 100644 --- a/crates/hyli-net/tests/basic.rs +++ b/crates/hyli-net/tests/basic.rs @@ -10,7 +10,7 @@ use hyli_net::{ net::Sim, tcp::{ p2p_server::{P2PServer, P2PTimeouts}, - Canal, P2PTcpMessage, TcpMessageLabel, + Canal, P2PTcpMessage, TcpMessageLabel, TcpServerLike, }, }; use hyli_turmoil_shims::init_test_meter_provider; @@ -203,7 +203,10 @@ async fn setup_drop_host( _ = interval_start_shutdown.tick() => { if turmoil::elapsed() > Duration::from_millis(duration) { tracing::error!("Current peers {:?}", p2p.peers.keys()); - tracing::error!("Current tcp peers {:?}", p2p.tcp_server.connected_clients()); + tracing::error!( + "Current tcp peers {:?}", + p2p.tcp_server.connected_clients().cloned().collect::>() + ); // Peers map should match all_other_peers assert_eq!(all_other_peers.len(), p2p.peers.keys().len()); From 91d3d9c557e9f3640be683749775e5b269bb9630 Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Fri, 20 Feb 2026 18:32:59 +0100 Subject: [PATCH 17/18] fix clippy --- crates/hyli-bus/src/utils/logger.rs | 4 +++- crates/hyli-net/src/tcp.rs | 1 + crates/hyli-net/src/tcp/tcp_client.rs | 2 +- crates/hyli-net/src/tcp/tcp_server.rs | 8 ++++++-- crates/hyli-net/tests/mempool_drop.rs | 4 +++- src/data_availability.rs | 5 ++++- 6 files changed, 18 insertions(+), 6 deletions(-) diff --git a/crates/hyli-bus/src/utils/logger.rs b/crates/hyli-bus/src/utils/logger.rs index 7415c0fad..c093111a3 100644 --- a/crates/hyli-bus/src/utils/logger.rs +++ b/crates/hyli-bus/src/utils/logger.rs @@ -1,4 +1,6 @@ -use anyhow::{Context, Result}; +#[cfg(feature = "instrumentation")] +use anyhow::Context; +use anyhow::Result; #[cfg(feature = "instrumentation")] use opentelemetry::trace::TracerProvider; use tracing::level_filters::LevelFilter; diff --git a/crates/hyli-net/src/tcp.rs b/crates/hyli-net/src/tcp.rs index 3d7e6deec..a9acc5f66 100644 --- a/crates/hyli-net/src/tcp.rs +++ b/crates/hyli-net/src/tcp.rs @@ -43,6 +43,7 @@ pub use hyli_net_traits::TcpMessageLabel; pub type TcpHeaders = Vec<(String, String)>; /// Common interface for `TcpServer` and middleware wrappers. +#[allow(async_fn_in_trait)] pub trait TcpServerLike { type EventOut; type ConnectedClients<'a>: Iterator diff --git a/crates/hyli-net/src/tcp/tcp_client.rs b/crates/hyli-net/src/tcp/tcp_client.rs index c8a2649ff..0b783225a 100644 --- a/crates/hyli-net/src/tcp/tcp_client.rs +++ b/crates/hyli-net/src/tcp/tcp_client.rs @@ -223,7 +223,7 @@ mod tests { client.socket_addr }); - while server.connected_clients().len() == 0 { + while server.connected_clients().is_empty() { _ = tokio::time::timeout(Duration::from_millis(100), server.listen_next()).await; } diff --git a/crates/hyli-net/src/tcp/tcp_server.rs b/crates/hyli-net/src/tcp/tcp_server.rs index 28ee02cf8..167ddb2a7 100644 --- a/crates/hyli-net/src/tcp/tcp_server.rs +++ b/crates/hyli-net/src/tcp/tcp_server.rs @@ -42,6 +42,10 @@ impl<'a> ConnectedClients<'a> { pub fn len(&self) -> usize { self.0.len() } + + pub fn is_empty(&self) -> bool { + self.0.len() == 0 + } } impl<'a> Iterator for ConnectedClients<'a> { @@ -844,8 +848,8 @@ pub mod tests { _ = tokio::time::timeout(Duration::from_millis(200), server.listen_next()).await; let client2_addr = server .connected_clients() + .find(|&addr| addr != &client1_addr) .cloned() - .find(|addr| addr != &client1_addr) .unwrap(); server.raw_send_parallel( @@ -894,8 +898,8 @@ pub mod tests { _ = tokio::time::timeout(Duration::from_millis(200), server.listen_next()).await; let client2_addr = server .connected_clients() + .find(|&addr| addr != &client1_addr) .cloned() - .find(|addr| addr != &client1_addr) .unwrap(); _ = server.send( diff --git a/crates/hyli-net/tests/mempool_drop.rs b/crates/hyli-net/tests/mempool_drop.rs index f8ffbd068..203d7fbfd 100644 --- a/crates/hyli-net/tests/mempool_drop.rs +++ b/crates/hyli-net/tests/mempool_drop.rs @@ -5,7 +5,9 @@ use std::sync::mpsc; use std::time::Duration; use hyli_net::tcp::intercept::{set_message_hook_scoped, MessageAction}; -use hyli_net::tcp::{decode_tcp_payload, tcp_client::TcpClient, tcp_server::TcpServer, TcpEvent}; +use hyli_net::tcp::{ + decode_tcp_payload, tcp_client::TcpClient, tcp_server::TcpServer, TcpEvent, TcpServerLike, +}; use hyli_turmoil_shims::init_test_meter_provider; use sdk::{DataProposal, Transaction}; diff --git a/src/data_availability.rs b/src/data_availability.rs index 615a33b5f..6d12f6882 100644 --- a/src/data_availability.rs +++ b/src/data_availability.rs @@ -1131,7 +1131,10 @@ impl DataAvailability { event: DataAvailabilityEvent, ) -> anyhow::Result<()> { // Bypass DequeDispatch for request/response priority while keeping RetryingSend. - server.inner_mut().inner_mut().send(peer_ip, event, vec![])?; + server + .inner_mut() + .inner_mut() + .send(peer_ip, event, vec![])?; Ok(()) } } From 6cfb176f0041dc356f313a64e94efa8b7a33b03b Mon Sep 17 00:00:00 2001 From: Alexandre Careil Date: Fri, 20 Feb 2026 18:43:37 +0100 Subject: [PATCH 18/18] fix test --- src/data_availability.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/data_availability.rs b/src/data_availability.rs index 6d12f6882..bc1ec16ba 100644 --- a/src/data_availability.rs +++ b/src/data_availability.rs @@ -1094,6 +1094,7 @@ impl DataAvailability { first_missing, peer_ip, e ); } + server.drop_peer_stream(peer_ip.to_string()); return Ok(()); }