diff --git a/redis/src/cluster.rs b/redis/src/cluster.rs index 6f236b179..9c269f83b 100644 --- a/redis/src/cluster.rs +++ b/redis/src/cluster.rs @@ -910,6 +910,7 @@ pub(crate) fn get_connection_info( redis: RedisConnectionInfo { password: cluster_params.password, username: cluster_params.username, + use_resp3: cluster_params.use_resp3, ..Default::default() }, }) diff --git a/redis/src/cluster_client.rs b/redis/src/cluster_client.rs index e2c76e8ce..98341fb43 100644 --- a/redis/src/cluster_client.rs +++ b/redis/src/cluster_client.rs @@ -33,6 +33,7 @@ struct BuilderParams { retries_configuration: RetryParams, connection_timeout: Option, topology_checks_interval: Option, + use_resp3: bool, } #[derive(Clone)] @@ -86,6 +87,7 @@ pub(crate) struct ClusterParams { pub(crate) connection_timeout: Duration, pub(crate) topology_checks_interval: Option, pub(crate) tls_params: Option, + pub(crate) use_resp3: bool, } impl ClusterParams { @@ -109,6 +111,7 @@ impl ClusterParams { connection_timeout: value.connection_timeout.unwrap_or(Duration::MAX), topology_checks_interval: value.topology_checks_interval, tls_params, + use_resp3: value.use_resp3, }) } } @@ -315,6 +318,15 @@ impl ClusterClientBuilder { self } + /// Sets whether the new ClusterClient should connect to the servers using RESP3. + /// + /// If not set, the default is to use RESP3. + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + pub fn use_resp3(mut self, use_resp3: bool) -> ClusterClientBuilder { + self.builder_params.use_resp3 = use_resp3; + self + } + /// Use `build()`. #[deprecated(since = "0.22.0", note = "Use build()")] pub fn open(self) -> RedisResult { diff --git a/redis/tests/support/cluster.rs b/redis/tests/support/cluster.rs index edb63825d..ab76a8afd 100644 --- a/redis/tests/support/cluster.rs +++ b/redis/tests/support/cluster.rs @@ -16,6 +16,7 @@ use tempfile::TempDir; use crate::support::{build_keys_and_certs_for_tls, Module}; +use super::use_resp3; #[cfg(feature = "tls-rustls")] use super::{build_single_client, load_certs_from_file}; @@ -343,6 +344,7 @@ impl TestClusterContext { .map(RedisServer::connection_info) .collect(); let mut builder = redis::cluster::ClusterClientBuilder::new(initial_nodes.clone()); + builder = builder.use_resp3(use_resp3()); #[cfg(feature = "tls-rustls")] if mtls_enabled { diff --git a/redis/tests/support/mod.rs b/redis/tests/support/mod.rs index 9c2fdb5fb..446320721 100644 --- a/redis/tests/support/mod.rs +++ b/redis/tests/support/mod.rs @@ -20,7 +20,7 @@ use redis::{ClientTlsConfig, TlsCertificates}; use socket2::{Domain, Socket, Type}; use tempfile::TempDir; -fn use_resp3() -> bool { +pub fn use_resp3() -> bool { env::var("RESP3").unwrap_or_default() == "true" } diff --git a/redis/tests/test_cluster.rs b/redis/tests/test_cluster.rs index a02dad909..f221b183a 100644 --- a/redis/tests/test_cluster.rs +++ b/redis/tests/test_cluster.rs @@ -136,6 +136,38 @@ fn test_cluster_multi_shard_commands() { assert_eq!(res, vec!["bazz", "bar", "foo"]); } +#[test] +fn test_cluster_resp3() { + if !use_resp3() { + return; + } + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.connection(); + + let hello: std::collections::HashMap = + redis::cmd("HELLO").query(&mut connection).unwrap(); + assert_eq!(hello.get("proto").unwrap(), &Value::Int(3)); + + let _: () = connection.hset("hash", "foo", "baz").unwrap(); + let _: () = connection.hset("hash", "bar", "foobar").unwrap(); + let result: Value = connection.hgetall("hash").unwrap(); + + assert_eq!( + result, + Value::Map(vec![ + ( + Value::BulkString("foo".as_bytes().to_vec()), + Value::BulkString("baz".as_bytes().to_vec()) + ), + ( + Value::BulkString("bar".as_bytes().to_vec()), + Value::BulkString("foobar".as_bytes().to_vec()) + ) + ]) + ); +} + #[test] #[cfg(feature = "script")] fn test_cluster_script() { diff --git a/redis/tests/test_cluster_async.rs b/redis/tests/test_cluster_async.rs index 8c6882fd9..982ed2220 100644 --- a/redis/tests/test_cluster_async.rs +++ b/redis/tests/test_cluster_async.rs @@ -203,6 +203,42 @@ fn test_async_cluster_route_info_to_nodes() { .unwrap(); } +#[test] +fn test_cluster_resp3() { + if !use_resp3() { + return; + } + block_on_all(async move { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.async_connection().await; + + let hello: HashMap = redis::cmd("HELLO") + .query_async(&mut connection) + .await + .unwrap(); + assert_eq!(hello.get("proto").unwrap(), &Value::Int(3)); + + let _: () = connection.hset("hash", "foo", "baz").await.unwrap(); + let _: () = connection.hset("hash", "bar", "foobar").await.unwrap(); + let result: Value = connection.hgetall("hash").await.unwrap(); + + assert_eq!( + result, + Value::Map(vec![ + ( + Value::BulkString("foo".as_bytes().to_vec()), + Value::BulkString("baz".as_bytes().to_vec()) + ), + ( + Value::BulkString("bar".as_bytes().to_vec()), + Value::BulkString("foobar".as_bytes().to_vec()) + ) + ]) + ); + }); +} + #[ignore] // TODO Handle pipe where the keys do not all go to the same node #[test] fn test_async_cluster_basic_pipe() { @@ -1854,21 +1890,24 @@ fn test_async_cluster_round_robin_read_from_replica() { fn get_queried_node_id_if_master(cluster_nodes_output: Value) -> Option { // Returns the node ID of the connection that was queried for CLUSTER NODES (using the 'myself' flag), if it's a master. // Otherwise, returns None. + let get_node_id = |str: &str| { + let parts: Vec<&str> = str.split('\n').collect(); + for node_entry in parts { + if node_entry.contains("myself") && node_entry.contains("master") { + let node_entry_parts: Vec<&str> = node_entry.split(' ').collect(); + let node_id = node_entry_parts[0]; + return Some(node_id.to_string()); + } + } + None + }; + match cluster_nodes_output { Value::BulkString(val) => match from_utf8(&val) { - Ok(str_res) => { - let parts: Vec<&str> = str_res.split('\n').collect(); - for node_entry in parts { - if node_entry.contains("myself") && node_entry.contains("master") { - let node_entry_parts: Vec<&str> = node_entry.split(' ').collect(); - let node_id = node_entry_parts[0]; - return Some(node_id.to_string()); - } - } - None - } + Ok(str_res) => get_node_id(str_res), Err(e) => panic!("failed to decode INFO response: {:?}", e), }, + Value::VerbatimString { format: _, text } => get_node_id(&text), _ => panic!("Recieved unexpected response: {:?}", cluster_nodes_output), } }