Skip to content

Commit 9040c07

Browse files
committed
CR changes: Add async method to DisconnectNotifier trait, styling and other cleanups
1 parent 73ff308 commit 9040c07

File tree

2 files changed

+69
-53
lines changed

2 files changed

+69
-53
lines changed

Diff for: redis/src/aio/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::net::SocketAddr;
1212
#[cfg(unix)]
1313
use std::path::Path;
1414
use std::pin::Pin;
15+
use std::time::Duration;
1516

1617
/// Enables the async_std compatibility
1718
#[cfg(feature = "async-std-comp")]
@@ -91,10 +92,14 @@ pub trait ConnectionLike {
9192
}
9293

9394
/// Implements ability to notify about disconnection events
95+
#[async_trait]
9496
pub trait DisconnectNotifier: Send + Sync {
9597
/// Notify about disconnect event
9698
fn notify_disconnect(&mut self);
9799

100+
/// Wait for disconnect event with timeout
101+
async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration);
102+
98103
/// Intended to be used with Box
99104
fn clone_box(&self) -> Box<dyn DisconnectNotifier>;
100105
}

Diff for: redis/src/cluster_async/mod.rs

+64-53
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ use backoff_tokio::future::retry;
9797
use backoff_tokio::{Error as BackoffError, ExponentialBackoff};
9898
#[cfg(feature = "tokio-comp")]
9999
use tokio::{sync::Notify, time::timeout};
100+
#[cfg(feature = "tokio-comp")]
101+
use async_trait::async_trait;
100102

101103
use dispose::{Disposable, Dispose};
102104
use futures::{future::BoxFuture, prelude::*, ready};
@@ -379,20 +381,37 @@ where
379381
#[cfg(feature = "tokio-comp")]
380382
#[derive(Clone)]
381383
struct TokioDisconnectNotifier {
382-
pub disconnect_notifier: Arc<Notify>,
384+
disconnect_notifier: Arc<Notify>,
383385
}
384386

385387
#[cfg(feature = "tokio-comp")]
388+
#[async_trait]
386389
impl DisconnectNotifier for TokioDisconnectNotifier {
387390
fn notify_disconnect(&mut self) {
388391
self.disconnect_notifier.notify_one();
389392
}
390393

394+
async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration) {
395+
let _ = timeout(*max_wait, async {
396+
self.disconnect_notifier.notified().await;
397+
})
398+
.await;
399+
}
400+
391401
fn clone_box(&self) -> Box<dyn DisconnectNotifier> {
392402
Box::new(self.clone())
393403
}
394404
}
395405

406+
#[cfg(feature = "tokio-comp")]
407+
impl TokioDisconnectNotifier {
408+
fn new() -> TokioDisconnectNotifier {
409+
TokioDisconnectNotifier {
410+
disconnect_notifier: Arc::new(Notify::new()),
411+
}
412+
}
413+
}
414+
396415
type ConnectionMap<C> = connections_container::ConnectionsMap<ConnectionFuture<C>>;
397416
type ConnectionsContainer<C> =
398417
self::connections_container::ConnectionsContainer<ConnectionFuture<C>>;
@@ -406,8 +425,6 @@ pub(crate) struct InnerCore<C> {
406425
subscriptions_by_address: RwLock<HashMap<String, PubSubSubscriptionInfo>>,
407426
unassigned_subscriptions: RwLock<PubSubSubscriptionInfo>,
408427
glide_connection_options: GlideConnectionOptions,
409-
#[cfg(feature = "tokio-comp")]
410-
tokio_notify: Arc<Notify>,
411428
}
412429

