Skip to content

Commit 2f5ca3c

Browse files
committed
feat(policy): add shard reconnection policies
Add abstract classes: `ShardReconnectionPolicy` and `ShardReconnectionScheduler` And implementations: `NoDelayShardReconnectionPolicy` - policy that represents old behavior of having no delay and no concurrency restriction. `NoConcurrentShardReconnectionPolicy` - policy that limits concurrent reconnections to 1 per scope and introduces delay between reconnections within the scope.
1 parent d5834c6 commit 2f5ca3c

File tree

2 files changed

+481
-2
lines changed

2 files changed

+481
-2
lines changed

cassandra/policies.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,31 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
15+
1416
import random
17+
import threading
18+
import time
19+
import weakref
20+
from abc import ABC, abstractmethod
1521

1622
from collections import namedtuple
23+
from enum import Enum
1724
from functools import lru_cache
1825
from itertools import islice, cycle, groupby, repeat
1926
import logging
2027
from random import randint, shuffle
2128
from threading import Lock
2229
import socket
2330
import warnings
31+
from typing import TYPE_CHECKING, Callable, Any, List, Tuple, Iterator, Optional
2432

2533
log = logging.getLogger(__name__)
2634

2735
from cassandra import WriteType as WT
2836

37+
if TYPE_CHECKING:
38+
from cluster import Session
2939

3040
# This is done this way because WriteType was originally
3141
# defined here and in order not to break the API.
@@ -864,6 +874,228 @@ def _add_jitter(self, value):
864874
return min(max(self.base_delay, delay), self.max_delay)
865875

866876

