Skip to content

Commit 9f3587c

Browse files
committed
feat(policy): add shard reconnection policies
Add abstract classes: `ShardReconnectionPolicy` and `ShardReconnectionScheduler` And implementations: `NoDelayShardReconnectionPolicy` - policy that represents old behavior of having no delay and no concurrency restriction. `NoConcurrentShardReconnectionPolicy` - policy that limits concurrent reconnections to 1 per scope and introduces delay between reconnections within the scope.
1 parent d5834c6 commit 9f3587c

File tree

2 files changed

+479
-2
lines changed

2 files changed

+479
-2
lines changed

cassandra/policies.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,28 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import random
15+
import threading
16+
import time
17+
import weakref
18+
from abc import ABC, abstractmethod
1519

1620
from collections import namedtuple
21+
from enum import Enum
1722
from functools import lru_cache
1823
from itertools import islice, cycle, groupby, repeat
1924
import logging
2025
from random import randint, shuffle
2126
from threading import Lock
2227
import socket
2328
import warnings
29+
from typing import TYPE_CHECKING, Callable, Any, List, Tuple, Iterator, Optional
2430

2531
log = logging.getLogger(__name__)
2632

2733
from cassandra import WriteType as WT
2834

35+
if TYPE_CHECKING:
36+
from cluster import Session
2937

3038
# This is done this way because WriteType was originally
3139
# defined here and in order not to break the API.
@@ -864,6 +872,228 @@ def _add_jitter(self, value):
864872
return min(max(self.base_delay, delay), self.max_delay)
865873

866874

