Skip to content

Commit b0fd069

Browse files
committed
feat(cluster): inject shard reconnection policy
Inject shard reconnection policy into cluster, session, connection and host pool Drop pending connections tracking logic, since policy does that. Fix some tests that mocks Cluster, session, connection or host pool.
1 parent 2f5ca3c commit b0fd069

File tree

4 files changed

+100
-42
lines changed

4 files changed

+100
-42
lines changed

cassandra/cluster.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from itertools import groupby, count, chain
3030
import json
3131
import logging
32+
from typing import Optional
3233
from warnings import warn
3334
from random import random
3435
import re
@@ -72,7 +73,8 @@
7273
ExponentialReconnectionPolicy, HostDistance,
7374
RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan,
7475
NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy,
75-
NeverRetryPolicy)
76+
NeverRetryPolicy, ShardReconnectionPolicy, NoDelayShardReconnectionPolicy,
77+
ShardReconnectionScheduler)
7678
from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler,
7779
HostConnectionPool, HostConnection,
7880
NoConnectionsAvailable)
@@ -742,6 +744,11 @@ def auth_provider(self, value):
742744

743745
self._auth_provider = value
744746

747+
_shard_reconnection_policy: ShardReconnectionPolicy
748+
@property
749+
def shard_reconnection_policy(self) -> ShardReconnectionPolicy:
750+
return self._shard_reconnection_policy
751+
745752
_load_balancing_policy = None
746753
@property
747754
def load_balancing_policy(self):
@@ -1204,6 +1211,7 @@ def __init__(self,
12041211
shard_aware_options=None,
12051212
metadata_request_timeout=None,
12061213
column_encryption_policy=None,
1214+
shard_reconnection_policy: Optional[ShardReconnectionPolicy] = None,
12071215
):
12081216
"""
12091217
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
@@ -1309,6 +1317,13 @@ def __init__(self,
13091317
else:
13101318
self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode
13111319

1320+
if shard_reconnection_policy is not None:
1321+
if not isinstance(shard_reconnection_policy, ShardReconnectionPolicy):
1322+
raise TypeError("shard_reconnection_policy should be an instance of class derived from ShardReconnectionPolicy")
1323+
self._shard_reconnection_policy = shard_reconnection_policy
1324+
else:
1325+
self._shard_reconnection_policy = NoDelayShardReconnectionPolicy()
1326+
13121327
if reconnection_policy is not None:
13131328
if isinstance(reconnection_policy, type):
13141329
raise TypeError("reconnection_policy should not be a class, it should be an instance of that class")
@@ -2693,6 +2708,7 @@ def default_serial_consistency_level(self, cl):
26932708
_metrics = None
26942709
_request_init_callbacks = None
26952710
_graph_paging_available = False
2711+
shard_reconnection_scheduler: ShardReconnectionScheduler
26962712

26972713
def __init__(self, cluster, hosts, keyspace=None):
26982714
self.cluster = cluster
@@ -2707,6 +2723,7 @@ def __init__(self, cluster, hosts, keyspace=None):
27072723
self._protocol_version = self.cluster.protocol_version
27082724

27092725
self.encoder = Encoder()
2726+
self.shard_reconnection_scheduler = cluster.shard_reconnection_policy.new_scheduler(self)
27102727

27112728
# create connection pools in parallel
27122729
self._initial_connect_futures = set()
@@ -4432,6 +4449,9 @@ def shutdown(self):
44324449
self._queue.put_nowait((0, 0, None))
44334450
self.join()
44344451

4452+
def empty(self):
4453+
return len(self._scheduled_tasks) == 0 and self._queue.empty()
4454+
44354455
def schedule(self, delay, fn, *args, **kwargs):
44364456
self._insert_task(delay, (fn, args, tuple(kwargs.items())))
44374457

cassandra/pool.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ def __init__(self, host, host_distance, session):
402402
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
403403
self._stream_available_condition = Condition(Lock())
404404
self._is_replacing = False
405-
self._connecting = set()
406405
self._connections = {}
407406
self._pending_connections = []
408407
# A pool of additional connections which are not used but affect how Scylla
@@ -418,7 +417,6 @@ def __init__(self, host, host_distance, session):
418417
# and are waiting until all requests time out or complete
419418
# so that we can dispose of them.
420419
self._trash = set()
421-
self._shard_connections_futures = []
422420
self.advanced_shardaware_block_until = 0
423421

424422
if host_distance == HostDistance.IGNORED:
@@ -483,25 +481,25 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
483481
self.host,
484482
routing_key
485483
)
486-
if conn.orphaned_threshold_reached and shard_id not in self._connecting:
484+
if conn.orphaned_threshold_reached:
487485
# The connection has met its orphaned stream ID limit
488486
# and needs to be replaced. Start opening a connection
489487
# to the same shard and replace when it is opened.
490-
self._connecting.add(shard_id)
491-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
488+
self._session.shard_reconnection_scheduler.schedule(
489+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
492490
log.debug(
493-
"Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
491+
"Scheduling Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
494492
shard_id,
495493
self.host,
496494
len(self._connections.keys()),
497495
self.host.sharding_info.shards_count
498496
)
499-
elif shard_id not in self._connecting:
497+
else:
500498
# rate controlled optimistic attempt to connect to a missing shard
501-
self._connecting.add(shard_id)
502-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
499+
self._session.shard_reconnection_scheduler.schedule(
500+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
503501
log.debug(
504-
"Trying to connect to missing shard_id=%i on host %s (%s/%i)",
502+
"Scheduling connection to missing shard_id=%i on host %s (%s/%i)",
505503
shard_id,
506504
self.host,
507505
len(self._connections.keys()),
@@ -609,8 +607,8 @@ def _replace(self, connection):
609607
if connection.features.shard_id in self._connections.keys():
610608
del self._connections[connection.features.shard_id]
611609
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
612-
self._connecting.add(connection.features.shard_id)
613-
self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id)
610+
self._session.shard_reconnection_scheduler.schedule(
611+
self.host.host_id, connection.features.shard_id, self._open_connection_to_missing_shard, connection.features.shard_id)
614612
else:
615613
connection = self._session.cluster.connection_factory(self.host.endpoint,
616614
on_orphaned_stream_released=self.on_orphaned_stream_released)
@@ -635,9 +633,6 @@ def shutdown(self):
635633
with self._stream_available_condition:
636634
self._stream_available_condition.notify_all()
637635

638-
for future in self._shard_connections_futures:
639-
future.cancel()
640-
641636
connections_to_close = self._connections.copy()
642637
pending_connections_to_close = self._pending_connections.copy()
643638
self._connections.clear()
@@ -843,7 +838,6 @@ def _open_connection_to_missing_shard(self, shard_id):
843838
self._excess_connections.add(conn)
844839
if close_connection:
845840
conn.close()
846-
self._connecting.discard(shard_id)
847841

848842
def _open_connections_for_all_shards(self, skip_shard_id=None):
849843
"""
@@ -856,10 +850,8 @@ def _open_connections_for_all_shards(self, skip_shard_id=None):
856850
for shard_id in range(self.host.sharding_info.shards_count):
857851
if skip_shard_id is not None and skip_shard_id == shard_id:
858852
continue
859-
future = self._session.submit(self._open_connection_to_missing_shard, shard_id)
860-
if isinstance(future, Future):
861-
self._connecting.add(shard_id)
862-
self._shard_connections_futures.append(future)
853+
self._session.shard_reconnection_scheduler.schedule(
854+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
863855

864856
trash_conns = None
865857
with self._lock:

tests/unit/test_host_connection_pool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from cassandra.connection import Connection
2727
from cassandra.pool import HostConnection, HostConnectionPool
2828
from cassandra.pool import Host, NoConnectionsAvailable
29-
from cassandra.policies import HostDistance, SimpleConvictionPolicy
29+
from cassandra.policies import HostDistance, SimpleConvictionPolicy, _NoDelayShardReconnectionScheduler
3030

3131
LOGGER = logging.getLogger(__name__)
3232

@@ -41,6 +41,8 @@ def make_session(self):
4141
session.cluster.get_core_connections_per_host.return_value = 1
4242
session.cluster.get_max_requests_per_connection.return_value = 1
4343
session.cluster.get_max_connections_per_host.return_value = 1
44+
session.shard_reconnection_scheduler = _NoDelayShardReconnectionScheduler(session)
45+
session.is_shutdown = False
4446
return session
4547

4648
def test_borrow_and_return(self):

tests/unit/test_shard_aware.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,21 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import uuid
1415

1516
try:
1617
import unittest2 as unittest
1718
except ImportError:
1819
import unittest # noqa
1920

21+
import time
2022
import logging
2123
from mock import MagicMock
2224
from concurrent.futures import ThreadPoolExecutor
2325

24-
from cassandra.cluster import ShardAwareOptions
26+
from cassandra.cluster import ShardAwareOptions, _Scheduler
27+
from cassandra.policies import ConstantReconnectionPolicy, \
28+
NoDelayShardReconnectionPolicy, NoConcurrentShardReconnectionPolicy, ShardReconnectionPolicyScope
2529
from cassandra.pool import HostConnection, HostDistance
2630
from cassandra.connection import ShardingInfo, DefaultEndPoint
2731
from cassandra.metadata import Murmur3Token
@@ -53,7 +57,15 @@ class OptionsHolder(object):
5357
self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"e").value), 4)
5458
self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"100000").value), 2)
5559

