diff --git a/cmd/ethrex/ethrex.rs b/cmd/ethrex/ethrex.rs index 9af8f5ef63..e2482ca3f3 100644 --- a/cmd/ethrex/ethrex.rs +++ b/cmd/ethrex/ethrex.rs @@ -264,18 +264,15 @@ async fn main() { let block_producer_engine = ethrex_dev::block_producer::start_block_producer(url, authrpc_jwtsecret.into(), head_block_hash, max_tries, 1000, ethrex_core::Address::default()); tracker.spawn(block_producer_engine); } else { - let networking = ethrex_net::start_network( + ethrex_net::start_network( local_p2p_node, tracker.clone(), - udp_socket_addr, - tcp_socket_addr, bootnodes, signer, peer_table.clone(), store, ) - .into_future(); - tracker.spawn(networking); + .await.expect("Network starts"); tracker.spawn(ethrex_net::periodically_show_peer_stats(peer_table)); } } diff --git a/crates/networking/p2p/discv4/helpers.rs b/crates/networking/p2p/discv4/helpers.rs new file mode 100644 index 0000000000..b0f21e42f0 --- /dev/null +++ b/crates/networking/p2p/discv4/helpers.rs @@ -0,0 +1,29 @@ +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +pub fn get_msg_expiration_from_seconds(seconds: u64) -> u64 { + (SystemTime::now() + Duration::from_secs(seconds)) + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +pub fn is_msg_expired(expiration: u64) -> bool { + // this cast to a signed integer is needed as the rlp decoder doesn't take into account the sign + // otherwise if a msg contains a negative expiration, it would pass since as it would wrap around the u64. + (expiration as i64) < (current_unix_time() as i64) +} + +pub fn elapsed_time_since(unix_timestamp: u64) -> u64 { + let time = SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(unix_timestamp); + SystemTime::now() + .duration_since(time) + .unwrap_or_default() + .as_secs() +} + +pub fn current_unix_time() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} diff --git a/crates/networking/p2p/discv4/lookup.rs b/crates/networking/p2p/discv4/lookup.rs new file mode 100644 index 0000000000..c1a62f4418 --- /dev/null +++ b/crates/networking/p2p/discv4/lookup.rs @@ -0,0 +1,420 @@ +use super::{ + helpers::get_msg_expiration_from_seconds, + messages::{FindNodeMessage, Message}, + server::DiscoveryError, +}; +use crate::{ + kademlia::{bucket_number, MAX_NODES_PER_BUCKET}, + node_id_from_signing_key, + types::Node, + P2PContext, +}; +use ethrex_core::H512; +use k256::ecdsa::SigningKey; +use rand::rngs::OsRng; +use std::{collections::HashSet, net::SocketAddr, sync::Arc, time::Duration}; +use tokio::net::UdpSocket; +use tracing::debug; + +#[derive(Clone, Debug)] +pub struct Discv4LookupHandler { + ctx: P2PContext, + udp_socket: Arc, + interval_minutes: u64, +} + +impl Discv4LookupHandler { + pub fn new(ctx: P2PContext, udp_socket: Arc, interval_minutes: u64) -> Self { + Self { + ctx, + udp_socket, + interval_minutes, + } + } + + /// Starts a tokio scheduler that: + /// - performs random lookups to discover new nodes. + /// + /// **Random lookups** + /// + /// Random lookups work in the following manner: + /// 1. Every 30min we spawn three concurrent lookups: one closest to our pubkey + /// and three other closest to random generated pubkeys. + /// 2. Every lookup starts with the closest nodes from our table. + /// Each lookup keeps track of: + /// - Peers that have already been asked for nodes + /// - Peers that have been already seen + /// - Potential peers to query for nodes: a vector of up to 16 entries holding the closest peers to the pubkey. + /// This vector is initially filled with nodes from our table. + /// 3. We send a `find_node` to the closest 3 nodes (that we have not yet asked) from the pubkey. + /// 4. We wait for the neighbors response and push or replace those that are closer to the potential peers array. + /// 5. We select three other nodes from the potential peers vector and do the same until one lookup + /// doesn't have any node to ask. + /// + /// See more https://github.com/ethereum/devp2p/blob/master/discv4.md#recursive-lookup + pub fn start(&self, initial_interval_wait_seconds: u64) { + self.ctx.tracker.spawn({ + let self_clone = self.clone(); + async move { + self_clone + .start_lookup_loop(initial_interval_wait_seconds) + .await; + } + }); + } + + async fn start_lookup_loop(&self, initial_interval_wait_seconds: u64) { + let mut interval = tokio::time::interval(Duration::from_secs(self.interval_minutes)); + tokio::time::sleep(Duration::from_secs(initial_interval_wait_seconds)).await; + + loop { + // first tick is immediate, + interval.tick().await; + + debug!("Starting lookup"); + + // lookup closest to our node_id + self.ctx.tracker.spawn({ + let self_clone = self.clone(); + async move { + self_clone + .recursive_lookup(self_clone.ctx.local_node.node_id) + .await + } + }); + + // lookup closest to 3 random keys + for _ in 0..3 { + let random_pub_key = SigningKey::random(&mut OsRng); + self.ctx.tracker.spawn({ + let self_clone = self.clone(); + async move { + self_clone + .recursive_lookup(node_id_from_signing_key(&random_pub_key)) + .await + } + }); + } + + debug!("Lookup finished"); + } + } + + async fn recursive_lookup(&self, target: H512) { + // lookups start with the closest nodes to the target from our table + let mut peers_to_ask: Vec = self.ctx.table.lock().await.get_closest_nodes(target); + // stores the peers in peers_to_ask + the peers that were in peers_to_ask but were replaced by closer targets + let mut seen_peers: HashSet = HashSet::default(); + let mut asked_peers = HashSet::default(); + + seen_peers.insert(self.ctx.local_node.node_id); + for node in &peers_to_ask { + seen_peers.insert(node.node_id); + } + + loop { + let (nodes_found, queries) = self.lookup(target, &mut asked_peers, &peers_to_ask).await; + + for node in nodes_found { + if !seen_peers.contains(&node.node_id) { + seen_peers.insert(node.node_id); + self.peers_to_ask_push(&mut peers_to_ask, target, node); + } + } + + // the lookup finishes when there are no more queries to do + // that happens when we have asked all the peers + if queries == 0 { + break; + } + } + } + + async fn lookup( + &self, + target: H512, + asked_peers: &mut HashSet, + nodes_to_ask: &Vec, + ) -> (Vec, u32) { + // send FIND_NODE as much as three times + let alpha = 3; + let mut queries = 0; + let mut nodes = vec![]; + + for node in nodes_to_ask { + if asked_peers.contains(&node.node_id) { + continue; + } + let mut locked_table = self.ctx.table.lock().await; + if let Some(peer) = locked_table.get_by_node_id_mut(node.node_id) { + // if the peer has an ongoing find_node request, don't query + if peer.find_node_request.is_none() { + let (tx, mut receiver) = tokio::sync::mpsc::unbounded_channel::>(); + peer.new_find_node_request_with_sender(tx); + + // Release the lock + drop(locked_table); + + queries += 1; + asked_peers.insert(node.node_id); + if let Ok(mut found_nodes) = self + .find_node_and_wait_for_response(*node, target, &mut receiver) + .await + { + nodes.append(&mut found_nodes); + } + + if let Some(peer) = self.ctx.table.lock().await.get_by_node_id_mut(node.node_id) + { + peer.find_node_request = None; + }; + } + } + + if queries == alpha { + break; + } + } + + (nodes, queries) + } + + /// Adds a node to `peers_to_ask` if there's space; otherwise, replaces the farthest node + /// from `target` if the new node is closer. + fn peers_to_ask_push(&self, peers_to_ask: &mut Vec, target: H512, node: Node) { + let distance = bucket_number(target, node.node_id); + + if peers_to_ask.len() < MAX_NODES_PER_BUCKET { + peers_to_ask.push(node); + return; + } + + // replace this node for the one whose distance to the target is the highest + let (mut idx_to_replace, mut highest_distance) = (None, 0); + + for (i, peer) in peers_to_ask.iter().enumerate() { + let current_distance = bucket_number(peer.node_id, target); + + if distance < current_distance && current_distance >= highest_distance { + highest_distance = current_distance; + idx_to_replace = Some(i); + } + } + + if let Some(idx) = idx_to_replace { + peers_to_ask[idx] = node; + } + } + + async fn find_node_and_wait_for_response( + &self, + node: Node, + target_id: H512, + request_receiver: &mut tokio::sync::mpsc::UnboundedReceiver>, + ) -> Result, DiscoveryError> { + let expiration: u64 = get_msg_expiration_from_seconds(20); + + let msg = Message::FindNode(FindNodeMessage::new(target_id, expiration)); + + let mut buf = Vec::new(); + msg.encode_with_header(&mut buf, &self.ctx.signer); + let bytes_sent = self + .udp_socket + .send_to(&buf, SocketAddr::new(node.ip, node.udp_port)) + .await + .map_err(DiscoveryError::MessageSendFailure)?; + + if bytes_sent != buf.len() { + return Err(DiscoveryError::PartialMessageSent); + } + + let mut nodes = vec![]; + loop { + // wait as much as 5 seconds for the response + match tokio::time::timeout(Duration::from_secs(5), request_receiver.recv()).await { + Ok(Some(mut found_nodes)) => { + nodes.append(&mut found_nodes); + if nodes.len() == MAX_NODES_PER_BUCKET { + return Ok(nodes); + }; + } + Ok(None) => { + return Ok(nodes); + } + Err(_) => { + // timeout expired + return Ok(nodes); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use tokio::time::sleep; + + use super::*; + use crate::discv4::server::{ + tests::{ + connect_servers, fill_table_with_random_nodes, insert_random_node_on_custom_bucket, + start_discovery_server, + }, + Discv4Server, + }; + + fn lookup_handler_from_server(server: Discv4Server) -> Discv4LookupHandler { + Discv4LookupHandler::new( + server.ctx.clone(), + server.udp_socket.clone(), + server.lookup_interval_minutes, + ) + } + + #[tokio::test] + /** This test tests the lookup function, the idea is as follows: + * - We'll start two discovery servers (`a` & `b`) that will connect between each other + * - We'll insert random nodes to the server `a`` to fill its table + * - We'll forcedly run `lookup` and validate that a `find_node` request was sent + * by checking that new nodes have been inserted to the table + * + * This test for only one lookup, and not recursively. + */ + async fn discovery_server_lookup() -> Result<(), DiscoveryError> { + let mut server_a = start_discovery_server(8000, true).await?; + let mut server_b = start_discovery_server(8001, true).await?; + + fill_table_with_random_nodes(server_a.ctx.table.clone()).await; + + // because the table is filled, before making the connection, remove a node from the `b` bucket + // otherwise it won't be added. + let b_bucket = bucket_number( + server_a.ctx.local_node.node_id, + server_b.ctx.local_node.node_id, + ); + let node_id_to_remove = server_a.ctx.table.lock().await.buckets()[b_bucket].peers[0] + .node + .node_id; + server_a + .ctx + .table + .lock() + .await + .replace_peer_on_custom_bucket(node_id_to_remove, b_bucket); + + connect_servers(&mut server_a, &mut server_b).await?; + + // now we are going to run a lookup with us as the target + let closets_peers_to_b_from_a = server_a + .ctx + .table + .lock() + .await + .get_closest_nodes(server_b.ctx.local_node.node_id); + let nodes_to_ask = server_b + .ctx + .table + .lock() + .await + .get_closest_nodes(server_b.ctx.local_node.node_id); + + let lookup_handler = lookup_handler_from_server(server_b.clone()); + lookup_handler + .lookup( + server_b.ctx.local_node.node_id, + &mut HashSet::default(), + &nodes_to_ask, + ) + .await; + + // find_node sent, allow some time for `a` to respond + sleep(Duration::from_secs(2)).await; + + // now all peers should've been inserted + for peer in closets_peers_to_b_from_a { + let table = server_b.ctx.table.lock().await; + let node = table.get_by_node_id(peer.node_id); + // sometimes nodes can send ourselves as a neighbor + // make sure we don't add it + if peer.node_id == server_b.ctx.local_node.node_id { + assert!(node.is_none()); + } else { + assert!(node.is_some()); + } + } + Ok(()) + } + + #[tokio::test] + /** This test tests the lookup function, the idea is as follows: + * - We'll start four discovery servers (`a`, `b`, `c` & `d`) + * - `a` will be connected to `b`, `b` will be connected to `c` and `c` will be connected to `d`. + * - The server `d` will have its table filled with mock nodes + * - We'll run a recursive lookup on server `a` and we expect to end with `b`, `c`, `d` and its mock nodes + */ + async fn discovery_server_recursive_lookup() -> Result<(), DiscoveryError> { + let mut server_a = start_discovery_server(8002, true).await?; + let mut server_b = start_discovery_server(8003, true).await?; + let mut server_c = start_discovery_server(8004, true).await?; + let mut server_d = start_discovery_server(8005, true).await?; + + connect_servers(&mut server_a, &mut server_b).await?; + connect_servers(&mut server_b, &mut server_c).await?; + connect_servers(&mut server_c, &mut server_d).await?; + + // now we fill the server_d table with 3 random nodes + // the reason we don't put more is because this nodes won't respond (as they don't are not real servers) + // and so we will have to wait for the timeout on each node, which will only slow down the test + for _ in 0..3 { + insert_random_node_on_custom_bucket(server_d.ctx.table.clone(), 0).await; + } + + let mut expected_peers = vec![]; + expected_peers.extend( + server_b + .ctx + .table + .lock() + .await + .get_closest_nodes(server_a.ctx.local_node.node_id), + ); + expected_peers.extend( + server_c + .ctx + .table + .lock() + .await + .get_closest_nodes(server_a.ctx.local_node.node_id), + ); + expected_peers.extend( + server_d + .ctx + .table + .lock() + .await + .get_closest_nodes(server_a.ctx.local_node.node_id), + ); + + let lookup_handler = lookup_handler_from_server(server_a.clone()); + + // we'll run a recursive lookup closest to the server itself + lookup_handler + .recursive_lookup(server_a.ctx.local_node.node_id) + .await; + + // sometimes nodes can send ourselves as a neighbor + // make sure we don't add it + for peer in expected_peers { + let table = server_a.ctx.table.lock().await; + let node = table.get_by_node_id(peer.node_id); + + if peer.node_id == server_a.ctx.local_node.node_id { + assert!(node.is_none()); + } else { + assert!(node.is_some()); + } + } + + Ok(()) + } +} diff --git a/crates/networking/p2p/discv4.rs b/crates/networking/p2p/discv4/messages.rs similarity index 96% rename from crates/networking/p2p/discv4.rs rename to crates/networking/p2p/discv4/messages.rs index f5dce656ce..e950881348 100644 --- a/crates/networking/p2p/discv4.rs +++ b/crates/networking/p2p/discv4/messages.rs @@ -1,5 +1,4 @@ -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - +use super::helpers::current_unix_time; use crate::types::{Endpoint, Node, NodeRecord}; use bytes::BufMut; use ethrex_core::{H256, H512, H520}; @@ -12,36 +11,6 @@ use ethrex_rlp::{ use k256::ecdsa::{RecoveryId, Signature, SigningKey, VerifyingKey}; use sha3::{Digest, Keccak256}; -//todo add tests -pub fn get_expiration(seconds: u64) -> u64 { - (SystemTime::now() + Duration::from_secs(seconds)) - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() -} - -pub fn is_expired(expiration: u64) -> bool { - // this cast to a signed integer is needed as the rlp decoder doesn't take into account the sign - // otherwise a potential negative expiration would pass since it would take 2^64. - (expiration as i64) - < SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i64 -} - -pub fn time_since_in_hs(time: u64) -> u64 { - let time = SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(time); - SystemTime::now().duration_since(time).unwrap().as_secs() / 3600 -} - -pub fn time_now_unix() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() -} - #[derive(Debug, PartialEq)] pub enum PacketDecodeErr { #[allow(unused)] @@ -51,7 +20,6 @@ pub enum PacketDecodeErr { InvalidSignature, } -#[allow(unused)] #[derive(Debug)] pub struct Packet { hash: H256, @@ -126,12 +94,7 @@ impl Packet { } #[derive(Debug, Eq, PartialEq)] -// NOTE: All messages could have more fields than specified by the spec. -// Those additional fields should be ignored, and the message must be accepted. -// TODO: remove when all variants are used -#[allow(dead_code)] pub(crate) enum Message { - /// A ping message. Should be responded to with a Pong message. Ping(PingMessage), Pong(PongMessage), FindNode(FindNodeMessage), @@ -140,6 +103,20 @@ pub(crate) enum Message { ENRResponse(ENRResponseMessage), } +impl std::fmt::Display for Message { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let variant = match self { + Message::Ping(_) => "Ping", + Message::Pong(_) => "Pong", + Message::FindNode(_) => "FindNode", + Message::Neighbors(_) => "Neighbors", + Message::ENRRequest(_) => "ENRRequest", + Message::ENRResponse(_) => "ENRResponse", + }; + write!(f, "{}", variant) + } +} + impl Message { pub fn encode_with_header(&self, buf: &mut dyn BufMut, node_signer: &SigningKey) { let signature_size = 65_usize; @@ -319,7 +296,7 @@ impl Default for FindNodeRequest { fn default() -> Self { Self { nodes_sent: 0, - sent_at: time_now_unix(), + sent_at: current_unix_time(), tx: None, } } @@ -380,8 +357,6 @@ impl PongMessage { } } - // TODO: remove when used - #[allow(unused)] pub fn with_enr_seq(self, enr_seq: u64) -> Self { Self { enr_seq: Some(enr_seq), diff --git a/crates/networking/p2p/discv4/mod.rs b/crates/networking/p2p/discv4/mod.rs new file mode 100644 index 0000000000..51104b57d5 --- /dev/null +++ b/crates/networking/p2p/discv4/mod.rs @@ -0,0 +1,4 @@ +pub(super) mod helpers; +mod lookup; +pub(super) mod messages; +pub mod server; diff --git a/crates/networking/p2p/discv4/server.rs b/crates/networking/p2p/discv4/server.rs new file mode 100644 index 0000000000..7f4f1396ad --- /dev/null +++ b/crates/networking/p2p/discv4/server.rs @@ -0,0 +1,877 @@ +use super::{ + helpers::{ + current_unix_time, elapsed_time_since, get_msg_expiration_from_seconds, is_msg_expired, + }, + lookup::Discv4LookupHandler, + messages::{ + ENRRequestMessage, ENRResponseMessage, Message, NeighborsMessage, Packet, PingMessage, + PongMessage, + }, +}; +use crate::{ + bootnode::BootNode, + handle_peer_as_initiator, + kademlia::MAX_NODES_PER_BUCKET, + types::{Endpoint, Node, NodeRecord}, + KademliaTable, P2PContext, +}; +use ethrex_core::H256; +use k256::ecdsa::{signature::hazmat::PrehashVerifier, Signature, VerifyingKey}; +use std::{ + collections::HashSet, + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, + time::Duration, +}; +use tokio::{net::UdpSocket, sync::MutexGuard}; +use tracing::{debug, error}; + +const MAX_DISC_PACKET_SIZE: usize = 1280; +const PROOF_EXPIRATION_IN_HS: u64 = 12; + +// These interval times are arbitrary numbers, maybe we should read them from a cfg or a cli param +const REVALIDATION_INTERVAL_IN_SECONDS: u64 = 30; +const PEERS_RANDOM_LOOKUP_TIME_IN_MIN: u64 = 30; + +#[derive(Debug)] +#[allow(dead_code)] +pub enum DiscoveryError { + BindSocket(std::io::Error), + MessageSendFailure(std::io::Error), + PartialMessageSent, + MessageExpired, + InvalidMessage(String), +} + +/// Implements the discv4 protocol see: https://github.com/ethereum/devp2p/blob/master/discv4.md +#[derive(Debug, Clone)] +pub struct Discv4Server { + pub(super) ctx: P2PContext, + pub(super) udp_socket: Arc, + pub(super) revalidation_interval_seconds: u64, + pub(super) lookup_interval_minutes: u64, +} + +impl Discv4Server { + /// Initializes a Discv4 UDP socket and creates a new `Discv4Server` instance. + /// Returns an error if the socket binding fails. + pub async fn try_new(ctx: P2PContext) -> Result { + let udp_socket = UdpSocket::bind(ctx.local_node.udp_addr()) + .await + .map_err(DiscoveryError::BindSocket)?; + + Ok(Self { + ctx, + udp_socket: Arc::new(udp_socket), + revalidation_interval_seconds: REVALIDATION_INTERVAL_IN_SECONDS, + lookup_interval_minutes: PEERS_RANDOM_LOOKUP_TIME_IN_MIN, + }) + } + + /// Initializes the discovery server. It: + /// - Spawns tasks to handle incoming messages and revalidate known nodes. + /// - Loads bootnodes to establish initial peer connections. + /// - Starts the lookup handler via [`Discv4LookupHandler`] to periodically search for new peers. + pub async fn start(&self, bootnodes: Vec) -> Result<(), DiscoveryError> { + let lookup_handler = Discv4LookupHandler::new( + self.ctx.clone(), + self.udp_socket.clone(), + self.lookup_interval_minutes, + ); + + self.ctx.tracker.spawn({ + let self_clone = self.clone(); + async move { self_clone.receive().await } + }); + self.ctx.tracker.spawn({ + let self_clone = self.clone(); + async move { self_clone.start_revalidation().await } + }); + self.load_bootnodes(bootnodes).await; + lookup_handler.start(10); + + Ok(()) + } + + async fn load_bootnodes(&self, bootnodes: Vec) { + for bootnode in bootnodes { + let node = Node { + ip: bootnode.socket_address.ip(), + udp_port: bootnode.socket_address.port(), + // TODO: udp port can differ from tcp port. + // see https://github.com/lambdaclass/ethrex/issues/905 + tcp_port: bootnode.socket_address.port(), + node_id: bootnode.node_id, + }; + if let Err(e) = self + .try_add_peer_and_ping(node, self.ctx.table.lock().await) + .await + { + debug!("Error while adding bootnode to table: {:?}", e); + }; + } + } + + pub async fn receive(&self) { + let mut buf = vec![0; MAX_DISC_PACKET_SIZE]; + + loop { + let (read, from) = match self.udp_socket.recv_from(&mut buf).await { + Ok(result) => result, + Err(e) => { + error!("Error receiving data from socket: {e}. Stopping discovery server"); + return; + } + }; + debug!("Received {read} bytes from {from}"); + + match Packet::decode(&buf[..read]) { + Err(e) => error!("Could not decode packet: {:?}", e), + Ok(packet) => { + let msg = packet.get_message(); + let msg_name = msg.to_string(); + debug!("Message: {:?} from {}", msg, packet.get_node_id()); + if let Err(e) = self.handle_message(packet, from).await { + debug!("Error while processing {} message: {:?}", msg_name, e); + }; + } + } + } + } + + async fn handle_message(&self, packet: Packet, from: SocketAddr) -> Result<(), DiscoveryError> { + match packet.get_message() { + Message::Ping(msg) => { + if is_msg_expired(msg.expiration) { + return Err(DiscoveryError::MessageExpired); + }; + + let node = Node { + ip: from.ip(), + udp_port: from.port(), + tcp_port: msg.from.tcp_port, + node_id: packet.get_node_id(), + }; + self.pong(packet.get_hash(), node).await?; + + let peer = { + let table = self.ctx.table.lock().await; + table.get_by_node_id(packet.get_node_id()).cloned() + }; + + let Some(peer) = peer else { + self.try_add_peer_and_ping(node, self.ctx.table.lock().await) + .await?; + return Ok(()); + }; + + // if peer was in the table and last ping was 12 hs ago + // we need to re ping to re-validate the endpoint proof + if elapsed_time_since(peer.last_ping) / 3600 >= PROOF_EXPIRATION_IN_HS { + self.ping(node, self.ctx.table.lock().await).await?; + } + if let Some(enr_seq) = msg.enr_seq { + if enr_seq > peer.record.seq && peer.is_proven { + debug!("Found outdated enr-seq, sending an enr_request"); + self.send_enr_request(peer.node, self.ctx.table.lock().await) + .await?; + } + } + + Ok(()) + } + Message::Pong(msg) => { + if is_msg_expired(msg.expiration) { + return Err(DiscoveryError::MessageExpired); + } + + let peer = { + let table = self.ctx.table.lock().await; + table.get_by_node_id(packet.get_node_id()).cloned() + }; + let Some(peer) = peer else { + return Err(DiscoveryError::InvalidMessage("not known node".into())); + }; + + let Some(ping_hash) = peer.last_ping_hash else { + return Err(DiscoveryError::InvalidMessage( + "node did not send a previous ping".into(), + )); + }; + if ping_hash != msg.ping_hash { + return Err(DiscoveryError::InvalidMessage( + "hash did not match the last corresponding ping".into(), + )); + } + + // all validations went well, mark as answered and start a rlpx connection + self.ctx + .table + .lock() + .await + .pong_answered(peer.node.node_id, current_unix_time()); + if let Some(enr_seq) = msg.enr_seq { + if enr_seq > peer.record.seq { + debug!("Found outdated enr-seq, send an enr_request"); + self.send_enr_request(peer.node, self.ctx.table.lock().await) + .await?; + } + } + + // We won't initiate a connection if we are already connected. + // This will typically be the case when revalidating a node. + if peer.is_connected { + return Ok(()); + } + + let ctx = self.ctx.clone(); + self.ctx + .tracker + .spawn(async move { handle_peer_as_initiator(ctx, peer.node).await }); + Ok(()) + } + Message::FindNode(msg) => { + if is_msg_expired(msg.expiration) { + return Err(DiscoveryError::MessageExpired); + }; + let node = { + let table = self.ctx.table.lock().await; + table.get_by_node_id(packet.get_node_id()).cloned() + }; + + let Some(node) = node else { + return Err(DiscoveryError::InvalidMessage("not a known node".into())); + }; + if !node.is_proven { + return Err(DiscoveryError::InvalidMessage("node isn't proven".into())); + } + + let nodes = { + let table = self.ctx.table.lock().await; + table.get_closest_nodes(msg.target) + }; + let nodes_chunks = nodes.chunks(4); + let expiration = get_msg_expiration_from_seconds(20); + + debug!("Sending neighbors!"); + // we are sending the neighbors in 4 different messages as not to exceed the + // maximum packet size + for nodes in nodes_chunks { + let neighbors = + Message::Neighbors(NeighborsMessage::new(nodes.to_vec(), expiration)); + let mut buf = Vec::new(); + neighbors.encode_with_header(&mut buf, &self.ctx.signer); + + let bytes_sent = self + .udp_socket + .send_to(&buf, from) + .await + .map_err(DiscoveryError::MessageSendFailure)?; + + if bytes_sent != buf.len() { + return Err(DiscoveryError::PartialMessageSent); + } + } + + Ok(()) + } + Message::Neighbors(neighbors_msg) => { + if is_msg_expired(neighbors_msg.expiration) { + return Err(DiscoveryError::MessageExpired); + }; + + let mut table_lock = self.ctx.table.lock().await; + + let Some(node) = table_lock.get_by_node_id_mut(packet.get_node_id()) else { + return Err(DiscoveryError::InvalidMessage("not a known node".into())); + }; + + let Some(req) = &mut node.find_node_request else { + return Err(DiscoveryError::InvalidMessage( + "find node request not sent".into(), + )); + }; + if current_unix_time().saturating_sub(req.sent_at) >= 60 { + node.find_node_request = None; + return Err(DiscoveryError::InvalidMessage( + "find_node request expired after one minute".into(), + )); + } + + let nodes = &neighbors_msg.nodes; + let total_nodes_sent = req.nodes_sent + nodes.len(); + + if total_nodes_sent > MAX_NODES_PER_BUCKET { + node.find_node_request = None; + return Err(DiscoveryError::InvalidMessage( + "sent more than allowed nodes".into(), + )); + } + + // update the number of node_sent + // and forward the nodes sent if a channel is attached + req.nodes_sent = total_nodes_sent; + if let Some(tx) = &req.tx { + let _ = tx.send(nodes.clone()); + } + + if total_nodes_sent == MAX_NODES_PER_BUCKET { + debug!("Neighbors request has been fulfilled"); + node.find_node_request = None; + } + + // release the lock early + // as we might be a long time pinging all the new nodes + drop(table_lock); + + debug!("Storing neighbors in our table!"); + for node in nodes { + let _ = self + .try_add_peer_and_ping(*node, self.ctx.table.lock().await) + .await; + } + + Ok(()) + } + Message::ENRRequest(msg) => { + if is_msg_expired(msg.expiration) { + return Err(DiscoveryError::MessageExpired); + } + let Ok(node_record) = + NodeRecord::from_node(self.ctx.local_node, self.ctx.enr_seq, &self.ctx.signer) + else { + return Err(DiscoveryError::InvalidMessage( + "could not build local node record".into(), + )); + }; + let msg = + Message::ENRResponse(ENRResponseMessage::new(packet.get_hash(), node_record)); + let mut buf = vec![]; + msg.encode_with_header(&mut buf, &self.ctx.signer); + + let bytes_sent = self + .udp_socket + .send_to(&buf, from) + .await + .map_err(DiscoveryError::MessageSendFailure)?; + + if bytes_sent != buf.len() { + return Err(DiscoveryError::PartialMessageSent); + } + + Ok(()) + } + Message::ENRResponse(msg) => { + let mut table_lock = self.ctx.table.lock().await; + let peer = table_lock.get_by_node_id_mut(packet.get_node_id()); + let Some(peer) = peer else { + return Err(DiscoveryError::InvalidMessage("Peer not known".into())); + }; + + let Some(req_hash) = peer.enr_request_hash else { + return Err(DiscoveryError::InvalidMessage( + "Discarding enr-response as enr-request wasn't sent".into(), + )); + }; + if req_hash != msg.request_hash { + return Err(DiscoveryError::InvalidMessage( + "Discarding enr-response did not match enr-request hash".into(), + )); + } + peer.enr_request_hash = None; + + if msg.node_record.seq < peer.record.seq { + return Err(DiscoveryError::InvalidMessage( + "msg node record is lower than the one we have".into(), + )); + } + + let record = msg.node_record.decode_pairs(); + let Some(id) = record.id else { + return Err(DiscoveryError::InvalidMessage( + "msg node record does not have required `id` field".into(), + )); + }; + + // https://github.com/ethereum/devp2p/blob/master/enr.md#v4-identity-scheme + let signature_valid = match id.as_str() { + "v4" => { + let digest = msg.node_record.get_signature_digest(); + let Some(public_key) = record.secp256k1 else { + return Err(DiscoveryError::InvalidMessage( + "signature could not be verified because public key was not provided".into(), + )); + }; + let signature_bytes = msg.node_record.signature.as_bytes(); + let Ok(signature) = Signature::from_slice(&signature_bytes[0..64]) else { + return Err(DiscoveryError::InvalidMessage( + "signature could not be build from msg signature bytes".into(), + )); + }; + let Ok(verifying_key) = + VerifyingKey::from_sec1_bytes(public_key.as_bytes()) + else { + return Err(DiscoveryError::InvalidMessage( + "public key could no be built from msg pub key bytes".into(), + )); + }; + verifying_key.verify_prehash(&digest, &signature).is_ok() + } + _ => false, + }; + if !signature_valid { + return Err(DiscoveryError::InvalidMessage( + "Signature verification invalid".into(), + )); + } + + if let Some(ip) = record.ip { + peer.node.ip = IpAddr::from(Ipv4Addr::from_bits(ip)); + } + if let Some(tcp_port) = record.tcp_port { + peer.node.tcp_port = tcp_port; + } + if let Some(udp_port) = record.udp_port { + peer.node.udp_port = udp_port; + } + peer.record = msg.node_record.clone(); + debug!( + "Node with id {:?} record has been successfully updated", + peer.node.node_id + ); + Ok(()) + } + } + } + + /// Starts a tokio scheduler that: + /// - performs periodic revalidation of the current nodes (sends a ping to the old nodes). + /// + /// **Peer revalidation** + /// + /// Peers revalidation works in the following manner: + /// 1. Every `revalidation_interval_seconds` we ping the 3 least recently pinged peers + /// 2. In the next iteration we check if they have answered + /// - if they have: we increment the liveness field by one + /// - otherwise we decrement it by the current value / 3. + /// 3. If the liveness field is 0, then we delete it and insert a new one from the replacements table + /// + /// See more https://github.com/ethereum/devp2p/blob/master/discv4.md#kademlia-table + async fn start_revalidation(&self) { + let mut interval = + tokio::time::interval(Duration::from_secs(self.revalidation_interval_seconds)); + + // first tick starts immediately + interval.tick().await; + + let mut previously_pinged_peers = HashSet::new(); + loop { + interval.tick().await; + debug!("Running peer revalidation"); + + // first check that the peers we ping have responded + for node_id in previously_pinged_peers { + let mut table_lock = self.ctx.table.lock().await; + let Some(peer) = table_lock.get_by_node_id_mut(node_id) else { + continue; + }; + + if let Some(has_answered) = peer.revalidation { + if has_answered { + peer.increment_liveness(); + } else { + peer.decrement_liveness(); + } + } + + peer.revalidation = None; + + if peer.liveness == 0 { + let new_peer = table_lock.replace_peer(node_id); + if let Some(new_peer) = new_peer { + let _ = self.ping(new_peer.node, table_lock).await; + } + } + } + + // now send a ping to the least recently pinged peers + // this might be too expensive to run if our table is filled + // maybe we could just pick them randomly + let peers = self + .ctx + .table + .lock() + .await + .get_least_recently_pinged_peers(3); + previously_pinged_peers = HashSet::default(); + for peer in peers { + debug!("Pinging peer {:?} to re-validate!", peer.node.node_id); + let _ = self.ping(peer.node, self.ctx.table.lock().await).await; + previously_pinged_peers.insert(peer.node.node_id); + let mut table = self.ctx.table.lock().await; + let peer = table.get_by_node_id_mut(peer.node.node_id); + if let Some(peer) = peer { + peer.revalidation = Some(false); + } + } + + debug!("Peer revalidation finished"); + } + } + + /// Attempts to add a node to the Kademlia table and send a ping if necessary. + /// + /// - If the node is **not found** in the table and there is enough space, it will be added, + /// and a ping message will be sent to verify connectivity. + /// - If the node is **already present**, no action is taken. + async fn try_add_peer_and_ping<'a>( + &self, + node: Node, + mut table_lock: MutexGuard<'a, KademliaTable>, + ) -> Result<(), DiscoveryError> { + // sanity check to make sure we are not storing ourselves + // a case that may happen in a neighbor message for example + if node.node_id == self.ctx.local_node.node_id { + return Ok(()); + } + + if let (Some(peer), true) = table_lock.insert_node(node) { + self.ping(peer.node, table_lock).await?; + }; + Ok(()) + } + + async fn ping<'a>( + &self, + node: Node, + mut table_lock: MutexGuard<'a, KademliaTable>, + ) -> Result<(), DiscoveryError> { + let mut buf = Vec::new(); + let expiration: u64 = get_msg_expiration_from_seconds(20); + let from = Endpoint { + ip: self.ctx.local_node.ip, + udp_port: self.ctx.local_node.udp_port, + tcp_port: self.ctx.local_node.tcp_port, + }; + let to = Endpoint { + ip: node.ip, + udp_port: node.udp_port, + tcp_port: node.tcp_port, + }; + + let ping = + Message::Ping(PingMessage::new(from, to, expiration).with_enr_seq(self.ctx.enr_seq)); + ping.encode_with_header(&mut buf, &self.ctx.signer); + let bytes_sent = self + .udp_socket + .send_to(&buf, node.udp_addr()) + .await + .map_err(DiscoveryError::MessageSendFailure)?; + + if bytes_sent != buf.len() { + return Err(DiscoveryError::PartialMessageSent); + } + + let hash = H256::from_slice(&buf[0..32]); + table_lock.update_peer_ping(node.node_id, Some(hash), current_unix_time()); + + Ok(()) + } + + async fn pong(&self, ping_hash: H256, node: Node) -> Result<(), DiscoveryError> { + let mut buf = Vec::new(); + let expiration: u64 = get_msg_expiration_from_seconds(20); + let to = Endpoint { + ip: node.ip, + udp_port: node.udp_port, + tcp_port: node.tcp_port, + }; + + let pong = Message::Pong( + PongMessage::new(to, ping_hash, expiration).with_enr_seq(self.ctx.enr_seq), + ); + pong.encode_with_header(&mut buf, &self.ctx.signer); + + let bytes_sent = self + .udp_socket + .send_to(&buf, node.udp_addr()) + .await + .map_err(DiscoveryError::MessageSendFailure)?; + + if bytes_sent != buf.len() { + Err(DiscoveryError::PartialMessageSent) + } else { + Ok(()) + } + } + + async fn send_enr_request<'a>( + &self, + node: Node, + mut table_lock: MutexGuard<'a, KademliaTable>, + ) -> Result<(), DiscoveryError> { + let mut buf = Vec::new(); + let expiration: u64 = get_msg_expiration_from_seconds(20); + let enr_req = Message::ENRRequest(ENRRequestMessage::new(expiration)); + enr_req.encode_with_header(&mut buf, &self.ctx.signer); + + let bytes_sent = self + .udp_socket + .send_to(&buf, node.udp_addr()) + .await + .map_err(DiscoveryError::MessageSendFailure)?; + if bytes_sent != buf.len() { + return Err(DiscoveryError::PartialMessageSent); + } + + let hash = H256::from_slice(&buf[0..32]); + if let Some(peer) = table_lock.get_by_node_id_mut(node.node_id) { + peer.enr_request_hash = Some(hash); + }; + + Ok(()) + } +} + +#[cfg(test)] +pub(super) mod tests { + use super::*; + use crate::{ + node_id_from_signing_key, rlpx::message::Message as RLPxMessage, MAX_MESSAGES_TO_BROADCAST, + }; + use ethrex_storage::{EngineType, Store}; + use k256::ecdsa::SigningKey; + use rand::rngs::OsRng; + use std::net::{IpAddr, Ipv4Addr}; + use tokio::{sync::Mutex, time::sleep}; + + pub async fn insert_random_node_on_custom_bucket( + table: Arc>, + bucket_idx: usize, + ) { + let node_id = node_id_from_signing_key(&SigningKey::random(&mut OsRng)); + let node = Node { + ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + tcp_port: 0, + udp_port: 0, + node_id, + }; + table + .lock() + .await + .insert_node_on_custom_bucket(node, bucket_idx); + } + + pub async fn fill_table_with_random_nodes(table: Arc>) { + for i in 0..256 { + for _ in 0..16 { + insert_random_node_on_custom_bucket(table.clone(), i).await; + } + } + } + + pub async fn start_discovery_server( + udp_port: u16, + should_start_server: bool, + ) -> Result { + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), udp_port); + let signer = SigningKey::random(&mut OsRng); + let node_id = node_id_from_signing_key(&signer); + let local_node = Node { + ip: addr.ip(), + node_id, + udp_port, + tcp_port: udp_port, + }; + + let storage = + Store::new("temp.db", EngineType::InMemory).expect("Failed to create test DB"); + let table = Arc::new(Mutex::new(KademliaTable::new(node_id))); + let (broadcast, _) = tokio::sync::broadcast::channel::<(tokio::task::Id, Arc)>( + MAX_MESSAGES_TO_BROADCAST, + ); + let tracker = tokio_util::task::TaskTracker::new(); + + let ctx = P2PContext { + local_node, + enr_seq: current_unix_time(), + tracker: tracker.clone(), + signer, + table, + storage, + broadcast, + }; + + let discv4 = Discv4Server::try_new(ctx).await?; + + if should_start_server { + tracker.spawn({ + let discv4 = discv4.clone(); + async move { + discv4.receive().await; + } + }); + } + + Ok(discv4) + } + + /// connects two mock servers by pinging a to b + pub async fn connect_servers( + server_a: &mut Discv4Server, + server_b: &mut Discv4Server, + ) -> Result<(), DiscoveryError> { + server_a + .try_add_peer_and_ping(server_b.ctx.local_node, server_a.ctx.table.lock().await) + .await?; + // allow some time for the server to respond + sleep(Duration::from_secs(1)).await; + Ok(()) + } + + #[tokio::test] + /** This is a end to end test on the discovery server, the idea is as follows: + * - We'll start two discovery servers (`a` & `b`) to ping between each other + * - We'll make `b` ping `a`, and validate that the connection is right + * - Then we'll wait for a revalidation where we expect everything to be the same + * - We'll do this five 5 more times + * - Then we'll stop server `a` so that it doesn't respond to re-validations + * - We expect server `b` to remove node `a` from its table after 3 re-validations + * To make this run faster, we'll change the revalidation time to be every 2secs + */ + async fn discovery_server_revalidation() -> Result<(), DiscoveryError> { + let mut server_a = start_discovery_server(7998, true).await?; + let mut server_b = start_discovery_server(7999, true).await?; + + connect_servers(&mut server_a, &mut server_b).await?; + + server_b.revalidation_interval_seconds = 2; + + // start revalidation server + server_b.ctx.tracker.spawn({ + let server_b = server_b.clone(); + async move { server_b.start_revalidation().await } + }); + + for _ in 0..5 { + sleep(Duration::from_millis(2500)).await; + // by now, b should've send a revalidation to a + let table = server_b.ctx.table.lock().await; + let node = table.get_by_node_id(server_a.ctx.local_node.node_id); + assert!(node.is_some_and(|n| n.revalidation.is_some())); + } + + // make sure that `a` has responded too all the re-validations + // we can do that by checking the liveness + { + let table = server_b.ctx.table.lock().await; + let node = table.get_by_node_id(server_a.ctx.local_node.node_id); + assert_eq!(node.map_or(0, |n| n.liveness), 6); + } + + // now, stopping server `a` is not trivial + // so we'll instead change its port, so that no one responds + { + let mut table = server_b.ctx.table.lock().await; + let node = table.get_by_node_id_mut(server_a.ctx.local_node.node_id); + if let Some(node) = node { + node.node.udp_port = 0 + }; + } + + // now the liveness field should start decreasing until it gets to 0 + // which should happen in 3 re-validations + for _ in 0..2 { + sleep(Duration::from_millis(2500)).await; + let table = server_b.ctx.table.lock().await; + let node = table.get_by_node_id(server_a.ctx.local_node.node_id); + assert!(node.is_some_and(|n| n.revalidation.is_some())); + } + sleep(Duration::from_millis(2500)).await; + + // finally, `a`` should not exist anymore + let table = server_b.ctx.table.lock().await; + assert!(table + .get_by_node_id(server_a.ctx.local_node.node_id) + .is_none()); + Ok(()) + } + + #[tokio::test] + /** + * This test verifies the exchange and update of ENR (Ethereum Node Record) messages. + * The test follows these steps: + * + * 1. Start two nodes. + * 2. Wait until they establish a connection. + * 3. Assert that they exchange their records and store them + * 3. Modify the ENR (node record) of one of the nodes. + * 4. Send a new ping message and check that an ENR request was triggered. + * 5. Verify that the updated node record has been correctly received and stored. + */ + async fn discovery_enr_message() -> Result<(), DiscoveryError> { + let mut server_a = start_discovery_server(8006, true).await?; + let mut server_b = start_discovery_server(8007, true).await?; + + connect_servers(&mut server_a, &mut server_b).await?; + + // wait some time for the enr request-response finishes + sleep(Duration::from_millis(2500)).await; + + let expected_record = NodeRecord::from_node( + server_b.ctx.local_node, + current_unix_time(), + &server_b.ctx.signer, + ) + .expect("Node record is created from node"); + + let server_a_peer_b = server_a + .ctx + .table + .lock() + .await + .get_by_node_id(server_b.ctx.local_node.node_id) + .cloned() + .unwrap(); + + // we only match the pairs, as the signature and seq will change + // because they are calculated with the current time + assert!(server_a_peer_b.record.decode_pairs() == expected_record.decode_pairs()); + + // Modify server_a's record of server_b with an incorrect TCP port. + // This simulates an outdated or incorrect entry in the node table. + server_a + .ctx + .table + .lock() + .await + .get_by_node_id_mut(server_b.ctx.local_node.node_id) + .unwrap() + .node + .tcp_port = 10; + + // update the enr_seq of server_b so that server_a notices it is outdated + // and sends a request to update it + server_b.ctx.enr_seq = current_unix_time(); + + // Send a ping from server_b to server_a. + // server_a should notice the enr_seq is outdated + // and trigger a enr-request to server_b to update the record. + server_b + .ping(server_a.ctx.local_node, server_b.ctx.table.lock().await) + .await?; + + // Wait for the update to propagate. + sleep(Duration::from_millis(2500)).await; + + // Verify that server_a has updated its record of server_b with the correct TCP port. + let table_lock = server_a.ctx.table.lock().await; + let server_a_node_b_record = table_lock + .get_by_node_id(server_b.ctx.local_node.node_id) + .unwrap(); + + assert!(server_a_node_b_record.node.tcp_port == server_b.ctx.local_node.tcp_port); + + Ok(()) + } +} diff --git a/crates/networking/p2p/kademlia.rs b/crates/networking/p2p/kademlia.rs index 72607cacfc..a9a8e4638e 100644 --- a/crates/networking/p2p/kademlia.rs +++ b/crates/networking/p2p/kademlia.rs @@ -1,5 +1,5 @@ use crate::{ - discv4::{time_now_unix, FindNodeRequest}, + discv4::messages::FindNodeRequest, peer_channels::PeerChannels, rlpx::p2p::Capability, types::{Node, NodeRecord}, @@ -94,7 +94,7 @@ impl KademliaTable { return (None, false); } - let peer = PeerData::new(node, NodeRecord::default(), time_now_unix(), 0, false); + let peer = PeerData::new(node, NodeRecord::default(), false); if self.buckets[bucket_idx].peers.len() == MAX_NODES_PER_BUCKET { self.insert_as_replacement(&peer, bucket_idx); @@ -148,7 +148,7 @@ impl KademliaTable { nodes.iter().map(|a| a.0).collect() } - pub fn pong_answered(&mut self, node_id: H512) { + pub fn pong_answered(&mut self, node_id: H512, pong_at: u64) { let peer = self.get_by_node_id_mut(node_id); if peer.is_none() { return; @@ -156,12 +156,12 @@ impl KademliaTable { let peer = peer.unwrap(); peer.is_proven = true; - peer.last_pong = time_now_unix(); + peer.last_pong = pong_at; peer.last_ping_hash = None; peer.revalidation = peer.revalidation.and(Some(true)); } - pub fn update_peer_ping(&mut self, node_id: H512, ping_hash: Option) { + pub fn update_peer_ping(&mut self, node_id: H512, ping_hash: Option, ping_at: u64) { let peer = self.get_by_node_id_mut(node_id); if peer.is_none() { return; @@ -169,26 +169,7 @@ impl KademliaTable { let peer = peer.unwrap(); peer.last_ping_hash = ping_hash; - peer.last_ping = time_now_unix(); - } - - pub fn update_peer_enr_seq(&mut self, node_id: H512, enr_seq: u64, enr_req_hash: Option) { - let peer = self.get_by_node_id_mut(node_id); - let Some(peer) = peer else { - return; - }; - peer.record.seq = enr_seq; - peer.enr_request_hash = enr_req_hash; - } - - pub fn update_peer_ping_with_revalidation(&mut self, node_id: H512, ping_hash: Option) { - let Some(peer) = self.get_by_node_id_mut(node_id) else { - return; - }; - - peer.last_ping_hash = ping_hash; - peer.last_ping = time_now_unix(); - peer.revalidation = Some(false); + peer.last_ping = ping_at; } /// ## Returns @@ -378,18 +359,12 @@ pub struct PeerData { } impl PeerData { - pub fn new( - node: Node, - record: NodeRecord, - last_ping: u64, - last_pong: u64, - is_proven: bool, - ) -> Self { + pub fn new(node: Node, record: NodeRecord, is_proven: bool) -> Self { Self { node, record, - last_ping, - last_pong, + last_ping: 0, + last_pong: 0, is_proven, liveness: 1, last_ping_hash: None, diff --git a/crates/networking/p2p/net.rs b/crates/networking/p2p/net.rs index ceeb0ddb8d..b7d145de8a 100644 --- a/crates/networking/p2p/net.rs +++ b/crates/networking/p2p/net.rs @@ -1,39 +1,27 @@ -use std::{ - collections::HashSet, - io, - net::SocketAddr, - net::{IpAddr, Ipv4Addr}, - sync::Arc, - time::{Duration, SystemTime, UNIX_EPOCH}, -}; - use bootnode::BootNode; use discv4::{ - get_expiration, is_expired, time_now_unix, time_since_in_hs, ENRRequestMessage, - ENRResponseMessage, FindNodeMessage, Message, NeighborsMessage, Packet, PingMessage, - PongMessage, + helpers::current_unix_time, + server::{DiscoveryError, Discv4Server}, }; -use ethrex_core::{H256, H512}; +use ethrex_core::H512; use ethrex_storage::Store; use k256::{ - ecdsa::{signature::hazmat::PrehashVerifier, Signature, SigningKey, VerifyingKey}, + ecdsa::SigningKey, elliptic_curve::{sec1::ToEncodedPoint, PublicKey}, }; pub use kademlia::KademliaTable; -use kademlia::{bucket_number, MAX_NODES_PER_BUCKET}; -use rand::rngs::OsRng; -use rlpx::{connection::RLPxConnection, message::Message as RLPxMessage}; +use rlpx::{ + connection::{RLPxConnBroadcastSender, RLPxConnection}, + message::Message as RLPxMessage, +}; +use std::{io, net::SocketAddr, sync::Arc}; use tokio::{ - net::{TcpListener, TcpSocket, TcpStream, UdpSocket}, - sync::{ - broadcast::{self, Sender}, - Mutex, - }, - task::Id, + net::{TcpListener, TcpSocket, TcpStream}, + sync::Mutex, }; use tokio_util::task::TaskTracker; -use tracing::{debug, error, info}; -use types::{Endpoint, Node, NodeRecord}; +use tracing::{error, info}; +use types::Node; pub mod bootnode; pub(crate) mod discv4; @@ -44,8 +32,6 @@ pub(crate) mod snap; pub mod sync; pub mod types; -const MAX_DISC_PACKET_SIZE: usize = 1280; - // Totally arbitrary limit on how // many messages the connections can queue, // if we miss messages to broadcast, maybe @@ -57,866 +43,71 @@ pub fn peer_table(signer: SigningKey) -> Arc> { Arc::new(Mutex::new(KademliaTable::new(local_node_id))) } -#[derive(Clone)] +#[derive(Debug)] +pub enum NetworkError { + DiscoveryStart(DiscoveryError), +} + +#[derive(Clone, Debug)] struct P2PContext { tracker: TaskTracker, signer: SigningKey, table: Arc>, storage: Store, - broadcast: Sender<(Id, Arc)>, + broadcast: RLPxConnBroadcastSender, local_node: Node, + enr_seq: u64, } -#[allow(clippy::too_many_arguments)] pub async fn start_network( local_node: Node, tracker: TaskTracker, - udp_addr: SocketAddr, - tcp_addr: SocketAddr, bootnodes: Vec, signer: SigningKey, peer_table: Arc>, storage: Store, -) { - info!("Starting discovery service at {udp_addr}"); - info!("Listening for requests at {tcp_addr}"); +) -> Result<(), NetworkError> { let (channel_broadcast_send_end, _) = tokio::sync::broadcast::channel::<( tokio::task::Id, Arc, )>(MAX_MESSAGES_TO_BROADCAST); let context = P2PContext { + local_node, + // Note we are passing the current timestamp as the sequence number + // This is because we are not storing our local_node updates in the db + // see #1756 + enr_seq: current_unix_time(), tracker, signer, table: peer_table, storage, broadcast: channel_broadcast_send_end, - local_node, - }; - - context - .tracker - .spawn(discover_peers(context.clone(), bootnodes)); - context.tracker.spawn(serve_p2p_requests(context.clone())); -} - -#[allow(clippy::too_many_arguments)] -async fn discover_peers(context: P2PContext, bootnodes: Vec) { - let udp_addr = context.local_node.udp_addr(); - let udp_socket = match UdpSocket::bind(udp_addr).await { - Ok(socket) => Arc::new(socket), - Err(e) => { - error!("Error binding udp socket {udp_addr}: {e}. Stopping discover peers task"); - return; - } - }; - - context - .tracker - .spawn(discover_peers_server(context.clone(), udp_socket.clone())); - - context.tracker.spawn(peers_revalidation( - context.clone(), - udp_socket.clone(), - REVALIDATION_INTERVAL_IN_SECONDS as u64, - )); - - discovery_startup(context.clone(), udp_socket.clone(), bootnodes).await; - - // a first initial lookup runs without waiting for the interval - // so we need to allow some time to the pinged peers to ping us back and acknowledge us - tokio::time::sleep(Duration::from_secs(10)).await; - context.tracker.spawn(peers_lookup( - context.clone(), - udp_socket.clone(), - PEERS_RANDOM_LOOKUP_TIME_IN_MIN as u64 * 60, - )); -} - -#[allow(clippy::too_many_arguments)] -async fn discover_peers_server(context: P2PContext, udp_socket: Arc) { - let mut buf = vec![0; MAX_DISC_PACKET_SIZE]; - let udp_addr = context.local_node.udp_addr(); - - loop { - let (read, from) = match udp_socket.recv_from(&mut buf).await { - Ok(result) => result, - Err(e) => { - error!( - "Error receiving data from socket {udp_addr}: {e}. Stopping discovery server" - ); - return; - } - }; - debug!("Received {read} bytes from {from}"); - - match Packet::decode(&buf[..read]) { - Err(e) => error!("Could not decode packet: {:?}", e), - Ok(packet) => { - let msg = packet.get_message(); - debug!("Message: {:?} from {}", msg, packet.get_node_id()); - - match msg { - Message::Ping(msg) => { - if is_expired(msg.expiration) { - debug!("Ignoring ping as it is expired."); - continue; - }; - let node = Node { - ip: from.ip(), - udp_port: from.port(), - tcp_port: msg.from.tcp_port, - node_id: packet.get_node_id(), - }; - let ping_hash = packet.get_hash(); - pong(&udp_socket, node, ping_hash, &context.signer).await; - let peer = { - let table = context.table.lock().await; - table.get_by_node_id(packet.get_node_id()).cloned() - }; - if let Some(peer) = peer { - // send a a ping to get an endpoint proof - if time_since_in_hs(peer.last_ping) >= PROOF_EXPIRATION_IN_HS as u64 { - let hash = ping( - &udp_socket, - context.local_node, - peer.node, - &context.signer, - ) - .await; - if let Some(hash) = hash { - context - .table - .lock() - .await - .update_peer_ping(peer.node.node_id, Some(hash)); - } - } - - // if it has updated its record, send a request to update it - if let Some(enr_seq) = msg.enr_seq { - if enr_seq > peer.record.seq { - debug!("enr-seq outdated, send an enr_request"); - let req_hash = - send_enr_request(&udp_socket, from, &context.signer).await; - context.table.lock().await.update_peer_enr_seq( - peer.node.node_id, - enr_seq, - req_hash, - ); - } - } - } else { - let mut table = context.table.lock().await; - if let (Some(peer), true) = table.insert_node(node) { - // send a ping to get the endpoint proof from our end - let hash = - ping(&udp_socket, context.local_node, node, &context.signer) - .await; - table.update_peer_ping(peer.node.node_id, hash); - } - } - } - Message::Pong(msg) => { - let table = context.table.clone(); - if is_expired(msg.expiration) { - debug!("Ignoring pong as it is expired."); - continue; - } - let peer = { - let table = table.lock().await; - table.get_by_node_id(packet.get_node_id()).cloned() - }; - if let Some(peer) = peer { - if peer.last_ping_hash.is_none() { - debug!("Discarding pong as the node did not send a previous ping"); - continue; - } - - if peer - .last_ping_hash - .is_some_and(|hash| hash == msg.ping_hash) - { - table.lock().await.pong_answered(peer.node.node_id); - // if it has updated its record, send a request to update it - if let Some(enr_seq) = msg.enr_seq { - if enr_seq > peer.record.seq { - debug!("enr-seq outdated, send an enr_request"); - let req_hash = - send_enr_request(&udp_socket, from, &context.signer) - .await; - table.lock().await.update_peer_enr_seq( - peer.node.node_id, - enr_seq, - req_hash, - ); - } - } - - // We won't initiate a connection if we are already connected. - // This will typically be the case when revalidating a node. - if peer.is_connected { - continue; - } - - let mut msg_buf = vec![0; read - 32]; - buf[32..read].clone_into(&mut msg_buf); - let signer = context.signer.clone(); - let storage = context.storage.clone(); - let broadcaster = context.broadcast.clone(); - - context.tracker.spawn(async move { - handle_peer_as_initiator( - signer, - &msg_buf, - &peer.node, - storage, - table, - broadcaster, - ) - .await - }); - } else { - debug!( - "Discarding pong as the hash did not match the last corresponding ping" - ); - } - } else { - debug!("Discarding pong as it is not a known node"); - } - } - Message::FindNode(msg) => { - if is_expired(msg.expiration) { - debug!("Ignoring find node msg as it is expired."); - continue; - }; - let node = { - let table = context.table.lock().await; - table.get_by_node_id(packet.get_node_id()).cloned() - }; - if let Some(node) = node { - if node.is_proven { - let nodes = { - let table = context.table.lock().await; - table.get_closest_nodes(msg.target) - }; - let nodes_chunks = nodes.chunks(4); - let expiration = get_expiration(20); - debug!("Sending neighbors!"); - // we are sending the neighbors in 4 different messages as not to exceed the - // maximum packet size - for nodes in nodes_chunks { - let neighbors = discv4::Message::Neighbors( - NeighborsMessage::new(nodes.to_vec(), expiration), - ); - let mut buf = Vec::new(); - neighbors.encode_with_header(&mut buf, &context.signer); - if let Err(e) = udp_socket.send_to(&buf, from).await { - error!("Could not send Neighbors message {e}"); - } - } - } else { - debug!("Ignoring find node message as the node isn't proven!"); - } - } else { - debug!("Ignoring find node message as it is not a known node"); - } - } - Message::Neighbors(neighbors_msg) => { - if is_expired(neighbors_msg.expiration) { - debug!("Ignoring neighbor msg as it is expired."); - continue; - }; - - let mut nodes_to_insert = None; - let mut table = context.table.lock().await; - if let Some(node) = table.get_by_node_id_mut(packet.get_node_id()) { - if let Some(req) = &mut node.find_node_request { - if time_now_unix().saturating_sub(req.sent_at) >= 60 { - debug!("Ignoring neighbors message as the find_node request expires after one minute"); - node.find_node_request = None; - continue; - } - let nodes = &neighbors_msg.nodes; - let nodes_sent = req.nodes_sent + nodes.len(); - - if nodes_sent <= MAX_NODES_PER_BUCKET { - debug!("Storing neighbors in our table!"); - req.nodes_sent = nodes_sent; - nodes_to_insert = Some(nodes.clone()); - if let Some(tx) = &req.tx { - let _ = tx.send(nodes.clone()); - } - } else { - debug!("Ignoring neighbors message as the client sent more than the allowed nodes"); - } - - if nodes_sent == MAX_NODES_PER_BUCKET { - debug!("Neighbors request has been fulfilled"); - node.find_node_request = None; - } - } - } else { - debug!("Ignoring neighbor msg as it is not a known node"); - } - - if let Some(nodes) = nodes_to_insert { - for node in nodes { - if let (Some(peer), true) = table.insert_node(node) { - let ping_hash = ping( - &udp_socket, - context.local_node, - peer.node, - &context.signer, - ) - .await; - table.update_peer_ping(peer.node.node_id, ping_hash); - } - } - } - } - Message::ENRRequest(msg) => { - if is_expired(msg.expiration) { - debug!("Ignoring enr-request msg as it is expired."); - continue; - } - // Note we are passing the current timestamp as the sequence number - // This is because we are not storing our local_node updates in the db - let Ok(node_record) = NodeRecord::from_node( - context.local_node, - time_now_unix(), - &context.signer, - ) else { - debug!("Ignoring enr-request msg could not build local node record."); - continue; - }; - let msg = discv4::Message::ENRResponse(ENRResponseMessage::new( - packet.get_hash(), - node_record, - )); - let mut buf = vec![]; - msg.encode_with_header(&mut buf, &context.signer); - let _ = udp_socket.send_to(&buf, from).await; - } - Message::ENRResponse(msg) => { - let mut table = context.table.lock().await; - let peer = table.get_by_node_id_mut(packet.get_node_id()); - let Some(peer) = peer else { - debug!("Discarding enr-response as we don't know the peer"); - continue; - }; - - let Some(req_hash) = peer.enr_request_hash else { - debug!("Discarding enr-response as it wasn't requested"); - continue; - }; - if req_hash != msg.request_hash { - debug!("Discarding enr-response as the request hash did not match"); - continue; - } - peer.enr_request_hash = None; - - if msg.node_record.seq < peer.record.seq { - debug!( - "Discarding enr-response as the record seq is lower than the one we have" - ); - continue; - } - - let record = msg.node_record.decode_pairs(); - let Some(id) = record.id else { - debug!( - "Discarding enr-response as record does not have the `id` field" - ); - continue; - }; - - // https://github.com/ethereum/devp2p/blob/master/enr.md#v4-identity-scheme - let signature_valid = match id.as_str() { - "v4" => { - let digest = msg.node_record.get_signature_digest(); - let Some(public_key) = record.secp256k1 else { - debug!("Discarding enr-response as signature could not be verified because public key was not provided"); - continue; - }; - let signature_bytes = msg.node_record.signature.as_bytes(); - let Ok(signature) = Signature::from_slice(&signature_bytes[0..64]) - else { - debug!("Discarding enr-response as signature could not be build from msg signature bytes"); - continue; - }; - let Ok(verifying_key) = - VerifyingKey::from_sec1_bytes(public_key.as_bytes()) - else { - debug!("Discarding enr-response as public key could no be built from msg pub key bytes"); - continue; - }; - verifying_key.verify_prehash(&digest, &signature).is_ok() - } - _ => false, - }; - if !signature_valid { - debug!( - "Discarding enr-response as the signature verification was invalid" - ); - continue; - } - - if let Some(ip) = record.ip { - peer.node.ip = IpAddr::from(Ipv4Addr::from_bits(ip)); - } - if let Some(tcp_port) = record.tcp_port { - peer.node.tcp_port = tcp_port; - } - if let Some(udp_port) = record.udp_port { - peer.node.udp_port = udp_port; - } - peer.record = msg.node_record.clone(); - debug!( - "Node with id {:?} record has been successfully updated", - peer.node.node_id - ); - } - } - } - } - } -} - -// this is just an arbitrary number, maybe we should get this from some kind of cfg -/// This is a really basic startup and should be improved when we have the nodes stored in the db -/// currently, since we are not storing nodes, the only way to have startup nodes is by providing -/// an array of bootnodes. -async fn discovery_startup( - context: P2PContext, - udp_socket: Arc, - bootnodes: Vec, -) { - for bootnode in bootnodes { - let node = Node { - ip: bootnode.socket_address.ip(), - udp_port: bootnode.socket_address.port(), - // TODO: udp port can differ from tcp port. - // see https://github.com/lambdaclass/ethrex/issues/905 - tcp_port: bootnode.socket_address.port(), - node_id: bootnode.node_id, - }; - context.table.lock().await.insert_node(node); - let ping_hash = ping(&udp_socket, context.local_node, node, &context.signer).await; - context - .table - .lock() - .await - .update_peer_ping(bootnode.node_id, ping_hash); - } -} - -const REVALIDATION_INTERVAL_IN_SECONDS: usize = 30; // this is just an arbitrary number, maybe we should get this from some kind of cfg -const PROOF_EXPIRATION_IN_HS: usize = 12; - -/// Starts a tokio scheduler that: -/// - performs periodic revalidation of the current nodes (sends a ping to the old nodes). Currently this is configured to happen every [`REVALIDATION_INTERVAL_IN_MINUTES`] -/// -/// **Peer revalidation** -/// -/// Peers revalidation works in the following manner: -/// 1. Every `REVALIDATION_INTERVAL_IN_SECONDS` we ping the 3 least recently pinged peers -/// 2. In the next iteration we check if they have answered -/// - if they have: we increment the liveness field by one -/// - otherwise we decrement it by the current value / 3. -/// 3. If the liveness field is 0, then we delete it and insert a new one from the replacements table -/// -/// See more https://github.com/ethereum/devp2p/blob/master/discv4.md#kademlia-table -async fn peers_revalidation( - context: P2PContext, - udp_socket: Arc, - interval_time_in_seconds: u64, -) { - let mut interval = tokio::time::interval(Duration::from_secs(interval_time_in_seconds)); - // peers we have pinged in the previous iteration - let mut previously_pinged_peers: HashSet = HashSet::default(); - - // first tick starts immediately - interval.tick().await; - - loop { - interval.tick().await; - debug!("Running peer revalidation"); - - // first check that the peers we ping have responded - for node_id in previously_pinged_peers { - let mut table = context.table.lock().await; - if let Some(peer) = table.get_by_node_id_mut(node_id) { - if let Some(has_answered) = peer.revalidation { - if has_answered { - peer.increment_liveness(); - } else { - peer.decrement_liveness(); - } - } - - peer.revalidation = None; - - if peer.liveness == 0 { - let new_peer = table.replace_peer(node_id); - if let Some(new_peer) = new_peer { - let ping_hash = ping( - &udp_socket, - context.local_node, - new_peer.node, - &context.signer, - ) - .await; - table.update_peer_ping(new_peer.node.node_id, ping_hash); - } - } - } - } - - // now send a ping to the least recently pinged peers - // this might be too expensive to run if our table is filled - // maybe we could just pick them randomly - let peers = context - .table - .lock() - .await - .get_least_recently_pinged_peers(3); - previously_pinged_peers = HashSet::default(); - for peer in peers { - let ping_hash = ping(&udp_socket, context.local_node, peer.node, &context.signer).await; - let mut table = context.table.lock().await; - table.update_peer_ping_with_revalidation(peer.node.node_id, ping_hash); - previously_pinged_peers.insert(peer.node.node_id); - - debug!("Pinging peer {:?} to re-validate!", peer.node.node_id); - } - - debug!("Peer revalidation finished"); - } -} - -const PEERS_RANDOM_LOOKUP_TIME_IN_MIN: usize = 30; - -/// Starts a tokio scheduler that: -/// - performs random lookups to discover new nodes. Currently this is configure to run every `PEERS_RANDOM_LOOKUP_TIME_IN_MIN` -/// -/// **Random lookups** -/// -/// Random lookups work in the following manner: -/// 1. Every 30min we spawn three concurrent lookups: one closest to our pubkey -/// and three other closest to random generated pubkeys. -/// 2. Every lookup starts with the closest nodes from our table. -/// Each lookup keeps track of: -/// - Peers that have already been asked for nodes -/// - Peers that have been already seen -/// - Potential peers to query for nodes: a vector of up to 16 entries holding the closest peers to the pubkey. -/// This vector is initially filled with nodes from our table. -/// 3. We send a `find_node` to the closest 3 nodes (that we have not yet asked) from the pubkey. -/// 4. We wait for the neighbors response and pushed or replace those that are closer to the potential peers. -/// 5. We select three other nodes from the potential peers vector and do the same until one lookup -/// doesn't have any node to ask. -/// -/// See more https://github.com/ethereum/devp2p/blob/master/discv4.md#recursive-lookup -async fn peers_lookup( - context: P2PContext, - udp_socket: Arc, - interval_time_in_seconds: u64, -) { - let mut interval = tokio::time::interval(Duration::from_secs(interval_time_in_seconds)); - - loop { - // Notice that the first tick is immediate, - // so as soon as the server starts we'll do a lookup with the seeder nodes. - interval.tick().await; - - debug!("Starting lookup"); - - // lookup closest to our pub key - context.tracker.spawn(recursive_lookup( - context.clone(), - udp_socket.clone(), - context.local_node.node_id, - )); - - // lookup closest to 3 random keys - for _ in 0..3 { - let random_pub_key = &SigningKey::random(&mut OsRng); - context.tracker.spawn(recursive_lookup( - context.clone(), - udp_socket.clone(), - node_id_from_signing_key(random_pub_key), - )); - } - - debug!("Lookup finished"); - } -} - -async fn recursive_lookup(context: P2PContext, udp_socket: Arc, target: H512) { - let mut asked_peers = HashSet::default(); - // lookups start with the closest from our table - let closest_nodes = context.table.lock().await.get_closest_nodes(target); - let mut seen_peers: HashSet = HashSet::default(); - - seen_peers.insert(context.local_node.node_id); - for node in &closest_nodes { - seen_peers.insert(node.node_id); - } - - let mut peers_to_ask: Vec = closest_nodes; - - loop { - let (nodes_found, queries) = lookup( - udp_socket.clone(), - context.table.clone(), - &context.signer, - target, - &mut asked_peers, - &peers_to_ask, - ) - .await; - - // only push the peers that have not been seen - // that is those who have not been yet pushed, which also accounts for - // those peers that were in the array but have been replaced for closer peers - for node in nodes_found { - if !seen_peers.contains(&node.node_id) { - seen_peers.insert(node.node_id); - peers_to_ask_push(&mut peers_to_ask, target, node); - } - } - - // the lookup finishes when there are no more queries to do - // that happens when we have asked all the peers - if queries == 0 { - break; - } - } -} - -async fn lookup( - udp_socket: Arc, - table: Arc>, - signer: &SigningKey, - target: H512, - asked_peers: &mut HashSet, - nodes_to_ask: &Vec, -) -> (Vec, u32) { - let alpha = 3; - let mut queries = 0; - let mut nodes = vec![]; - - for node in nodes_to_ask { - if !asked_peers.contains(&node.node_id) { - let mut locked_table = table.lock().await; - if let Some(peer) = locked_table.get_by_node_id_mut(node.node_id) { - // if the peer has an ongoing find_node request, don't query - if peer.find_node_request.is_none() { - let (tx, mut receiver) = tokio::sync::mpsc::unbounded_channel::>(); - peer.new_find_node_request_with_sender(tx); - - // Release the lock - drop(locked_table); - - queries += 1; - asked_peers.insert(node.node_id); - let mut found_nodes = find_node_and_wait_for_response( - &udp_socket, - SocketAddr::new(node.ip, node.udp_port), - signer, - target, - &mut receiver, - ) - .await; - nodes.append(&mut found_nodes) - } - } - } - - if queries == alpha { - break; - } - } - - (nodes, queries) -} - -fn peers_to_ask_push(peers_to_ask: &mut Vec, target: H512, node: Node) { - let distance = bucket_number(target, node.node_id); - - if peers_to_ask.len() < MAX_NODES_PER_BUCKET { - peers_to_ask.push(node); - return; - } - - // replace this node for the one whose distance to the target is the highest - let (mut idx_to_replace, mut highest_distance) = (None, 0); - - for (i, peer) in peers_to_ask.iter().enumerate() { - let current_distance = bucket_number(peer.node_id, target); - - if distance < current_distance && current_distance >= highest_distance { - highest_distance = current_distance; - idx_to_replace = Some(i); - } - } - - if let Some(idx) = idx_to_replace { - peers_to_ask[idx] = node; - } -} - -/// Sends a ping to the addr -/// # Returns -/// an optional hash corresponding to the message header hash to account if the send was successful -async fn ping( - socket: &UdpSocket, - local_node: Node, - node: Node, - signer: &SigningKey, -) -> Option { - let mut buf = Vec::new(); - - let expiration: u64 = (SystemTime::now() + Duration::from_secs(20)) - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - - let from = Endpoint { - ip: local_node.ip, - udp_port: local_node.udp_port, - tcp_port: local_node.tcp_port, - }; - let to = Endpoint { - ip: node.ip, - udp_port: node.udp_port, - tcp_port: node.tcp_port, }; - - let ping = - discv4::Message::Ping(PingMessage::new(from, to, expiration).with_enr_seq(time_now_unix())); - ping.encode_with_header(&mut buf, signer); - - // Send ping and log if error - match socket - .send_to(&buf, SocketAddr::new(to.ip, to.udp_port)) + let discovery = Discv4Server::try_new(context.clone()) .await - { - Ok(bytes_sent) => { - // sanity check to make sure the ping was well sent - // though idk if this is actually needed or if it might break other stuff - if bytes_sent == buf.len() { - return Some(H256::from_slice(&buf[0..32])); - } - } - Err(e) => error!("Unable to send ping: {e}"), - } - - None -} - -async fn find_node_and_wait_for_response( - socket: &UdpSocket, - to_addr: SocketAddr, - signer: &SigningKey, - target_node_id: H512, - request_receiver: &mut tokio::sync::mpsc::UnboundedReceiver>, -) -> Vec { - let expiration: u64 = (SystemTime::now() + Duration::from_secs(20)) - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - - let msg: discv4::Message = - discv4::Message::FindNode(FindNodeMessage::new(target_node_id, expiration)); + .map_err(NetworkError::DiscoveryStart)?; - let mut buf = Vec::new(); - msg.encode_with_header(&mut buf, signer); - let mut nodes = vec![]; - - if socket.send_to(&buf, to_addr).await.is_err() { - return nodes; - } - - loop { - // wait as much as 5 seconds for the response - match tokio::time::timeout(Duration::from_secs(5), request_receiver.recv()).await { - Ok(Some(mut found_nodes)) => { - nodes.append(&mut found_nodes); - if nodes.len() == MAX_NODES_PER_BUCKET { - return nodes; - }; - } - Ok(None) => { - return nodes; - } - Err(_) => { - // timeout expired - return nodes; - } - } - } -} - -async fn pong(socket: &UdpSocket, node: Node, ping_hash: H256, signer: &SigningKey) { - let mut buf = Vec::new(); - - let expiration: u64 = (SystemTime::now() + Duration::from_secs(20)) - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - - let to = Endpoint { - ip: node.ip, - udp_port: node.udp_port, - tcp_port: node.tcp_port, - }; - let pong: discv4::Message = discv4::Message::Pong( - PongMessage::new(to, ping_hash, expiration).with_enr_seq(time_now_unix()), + info!( + "Starting discovery service at {}", + context.local_node.udp_addr() ); - - pong.encode_with_header(&mut buf, signer); - - // Send pong and log if error - if let Err(e) = socket - .send_to(&buf, SocketAddr::new(node.ip, node.udp_port)) + discovery + .start(bootnodes) .await - { - error!("Unable to send pong: {e}") - } -} - -async fn send_enr_request( - socket: &UdpSocket, - to_addr: SocketAddr, - signer: &SigningKey, -) -> Option { - let mut buf = Vec::new(); + .map_err(NetworkError::DiscoveryStart)?; - let expiration: u64 = (SystemTime::now() + Duration::from_secs(20)) - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - - let enr_req = discv4::Message::ENRRequest(ENRRequestMessage::new(expiration)); - - enr_req.encode_with_header(&mut buf, signer); - - let bytes_sent = socket.send_to(&buf, to_addr).await.ok()?; - if bytes_sent != buf.len() { - debug!( - "ENR request message partially sent: {} out of {} bytes.", - bytes_sent, - buf.len() - ); - return None; - } + info!( + "Listening for requests at {}", + context.local_node.tcp_addr() + ); + context.tracker.spawn(serve_p2p_requests(context.clone())); - Some(H256::from_slice(&buf[0..32])) + Ok(()) } async fn serve_p2p_requests(context: P2PContext) { - let tcp_addr = SocketAddr::new(context.local_node.ip, context.local_node.tcp_port); + let tcp_addr = context.local_node.tcp_addr(); let listener = match listener(tcp_addr) { Ok(result) => result, Err(e) => { @@ -951,14 +142,7 @@ async fn handle_peer_as_receiver(context: P2PContext, peer_addr: SocketAddr, str conn.start_peer(peer_addr, context.table).await; } -async fn handle_peer_as_initiator( - signer: SigningKey, - msg: &[u8], - node: &Node, - storage: Store, - table: Arc>, - connection_broadcast: broadcast::Sender<(tokio::task::Id, Arc)>, -) { +async fn handle_peer_as_initiator(context: P2PContext, node: Node) { let addr = SocketAddr::new(node.ip, node.tcp_port); let stream = match tcp_stream(addr).await { Ok(result) => result, @@ -970,12 +154,14 @@ async fn handle_peer_as_initiator( return; } }; - - match RLPxConnection::initiator(signer, msg, stream, storage, connection_broadcast) { - Ok(mut conn) => { - conn.start_peer(SocketAddr::new(node.ip, node.udp_port), table) - .await - } + match RLPxConnection::initiator( + context.signer, + node.node_id, + stream, + context.storage, + context.broadcast, + ) { + Ok(mut conn) => conn.start_peer(node.udp_addr(), context.table).await, Err(e) => { // TODO We should remove the peer from the table if connection failed // but currently it will make the tests fail @@ -1004,412 +190,3 @@ pub async fn periodically_show_peer_stats(peer_table: Arc>) interval.tick().await; } } - -#[cfg(test)] -mod tests { - use super::*; - use ethrex_storage::EngineType; - use kademlia::bucket_number; - use rand::rngs::OsRng; - use std::{ - collections::HashSet, - net::{IpAddr, Ipv4Addr}, - }; - use tokio::time::sleep; - - async fn insert_random_node_on_custom_bucket( - table: Arc>, - bucket_idx: usize, - ) { - let node_id = node_id_from_signing_key(&SigningKey::random(&mut OsRng)); - let node = Node { - ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - tcp_port: 0, - udp_port: 0, - node_id, - }; - table - .lock() - .await - .insert_node_on_custom_bucket(node, bucket_idx); - } - - async fn fill_table_with_random_nodes(table: Arc>) { - for i in 0..256 { - for _ in 0..16 { - insert_random_node_on_custom_bucket(table.clone(), i).await; - } - } - } - - struct MockServer { - pub local_node: Node, - pub addr: SocketAddr, - pub signer: SigningKey, - pub table: Arc>, - pub node_id: H512, - pub udp_socket: Arc, - } - - async fn start_mock_discovery_server( - udp_port: u16, - should_start_server: bool, - ) -> Result { - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), udp_port); - let signer = SigningKey::random(&mut OsRng); - let udp_socket = Arc::new(UdpSocket::bind(addr).await?); - let node_id = node_id_from_signing_key(&signer); - let storage = - Store::new("temp.db", EngineType::InMemory).expect("Failed to create test DB"); - let table = Arc::new(Mutex::new(KademliaTable::new(node_id))); - let (channel_broadcast_send_end, _) = tokio::sync::broadcast::channel::<( - tokio::task::Id, - Arc, - )>(MAX_MESSAGES_TO_BROADCAST); - let local_node = Node { - ip: addr.ip(), - tcp_port: addr.port(), - udp_port: addr.port(), - node_id, - }; - let tracker = TaskTracker::new(); - - let context = P2PContext { - tracker, - signer, - table, - storage, - broadcast: channel_broadcast_send_end, - local_node, - }; - if should_start_server { - context - .tracker - .spawn(discover_peers_server(context.clone(), udp_socket.clone())); - } - - Ok(MockServer { - local_node, - addr, - signer: context.signer, - table: context.table, - node_id, - udp_socket, - }) - } - - /// connects two mock servers by pinging a to b - async fn connect_servers(server_a: &mut MockServer, server_b: &mut MockServer) { - let ping_hash = ping( - &server_a.udp_socket, - server_a.local_node, - server_b.local_node, - &server_a.signer, - ) - .await; - { - let mut table = server_a.table.lock().await; - table.insert_node(Node { - ip: server_b.local_node.ip, - udp_port: server_b.local_node.udp_port, - tcp_port: server_b.local_node.tcp_port, - node_id: server_b.node_id, - }); - table.update_peer_ping(server_b.node_id, ping_hash); - } - // allow some time for the server to respond - sleep(Duration::from_secs(1)).await; - } - - #[tokio::test] - /** This is a end to end test on the discovery server, the idea is as follows: - * - We'll start two discovery servers (`a` & `b`) to ping between each other - * - We'll make `b` ping `a`, and validate that the connection is right - * - Then we'll wait for a revalidation where we expect everything to be the same - * - We'll do this five 5 more times - * - Then we'll stop server `a` so that it doesn't respond to re-validations - * - We expect server `b` to remove node `a` from its table after 3 re-validations - * To make this run faster, we'll change the revalidation time to be every 2secs - */ - async fn discovery_server_revalidation() -> Result<(), io::Error> { - let mut server_a = start_mock_discovery_server(7998, true).await?; - let mut server_b = start_mock_discovery_server(7999, true).await?; - - connect_servers(&mut server_a, &mut server_b).await; - - let (channel_broadcast_send_end, _) = tokio::sync::broadcast::channel::<( - tokio::task::Id, - Arc, - )>(MAX_MESSAGES_TO_BROADCAST); - let storage = - Store::new("temp.db", EngineType::InMemory).expect("Failed to create test DB"); - - let context = P2PContext { - tracker: TaskTracker::new(), - signer: server_b.signer.clone(), - table: server_b.table.clone(), - storage, - broadcast: channel_broadcast_send_end, - local_node: server_b.local_node, - }; - - // start revalidation server - tokio::spawn(peers_revalidation(context, server_b.udp_socket.clone(), 2)); - - for _ in 0..5 { - sleep(Duration::from_millis(2500)).await; - // by now, b should've send a revalidation to a - let table = server_b.table.lock().await; - let node = table.get_by_node_id(server_a.node_id); - assert!(node.is_some_and(|n| n.revalidation.is_some())); - } - - // make sure that `a` has responded too all the re-validations - // we can do that by checking the liveness - { - let table = server_b.table.lock().await; - let node = table.get_by_node_id(server_a.node_id); - assert_eq!(node.map_or(0, |n| n.liveness), 6); - } - - // now, stopping server `a` is not trivial - // so we'll instead change its port, so that no one responds - { - let mut table = server_b.table.lock().await; - let node = table.get_by_node_id_mut(server_a.node_id); - if let Some(node) = node { - node.node.udp_port = 0 - }; - } - - // now the liveness field should start decreasing until it gets to 0 - // which should happen in 3 re-validations - for _ in 0..2 { - sleep(Duration::from_millis(2500)).await; - let table = server_b.table.lock().await; - let node = table.get_by_node_id(server_a.node_id); - assert!(node.is_some_and(|n| n.revalidation.is_some())); - } - sleep(Duration::from_millis(2500)).await; - - // finally, `a`` should not exist anymore - let table = server_b.table.lock().await; - assert!(table.get_by_node_id(server_a.node_id).is_none()); - Ok(()) - } - - #[tokio::test] - /** This test tests the lookup function, the idea is as follows: - * - We'll start two discovery servers (`a` & `b`) that will connect between each other - * - We'll insert random nodes to the server `a`` to fill its table - * - We'll forcedly run `lookup` and validate that a `find_node` request was sent - * by checking that new nodes have been inserted to the table - * - * This test for only one lookup, and not recursively. - */ - async fn discovery_server_lookup() -> Result<(), io::Error> { - let mut server_a = start_mock_discovery_server(8000, true).await?; - let mut server_b = start_mock_discovery_server(8001, true).await?; - - fill_table_with_random_nodes(server_a.table.clone()).await; - - // before making the connection, remove a node from the `b` bucket. Otherwise it won't be added - let b_bucket = bucket_number(server_a.node_id, server_b.node_id); - let node_id_to_remove = server_a.table.lock().await.buckets()[b_bucket].peers[0] - .node - .node_id; - server_a - .table - .lock() - .await - .replace_peer_on_custom_bucket(node_id_to_remove, b_bucket); - - connect_servers(&mut server_a, &mut server_b).await; - - // now we are going to run a lookup with us as the target - let closets_peers_to_b_from_a = server_a - .table - .lock() - .await - .get_closest_nodes(server_b.node_id); - let nodes_to_ask = server_b - .table - .lock() - .await - .get_closest_nodes(server_b.node_id); - - lookup( - server_b.udp_socket.clone(), - server_b.table.clone(), - &server_b.signer, - server_b.node_id, - &mut HashSet::default(), - &nodes_to_ask, - ) - .await; - - // find_node sent, allow some time for `a` to respond - sleep(Duration::from_secs(2)).await; - - // now all peers should've been inserted - for peer in closets_peers_to_b_from_a { - let table = server_b.table.lock().await; - assert!(table.get_by_node_id(peer.node_id).is_some()); - } - Ok(()) - } - - #[tokio::test] - /** This test tests the lookup function, the idea is as follows: - * - We'll start four discovery servers (`a`, `b`, `c` & `d`) - * - `a` will be connected to `b`, `b` will be connected to `c` and `c` will be connected to `d`. - * - The server `d` will have its table filled with mock nodes - * - We'll run a recursive lookup on server `a` and we expect to end with `b`, `c`, `d` and its mock nodes - */ - async fn discovery_server_recursive_lookup() -> Result<(), io::Error> { - let mut server_a = start_mock_discovery_server(8002, true).await?; - let mut server_b = start_mock_discovery_server(8003, true).await?; - let mut server_c = start_mock_discovery_server(8004, true).await?; - let mut server_d = start_mock_discovery_server(8005, true).await?; - - connect_servers(&mut server_a, &mut server_b).await; - connect_servers(&mut server_b, &mut server_c).await; - connect_servers(&mut server_c, &mut server_d).await; - - // now we fill the server_d table with 3 random nodes - // the reason we don't put more is because this nodes won't respond (as they don't are not real servers) - // and so we will have to wait for the timeout on each node, which will only slow down the test - for _ in 0..3 { - insert_random_node_on_custom_bucket(server_d.table.clone(), 0).await; - } - - let mut expected_peers = vec![]; - expected_peers.extend( - server_b - .table - .lock() - .await - .get_closest_nodes(server_a.node_id), - ); - expected_peers.extend( - server_c - .table - .lock() - .await - .get_closest_nodes(server_a.node_id), - ); - expected_peers.extend( - server_d - .table - .lock() - .await - .get_closest_nodes(server_a.node_id), - ); - - let (channel_broadcast_send_end, _) = tokio::sync::broadcast::channel::<( - tokio::task::Id, - Arc, - )>(MAX_MESSAGES_TO_BROADCAST); - let storage = - Store::new("temp.db", EngineType::InMemory).expect("Failed to create test DB"); - - let context = P2PContext { - tracker: TaskTracker::new(), - signer: server_a.signer.clone(), - table: server_a.table.clone(), - storage, - broadcast: channel_broadcast_send_end, - local_node: server_a.local_node, - }; - - // we'll run a recursive lookup closest to the server itself - recursive_lookup(context, server_a.udp_socket.clone(), server_a.node_id).await; - - for peer in expected_peers { - assert!(server_a - .table - .lock() - .await - .get_by_node_id(peer.node_id) - .is_some()); - } - Ok(()) - } - - #[tokio::test] - /** - * This test verifies the exchange and update of ENR (Ethereum Node Record) messages. - * The test follows these steps: - * - * 1. Start two nodes. - * 2. Wait until they establish a connection. - * 3. Assert that they exchange their records and store them - * 3. Modify the ENR (node record) of one of the nodes. - * 4. Send a new ping message and check that an ENR request was triggered. - * 5. Verify that the updated node record has been correctly received and stored. - */ - async fn discovery_enr_message() -> Result<(), io::Error> { - let mut server_a = start_mock_discovery_server(8006, true).await?; - let mut server_b = start_mock_discovery_server(8007, true).await?; - - connect_servers(&mut server_a, &mut server_b).await; - - // wait some time for the enr request-response finishes - sleep(Duration::from_millis(2500)).await; - - let expected_record = - NodeRecord::from_node(server_b.local_node, time_now_unix(), &server_b.signer) - .expect("Node record is created from node"); - - let server_a_peer_b = server_a - .table - .lock() - .await - .get_by_node_id(server_b.node_id) - .cloned() - .unwrap(); - - // we only match the pairs, as the signature and seq will change - // because they are calculated with the current time - assert!(server_a_peer_b.record.decode_pairs() == expected_record.decode_pairs()); - - // Modify server_a's record of server_b with an incorrect TCP port. - // This simulates an outdated or incorrect entry in the node table. - server_a - .table - .lock() - .await - .get_by_node_id_mut(server_b.node_id) - .unwrap() - .node - .tcp_port = 10; - - // Send a ping from server_b to server_a. - // server_a should notice the enr_seq is outdated - // and trigger a enr-request to server_b to update the record. - ping( - &server_b.udp_socket, - server_b.local_node, - server_a.local_node, - &server_b.signer, - ) - .await; - - // Wait for the update to propagate. - sleep(Duration::from_millis(2500)).await; - - // Verify that server_a has updated its record of server_b with the correct TCP port. - let tcp_port = server_a - .table - .lock() - .await - .get_by_node_id(server_b.node_id) - .unwrap() - .node - .tcp_port; - - assert!(tcp_port == server_b.addr.port()); - - Ok(()) - } -} diff --git a/crates/networking/p2p/rlpx/connection.rs b/crates/networking/p2p/rlpx/connection.rs index 7ad881242a..6158d337bd 100644 --- a/crates/networking/p2p/rlpx/connection.rs +++ b/crates/networking/p2p/rlpx/connection.rs @@ -18,7 +18,6 @@ use crate::{ process_account_range_request, process_byte_codes_request, process_storage_ranges_request, process_trie_nodes_request, }, - MAX_DISC_PACKET_SIZE, }; use super::{ @@ -28,16 +27,12 @@ use super::{ handshake::{decode_ack_message, decode_auth_message, encode_auth_message}, message as rlpx, p2p::Capability, - utils::pubkey2id, }; use ethrex_blockchain::mempool::{self}; use ethrex_core::{H256, H512}; use ethrex_storage::Store; use futures::SinkExt; -use k256::{ - ecdsa::{RecoveryId, Signature, SigningKey, VerifyingKey}, - PublicKey, SecretKey, -}; +use k256::{ecdsa::SigningKey, PublicKey, SecretKey}; use rand::random; use sha3::{Digest, Keccak256}; use tokio::{ @@ -60,6 +55,11 @@ const PERIODIC_TASKS_CHECK_INTERVAL: std::time::Duration = std::time::Duration:: pub(crate) type Aes256Ctr64BE = ctr::Ctr64BE; +pub(crate) type RLPxConnBroadcastSender = broadcast::Sender<(tokio::task::Id, Arc)>; + +// https://github.com/ethereum/go-ethereum/blob/master/p2p/peer.go#L44 +pub const P2P_MAX_MESSAGE_SIZE: usize = 2048; + enum RLPxConnectionMode { Initiator, Receiver, @@ -94,7 +94,7 @@ pub(crate) struct RLPxConnection { /// messages from other connections (sent from other peers). /// The receive end is instantiated after the handshake is completed /// under `handle_peer`. - connection_broadcast_send: broadcast::Sender<(task::Id, Arc)>, + connection_broadcast_send: RLPxConnBroadcastSender, } impl RLPxConnection { @@ -104,7 +104,7 @@ impl RLPxConnection { stream: S, mode: RLPxConnectionMode, storage: Store, - connection_broadcast: broadcast::Sender<(task::Id, Arc)>, + connection_broadcast: RLPxConnBroadcastSender, ) -> Self { Self { signer, @@ -138,23 +138,14 @@ impl RLPxConnection { pub fn initiator( signer: SigningKey, - msg: &[u8], + remote_node_id: H512, stream: S, storage: Store, connection_broadcast_send: broadcast::Sender<(task::Id, Arc)>, ) -> Result { - let digest = Keccak256::digest(msg.get(65..).ok_or(RLPxError::InvalidMessageLength())?); - let signature = &Signature::from_bytes( - msg.get(..64) - .ok_or(RLPxError::InvalidMessageLength())? - .into(), - )?; - let rid = RecoveryId::from_byte(*msg.get(64).ok_or(RLPxError::InvalidMessageLength())?) - .ok_or(RLPxError::InvalidRecoveryId())?; - let peer_pk = VerifyingKey::recover_from_prehash(&digest, signature, rid)?; Ok(RLPxConnection::new( signer, - pubkey2id(&peer_pk.into()), + remote_node_id, stream, RLPxConnectionMode::Initiator, storage, @@ -606,12 +597,16 @@ impl RLPxConnection { } async fn receive_handshake_msg(&mut self) -> Result, RLPxError> { - let mut buf = vec![0; MAX_DISC_PACKET_SIZE]; + let mut buf = vec![0; 2]; // Read the message's size - self.framed.get_mut().read_exact(&mut buf[..2]).await?; + self.framed.get_mut().read_exact(&mut buf).await?; let ack_data = [buf[0], buf[1]]; let msg_size = u16::from_be_bytes(ack_data) as usize; + if msg_size > P2P_MAX_MESSAGE_SIZE { + return Err(RLPxError::InvalidMessageLength()); + } + buf.resize(msg_size + 2, 0); // Read the rest of the message self.framed diff --git a/crates/networking/p2p/types.rs b/crates/networking/p2p/types.rs index 1a01d0a5ac..24eba79051 100644 --- a/crates/networking/p2p/types.rs +++ b/crates/networking/p2p/types.rs @@ -213,7 +213,7 @@ impl RLPDecode for NodeRecord { let decoder = Decoder::new(rlp)?; let (signature, decoder) = decoder.decode_field("signature")?; let (seq, decoder) = decoder.decode_field("seq")?; - let (pairs, decoder) = decode_node_record_optional_fields(vec![], decoder); + let (pairs, decoder) = decode_node_record_optional_fields(vec![], decoder)?; // all fields in pairs are optional except for id let id_pair = pairs.iter().find(|(k, _v)| k.eq("id".as_bytes())); @@ -240,14 +240,14 @@ impl RLPDecode for NodeRecord { fn decode_node_record_optional_fields( mut pairs: Vec<(Bytes, Bytes)>, decoder: Decoder, -) -> (Vec<(Bytes, Bytes)>, Decoder) { +) -> Result<(Vec<(Bytes, Bytes)>, Decoder), RLPDecodeError> { let (key, decoder): (Option, Decoder) = decoder.decode_optional_field(); if let Some(k) = key { - let (value, decoder): (Vec, Decoder) = decoder.get_encoded_item().unwrap(); + let (value, decoder): (Vec, Decoder) = decoder.get_encoded_item()?; pairs.push((k, Bytes::from(value))); decode_node_record_optional_fields(pairs, decoder) } else { - (pairs, decoder) + Ok((pairs, decoder)) } }