Skip to content

Commit 9dfd9ec

Browse files
committed
feat(cluster): inject shard reconnection policy
Inject shard reconnection policy into cluster, session, connection and host pool Drop pending connections tracking logic, since policy does that. Fix some tests that mocks Cluster, session, connection or host pool.
1 parent b8e5b13 commit 9dfd9ec

File tree

4 files changed

+100
-43
lines changed

4 files changed

+100
-43
lines changed

cassandra/cluster.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@
7373
ExponentialReconnectionPolicy, HostDistance,
7474
RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan,
7575
NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy,
76-
NeverRetryPolicy)
76+
NeverRetryPolicy, ShardReconnectionPolicy, NoDelayShardReconnectionPolicy,
77+
ShardReconnectionScheduler)
7778
from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler,
7879
HostConnectionPool, HostConnection,
7980
NoConnectionsAvailable)
@@ -757,6 +758,11 @@ def auth_provider(self, value):
757758

758759
self._auth_provider = value
759760

761+
_shard_reconnection_policy: ShardReconnectionPolicy
762+
@property
763+
def shard_reconnection_policy(self) -> ShardReconnectionPolicy:
764+
return self._shard_reconnection_policy
765+
760766
_load_balancing_policy = None
761767
@property
762768
def load_balancing_policy(self):
@@ -1219,7 +1225,8 @@ def __init__(self,
12191225
shard_aware_options=None,
12201226
metadata_request_timeout=None,
12211227
column_encryption_policy=None,
1222-
application_info:Optional[ApplicationInfoBase]=None
1228+
application_info: Optional[ApplicationInfoBase] = None,
1229+
shard_reconnection_policy: Optional[ShardReconnectionPolicy] = None,
12231230
):
12241231
"""
12251232
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
@@ -1325,6 +1332,13 @@ def __init__(self,
13251332
else:
13261333
self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode
13271334

1335+
if shard_reconnection_policy is not None:
1336+
if not isinstance(shard_reconnection_policy, ShardReconnectionPolicy):
1337+
raise TypeError("shard_reconnection_policy should be an instance of class derived from ShardReconnectionPolicy")
1338+
self._shard_reconnection_policy = shard_reconnection_policy
1339+
else:
1340+
self._shard_reconnection_policy = NoDelayShardReconnectionPolicy()
1341+
13281342
if reconnection_policy is not None:
13291343
if isinstance(reconnection_policy, type):
13301344
raise TypeError("reconnection_policy should not be a class, it should be an instance of that class")
@@ -2716,6 +2730,7 @@ def default_serial_consistency_level(self, cl):
27162730
_metrics = None
27172731
_request_init_callbacks = None
27182732
_graph_paging_available = False
2733+
shard_reconnection_scheduler: ShardReconnectionScheduler
27192734

27202735
def __init__(self, cluster, hosts, keyspace=None):
27212736
self.cluster = cluster
@@ -2730,6 +2745,7 @@ def __init__(self, cluster, hosts, keyspace=None):
27302745
self._protocol_version = self.cluster.protocol_version
27312746

27322747
self.encoder = Encoder()
2748+
self.shard_reconnection_scheduler = cluster.shard_reconnection_policy.new_scheduler(self)
27332749

27342750
# create connection pools in parallel
27352751
self._initial_connect_futures = set()
@@ -4455,6 +4471,9 @@ def shutdown(self):
44554471
self._queue.put_nowait((0, 0, None))
44564472
self.join()
44574473

4474+
def empty(self):
4475+
return len(self._scheduled_tasks) == 0 and self._queue.empty()
4476+
44584477
def schedule(self, delay, fn, *args, **kwargs):
44594478
self._insert_task(delay, (fn, args, tuple(kwargs.items())))
44604479

cassandra/pool.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ def __init__(self, host, host_distance, session):
402402
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
403403
self._stream_available_condition = Condition(Lock())
404404
self._is_replacing = False
405-
self._connecting = set()
406405
self._connections = {}
407406
self._pending_connections = []
408407
# A pool of additional connections which are not used but affect how Scylla
@@ -418,7 +417,6 @@ def __init__(self, host, host_distance, session):
418417
# and are waiting until all requests time out or complete
419418
# so that we can dispose of them.
420419
self._trash = set()
421-
self._shard_connections_futures = []
422420
self.advanced_shardaware_block_until = 0
423421

424422
if host_distance == HostDistance.IGNORED:
@@ -483,25 +481,25 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
483481
self.host,
484482
routing_key
485483
)
486-
if conn.orphaned_threshold_reached and shard_id not in self._connecting:
484+
if conn.orphaned_threshold_reached:
487485
# The connection has met its orphaned stream ID limit
488486
# and needs to be replaced. Start opening a connection
489487
# to the same shard and replace when it is opened.
490-
self._connecting.add(shard_id)
491-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
488+
self._session.shard_reconnection_scheduler.schedule(
489+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
492490
log.debug(
493-
"Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
491+
"Scheduling Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
494492
shard_id,
495493
self.host,
496494
len(self._connections.keys()),
497495
self.host.sharding_info.shards_count
498496
)
499-
elif shard_id not in self._connecting:
497+
else:
500498
# rate controlled optimistic attempt to connect to a missing shard
501-
self._connecting.add(shard_id)
502-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
499+
self._session.shard_reconnection_scheduler.schedule(
500+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
503501
log.debug(
504-
"Trying to connect to missing shard_id=%i on host %s (%s/%i)",
502+
"Scheduling connection to missing shard_id=%i on host %s (%s/%i)",
505503
shard_id,
506504
self.host,
507505
len(self._connections.keys()),
@@ -609,8 +607,8 @@ def _replace(self, connection):
609607
if connection.features.shard_id in self._connections.keys():
610608
del self._connections[connection.features.shard_id]
611609
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
612-
self._connecting.add(connection.features.shard_id)
613-
self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id)
610+
self._session.shard_reconnection_scheduler.schedule(
611+
self.host.host_id, connection.features.shard_id, self._open_connection_to_missing_shard, connection.features.shard_id)
614612
else:
615613
connection = self._session.cluster.connection_factory(self.host.endpoint,
616614
on_orphaned_stream_released=self.on_orphaned_stream_released)
@@ -635,9 +633,6 @@ def shutdown(self):
635633
with self._stream_available_condition:
636634
self._stream_available_condition.notify_all()
637635

638-
for future in self._shard_connections_futures:
639-
future.cancel()
640-
641636
connections_to_close = self._connections.copy()
642637
pending_connections_to_close = self._pending_connections.copy()
643638
self._connections.clear()
@@ -843,7 +838,6 @@ def _open_connection_to_missing_shard(self, shard_id):
843838
self._excess_connections.add(conn)
844839
if close_connection:
845840
conn.close()
846-
self._connecting.discard(shard_id)
847841

848842
def _open_connections_for_all_shards(self, skip_shard_id=None):
849843
"""
@@ -856,10 +850,8 @@ def _open_connections_for_all_shards(self, skip_shard_id=None):
856850
for shard_id in range(self.host.sharding_info.shards_count):
857851
if skip_shard_id is not None and skip_shard_id == shard_id:
858852
continue
859-
future = self._session.submit(self._open_connection_to_missing_shard, shard_id)
860-
if isinstance(future, Future):
861-
self._connecting.add(shard_id)
862-
self._shard_connections_futures.append(future)
853+
self._session.shard_reconnection_scheduler.schedule(
854+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
863855

864856
trash_conns = None
865857
with self._lock:

tests/unit/test_host_connection_pool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from cassandra.connection import Connection
2727
from cassandra.pool import HostConnection, HostConnectionPool
2828
from cassandra.pool import Host, NoConnectionsAvailable
29-
from cassandra.policies import HostDistance, SimpleConvictionPolicy
29+
from cassandra.policies import HostDistance, SimpleConvictionPolicy, _NoDelayShardReconnectionScheduler
3030

3131
LOGGER = logging.getLogger(__name__)
3232

@@ -41,6 +41,8 @@ def make_session(self):
4141
session.cluster.get_core_connections_per_host.return_value = 1
4242
session.cluster.get_max_requests_per_connection.return_value = 1
4343
session.cluster.get_max_connections_per_host.return_value = 1
44+
session.shard_reconnection_scheduler = _NoDelayShardReconnectionScheduler(session)
45+
session.is_shutdown = False
4446
return session
4547

4648
def test_borrow_and_return(self):

tests/unit/test_shard_aware.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,21 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import uuid
1415

1516
try:
1617
import unittest2 as unittest
1718
except ImportError:
1819
import unittest # noqa
1920

21+
import time
2022
import logging
2123
from mock import MagicMock
2224
from concurrent.futures import ThreadPoolExecutor
2325

24-
from cassandra.cluster import ShardAwareOptions
26+
from cassandra.cluster import ShardAwareOptions, _Scheduler
27+
from cassandra.policies import ConstantReconnectionPolicy, \
28+
NoDelayShardReconnectionPolicy, NoConcurrentShardReconnectionPolicy, ShardReconnectionPolicyScope
2529
from cassandra.pool import HostConnection, HostDistance
2630
from cassandra.connection import ShardingInfo, DefaultEndPoint
2731
from cassandra.metadata import Murmur3Token
@@ -53,7 +57,15 @@ class OptionsHolder(object):
5357
self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"e").value), 4)
5458
self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"100000").value), 2)
5559

