Skip to content

Commit 46af54d

Browse files
committed
Integrate ShardConnectionBackoffPolicy
Add code that integrates ShardConnectionBackoffPolicy into: 1. Cluster 2. Session 3. HostConnection Main idea is to put ShardConnectionBackoffPolicy in control of shard connection creation proccess. Removing duplicate logic from HostConnection that tracks pending connection creation requests.
1 parent b0388f7 commit 46af54d

File tree

4 files changed

+119
-51
lines changed

4 files changed

+119
-51
lines changed

cassandra/cluster.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@
7373
ExponentialReconnectionPolicy, HostDistance,
7474
RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan,
7575
NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy,
76-
NeverRetryPolicy)
76+
NeverRetryPolicy, ShardConnectionBackoffPolicy, NoDelayShardConnectionBackoffPolicy,
77+
ShardConnectionScheduler)
7778
from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler,
7879
HostConnection,
7980
NoConnectionsAvailable)
@@ -754,6 +755,11 @@ def auth_provider(self, value):
754755

755756
self._auth_provider = value
756757

758+
_shard_connection_backoff_policy: ShardConnectionBackoffPolicy
759+
@property
760+
def shard_connection_backoff_policy(self) -> ShardConnectionBackoffPolicy:
761+
return self._shard_connection_backoff_policy
762+
757763
_load_balancing_policy = None
758764
@property
759765
def load_balancing_policy(self):
@@ -1216,7 +1222,8 @@ def __init__(self,
12161222
shard_aware_options=None,
12171223
metadata_request_timeout=None,
12181224
column_encryption_policy=None,
1219-
application_info:Optional[ApplicationInfoBase]=None
1225+
application_info: Optional[ApplicationInfoBase] = None,
1226+
shard_connection_backoff_policy: Optional[ShardConnectionBackoffPolicy] = None,
12201227
):
12211228
"""
12221229
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
@@ -1322,6 +1329,13 @@ def __init__(self,
13221329
else:
13231330
self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode
13241331

1332+
if shard_connection_backoff_policy is not None:
1333+
if not isinstance(shard_connection_backoff_policy, ShardConnectionBackoffPolicy):
1334+
raise TypeError("shard_connection_backoff_policy should be an instance of class derived from ShardConnectionBackoffPolicy")
1335+
self._shard_connection_backoff_policy = shard_connection_backoff_policy
1336+
else:
1337+
self._shard_connection_backoff_policy = NoDelayShardConnectionBackoffPolicy()
1338+
13251339
if reconnection_policy is not None:
13261340
if isinstance(reconnection_policy, type):
13271341
raise TypeError("reconnection_policy should not be a class, it should be an instance of that class")
@@ -2659,6 +2673,7 @@ def default_serial_consistency_level(self, cl):
26592673
_metrics = None
26602674
_request_init_callbacks = None
26612675
_graph_paging_available = False
2676+
shard_connection_backoff_scheduler: ShardConnectionScheduler
26622677

26632678
def __init__(self, cluster, hosts, keyspace=None):
26642679
self.cluster = cluster
@@ -2673,6 +2688,7 @@ def __init__(self, cluster, hosts, keyspace=None):
26732688
self._protocol_version = self.cluster.protocol_version
26742689

26752690
self.encoder = Encoder()
2691+
self.shard_connection_backoff_scheduler = cluster.shard_connection_backoff_policy.new_connection_scheduler(self.cluster.scheduler)
26762692

26772693
# create connection pools in parallel
26782694
self._initial_connect_futures = set()
@@ -3281,6 +3297,7 @@ def shutdown(self):
32813297
else:
32823298
self.is_shutdown = True
32833299

3300+
self.shard_connection_backoff_scheduler.shutdown()
32843301
# PYTHON-673. If shutdown was called shortly after session init, avoid
32853302
# a race by cancelling any initial connection attempts haven't started,
32863303
# then blocking on any that have.

cassandra/pool.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
Connection pooling and host management.
1717
"""
1818
from concurrent.futures import Future
19-
from functools import total_ordering
19+
from functools import total_ordering, partial
2020
import logging
2121
import socket
2222
import time
@@ -401,7 +401,6 @@ def __init__(self, host, host_distance, session):
401401
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
402402
self._stream_available_condition = Condition(Lock())
403403
self._is_replacing = False
404-
self._connecting = set()
405404
self._connections = {}
406405
self._pending_connections = []
407406
# A pool of additional connections which are not used but affect how Scylla
@@ -417,7 +416,6 @@ def __init__(self, host, host_distance, session):
417416
# and are waiting until all requests time out or complete
418417
# so that we can dispose of them.
419418
self._trash = set()
420-
self._shard_connections_futures = []
421419
self.advanced_shardaware_block_until = 0
422420

