Skip to content

Commit d417ddb

Browse files
committed
feat(policy): add shard reconnection policies
Add abstract classes: `ShardReconnectionPolicy` and `ShardReconnectionScheduler` And implementations: `NoDelayShardReconnectionPolicy` - policy that represents old behavior of having no delay and no concurrency restriction. `NoConcurrentShardReconnectionPolicy` - policy that limits concurrent reconnections to 1 per scope and introduces delay between reconnections within the scope.
1 parent 3f7bcbb commit d417ddb

File tree

4 files changed

+680
-23
lines changed

4 files changed

+680
-23
lines changed

cassandra/policies.py

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,31 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
15+
1416
import random
17+
import threading
18+
import time
19+
import weakref
20+
from abc import ABC, abstractmethod
1521

1622
from collections import namedtuple
23+
from enum import Enum
1724
from functools import lru_cache
1825
from itertools import islice, cycle, groupby, repeat
1926
import logging
2027
from random import randint, shuffle
2128
from threading import Lock
2229
import socket
2330
import warnings
31+
from typing import TYPE_CHECKING, Callable, Any, List, Tuple, Iterator, Optional, Dict
2432

2533
log = logging.getLogger(__name__)
2634

2735
from cassandra import WriteType as WT
2836

37+
if TYPE_CHECKING:
38+
from cluster import Session
2939

3040
# This is done this way because WriteType was originally
3141
# defined here and in order not to break the API.
@@ -864,6 +874,348 @@ def _add_jitter(self, value):
864874
return min(max(self.base_delay, delay), self.max_delay)
865875

866876