875+
class ShardReconnectionScheduler(ABC):
876+
@abstractmethod
877+
def schedule(
878+
self,
879+
host_id: str,
880+
shard_id: int,
881+
method: Callable[..., None],
882+
*args: List[Any],
883+
**kwargs: dict[Any, Any]) -> None:
884+
raise NotImplementedError()
885+
886+
887+
class ShardReconnectionPolicy(ABC):
888+
"""
889+
Base class for shard reconnection policies.
890+
891+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
892+
"""
893+
894+
@abstractmethod
895+
def new_scheduler(self, session: Session) -> ShardReconnectionScheduler:
896+
raise NotImplementedError()
897+
898+
899+
class NoDelayShardReconnectionPolicy(ShardReconnectionPolicy):
900+
"""
901+
A shard reconnection policy with no delay between attempts and no concurrency restrictions.
902+
Ensures at most one pending reconnection per (host, shard) pair — any additional
903+
reconnection attempts for the same (host, shard) are silently ignored.
904+
905+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
906+
"""
907+
908+
def new_scheduler(self, session: Session) -> ShardReconnectionScheduler:
909+
return _NoDelayShardReconnectionScheduler(session)
910+
911+
912+
class _NoDelayShardReconnectionScheduler(ShardReconnectionScheduler):
913+
session: Session
914+
already_scheduled: dict[str, bool]
915+
916+
def __init__(self, session: Session):
917+
self.session = weakref.proxy(session)
918+
self.already_scheduled = {}
919+
920+
def _execute(
921+
self,
922+
scheduled_key: str,
923+
method: Callable[..., None],
924+
*args: List[Any],
925+
**kwargs: dict[Any, Any]) -> None:
926+
try:
927+
method(*args, **kwargs)
928+
finally:
929+
self.already_scheduled[scheduled_key] = False
930+
931+
def schedule(
932+
self,
933+
host_id: str,
934+
shard_id: int,
935+
method: Callable[..., None],
936+
*args: List[Any],
937+
**kwargs: dict[Any, Any]) -> None:
938+
scheduled_key = f'{host_id}-{shard_id}'
939+
if self.already_scheduled.get(scheduled_key):
940+
return
941+
942+
self.already_scheduled[scheduled_key] = True
943+
if not self.session.is_shutdown:
944+
self.session.submit(self._execute, scheduled_key, method, *args, **kwargs)
945+
946+
947+
class ShardReconnectionPolicyScope(Enum):
948+
"""
949+
A scope for `ShardReconnectionPolicy`, in particular `NoConcurrentShardReconnectionPolicy`
950+
"""
951+
Cluster = 0
952+
Host = 1
953+
954+
955+
class NoConcurrentShardReconnectionPolicy(ShardReconnectionPolicy):
956+
"""
957+
A shard reconnection policy that allows only one pending connection per scope, where scope could be `Host`, `Cluster`
958+
For backoff it uses `ReconnectionPolicy`, when there is no more reconnections to scheduled backoff policy is reminded
959+
For all scopes does not allow schedule multiple reconnections for same host+shard, it silently ignores attempts to do that.
960+
961+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
962+
"""
963+
shard_reconnection_scope: ShardReconnectionPolicyScope
964+
reconnection_policy: ReconnectionPolicy
965+
966+
def __init__(
967+
self,
968+
shard_reconnection_scope: ShardReconnectionPolicyScope,
969+
reconnection_policy: ReconnectionPolicy,
970+
):
971+
if not isinstance(shard_reconnection_scope, ShardReconnectionPolicyScope):
972+
raise ValueError("shard_reconnection_scope must be a ShardReconnectionPolicyScope")
973+
if not isinstance(reconnection_policy, ReconnectionPolicy):
974+
raise ValueError("reconnection_policy must be a ReconnectionPolicy")
975+
self.shard_reconnection_scope = shard_reconnection_scope
976+
self.reconnection_policy = reconnection_policy
977+
978+
def new_scheduler(self, session: Session) -> ShardReconnectionScheduler:
979+
return _NoConcurrentShardReconnectionScheduler(session, self.shard_reconnection_scope, self.reconnection_policy)
980+
981+
982+
class _ScopeBucket:
983+
"""
984+
Holds information for a shard reconnection scope, schedules and executes reconnections.
985+
"""
986+
items: List[Tuple[Callable[..., None], Tuple[Any, ...], dict[str, Any]]]
987+
session: Session
988+
reconnection_policy: ReconnectionPolicy
989+
lock = threading.Lock
990+
schedule: Optional[Iterator[float]]
991+
992+
running: bool = False
993+
994+
def __init__(
995+
self,
996+
session: Session,
997+
reconnection_policy: ReconnectionPolicy,
998+
):
999+
self.items = []
1000+
self.session = session
1001+
self.reconnection_policy = reconnection_policy
1002+
self.lock = threading.Lock()
1003+
self.schedule = self.reconnection_policy.new_schedule()
1004+
1005+
def _get_delay(self) -> float:
1006+
if self.schedule is None:
1007+
self.schedule = self.reconnection_policy.new_schedule()
1008+
try:
1009+
return next(self.schedule)
1010+
except StopIteration:
1011+
self.schedule = self.reconnection_policy.new_schedule()
1012+
return next(self.schedule)
1013+
1014+
def _schedule(self):
1015+
if self.session.is_shutdown:
1016+
return
1017+
delay = self._get_delay()
1018+
if delay:
1019+
self.session.cluster.scheduler.schedule(delay, self._run)
1020+
else:
1021+
self.session.submit(self._run)
1022+
1023+
def _run(self):
1024+
if self.session.is_shutdown:
1025+
return
1026+
1027+
with self.lock:
1028+
try:
1029+
item = self.items.pop()
1030+
except IndexError:
1031+
self.running = False
1032+
self.schedule = None
1033+
return
1034+
1035+
method, args, kwargs = item
1036+
try:
1037+
method(*args, **kwargs)
1038+
finally:
1039+
self._schedule()
1040+
1041+
def add(self, method: Callable[..., None], *args, **kwargs):
1042+
with self.lock:
1043+
self.items.append((method, args, kwargs))
1044+
if not self.running:
1045+
self.running = True
1046+
self._schedule()
1047+
1048+
1049+
class _NoConcurrentShardReconnectionScheduler(ShardReconnectionScheduler):
1050+
already_scheduled: dict[str, bool]
1051+
scopes: dict[str, _ScopeBucket]
1052+
shard_reconnection_scope: ShardReconnectionPolicyScope
1053+
reconnection_policy: ReconnectionPolicy
1054+
session: Session
1055+
lock: threading.Lock
1056+
1057+
def __init__(
1058+
self,
1059+
session: Session,
1060+
shard_reconnection_scope: ShardReconnectionPolicyScope,
1061+
reconnection_policy: ReconnectionPolicy,
1062+
):
1063+
self.already_scheduled = {}
1064+
self.scopes = {}
1065+
self.shard_reconnection_scope = shard_reconnection_scope
1066+
self.reconnection_policy = reconnection_policy
1067+
self.session = session
1068+
self.lock = threading.Lock()
1069+
1070+
def _execute(self, scheduled_key: str, method: Callable[..., None], *args, **kwargs):
1071+
try:
1072+
method(*args, **kwargs)
1073+
finally:
1074+
with self.lock:
1075+
self.already_scheduled[scheduled_key] = False
1076+
1077+
def schedule(self, host_id: str, shard_id: int, method: Callable[..., None], *args, **kwargs):
1078+
if self.shard_reconnection_scope == ShardReconnectionPolicyScope.Cluster:
1079+
scope_hash = "global-cluster-scope"
1080+
else:
1081+
scope_hash = host_id
1082+
scheduled_key = f'{host_id}-{shard_id}'
1083+
1084+
with self.lock:
1085+
if self.already_scheduled.get(scheduled_key):
1086+
return False
1087+
self.already_scheduled[scheduled_key] = True
1088+
1089+
scope_info = self.scopes.get(scope_hash, 0)
1090+
if not scope_info:
1091+
scope_info = _ScopeBucket(self.session, self.reconnection_policy)
1092+
self.scopes[scope_hash] = scope_info
1093+
scope_info.add(self._execute, scheduled_key, method, *args, **kwargs)
1094+
return True
1095+
1096+
8671097
class RetryPolicy(object):
8681098
"""
8691099
A policy that describes whether to retry, rethrow, or ignore coordinator

0 commit comments

Comments
 (0)