423421
if host_distance == HostDistance.IGNORED:
@@ -482,25 +480,25 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
482480
self.host,
483481
routing_key
484482
)
485-
if conn.orphaned_threshold_reached and shard_id not in self._connecting:
483+
if conn.orphaned_threshold_reached:
486484
# The connection has met its orphaned stream ID limit
487485
# and needs to be replaced. Start opening a connection
488486
# to the same shard and replace when it is opened.
489-
self._connecting.add(shard_id)
490-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
487+
self._session.shard_connection_backoff_scheduler.schedule(
488+
self.host.host_id, shard_id, partial(self._open_connection_to_missing_shard, shard_id))
491489
log.debug(
492-
"Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
490+
"Scheduling Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
493491
shard_id,
494492
self.host,
495493
len(self._connections.keys()),
496494
self.host.sharding_info.shards_count
497495
)
498-
elif shard_id not in self._connecting:
496+
else:
499497
# rate controlled optimistic attempt to connect to a missing shard
500-
self._connecting.add(shard_id)
501-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
498+
self._session.shard_connection_backoff_scheduler.schedule(
499+
self.host.host_id, shard_id, partial(self._open_connection_to_missing_shard, shard_id))
502500
log.debug(
503-
"Trying to connect to missing shard_id=%i on host %s (%s/%i)",
501+
"Scheduling connection to missing shard_id=%i on host %s (%s/%i)",
504502
shard_id,
505503
self.host,
506504
len(self._connections.keys()),
@@ -610,8 +608,8 @@ def _replace(self, connection):
610608
if connection.features.shard_id in self._connections.keys():
611609
del self._connections[connection.features.shard_id]
612610
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
613-
self._connecting.add(connection.features.shard_id)
614-
self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id)
611+
self._session.shard_connection_backoff_scheduler.schedule(
612+
self.host.host_id, connection.features.shard_id, partial(self._open_connection_to_missing_shard, connection.features.shard_id))
615613
else:
616614
connection = self._session.cluster.connection_factory(self.host.endpoint,
617615
on_orphaned_stream_released=self.on_orphaned_stream_released)
@@ -636,9 +634,6 @@ def shutdown(self):
636634
with self._stream_available_condition:
637635
self._stream_available_condition.notify_all()
638636

639-
for future in self._shard_connections_futures:
640-
future.cancel()
641-
642637
connections_to_close = self._connections.copy()
643638
pending_connections_to_close = self._pending_connections.copy()
644639
self._connections.clear()
@@ -848,7 +843,6 @@ def _open_connection_to_missing_shard(self, shard_id):
848843
self._excess_connections.add(conn)
849844
if close_connection:
850845
conn.close()
851-
self._connecting.discard(shard_id)
852846