877+
class ShardConnectionScheduler(ABC):
878+
"""
879+
A base class for a scheduler for a shard connection backoff policy.
880+
``ShardConnectionScheduler`` is a per Session instance that schedules per shard connections according to
881+
``ShardConnectionBackoffPolicy`` that instantiates it.
882+
"""
883+
884+
@abstractmethod
885+
def schedule(
886+
self,
887+
host_id: str,
888+
shard_id: int,
889+
method: Callable[..., None],
890+
*args: List[Any],
891+
**kwargs: dict[Any, Any]) -> None:
892+
"""
893+
Schedules request to create connection to given host and shard according to the policy.
894+
At no point request is executed on the call, it is always running in a separate thread,
895+
this method is non-blocking in this regard.
896+
897+
``host_id`` - an id of the host of the shard.
898+
``shard_id`` - an id of the shard.
899+
``method`` - a callable that creates connection and stores it in the connection pool.
900+
Currently, it is `HostConnection._open_connection_to_missing_shard`.
901+
``*args`` and ``**kwargs`` are passed to ``method`` when policy executes it.
902+
"""
903+
raise NotImplementedError()
904+
905+
906+
class ShardConnectionBackoffPolicy(ABC):
907+
"""
908+
Base class for shard connection backoff policies.
909+
These policies allow user to control pace of establishing new connections.
910+
911+
On `new_scheduler` instantiate a scheduler that behaves according to the policy.
912+
"""
913+
914+
@abstractmethod
915+
def new_scheduler(self, session: Session) -> ShardConnectionScheduler:
916+
raise NotImplementedError()
917+
918+
919+
class NoDelayShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy):
920+
"""
921+
A shard connection backoff policy with no delay between attempts and no concurrency restrictions.
922+
Ensures that at most one pending connection per (host, shard) pair.
923+
If connection attempts for the same (host, shard) it is silently dropped.
924+
925+
On `new_scheduler` instantiate a scheduler that behaves according to the policy.
926+
"""
927+
928+
def new_scheduler(self, session: Session) -> ShardConnectionScheduler:
929+
return _NoDelayShardConnectionBackoffScheduler(session)
930+
931+
932+
class _NoDelayShardConnectionBackoffScheduler(ShardConnectionScheduler):
933+
"""
934+
A scheduler for ``cassandra.policies.NoDelayShardConnectionBackoffPolicy``.
935+
It does not introduce any delay or concurrency restrictions.
936+
It only ensures that there is only one pending or scheduled connection per (host, shard) pair.
937+
"""
938+
session: Session
939+
already_scheduled: dict[str, bool]
940+
lock: threading.Lock
941+
942+
def __init__(self, session: Session):
943+
self.session = weakref.proxy(session)
944+
self.already_scheduled = {}
945+
self.lock = threading.Lock()
946+
947+
def _execute(
948+
self,
949+
scheduled_key: str,
950+
method: Callable[..., None],
951+
*args: List[Any],
952+
**kwargs: dict[Any, Any]) -> None:
953+
try:
954+
method(*args, **kwargs)
955+
finally:
956+
with self.lock:
957+
self.already_scheduled[scheduled_key] = False
958+
959+
def schedule(
960+
self,
961+
host_id: str,
962+
shard_id: int,
963+
method: Callable[..., None],
964+
*args: List[Any],
965+
**kwargs: dict[Any, Any]) -> None:
966+
scheduled_key = f'{host_id}-{shard_id}'
967+
968+
with self.lock:
969+
if self.already_scheduled.get(scheduled_key):
970+
return
971+
self.already_scheduled[scheduled_key] = True
972+
973+
if not self.session.is_shutdown:
974+
self.session.submit(self._execute, scheduled_key, method, *args, **kwargs)
975+
976+
977+
class ShardConnectionBackoffScope(Enum):
978+
"""
979+
A scope for ``cassandra.policies.ShardConnectionBackoffPolicy``, in particular ``cassandra.policies.LimitedConcurrencyShardConnectionBackoffPolicy``.
980+
981+
Scope defines concurrency limitation scope, for instance:
982+
``LimitedConcurrencyShardConnectionBackoffPolicy`` - allows only N pending connection per scope, if you set it to Cluster,
983+
only N connection per cluster will be allowed.
984+
"""
985+
Cluster = 0
986+
Host = 1
987+
988+
989+
class ShardConnectionBackoffSchedule(ABC):
990+
@abstractmethod
991+
def new_schedule(self) -> Iterator[float]:
992+
"""
993+
This should return a finite or infinite iterable of delays (each as a
994+
floating point number of seconds).
995+
Note that if the iterable is finite, schedule will be recreated right after iterable is exhausted.
996+
"""
997+
raise NotImplementedError()
998+
999+
1000+
class LimitedConcurrencyShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy):
1001+
"""
1002+
A shard connection backoff policy that allows only ``max_concurrent`` concurrent connection per scope.
1003+
Scope could be ``Host``or ``Cluster``
1004+
For backoff calculation it needs ``cassandra.policies.ShardConnectionBackoffSchedule`` or ``cassandra.policies.ReconnectionPolicy``
1005+
, since both share same API.
1006+
When there is no more scheduled connections the schedule for the backoff is reset.
1007+
1008+
It also does not allow multiple pending or scheduled connections for same (host, shard) pair,
1009+
it silently drops attempts to schedule it.
1010+
1011+
On ``new_scheduler`` instantiate a scheduler that behaves according to the policy.
1012+
"""
1013+
scope: ShardConnectionBackoffScope
1014+
backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy
1015+
1016+
max_concurrent: int
1017+
"""
1018+
Max concurrent connection creation requests per scope.
1019+
"""
1020+
1021+
def __init__(
1022+
self,
1023+
scope: ShardConnectionBackoffScope,
1024+
backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy,
1025+
max_concurrent: int = 1,
1026+
):
1027+
if not isinstance(scope, ShardConnectionBackoffScope):
1028+
raise ValueError("scope must be a ShardConnectionBackoffScope")
1029+
if not isinstance(backoff_policy, (ShardConnectionBackoffSchedule, ReconnectionPolicy)):
1030+
raise ValueError("backoff_policy must be a ShardConnectionBackoffSchedule or ReconnectionPolicy")
1031+
if max_concurrent < 1:
1032+
raise ValueError("max_concurrent must be a positive integer")
1033+
self.scope = scope
1034+
self.backoff_policy = backoff_policy
1035+
self.max_concurrent = max_concurrent
1036+
1037+
def new_scheduler(self, session: Session) -> ShardConnectionScheduler:
1038+
return _LimitedConcurrencyShardConnectionScheduler(session, self.scope, self.backoff_policy, self.max_concurrent)
1039+
1040+
1041+
class CreateConnectionCallback:
1042+
method: Callable[..., None]
1043+
args: Tuple[Any, ...]
1044+
kwargs: Dict[str, Any]
1045+
1046+
def __init__(self, method: Callable[..., None], *args, **kwargs) -> None:
1047+
self.method = method
1048+
self.args = args
1049+
self.kwargs = kwargs
1050+
1051+
1052+
class _ScopeBucket:
1053+
"""
1054+
Holds information for a shard connection backoff policy scope, schedules and executes requests to create connection.
1055+
"""
1056+
session: Session
1057+
backoff_policy: ShardConnectionBackoffSchedule
1058+
lock: threading.Lock
1059+
1060+
schedule: Iterator[float]
1061+
"""
1062+
An iterable of delays in seconds generated by ``backoff_policy``.
1063+
"""
1064+
1065+
max_concurrent: int
1066+
"""
1067+
Max concurrent connection creation requests in the scope.
1068+
"""
1069+
1070+
currently_pending: int
1071+
"""
1072+
Currently pending connections.
1073+
"""
1074+
1075+
items: List[CreateConnectionCallback]
1076+
"""
1077+
Scheduled create connections requests.
1078+
"""
1079+
1080+
def __init__(
1081+
self,
1082+
session: Session,
1083+
backoff_policy: ShardConnectionBackoffSchedule,
1084+
max_concurrent: int,
1085+
):
1086+
self.items = []
1087+
self.session = session
1088+
self.backoff_policy = backoff_policy
1089+
self.lock = threading.Lock()
1090+
self.schedule = self.backoff_policy.new_schedule()
1091+
self.max_concurrent = max_concurrent
1092+
self.currently_pending = 0
1093+
1094+
def _get_delay(self) -> float:
1095+
try:
1096+
return next(self.schedule)
1097+
except StopIteration:
1098+
# A bit of trickery to avoid having lock around self.schedule
1099+
schedule = self.backoff_policy.new_schedule()
1100+
delay = next(schedule)
1101+
self.schedule = schedule
1102+
return delay
1103+
1104+
def _schedule(self):
1105+
if self.session.is_shutdown:
1106+
return
1107+
delay = self._get_delay()
1108+
if delay:
1109+
self.session.cluster.scheduler.schedule(delay, self._run)
1110+
else:
1111+
self.session.submit(self._run)
1112+
1113+
def _run(self):
1114+
if self.session.is_shutdown:
1115+
return
1116+
1117+
with self.lock:
1118+
try:
1119+
cb = self.items.pop()
1120+
except IndexError:
1121+
# Just in case
1122+
if self.currently_pending > 0:
1123+
self.currently_pending -= 1
1124+
# When items are exhausted reset schedule to ensure that new items going to get another schedule
1125+
# It is important for exponential policy
1126+
self.schedule = self.backoff_policy.new_schedule()
1127+
return
1128+
1129+
try:
1130+
cb.method(*cb.args, **cb.kwargs)
1131+
finally:
1132+
self._schedule()
1133+
1134+
def schedule_new_connection(self, cb: CreateConnectionCallback):
1135+
with self.lock:
1136+
self.items.append(cb)
1137+
if self.currently_pending < self.max_concurrent:
1138+
self.currently_pending += 1
1139+
self._schedule()
1140+
1141+
1142+
class _LimitedConcurrencyShardConnectionScheduler(ShardConnectionScheduler):
1143+
"""
1144+
A scheduler for ``cassandra.policies.LimitedConcurrencyShardConnectionPolicy``.
1145+
1146+
Limits concurrency for connection creation requests per scope to ``max_concurrent``.
1147+
"""
1148+
1149+
already_scheduled: dict[str, bool]
1150+
"""
1151+
Dict of (host, shard) flags, flag is true if there is connection creation request scheduled or
1152+
currently running for given host and shard.
1153+
"""
1154+
1155+
scopes: dict[str, _ScopeBucket]
1156+
"""
1157+
Scopes storage, key is a scope key, value is an instance that holds scope data.
1158+
"""
1159+
1160+
scope: ShardConnectionBackoffScope
1161+
"""
1162+
Scope type.
1163+
"""
1164+
1165+
backoff_policy: ShardConnectionBackoffSchedule
1166+
session: Session
1167+
lock: threading.Lock
1168+
1169+
max_concurrent: int
1170+
"""
1171+
Max concurrent connection creation requests per scope.
1172+
"""
1173+
1174+
def __init__(
1175+
self,
1176+
session: Session,
1177+
scope: ShardConnectionBackoffScope,
1178+
backoff_policy: ShardConnectionBackoffSchedule,
1179+
max_concurrent: int,
1180+
):
1181+
self.already_scheduled = {}
1182+
self.scopes = {}
1183+
self.scope = scope
1184+
self.backoff_policy = backoff_policy
1185+
self.max_concurrent = max_concurrent
1186+
self.session = session
1187+
self.lock = threading.Lock()
1188+
1189+
def _execute(self, scheduled_key: str, method: Callable[..., None], *args, **kwargs):
1190+
try:
1191+
method(*args, **kwargs)
1192+
finally:
1193+
with self.lock:
1194+
self.already_scheduled[scheduled_key] = False
1195+
1196+
def schedule(self, host_id: str, shard_id: int, method: Callable[..., None], *args, **kwargs):
1197+
if self.scope == ShardConnectionBackoffScope.Cluster:
1198+
scope_hash = "global-cluster-scope"
1199+
elif self.scope == ShardConnectionBackoffScope.Host:
1200+
scope_hash = host_id
1201+
else:
1202+
raise ValueError("scope must be Cluster or Host")
1203+
1204+
scheduled_key = f'{host_id}-{shard_id}'
1205+
1206+
with self.lock:
1207+
if self.already_scheduled.get(scheduled_key):
1208+
return False
1209+
self.already_scheduled[scheduled_key] = True
1210+
1211+
scope_info = self.scopes.get(scope_hash)
1212+
if not scope_info:
1213+
scope_info = _ScopeBucket(self.session, self.backoff_policy, self.max_concurrent)
1214+
self.scopes[scope_hash] = scope_info
1215+
scope_info.schedule_new_connection(CreateConnectionCallback(self._execute, scheduled_key, method, *args, **kwargs))
1216+
return True
1217+
1218+
8671219
class RetryPolicy(object):
8681220
"""
8691221
A policy that describes whether to retry, rethrow, or ignore coordinator

tests/unit/test_host_connection_pool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from cassandra.connection import Connection
2727
from cassandra.pool import HostConnection, HostConnectionPool
2828
from cassandra.pool import Host, NoConnectionsAvailable
29-
from cassandra.policies import HostDistance, SimpleConvictionPolicy
29+
from cassandra.policies import HostDistance, SimpleConvictionPolicy, _NoDelayShardConnectionBackoffScheduler
3030

3131
LOGGER = logging.getLogger(__name__)
3232

@@ -41,6 +41,8 @@ def make_session(self):
4141
session.cluster.get_core_connections_per_host.return_value = 1
4242
session.cluster.get_max_requests_per_connection.return_value = 1
4343
session.cluster.get_max_connections_per_host.return_value = 1
44+
session.shard_connection_backoff_scheduler = _NoDelayShardConnectionBackoffScheduler(session)
45+
session.is_shutdown = False
4446
return session
4547

4648
def test_borrow_and_return(self):

0 commit comments

Comments
 (0)