From 4f37955e587641366e3368f5d59a745a8b238fb7 Mon Sep 17 00:00:00 2001 From: xinyuan <213633968+xinyuan-dev@users.noreply.github.com> Date: Thu, 8 Jan 2026 19:53:38 +0900 Subject: [PATCH] remove peer lookup from packet builder --- monad-raptor/src/r10/nonsystematic/encoder.rs | 2 +- monad-raptorcast/src/auth/protocol.rs | 6 +- monad-raptorcast/src/auth/socket.rs | 2 +- monad-raptorcast/src/decoding.rs | 2 +- monad-raptorcast/src/lib.rs | 57 +++--- monad-raptorcast/src/packet/assembler.rs | 181 ++++++++++-------- monad-raptorcast/src/packet/assigner.rs | 12 +- monad-raptorcast/src/packet/builder.rs | 59 ++---- monad-raptorcast/src/packet/mod.rs | 52 +++-- .../src/raptorcast_secondary/mod.rs | 2 +- monad-raptorcast/src/udp.rs | 4 +- monad-raptorcast/src/util.rs | 8 +- monad-wireauth/src/api.rs | 2 +- 13 files changed, 194 insertions(+), 195 deletions(-) diff --git a/monad-raptor/src/r10/nonsystematic/encoder.rs b/monad-raptor/src/r10/nonsystematic/encoder.rs index ed0c971e8e..4ee9be48ef 100644 --- a/monad-raptor/src/r10/nonsystematic/encoder.rs +++ b/monad-raptor/src/r10/nonsystematic/encoder.rs @@ -55,7 +55,7 @@ pub struct Encoder<'a> { } impl Encoder<'_> { - pub fn new(src: &[u8], symbol_len: usize) -> Result { + pub fn new(src: &[u8], symbol_len: usize) -> Result, Error> { if symbol_len == 0 { return Err(Error::new( ErrorKind::InvalidInput, diff --git a/monad-raptorcast/src/auth/protocol.rs b/monad-raptorcast/src/auth/protocol.rs index e67f378354..a75c9d17b8 100644 --- a/monad-raptorcast/src/auth/protocol.rs +++ b/monad-raptorcast/src/auth/protocol.rs @@ -75,7 +75,7 @@ pub trait AuthenticationProtocol { fn next_deadline(&self) -> Option; - fn metrics(&self) -> ExecutorMetricsChain; + fn metrics(&self) -> ExecutorMetricsChain<'_>; } pub struct WireAuthProtocol { @@ -185,7 +185,7 @@ impl AuthenticationProtocol for WireAuthProtocol { self.api.has_any_session_by_public_key(public_key) } - fn metrics(&self) -> ExecutorMetricsChain { + fn metrics(&self) -> ExecutorMetricsChain<'_> { self.api.metrics() } } @@ -288,7 +288,7 @@ impl AuthenticationProtocol for NoopAuthProtocol

{ false } - fn metrics(&self) -> ExecutorMetricsChain { + fn metrics(&self) -> ExecutorMetricsChain<'_> { ExecutorMetricsChain::default() } } diff --git a/monad-raptorcast/src/auth/socket.rs b/monad-raptorcast/src/auth/socket.rs index 9d10ff13ec..09be4827b2 100644 --- a/monad-raptorcast/src/auth/socket.rs +++ b/monad-raptorcast/src/auth/socket.rs @@ -209,7 +209,7 @@ where } } - pub fn metrics(&self) -> ExecutorMetricsChain { + pub fn metrics(&self) -> ExecutorMetricsChain<'_> { let mut chain = ExecutorMetricsChain::default().push(self.metrics.as_ref()); if let Some(authenticated) = &self.authenticated { chain = chain.chain(authenticated.auth_protocol.metrics()); diff --git a/monad-raptorcast/src/decoding.rs b/monad-raptorcast/src/decoding.rs index 7b9160117e..40b1cc8524 100644 --- a/monad-raptorcast/src/decoding.rs +++ b/monad-raptorcast/src/decoding.rs @@ -330,7 +330,7 @@ where self.pending_messages.consistency_breaches() } - pub fn metrics(&self) -> ExecutorMetricsChain { + pub fn metrics(&self) -> ExecutorMetricsChain<'_> { ExecutorMetricsChain::default() .push(&self.metrics) .push(self.pending_messages.validator.metrics()) diff --git a/monad-raptorcast/src/lib.rs b/monad-raptorcast/src/lib.rs index 18e59ca56e..18183f64f1 100644 --- a/monad-raptorcast/src/lib.rs +++ b/monad-raptorcast/src/lib.rs @@ -86,8 +86,7 @@ pub const UNICAST_MSG_BATCH_SIZE: usize = 32; pub const RAPTORCAST_SOCKET: &str = "raptorcast"; pub const AUTHENTICATED_RAPTORCAST_SOCKET: &str = "authenticated_raptorcast"; -pub(crate) type OwnedMessageBuilder = - packet::MessageBuilder<'static, ST, Arc>>>; +pub(crate) type OwnedMessageBuilder = packet::MessageBuilder<'static, ST>; pub struct RaptorCast where @@ -109,8 +108,8 @@ where current_epoch: Epoch, udp_state: udp::UdpState, - message_builder: OwnedMessageBuilder, - secondary_message_builder: Option>, + message_builder: OwnedMessageBuilder, + secondary_message_builder: Option>, tcp_reader: TcpSocketReader, tcp_writer: TcpSocketWriter, @@ -182,11 +181,10 @@ where let redundancy = Redundancy::from_f32(config.primary_instance.raptor10_redundancy) .expect("primary raptor10_redundancy doesn't fit"); let segment_size = dual_socket.segment_size(config.mtu); - let message_builder = - OwnedMessageBuilder::new(config.shared_key.clone(), peer_discovery_driver.clone()) - .segment_size(segment_size) - .group_id(GroupId::Primary(current_epoch)) - .redundancy(redundancy); + let message_builder = OwnedMessageBuilder::new(config.shared_key.clone()) + .segment_size(segment_size) + .group_id(GroupId::Primary(current_epoch)) + .redundancy(redundancy); let secondary_redundancy = Redundancy::from_f32( config @@ -194,11 +192,10 @@ where .raptor10_fullnode_redundancy_factor, ) .expect("secondary raptor10_redundancy doesn't fit"); - let secondary_message_builder = - OwnedMessageBuilder::new(config.shared_key.clone(), peer_discovery_driver.clone()) - .segment_size(segment_size) - .group_id(GroupId::Primary(current_epoch)) - .redundancy(secondary_redundancy); + let secondary_message_builder = OwnedMessageBuilder::new(config.shared_key.clone()) + .segment_size(segment_size) + .group_id(GroupId::Primary(current_epoch)) + .redundancy(secondary_redundancy); Self { is_dynamic_fullnode, @@ -1358,7 +1355,7 @@ where fn send( dual_socket: &mut auth::DualSocketHandle, peer_discovery_driver: &Arc>>, - message_builder: &mut OwnedMessageBuilder, + message_builder: &mut OwnedMessageBuilder, message: &Bytes, build_target: &BuildTarget, priority: UdpPriority, @@ -1370,14 +1367,18 @@ fn send( { { let dual_socket_cell = std::cell::RefCell::new(&mut *dual_socket); - let mut sink = packet::UdpMessageBatcher::new(UNICAST_MSG_BATCH_SIZE, |rc_chunks| { - dual_socket_cell - .borrow_mut() - .write_unicast_with_priority(rc_chunks, priority); - }); + let mut sink = packet::UdpMessageBatcher::new( + UNICAST_MSG_BATCH_SIZE, + (peer_discovery_driver, &dual_socket_cell), + |rc_chunks| { + dual_socket_cell + .borrow_mut() + .write_unicast_with_priority(rc_chunks, priority); + }, + ); message_builder - .prepare_with_peer_lookup((peer_discovery_driver, &dual_socket_cell)) + .prepare() .group_id(group_id) .build_into(message, build_target, &mut sink) .unwrap_log_on_error(message, build_target); @@ -1389,7 +1390,7 @@ fn send( fn send_with_record( dual_socket: &mut auth::DualSocketHandle, peer_discovery_driver: &Arc>>, - message_builder: &mut OwnedMessageBuilder, + message_builder: &mut OwnedMessageBuilder, message: &Bytes, priority: UdpPriority, target: &NodeId>, @@ -1409,14 +1410,14 @@ fn send_with_record( name_record, dual_socket: &dual_socket_cell, }; - let mut sink = packet::UdpMessageBatcher::new(UNICAST_MSG_BATCH_SIZE, |rc_chunks| { - dual_socket_cell - .borrow_mut() - .write_unicast_with_priority(rc_chunks, priority); - }); + let mut sink = + packet::UdpMessageBatcher::new(UNICAST_MSG_BATCH_SIZE, lookup, |rc_chunks| { + dual_socket_cell + .borrow_mut() + .write_unicast_with_priority(rc_chunks, priority); + }); message_builder - .prepare_with_peer_lookup(&lookup) .build_into(message, &build_target, &mut sink) .unwrap_log_on_error(message, &build_target); } diff --git a/monad-raptorcast/src/packet/assembler.rs b/monad-raptorcast/src/packet/assembler.rs index d7f7663173..3f9e0cf2e4 100644 --- a/monad-raptorcast/src/packet/assembler.rs +++ b/monad-raptorcast/src/packet/assembler.rs @@ -53,19 +53,17 @@ pub enum AssembleMode { } #[allow(clippy::too_many_arguments)] -pub(crate) fn assemble( +pub(crate) fn assemble( key: &ST::KeyPairType, layout: PacketLayout, app_message: &[u8], header_buf: &[u8], assignment: ChunkAssignment>, mode: AssembleMode, - peer_lookup: &PL, - collector: &mut impl Collector, + collector: &mut impl Collector>>, ) -> Result<()> where ST: CertificateSignature, - PL: PeerAddrLookup> + ?Sized, { // step 1. generate the chunks let mut chunks = assignment.generate(layout); @@ -77,26 +75,23 @@ where encode_symbols(app_message, &mut chunks, layout)?; } - // step 3. lookup recipient addresses - lookup_recipient_addrs(&chunks, peer_lookup); - if mode.stream_mode() { for mut batch in owned_merkle_batches(chunks, layout) { - // step 4. sign and write headers for this merkle batch + // step 3. sign and write headers for this merkle batch let merkle_batch = MerkleBatch::from(&mut batch[..]); merkle_batch.write_header::(layout, key, header_buf)?; - // step 5. assemble udp messages - mode.assemble_udp_messages_into(collector, batch, layout); + // step 4. assemble udp messages + mode.assemble_udp_messages_into(collector, batch); } } else { - // step 4. sign and write headers for this merkle batch + // step 3. sign and write headers for this merkle batch for batch in merkle_batches(&mut chunks, layout) { batch.write_header::(layout, key, header_buf)?; } - // step 5. assemble udp messages - mode.assemble_udp_messages_into(collector, chunks, layout); + // step 4. assemble udp messages + mode.assemble_udp_messages_into(collector, chunks); } Ok(()) @@ -108,6 +103,16 @@ pub(crate) struct Chunk { payload: BytesMut, } +impl From> for UdpMessage { + fn from(chunk: Chunk) -> Self { + Self { + recipient: chunk.recipient, + stride: chunk.payload.len(), + payload: chunk.payload.freeze(), + } + } +} + impl Chunk { pub fn new(chunk_id: usize, recipient: Recipient, payload: BytesMut) -> Self { Self { @@ -123,7 +128,7 @@ impl Chunk { // // Change to Arc if we need parallel processing. #[derive(Clone, PartialEq, Eq)] -pub(crate) struct Recipient(Rc>); +pub struct Recipient(Rc>); impl std::hash::Hash for Recipient { fn hash(&self, state: &mut H) { @@ -158,6 +163,20 @@ impl PartialEq for RecipientInner { } } +impl Recipient { + // only used for testing + #[cfg(test)] + pub fn dummy(addr: Option) -> Self { + let mut bytes = format!("{:?}", addr).into_bytes(); + bytes.resize(32, 0u8); + + let pubkey = monad_crypto::NopPubKey::from_bytes(&bytes).expect("pubkey"); + let recipient = Self::new(NodeId::new(pubkey)); + recipient.0.addr.set(addr).expect("addr not set"); + recipient + } +} + impl Recipient { pub fn new(node_id: NodeId) -> Self { let node_hash = compute_hash(&node_id).0; @@ -175,11 +194,12 @@ impl Recipient { } // Expect `lookup` or `set_addr` performed earlier, otherwise panic. + #[allow(unused)] pub(super) fn get_addr(&self) -> Option { *self.0.addr.get().expect("get addr called before lookup") } - fn lookup(&self, handle: &(impl PeerAddrLookup + ?Sized)) -> &Option { + pub fn lookup(&self, handle: &(impl PeerAddrLookup + ?Sized)) -> &Option { self.0.addr.get_or_init(|| { let addr = handle.lookup(&self.0.node_id); if addr.is_none() { @@ -190,6 +210,16 @@ impl Recipient { } } +#[cfg(test)] +pub struct DummyPeerLookup; + +#[cfg(test)] +impl PeerAddrLookup for DummyPeerLookup { + fn lookup(&self, _node_id: &NodeId) -> Option { + panic!("recipient addr should be self contained") + } +} + /// Stuff to include: /// /// - 65 bytes => Signature of sender over hash(rest of message up to merkle proof, concatenated with merkle root) @@ -407,7 +437,7 @@ pub(super) struct MerkleBatch<'a, PT: PubKey> { pub(super) fn merkle_batches( all_chunks: &mut [Chunk], layout: PacketLayout, -) -> impl Iterator> { +) -> impl Iterator> { let batch_len = layout.merkle_batch_len(); debug_assert!(batch_len > 0); all_chunks.chunks_mut(batch_len).map(MerkleBatch::from) @@ -639,16 +669,6 @@ pub(crate) fn build_header( Ok(buffer.freeze()) } -fn lookup_recipient_addrs(chunks: &Vec>, handle: &PL) -where - PT: PubKey, - PL: PeerAddrLookup + ?Sized, -{ - for chunk in chunks { - chunk.recipient.lookup(handle); - } -} - impl AssembleMode { pub fn expected_chunk_order(self) -> Option { match self { @@ -667,95 +687,92 @@ impl AssembleMode { fn assemble_udp_messages_into( self, - collector: &mut impl Collector, + collector: &mut impl Collector>, chunks: Vec>, - layout: PacketLayout, ) { match self { AssembleMode::GsoFull | AssembleMode::GsoBestEffort => { - Self::assemble_gso_udp_messages_into(collector, chunks, layout); + Self::assemble_gso_udp_messages_into(collector, chunks); } AssembleMode::RoundRobin => { - Self::assemble_standalone_udp_messages_into(collector, chunks, layout); + Self::assemble_standalone_udp_messages_into(collector, chunks); } } } fn assemble_standalone_udp_messages_into( - collector: &mut impl Collector, + collector: &mut impl Collector>, chunks: Vec>, - layout: PacketLayout, ) { collector.reserve(chunks.len()); for chunk in chunks { - let Some(dest) = chunk.recipient.get_addr() else { - continue; - }; - collector.push(UdpMessage { - dest, - payload: chunk.payload.freeze(), - stride: layout.segment_len(), - }); + collector.push(chunk.into()); } } fn assemble_gso_udp_messages_into( - collector: &mut impl Collector, + collector: &mut impl Collector>, chunks: Vec>, - layout: PacketLayout, ) { - struct AggregatedChunk { - dest: SocketAddr, - payload: BytesMut, - } - - impl AggregatedChunk { - fn from_chunk(chunk: Chunk, dest: SocketAddr) -> Self { - Self { - dest, - payload: chunk.payload, - } - } - - fn into_udp_message(self, stride: usize) -> UdpMessage { - UdpMessage { - dest: self.dest, - payload: self.payload.freeze(), - stride, - } - } - } - - let stride = layout.segment_len(); let mut agg_chunk = None; for chunk in chunks { - let Some(dest) = chunk.recipient.get_addr() else { - // skip chunks with unknown recipient - continue; - }; - let Some(agg) = &mut agg_chunk else { // first chunk, start a new aggregation - agg_chunk = Some(AggregatedChunk::from_chunk(chunk, dest)); + agg_chunk = Some(AggregatedChunk::from(chunk)); continue; }; - if agg.dest == dest { - // same recipient, merge the payload BytesMut::unsplit - // is O(1) when the chunk payload are consecutive. - agg.payload.unsplit(chunk.payload); - continue; + if let Some(to_flush) = agg.aggregate(chunk) { + // different recipient, flush the previous message + collector.push(to_flush.into()); } - - // different recipient, flush the previous message - let next_agg = AggregatedChunk::from_chunk(chunk, dest); - let udp_msg = std::mem::replace(agg, next_agg).into_udp_message(stride); - collector.push(udp_msg); } if let Some(agg) = agg_chunk.take() { - collector.push(agg.into_udp_message(stride)); + collector.push(agg.into()); + } + } +} + +// used in gso grouping +struct AggregatedChunk { + recipient: Recipient, + payload: BytesMut, + stride: usize, +} + +impl AggregatedChunk { + #[must_use] + fn aggregate(&mut self, chunk: Chunk) -> Option { + if self.recipient == chunk.recipient && chunk.payload.len() == self.stride { + // same recipient, merge the payload. BytesMut::unsplit is + // O(1) when the chunk payload are consecutive. + self.payload.unsplit(chunk.payload); + return None; + } + + let new_agg = chunk.into(); + Some(std::mem::replace(self, new_agg)) + } +} + +impl From> for AggregatedChunk { + fn from(chunk: Chunk) -> Self { + Self { + recipient: chunk.recipient, + stride: chunk.payload.len(), + payload: chunk.payload, + } + } +} + +impl From> for UdpMessage { + fn from(agg_chunk: AggregatedChunk) -> Self { + UdpMessage { + recipient: agg_chunk.recipient, + stride: agg_chunk.stride, + payload: agg_chunk.payload.freeze(), } } } diff --git a/monad-raptorcast/src/packet/assigner.rs b/monad-raptorcast/src/packet/assigner.rs index 13719e843a..39f0adb6c6 100644 --- a/monad-raptorcast/src/packet/assigner.rs +++ b/monad-raptorcast/src/packet/assigner.rs @@ -334,7 +334,7 @@ pub(crate) trait ChunkAssigner { &self, num_symbols: usize, preferred_order: Option, - ) -> Result>; + ) -> Result>; } impl ChunkAssigner for Replicated { @@ -342,7 +342,7 @@ impl ChunkAssigner for Replicated { &self, num_symbols: usize, preferred_order: Option, - ) -> Result> { + ) -> Result> { if self.recipients.is_empty() { tracing::warn!("no recipients specified for chunk assigner"); return Ok(ChunkAssignment::empty()); @@ -415,7 +415,7 @@ impl Partitioned { } } - fn assign_gso(&self, num_symbols: usize) -> ChunkAssignment { + fn assign_gso(&self, num_symbols: usize) -> ChunkAssignment<'_, PT> { let num_nodes = self.weighted_nodes.len(); let mut assignment = ChunkAssignment::with_capacity(num_nodes); assignment.hint_order(ChunkOrder::GsoFriendly); @@ -439,7 +439,7 @@ impl ChunkAssigner for Partitioned { &self, num_symbols: usize, _preferred_order: Option, - ) -> Result> { + ) -> Result> { if self.weighted_nodes.is_empty() { tracing::warn!("no nodes specified for partitioned chunk assigner"); return Ok(ChunkAssignment::empty()); @@ -511,7 +511,7 @@ impl ChunkAssigner for StakeBasedWithRC { &self, num_symbols: usize, _preferred_order: Option, - ) -> Result> { + ) -> Result> { if self.validator_set.is_empty() { tracing::warn!("no nodes specified for partitioned chunk assigner"); return Ok(ChunkAssignment::empty()); @@ -625,7 +625,7 @@ mod tests { Self { slices } } - fn assign_chunks(&self) -> ChunkAssignment { + fn assign_chunks(&self) -> ChunkAssignment<'_, PT> { let mut assignment = ChunkAssignment::with_capacity(self.slices.len()); for slice in &self.slices { assignment.push(&slice.0, slice.1.clone()); diff --git a/monad-raptorcast/src/packet/builder.rs b/monad-raptorcast/src/packet/builder.rs index b6adf1e618..d576e5d849 100644 --- a/monad-raptorcast/src/packet/builder.rs +++ b/monad-raptorcast/src/packet/builder.rs @@ -26,7 +26,7 @@ use rand::Rng; use super::{ assembler::{self, build_header, AssembleMode, BroadcastType, PacketLayout}, assigner::{self, ChunkAssignment}, - BuildError, ChunkAssigner, PeerAddrLookup, UdpMessage, + BuildError, ChunkAssigner, UdpMessage, }; use crate::{ message::MAX_MESSAGE_SIZE, @@ -82,14 +82,12 @@ enum TimestampMode { RealTime, } -pub struct MessageBuilder<'key, ST, PL> +pub struct MessageBuilder<'key, ST> where ST: CertificateSignatureRecoverable, - PL: PeerAddrLookup>, { // support both owned or borrowed keys key: MaybeArc<'key, ST::KeyPairType>, - peer_lookup: PL, // required fields group_id: Option, @@ -102,15 +100,13 @@ where assemble_mode: AssembleMode, } -impl<'key, ST, PL> Clone for MessageBuilder<'key, ST, PL> +impl<'key, ST> Clone for MessageBuilder<'key, ST> where ST: CertificateSignatureRecoverable, - PL: PeerAddrLookup> + Clone, { fn clone(&self) -> Self { Self { key: self.key.clone(), - peer_lookup: self.peer_lookup.clone(), group_id: self.group_id, redundancy: self.redundancy, unix_ts_ms: self.unix_ts_ms, @@ -121,13 +117,12 @@ where } } -impl<'key, ST, PL> MessageBuilder<'key, ST, PL> +impl<'key, ST> MessageBuilder<'key, ST> where ST: CertificateSignatureRecoverable, - PL: PeerAddrLookup>, { #[allow(private_bounds)] - pub fn new(key: K, peer_lookup: PL) -> Self + pub fn new(key: K) -> Self where K: Into>, { @@ -137,7 +132,6 @@ where Self { key, - peer_lookup, // default fields redundancy: None, @@ -192,24 +186,9 @@ where } // ----- Prepare override builder ----- - pub fn prepare(&self) -> PreparedMessageBuilder<'_, 'key, ST, PL, PL> { + pub fn prepare(&self) -> PreparedMessageBuilder<'_, 'key, ST> { PreparedMessageBuilder { base: self, - peer_lookup: None, - group_id: None, - } - } - - pub fn prepare_with_peer_lookup( - &self, - peer_lookup: PL2, - ) -> PreparedMessageBuilder<'_, 'key, ST, PL, PL2> - where - PL2: PeerAddrLookup>, - { - PreparedMessageBuilder { - base: self, - peer_lookup: Some(peer_lookup), group_id: None, } } @@ -222,7 +201,7 @@ where collector: &mut C, ) -> Result<()> where - C: super::Collector, + C: super::Collector>>, { self.prepare() .build_into(app_message, build_target, collector) @@ -232,29 +211,24 @@ where &self, app_message: &[u8], build_target: &BuildTarget, - ) -> Result> { + ) -> Result>>> { self.prepare().build_vec(app_message, build_target) } } -pub struct PreparedMessageBuilder<'base, 'key, ST, PL, PL2> +pub struct PreparedMessageBuilder<'base, 'key, ST> where ST: CertificateSignatureRecoverable, - PL: PeerAddrLookup>, - PL2: PeerAddrLookup>, { - base: &'base MessageBuilder<'key, ST, PL>, + base: &'base MessageBuilder<'key, ST>, // Add extra override fields as needed - peer_lookup: Option, group_id: Option, } -impl<'base, 'key, ST, PL, PL2> PreparedMessageBuilder<'base, 'key, ST, PL, PL2> +impl<'base, 'key, ST> PreparedMessageBuilder<'base, 'key, ST> where ST: CertificateSignatureRecoverable, - PL: PeerAddrLookup>, - PL2: PeerAddrLookup>, { // ----- Setters for overrides ----- pub fn group_id(mut self, group_id: GroupId) -> Self { @@ -423,7 +397,7 @@ where collector: &mut C, ) -> Result<()> where - C: super::Collector, + C: super::Collector>>, { // figure out the layout of the packet let segment_size = self.unwrap_segment_size()?; @@ -459,18 +433,13 @@ where )?; // assemble the chunks's headers and content - let peer_lookup: &dyn PeerAddrLookup<_> = match &self.peer_lookup { - Some(pl) => pl, - None => &self.base.peer_lookup, - }; - assembler::assemble::( + assembler::assemble::( self.base.key.as_ref(), layout, app_message, &header, assignment, assemble_mode, - peer_lookup, collector, )?; @@ -481,7 +450,7 @@ where &self, app_message: &[u8], build_target: &BuildTarget, - ) -> Result> { + ) -> Result>>> { let mut packets = Vec::new(); self.build_into(app_message, build_target, &mut packets)?; Ok(packets) diff --git a/monad-raptorcast/src/packet/mod.rs b/monad-raptorcast/src/packet/mod.rs index 9ad66b772a..c857a62214 100644 --- a/monad-raptorcast/src/packet/mod.rs +++ b/monad-raptorcast/src/packet/mod.rs @@ -36,8 +36,8 @@ use crate::{ }; #[derive(Debug, Clone)] -pub struct UdpMessage { - pub dest: SocketAddr, +pub struct UdpMessage { + pub recipient: Recipient, pub payload: Bytes, pub stride: usize, } @@ -64,7 +64,7 @@ pub enum BuildError { RedundancyTooHigh, } -pub(crate) trait PeerAddrLookup { +pub trait PeerAddrLookup { fn lookup(&self, node_id: &NodeId) -> Option; } @@ -91,7 +91,7 @@ pub fn build_messages( where ST: CertificateSignatureRecoverable, { - let builder = MessageBuilder::new(key, known_addresses) + let builder = MessageBuilder::new(key) .segment_size(segment_size) .group_id(group_id) .unix_ts_ms(unix_ts_ms) @@ -103,7 +103,11 @@ where packets .into_iter() - .map(|msg| (msg.dest, msg.payload)) + .filter_map(|msg| { + msg.recipient + .lookup(known_addresses) + .map(|dest| (dest, msg.payload)) + }) .collect() } @@ -172,15 +176,15 @@ where } } -impl PeerAddrLookup for &T +impl PeerAddrLookup for F where - PT: PubKey, - T: PeerAddrLookup, + F: Fn(&NodeId) -> Option, { fn lookup(&self, node_id: &NodeId) -> Option { - (*self).lookup(node_id) + self(node_id) } } + impl PeerAddrLookup for std::sync::Arc where PT: PubKey, @@ -212,21 +216,23 @@ where // Batch assembled UdpMessages into UnicastMsgs for consumption in // dataplane, flush on buffer full and on drop. -pub struct UdpMessageBatcher +pub struct UdpMessageBatcher where F: FnMut(monad_dataplane::UnicastMsg), { + peer_lookup: PL, buffer_size: usize, buffer: monad_dataplane::UnicastMsg, sink: F, } -impl UdpMessageBatcher +impl UdpMessageBatcher where F: FnMut(monad_dataplane::UnicastMsg), { - pub fn new(buffer_size: usize, sink: F) -> Self { + pub fn new(buffer_size: usize, peer_lookup: PL, sink: F) -> Self { Self { + peer_lookup, buffer_size, buffer: monad_dataplane::UnicastMsg { msgs: Vec::with_capacity(buffer_size), @@ -252,7 +258,7 @@ where } } -impl Drop for UdpMessageBatcher +impl Drop for UdpMessageBatcher where F: FnMut(monad_dataplane::UnicastMsg), { @@ -261,11 +267,16 @@ where } } -impl Collector for UdpMessageBatcher +impl Collector> for UdpMessageBatcher where F: FnMut(monad_dataplane::UnicastMsg), + PT: PubKey, + PL: PeerAddrLookup, { - fn push(&mut self, item: UdpMessage) { + fn push(&mut self, item: UdpMessage) { + let Some(dest) = item.recipient.lookup(&self.peer_lookup) else { + return; + }; let stride = item.stride as u16; // uninitialized, set the stride @@ -285,7 +296,7 @@ where self.buffer.stride = stride; } - self.buffer.msgs.push((item.dest, item.payload)); + self.buffer.msgs.push((*dest, item.payload)); if self.buffer.msgs.len() >= self.buffer_size { self.flush(); @@ -298,22 +309,23 @@ mod tests { use std::cell::RefCell; use super::*; + use crate::packet::assembler::DummyPeerLookup; #[test] fn test_udp_message_batcher() { let collected_batches: RefCell> = Default::default(); - let dest = "127.0.0.1:3000".parse().unwrap(); + let recipient = Recipient::dummy("127.0.0.1:3000".parse().ok()); let msg_batch_1 = vec![ // 4 messages, each with stride 4 - UdpMessage { dest, payload: Bytes::from(vec![42; 4]), stride: 4 }; 4 + UdpMessage { recipient: recipient.clone(), payload: Bytes::from(vec![42; 4]), stride: 4 }; 4 ]; let msg_batch_2 = vec![ // 4 messages, each with stride 5 - UdpMessage { dest, payload: Bytes::from(vec![43; 5]), stride: 5 }; 4 + UdpMessage { recipient, payload: Bytes::from(vec![43; 5]), stride: 5 }; 4 ]; - let mut batcher = UdpMessageBatcher::new(3, |batch| { + let mut batcher = UdpMessageBatcher::new(3, DummyPeerLookup, |batch| { collected_batches.borrow_mut().push(batch); }); for msg in msg_batch_1 { diff --git a/monad-raptorcast/src/raptorcast_secondary/mod.rs b/monad-raptorcast/src/raptorcast_secondary/mod.rs index c4463882d0..474b5ab3d2 100644 --- a/monad-raptorcast/src/raptorcast_secondary/mod.rs +++ b/monad-raptorcast/src/raptorcast_secondary/mod.rs @@ -383,7 +383,7 @@ where } } - fn metrics(&self) -> ExecutorMetricsChain { + fn metrics(&self) -> ExecutorMetricsChain<'_> { match &self.role { Role::Publisher(publisher) => publisher.metrics().into(), Role::Client(client) => client.metrics().into(), diff --git a/monad-raptorcast/src/udp.rs b/monad-raptorcast/src/udp.rs index 9c48303f62..14d0b4f1f1 100644 --- a/monad-raptorcast/src/udp.rs +++ b/monad-raptorcast/src/udp.rs @@ -1205,7 +1205,7 @@ mod tests { #[case] raptorcast: bool, #[case] should_succeed: bool, ) { - let (key, validators, known_addresses) = validator_set(); + let (key, validators, _known_addresses) = validator_set(); let epoch_validators = validators.view_without(vec![&NodeId::new(key.pubkey())]); let target = if raptorcast { BuildTarget::Raptorcast(epoch_validators) @@ -1213,7 +1213,7 @@ mod tests { BuildTarget::Broadcast(epoch_validators.into()) }; let app_msg = vec![0; app_msg_len]; - let messages = MessageBuilder::new(&key, known_addresses) + let messages = MessageBuilder::new(&key) .segment_size(DEFAULT_SEGMENT_SIZE as usize) .group_id(GroupId::Primary(EPOCH)) .redundancy(Redundancy::from_u8(1)) diff --git a/monad-raptorcast/src/util.rs b/monad-raptorcast/src/util.rs index ca1bbf6fce..3a7530c810 100644 --- a/monad-raptorcast/src/util.rs +++ b/monad-raptorcast/src/util.rs @@ -130,7 +130,7 @@ impl FullNodes

{ Self { list: nodes } } - pub fn view(&self) -> FullNodesView

{ + pub fn view(&self) -> FullNodesView<'_, P> { FullNodesView(&self.list) } } @@ -443,7 +443,7 @@ where &self.round_span } - fn empty_iterator(&self) -> GroupIterator { + fn empty_iterator(&self) -> GroupIterator<'_, ST> { GroupIterator { group: self, num_consumed: usize::MAX, @@ -461,7 +461,7 @@ where &self, author_id: &NodeId>, seed: usize, - ) -> GroupIterator { + ) -> GroupIterator<'_, ST> { // Hint for the index of author_id within self.sorted_other_peers. // We want to skip it when iterating the peers for broadcasting. let author_id_ix = if let Some(root_vid) = self.validator_id { @@ -647,7 +647,7 @@ where &self, msg_group_id: GroupId, msg_author: &NodeId>, // skipped when iterating RaptorCast group - ) -> Option> { + ) -> Option> { let rebroadcast_group = match msg_group_id { GroupId::Primary(msg_epoch) => self.validator_map.get(&msg_epoch)?, GroupId::Secondary(msg_round) => { diff --git a/monad-wireauth/src/api.rs b/monad-wireauth/src/api.rs index 38a846a9db..78a2399efa 100644 --- a/monad-wireauth/src/api.rs +++ b/monad-wireauth/src/api.rs @@ -91,7 +91,7 @@ impl> API { } } - pub fn metrics(&self) -> ExecutorMetricsChain { + pub fn metrics(&self) -> ExecutorMetricsChain<'_> { ExecutorMetricsChain::default() .push(&self.metrics) .push(self.state.metrics())