Skip to content

Commit aafa540

Browse files
committed
feat(policy): add shard reconnection policies
Add abstract classes: `ShardReconnectionPolicy` and `ShardReconnectionScheduler` And implementations: `NoDelayShardReconnectionPolicy` - policy that represents old behavior of having no delay and no concurrency restriction. `NoConcurrentShardReconnectionPolicy` - policy that limits concurrent reconnections to 1 per scope and introduces delay between reconnections within the scope.
1 parent 3f7bcbb commit aafa540

File tree

4 files changed

+671
-23
lines changed

4 files changed

+671
-23
lines changed

cassandra/policies.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,31 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
15+
1416
import random
17+
import threading
18+
import time
19+
import weakref
20+
from abc import ABC, abstractmethod
1521

1622
from collections import namedtuple
23+
from enum import Enum
1724
from functools import lru_cache
1825
from itertools import islice, cycle, groupby, repeat
1926
import logging
2027
from random import randint, shuffle
2128
from threading import Lock
2229
import socket
2330
import warnings
31+
from typing import TYPE_CHECKING, Callable, Any, List, Tuple, Iterator, Optional, Dict
2432

2533
log = logging.getLogger(__name__)
2634

2735
from cassandra import WriteType as WT
2836

37+
if TYPE_CHECKING:
38+
from cluster import Session
2939

3040
# This is done this way because WriteType was originally
3141
# defined here and in order not to break the API.
@@ -864,6 +874,339 @@ def _add_jitter(self, value):
864874
return min(max(self.base_delay, delay), self.max_delay)
865875

866876

