Skip to content

Commit 8f3670e

Browse files
committed
feat(cluster): inject shard reconnection policy
Inject shard reconnection policy into cluster, session, connection and host pool. Drop pending connections tracking logic, since policy does that. Fix some tests that mocks Cluster, session, connection or host pool.
1 parent 323fb01 commit 8f3670e

File tree

4 files changed

+163
-24
lines changed

4 files changed

+163
-24
lines changed

cassandra/cluster.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@
7373
ExponentialReconnectionPolicy, HostDistance,
7474
RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan,
7575
NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy,
76-
NeverRetryPolicy)
76+
NeverRetryPolicy, ShardConnectionBackoffPolicy, NoDelayShardConnectionBackoffPolicy,
77+
ShardConnectionScheduler)
7778
from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler,
7879
HostConnectionPool, HostConnection,
7980
NoConnectionsAvailable)
@@ -757,6 +758,11 @@ def auth_provider(self, value):
757758

758759
self._auth_provider = value
759760

761+
_shard_connection_backoff_policy: ShardConnectionBackoffPolicy
762+
@property
763+
def shard_connection_backoff_policy(self) -> ShardConnectionBackoffPolicy:
764+
return self._shard_connection_backoff_policy
765+
760766
_load_balancing_policy = None
761767
@property
762768
def load_balancing_policy(self):
@@ -1219,7 +1225,8 @@ def __init__(self,
12191225
shard_aware_options=None,
12201226
metadata_request_timeout=None,
12211227
column_encryption_policy=None,
1222-
application_info:Optional[ApplicationInfoBase]=None
1228+
application_info: Optional[ApplicationInfoBase] = None,
1229+
shard_connection_backoff_policy: Optional[ShardConnectionBackoffPolicy] = None,
12231230
):
12241231
"""
12251232
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
@@ -1325,6 +1332,13 @@ def __init__(self,
13251332
else:
13261333
self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode
13271334

1335+
if shard_connection_backoff_policy is not None:
1336+
if not isinstance(shard_connection_backoff_policy, ShardConnectionBackoffPolicy):
1337+
raise TypeError("shard_connection_backoff_policy should be an instance of class derived from ShardConnectionBackoffPolicy")
1338+
self._shard_connection_backoff_policy = shard_connection_backoff_policy
1339+
else:
1340+
self._shard_connection_backoff_policy = NoDelayShardConnectionBackoffPolicy()
1341+
13281342
if reconnection_policy is not None:
13291343
if isinstance(reconnection_policy, type):
13301344
raise TypeError("reconnection_policy should not be a class, it should be an instance of that class")
@@ -2716,6 +2730,7 @@ def default_serial_consistency_level(self, cl):
27162730
_metrics = None
27172731
_request_init_callbacks = None
27182732
_graph_paging_available = False
2733+
shard_connection_backoff_scheduler: ShardConnectionScheduler
27192734

27202735
def __init__(self, cluster, hosts, keyspace=None):
27212736
self.cluster = cluster
@@ -2730,6 +2745,7 @@ def __init__(self, cluster, hosts, keyspace=None):
27302745
self._protocol_version = self.cluster.protocol_version
27312746

27322747
self.encoder = Encoder()
2748+
self.shard_connection_backoff_scheduler = cluster.shard_connection_backoff_policy.new_scheduler(self)
27332749

27342750
# create connection pools in parallel
27352751
self._initial_connect_futures = set()

cassandra/pool.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ def __init__(self, host, host_distance, session):
402402
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
403403
self._stream_available_condition = Condition(Lock())
404404
self._is_replacing = False
405-
self._connecting = set()
406405
self._connections = {}
407406
self._pending_connections = []
408407
# A pool of additional connections which are not used but affect how Scylla
@@ -418,7 +417,6 @@ def __init__(self, host, host_distance, session):
418417
# and are waiting until all requests time out or complete
419418
# so that we can dispose of them.
420419
self._trash = set()
421-
self._shard_connections_futures = []
422420
self.advanced_shardaware_block_until = 0
423421

424422
if host_distance == HostDistance.IGNORED:
@@ -483,25 +481,25 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
483481
self.host,
484482
routing_key
485483
)
486-
if conn.orphaned_threshold_reached and shard_id not in self._connecting:
484+
if conn.orphaned_threshold_reached:
487485
# The connection has met its orphaned stream ID limit
488486
# and needs to be replaced. Start opening a connection
489487
# to the same shard and replace when it is opened.
490-
self._connecting.add(shard_id)
491-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
488+
self._session.shard_connection_backoff_scheduler.schedule(
489+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
492490
log.debug(
493-
"Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
491+
"Scheduling Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
494492
shard_id,
495493
self.host,
496494
len(self._connections.keys()),
497495
self.host.sharding_info.shards_count
498496
)
499-
elif shard_id not in self._connecting:
497+
else:
500498
# rate controlled optimistic attempt to connect to a missing shard
501-
self._connecting.add(shard_id)
502-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
499+
self._session.shard_connection_backoff_scheduler.schedule(
500+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
503501
log.debug(
504-
"Trying to connect to missing shard_id=%i on host %s (%s/%i)",
502+
"Scheduling connection to missing shard_id=%i on host %s (%s/%i)",
505503
shard_id,
506504
self.host,
507505
len(self._connections.keys()),
@@ -609,8 +607,8 @@ def _replace(self, connection):
609607
if connection.features.shard_id in self._connections.keys():
610608
del self._connections[connection.features.shard_id]
611609
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
612-
self._connecting.add(connection.features.shard_id)
613-
self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id)
610+
self._session.shard_connection_backoff_scheduler.schedule(
611+
self.host.host_id, connection.features.shard_id, self._open_connection_to_missing_shard, connection.features.shard_id)
614612
else:
615613
connection = self._session.cluster.connection_factory(self.host.endpoint,
616614
on_orphaned_stream_released=self.on_orphaned_stream_released)
@@ -635,9 +633,6 @@ def shutdown(self):
635633
with self._stream_available_condition:
636634
self._stream_available_condition.notify_all()
637635

638-
for future in self._shard_connections_futures:
639-
future.cancel()
640-
641636
connections_to_close = self._connections.copy()
642637
pending_connections_to_close = self._pending_connections.copy()
643638
self._connections.clear()
@@ -843,7 +838,6 @@ def _open_connection_to_missing_shard(self, shard_id):
843838
self._excess_connections.add(conn)
844839
if close_connection:
845840
conn.close()
846-
self._connecting.discard(shard_id)
847841

848842
def _open_connections_for_all_shards(self, skip_shard_id=None):
849843
"""
@@ -856,10 +850,8 @@ def _open_connections_for_all_shards(self, skip_shard_id=None):
856850
for shard_id in range(self.host.sharding_info.shards_count):
857851
if skip_shard_id is not None and skip_shard_id == shard_id:
858852
continue
859-
future = self._session.submit(self._open_connection_to_missing_shard, shard_id)
860-
if isinstance(future, Future):
861-
self._connecting.add(shard_id)
862-
self._shard_connections_futures.append(future)
853+
self._session.shard_connection_backoff_scheduler.schedule(
854+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
863855

864856
trash_conns = None
865857
with self._lock:

tests/integration/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def _id_and_mark(f):
396396
reason='Scylla does not support custom payloads. Cassandra requires native protocol v4.0+')
397397
xfail_scylla = lambda reason, *args, **kwargs: pytest.mark.xfail(SCYLLA_VERSION is not None, reason=reason, *args, **kwargs)
398398
incorrect_test = lambda reason='This test seems to be incorrect and should be fixed', *args, **kwargs: pytest.mark.xfail(reason=reason, *args, **kwargs)
399+
requires_scylla = pytest.mark.skipif(not SCYLLA_VERSION, reason='This test is designed for scylla only')
399400

