diff --git a/redis-test/src/lib.rs b/redis-test/src/lib.rs index fb21e13bf9..cafe8a347b 100644 --- a/redis-test/src/lib.rs +++ b/redis-test/src/lib.rs @@ -288,6 +288,10 @@ impl AioConnectionLike for MockRedisConnection { fn get_db(&self) -> i64 { 0 } + + fn is_closed(&self) -> bool { + false + } } #[cfg(test)] diff --git a/redis/examples/async-await.rs b/redis/examples/async-await.rs index 36b8182a89..b52776a460 100644 --- a/redis/examples/async-await.rs +++ b/redis/examples/async-await.rs @@ -4,7 +4,7 @@ use redis::AsyncCommands; #[tokio::main] async fn main() -> redis::RedisResult<()> { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); - let mut con = client.get_multiplexed_async_connection(None).await?; + let mut con = client.get_multiplexed_async_connection(None, None).await?; con.set("key1", b"foo").await?; diff --git a/redis/examples/async-connection-loss.rs b/redis/examples/async-connection-loss.rs index 4c2d54d082..90af361f2a 100644 --- a/redis/examples/async-connection-loss.rs +++ b/redis/examples/async-connection-loss.rs @@ -80,7 +80,9 @@ async fn main() -> RedisResult<()> { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); match mode { - Mode::Default => run_multi(client.get_multiplexed_tokio_connection(None).await?).await?, + Mode::Default => { + run_multi(client.get_multiplexed_tokio_connection(None, None).await?).await? + } Mode::Reconnect => run_multi(client.get_connection_manager().await?).await?, #[allow(deprecated)] Mode::Deprecated => run_single(client.get_async_connection(None).await?).await?, diff --git a/redis/examples/async-multiplexed.rs b/redis/examples/async-multiplexed.rs index b057b759ca..9c8c73235c 100644 --- a/redis/examples/async-multiplexed.rs +++ b/redis/examples/async-multiplexed.rs @@ -34,7 +34,10 @@ async fn test_cmd(con: &MultiplexedConnection, i: i32) -> RedisResult<()> { async fn main() { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); - let con = client.get_multiplexed_tokio_connection(None).await.unwrap(); + let con = client + .get_multiplexed_tokio_connection(None, None) + .await + .unwrap(); let cmds = (0..100).map(|i| test_cmd(&con, i)); let result = future::try_join_all(cmds).await.unwrap(); diff --git a/redis/examples/async-pub-sub.rs b/redis/examples/async-pub-sub.rs index 3dbb7e0f9f..15634e2b00 100644 --- a/redis/examples/async-pub-sub.rs +++ b/redis/examples/async-pub-sub.rs @@ -5,7 +5,7 @@ use redis::AsyncCommands; #[tokio::main] async fn main() -> redis::RedisResult<()> { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); - let mut publish_conn = client.get_multiplexed_async_connection(None).await?; + let mut publish_conn = client.get_multiplexed_async_connection(None, None).await?; let mut pubsub_conn = client.get_async_pubsub().await?; pubsub_conn.subscribe("wavephone").await?; diff --git a/redis/examples/async-scan.rs b/redis/examples/async-scan.rs index 55e33d0eaf..6f55ac933f 100644 --- a/redis/examples/async-scan.rs +++ b/redis/examples/async-scan.rs @@ -5,7 +5,7 @@ use redis::{AsyncCommands, AsyncIter}; #[tokio::main] async fn main() -> redis::RedisResult<()> { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); - let mut con = client.get_multiplexed_async_connection(None).await?; + let mut con = client.get_multiplexed_async_connection(None, None).await?; con.set("async-key1", b"foo").await?; con.set("async-key2", b"foo").await?; diff --git a/redis/src/aio/connection.rs b/redis/src/aio/connection.rs index d78ef0850a..6b1f6e657a 100644 --- a/redis/src/aio/connection.rs +++ b/redis/src/aio/connection.rs @@ -305,6 +305,11 @@ where fn get_db(&self) -> i64 { self.db } + + fn is_closed(&self) -> bool { + // always false for AsyncRead + AsyncWrite (cant do better) + false + } } /// Represents a `PubSub` connection. diff --git a/redis/src/aio/connection_manager.rs b/redis/src/aio/connection_manager.rs index 0070d97736..741086d766 100644 --- a/redis/src/aio/connection_manager.rs +++ b/redis/src/aio/connection_manager.rs @@ -196,6 +196,7 @@ impl ConnectionManager { response_timeout, connection_timeout, None, + None, ) }) .await @@ -301,4 +302,9 @@ impl ConnectionLike for ConnectionManager { fn get_db(&self) -> i64 { self.client.connection_info().redis.db } + + fn is_closed(&self) -> bool { + // always return false due to automatic reconnect + false + } } diff --git a/redis/src/aio/mod.rs b/redis/src/aio/mod.rs index 04ebe960fa..021550a8a6 100644 --- a/redis/src/aio/mod.rs +++ b/redis/src/aio/mod.rs @@ -85,6 +85,24 @@ pub trait ConnectionLike { /// also might be incorrect if the connection like object is not /// actually connected. fn get_db(&self) -> i64; + + /// Returns the state of the connection + fn is_closed(&self) -> bool; +} + +/// Implements ability to notify about disconnection events +pub trait DisconnectNotifier: Send + Sync { + /// Notify about disconnect event + fn notify_disconnect(&mut self); + + /// Inteded to be used with Box + fn clone_box(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.clone_box() + } } // Initial setup for every connection. diff --git a/redis/src/aio/multiplexed_connection.rs b/redis/src/aio/multiplexed_connection.rs index 64e1ed7f2d..3e63afceb3 100644 --- a/redis/src/aio/multiplexed_connection.rs +++ b/redis/src/aio/multiplexed_connection.rs @@ -1,5 +1,6 @@ use super::{ConnectionLike, Runtime}; use crate::aio::setup_connection; +use crate::aio::DisconnectNotifier; use crate::cmd::Cmd; #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] use crate::parser::ValueCodec; @@ -23,6 +24,7 @@ use std::fmt; use std::fmt::Debug; use std::io; use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::task::{self, Poll}; use std::time::Duration; @@ -77,6 +79,7 @@ struct Pipeline { sender: mpsc::Sender>, push_manager: Arc>, + is_stream_closed: Arc, } impl Clone for Pipeline { @@ -84,6 +87,7 @@ impl Clone for Pipeline { Pipeline { sender: self.sender.clone(), push_manager: self.push_manager.clone(), + is_stream_closed: self.is_stream_closed.clone(), } } } @@ -104,6 +108,8 @@ pin_project! { in_flight: VecDeque, error: Option, push_manager: Arc>, + disconnect_notifier: Option>, + is_stream_closed: Arc, } } @@ -111,7 +117,12 @@ impl PipelineSink where T: Stream> + 'static, { - fn new(sink_stream: T, push_manager: Arc>) -> Self + fn new( + sink_stream: T, + push_manager: Arc>, + disconnect_notifier: Option>, + is_stream_closed: Arc, + ) -> Self where T: Sink + Stream> + 'static, { @@ -120,6 +131,8 @@ where in_flight: VecDeque::new(), error: None, push_manager, + disconnect_notifier, + is_stream_closed, } } @@ -130,7 +143,15 @@ where Some(result) => result, // The redis response stream is not going to produce any more items so we `Err` // to break out of the `forward` combinator and stop handling requests - None => return Poll::Ready(Err(())), + None => { + // this is the right place to notify about the passive TCP disconnect + // In other places we cannot distinguish between the active destruction of MultiplexedConnection and passive disconnect + if let Some(disconnect_notifier) = self.as_mut().project().disconnect_notifier { + disconnect_notifier.notify_disconnect(); + } + self.is_stream_closed.store(true, Ordering::Relaxed); + return Poll::Ready(Err(())); + } }; self.as_mut().send_result(item); } @@ -296,7 +317,10 @@ impl Pipeline where SinkItem: Send + 'static, { - fn new(sink_stream: T) -> (Self, impl Future) + fn new( + sink_stream: T, + disconnect_notifier: Option>, + ) -> (Self, impl Future) where T: Sink + Stream> + 'static, T: Send + 'static, @@ -308,7 +332,13 @@ where let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE); let push_manager: Arc> = Arc::new(ArcSwap::new(Arc::new(PushManager::default()))); - let sink = PipelineSink::new::(sink_stream, push_manager.clone()); + let is_stream_closed = Arc::new(AtomicBool::new(false)); + let sink = PipelineSink::new::( + sink_stream, + push_manager.clone(), + disconnect_notifier, + is_stream_closed.clone(), + ); let f = stream::poll_fn(move |cx| receiver.poll_recv(cx)) .map(Ok) .forward(sink) @@ -317,6 +347,7 @@ where Pipeline { sender, push_manager, + is_stream_closed, }, f, ) @@ -363,6 +394,10 @@ where async fn set_push_manager(&mut self, push_manager: PushManager) { self.push_manager.store(Arc::new(push_manager)); } + + pub fn is_closed(&self) -> bool { + self.is_stream_closed.load(Ordering::Relaxed) + } } /// A connection object which can be cloned, allowing requests to be be sent concurrently @@ -392,6 +427,7 @@ impl MultiplexedConnection { connection_info: &ConnectionInfo, stream: C, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<(Self, impl Future)> where C: Unpin + AsyncRead + AsyncWrite + Send + 'static, @@ -401,6 +437,7 @@ impl MultiplexedConnection { stream, std::time::Duration::MAX, push_sender, + disconnect_notifier, ) .await } @@ -412,6 +449,7 @@ impl MultiplexedConnection { stream: C, response_timeout: std::time::Duration, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<(Self, impl Future)> where C: Unpin + AsyncRead + AsyncWrite + Send + 'static, @@ -429,7 +467,7 @@ impl MultiplexedConnection { let codec = ValueCodec::default() .framed(stream) .and_then(|msg| async move { msg }); - let (mut pipeline, driver) = Pipeline::new(codec); + let (mut pipeline, driver) = Pipeline::new(codec, disconnect_notifier); let driver = boxed(driver); let pm = PushManager::default(); if let Some(sender) = push_sender { @@ -560,6 +598,10 @@ impl ConnectionLike for MultiplexedConnection { fn get_db(&self) -> i64 { self.db } + + fn is_closed(&self) -> bool { + self.pipeline.is_closed() + } } impl MultiplexedConnection { /// Subscribes to a new channel. diff --git a/redis/src/client.rs b/redis/src/client.rs index 7ace000890..534c186d91 100644 --- a/redis/src/client.rs +++ b/redis/src/client.rs @@ -1,5 +1,8 @@ use std::time::Duration; +#[cfg(feature = "aio")] +use crate::aio::DisconnectNotifier; + use crate::{ connection::{connect, Connection, ConnectionInfo, ConnectionLike, IntoConnectionInfo}, push_manager::PushInfo, @@ -147,11 +150,13 @@ impl Client { pub async fn get_multiplexed_async_connection( &self, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult { self.get_multiplexed_async_connection_with_timeouts( std::time::Duration::MAX, std::time::Duration::MAX, push_sender, + disconnect_notifier, ) .await } @@ -167,6 +172,7 @@ impl Client { response_timeout: std::time::Duration, connection_timeout: std::time::Duration, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult { let result = match Runtime::locate() { #[cfg(feature = "tokio-comp")] @@ -177,6 +183,7 @@ impl Client { response_timeout, None, push_sender, + disconnect_notifier, ), ) .await @@ -189,6 +196,7 @@ impl Client { response_timeout, None, push_sender, + disconnect_notifier, ), ) .await @@ -213,6 +221,7 @@ impl Client { pub async fn get_multiplexed_async_connection_and_ip( &self, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<(crate::aio::MultiplexedConnection, Option)> { match Runtime::locate() { #[cfg(feature = "tokio-comp")] @@ -221,6 +230,7 @@ impl Client { Duration::MAX, None, push_sender, + disconnect_notifier, ) .await } @@ -230,6 +240,7 @@ impl Client { Duration::MAX, None, push_sender, + disconnect_notifier, ) .await } @@ -247,6 +258,7 @@ impl Client { response_timeout: std::time::Duration, connection_timeout: std::time::Duration, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult { let result = Runtime::locate() .timeout( @@ -255,6 +267,7 @@ impl Client { response_timeout, None, push_sender, + disconnect_notifier, ), ) .await; @@ -275,11 +288,13 @@ impl Client { pub async fn get_multiplexed_tokio_connection( &self, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult { self.get_multiplexed_tokio_connection_with_response_timeouts( std::time::Duration::MAX, std::time::Duration::MAX, push_sender, + disconnect_notifier, ) .await } @@ -295,6 +310,7 @@ impl Client { response_timeout: std::time::Duration, connection_timeout: std::time::Duration, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult { let result = Runtime::locate() .timeout( @@ -303,6 +319,7 @@ impl Client { response_timeout, None, push_sender, + disconnect_notifier, ), ) .await; @@ -323,11 +340,13 @@ impl Client { pub async fn get_multiplexed_async_std_connection( &self, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult { self.get_multiplexed_async_std_connection_with_timeouts( std::time::Duration::MAX, std::time::Duration::MAX, push_sender, + disconnect_notifier, ) .await } @@ -344,6 +363,7 @@ impl Client { &self, response_timeout: std::time::Duration, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -352,6 +372,7 @@ impl Client { response_timeout, None, push_sender, + disconnect_notifier, ) .await .map(|(conn, driver, _ip)| (conn, driver)) @@ -367,6 +388,7 @@ impl Client { pub async fn create_multiplexed_tokio_connection( &self, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -374,6 +396,7 @@ impl Client { self.create_multiplexed_tokio_connection_with_response_timeout( std::time::Duration::MAX, push_sender, + disconnect_notifier, ) .await .map(|conn_res| (conn_res.0, conn_res.1)) @@ -391,6 +414,7 @@ impl Client { &self, response_timeout: std::time::Duration, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -399,6 +423,7 @@ impl Client { response_timeout, None, push_sender, + disconnect_notifier, ) .await .map(|(conn, driver, _ip)| (conn, driver)) @@ -414,6 +439,7 @@ impl Client { pub async fn create_multiplexed_async_std_connection( &self, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -421,6 +447,7 @@ impl Client { self.create_multiplexed_async_std_connection_with_response_timeout( std::time::Duration::MAX, push_sender, + disconnect_notifier, ) .await } @@ -624,6 +651,7 @@ impl Client { response_timeout: std::time::Duration, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<(crate::aio::MultiplexedConnection, Option)> where T: crate::aio::RedisRuntime, @@ -633,6 +661,7 @@ impl Client { response_timeout, socket_addr, push_sender, + disconnect_notifier, ) .await?; T::spawn(driver); @@ -644,6 +673,7 @@ impl Client { response_timeout: std::time::Duration, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -658,6 +688,7 @@ impl Client { con, response_timeout, push_sender, + disconnect_notifier, ) .await .map(|res| (res.0, res.1, ip)) diff --git a/redis/src/cluster_async/connections_logic.rs b/redis/src/cluster_async/connections_logic.rs index 96d9965c34..dc3fd82d03 100644 --- a/redis/src/cluster_async/connections_logic.rs +++ b/redis/src/cluster_async/connections_logic.rs @@ -5,7 +5,7 @@ use super::{ Connect, }; use crate::{ - aio::{ConnectionLike, Runtime}, + aio::{ConnectionLike, DisconnectNotifier, Runtime}, cluster::get_connection_info, cluster_client::ClusterParams, push_manager::PushInfo, @@ -57,6 +57,7 @@ pub(crate) async fn get_or_create_conn( params: &ClusterParams, conn_type: RefreshConnectionType, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult> where C: ConnectionLike + Send + Clone + Sync + Connect + 'static, @@ -73,14 +74,23 @@ where conn_type, Some(node), push_sender, + disconnect_notifier, ) .await .get_node(), } } else { - connect_and_check(addr, params.clone(), None, conn_type, None, push_sender) - .await - .get_node() + connect_and_check( + addr, + params.clone(), + None, + conn_type, + None, + push_sender, + disconnect_notifier, + ) + .await + .get_node() } } @@ -102,6 +112,7 @@ pub(crate) async fn connect_and_check_all_connections( params: ClusterParams, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> ConnectAndCheckResult where C: ConnectionLike + Connect + Send + Sync + 'static + Clone, @@ -113,8 +124,16 @@ where socket_addr, push_sender.clone(), false, + disconnect_notifier.clone(), + ), + create_connection( + addr, + params.clone(), + socket_addr, + push_sender, + true, + disconnect_notifier, ), - create_connection(addr, params.clone(), socket_addr, push_sender, true), ) .await { @@ -160,11 +179,21 @@ async fn connect_and_check_only_management_conn( params: ClusterParams, socket_addr: Option, prev_node: AsyncClusterNode, + disconnect_notifier: Option>, ) -> ConnectAndCheckResult where C: ConnectionLike + Connect + Send + Sync + 'static + Clone, { - match create_connection::(addr, params.clone(), socket_addr, None, true).await { + match create_connection::( + addr, + params.clone(), + socket_addr, + None, + true, + disconnect_notifier, + ) + .await + { Err(conn_err) => failed_management_connection(addr, prev_node.user_connection, conn_err), Ok(mut connection) => { @@ -241,6 +270,7 @@ pub async fn connect_and_check( conn_type: RefreshConnectionType, node: Option>, push_sender: Option>, + disconnect_notifier: Option>, ) -> ConnectAndCheckResult where C: ConnectionLike + Connect + Send + Sync + 'static + Clone, @@ -252,6 +282,7 @@ where params.clone(), socket_addr, push_sender, + disconnect_notifier, ) .await { @@ -265,15 +296,36 @@ where // Refreshing only the management connection requires the node to exist alongside a user connection. Otherwise, refresh all connections. match node { Some(node) => { - connect_and_check_only_management_conn(addr, params, socket_addr, node).await + connect_and_check_only_management_conn( + addr, + params, + socket_addr, + node, + disconnect_notifier, + ) + .await } None => { - connect_and_check_all_connections(addr, params, socket_addr, push_sender).await + connect_and_check_all_connections( + addr, + params, + socket_addr, + push_sender, + disconnect_notifier, + ) + .await } } } RefreshConnectionType::AllConnections => { - connect_and_check_all_connections(addr, params, socket_addr, push_sender).await + connect_and_check_all_connections( + addr, + params, + socket_addr, + push_sender, + disconnect_notifier, + ) + .await } } } @@ -283,12 +335,20 @@ async fn create_and_setup_user_connection( params: ClusterParams, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult> where C: ConnectionLike + Connect + Send + 'static, { - let mut connection: ConnectionWithIp = - create_connection(node, params.clone(), socket_addr, push_sender, false).await?; + let mut connection: ConnectionWithIp = create_connection( + node, + params.clone(), + socket_addr, + push_sender, + false, + disconnect_notifier, + ) + .await?; setup_user_connection(&mut connection.conn, params).await?; Ok(connection) } @@ -328,6 +388,7 @@ async fn create_connection( socket_addr: Option, push_sender: Option>, is_management: bool, + disconnect_notifier: Option>, ) -> RedisResult> where C: ConnectionLike + Connect + Send + 'static, @@ -339,12 +400,18 @@ where params.pubsub_subscriptions = None; } let info = get_connection_info(node, params)?; + // management connection does not require notifications or disconnect notifications C::connect( info, response_timeout, connection_timeout, socket_addr, if !is_management { push_sender } else { None }, + if !is_management { + disconnect_notifier + } else { + None + }, ) .await .map(|conn| conn.into()) diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index cf977dd2a5..19cd27c84c 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -58,7 +58,7 @@ use std::{ use tokio::task::JoinHandle; use crate::{ - aio::{get_socket_addrs, ConnectionLike, MultiplexedConnection, Runtime}, + aio::{get_socket_addrs, ConnectionLike, DisconnectNotifier, MultiplexedConnection, Runtime}, cluster::slot_cmd, cluster_async::connections_logic::{ get_host_and_port_from_addr, get_or_create_conn, ConnectionFuture, RefreshConnectionType, @@ -91,6 +91,8 @@ use backoff_std_async::{Error as BackoffError, ExponentialBackoff}; use backoff_tokio::future::retry; #[cfg(feature = "tokio-comp")] use backoff_tokio::{Error as BackoffError, ExponentialBackoff}; +#[cfg(feature = "tokio-comp")] +use tokio::{sync::Notify, time::timeout}; use dispose::{Disposable, Dispose}; use futures::{future::BoxFuture, prelude::*, ready}; @@ -370,6 +372,23 @@ where } } +#[cfg(feature = "tokio-comp")] +#[derive(Clone)] +struct TokioDisconnectNotifier { + pub disconnect_notifier: Arc, +} + +#[cfg(feature = "tokio-comp")] +impl DisconnectNotifier for TokioDisconnectNotifier { + fn notify_disconnect(&mut self) { + self.disconnect_notifier.notify_one(); + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + type ConnectionMap = connections_container::ConnectionsMap>; type ConnectionsContainer = self::connections_container::ConnectionsContainer>; @@ -383,6 +402,9 @@ pub(crate) struct InnerCore { push_sender: Option>, subscriptions_by_address: RwLock>, unassigned_subscriptions: RwLock, + disconnect_notifier: Option>, + #[cfg(feature = "tokio-comp")] + tokio_notify: Arc, } pub(crate) type Core = Arc>; @@ -461,6 +483,8 @@ pub(crate) struct ClusterConnInner { refresh_error: Option, // Handler of the periodic check task. periodic_checks_handler: Option>, + // Handler of fast connection validation task + connections_validation_handler: Option>, } impl Dispose for ClusterConnInner { @@ -471,6 +495,12 @@ impl Dispose for ClusterConnInner { #[cfg(feature = "tokio-comp")] handle.abort() } + if let Some(handle) = self.connections_validation_handler { + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + block_on(handle.cancel()); + #[cfg(feature = "tokio-comp")] + handle.abort() + } } } @@ -957,9 +987,27 @@ where cluster_params: ClusterParams, push_sender: Option>, ) -> RedisResult> { - let connections = - Self::create_initial_connections(initial_nodes, &cluster_params, push_sender.clone()) - .await?; + #[cfg(feature = "tokio-comp")] + let tokio_notify = Arc::new(Notify::new()); + + let disconnect_notifier = { + #[cfg(feature = "tokio-comp")] + { + Some::>(Box::new(TokioDisconnectNotifier { + disconnect_notifier: tokio_notify.clone(), + })) + } + #[cfg(not(feature = "tokio-comp"))] + None + }; + + let connections = Self::create_initial_connections( + initial_nodes, + &cluster_params, + push_sender.clone(), + disconnect_notifier.clone(), + ) + .await?; let topology_checks_interval = cluster_params.topology_checks_interval; let slots_refresh_rate_limiter = cluster_params.slots_refresh_rate_limit; @@ -983,6 +1031,9 @@ where }, ), subscriptions_by_address: RwLock::new(Default::default()), + disconnect_notifier: disconnect_notifier.clone(), + #[cfg(feature = "tokio-comp")] + tokio_notify, }); let mut connection = ClusterConnInner { inner, @@ -990,6 +1041,7 @@ where refresh_error: None, state: ConnectionState::PollComplete, periodic_checks_handler: None, + connections_validation_handler: None, }; Self::refresh_slots_and_subscriptions_with_retries( connection.inner.clone(), @@ -1010,6 +1062,22 @@ where } } + let connections_validation_interval = cluster_params.connections_validation_interval; + if let Some(duration) = connections_validation_interval { + let connections_validation_handler = + ClusterConnInner::connections_validation_task(connection.inner.clone(), duration); + #[cfg(feature = "tokio-comp")] + { + connection.connections_validation_handler = + Some(tokio::spawn(connections_validation_handler)); + } + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + { + connection.connections_validation_handler = + Some(spawn(connections_validation_handler)); + } + } + Ok(Disposable::new(connection)) } @@ -1058,6 +1126,7 @@ where initial_nodes: &[ConnectionInfo], params: &ClusterParams, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult> { let initial_nodes: Vec<(String, Option)> = Self::try_to_expand_initial_nodes(initial_nodes).await; @@ -1067,6 +1136,7 @@ where let push_sender = push_sender.clone(); // set subscriptions to none, they will be applied upon the topology discovery params.pubsub_subscriptions = None; + let disconnect_notifier = disconnect_notifier.clone(); async move { let result = connect_and_check( @@ -1076,6 +1146,7 @@ where RefreshConnectionType::AllConnections, None, push_sender, + disconnect_notifier, ) .await .get_node(); @@ -1122,6 +1193,7 @@ where &inner.initial_nodes, &inner.cluster_params, None, + inner.disconnect_notifier.clone(), ) .await { @@ -1145,22 +1217,92 @@ where } } + // Validate all existing user connections and try to reconnect if nessesary. + // In addition, as a safety measure, drop nodes that do not have any assigned slots. + // This function serves as a cheap alternative to slot_refresh() and thus can be used much more frequently. + // The function does not discover the topology from the cluster and assumes the cached topology is valid. + // In addition, the validation is done by peeking at the state of the underlying transport w/o overhead of additional commands to server. + async fn validate_all_user_connections(inner: Arc>) { + let mut all_valid_conns = HashMap::new(); + let mut all_nodes_with_slots = HashSet::new(); + // prep connections and clean out these w/o assigned slots, as we might have established connections to unwanted hosts + { + let mut nodes_to_delete = Vec::new(); + let connections_container = inner.conn_lock.read().await; + + connections_container + .slot_map + .addresses_for_all_nodes() + .iter() + .for_each(|addr| { + all_nodes_with_slots.insert(String::from(*addr)); + }); + + connections_container + .all_node_connections() + .for_each(|(addr, con)| { + if all_nodes_with_slots.contains(&addr) { + all_valid_conns.insert(addr.clone(), con.clone()); + } else { + nodes_to_delete.push(addr.clone()); + } + }); + + for addr in nodes_to_delete.iter() { + connections_container.remove_node(addr); + } + } + + // identify nodes with closed connection + let mut addrs_to_refresh = Vec::new(); + for addr_and_fut in all_valid_conns.iter() { + let con = addr_and_fut.1.clone().await; + if con.is_closed() { + addrs_to_refresh.push(addr_and_fut.0.clone()); + } + } + + // identify missing nodes + all_nodes_with_slots.iter().for_each(|addr| { + if !all_valid_conns.contains_key(addr) { + addrs_to_refresh.push(addr.clone()); + } + }); + + if !addrs_to_refresh.is_empty() { + // dont try existing nodes since we know a. it does not exist. b. exist but its connection is closed + Self::refresh_connections( + inner.clone(), + addrs_to_refresh, + RefreshConnectionType::AllConnections, + false, + ) + .await; + } + } + async fn refresh_connections( inner: Arc>, addresses: Vec, conn_type: RefreshConnectionType, + try_existing_node: bool, ) { info!("Started refreshing connections to {:?}", addresses); let connections_container = inner.conn_lock.read().await; let cluster_params = &inner.cluster_params; let subscriptions_by_address = &inner.subscriptions_by_address; let push_sender = &inner.push_sender; + let disconnect_notifier = &inner.disconnect_notifier; stream::iter(addresses.into_iter()) .fold( &*connections_container, |connections_container, address| async move { - let node_option = connections_container.remove_node(&address); + let node_option = if try_existing_node { + connections_container.remove_node(&address) + } else { + Option::None + }; // override subscriptions for this connection let mut cluster_params = cluster_params.clone(); @@ -1173,6 +1315,7 @@ where &cluster_params, conn_type, push_sender.clone(), + disconnect_notifier.clone(), ) .await; match node { @@ -1394,6 +1537,20 @@ where } } + async fn connections_validation_task(inner: Arc>, interval_duration: Duration) { + loop { + #[cfg(feature = "tokio-comp")] + let _ = timeout(interval_duration, async { + inner.tokio_notify.notified().await; + }) + .await; + #[cfg(not(feature = "tokio-comp"))] + let _ = boxed_sleep(interval_duration).await; + + Self::validate_all_user_connections(inner.clone()).await; + } + } + async fn refresh_pubsub_subscriptions(inner: Arc>) { if inner.cluster_params.protocol != crate::types::ProtocolVersion::RESP3 { return; @@ -1471,17 +1628,12 @@ where drop(subs_by_address_guard); if !addrs_to_refresh.is_empty() { - let conns_read_guard = inner.conn_lock.read().await; - // have to remove or otherwise the refresh_connection wont trigger node recreation - for addr_to_refresh in addrs_to_refresh.iter() { - conns_read_guard.remove_node(addr_to_refresh); - } - drop(conns_read_guard); // immediately trigger connection reestablishment Self::refresh_connections( inner.clone(), addrs_to_refresh.into_iter().collect(), RefreshConnectionType::AllConnections, + false, ) .await; } @@ -1517,6 +1669,7 @@ where inner, failed_connections, RefreshConnectionType::OnlyManagementConnection, + true, ) .await; } @@ -1616,6 +1769,7 @@ where &cluster_params, RefreshConnectionType::AllConnections, inner.push_sender.clone(), + inner.disconnect_notifier.clone(), ) .await; if let Ok(node) = node { @@ -1911,6 +2065,7 @@ where RefreshConnectionType::AllConnections, None, core.push_sender.clone(), + core.disconnect_notifier.clone(), ) .await .get_node() @@ -2221,6 +2376,7 @@ where self.inner.clone(), addresses, RefreshConnectionType::OnlyUserConnection, + true, ), ))); } @@ -2329,7 +2485,12 @@ where fn get_db(&self) -> i64 { 0 } + + fn is_closed(&self) -> bool { + false + } } + /// Implements the process of connecting to a Redis server /// and obtaining a connection handle. pub trait Connect: Sized { @@ -2342,6 +2503,7 @@ pub trait Connect: Sized { connection_timeout: Duration, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisFuture<'a, (Self, Option)> where T: IntoConnectionInfo + Send + 'a; @@ -2354,6 +2516,7 @@ impl Connect for MultiplexedConnection { connection_timeout: Duration, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisFuture<'a, (MultiplexedConnection, Option)> where T: IntoConnectionInfo + Send + 'a, @@ -2371,6 +2534,7 @@ impl Connect for MultiplexedConnection { response_timeout, socket_addr, push_sender, + disconnect_notifier, ), ) .await? @@ -2382,6 +2546,7 @@ impl Connect for MultiplexedConnection { response_timeout, socket_addr, push_sender, + disconnect_notifier, )) .await? } diff --git a/redis/src/cluster_client.rs b/redis/src/cluster_client.rs index 7c47631798..5815bede1e 100644 --- a/redis/src/cluster_client.rs +++ b/redis/src/cluster_client.rs @@ -42,6 +42,8 @@ struct BuilderParams { #[cfg(feature = "cluster-async")] topology_checks_interval: Option, #[cfg(feature = "cluster-async")] + connections_validation_interval: Option, + #[cfg(feature = "cluster-async")] slots_refresh_rate_limit: SlotsRefreshRateLimit, client_name: Option, response_timeout: Option, @@ -138,6 +140,8 @@ pub struct ClusterParams { pub(crate) topology_checks_interval: Option, #[cfg(feature = "cluster-async")] pub(crate) slots_refresh_rate_limit: SlotsRefreshRateLimit, + #[cfg(feature = "cluster-async")] + pub(crate) connections_validation_interval: Option, pub(crate) tls_params: Option, pub(crate) client_name: Option, pub(crate) connection_timeout: Duration, @@ -169,6 +173,8 @@ impl ClusterParams { topology_checks_interval: value.topology_checks_interval, #[cfg(feature = "cluster-async")] slots_refresh_rate_limit: value.slots_refresh_rate_limit, + #[cfg(feature = "cluster-async")] + connections_validation_interval: value.connections_validation_interval, tls_params, client_name: value.client_name, response_timeout: value.response_timeout.unwrap_or(Duration::MAX), @@ -393,6 +399,16 @@ impl ClusterClientBuilder { self } + /// Enables periodic connections checks for this client. + /// If enabled, the conenctions to the cluster nodes will be validated periodicatly, per configured interval. + /// In addition, for tokio runtime, passive disconnections could be detected instantly, + /// triggering reestablishemnt, w/o waiting for the next periodic check. + #[cfg(feature = "cluster-async")] + pub fn periodic_connections_checks(mut self, interval: Duration) -> ClusterClientBuilder { + self.builder_params.connections_validation_interval = Some(interval); + self + } + /// Sets the rate limit for slot refresh operations in the cluster. /// /// This method configures the interval duration between consecutive slot diff --git a/redis/src/sentinel.rs b/redis/src/sentinel.rs index 2e30ec02d5..8b853f643b 100644 --- a/redis/src/sentinel.rs +++ b/redis/src/sentinel.rs @@ -301,7 +301,7 @@ fn find_valid_master( #[cfg(feature = "aio")] async fn async_check_role(connection_info: &ConnectionInfo, target_role: &str) -> bool { if let Ok(client) = Client::open(connection_info.clone()) { - if let Ok(mut conn) = client.get_multiplexed_async_connection(None).await { + if let Ok(mut conn) = client.get_multiplexed_async_connection(None, None).await { let result: RedisResult> = crate::cmd("ROLE").query_async(&mut conn).await; return check_role_result(&result, target_role); } @@ -366,7 +366,7 @@ async fn async_reconnect( ) -> RedisResult<()> { let sentinel_client = Client::open(connection_info.clone())?; let new_connection = sentinel_client - .get_multiplexed_async_connection(None) + .get_multiplexed_async_connection(None, None) .await?; connection.replace(new_connection); Ok(()) @@ -768,6 +768,6 @@ impl SentinelClient { #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] pub async fn get_async_connection(&mut self) -> RedisResult { let client = self.async_get_client().await?; - client.get_multiplexed_async_connection(None).await + client.get_multiplexed_async_connection(None, None).await } } diff --git a/redis/tests/support/mock_cluster.rs b/redis/tests/support/mock_cluster.rs index b9f27710b7..0c7212e29e 100644 --- a/redis/tests/support/mock_cluster.rs +++ b/redis/tests/support/mock_cluster.rs @@ -29,6 +29,9 @@ use futures::future; #[cfg(feature = "cluster-async")] use tokio::runtime::Runtime; +#[cfg(feature = "aio")] +use redis::aio::DisconnectNotifier; + type Handler = Arc Result<(), RedisResult> + Send + Sync>; pub struct MockConnectionBehavior { @@ -135,10 +138,12 @@ impl cluster_async::Connect for MockConnection { _connection_timeout: Duration, _socket_addr: Option, _push_sender: Option>, + _disconnect_notifier: Option>, ) -> RedisFuture<'a, (Self, Option)> where T: IntoConnectionInfo + Send + 'a, { + let _ = _disconnect_notifier; let info = info.into_connection_info().unwrap(); let (name, port) = match &info.addr { @@ -369,6 +374,10 @@ impl aio::ConnectionLike for MockConnection { fn get_db(&self) -> i64 { 0 } + + fn is_closed(&self) -> bool { + false + } } impl redis::ConnectionLike for MockConnection { diff --git a/redis/tests/support/mod.rs b/redis/tests/support/mod.rs index 24f786c2e3..96ce71e6a4 100644 --- a/redis/tests/support/mod.rs +++ b/redis/tests/support/mod.rs @@ -501,7 +501,9 @@ impl TestContext { #[cfg(feature = "aio")] pub async fn async_connection(&self) -> redis::RedisResult { - self.client.get_multiplexed_async_connection(None).await + self.client + .get_multiplexed_async_connection(None, None) + .await } #[cfg(feature = "aio")] @@ -513,7 +515,9 @@ impl TestContext { pub async fn async_connection_async_std( &self, ) -> redis::RedisResult { - self.client.get_multiplexed_async_std_connection(None).await + self.client + .get_multiplexed_async_std_connection(None, None) + .await } pub fn stop_server(&mut self) { @@ -531,14 +535,18 @@ impl TestContext { pub async fn multiplexed_async_connection_tokio( &self, ) -> redis::RedisResult { - self.client.get_multiplexed_tokio_connection(None).await + self.client + .get_multiplexed_tokio_connection(None, None) + .await } #[cfg(feature = "async-std-comp")] pub async fn multiplexed_async_connection_async_std( &self, ) -> redis::RedisResult { - self.client.get_multiplexed_async_std_connection(None).await + self.client + .get_multiplexed_async_std_connection(None, None) + .await } pub fn get_version(&self) -> Version { diff --git a/redis/tests/test_async.rs b/redis/tests/test_async.rs index c0fc7fe3e1..f7c892a264 100644 --- a/redis/tests/test_async.rs +++ b/redis/tests/test_async.rs @@ -100,7 +100,7 @@ mod basic_async { fn dont_panic_on_closed_multiplexed_connection() { let ctx = TestContext::new(); let client = ctx.client.clone(); - let connect = client.get_multiplexed_async_connection(None); + let connect = client.get_multiplexed_async_connection(None, None); drop(ctx); block_on_all(async move { @@ -584,7 +584,7 @@ mod basic_async { let client = redis::Client::open(coninfo).unwrap(); let err = client - .get_multiplexed_tokio_connection(None) + .get_multiplexed_tokio_connection(None, None) .await .err() .unwrap(); @@ -916,7 +916,7 @@ mod basic_async { let millisecond = std::time::Duration::from_millis(1); let mut retries = 0; loop { - match client.get_multiplexed_async_connection(None).await { + match client.get_multiplexed_async_connection(None, None).await { Err(err) => { if err.is_connection_refusal() { tokio::time::sleep(millisecond).await; @@ -986,7 +986,7 @@ mod basic_async { let client = build_single_client(ctx.server.connection_info(), &ctx.server.tls_paths, true) .unwrap(); - let connect = client.get_multiplexed_async_connection(None); + let connect = client.get_multiplexed_async_connection(None, None); block_on_all(connect.and_then(|mut con| async move { redis::cmd("SET") .arg("key1") @@ -1007,7 +1007,7 @@ mod basic_async { let client = build_single_client(ctx.server.connection_info(), &ctx.server.tls_paths, false) .unwrap(); - let connect = client.get_multiplexed_async_connection(None); + let connect = client.get_multiplexed_async_connection(None, None); let result = block_on_all(connect.and_then(|mut con| async move { redis::cmd("SET") .arg("key1") diff --git a/redis/tests/test_async_async_std.rs b/redis/tests/test_async_async_std.rs index aabe58320b..ae2ae8443f 100644 --- a/redis/tests/test_async_async_std.rs +++ b/redis/tests/test_async_async_std.rs @@ -61,7 +61,7 @@ fn test_args_async_std() { fn dont_panic_on_closed_multiplexed_connection() { let ctx = TestContext::new(); let client = ctx.client.clone(); - let connect = client.get_multiplexed_async_std_connection(None); + let connect = client.get_multiplexed_async_std_connection(None, None); drop(ctx); block_on_all_using_async_std(async move { diff --git a/redis/tests/test_async_cluster_connections_logic.rs b/redis/tests/test_async_cluster_connections_logic.rs index 2a5bab6aec..07e41a6993 100644 --- a/redis/tests/test_async_cluster_connections_logic.rs +++ b/redis/tests/test_async_cluster_connections_logic.rs @@ -73,6 +73,7 @@ mod test_connect_and_check { RefreshConnectionType::AllConnections, None, None, + None, ) .await; let node = assert_full_success(result); @@ -109,6 +110,7 @@ mod test_connect_and_check { RefreshConnectionType::AllConnections, None, None, + None, ) .await; let (node, _) = assert_partial_result(result); @@ -127,6 +129,7 @@ mod test_connect_and_check { RefreshConnectionType::AllConnections, None, None, + None, ) .await; let (node, _) = assert_partial_result(result); @@ -160,6 +163,7 @@ mod test_connect_and_check { RefreshConnectionType::AllConnections, None, None, + None, ) .await; let node = assert_full_success(result); @@ -197,6 +201,7 @@ mod test_connect_and_check { RefreshConnectionType::AllConnections, None, None, + None, ) .await; let err = result.get_error().unwrap(); @@ -248,6 +253,7 @@ mod test_connect_and_check { RefreshConnectionType::OnlyManagementConnection, Some(node), None, + None, ) .await; let node = assert_full_success(result); @@ -295,6 +301,7 @@ mod test_connect_and_check { RefreshConnectionType::OnlyManagementConnection, Some(node), None, + None, ) .await; let (node, _) = assert_partial_result(result); @@ -357,6 +364,7 @@ mod test_connect_and_check { RefreshConnectionType::OnlyUserConnection, Some(node), None, + None, ) .await; let node = assert_full_success(result); diff --git a/redis/tests/test_cluster_async.rs b/redis/tests/test_cluster_async.rs index 7d1249c3e3..8c1d0d7e01 100644 --- a/redis/tests/test_cluster_async.rs +++ b/redis/tests/test_cluster_async.rs @@ -21,7 +21,7 @@ mod cluster_async { use std::ops::Add; use redis::{ - aio::{ConnectionLike, MultiplexedConnection}, + aio::{ConnectionLike, DisconnectNotifier, MultiplexedConnection}, cluster::ClusterClient, cluster_async::{testing::MANAGEMENT_CONN_NAME, ClusterConnection, Connect}, cluster_routing::{ @@ -44,6 +44,60 @@ mod cluster_async { )) } + fn validate_subscriptions( + pubsub_subs: &PubSubSubscriptionInfo, + notifications_rx: &mut mpsc::UnboundedReceiver, + allow_disconnects: bool, + ) { + let mut subscribe_cnt = + if let Some(exact_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Exact) { + exact_subs.len() + } else { + 0 + }; + + let mut psubscribe_cnt = + if let Some(pattern_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Pattern) { + pattern_subs.len() + } else { + 0 + }; + + let mut ssubscribe_cnt = + if let Some(sharded_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Sharded) { + sharded_subs.len() + } else { + 0 + }; + + for _ in 0..(subscribe_cnt + psubscribe_cnt + ssubscribe_cnt) { + let result = notifications_rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data: _ } = result.unwrap(); + assert!( + kind == PushKind::Subscribe + || kind == PushKind::PSubscribe + || kind == PushKind::SSubscribe + || if allow_disconnects { + kind == PushKind::Disconnection + } else { + false + } + ); + if kind == PushKind::Subscribe { + subscribe_cnt -= 1; + } else if kind == PushKind::PSubscribe { + psubscribe_cnt -= 1; + } else if kind == PushKind::SSubscribe { + ssubscribe_cnt -= 1; + } + } + + assert!(subscribe_cnt == 0); + assert!(psubscribe_cnt == 0); + assert!(ssubscribe_cnt == 0); + } + #[test] fn test_async_cluster_basic_cmd() { let cluster = TestClusterContext::new(3, 0); @@ -382,7 +436,7 @@ mod cluster_async { .unwrap_or_else(|e| panic!("Failed to connect to '{addr}': {e}")); let mut conn = client - .get_multiplexed_async_connection(None) + .get_multiplexed_async_connection(None, None) .await .unwrap_or_else(|e| panic!("Failed to get connection: {e}")); @@ -482,6 +536,7 @@ mod cluster_async { connection_timeout: std::time::Duration, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisFuture<'a, (Self, Option)> where T: IntoConnectionInfo + Send + 'a, @@ -493,6 +548,7 @@ mod cluster_async { connection_timeout, socket_addr, push_sender, + disconnect_notifier, ) .await?; Ok((ErrorConnection { inner }, None)) @@ -521,6 +577,10 @@ mod cluster_async { fn get_db(&self) -> i64 { self.inner.get_db() } + + fn is_closed(&self) -> bool { + true + } } #[test] @@ -2683,546 +2743,522 @@ mod cluster_async { } #[test] - fn test_async_cluster_restore_resp3_pubsub_state_after_complete_server_disconnect() { - // let cluster = TestClusterContext::new_with_cluster_client_builder( - // 3, - // 0, - // |builder| builder.retries(3).use_protocol(ProtocolVersion::RESP3), - // //|builder| builder.retries(3), - // false, - // ); - - // block_on_all(async move { - // let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); - // let mut connection = cluster.async_connection(Some(tx.clone())).await; - // // assuming the implementation of TestCluster assigns the slots monotonicaly incerasing with the nodes - // let route_0 = redis::cluster_routing::Route::new(0, redis::cluster_routing::SlotAddr::Master); - // let node_0_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(route_0); - // let route_2 = redis::cluster_routing::Route::new(16 * 1024 - 1, redis::cluster_routing::SlotAddr::Master); - // let node_2_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(route_2); - - // let result = connection - // .route_command(&redis::Cmd::new().arg("SUBSCRIBE").arg("test_channel"), RoutingInfo::SingleNode(node_0_route.clone())) - // //.route_command(&redis::Cmd::new().arg("SUBSCRIBE").arg("test_channel"), RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) - // .await; - - // assert_eq!( - // result, - // Ok(Value::Push { - // kind: PushKind::Subscribe, - // data: vec![Value::BulkString("test_channel".into()), Value::Int(1)], - // }) - // ); - - // // pull out all the subscribe notification, this push notification is due to the previous subscribe command - // let result = rx.recv().await; - // assert!(result.is_some()); - // let PushInfo { kind, data } = result.unwrap(); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Subscribe, - // vec![ - // Value::BulkString("test_channel".as_bytes().to_vec()), - // Value::Int(1), - // ] - // ) - // ); - - // // ensure subscription, routing on the same node, expected return Int(1) - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel").arg("test_message_from_node_0"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // // ensure subscription, routing on different node, expected return Int(0) - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel").arg("test_message_from_node_2"), RoutingInfo::SingleNode(node_2_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(0)) - // ); - - // for i in vec![0, 2] { - // let result = rx.recv().await; - // assert!(result.is_some()); - // let PushInfo { kind, data } = result.unwrap(); - // println!("^^^^^^^^^ '{:?} -> {:?}'", kind, data); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Message, - // vec![ - // Value::BulkString("test_channel".into()), - // Value::BulkString(format!("test_message_from_node_{}", i).into()), - // ] - // ) - // ); - // } - - // // drop and recreate cluster and connections - // drop(cluster); - // println!("*********** DROPPED **********"); - - // let cluster = TestClusterContext::new_with_cluster_client_builder( - // 3, - // 0, - // |builder| builder.retries(3).use_protocol(ProtocolVersion::RESP3), - // //|builder| builder.retries(3), - // false, - // ); - - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel").arg("test_message_from_node_0"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // //sleep(futures_time::time::Duration::from_secs(15)).await; - // //return Ok(()); - - // let cluster = TestClusterContext::new_with_cluster_client_builder( - // 3, - // 0, - // |builder| builder.retries(3).use_protocol(ProtocolVersion::RESP3), - // //|builder| builder.retries(3), - // false, - // ); - - // // ensure subscription state restore - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel").arg("test_message_from_node_0"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // // non-subscribed channel - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_1").arg("should_not_receive"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(0)) - // ); - - // // ensure subscription, routing on different node, expected return Int(0) - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel").arg("test_message_from_node_2"), RoutingInfo::SingleNode(node_2_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(0)) - // ); - - // // should produce an arbitrary number of 'disconnected' notifications - 1 for the intitial try after the drop and an unknown? amout during reconnecting procedure - // // Notifications become available ONLY after we try to send the commands, since push manager does not register TCP disconnect on a idle socket - // // Remove the any amount of 'disconnected' notifications - // sleep(futures_time::time::Duration::from_secs(1)).await; - // //let mut result = rx.recv().await; - // let mut result = rx.try_recv(); - // assert!(result.is_ok()); - // //assert!(result.is_some()); - // loop { - // let kind = result.clone().unwrap().kind; - // if kind != PushKind::Disconnection && kind != PushKind::Subscribe { - // break; - // } - // // result = rx.recv().await; - // // assert!(result.is_some()); - // result = rx.try_recv(); - // assert!(result.is_ok()); - // } - - // // ensure messages test_message_from_node_0 and test_message_from_node_2 - // let mut msg_from_0 = false; - // let mut msg_from_2 = false; - // while !msg_from_0 && !msg_from_2 { - // let mut result = rx.recv().await; - // assert!(result.is_some()); - // let PushInfo { kind, data } = result.unwrap(); - - // assert!(kind == PushKind::Disconnection || kind == PushKind::Subscribe || kind == PushKind::Message); - // if kind == PushKind::Disconnection || kind == PushKind::Subscribe { - // // ignore - // continue; - // } - - // if data == vec![ - // Value::BulkString("test_channel".into()), - // Value::BulkString("test_message_from_node_0".into())] { - // assert!(!msg_from_0); - // msg_from_0 = true; - // } - // else if data == vec![ - // Value::BulkString("test_channel".into()), - // Value::BulkString("test_message_from_node_2".into())] { - // assert!(!msg_from_2); - // msg_from_2 = true; - // } - // else { - // assert!(false, "Unexpected message received"); - // } - // } - - // // let mut msg_from_0 = false; - // // let mut msg_from_2 = false; - // // while !msg_from_2 { - // // let mut result = rx.recv().await; - // // assert!(result.is_some()); - // // let PushInfo { kind, data } = result.unwrap(); - - // // assert!(kind == PushKind::Disconnection || kind == PushKind::Subscribe || kind == PushKind::Message); - // // if kind == PushKind::Disconnection || kind == PushKind::Subscribe { - // // // ignore - // // continue; - // // } - - // // if data == vec![ - // // Value::BulkString("test_channel".into()), - // // Value::BulkString("test_message_from_node_2".into())] { - // // assert!(!msg_from_2); - // // msg_from_2 = true; - // // } - // // else { - // // assert!(false, "Unexpected message received"); - // // } - // // } - - // Ok(()) - // }) - // .unwrap(); + fn test_async_cluster_test_fast_reconnect() { + // Note the 3 seconds connection check to differentiate between notifications and periodic + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(0) + .periodic_connections_checks(Duration::from_secs(3)) + }, + false, + ); + + // For tokio-comp, do 3 consequtive disconnects and ensure reconnects succeeds in less than 100ms, + // which is more than enough for local connections even with TLS. + // More than 1 run is done to ensure it is the fast reconnect notification that trigger the reconnect + // and not the periodic interval. + // For other async implementation, only periodic connection check is available, hence, + // do 1 run sleeping for periodic connection check interval, allowing it to reestablish connections + block_on_all(async move { + let mut disconnecting_con = cluster.async_connection(None).await; + let mut monitoring_con = cluster.async_connection(None).await; + + #[cfg(feature = "tokio-comp")] + let tries = 0..3; + #[cfg(not(feature = "tokio-comp"))] + let tries = 0..1; + + for _ in tries { + // get connection id + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("ID"); + let res = disconnecting_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + let id = { + match res { + Value::Int(id) => id, + _ => { + panic!("Wrong return value for CLIENT ID command: {:?}", res); + } + } + }; + + // ask server to kill the connection + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("KILL").arg("ID").arg(id).arg("SKIPME").arg("NO"); + let res = disconnecting_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + // assert server has closed connection + assert_eq!(res, Ok(Value::Int(1))); + + #[cfg(feature = "tokio-comp")] + // ensure reconnect happened in less than 100ms + sleep(futures_time::time::Duration::from_millis(100)).await; + + #[cfg(not(feature = "tokio-comp"))] + // no fast notification is available, wait for 1 periodic check + overhead + sleep(futures_time::time::Duration::from_secs(3 + 1)).await; + + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("LIST").arg("TYPE").arg("NORMAL"); + let res = monitoring_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + let client_list: String = { + match res { + // RESP2 + Value::BulkString(client_info) => { + // ensure 4 connections - 2 for each client, its save to unwrap here + String::from_utf8(client_info).unwrap() + } + // RESP3 + Value::VerbatimString { format: _, text } => text, + _ => { + panic!("Wrong return type for CLIENT LIST command: {:?}", res); + } + } + }; + assert_eq!(client_list.chars().filter(|&x| x == '\n').count(), 4); + } + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_restore_resp3_pubsub_state_passive_disconnect() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel".as_bytes())]), + )]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + // note topology change detection is not activated since no topology change is expected + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + .periodic_connections_checks(Duration::from_secs(1)) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut _listening_con = cluster.async_connection(Some(tx.clone())).await; + // Note, publishing connection has the same pubsub config + let mut publishing_con = cluster.async_connection(None).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate subscriptions + validate_subscriptions(&client_subscriptions, &mut rx, false); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + // simulate passive disconnect + drop(cluster); + + // recreate the cluster, the assumtion is that the cluster is built with exactly the same params (ports, slots map...) + let _cluster = + TestClusterContext::new_with_cluster_client_builder(3, 0, |builder| builder, false); + + // sleep for 1 periodic_connections_checks + overhead + sleep(futures_time::time::Duration::from_secs(1 + 1)).await; + + // new subscription notifications due to resubscriptions + validate_subscriptions(&client_subscriptions, &mut rx, true); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + Ok(()) + }) + .unwrap(); } #[test] - fn test_async_cluster_restore_resp3_pubsub_state_after_scale_in() { - - // let client_subscriptions = PubSubSubscriptionInfo::from( - // [ - // (PubSubSubscriptionKind::Exact, HashSet::from( - // [ - // // test_channel_? is used as it maps to the last node in both 3 and 6 node config - // // (assuming slots allocation is monotonicaly increasing starting from node 0) - // PubSubChannelOrPattern::from(b"test_channel_?") - // ]) - // ) - // ] - // ); - - // let cluster = TestClusterContext::new_with_cluster_client_builder( - // 6, - // 0, - // |builder| builder - // .retries(3) - // .use_protocol(ProtocolVersion::RESP3) - // .pubsub_subscriptions(client_subscriptions.clone()), - // false, - // ); - - // block_on_all(async move { - // let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); - // let mut connection = cluster.async_connection(Some(tx.clone())).await; - - // // short sleep to allow the server to push subscription notification - // sleep(futures_time::time::Duration::from_secs(1)).await; - // let result = rx.try_recv(); - // assert!(result.is_ok()); - // let PushInfo { kind, data } = result.unwrap(); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Subscribe, - // vec![ - // Value::BulkString("test_channel_?".into()), - // Value::Int(1), - // ] - // ) - // ); - - // let slot_14212 = get_slot(b"test_channel_?"); - // assert_eq!(slot_14212, 14212); - // let slot_14212_route = redis::cluster_routing::Route::new(slot_14212, redis::cluster_routing::SlotAddr::Master); - // let node_5_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_14212_route); - - // let result = connection - // //.route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_msg"), RoutingInfo::SingleNode(node_5_route.clone())) - // .route_command(&redis::Cmd::new().arg("PING"), RoutingInfo::SingleNode(node_5_route.clone())) - // .await; - // // let slot_0_route = redis::cluster_routing::Route::new(0, redis::cluster_routing::SlotAddr::Master); - // // let node_0_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_0_route); - - // let result = cmd("PUBLISH") - // .arg("test_channel_?") - // .arg("test_message") - // .query_async(&mut connection) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // sleep(futures_time::time::Duration::from_secs(1)).await; - // let result = rx.try_recv(); - // assert!(result.is_ok()); - // let PushInfo { kind, data } = result.unwrap(); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Message, - // vec![ - // Value::BulkString("test_channel_?".into()), - // Value::BulkString(format!("test_message").into()), - // ] - // ) - // ); - - // // simulate scale in - // drop(cluster); - // println!("*********** DROPPED **********"); - // let cluster = TestClusterContext::new_with_cluster_client_builder( - // 3, - // 0, - // |builder| builder - // .retries(6) - // .use_protocol(ProtocolVersion::RESP3) - // .pubsub_subscriptions(client_subscriptions.clone()), - // false, - // ); - - // sleep(futures_time::time::Duration::from_secs(3)).await; - - // //ensure subscription notification due to resubscription - // // let result = cmd("PUBLISH") - // // .arg("test_channel_?") - // // .arg("test_message") - // // .query_async(&mut connection) - // // .await; - // let result = connection - // //.route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_msg"), RoutingInfo::SingleNode(node_5_route.clone())) - // .route_command(&redis::Cmd::new().arg("PING"), RoutingInfo::SingleNode(node_5_route.clone())) - // .await; - // // assert_eq!( - // // result, - // // Ok(Value::Int(1)) - // // ); - - // let slot_14212 = get_slot(b"test_channel_?"); - // assert_eq!(slot_14212, 14212); - // let slot_14212_route = redis::cluster_routing::Route::new(slot_14212, redis::cluster_routing::SlotAddr::Master); - // let node_2_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_14212_route); - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_message"), RoutingInfo::SingleNode(node_2_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // sleep(futures_time::time::Duration::from_secs(1)).await; - // let result = rx.try_recv(); - // assert!(result.is_ok()); - // let PushInfo { kind, data } = result.unwrap(); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Subscribe, - // vec![ - // Value::BulkString("test_channel_?".into()), - // Value::Int(1), - // ] - // ) - // ); - - // let result = rx.try_recv(); - // assert!(result.is_ok()); - // let PushInfo { kind, data } = result.unwrap(); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Disconnection, - // vec![], - // ) - // ); - - // return Ok(()); - - // // Subscribe on the slot 14212, this slot will reside on the last node in both 3 and 6 nodes cluster, - // // When the cluster is recreated with 3 nodes, this slot will reside on different network address. - // // Assuming the implementation of TestCluster assigns the slots monotonicaly incerasing with the nodes - // let slot_14212 = get_slot(b"test_channel_?"); - // assert_eq!(slot_14212, 14212); - // let slot_14212_route = redis::cluster_routing::Route::new(slot_14212, redis::cluster_routing::SlotAddr::Master); - // let node_5_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_14212_route); - - // let slot_0_route = redis::cluster_routing::Route::new(0, redis::cluster_routing::SlotAddr::Master); - // let node_0_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_0_route); - - // let result = connection - // .route_command(&redis::Cmd::new().arg("SUBSCRIBE").arg("test_channel_?"), RoutingInfo::SingleNode(node_5_route.clone())) - // //.route_command(&redis::Cmd::new().arg("SUBSCRIBE").arg("test_channel"), RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) - // .await; - - // assert_eq!( - // result, - // Ok(Value::Push { - // kind: PushKind::Subscribe, - // data: vec![Value::BulkString("test_channel_?".into()), Value::Int(1)], - // }) - // ); - - // // pull out all the subscribe notification, this push notification is due to the previous subscribe command - // let result = rx.recv().await; - // assert!(result.is_some()); - // let PushInfo { kind, data } = result.unwrap(); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Subscribe, - // vec![ - // Value::BulkString("test_channel_?".as_bytes().to_vec()), - // Value::Int(1), - // ] - // ) - // ); - - // // ensure subscription, routing on the last node, expected return Int(1) - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_message_from_node_5"), RoutingInfo::SingleNode(node_5_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // // ensure subscription, routing on the first node, expected return Int(0) - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_message_from_node_0"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(0)) - // ); - - // for i in vec![5, 0] { - // let result = rx.recv().await; - // assert!(result.is_some()); - // let PushInfo { kind, data } = result.unwrap(); - // println!("^^^^^^^^^ '{:?} -> {:?}'", kind, data); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Message, - // vec![ - // Value::BulkString("test_channel_?".into()), - // Value::BulkString(format!("test_message_from_node_{}", i).into()), - // ] - // ) - // ); - // } - - // // drop and recreate cluster and connections - // drop(cluster); - // println!("*********** DROPPED **********"); - - // let cluster = TestClusterContext::new_with_cluster_client_builder( - // 3, - // 0, - // |builder| builder.retries(3).use_protocol(ProtocolVersion::RESP3), - // //|builder| builder.retries(3), - // false, - // ); - - // // ensure subscription state restore - // let node_2_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_14212_route); - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_message_from_node_2"), RoutingInfo::SingleNode(node_2_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // // non-subscribed channel - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_1").arg("should_not_receive"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(0)) - // ); - - // // ensure subscription, routing on different node, expected return Int(0) - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_message_from_node_2"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(0)) - // ); - - // // should produce an arbitrary number of 'disconnected' notifications - 1 for the intitial try after the drop and an unknown? amout during reconnecting procedure - // // Notifications become available ONLY after we try to send the commands, since push manager does not register TCP disconnect on a idle socket - // // Remove the any amount of 'disconnected' notifications - // sleep(futures_time::time::Duration::from_secs(1)).await; - // //let mut result = rx.recv().await; - // let mut result = rx.try_recv(); - // assert!(result.is_ok()); - // //assert!(result.is_some()); - // loop { - // let kind = result.clone().unwrap().kind; - // if kind != PushKind::Disconnection && kind != PushKind::Subscribe { - // break; - // } - // // result = rx.recv().await; - // // assert!(result.is_some()); - // result = rx.try_recv(); - // assert!(result.is_ok()); - // } - - // // ensure messages test_message_from_node_0 and test_message_from_node_2 - // let mut msg_from_0 = false; - // let mut msg_from_2 = false; - // while !msg_from_0 && !msg_from_2 { - // let mut result = rx.recv().await; - // assert!(result.is_some()); - // let PushInfo { kind, data } = result.unwrap(); - - // assert!(kind == PushKind::Disconnection || kind == PushKind::Subscribe || kind == PushKind::Message); - // if kind == PushKind::Disconnection || kind == PushKind::Subscribe { - // // ignore - // continue; - // } - - // if data == vec![ - // Value::BulkString("test_channel".into()), - // Value::BulkString("test_message_from_node_0".into())] { - // assert!(!msg_from_0); - // msg_from_0 = true; - // } - // else if data == vec![ - // Value::BulkString("test_channel".into()), - // Value::BulkString("test_message_from_node_2".into())] { - // assert!(!msg_from_2); - // msg_from_2 = true; - // } - // else { - // assert!(false, "Unexpected message received"); - // } - // } - - // Ok(()) - // }) - // .unwrap(); + fn test_async_cluster_restore_resp3_pubsub_state_after_scale_out() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([ + // test_channel_? is used as it maps to 14212 slot, which is the last node in both 3 and 6 node config + // (assuming slots allocation is monotonicaly increasing starting from node 0) + ( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ), + ]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + let slot_14212 = get_slot(b"test_channel_?"); + assert_eq!(slot_14212, 14212); + + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + // periodic connection check is required to detect the disconnect from the last node + .periodic_connections_checks(Duration::from_secs(1)) + // periodic topology check is required to detect topology change + .periodic_topology_checks(Duration::from_secs(1)) + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut _listening_con = cluster.async_connection(Some(tx.clone())).await; + // Note, publishing connection has the same pubsub config + let mut publishing_con = cluster.async_connection(None).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate subscriptions + validate_subscriptions(&client_subscriptions, &mut rx, false); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + // drop and recreate a cluster with more nodes + drop(cluster); + + // recreate the cluster, the assumtion is that the cluster is built with exactly the same params (ports, slots map...) + let cluster = + TestClusterContext::new_with_cluster_client_builder(6, 0, |builder| builder, false); + + // assume slot 14212 will reside in the last node + let last_server_port = { + let addr = cluster.cluster.servers.last().unwrap().addr.clone(); + match addr { + redis::ConnectionAddr::TcpTls { + host: _, + port, + insecure: _, + tls_params: _, + } => port, + redis::ConnectionAddr::Tcp(_, port) => port, + _ => { + panic!("Wrong server address type: {:?}", addr); + } + } + }; + + // wait for new topology discovery + loop { + let mut cmd = redis::cmd("INFO"); + cmd.arg("SERVER"); + let res = publishing_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + slot_14212, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + match res { + Value::VerbatimString { format: _, text } => { + if text.contains(format!("tcp_port:{}", last_server_port).as_str()) { + // new topology rediscovered + break; + } + } + _ => { + panic!("Wrong return type for INFO SERVER command: {:?}", res); + } + } + sleep(futures_time::time::Duration::from_secs(1)).await; + } + + // sleep for one one cycle of topology refresh + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate PUBLISH + let result = redis::cmd("PUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + // allow message to propagate + sleep(futures_time::time::Duration::from_secs(1)).await; + + loop { + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + // ignore disconnection and subscription notifications due to resubscriptions + if kind == PushKind::Message { + assert_eq!( + data, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ); + break; + } + } + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + // allow message to propagate + sleep(futures_time::time::Duration::from_secs(1)).await; + + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + drop(publishing_con); + drop(_listening_con); + + Ok(()) + }) + .unwrap(); + + block_on_all(async move { + sleep(futures_time::time::Duration::from_secs(10)).await; + Ok(()) + }) + .unwrap(); } - //#[allow(unreachable_code)] #[test] fn test_async_cluster_resp3_pubsub() { let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); @@ -3268,39 +3304,10 @@ mod cluster_async { // short sleep to allow the server to push subscription notification sleep(futures_time::time::Duration::from_secs(1)).await; - let mut subscribe_cnt = client_subscriptions[&PubSubSubscriptionKind::Exact].len(); - let mut psubscribe_cnt = client_subscriptions[&PubSubSubscriptionKind::Pattern].len(); - let mut ssubscribe_cnt = 0; - if let Some(sharded_shubs) = client_subscriptions.get(&PubSubSubscriptionKind::Sharded) - { - ssubscribe_cnt += sharded_shubs.len() - } - for _ in 0..(subscribe_cnt + psubscribe_cnt + ssubscribe_cnt) { - let result = rx.try_recv(); - assert!(result.is_ok()); - let PushInfo { kind, data: _ } = result.unwrap(); - assert!( - kind == PushKind::Subscribe - || kind == PushKind::PSubscribe - || kind == PushKind::SSubscribe - ); - if kind == PushKind::Subscribe { - subscribe_cnt -= 1; - } else if kind == PushKind::PSubscribe { - psubscribe_cnt -= 1; - } else { - ssubscribe_cnt -= 1; - } - } - - assert!(subscribe_cnt == 0); - assert!(psubscribe_cnt == 0); - assert!(ssubscribe_cnt == 0); + validate_subscriptions(&client_subscriptions, &mut rx, false); let slot_14212 = get_slot(b"test_channel_?"); assert_eq!(slot_14212, 14212); - //let slot_14212_route = redis::cluster_routing::Route::new(slot_14212, redis::cluster_routing::SlotAddr::Master); - //let node_5_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_14212_route); let slot_0_route = redis::cluster_routing::Route::new(0, redis::cluster_routing::SlotAddr::Master); diff --git a/redis/tests/test_sentinel.rs b/redis/tests/test_sentinel.rs index 53ff86e485..0782c8b6d8 100644 --- a/redis/tests/test_sentinel.rs +++ b/redis/tests/test_sentinel.rs @@ -283,7 +283,7 @@ pub mod async_tests { .await .unwrap(); let mut replica_con = replica_client - .get_multiplexed_async_connection(None) + .get_multiplexed_async_connection(None, None) .await .unwrap(); @@ -316,7 +316,7 @@ pub mod async_tests { .await .unwrap(); let mut replica_con = replica_client - .get_multiplexed_async_connection(None) + .get_multiplexed_async_connection(None, None) .await .unwrap(); @@ -338,12 +338,14 @@ pub mod async_tests { let master_client = sentinel .async_master_for(master_name, Some(&node_conn_info)) .await?; - let mut master_con = master_client.get_multiplexed_async_connection(None).await?; + let mut master_con = master_client + .get_multiplexed_async_connection(None, None) + .await?; let mut replica_con = sentinel .async_replica_for(master_name, Some(&node_conn_info)) .await? - .get_multiplexed_async_connection(None) + .get_multiplexed_async_connection(None, None) .await?; async_assert_is_connection_to_master(&mut master_con).await; @@ -367,7 +369,9 @@ pub mod async_tests { let master_client = sentinel .async_master_for(master_name, Some(&node_conn_info)) .await?; - let mut master_con = master_client.get_multiplexed_async_connection(None).await?; + let mut master_con = master_client + .get_multiplexed_async_connection(None, None) + .await?; async_assert_is_connection_to_master(&mut master_con).await; @@ -408,7 +412,9 @@ pub mod async_tests { let master_client = sentinel .async_master_for(master_name, Some(&node_conn_info)) .await?; - let mut master_con = master_client.get_multiplexed_async_connection(None).await?; + let mut master_con = master_client + .get_multiplexed_async_connection(None, None) + .await?; async_assert_is_connection_to_master(&mut master_con).await;