|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 |
| - |
| 14 | +import os |
| 15 | +import time |
15 | 16 | import unittest
|
| 17 | +from typing import Optional |
16 | 18 |
|
17 | 19 | from cassandra import ConsistencyLevel, Unavailable
|
18 |
| -from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT |
| 20 | +from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT, Session |
| 21 | +from cassandra.policies import LimitedConcurrencyShardConnectionBackoffPolicy, ShardConnectionBackoffScope, \ |
| 22 | + ConstantReconnectionPolicy, ShardConnectionBackoffPolicy, NoDelayShardConnectionBackoffPolicy |
| 23 | +from cassandra.shard_info import _ShardingInfo |
19 | 24 |
|
20 | 25 | from tests.integration import use_cluster, get_cluster, get_node, TestCluster
|
21 | 26 |
|
22 | 27 |
|
23 | 28 | def setup_module():
|
| 29 | + os.environ['SCYLLA_EXT_OPTS'] = "--smp 4" |
24 | 30 | use_cluster('test_cluster', [4])
|
25 | 31 |
|
26 | 32 |
|
@@ -65,3 +71,127 @@ def test_should_rethrow_on_unvailable_with_default_policy_if_cas(self):
|
65 | 71 | self.assertEqual(exception.consistency, ConsistencyLevel.SERIAL)
|
66 | 72 | self.assertEqual(exception.required_replicas, 2)
|
67 | 73 | self.assertEqual(exception.alive_replicas, 1)
|
| 74 | + |
| 75 | + |
| 76 | +class ShardBackoffPolicyTests(unittest.TestCase): |
| 77 | + @classmethod |
| 78 | + def tearDownClass(cls): |
| 79 | + cluster = get_cluster() |
| 80 | + cluster.start(wait_for_binary_proto=True, wait_other_notice=True) # make sure other nodes are restarted |
| 81 | + |
| 82 | + def test_limited_concurrency_1_connection_per_cluster(self): |
| 83 | + self._test_backoff( |
| 84 | + LimitedConcurrencyShardConnectionBackoffPolicy( |
| 85 | + backoff_policy=ConstantReconnectionPolicy(0.1), |
| 86 | + max_concurrent=1, |
| 87 | + scope=ShardConnectionBackoffScope.Cluster, |
| 88 | + ) |
| 89 | + ) |
| 90 | + |
| 91 | + def test_limited_concurrency_2_connection_per_cluster(self): |
| 92 | + self._test_backoff( |
| 93 | + LimitedConcurrencyShardConnectionBackoffPolicy( |
| 94 | + backoff_policy=ConstantReconnectionPolicy(0.1), |
| 95 | + max_concurrent=2, |
| 96 | + scope=ShardConnectionBackoffScope.Cluster, |
| 97 | + ) |
| 98 | + ) |
| 99 | + |
| 100 | + def test_limited_concurrency_1_connection_per_host(self): |
| 101 | + self._test_backoff( |
| 102 | + LimitedConcurrencyShardConnectionBackoffPolicy( |
| 103 | + backoff_policy=ConstantReconnectionPolicy(0.1), |
| 104 | + max_concurrent=1, |
| 105 | + scope=ShardConnectionBackoffScope.Host, |
| 106 | + ) |
| 107 | + ) |
| 108 | + |
| 109 | + def test_limited_concurrency_2_connection_per_host(self): |
| 110 | + self._test_backoff( |
| 111 | + LimitedConcurrencyShardConnectionBackoffPolicy( |
| 112 | + backoff_policy=ConstantReconnectionPolicy(0.1), |
| 113 | + max_concurrent=1, |
| 114 | + scope=ShardConnectionBackoffScope.Host, |
| 115 | + ) |
| 116 | + ) |
| 117 | + |
| 118 | + def test_no_delay(self): |
| 119 | + self._test_backoff(NoDelayShardConnectionBackoffPolicy()) |
| 120 | + |
| 121 | + def _test_backoff(self, shard_connection_backoff_policy: ShardConnectionBackoffPolicy): |
| 122 | + backoff_policy = None |
| 123 | + if isinstance(shard_connection_backoff_policy, LimitedConcurrencyShardConnectionBackoffPolicy): |
| 124 | + backoff_policy = shard_connection_backoff_policy.backoff_policy |
| 125 | + |
| 126 | + cluster = TestCluster( |
| 127 | + shard_connection_backoff_policy=shard_connection_backoff_policy, |
| 128 | + reconnection_policy=ConstantReconnectionPolicy(0), |
| 129 | + ) |
| 130 | + session = cluster.connect() |
| 131 | + sharding_info = get_sharding_info(session) |
| 132 | + |
| 133 | + # even if backoff is set and there is no sharding info |
| 134 | + # behavior should be the same as if there is no backoff policy |
| 135 | + if not backoff_policy or not sharding_info: |
| 136 | + time.sleep(2) |
| 137 | + expected_connections = 1 |
| 138 | + if sharding_info: |
| 139 | + expected_connections = sharding_info.shards_count |
| 140 | + for host_id, connections_count in get_connections_per_host(session).items(): |
| 141 | + self.assertEqual(connections_count, expected_connections) |
| 142 | + return |
| 143 | + |
| 144 | + sleep_time = 0 |
| 145 | + schedule = backoff_policy.new_schedule() |
| 146 | + # Calculate total time it will need to establish all connections |
| 147 | + if shard_connection_backoff_policy.scope == ShardConnectionBackoffScope.Cluster: |
| 148 | + for _ in session.hosts: |
| 149 | + for _ in range(sharding_info.shards_count - 1): |
| 150 | + sleep_time += next(schedule) |
| 151 | + sleep_time /= shard_connection_backoff_policy.max_concurrent |
| 152 | + elif shard_connection_backoff_policy.scope == ShardConnectionBackoffScope.Host: |
| 153 | + for _ in range(sharding_info.shards_count - 1): |
| 154 | + sleep_time += next(schedule) |
| 155 | + sleep_time /= shard_connection_backoff_policy.max_concurrent |
| 156 | + else: |
| 157 | + raise ValueError("Unknown scope {}".format(shard_connection_backoff_policy.scope)) |
| 158 | + |
| 159 | + time.sleep(sleep_time / 2) |
| 160 | + self.assertFalse( |
| 161 | + is_connection_filled(shard_connection_backoff_policy.scope, session, sharding_info.shards_count)) |
| 162 | + time.sleep(sleep_time / 2 + 1) |
| 163 | + self.assertTrue( |
| 164 | + is_connection_filled(shard_connection_backoff_policy.scope, session, sharding_info.shards_count)) |
| 165 | + |
| 166 | + |
| 167 | +def is_connection_filled(scope: ShardConnectionBackoffScope, session: Session, shards_count: int) -> bool: |
| 168 | + if scope == ShardConnectionBackoffScope.Cluster: |
| 169 | + expected_connections = shards_count * len(session.hosts) |
| 170 | + total_connections = sum(get_connections_per_host(session).values()) |
| 171 | + return expected_connections == total_connections |
| 172 | + elif scope == ShardConnectionBackoffScope.Host: |
| 173 | + expected_connections_per_host = shards_count |
| 174 | + for connections_count in get_connections_per_host(session).values(): |
| 175 | + if connections_count < expected_connections_per_host: |
| 176 | + return False |
| 177 | + if connections_count == expected_connections_per_host: |
| 178 | + continue |
| 179 | + assert False, "Expected {} or less connections but got {}".format(expected_connections_per_host, |
| 180 | + connections_count) |
| 181 | + return True |
| 182 | + else: |
| 183 | + raise ValueError("Unknown scope {}".format(scope)) |
| 184 | + |
| 185 | + |
| 186 | +def get_connections_per_host(session: Session) -> dict[str, int]: |
| 187 | + host_connections = {} |
| 188 | + for host, pool in session._pools.items(): |
| 189 | + host_connections[host.host_id] = len(pool._connections) |
| 190 | + return host_connections |
| 191 | + |
| 192 | + |
| 193 | +def get_sharding_info(session: Session) -> Optional[_ShardingInfo]: |
| 194 | + for host in session.hosts: |
| 195 | + if host.sharding_info: |
| 196 | + return host.sharding_info |
| 197 | + return None |
0 commit comments