Skip to content

Commit f62dfa3

Browse files
committed
Add shard-aware reconnection policies with support for scheduling constraints
Introduce `ShardReconnectionPolicy` and its implementations: - `NoDelayShardReconnectionPolicy`: avoids reconnection delay and ensures at most one reconnection per host+shard. - `NoConcurrentShardReconnectionPolicy`: limits concurrent reconnections per scope (Cluster or Host) using a backoff policy. This feature enables finer control over shard reconnection behavior, helping prevent reconnection storms.
1 parent d5834c6 commit f62dfa3

File tree

8 files changed

+545
-47
lines changed

8 files changed

+545
-47
lines changed

cassandra/cluster.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@
7272
ExponentialReconnectionPolicy, HostDistance,
7373
RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan,
7474
NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy,
75-
NeverRetryPolicy)
75+
NeverRetryPolicy, ConstantReconnectionPolicy,
76+
ShardReconnectionPolicyScope, ShardReconnectionPolicy, NoConcurrentShardReconnectionPolicy)
7677
from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler,
7778
HostConnectionPool, HostConnection,
7879
NoConnectionsAvailable)
@@ -742,6 +743,19 @@ def auth_provider(self, value):
742743

743744
self._auth_provider = value
744745

746+
_shard_reconnection_policy = None
747+
@property
748+
def shard_reconnection_policy(self):
749+
return self._shard_reconnection_policy
750+
751+
@shard_reconnection_policy.setter
752+
def shard_reconnection_policy(self, srp):
753+
if self._config_mode == _ConfigMode.PROFILES:
754+
raise ValueError(
755+
"Cannot set Cluster.shard_reconnection_policy while using Configuration Profiles. Set this in a profile instead.")
756+
self._shard_reconnection_policy = srp
757+
self._config_mode = _ConfigMode.LEGACY
758+
745759
_load_balancing_policy = None
746760
@property
747761
def load_balancing_policy(self):
@@ -1204,6 +1218,7 @@ def __init__(self,
12041218
shard_aware_options=None,
12051219
metadata_request_timeout=None,
12061220
column_encryption_policy=None,
1221+
shard_reconnection_policy=None,
12071222
):
12081223
"""
12091224
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
@@ -1309,6 +1324,17 @@ def __init__(self,
13091324
else:
13101325
self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode
13111326

1327+
if shard_reconnection_policy is not None:
1328+
if isinstance(shard_reconnection_policy, type):
1329+
raise TypeError("load_balancing_policy should not be a class, it should be an instance of that class")
1330+
if not isinstance(shard_reconnection_policy, ShardReconnectionPolicy):
1331+
raise TypeError("load_balancing_policy should be an instance of class derived from ReconnectionPolicy")
1332+
self.shard_reconnection_policy = shard_reconnection_policy
1333+
else:
1334+
self._shard_reconnection_policy = NoConcurrentShardReconnectionPolicy(
1335+
ShardReconnectionPolicyScope.Host,
1336+
ConstantReconnectionPolicy(2, 0))
1337+
13121338
if reconnection_policy is not None:
13131339
if isinstance(reconnection_policy, type):
13141340
raise TypeError("reconnection_policy should not be a class, it should be an instance of that class")
@@ -2707,6 +2733,7 @@ def __init__(self, cluster, hosts, keyspace=None):
27072733
self._protocol_version = self.cluster.protocol_version
27082734

27092735
self.encoder = Encoder()
2736+
self.shard_reconnection_scheduler = cluster.shard_reconnection_policy.new_scheduler(self)
27102737

27112738
# create connection pools in parallel
27122739
self._initial_connect_futures = set()
@@ -4432,6 +4459,9 @@ def shutdown(self):
44324459
self._queue.put_nowait((0, 0, None))
44334460
self.join()
44344461

4462+
def empty(self):
4463+
return len(self._scheduled_tasks) == 0 and self._queue.empty()
4464+
44354465
def schedule(self, delay, fn, *args, **kwargs):
44364466
self._insert_task(delay, (fn, args, tuple(kwargs.items())))
44374467

cassandra/policies.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import random
15+
import threading
16+
import time
17+
import weakref
1518

1619
from collections import namedtuple
20+
from enum import Enum
1721
from functools import lru_cache
1822
from itertools import islice, cycle, groupby, repeat
1923
import logging
@@ -778,6 +782,14 @@ def new_schedule(self):
778782
raise NotImplementedError()
779783

780784

785+
class NoDelayReconnectionPolicy(ReconnectionPolicy):
786+
"""
787+
A :class:`.ReconnectionPolicy` subclass which does not sleep.
788+
"""
789+
def new_schedule(self):
790+
return repeat(0)
791+
792+
781793
class ConstantReconnectionPolicy(ReconnectionPolicy):
782794
"""
783795
A :class:`.ReconnectionPolicy` subclass which sleeps for a fixed delay
@@ -864,6 +876,172 @@ def _add_jitter(self, value):
864876
return min(max(self.base_delay, delay), self.max_delay)
865877