853847
def _open_connections_for_all_shards(self, skip_shard_id=None):
854848
"""
@@ -861,10 +855,8 @@ def _open_connections_for_all_shards(self, skip_shard_id=None):
861855
for shard_id in range(self.host.sharding_info.shards_count):
862856
if skip_shard_id is not None and skip_shard_id == shard_id:
863857
continue
864-
future = self._session.submit(self._open_connection_to_missing_shard, shard_id)
865-
if isinstance(future, Future):
866-
self._connecting.add(shard_id)
867-
self._shard_connections_futures.append(future)
858+
self._session.shard_connection_backoff_scheduler.schedule(
859+
self.host.host_id, shard_id, partial(self._open_connection_to_missing_shard, shard_id))
868860

869861
trash_conns = None
870862
with self._lock:

tests/unit/test_host_connection_pool.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,23 @@
2222
from threading import Thread, Event, Lock
2323
from unittest.mock import Mock, NonCallableMagicMock, MagicMock
2424

25-
from cassandra.cluster import Session, ShardAwareOptions
25+
from cassandra.cluster import Session, ShardAwareOptions, _Scheduler
2626
from cassandra.connection import Connection
2727
from cassandra.pool import HostConnection
2828
from cassandra.pool import Host, NoConnectionsAvailable
29-
from cassandra.policies import HostDistance, SimpleConvictionPolicy
29+
from cassandra.policies import HostDistance, SimpleConvictionPolicy, _NoDelayShardConnectionBackoffScheduler
3030

3131
LOGGER = logging.getLogger(__name__)
3232

3333

34+
class FakeScheduler(_Scheduler):
35+
def __init__(self):
36+
super(FakeScheduler, self).__init__(ThreadPoolExecutor())
37+
38+
def schedule(self, delay, fn, *args, **kwargs):
39+
super().schedule(0, fn, *args, **kwargs)
40+
41+
3442
class _PoolTests(unittest.TestCase):
3543
__test__ = False
3644
PoolImpl = None
@@ -40,6 +48,9 @@ def make_session(self):
4048
session = NonCallableMagicMock(spec=Session, keyspace='foobarkeyspace')
4149
session.cluster.get_core_connections_per_host.return_value = 1
4250
session.cluster.get_max_connections_per_host.return_value = 1
51+
session.shard_connection_backoff_scheduler = _NoDelayShardConnectionBackoffScheduler(FakeScheduler())
52+
session.shard_connection_backoff_scheduler.schedule = Mock(wraps=session.shard_connection_backoff_scheduler.schedule)
53+
session.is_shutdown = False
4354
return session
4455

4556
def test_borrow_and_return(self):
@@ -173,9 +184,9 @@ def test_return_defunct_connection_on_down_host(self):
173184
if self.PoolImpl is HostConnection:
174185
# on shard aware implementation we use submit function regardless
175186
self.assertTrue(host.signal_connection_failure.call_args)
176-
self.assertTrue(session.submit.called)
187+
self.assertTrue(session.shard_connection_backoff_scheduler.schedule.called)
177188
else:
178-
self.assertFalse(session.submit.called)
189+
self.assertFalse(session.shard_connection_backoff_scheduler.schedule.called)
179190
self.assertTrue(session.cluster.signal_connection_failure.call_args)
180191
self.assertTrue(pool.is_shutdown)
181192

tests/unit/test_shard_aware.py

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import uuid
15+
from unittest.mock import Mock
16+
17+
from cassandra.policies import NoDelayShardConnectionBackoffPolicy, _NoDelayShardConnectionBackoffScheduler
1418

1519
try:
1620
import unittest2 as unittest
@@ -21,7 +25,7 @@
2125
from mock import MagicMock
2226
from concurrent.futures import ThreadPoolExecutor
2327

24-
from cassandra.cluster import ShardAwareOptions
28+
from cassandra.cluster import ShardAwareOptions, _Scheduler
2529
from cassandra.pool import HostConnection, HostDistance
2630
from cassandra.connection import ShardingInfo, DefaultEndPoint
2731
from cassandra.metadata import Murmur3Token
@@ -53,11 +57,18 @@ class OptionsHolder(object):
5357
self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"e").value), 4)
5458
self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"100000").value), 2)
5559

56-
def test_advanced_shard_aware_port(self):
60+
def test_shard_aware_reconnection_policy_no_delay(self):
61+
# with NoDelayReconnectionPolicy all the connections should be created right away
62+
self._test_shard_aware_reconnection_policy(4, NoDelayShardConnectionBackoffPolicy(), 4)
63+
64+
def _test_shard_aware_reconnection_policy(self, shard_count, shard_connection_backoff_policy, expected_connections):
5765
"""
5866
Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class)
59-
the next connections would be open using this port
67+
It checks that:
68+
1. Next connections are opened using this port
69+
2. Connection creation pase matches `shard_connection_backoff_policy`
6070
"""
71+
6172
class MockSession(MagicMock):
6273
is_shutdown = False
6374
keyspace = "ks1"
@@ -71,45 +82,82 @@ def __init__(self, is_ssl=False, *args, **kwargs):
7182
self.cluster.ssl_options = None
7283
self.cluster.shard_aware_options = ShardAwareOptions()
7384
self.cluster.executor = ThreadPoolExecutor(max_workers=2)
85+
self._executor_submit_original = self.cluster.executor.submit
86+
self.cluster.executor.submit = self._executor_submit
87+
self.cluster.scheduler = _Scheduler(self.cluster.executor)
88+
89+
# Collect scheduled calls and execute them right away
90+
self.scheduler_calls = []
91+
original_schedule = self.cluster.scheduler.schedule
92+
93+
def new_schedule(delay, fn, *args, **kwargs):
94+
self.scheduler_calls.append((delay, fn, args, kwargs))
95+
return original_schedule(0, fn, *args, **kwargs)
96+
97+
self.cluster.scheduler.schedule = Mock(side_effect=new_schedule)
7498
self.cluster.signal_connection_failure = lambda *args, **kwargs: False
7599
self.cluster.connection_factory = self.mock_connection_factory
76100
self.connection_counter = 0
101+
self.shard_connection_backoff_scheduler = shard_connection_backoff_policy.new_connection_scheduler(
102+
self.cluster.scheduler)
77103
self.futures = []
78104

79105
def submit(self, fn, *args, **kwargs):
106+
if self.is_shutdown:
107+
return None
108+
return self.cluster.executor.submit(fn, *args, **kwargs)
109+
110+
def _executor_submit(self, fn, *args, **kwargs):
80111
logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs)
81-
if not self.is_shutdown:
82-
f = self.cluster.executor.submit(fn, *args, **kwargs)
83-
self.futures += [f]
84-
return f
112+
f = self._executor_submit_original(fn, *args, **kwargs)
113+
self.futures += [f]
114+
return f
85115

86116
def mock_connection_factory(self, *args, **kwargs):
87117
connection = MagicMock()
88118
connection.is_shutdown = False
89119
connection.is_defunct = False
90120
connection.is_closed = False
91121
connection.orphaned_threshold_reached = False
92-
connection.endpoint = args[0]
93-
sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045)
94-
connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info)
122+
connection.endpoint = args[0]
123+
sharding_info = None
124+
if shard_count:
125+
sharding_info = ShardingInfo(shard_id=1, shards_count=shard_count, partitioner="",
126+
sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042,
127+
shard_aware_port_ssl=19045)
128+
connection.features = ProtocolFeatures(
129+
shard_id=kwargs.get('shard_id', self.connection_counter),
130+
sharding_info=sharding_info)
95131
self.connection_counter += 1
96132

97133
return connection
98134

99135
host = MagicMock()
136+
host.host_id = uuid.uuid4()
100137
host.endpoint = DefaultEndPoint("1.2.3.4")
138+
session = None
139+
try:
140+
for port, is_ssl in [(19042, False), (19045, True)]:
141+
session = MockSession(is_ssl=is_ssl)
142+
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
143+
for f in session.futures:
144+
f.result()
145+
assert len(pool._connections) == expected_connections
146+
for shard_id, connection in pool._connections.items():
147+
assert connection.features.shard_id == shard_id
148+
if shard_id == 0:
149+
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
150+
else:
151+
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
101152

102-
for port, is_ssl in [(19042, False), (19045, True)]:
103-
session = MockSession(is_ssl=is_ssl)
104-
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
105-
for f in session.futures:
106-
f.result()
107-
assert len(pool._connections) == 4
108-
for shard_id, connection in pool._connections.items():
109-
assert connection.features.shard_id == shard_id
110-
if shard_id == 0:
111-
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
112-
else:
113-
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
114-
115-
session.cluster.executor.shutdown(wait=True)
153+
sleep_time = 0
154+
found_related_calls = 0
155+
for delay, fn, args, kwargs in session.scheduler_calls:
156+
if fn.__self__.__class__ is _NoDelayShardConnectionBackoffScheduler:
157+
found_related_calls += 1
158+
self.assertEqual(delay, sleep_time)
159+
self.assertLessEqual(shard_count - 1, found_related_calls)
160+
finally:
161+
if session:
162+
session.cluster.scheduler.shutdown()
163+
session.cluster.executor.shutdown(wait=True)

0 commit comments

Comments
 (0)