56-
def test_advanced_shard_aware_port(self):
60+
def test_shard_aware_reconnection_policy_no_delay(self):
61+
# with NoDelayReconnectionPolicy all the connections should be created right away
62+
self._test_shard_aware_reconnection_policy(4, NoDelayShardReconnectionPolicy(), 4, 4)
63+
64+
def test_shard_aware_reconnection_policy_delay(self):
65+
# with ConstantReconnectionPolicy first connection is created right away, others are delayed
66+
self._test_shard_aware_reconnection_policy(4, NoConcurrentShardReconnectionPolicy(ShardReconnectionPolicyScope.Cluster, ConstantReconnectionPolicy(1)), 1, 4)
67+
68+
def _test_shard_aware_reconnection_policy(self, shard_count, shard_reconnection_policy, expected_count, expected_after):
5769
"""
5870
Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class)
5971
the next connections would be open using this port
@@ -71,17 +83,25 @@ def __init__(self, is_ssl=False, *args, **kwargs):
7183
self.cluster.ssl_options = None
7284
self.cluster.shard_aware_options = ShardAwareOptions()
7385
self.cluster.executor = ThreadPoolExecutor(max_workers=2)
86+
self._executor_submit_original = self.cluster.executor.submit
87+
self.cluster.executor.submit = self._executor_submit
88+
self.cluster.scheduler = _Scheduler(self.cluster.executor)
7489
self.cluster.signal_connection_failure = lambda *args, **kwargs: False
7590
self.cluster.connection_factory = self.mock_connection_factory
7691
self.connection_counter = 0
92+
self.shard_reconnection_scheduler = shard_reconnection_policy.new_scheduler(self)
7793
self.futures = []
7894

7995
def submit(self, fn, *args, **kwargs):
96+
if self.is_shutdown:
97+
return None
98+
return self.cluster.executor.submit(fn, *args, **kwargs)
99+
100+
def _executor_submit(self, fn, *args, **kwargs):
80101
logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs)
81-
if not self.is_shutdown:
82-
f = self.cluster.executor.submit(fn, *args, **kwargs)
83-
self.futures += [f]
84-
return f
102+
f = self._executor_submit_original(fn, *args, **kwargs)
103+
self.futures += [f]
104+
return f
85105

86106
def mock_connection_factory(self, *args, **kwargs):
87107
connection = MagicMock()
@@ -90,26 +110,50 @@ def mock_connection_factory(self, *args, **kwargs):
90110
connection.is_closed = False
91111
connection.orphaned_threshold_reached = False
92112
connection.endpoint = args[0]
93-
sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045)
113+
sharding_info = ShardingInfo(shard_id=1, shards_count=shard_count, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045)
94114
connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info)
95115
self.connection_counter += 1
96116

97117
return connection
98118

99119
host = MagicMock()
120+
host.host_id = uuid.uuid4()
100121
host.endpoint = DefaultEndPoint("1.2.3.4")
122+
session = None
123+
reconnection_policy = None
124+
if isinstance(shard_reconnection_policy, NoConcurrentShardReconnectionPolicy):
125+
reconnection_policy = shard_reconnection_policy.reconnection_policy
126+
try:
127+
for port, is_ssl in [(19042, False), (19045, True)]:
128+
session = MockSession(is_ssl=is_ssl)
129+
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
130+
for f in session.futures:
131+
f.result()
132+
assert len(pool._connections) == expected_count
133+
for shard_id, connection in pool._connections.items():
134+
assert connection.features.shard_id == shard_id
135+
if shard_id == 0:
136+
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
137+
else:
138+
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
101139

102-
for port, is_ssl in [(19042, False), (19045, True)]:
103-
session = MockSession(is_ssl=is_ssl)
104-
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
105-
for f in session.futures:
106-
f.result()
107-
assert len(pool._connections) == 4
108-
for shard_id, connection in pool._connections.items():
109-
assert connection.features.shard_id == shard_id
110-
if shard_id == 0:
111-
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
112-
else:
113-
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
140+
sleep_time = 0
141+
if reconnection_policy:
142+
# Check that connections to shards are being established according to the policy
143+
# Calculate total time it will need to establish all connections
144+
# Sleep half of the time and check that connections are not there yet
145+
# Sleep rest of the time + 1 second and check that all connections has been established
146+
schedule = reconnection_policy.new_schedule()
147+
for _ in range(shard_count):
148+
sleep_time += next(schedule)
149+
if sleep_time > 0:
150+
time.sleep(sleep_time/2)
151+
# Check that connection are not being established quicker than expected
152+
assert len(pool._connections) < expected_after
153+
time.sleep(sleep_time/2 + 1)
114154

115-
session.cluster.executor.shutdown(wait=True)
155+
assert len(pool._connections) == expected_after
156+
finally:
157+
if session:
158+
session.cluster.scheduler.shutdown()
159+
session.cluster.executor.shutdown(wait=True)

0 commit comments

Comments
 (0)