400401
pypy = unittest.skipUnless(platform.python_implementation() == "PyPy", "Test is skipped unless it's on PyPy")
401402
requiresmallclockgranularity = unittest.skipIf("Windows" in platform.system() or "asyncore" in EVENT_LOOP_MANAGER,

tests/integration/long/test_policies.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,22 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import os
15+
import time
1516
import unittest
17+
from typing import Optional
1618

1719
from cassandra import ConsistencyLevel, Unavailable
18-
from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT
20+
from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT, Session
21+
from cassandra.policies import LimitedConcurrencyShardConnectionBackoffPolicy, ShardConnectionBackoffScope, \
22+
ConstantReconnectionPolicy, ShardConnectionBackoffPolicy, NoDelayShardConnectionBackoffPolicy
23+
from cassandra.shard_info import _ShardingInfo
1924

2025
from tests.integration import use_cluster, get_cluster, get_node, TestCluster
2126

2227

2328
def setup_module():
29+
os.environ['SCYLLA_EXT_OPTS'] = "--smp 4"
2430
use_cluster('test_cluster', [4])
2531

2632

@@ -65,3 +71,127 @@ def test_should_rethrow_on_unvailable_with_default_policy_if_cas(self):
6571
self.assertEqual(exception.consistency, ConsistencyLevel.SERIAL)
6672
self.assertEqual(exception.required_replicas, 2)
6773
self.assertEqual(exception.alive_replicas, 1)
74+
75+
76+
class ShardBackoffPolicyTests(unittest.TestCase):
77+
@classmethod
78+
def tearDownClass(cls):
79+
cluster = get_cluster()
80+
cluster.start(wait_for_binary_proto=True, wait_other_notice=True) # make sure other nodes are restarted
81+
82+
def test_limited_concurrency_1_connection_per_cluster(self):
83+
self._test_backoff(
84+
LimitedConcurrencyShardConnectionBackoffPolicy(
85+
backoff_policy=ConstantReconnectionPolicy(0.1),
86+
max_concurrent=1,
87+
scope=ShardConnectionBackoffScope.Cluster,
88+
)
89+
)
90+
91+
def test_limited_concurrency_2_connection_per_cluster(self):
92+
self._test_backoff(
93+
LimitedConcurrencyShardConnectionBackoffPolicy(
94+
backoff_policy=ConstantReconnectionPolicy(0.1),
95+
max_concurrent=2,
96+
scope=ShardConnectionBackoffScope.Cluster,
97+
)
98+
)
99+
100+
def test_limited_concurrency_1_connection_per_host(self):
101+
self._test_backoff(
102+
LimitedConcurrencyShardConnectionBackoffPolicy(
103+
backoff_policy=ConstantReconnectionPolicy(0.1),
104+
max_concurrent=1,
105+
scope=ShardConnectionBackoffScope.Host,
106+
)
107+
)
108+
109+
def test_limited_concurrency_2_connection_per_host(self):
110+
self._test_backoff(
111+
LimitedConcurrencyShardConnectionBackoffPolicy(
112+
backoff_policy=ConstantReconnectionPolicy(0.1),
113+
max_concurrent=1,
114+
scope=ShardConnectionBackoffScope.Host,
115+
)
116+
)
117+
118+
def test_no_delay(self):
119+
self._test_backoff(NoDelayShardConnectionBackoffPolicy())
120+
121+
def _test_backoff(self, shard_connection_backoff_policy: ShardConnectionBackoffPolicy):
122+
backoff_policy = None
123+
if isinstance(shard_connection_backoff_policy, LimitedConcurrencyShardConnectionBackoffPolicy):
124+
backoff_policy = shard_connection_backoff_policy.backoff_policy
125+
126+
cluster = TestCluster(
127+
shard_connection_backoff_policy=shard_connection_backoff_policy,
128+
reconnection_policy=ConstantReconnectionPolicy(0),
129+
)
130+
session = cluster.connect()
131+
sharding_info = get_sharding_info(session)
132+
133+
# even if backoff is set and there is no sharding info
134+
# behavior should be the same as if there is no backoff policy
135+
if not backoff_policy or not sharding_info:
136+
time.sleep(2)
137+
expected_connections = 1
138+
if sharding_info:
139+
expected_connections = sharding_info.shards_count
140+
for host_id, connections_count in get_connections_per_host(session).items():
141+
self.assertEqual(connections_count, expected_connections)
142+
return
143+
144+
sleep_time = 0
145+
schedule = backoff_policy.new_schedule()
146+
# Calculate total time it will need to establish all connections
147+
if shard_connection_backoff_policy.scope == ShardConnectionBackoffScope.Cluster:
148+
for _ in session.hosts:
149+
for _ in range(sharding_info.shards_count - 1):
150+
sleep_time += next(schedule)
151+
sleep_time /= shard_connection_backoff_policy.max_concurrent
152+
elif shard_connection_backoff_policy.scope == ShardConnectionBackoffScope.Host:
153+
for _ in range(sharding_info.shards_count - 1):
154+
sleep_time += next(schedule)
155+
sleep_time /= shard_connection_backoff_policy.max_concurrent
156+
else:
157+
raise ValueError("Unknown scope {}".format(shard_connection_backoff_policy.scope))
158+
159+
time.sleep(sleep_time / 2)
160+
self.assertFalse(
161+
is_connection_filled(shard_connection_backoff_policy.scope, session, sharding_info.shards_count))
162+
time.sleep(sleep_time / 2 + 1)
163+
self.assertTrue(
164+
is_connection_filled(shard_connection_backoff_policy.scope, session, sharding_info.shards_count))
165+
166+
167+
def is_connection_filled(scope: ShardConnectionBackoffScope, session: Session, shards_count: int) -> bool:
168+
if scope == ShardConnectionBackoffScope.Cluster:
169+
expected_connections = shards_count * len(session.hosts)
170+
total_connections = sum(get_connections_per_host(session).values())
171+
return expected_connections == total_connections
172+
elif scope == ShardConnectionBackoffScope.Host:
173+
expected_connections_per_host = shards_count
174+
for connections_count in get_connections_per_host(session).values():
175+
if connections_count < expected_connections_per_host:
176+
return False
177+
if connections_count == expected_connections_per_host:
178+
continue
179+
assert False, "Expected {} or less connections but got {}".format(expected_connections_per_host,
180+
connections_count)
181+
return True
182+
else:
183+
raise ValueError("Unknown scope {}".format(scope))
184+
185+
186+
def get_connections_per_host(session: Session) -> dict[str, int]:
187+
host_connections = {}
188+
for host, pool in session._pools.items():
189+
host_connections[host.host_id] = len(pool._connections)
190+
return host_connections
191+
192+
193+
def get_sharding_info(session: Session) -> Optional[_ShardingInfo]:
194+
for host in session.hosts:
195+
if host.sharding_info:
196+
return host.sharding_info
197+
return None

0 commit comments

Comments
 (0)