56-
def test_advanced_shard_aware_port(self):
60+
def test_shard_aware_reconnection_policy_no_delay(self):
61+
# with NoDelayReconnectionPolicy all the connections should be created right away
62+
self._test_shard_aware_reconnection_policy(4, NoDelayShardReconnectionPolicy(), 4, 4)
63+
64+
def test_shard_aware_reconnection_policy_delay(self):
65+
# with ConstantReconnectionPolicy first connection is created right away, others are delayed
66+
self._test_shard_aware_reconnection_policy(4, NoConcurrentShardReconnectionPolicy(ShardReconnectionPolicyScope.Cluster, ConstantReconnectionPolicy(1)), 1, 4)
67+
68+
def _test_shard_aware_reconnection_policy(self, shard_count, shard_reconnection_policy, expected_count, expected_after):
5769
"""
5870
Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class)
5971
the next connections would be open using this port
@@ -71,17 +83,25 @@ def __init__(self, is_ssl=False, *args, **kwargs):
7183
self.cluster.ssl_options = None
7284
self.cluster.shard_aware_options = ShardAwareOptions()
7385
self.cluster.executor = ThreadPoolExecutor(max_workers=2)
86+
self._executor_submit_original = self.cluster.executor.submit
87+
self.cluster.executor.submit = self._executor_submit
88+
self.cluster.scheduler = _Scheduler(self.cluster.executor)
7489
self.cluster.signal_connection_failure = lambda *args, **kwargs: False
7590
self.cluster.connection_factory = self.mock_connection_factory
7691
self.connection_counter = 0
92+
self.shard_reconnection_scheduler = shard_reconnection_policy.new_scheduler(self)
7793
self.futures = []
7894

7995
def submit(self, fn, *args, **kwargs):
96+
if self.is_shutdown:
97+
return None
98+
return self.cluster.executor.submit(fn, *args, **kwargs)
99+
100+
def _executor_submit(self, fn, *args, **kwargs):
80101
logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs)
81-
if not self.is_shutdown:
82-
f = self.cluster.executor.submit(fn, *args, **kwargs)
83-
self.futures += [f]
84-
return f
102+
f = self._executor_submit_original(fn, *args, **kwargs)
103+
self.futures += [f]
104+
return f
85105

86106
def mock_connection_factory(self, *args, **kwargs):
87107
connection = MagicMock()
@@ -90,26 +110,50 @@ def mock_connection_factory(self, *args, **kwargs):
90110
connection.is_closed = False
91111
connection.orphaned_threshold_reached = False
92112
connection.endpoint = args[0]
93-
sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045)
113+
sharding_info = ShardingInfo(shard_id=1, shards_count=shard_count, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045)
94114
connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info)
95115
self.connection_counter += 1
96116

97117
return connection
98118

99119
host = MagicMock()
120+
host.host_id = uuid.uuid4()
100121
host.endpoint = DefaultEndPoint("1.2.3.4")
122+
session = None
123+
reconnection_policy = None
124+
if isinstance(shard_reconnection_policy, NoConcurrentShardReconnectionPolicy):
125+
reconnection_policy = shard_reconnection_policy.reconnection_policy
126+
try:
127+
for port, is_ssl in [(19042, False), (19045, True)]:
128+
session = MockSession(is_ssl=is_ssl)
129+
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
130+
for f in session.futures:
131+
f.result()
132+
assert len(pool._connections) == expected_count
133+
for shard_id, connection in pool._connections.items():
134+
assert connection.features.shard_id == shard_id
135+
if shard_id == 0:
136+
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
137+
else:
138+
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
101139

102-
for port, is_ssl in [(19042, False), (19045, True)]:
103-
session = MockSession(is_ssl=is_ssl)
104-
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
105-
for f in session.futures:
106-
f.result()
107-
assert len(pool._connections) == 4
108-
for shard_id, connection in pool._connections.items():
109-
assert connection.features.shard_id == shard_id
110-
if shard_id == 0:
111-
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
112-
else:
113-
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
140+
sleep_time = 0
141+
if reconnection_policy:
142+
# Check that connections to shards are being established according to the policy
143+
# Calculate total time it will need to establish all connections
144+
# Sleep half of the time and check that connections are not there yet
145+
# Sleep rest of the time + 1 second and check that all connections has been established
146+
schedule = reconnection_policy.new_schedule()
147+
for _ in range(shard_count):
148+
sleep_time += next(schedule)
149+
if sleep_time > 0:
150+
time.sleep(sleep_time/2)
151+
# Check that connection are not being established quicker than expected
152+
assert len(pool._connections) < expected_after
153+
time.sleep(sleep_time/2 + 1)
114154

115-
session.cluster.executor.shutdown(wait=True)
155+
assert len(pool._connections) == expected_after
156+
finally:
157+
if session:
158+
session.cluster.scheduler.shutdown()
159+
session.cluster.executor.shutdown(wait=True)

0 commit comments

Comments
 (0)