diff --git a/redis/src/cluster_async/connections_container.rs b/redis/src/cluster_async/connections_container.rs index f19318955..5eae09b3c 100644 --- a/redis/src/cluster_async/connections_container.rs +++ b/redis/src/cluster_async/connections_container.rs @@ -11,7 +11,8 @@ type IdentifierType = ArcStr; #[derive(Clone, Eq, PartialEq, Debug)] pub(crate) struct ClusterNode { - pub connection: Connection, + pub user_connection: Connection, + pub management_connection: Option, pub ip: Option, } @@ -19,10 +20,36 @@ impl ClusterNode where Connection: Clone, { - pub(crate) fn new(connection: Connection, ip: Option) -> Self { - Self { connection, ip } + pub(crate) fn new( + user_connection: Connection, + management_connection: Option, + ip: Option, + ) -> Self { + Self { + user_connection, + management_connection, + ip, + } } + + pub(crate) fn get_connection(&self, conn_type: &ConnectionType) -> Connection { + match conn_type { + ConnectionType::User => self.user_connection.clone(), + ConnectionType::_Management => self + .management_connection + .clone() + .unwrap_or_else(|| self.user_connection.clone()), + } + } +} + +#[derive(Clone, Eq, PartialEq, Debug)] + +pub(crate) enum ConnectionType { + User, + _Management, } + /// This opaque type allows us to change the way that the connections are organized /// internally without refactoring the calling code. #[derive(Clone, Hash, Eq, PartialEq, Debug)] @@ -150,10 +177,12 @@ where pub(crate) fn all_node_connections( &self, ) -> impl Iterator> + '_ { - self.connection_map.iter().filter_map(|(identifier, node)| { - node.as_ref() - .map(|node| (identifier.clone(), node.connection.clone())) - }) + self.connection_map + .iter() + .filter_map(move |(identifier, node)| { + node.as_ref() + .map(|node| (identifier.clone(), node.user_connection.clone())) + }) } pub(crate) fn all_primary_connections( @@ -178,7 +207,7 @@ where pub(crate) fn connection_for_identifier(&self, identifier: &Identifier) -> Option { let node = self.connection_map.get(identifier)?.as_ref()?; - Some(node.connection.clone()) + Some(node.user_connection.clone()) } pub(crate) fn connection_for_address( @@ -204,6 +233,7 @@ where pub(crate) fn random_connections( &self, amount: usize, + conn_type: ConnectionType, ) -> impl Iterator> + '_ { self.connection_map .iter() @@ -214,7 +244,10 @@ where }) .choose_multiple(&mut rand::thread_rng(), amount) .into_iter() - .map(|(identifier, node)| (identifier.clone(), node.connection.clone())) + .map(move |(identifier, node)| { + let conn = node.get_connection(&conn_type); + (identifier.clone(), conn) + }) } pub(crate) fn replace_or_add_connection_for_address( @@ -227,7 +260,7 @@ where identifier } - pub(crate) fn remove_connection( + pub(crate) fn remove_node( &mut self, identifier: &Identifier, ) -> Option> { @@ -237,7 +270,7 @@ where pub(crate) fn len(&self) -> usize { self.connection_map .iter() - .filter(|(_, conn_option)| conn_option.is_some()) + .filter(|(_, node)| node.is_some()) .count() } @@ -253,15 +286,26 @@ mod tests { use crate::cluster_routing::{Slot, SlotAddr}; use super::*; - - fn remove_connections(container: &mut ConnectionsContainer, identifiers: &[&str]) { + impl ClusterNode + where + Connection: Clone, + { + pub(crate) fn new_only_with_user_conn(user_connection: Connection) -> Self { + Self { + user_connection, + management_connection: None, + ip: None, + } + } + } + fn remove_nodes(container: &mut ConnectionsContainer, identifiers: &[&str]) { for identifier in identifiers { - container.remove_connection(&Identifier((*identifier).into())); + container.remove_node(&Identifier((*identifier).into())); } } fn remove_all_connections(container: &mut ConnectionsContainer) { - remove_connections( + remove_nodes( container, &[ "primary1", @@ -281,9 +325,24 @@ mod tests { let found = connection.unwrap().1; expected_connections.contains(&found) } + fn create_cluster_node( + connection: usize, + use_management_connections: bool, + ) -> Option> { + Some(ClusterNode::new( + connection, + if use_management_connections { + Some(connection * 10) + } else { + None + }, + None, + )) + } fn create_container_with_strategy( stragey: ReadFromReplicaStrategy, + use_management_connections: bool, ) -> ConnectionsContainer { let slot_map = SlotMap::new( vec![ @@ -306,27 +365,27 @@ mod tests { let mut connection_map = HashMap::new(); connection_map.insert( Identifier("primary1".into()), - Some(ClusterNode::new(1, None)), + create_cluster_node(1, use_management_connections), ); connection_map.insert( Identifier("primary2".into()), - Some(ClusterNode::new(2, None)), + create_cluster_node(2, use_management_connections), ); connection_map.insert( Identifier("primary3".into()), - Some(ClusterNode::new(3, None)), + create_cluster_node(3, use_management_connections), ); connection_map.insert( Identifier("replica2-1".into()), - Some(ClusterNode::new(21, None)), + create_cluster_node(21, use_management_connections), ); connection_map.insert( Identifier("replica3-1".into()), - Some(ClusterNode::new(31, None)), + create_cluster_node(31, use_management_connections), ); connection_map.insert( Identifier("replica3-2".into()), - Some(ClusterNode::new(32, None)), + create_cluster_node(32, use_management_connections), ); ConnectionsContainer { @@ -338,7 +397,7 @@ mod tests { } fn create_container() -> ConnectionsContainer { - create_container_with_strategy(ReadFromReplicaStrategy::RoundRobin) + create_container_with_strategy(ReadFromReplicaStrategy::RoundRobin, false) } #[test] @@ -452,7 +511,7 @@ mod tests { #[test] fn get_replica_connection_for_replica_route_if_some_but_not_all_replicas_were_removed() { let mut container = create_container(); - container.remove_connection(&Identifier("replica3-2".into())); + container.remove_node(&Identifier("replica3-2".into())); assert_eq!( 31, @@ -466,7 +525,8 @@ mod tests { #[test] fn get_replica_connection_for_replica_route_if_replica_is_required_even_if_strategy_is_always_from_primary( ) { - let container = create_container_with_strategy(ReadFromReplicaStrategy::AlwaysFromPrimary); + let container = + create_container_with_strategy(ReadFromReplicaStrategy::AlwaysFromPrimary, false); assert!(one_of( container.connection_for_route(&Route::new(2001, SlotAddr::ReplicaRequired)), @@ -477,7 +537,7 @@ mod tests { #[test] fn get_primary_connection_for_replica_route_if_all_replicas_were_removed() { let mut container = create_container(); - remove_connections(&mut container, &["replica2-1", "replica3-1", "replica3-2"]); + remove_nodes(&mut container, &["replica2-1", "replica3-1", "replica3-2"]); assert_eq!( 2, @@ -530,7 +590,7 @@ mod tests { #[test] fn get_connection_by_address_returns_none_if_connection_was_removed() { let mut container = create_container(); - container.remove_connection(&Identifier("primary1".into())); + container.remove_node(&Identifier("primary1".into())); assert!(container.connection_for_address("primary1").is_none()); } @@ -539,7 +599,7 @@ mod tests { fn get_connection_by_identifier_returns_none_if_connection_was_removed() { let mut container = create_container(); let identifier = Identifier("primary1".into()); - container.remove_connection(&identifier.clone()); + container.remove_node(&identifier.clone()); assert!(container.connection_for_identifier(&identifier).is_none()); } @@ -547,8 +607,10 @@ mod tests { #[test] fn get_connection_by_address_returns_added_connection() { let mut container = create_container(); - let identifier = - container.replace_or_add_connection_for_address("foobar", ClusterNode::new(4, None)); + let identifier = container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); assert_eq!(4, container.connection_for_identifier(&identifier).unwrap()); assert_eq!( @@ -561,8 +623,10 @@ mod tests { fn get_random_connections_without_repetitions() { let container = create_container(); - let random_connections: HashSet<_> = - container.random_connections(3).map(|pair| pair.1).collect(); + let random_connections: HashSet<_> = container + .random_connections(3, ConnectionType::User) + .map(|pair| pair.1) + .collect(); assert_eq!(random_connections.len(), 3); assert!(random_connections @@ -575,16 +639,25 @@ mod tests { let mut container = create_container(); remove_all_connections(&mut container); - assert_eq!(0, container.random_connections(1).count()); + assert_eq!( + 0, + container + .random_connections(1, ConnectionType::User) + .count() + ); } #[test] fn get_random_connections_returns_added_connection() { let mut container = create_container(); remove_all_connections(&mut container); - let identifier = - container.replace_or_add_connection_for_address("foobar", ClusterNode::new(4, None)); - let random_connections: Vec<_> = container.random_connections(1).collect(); + let identifier = container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + let random_connections: Vec<_> = container + .random_connections(1, ConnectionType::User) + .collect(); assert_eq!(vec![(identifier, 4)], random_connections); } @@ -593,7 +666,7 @@ mod tests { fn get_random_connections_is_bound_by_the_number_of_connections_in_the_map() { let container = create_container(); let mut random_connections: Vec<_> = container - .random_connections(1000) + .random_connections(1000, ConnectionType::User) .map(|pair| pair.1) .collect(); random_connections.sort(); @@ -602,7 +675,19 @@ mod tests { } #[test] - fn get_all_nodes() { + fn get_random_management_connections() { + let container = create_container_with_strategy(ReadFromReplicaStrategy::RoundRobin, true); + let mut random_connections: Vec<_> = container + .random_connections(1000, ConnectionType::_Management) + .map(|pair| pair.1) + .collect(); + random_connections.sort(); + + assert_eq!(random_connections, vec![10, 20, 30, 210, 310, 320]); + } + + #[test] + fn get_all_user_connections() { let container = create_container(); let mut connections: Vec<_> = container .all_node_connections() @@ -614,9 +699,12 @@ mod tests { } #[test] - fn get_all_nodes_returns_added_connection() { + fn get_all_user_connections_returns_added_connection() { let mut container = create_container(); - container.replace_or_add_connection_for_address("foobar", ClusterNode::new(4, None)); + container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); let mut connections: Vec<_> = container .all_node_connections() @@ -628,9 +716,9 @@ mod tests { } #[test] - fn get_all_nodes_does_not_return_removed_connection() { + fn get_all_user_connections_does_not_return_removed_connection() { let mut container = create_container(); - container.remove_connection(&Identifier("primary1".into())); + container.remove_node(&Identifier("primary1".into())); let mut connections: Vec<_> = container .all_node_connections() @@ -657,7 +745,7 @@ mod tests { #[test] fn get_all_primaries_does_not_return_removed_connection() { let mut container = create_container(); - container.remove_connection(&Identifier("primary1".into())); + container.remove_node(&Identifier("primary1".into())); let mut connections: Vec<_> = container .all_primary_connections() @@ -674,10 +762,13 @@ mod tests { assert_eq!(container.len(), 6); - container.remove_connection(&Identifier("primary1".into())); + container.remove_node(&Identifier("primary1".into())); assert_eq!(container.len(), 5); - container.replace_or_add_connection_for_address("foobar", ClusterNode::new(4, None)); + container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); assert_eq!(container.len(), 6); } @@ -688,21 +779,24 @@ mod tests { assert_eq!(container.len(), 6); - container.remove_connection(&Identifier("foobar".into())); + container.remove_node(&Identifier("foobar".into())); assert_eq!(container.len(), 6); - container.replace_or_add_connection_for_address("primary1", ClusterNode::new(4, None)); + container.replace_or_add_connection_for_address( + "primary1", + ClusterNode::new_only_with_user_conn(4), + ); assert_eq!(container.len(), 6); } #[test] - fn remove_connection_returns_connection_if_it_exists() { + fn remove_node_returns_connection_if_it_exists() { let mut container = create_container(); - let connection = container.remove_connection(&Identifier("primary1".into())); - assert_eq!(connection, Some(ClusterNode::new(1, None))); + let connection = container.remove_node(&Identifier("primary1".into())); + assert_eq!(connection, Some(ClusterNode::new_only_with_user_conn(1))); - let non_connection = container.remove_connection(&Identifier("foobar".into())); + let non_connection = container.remove_node(&Identifier("foobar".into())); assert_eq!(non_connection, None); } @@ -713,7 +807,7 @@ mod tests { let address = container.address_for_identifier(&Identifier("primary1".into())); assert_eq!(address, Some("primary1".into())); - container.remove_connection(&Identifier("primary1".into())); + container.remove_node(&Identifier("primary1".into())); let address = container.address_for_identifier(&Identifier("primary1".into())); assert_eq!(address, Some("primary1".into())); diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 634cee1fd..6877a46b0 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -89,7 +89,7 @@ use tokio::sync::{ use tracing::{info, trace, warn}; use self::connections_container::{ - ConnectionAndIdentifier, ConnectionsMap, Identifier as ConnectionIdentifier, + ConnectionAndIdentifier, ConnectionType, ConnectionsMap, Identifier as ConnectionIdentifier, }; /// This represents an async Redis Cluster connection. It stores the @@ -650,7 +650,7 @@ where result.map(|(conn, ip)| { ( node_identifier, - ClusterNode::new(async { conn }.boxed().shared(), ip), + ClusterNode::new(async { conn }.boxed().shared(), None, ip), ) }) } @@ -697,14 +697,14 @@ where &mut *connections_container, |connections_container, identifier| async move { let addr_option = connections_container.address_for_identifier(&identifier); - let node_option = connections_container.remove_connection(&identifier); + let node_option = connections_container.remove_node(&identifier); if let Some(addr) = addr_option { let conn = Self::get_or_create_conn(&addr, node_option, cluster_params).await; if let Ok((conn, ip)) = conn { connections_container.replace_or_add_connection_for_address( addr, - ClusterNode::new(async { conn }.boxed().shared(), ip), + ClusterNode::new(async { conn }.boxed().shared(), None, ip), ); } } @@ -860,7 +860,8 @@ where // When we no longer need to support Rust versions < 1.67, remove fast_math and transition to the ilog2 function. let num_of_nodes_to_query = std::cmp::max(fast_math::log2_raw(num_of_nodes as f32) as usize, 1); - let requested_nodes = read_guard.random_connections(num_of_nodes_to_query); + let requested_nodes = + read_guard.random_connections(num_of_nodes_to_query, ConnectionType::User); let topology_join_results = futures::future::join_all(requested_nodes.map(|conn| async move { let mut conn: C = conn.1.await; @@ -889,7 +890,8 @@ where let num_of_nodes = read_guard.len(); const MAX_REQUESTED_NODES: usize = 50; let num_of_nodes_to_query = std::cmp::min(num_of_nodes, MAX_REQUESTED_NODES); - let requested_nodes = read_guard.random_connections(num_of_nodes_to_query); + let requested_nodes = + read_guard.random_connections(num_of_nodes_to_query, ConnectionType::User); let topology_join_results = futures::future::join_all(requested_nodes.map(|conn| async move { let mut conn: C = conn.1.await; @@ -953,7 +955,7 @@ where if let Ok((conn, ip)) = conn { connections.0.insert( addr.into(), - ClusterNode::new(async { conn }.boxed().shared(), ip), + ClusterNode::new(async { conn }.boxed().shared(), None, ip), ); } connections @@ -1168,6 +1170,7 @@ where addr, ClusterNode::new( async move { connection_clone.clone() }.boxed().shared(), + None, ip, ), ); @@ -1184,8 +1187,10 @@ where Some(tuple) => tuple, None => { let read_guard = core.conn_lock.read().await; - let (random_identifier, random_conn_future) = - read_guard.random_connections(1).next().unwrap(); // TODO - this can panic. handle None. + let (random_identifier, random_conn_future) = read_guard + .random_connections(1, ConnectionType::User) + .next() + .unwrap(); // TODO - this can panic. handle None. drop(read_guard); (random_identifier, random_conn_future.await) } @@ -1351,7 +1356,7 @@ where params: &ClusterParams, ) -> RedisResult<(C, Option)> { if let Some(node) = node { - let mut conn = node.connection.await; + let mut conn = node.user_connection.await; if let Some(ref ip) = node.ip { if Self::is_dns_changed(addr, ip).await { return connect_and_check(addr, params.clone(), None).await;