Skip to content

Commit

Permalink
Merge pull request #74 from nihohit/cluster-resp3
Browse files Browse the repository at this point in the history
Add RESP3 support to cluster connections.
  • Loading branch information
shachlanAmazon authored Dec 11, 2023
2 parents 270b999 + 7cb3aec commit 8b3fb0a
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 12 deletions.
1 change: 1 addition & 0 deletions redis/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,7 @@ pub(crate) fn get_connection_info(
redis: RedisConnectionInfo {
password: cluster_params.password,
username: cluster_params.username,
use_resp3: cluster_params.use_resp3,
..Default::default()
},
})
Expand Down
12 changes: 12 additions & 0 deletions redis/src/cluster_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct BuilderParams {
retries_configuration: RetryParams,
connection_timeout: Option<Duration>,
topology_checks_interval: Option<Duration>,
use_resp3: bool,
}

#[derive(Clone)]
Expand Down Expand Up @@ -86,6 +87,7 @@ pub(crate) struct ClusterParams {
pub(crate) connection_timeout: Duration,
pub(crate) topology_checks_interval: Option<Duration>,
pub(crate) tls_params: Option<TlsConnParams>,
pub(crate) use_resp3: bool,
}

impl ClusterParams {
Expand All @@ -109,6 +111,7 @@ impl ClusterParams {
connection_timeout: value.connection_timeout.unwrap_or(Duration::MAX),
topology_checks_interval: value.topology_checks_interval,
tls_params,
use_resp3: value.use_resp3,
})
}
}
Expand Down Expand Up @@ -315,6 +318,15 @@ impl ClusterClientBuilder {
self
}

/// Sets whether the new ClusterClient should connect to the servers using RESP3.
///
/// If not set, the default is to use RESP3.
#[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
pub fn use_resp3(mut self, use_resp3: bool) -> ClusterClientBuilder {
self.builder_params.use_resp3 = use_resp3;
self
}

/// Use `build()`.
#[deprecated(since = "0.22.0", note = "Use build()")]
pub fn open(self) -> RedisResult<ClusterClient> {
Expand Down
2 changes: 2 additions & 0 deletions redis/tests/support/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use tempfile::TempDir;

use crate::support::{build_keys_and_certs_for_tls, Module};

use super::use_resp3;
#[cfg(feature = "tls-rustls")]
use super::{build_single_client, load_certs_from_file};

Expand Down Expand Up @@ -343,6 +344,7 @@ impl TestClusterContext {
.map(RedisServer::connection_info)
.collect();
let mut builder = redis::cluster::ClusterClientBuilder::new(initial_nodes.clone());
builder = builder.use_resp3(use_resp3());

#[cfg(feature = "tls-rustls")]
if mtls_enabled {
Expand Down
2 changes: 1 addition & 1 deletion redis/tests/support/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use redis::{ClientTlsConfig, TlsCertificates};
use socket2::{Domain, Socket, Type};
use tempfile::TempDir;

fn use_resp3() -> bool {
pub fn use_resp3() -> bool {
env::var("RESP3").unwrap_or_default() == "true"
}

Expand Down
32 changes: 32 additions & 0 deletions redis/tests/test_cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,38 @@ fn test_cluster_multi_shard_commands() {
assert_eq!(res, vec!["bazz", "bar", "foo"]);
}

#[test]
fn test_cluster_resp3() {
if !use_resp3() {
return;
}
let cluster = TestClusterContext::new(3, 0);

let mut connection = cluster.connection();

let hello: std::collections::HashMap<String, Value> =
redis::cmd("HELLO").query(&mut connection).unwrap();
assert_eq!(hello.get("proto").unwrap(), &Value::Int(3));

let _: () = connection.hset("hash", "foo", "baz").unwrap();
let _: () = connection.hset("hash", "bar", "foobar").unwrap();
let result: Value = connection.hgetall("hash").unwrap();

assert_eq!(
result,
Value::Map(vec![
(
Value::BulkString("foo".as_bytes().to_vec()),
Value::BulkString("baz".as_bytes().to_vec())
),
(
Value::BulkString("bar".as_bytes().to_vec()),
Value::BulkString("foobar".as_bytes().to_vec())
)
])
);
}

#[test]
#[cfg(feature = "script")]
fn test_cluster_script() {
Expand Down
61 changes: 50 additions & 11 deletions redis/tests/test_cluster_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,42 @@ fn test_async_cluster_route_info_to_nodes() {
.unwrap();
}

#[test]
fn test_cluster_resp3() {
if !use_resp3() {
return;
}
block_on_all(async move {
let cluster = TestClusterContext::new(3, 0);

let mut connection = cluster.async_connection().await;

let hello: HashMap<String, Value> = redis::cmd("HELLO")
.query_async(&mut connection)
.await
.unwrap();
assert_eq!(hello.get("proto").unwrap(), &Value::Int(3));

let _: () = connection.hset("hash", "foo", "baz").await.unwrap();
let _: () = connection.hset("hash", "bar", "foobar").await.unwrap();
let result: Value = connection.hgetall("hash").await.unwrap();

assert_eq!(
result,
Value::Map(vec![
(
Value::BulkString("foo".as_bytes().to_vec()),
Value::BulkString("baz".as_bytes().to_vec())
),
(
Value::BulkString("bar".as_bytes().to_vec()),
Value::BulkString("foobar".as_bytes().to_vec())
)
])
);
});
}

#[ignore] // TODO Handle pipe where the keys do not all go to the same node
#[test]
fn test_async_cluster_basic_pipe() {
Expand Down Expand Up @@ -1854,21 +1890,24 @@ fn test_async_cluster_round_robin_read_from_replica() {
fn get_queried_node_id_if_master(cluster_nodes_output: Value) -> Option<String> {
// Returns the node ID of the connection that was queried for CLUSTER NODES (using the 'myself' flag), if it's a master.
// Otherwise, returns None.
let get_node_id = |str: &str| {
let parts: Vec<&str> = str.split('\n').collect();
for node_entry in parts {
if node_entry.contains("myself") && node_entry.contains("master") {
let node_entry_parts: Vec<&str> = node_entry.split(' ').collect();
let node_id = node_entry_parts[0];
return Some(node_id.to_string());
}
}
None
};

match cluster_nodes_output {
Value::BulkString(val) => match from_utf8(&val) {
Ok(str_res) => {
let parts: Vec<&str> = str_res.split('\n').collect();
for node_entry in parts {
if node_entry.contains("myself") && node_entry.contains("master") {
let node_entry_parts: Vec<&str> = node_entry.split(' ').collect();
let node_id = node_entry_parts[0];
return Some(node_id.to_string());
}
}
None
}
Ok(str_res) => get_node_id(str_res),
Err(e) => panic!("failed to decode INFO response: {:?}", e),
},
Value::VerbatimString { format: _, text } => get_node_id(&text),
_ => panic!("Recieved unexpected response: {:?}", cluster_nodes_output),
}
}
Expand Down

0 comments on commit 8b3fb0a

Please sign in to comment.