413430
pub(crate) type Core<C> = Arc<InnerCore<C>>;
@@ -990,27 +1007,24 @@ where
9901007
cluster_params: ClusterParams,
9911008
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
9921009
) -> RedisResult<Disposable<Self>> {
993-
#[cfg(feature = "tokio-comp")]
994-
let tokio_notify = Arc::new(Notify::new());
995-
9961010
let disconnect_notifier = {
9971011
#[cfg(feature = "tokio-comp")]
9981012
{
999-
Some::<Box<dyn DisconnectNotifier>>(Box::new(TokioDisconnectNotifier {
1000-
disconnect_notifier: tokio_notify.clone(),
1001-
}))
1013+
Some::<Box<dyn DisconnectNotifier>>(Box::new(TokioDisconnectNotifier::new()))
10021014
}
10031015
#[cfg(not(feature = "tokio-comp"))]
10041016
None
10051017
};
10061018

1019+
let glide_connection_options = GlideConnectionOptions {
1020+
push_sender,
1021+
disconnect_notifier,
1022+
};
1023+
10071024
let connections = Self::create_initial_connections(
10081025
initial_nodes,
10091026
&cluster_params,
1010-
GlideConnectionOptions {
1011-
push_sender: push_sender.clone(),
1012-
disconnect_notifier: disconnect_notifier.clone(),
1013-
},
1027+
glide_connection_options.clone(),
10141028
)
10151029
.await?;
10161030

@@ -1035,12 +1049,7 @@ where
10351049
},
10361050
),
10371051
subscriptions_by_address: RwLock::new(Default::default()),
1038-
glide_connection_options: GlideConnectionOptions {
1039-
push_sender: push_sender.clone(),
1040-
disconnect_notifier: disconnect_notifier.clone(),
1041-
},
1042-
#[cfg(feature = "tokio-comp")]
1043-
tokio_notify,
1052+
glide_connection_options,
10441053
});
10451054
let mut connection = ClusterConnInner {
10461055
inner,
@@ -1227,40 +1236,40 @@ where
12271236
// In addition, the validation is done by peeking at the state of the underlying transport w/o overhead of additional commands to server.
12281237
async fn validate_all_user_connections(inner: Arc<InnerCore<C>>) {
12291238
let mut all_valid_conns = HashMap::new();
1230-
let mut all_nodes_with_slots = HashSet::new();
12311239
// prep connections and clean out these w/o assigned slots, as we might have established connections to unwanted hosts
1232-
{
1233-
let mut nodes_to_delete = Vec::new();
1234-
let connections_container = inner.conn_lock.read().await;
1235-
1236-
connections_container
1237-
.slot_map
1238-
.addresses_for_all_nodes()
1239-
.iter()
1240-
.for_each(|addr| {
1241-
all_nodes_with_slots.insert(String::from(*addr));
1242-
});
1240+
let mut nodes_to_delete = Vec::new();
1241+
let connections_container = inner.conn_lock.read().await;
12431242

1244-
connections_container
1245-
.all_node_connections()
1246-
.for_each(|(addr, con)| {
1247-
if all_nodes_with_slots.contains(&addr) {
1248-
all_valid_conns.insert(addr.clone(), con.clone());
1249-
} else {
1250-
nodes_to_delete.push(addr.clone());
1251-
}
1252-
});
1243+
let all_nodes_with_slots: HashSet<String> = connections_container
1244+
.slot_map
1245+
.addresses_for_all_nodes()
1246+
.iter()
1247+
.map(|addr| String::from(*addr))
1248+
.collect();
1249+
1250+
connections_container
1251+
.all_node_connections()
1252+
.for_each(|(addr, con)| {
1253+
if all_nodes_with_slots.contains(&addr) {
1254+
all_valid_conns.insert(addr.clone(), con.clone());
1255+
} else {
1256+
nodes_to_delete.push(addr.clone());
1257+
}
1258+
});
12531259

1254-
for addr in &nodes_to_delete {
1255-
connections_container.remove_node(addr);
1256-
}
1260+
for addr in &nodes_to_delete {
1261+
connections_container.remove_node(addr);
12571262
}
12581263

1264+
drop(connections_container);
1265+
12591266
// identify nodes with closed connection
12601267
let mut addrs_to_refresh = Vec::new();
12611268
for (addr, con_fut) in &all_valid_conns {
12621269
let con = con_fut.clone().await;
1270+
// connection object might be present despite the transport being closed
12631271
if con.is_closed() {
1272+
// transport is closed, need to refresh
12641273
addrs_to_refresh.push(addr.clone());
12651274
}
12661275
}
@@ -1289,7 +1298,7 @@ where
12891298
inner: Arc<InnerCore<C>>,
12901299
addresses: Vec<String>,
12911300
conn_type: RefreshConnectionType,
1292-
try_existing_node: bool,
1301+
check_existing_conn: bool,
12931302
) {
12941303
info!("Started refreshing connections to {:?}", addresses);
12951304
let connections_container = inner.conn_lock.read().await;
@@ -1301,10 +1310,10 @@ where
13011310
.fold(
13021311
&*connections_container,
13031312
|connections_container, address| async move {
1304-
let node_option = if try_existing_node {
1313+
let node_option = if check_existing_conn {
13051314
connections_container.remove_node(&address)
13061315
} else {
1307-
Option::None
1316+
None
13081317
};
13091318

13101319
// override subscriptions for this connection
@@ -1541,13 +1550,15 @@ where
15411550

15421551
async fn connections_validation_task(inner: Arc<InnerCore<C>>, interval_duration: Duration) {
15431552
loop {
1544-
#[cfg(feature = "tokio-comp")]
1545-
let _ = timeout(interval_duration, async {
1546-
inner.tokio_notify.notified().await;
1547-
})
1548-
.await;
1549-
#[cfg(not(feature = "tokio-comp"))]
1550-
let _ = boxed_sleep(interval_duration).await;
1553+
if let Some(disconnect_notifier) =
1554+
inner.glide_connection_options.disconnect_notifier.clone()
1555+
{
1556+
disconnect_notifier
1557+
.wait_for_disconnect_with_timeout(&interval_duration)
1558+
.await;
1559+
} else {
1560+
let _ = boxed_sleep(interval_duration).await;
1561+
}
15511562

15521563
Self::validate_all_user_connections(inner.clone()).await;
15531564
}

0 commit comments

Comments
 (0)