877+
class ShardConnectionScheduler(ABC):
878+
"""
879+
A base class for a scheduler for a shard connection backoff policy.
880+
``ShardConnectionScheduler`` is a per Session instance that implements logic
881+
described by ``ShardConnectionBackoffPolicy`` that instantiates it
882+
"""
883+
884+
@abstractmethod
885+
def schedule(
886+
self,
887+
host_id: str,
888+
shard_id: int,
889+
method: Callable[..., None],
890+
*args: List[Any],
891+
**kwargs: dict[Any, Any]) -> None:
892+
"""
893+
Schedule a create connection request for given host and shard according to policy or executes it right away.
894+
It is non-blocking call, even if policy executes it right away, it is being executed in a separate thread.
895+
896+
``host_id`` - an id of the host of the shard
897+
``shard_id`` - an id of the shard
898+
``method`` - a callable that creates connection and stores it in the connection pool.
899+
Currently, it is `HostConnection._open_connection_to_missing_shard`
900+
``*args`` and ``**kwargs`` are passed to ``method`` when policy executes it
901+
"""
902+
raise NotImplementedError()
903+
904+
905+
class ShardConnectionBackoffPolicy(ABC):
906+
"""
907+
Base class for shard connection backoff policies.
908+
These policies allow user to control pace of establishing new connections to shards
909+
910+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
911+
"""
912+
913+
@abstractmethod
914+
def new_scheduler(self, session: Session) -> ShardConnectionScheduler:
915+
raise NotImplementedError()
916+
917+
918+
class NoDelayShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy):
919+
"""
920+
A shard connection backoff policy with no delay between attempts and no concurrency restrictions.
921+
Ensures that at most one pending connection per (host, shard) pair.
922+
If connection attempts for the same (host, shard) it is silently dropped.
923+
924+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
925+
"""
926+
927+
def new_scheduler(self, session: Session) -> ShardConnectionScheduler:
928+
return _NoDelayShardConnectionBackoffScheduler(session)
929+
930+
931+
class _NoDelayShardConnectionBackoffScheduler(ShardConnectionScheduler):
932+
"""
933+
A shard connection backoff scheduler for ``cassandra.policies.NoDelayShardConnectionBackoffPolicy``.
934+
It does not introduce any delay or concurrency restrictions.
935+
It only ensures that there is only one pending or scheduled connection per host+shard.
936+
"""
937+
session: Session
938+
already_scheduled: dict[str, bool]
939+
lock: threading.Lock
940+
941+
def __init__(self, session: Session):
942+
self.session = weakref.proxy(session)
943+
self.already_scheduled = {}
944+
self.lock = threading.Lock()
945+
946+
def _execute(
947+
self,
948+
scheduled_key: str,
949+
method: Callable[..., None],
950+
*args: List[Any],
951+
**kwargs: dict[Any, Any]) -> None:
952+
try:
953+
method(*args, **kwargs)
954+
finally:
955+
with self.lock:
956+
self.already_scheduled[scheduled_key] = False
957+
958+
def schedule(
959+
self,
960+
host_id: str,
961+
shard_id: int,
962+
method: Callable[..., None],
963+
*args: List[Any],
964+
**kwargs: dict[Any, Any]) -> None:
965+
scheduled_key = f'{host_id}-{shard_id}'
966+
967+
with self.lock:
968+
if self.already_scheduled.get(scheduled_key):
969+
return
970+
self.already_scheduled[scheduled_key] = True
971+
972+
if not self.session.is_shutdown:
973+
self.session.submit(self._execute, scheduled_key, method, *args, **kwargs)
974+
975+
976+
class ShardConnectionBackoffScope(Enum):
977+
"""
978+
A scope for `ShardConnectionBackoffPolicy`, in particular ``LimitedConcurrencyShardConnectionBackoffPolicy``
979+
980+
Scope defines concurrency limitation scope, for instance :
981+
``LimitedConcurrencyShardConnectionBackoffPolicy`` - allows only one pending connection per scope, if you set it to Cluster,
982+
only one connection per cluster will be allowed
983+
"""
984+
Cluster = 0
985+
Host = 1
986+
987+
988+
class ShardConnectionBackoffSchedule(ABC):
989+
@abstractmethod
990+
def new_schedule(self) -> Iterator[float]:
991+
"""
992+
This should return a finite or infinite iterable of delays (each as a
993+
floating point number of seconds).
994+
Note that if the iterable is finite, schedule will be recreated right after iterable is exhausted.
995+
"""
996+
raise NotImplementedError()
997+
998+
999+
class LimitedConcurrencyShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy):
1000+
"""
1001+
A shard connection backoff policy that allows only ``max_concurrent`` concurrent connection per scope.
1002+
Scope could be ``Host``or ``Cluster``
1003+
For backoff calculation ir needs ``ShardConnectionBackoffSchedule`` or ``ReconnectionPolicy``, since both share same API.
1004+
When there is no more scheduled connections schedule of the backoff is reset.
1005+
1006+
it also does not allow multiple pending or scheduled connections for same host+shard,
1007+
it silently drops attempts to schedule it.
1008+
1009+
On ``new_scheduler`` instantiate a scheduler that behaves according to the policy
1010+
"""
1011+
scope: ShardConnectionBackoffScope
1012+
backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy
1013+
1014+
max_concurrent: int
1015+
"""
1016+
Max concurrent connection creation requests per scope.
1017+
"""
1018+
1019+
def __init__(
1020+
self,
1021+
scope: ShardConnectionBackoffScope,
1022+
backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy,
1023+
max_concurrent: int = 1,
1024+
):
1025+
if not isinstance(scope, ShardConnectionBackoffScope):
1026+
raise ValueError("scope must be a ShardConnectionBackoffScope")
1027+
if not isinstance(backoff_policy, (ShardConnectionBackoffSchedule, ReconnectionPolicy)):
1028+
raise ValueError("backoff_policy must be a ShardConnectionBackoffSchedule or ReconnectionPolicy")
1029+
if max_concurrent < 1:
1030+
raise ValueError("max_concurrent must be a positive integer")
1031+
self.scope = scope
1032+
self.backoff_policy = backoff_policy
1033+
self.max_concurrent = max_concurrent
1034+
1035+
def new_scheduler(self, session: Session) -> ShardConnectionScheduler:
1036+
return _LimitedConcurrencyShardConnectionScheduler(session, self.scope, self.backoff_policy, self.max_concurrent)
1037+
1038+
1039+
class CreateConnectionCallback:
1040+
method: Callable[..., None]
1041+
args: Tuple[Any, ...]
1042+
kwargs: Dict[str, Any]
1043+
1044+
def __init__(self, method: Callable[..., None], *args, **kwargs) -> None:
1045+
self.method = method
1046+
self.args = args
1047+
self.kwargs = kwargs
1048+
1049+
1050+
class _ScopeBucket:
1051+
"""
1052+
Holds information for a shard reconnection scope, schedules and executes reconnections.
1053+
"""
1054+
session: Session
1055+
backoff_policy: ShardConnectionBackoffSchedule
1056+
lock: threading.Lock
1057+
1058+
schedule: Iterator[float]
1059+
"""
1060+
Current schedule generated by ``backoff_policy``
1061+
"""
1062+
1063+
max_concurrent: int
1064+
"""
1065+
Max concurrent connection creation requests in the scope.
1066+
"""
1067+
1068+
currently_pending: int
1069+
"""
1070+
Currently pending connections
1071+
"""
1072+
1073+
items: List[CreateConnectionCallback]
1074+
"""
1075+
Scheduled create connections requests
1076+
"""
1077+
1078+
def __init__(
1079+
self,
1080+
session: Session,
1081+
backoff_policy: ShardConnectionBackoffSchedule,
1082+
max_concurrent: int,
1083+
):
1084+
self.items = []
1085+
self.session = session
1086+
self.backoff_policy = backoff_policy
1087+
self.lock = threading.Lock()
1088+
self.schedule = self.backoff_policy.new_schedule()
1089+
self.max_concurrent = max_concurrent
1090+
self.currently_pending = 0
1091+
1092+
def _get_delay(self) -> float:
1093+
try:
1094+
return next(self.schedule)
1095+
except StopIteration:
1096+
# A bit of trickery to avoid having lock around self.schedule
1097+
schedule = self.backoff_policy.new_schedule()
1098+
delay = next(schedule)
1099+
self.schedule = schedule
1100+
return delay
1101+
1102+
def _schedule(self):
1103+
if self.session.is_shutdown:
1104+
return
1105+
delay = self._get_delay()
1106+
if delay:
1107+
self.session.cluster.scheduler.schedule(delay, self._run)
1108+
else:
1109+
self.session.submit(self._run)
1110+
1111+
def _run(self):
1112+
if self.session.is_shutdown:
1113+
return
1114+
1115+
with self.lock:
1116+
try:
1117+
cb = self.items.pop()
1118+
except IndexError:
1119+
# Just in case
1120+
if self.currently_pending > 0:
1121+
self.currently_pending -= 1
1122+
# When items are exhausted reset schedule to ensure that new items going to get another schedule
1123+
# It is important for exponential policy
1124+
self.schedule = self.backoff_policy.new_schedule()
1125+
return
1126+
1127+
try:
1128+
cb.method(*cb.args, **cb.kwargs)
1129+
finally:
1130+
self._schedule()
1131+
1132+
def schedule_new_connection(self, cb: CreateConnectionCallback):
1133+
with self.lock:
1134+
self.items.append(cb)
1135+
if self.currently_pending < self.max_concurrent:
1136+
self.currently_pending += 1
1137+
self._schedule()
1138+
1139+
1140+
class _LimitedConcurrencyShardConnectionScheduler(ShardConnectionScheduler):
1141+
"""
1142+
Dict of host+shard flags, flag is true if there is connection creation request scheduled or currently running for host+shard
1143+
"""
1144+
already_scheduled: dict[str, bool]
1145+
1146+
scopes: dict[str, _ScopeBucket]
1147+
"""
1148+
Scopes storage
1149+
"""
1150+
1151+
scope: ShardConnectionBackoffScope
1152+
"""
1153+
Scope type
1154+
"""
1155+
1156+
backoff_policy: ShardConnectionBackoffSchedule
1157+
session: Session
1158+
lock: threading.Lock
1159+
1160+
max_concurrent: int
1161+
"""
1162+
Max concurrent connection creation requests per scope.
1163+
"""
1164+
1165+
def __init__(
1166+
self,
1167+
session: Session,
1168+
scope: ShardConnectionBackoffScope,
1169+
backoff_policy: ShardConnectionBackoffSchedule,
1170+
max_concurrent: int,
1171+
):
1172+
self.already_scheduled = {}
1173+
self.scopes = {}
1174+
self.scope = scope
1175+
self.backoff_policy = backoff_policy
1176+
self.max_concurrent = max_concurrent
1177+
self.session = session
1178+
self.lock = threading.Lock()
1179+
1180+
def _execute(self, scheduled_key: str, method: Callable[..., None], *args, **kwargs):
1181+
try:
1182+
method(*args, **kwargs)
1183+
finally:
1184+
with self.lock:
1185+
self.already_scheduled[scheduled_key] = False
1186+
1187+
def schedule(self, host_id: str, shard_id: int, method: Callable[..., None], *args, **kwargs):
1188+
if self.scope == ShardConnectionBackoffScope.Cluster:
1189+
scope_hash = "global-cluster-scope"
1190+
elif self.scope == ShardConnectionBackoffScope.Host:
1191+
scope_hash = host_id
1192+
else:
1193+
raise ValueError("scope must be Cluster or Host")
1194+
1195+
scheduled_key = f'{host_id}-{shard_id}'
1196+
1197+
with self.lock:
1198+
if self.already_scheduled.get(scheduled_key):
1199+
return False
1200+
self.already_scheduled[scheduled_key] = True
1201+
1202+
scope_info = self.scopes.get(scope_hash)
1203+
if not scope_info:
1204+
scope_info = _ScopeBucket(self.session, self.backoff_policy, self.max_concurrent)
1205+
self.scopes[scope_hash] = scope_info
1206+
scope_info.schedule_new_connection(CreateConnectionCallback(self._execute, scheduled_key, method, *args, **kwargs))
1207+
return True
1208+
1209+
8671210
class RetryPolicy(object):
8681211
"""
8691212
A policy that describes whether to retry, rethrow, or ignore coordinator

tests/unit/test_host_connection_pool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from cassandra.connection import Connection
2727
from cassandra.pool import HostConnection, HostConnectionPool
2828
from cassandra.pool import Host, NoConnectionsAvailable
29-
from cassandra.policies import HostDistance, SimpleConvictionPolicy
29+
from cassandra.policies import HostDistance, SimpleConvictionPolicy, _NoDelayShardConnectionBackoffScheduler
3030

3131
LOGGER = logging.getLogger(__name__)
3232

@@ -41,6 +41,8 @@ def make_session(self):
4141
session.cluster.get_core_connections_per_host.return_value = 1
4242
session.cluster.get_max_requests_per_connection.return_value = 1
4343
session.cluster.get_max_connections_per_host.return_value = 1
44+
session.shard_connection_backoff_scheduler = _NoDelayShardConnectionBackoffScheduler(session)
45+
session.is_shutdown = False
4446
return session
4547

4648
def test_borrow_and_return(self):

0 commit comments

Comments
 (0)