877+
class ShardReconnectionScheduler(ABC):
878+
@abstractmethod
879+
def schedule(
880+
self,
881+
host_id: str,
882+
shard_id: int,
883+
method: Callable[..., None],
884+
*args: List[Any],
885+
**kwargs: dict[Any, Any]) -> None:
886+
raise NotImplementedError()
887+
888+
889+
class ShardReconnectionPolicy(ABC):
890+
"""
891+
Base class for shard reconnection policies.
892+
893+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
894+
"""
895+
896+
@abstractmethod
897+
def new_scheduler(self, session: Session) -> ShardReconnectionScheduler:
898+
raise NotImplementedError()
899+
900+
901+
class NoDelayShardReconnectionPolicy(ShardReconnectionPolicy):
902+
"""
903+
A shard reconnection policy with no delay between attempts and no concurrency restrictions.
904+
Ensures at most one pending reconnection per (host, shard) pair — any additional
905+
reconnection attempts for the same (host, shard) are silently ignored.
906+
907+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
908+
"""
909+
910+
def new_scheduler(self, session: Session) -> ShardReconnectionScheduler:
911+
return _NoDelayShardReconnectionScheduler(session)
912+
913+
914+
class _NoDelayShardReconnectionScheduler(ShardReconnectionScheduler):
915+
session: Session
916+
already_scheduled: dict[str, bool]
917+
918+
def __init__(self, session: Session):
919+
self.session = weakref.proxy(session)
920+
self.already_scheduled = {}
921+
922+
def _execute(
923+
self,
924+
scheduled_key: str,
925+
method: Callable[..., None],
926+
*args: List[Any],
927+
**kwargs: dict[Any, Any]) -> None:
928+
try:
929+
method(*args, **kwargs)
930+
finally:
931+
self.already_scheduled[scheduled_key] = False
932+
933+
def schedule(
934+
self,
935+
host_id: str,
936+
shard_id: int,
937+
method: Callable[..., None],
938+
*args: List[Any],
939+
**kwargs: dict[Any, Any]) -> None:
940+
scheduled_key = f'{host_id}-{shard_id}'
941+
if self.already_scheduled.get(scheduled_key):
942+
return
943+
944+
self.already_scheduled[scheduled_key] = True
945+
if not self.session.is_shutdown:
946+
self.session.submit(self._execute, scheduled_key, method, *args, **kwargs)
947+
948+
949+
class ShardReconnectionPolicyScope(Enum):
950+
"""
951+
A scope for `ShardReconnectionPolicy`, in particular `NoConcurrentShardReconnectionPolicy`
952+
"""
953+
Cluster = 0
954+
Host = 1
955+
956+
957+
class NoConcurrentShardReconnectionPolicy(ShardReconnectionPolicy):
958+
"""
959+
A shard reconnection policy that allows only one pending connection per scope, where scope could be `Host`, `Cluster`
960+
For backoff it uses `ReconnectionPolicy`, when there is no more reconnections to scheduled backoff policy is reminded
961+
For all scopes does not allow schedule multiple reconnections for same host+shard, it silently ignores attempts to do that.
962+
963+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
964+
"""
965+
shard_reconnection_scope: ShardReconnectionPolicyScope
966+
reconnection_policy: ReconnectionPolicy
967+
968+
def __init__(
969+
self,
970+
shard_reconnection_scope: ShardReconnectionPolicyScope,
971+
reconnection_policy: ReconnectionPolicy,
972+
):
973+
if not isinstance(shard_reconnection_scope, ShardReconnectionPolicyScope):
974+
raise ValueError("shard_reconnection_scope must be a ShardReconnectionPolicyScope")
975+
if not isinstance(reconnection_policy, ReconnectionPolicy):
976+
raise ValueError("reconnection_policy must be a ReconnectionPolicy")
977+
self.shard_reconnection_scope = shard_reconnection_scope
978+
self.reconnection_policy = reconnection_policy
979+
980+
def new_scheduler(self, session: Session) -> ShardReconnectionScheduler:
981+
return _NoConcurrentShardReconnectionScheduler(session, self.shard_reconnection_scope, self.reconnection_policy)
982+
983+
984+
class _ScopeBucket:
985+
"""
986+
Holds information for a shard reconnection scope, schedules and executes reconnections.
987+
"""
988+
items: List[Tuple[Callable[..., None], Tuple[Any, ...], dict[str, Any]]]
989+
session: Session
990+
reconnection_policy: ReconnectionPolicy
991+
lock = threading.Lock
992+
schedule: Optional[Iterator[float]]
993+
994+
running: bool = False
995+
996+
def __init__(
997+
self,
998+
session: Session,
999+
reconnection_policy: ReconnectionPolicy,
1000+
):
1001+
self.items = []
1002+
self.session = session
1003+
self.reconnection_policy = reconnection_policy
1004+
self.lock = threading.Lock()
1005+
self.schedule = self.reconnection_policy.new_schedule()
1006+
1007+
def _get_delay(self) -> float:
1008+
if self.schedule is None:
1009+
self.schedule = self.reconnection_policy.new_schedule()
1010+
try:
1011+
return next(self.schedule)
1012+
except StopIteration:
1013+
self.schedule = self.reconnection_policy.new_schedule()
1014+
return next(self.schedule)
1015+
1016+
def _schedule(self):
1017+
if self.session.is_shutdown:
1018+
return
1019+
delay = self._get_delay()
1020+
if delay:
1021+
self.session.cluster.scheduler.schedule(delay, self._run)
1022+
else:
1023+
self.session.submit(self._run)
1024+
1025+
def _run(self):
1026+
if self.session.is_shutdown:
1027+
return
1028+
1029+
with self.lock:
1030+
try:
1031+
item = self.items.pop()
1032+
except IndexError:
1033+
self.running = False
1034+
self.schedule = None
1035+
return
1036+
1037+
method, args, kwargs = item
1038+
try:
1039+
method(*args, **kwargs)
1040+
finally:
1041+
self._schedule()
1042+
1043+
def add(self, method: Callable[..., None], *args, **kwargs):
1044+
with self.lock:
1045+
self.items.append((method, args, kwargs))
1046+
if not self.running:
1047+
self.running = True
1048+
self._schedule()
1049+
1050+
1051+
class _NoConcurrentShardReconnectionScheduler(ShardReconnectionScheduler):
1052+
already_scheduled: dict[str, bool]
1053+
scopes: dict[str, _ScopeBucket]
1054+
shard_reconnection_scope: ShardReconnectionPolicyScope
1055+
reconnection_policy: ReconnectionPolicy
1056+
session: Session
1057+
lock: threading.Lock
1058+
1059+
def __init__(
1060+
self,
1061+
session: Session,
1062+
shard_reconnection_scope: ShardReconnectionPolicyScope,
1063+
reconnection_policy: ReconnectionPolicy,
1064+
):
1065+
self.already_scheduled = {}
1066+
self.scopes = {}
1067+
self.shard_reconnection_scope = shard_reconnection_scope
1068+
self.reconnection_policy = reconnection_policy
1069+
self.session = session
1070+
self.lock = threading.Lock()
1071+
1072+
def _execute(self, scheduled_key: str, method: Callable[..., None], *args, **kwargs):
1073+
try:
1074+
method(*args, **kwargs)
1075+
finally:
1076+
with self.lock:
1077+
self.already_scheduled[scheduled_key] = False
1078+
1079+
def schedule(self, host_id: str, shard_id: int, method: Callable[..., None], *args, **kwargs):
1080+
if self.shard_reconnection_scope == ShardReconnectionPolicyScope.Cluster:
1081+
scope_hash = "global-cluster-scope"
1082+
else:
1083+
scope_hash = host_id
1084+
scheduled_key = f'{host_id}-{shard_id}'
1085+
1086+
with self.lock:
1087+
if self.already_scheduled.get(scheduled_key):
1088+
return False
1089+
self.already_scheduled[scheduled_key] = True
1090+
1091+
scope_info = self.scopes.get(scope_hash, 0)
1092+
if not scope_info:
1093+
scope_info = _ScopeBucket(self.session, self.reconnection_policy)
1094+
self.scopes[scope_hash] = scope_info
1095+
scope_info.add(self._execute, scheduled_key, method, *args, **kwargs)
1096+
return True
1097+
1098+
8671099
class RetryPolicy(object):
8681100
"""
8691101
A policy that describes whether to retry, rethrow, or ignore coordinator

0 commit comments

Comments
 (0)