|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +from __future__ import annotations |
| 15 | + |
14 | 16 | import random
|
| 17 | +import threading |
| 18 | +import time |
| 19 | +import weakref |
| 20 | +from abc import ABC, abstractmethod |
15 | 21 |
|
16 | 22 | from collections import namedtuple
|
| 23 | +from enum import Enum |
17 | 24 | from functools import lru_cache
|
18 | 25 | from itertools import islice, cycle, groupby, repeat
|
19 | 26 | import logging
|
20 | 27 | from random import randint, shuffle
|
21 | 28 | from threading import Lock
|
22 | 29 | import socket
|
23 | 30 | import warnings
|
| 31 | +from typing import TYPE_CHECKING, Callable, Any, List, Tuple, Iterator, Optional |
24 | 32 |
|
25 | 33 | log = logging.getLogger(__name__)
|
26 | 34 |
|
27 | 35 | from cassandra import WriteType as WT
|
28 | 36 |
|
| 37 | +if TYPE_CHECKING: |
| 38 | + from cluster import Session |
29 | 39 |
|
30 | 40 | # This is done this way because WriteType was originally
|
31 | 41 | # defined here and in order not to break the API.
|
@@ -864,6 +874,228 @@ def _add_jitter(self, value):
|
864 | 874 | return min(max(self.base_delay, delay), self.max_delay)
|
865 | 875 |
|
866 | 876 |
|
| 877 | +class ShardReconnectionScheduler(ABC): |
| 878 | + @abstractmethod |
| 879 | + def schedule( |
| 880 | + self, |
| 881 | + host_id: str, |
| 882 | + shard_id: int, |
| 883 | + method: Callable[..., None], |
| 884 | + *args: List[Any], |
| 885 | + **kwargs: dict[Any, Any]) -> None: |
| 886 | + raise NotImplementedError() |
| 887 | + |
| 888 | + |
| 889 | +class ShardReconnectionPolicy(ABC): |
| 890 | + """ |
| 891 | + Base class for shard reconnection policies. |
| 892 | +
|
| 893 | + On `new_scheduler` instantiate a scheduler that behaves according to the policy |
| 894 | + """ |
| 895 | + |
| 896 | + @abstractmethod |
| 897 | + def new_scheduler(self, session: Session) -> ShardReconnectionScheduler: |
| 898 | + raise NotImplementedError() |
| 899 | + |
| 900 | + |
| 901 | +class NoDelayShardReconnectionPolicy(ShardReconnectionPolicy): |
| 902 | + """ |
| 903 | + A shard reconnection policy with no delay between attempts and no concurrency restrictions. |
| 904 | + Ensures at most one pending reconnection per (host, shard) pair — any additional |
| 905 | + reconnection attempts for the same (host, shard) are silently ignored. |
| 906 | +
|
| 907 | + On `new_scheduler` instantiate a scheduler that behaves according to the policy |
| 908 | + """ |
| 909 | + |
| 910 | + def new_scheduler(self, session: Session) -> ShardReconnectionScheduler: |
| 911 | + return _NoDelayShardReconnectionScheduler(session) |
| 912 | + |
| 913 | + |
| 914 | +class _NoDelayShardReconnectionScheduler(ShardReconnectionScheduler): |
| 915 | + session: Session |
| 916 | + already_scheduled: dict[str, bool] |
| 917 | + |
| 918 | + def __init__(self, session: Session): |
| 919 | + self.session = weakref.proxy(session) |
| 920 | + self.already_scheduled = {} |
| 921 | + |
| 922 | + def _execute( |
| 923 | + self, |
| 924 | + scheduled_key: str, |
| 925 | + method: Callable[..., None], |
| 926 | + *args: List[Any], |
| 927 | + **kwargs: dict[Any, Any]) -> None: |
| 928 | + try: |
| 929 | + method(*args, **kwargs) |
| 930 | + finally: |
| 931 | + self.already_scheduled[scheduled_key] = False |
| 932 | + |
| 933 | + def schedule( |
| 934 | + self, |
| 935 | + host_id: str, |
| 936 | + shard_id: int, |
| 937 | + method: Callable[..., None], |
| 938 | + *args: List[Any], |
| 939 | + **kwargs: dict[Any, Any]) -> None: |
| 940 | + scheduled_key = f'{host_id}-{shard_id}' |
| 941 | + if self.already_scheduled.get(scheduled_key): |
| 942 | + return |
| 943 | + |
| 944 | + self.already_scheduled[scheduled_key] = True |
| 945 | + if not self.session.is_shutdown: |
| 946 | + self.session.submit(self._execute, scheduled_key, method, *args, **kwargs) |
| 947 | + |
| 948 | + |
| 949 | +class ShardReconnectionPolicyScope(Enum): |
| 950 | + """ |
| 951 | + A scope for `ShardReconnectionPolicy`, in particular `NoConcurrentShardReconnectionPolicy` |
| 952 | + """ |
| 953 | + Cluster = 0 |
| 954 | + Host = 1 |
| 955 | + |
| 956 | + |
| 957 | +class NoConcurrentShardReconnectionPolicy(ShardReconnectionPolicy): |
| 958 | + """ |
| 959 | + A shard reconnection policy that allows only one pending connection per scope, where scope could be `Host`, `Cluster` |
| 960 | + For backoff it uses `ReconnectionPolicy`, when there is no more reconnections to scheduled backoff policy is reminded |
| 961 | + For all scopes does not allow schedule multiple reconnections for same host+shard, it silently ignores attempts to do that. |
| 962 | +
|
| 963 | + On `new_scheduler` instantiate a scheduler that behaves according to the policy |
| 964 | + """ |
| 965 | + shard_reconnection_scope: ShardReconnectionPolicyScope |
| 966 | + reconnection_policy: ReconnectionPolicy |
| 967 | + |
| 968 | + def __init__( |
| 969 | + self, |
| 970 | + shard_reconnection_scope: ShardReconnectionPolicyScope, |
| 971 | + reconnection_policy: ReconnectionPolicy, |
| 972 | + ): |
| 973 | + if not isinstance(shard_reconnection_scope, ShardReconnectionPolicyScope): |
| 974 | + raise ValueError("shard_reconnection_scope must be a ShardReconnectionPolicyScope") |
| 975 | + if not isinstance(reconnection_policy, ReconnectionPolicy): |
| 976 | + raise ValueError("reconnection_policy must be a ReconnectionPolicy") |
| 977 | + self.shard_reconnection_scope = shard_reconnection_scope |
| 978 | + self.reconnection_policy = reconnection_policy |
| 979 | + |
| 980 | + def new_scheduler(self, session: Session) -> ShardReconnectionScheduler: |
| 981 | + return _NoConcurrentShardReconnectionScheduler(session, self.shard_reconnection_scope, self.reconnection_policy) |
| 982 | + |
| 983 | + |
| 984 | +class _ScopeBucket: |
| 985 | + """ |
| 986 | + Holds information for a shard reconnection scope, schedules and executes reconnections. |
| 987 | + """ |
| 988 | + items: List[Tuple[Callable[..., None], Tuple[Any, ...], dict[str, Any]]] |
| 989 | + session: Session |
| 990 | + reconnection_policy: ReconnectionPolicy |
| 991 | + lock = threading.Lock |
| 992 | + schedule: Optional[Iterator[float]] |
| 993 | + |
| 994 | + running: bool = False |
| 995 | + |
| 996 | + def __init__( |
| 997 | + self, |
| 998 | + session: Session, |
| 999 | + reconnection_policy: ReconnectionPolicy, |
| 1000 | + ): |
| 1001 | + self.items = [] |
| 1002 | + self.session = session |
| 1003 | + self.reconnection_policy = reconnection_policy |
| 1004 | + self.lock = threading.Lock() |
| 1005 | + self.schedule = self.reconnection_policy.new_schedule() |
| 1006 | + |
| 1007 | + def _get_delay(self) -> float: |
| 1008 | + if self.schedule is None: |
| 1009 | + self.schedule = self.reconnection_policy.new_schedule() |
| 1010 | + try: |
| 1011 | + return next(self.schedule) |
| 1012 | + except StopIteration: |
| 1013 | + self.schedule = self.reconnection_policy.new_schedule() |
| 1014 | + return next(self.schedule) |
| 1015 | + |
| 1016 | + def _schedule(self): |
| 1017 | + if self.session.is_shutdown: |
| 1018 | + return |
| 1019 | + delay = self._get_delay() |
| 1020 | + if delay: |
| 1021 | + self.session.cluster.scheduler.schedule(delay, self._run) |
| 1022 | + else: |
| 1023 | + self.session.submit(self._run) |
| 1024 | + |
| 1025 | + def _run(self): |
| 1026 | + if self.session.is_shutdown: |
| 1027 | + return |
| 1028 | + |
| 1029 | + with self.lock: |
| 1030 | + try: |
| 1031 | + item = self.items.pop() |
| 1032 | + except IndexError: |
| 1033 | + self.running = False |
| 1034 | + self.schedule = None |
| 1035 | + return |
| 1036 | + |
| 1037 | + method, args, kwargs = item |
| 1038 | + try: |
| 1039 | + method(*args, **kwargs) |
| 1040 | + finally: |
| 1041 | + self._schedule() |
| 1042 | + |
| 1043 | + def add(self, method: Callable[..., None], *args, **kwargs): |
| 1044 | + with self.lock: |
| 1045 | + self.items.append((method, args, kwargs)) |
| 1046 | + if not self.running: |
| 1047 | + self.running = True |
| 1048 | + self._schedule() |
| 1049 | + |
| 1050 | + |
| 1051 | +class _NoConcurrentShardReconnectionScheduler(ShardReconnectionScheduler): |
| 1052 | + already_scheduled: dict[str, bool] |
| 1053 | + scopes: dict[str, _ScopeBucket] |
| 1054 | + shard_reconnection_scope: ShardReconnectionPolicyScope |
| 1055 | + reconnection_policy: ReconnectionPolicy |
| 1056 | + session: Session |
| 1057 | + lock: threading.Lock |
| 1058 | + |
| 1059 | + def __init__( |
| 1060 | + self, |
| 1061 | + session: Session, |
| 1062 | + shard_reconnection_scope: ShardReconnectionPolicyScope, |
| 1063 | + reconnection_policy: ReconnectionPolicy, |
| 1064 | + ): |
| 1065 | + self.already_scheduled = {} |
| 1066 | + self.scopes = {} |
| 1067 | + self.shard_reconnection_scope = shard_reconnection_scope |
| 1068 | + self.reconnection_policy = reconnection_policy |
| 1069 | + self.session = session |
| 1070 | + self.lock = threading.Lock() |
| 1071 | + |
| 1072 | + def _execute(self, scheduled_key: str, method: Callable[..., None], *args, **kwargs): |
| 1073 | + try: |
| 1074 | + method(*args, **kwargs) |
| 1075 | + finally: |
| 1076 | + with self.lock: |
| 1077 | + self.already_scheduled[scheduled_key] = False |
| 1078 | + |
| 1079 | + def schedule(self, host_id: str, shard_id: int, method: Callable[..., None], *args, **kwargs): |
| 1080 | + if self.shard_reconnection_scope == ShardReconnectionPolicyScope.Cluster: |
| 1081 | + scope_hash = "global-cluster-scope" |
| 1082 | + else: |
| 1083 | + scope_hash = host_id |
| 1084 | + scheduled_key = f'{host_id}-{shard_id}' |
| 1085 | + |
| 1086 | + with self.lock: |
| 1087 | + if self.already_scheduled.get(scheduled_key): |
| 1088 | + return False |
| 1089 | + self.already_scheduled[scheduled_key] = True |
| 1090 | + |
| 1091 | + scope_info = self.scopes.get(scope_hash, 0) |
| 1092 | + if not scope_info: |
| 1093 | + scope_info = _ScopeBucket(self.session, self.reconnection_policy) |
| 1094 | + self.scopes[scope_hash] = scope_info |
| 1095 | + scope_info.add(self._execute, scheduled_key, method, *args, **kwargs) |
| 1096 | + return True |
| 1097 | + |
| 1098 | + |
867 | 1099 | class RetryPolicy(object):
|
868 | 1100 | """
|
869 | 1101 | A policy that describes whether to retry, rethrow, or ignore coordinator
|
|
0 commit comments