866878

879+
class _ShardReconnectionScheduler(object):
880+
def schedule(self, host_id, shard_id, method, *args, **kwargs):
881+
raise NotImplementedError()
882+
883+
class ShardReconnectionPolicy(object):
884+
"""
885+
Base class for shard reconnection policies.
886+
887+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
888+
"""
889+
def new_scheduler(self, session) -> _ShardReconnectionScheduler:
890+
raise NotImplementedError()
891+
892+
893+
class NoDelayShardReconnectionPolicy(ShardReconnectionPolicy):
894+
"""
895+
Shard reconnection policy that does not have delay.
896+
Does not allow schedule multiple reconnections for same host+shard, silently ignores attempts to do that.
897+
898+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
899+
"""
900+
def new_scheduler(self, session) -> _ShardReconnectionScheduler:
901+
return _NoDelayShardReconnectionScheduler(session)
902+
903+
904+
class _NoDelayShardReconnectionScheduler(_ShardReconnectionScheduler):
905+
def __init__(self, session):
906+
self.session = weakref.proxy(session)
907+
self.already_scheduled = {}
908+
909+
def _execute(self, scheduled_key, method, *args, **kwargs):
910+
try:
911+
method(*args, **kwargs)
912+
finally:
913+
self.already_scheduled[scheduled_key] = False
914+
915+
def schedule(self, host_id, shard_id, method, *args, **kwargs):
916+
scheduled_key = f'{host_id}-{shard_id}'
917+
if self.already_scheduled.get(scheduled_key):
918+
return
919+
920+
self.already_scheduled[scheduled_key] = True
921+
if not self.session.is_shutdown:
922+
self.session.submit(self._execute, scheduled_key, method, *args, **kwargs)
923+
924+
925+
class ShardReconnectionPolicyScope(Enum):
926+
"""
927+
A scope for `ShardReconnectionPolicy`, in particular `NoConcurrentShardReconnectionPolicy`
928+
"""
929+
Cluster = 0
930+
Host = 1
931+
932+
933+
class NoConcurrentShardReconnectionPolicy(ShardReconnectionPolicy):
934+
"""
935+
A shard reconnection policy that allows only one pending connection per scope, where scope could be `Host`, `Cluster`
936+
For backoff it uses `ReconnectionPolicy`, when there is no more reconnections to scheduled backoff policy is reminded
937+
For all scopes does not allow schedule multiple reconnections for same host+shard, it silently ignores attempts to do that.
938+
939+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
940+
"""
941+
def __init__(self, shard_reconnection_scope, reconnection_policy):
942+
if not isinstance(shard_reconnection_scope, ShardReconnectionPolicyScope):
943+
raise ValueError("shard_reconnection_scope must be a ShardReconnectionPolicyScope")
944+
if not isinstance(reconnection_policy, ReconnectionPolicy):
945+
raise ValueError("reconnection_policy must be a ReconnectionPolicy")
946+
self.shard_reconnection_scope = shard_reconnection_scope
947+
self.reconnection_policy = reconnection_policy
948+
949+
def new_scheduler(self, session) -> _ShardReconnectionScheduler:
950+
return _NoConcurrentShardReconnectionScheduler(session, self.shard_reconnection_scope, self.reconnection_policy)
951+
952+
953+
class _ScopeBucket(object):
954+
"""
955+
Holds information for a shard reconnection scope, schedules and executes reconnections.
956+
"""
957+
def __init__(self, session, reconnection_policy):
958+
self._items = []
959+
self.session = session
960+
self.reconnection_policy = reconnection_policy
961+
self.lock = threading.Lock()
962+
self.running = False
963+
self.schedule = self.reconnection_policy.new_schedule()
964+
965+
def _get_delay(self):
966+
if self.schedule is None:
967+
self.schedule = self.reconnection_policy.new_schedule()
968+
try:
969+
return next(self.schedule)
970+
except StopIteration:
971+
self.schedule = self.reconnection_policy.new_schedule()
972+
return next(self.schedule)
973+
974+
def _schedule(self):
975+
if self.session.is_shutdown:
976+
return
977+
delay = self._get_delay()
978+
if delay:
979+
self.session.cluster.scheduler.schedule(delay, self._run)
980+
else:
981+
self.session.submit(self._run)
982+
983+
def _run(self):
984+
if self.session.is_shutdown:
985+
return
986+
987+
with self.lock:
988+
try:
989+
item = self._items.pop()
990+
except IndexError:
991+
self.running = False
992+
self.schedule = None
993+
return
994+
995+
method, args, kwargs = item
996+
try:
997+
method(*args, **kwargs)
998+
finally:
999+
self._schedule()
1000+
1001+
def add(self, method, *args, **kwargs):
1002+
with self.lock:
1003+
self._items.append([method, args, kwargs])
1004+
if not self.running:
1005+
self.running = True
1006+
self._schedule()
1007+
1008+
1009+
class _NoConcurrentShardReconnectionScheduler(_ShardReconnectionScheduler):
1010+
def __init__(self, session, shard_reconnection_scope, reconnection_policy):
1011+
self.already_scheduled = {}
1012+
self.scopes = {}
1013+
self.shard_reconnection_scope = shard_reconnection_scope
1014+
self.reconnection_policy = reconnection_policy
1015+
self.session = session
1016+
self.lock = threading.Lock()
1017+
1018+
def _execute(self, scheduled_key, method, *args, **kwargs):
1019+
try:
1020+
method(*args, **kwargs)
1021+
finally:
1022+
with self.lock:
1023+
self.already_scheduled[scheduled_key] = False
1024+
1025+
def schedule(self, host_id, shard_id, method, *args, **kwargs):
1026+
if self.shard_reconnection_scope == ShardReconnectionPolicyScope.Cluster:
1027+
scope_hash = "global-cluster-scope"
1028+
else:
1029+
scope_hash = host_id
1030+
scheduled_key = f'{host_id}-{shard_id}'
1031+
1032+
with self.lock:
1033+
if self.already_scheduled.get(scheduled_key):
1034+
return False
1035+
self.already_scheduled[scheduled_key] = True
1036+
1037+
scope_info = self.scopes.get(scope_hash, 0)
1038+
if not scope_info:
1039+
scope_info = _ScopeBucket(self.session, self.reconnection_policy)
1040+
self.scopes[scope_hash] = scope_info
1041+
scope_info.add(self._execute, scheduled_key, method,*args, **kwargs)
1042+
return True
1043+
1044+
8671045
class RetryPolicy(object):
8681046
"""
8691047
A policy that describes whether to retry, rethrow, or ignore coordinator

cassandra/pool.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ def __init__(self, host, host_distance, session):
402402
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
403403
self._stream_available_condition = Condition(Lock())
404404
self._is_replacing = False
405-
self._connecting = set()
406405
self._connections = {}
407406
self._pending_connections = []
408407
# A pool of additional connections which are not used but affect how Scylla
@@ -418,7 +417,6 @@ def __init__(self, host, host_distance, session):
418417
# and are waiting until all requests time out or complete
419418
# so that we can dispose of them.
420419
self._trash = set()
421-
self._shard_connections_futures = []
422420
self.advanced_shardaware_block_until = 0
423421

424422
if host_distance == HostDistance.IGNORED:
@@ -483,25 +481,25 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
483481
self.host,
484482
routing_key
485483
)
486-
if conn.orphaned_threshold_reached and shard_id not in self._connecting:
484+
if conn.orphaned_threshold_reached:
487485
# The connection has met its orphaned stream ID limit
488486
# and needs to be replaced. Start opening a connection
489487
# to the same shard and replace when it is opened.
490-
self._connecting.add(shard_id)
491-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
488+
self._session.shard_reconnection_scheduler.schedule(
489+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
492490
log.debug(
493-
"Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
491+
"Scheduling Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
494492
shard_id,
495493
self.host,
496494
len(self._connections.keys()),
497495
self.host.sharding_info.shards_count
498496
)
499-
elif shard_id not in self._connecting:
497+
else:
500498
# rate controlled optimistic attempt to connect to a missing shard
501-
self._connecting.add(shard_id)
502-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
499+
self._session.shard_reconnection_scheduler.schedule(
500+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
503501
log.debug(
504-
"Trying to connect to missing shard_id=%i on host %s (%s/%i)",
502+
"Scheduling connection to missing shard_id=%i on host %s (%s/%i)",
505503
shard_id,
506504
self.host,
507505
len(self._connections.keys()),
@@ -609,8 +607,8 @@ def _replace(self, connection):
609607
if connection.features.shard_id in self._connections.keys():
610608
del self._connections[connection.features.shard_id]
611609
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
612-
self._connecting.add(connection.features.shard_id)
613-
self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id)
610+
self._session.shard_reconnection_scheduler.schedule(
611+
self.host.host_id, connection.features.shard_id, self._open_connection_to_missing_shard, connection.features.shard_id)
614612
else:
615613
connection = self._session.cluster.connection_factory(self.host.endpoint,
616614
on_orphaned_stream_released=self.on_orphaned_stream_released)
@@ -635,9 +633,6 @@ def shutdown(self):
635633
with self._stream_available_condition:
636634
self._stream_available_condition.notify_all()
637635

638-
for future in self._shard_connections_futures:
639-
future.cancel()
640-
641636
connections_to_close = self._connections.copy()
642637
pending_connections_to_close = self._pending_connections.copy()
643638
self._connections.clear()
@@ -843,7 +838,6 @@ def _open_connection_to_missing_shard(self, shard_id):
843838
self._excess_connections.add(conn)
844839
if close_connection:
845840
conn.close()
846-
self._connecting.discard(shard_id)
847841

848842
def _open_connections_for_all_shards(self, skip_shard_id=None):
849843
"""
@@ -856,10 +850,8 @@ def _open_connections_for_all_shards(self, skip_shard_id=None):
856850
for shard_id in range(self.host.sharding_info.shards_count):
857851
if skip_shard_id is not None and skip_shard_id == shard_id:
858852
continue
859-
future = self._session.submit(self._open_connection_to_missing_shard, shard_id)
860-
if isinstance(future, Future):
861-
self._connecting.add(shard_id)
862-
self._shard_connections_futures.append(future)
853+
self._session.shard_reconnection_scheduler.schedule(
854+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
863855

864856
trash_conns = None
865857
with self._lock:

tests/integration/experiments/test_tablets.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import pytest
44

55
from cassandra.cluster import Cluster
6-
from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy
6+
from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy, \
7+
NoDelayShardReconnectionPolicy
78

89
from tests.integration import PROTOCOL_VERSION, use_cluster
910
from tests.unit.test_host_connection_pool import LOGGER
@@ -21,7 +22,9 @@ class TestTabletsIntegration:
2122
def setup_class(cls):
2223
cls.cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION,
2324
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
24-
reconnection_policy=ConstantReconnectionPolicy(1))
25+
reconnection_policy=ConstantReconnectionPolicy(1),
26+
shard_reconnection_policy=NoDelayShardReconnectionPolicy(),
27+
)
2528
cls.session = cls.cluster.connect()
2629
cls.create_ks_and_cf(cls.session)
2730
cls.create_data(cls.session)

0 commit comments

Comments
 (0)