Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4384,7 +4384,7 @@ def _query(self, host, message=None, cb=None):
connection = None
try:
# TODO get connectTimeout from cluster settings
connection, request_id = pool.borrow_connection(timeout=2.0)
connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key if self.query else None)
self._connection = connection
result_meta = self.prepared_statement.result_metadata if self.prepared_statement else []

Expand Down
38 changes: 38 additions & 0 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
RegisterMessage, ReviseRequestMessage)
from cassandra.util import OrderedDict

MIN_LONG = -(2 ** 63)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -599,6 +600,39 @@ def int_from_buf_item(i):
else:
int_from_buf_item = ord

class ShardingInfo(object):

def __init__(self, shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb):
self.shards_count = int(shards_count)
self.partitioner = partitioner
self.sharding_algorithm = sharding_algorithm
self.sharding_ignore_msb = int(sharding_ignore_msb)

@staticmethod
def parse_sharding_info(message):
shard_id = message.options.get('SCYLLA_SHARD', [''])[0] or None
shards_count = message.options.get('SCYLLA_NR_SHARDS', [''])[0] or None
partitioner = message.options.get('SCYLLA_PARTITIONER', [''])[0] or None
sharding_algorithm = message.options.get('SCYLLA_SHARDING_ALGORITHM', [''])[0] or None
sharding_ignore_msb = message.options.get('SCYLLA_SHARDING_IGNORE_MSB', [''])[0] or None

if not (shard_id or shards_count or partitioner == "org.apache.cassandra.dht.Murmur3Partitioner" or
sharding_algorithm == "biased-token-round-robin" or sharding_ignore_msb):
return 0, None

return int(shard_id), ShardingInfo(shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb)

def shard_id(self, t):
token = t.value
token += MIN_LONG
token <<= self.sharding_ignore_msb
tokLo = token & 0xffffffff
tokHi = (token >> 32) & 0xffffffff
mul1 = tokLo * self.shards_count
mul2 = tokHi * self.shards_count
_sum = (mul1 >> 32) + mul2
output = _sum >> 32

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm... What is this about anyway? ;)

return output

class Connection(object):

Expand Down Expand Up @@ -666,6 +700,9 @@ class Connection(object):
_check_hostname = False
_product_type = None

shard_id = 0
sharding_info = None

def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
ssl_options=None, sockopts=None, compression=True,
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
Expand Down Expand Up @@ -1126,6 +1163,7 @@ def _send_options_message(self):

@defunct_on_error
def _handle_options_response(self, options_response):
self.shard_id, self.sharding_info = ShardingInfo.parse_sharding_info(options_response)
if self.is_defunct:
return

Expand Down
88 changes: 65 additions & 23 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
import socket
import time
import random
from threading import Lock, RLock, Condition
import weakref
try:
Expand Down Expand Up @@ -123,6 +124,8 @@ class Host(object):

_currently_handling_node_up = False

sharding_info = None

def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=None, host_id=None):
if endpoint is None:
raise ValueError("endpoint may not be None")
Expand Down Expand Up @@ -339,7 +342,6 @@ class HostConnection(object):
shutdown_on_error = False

_session = None
_connection = None
_lock = None
_keyspace = None

Expand All @@ -351,6 +353,7 @@ def __init__(self, host, host_distance, session):
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
self._stream_available_condition = Condition(self._lock)
self._is_replacing = False
self._connections = dict()

if host_distance == HostDistance.IGNORED:
log.debug("Not opening connection to ignored host %s", self.host)
Expand All @@ -360,18 +363,45 @@ def __init__(self, host, host_distance, session):
return

log.debug("Initializing connection for host %s", self.host)
self._connection = session.cluster.connection_factory(host.endpoint)
first_connection = session.cluster.connection_factory(host.endpoint)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
first_connection = session.cluster.connection_factory(host.endpoint)
first_connection = session.cluster.connection_factory(self.host.endpoint)

log.debug("first connection created for shard_id=%i", first_connection.shard_id)
self._connections[first_connection.shard_id] = first_connection
self._keyspace = session.keyspace

if self._keyspace:
self._connection.set_keyspace_blocking(self._keyspace)
first_connection.set_keyspace_blocking(self._keyspace)

if first_connection.sharding_info:
self.host.sharding_info = weakref.proxy(first_connection.sharding_info)
for _ in range(first_connection.sharding_info.shards_count * 2):
conn = self._session.cluster.connection_factory(self.host.endpoint)
if conn.shard_id not in self._connections.keys():
log.debug("new connection created for shard_id=%i", conn.shard_id)
self._connections[conn.shard_id] = conn
if self._keyspace:
self._connections[conn.shard_id].set_keyspace_blocking(self._keyspace)

if len(self._connections.keys()) == first_connection.sharding_info.shards_count:
break
if not len(self._connections.keys()) == first_connection.sharding_info.shards_count:
raise NoConnectionsAvailable("not enough shard connection opened")

log.debug("Finished initializing connection for host %s", self.host)

def borrow_connection(self, timeout):
def borrow_connection(self, timeout, routing_key=None):
if self.is_shutdown:
raise ConnectionException(
"Pool for %s is shutdown" % (self.host,), self.host)

conn = self._connection
shard_id = 0
if self.host.sharding_info:
if routing_key:
t = self._session.cluster.metadata.token_map.token_class.from_key(routing_key)
shard_id =self.host.sharding_info.shard_id(t)
else:
shard_id = random.randint(0, self.host.sharding_info.shards_count - 1)

conn = self._connections.get(shard_id)
if not conn:
raise NoConnectionsAvailable()

Expand Down Expand Up @@ -416,7 +446,7 @@ def return_connection(self, connection):
if is_down:
self.shutdown()
else:
self._connection = None
del self._connections[connection.shard_id]
with self._lock:
if self._is_replacing:
return
Expand All @@ -433,7 +463,7 @@ def _replace(self, connection):
conn = self._session.cluster.connection_factory(self.host.endpoint)
if self._keyspace:
conn.set_keyspace_blocking(self._keyspace)
self._connection = conn
self._connections[connection.shard_id] = conn
except Exception:
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
self._session.submit(self._replace, connection)
Expand All @@ -450,36 +480,48 @@ def shutdown(self):
self.is_shutdown = True
self._stream_available_condition.notify_all()

if self._connection:
self._connection.close()
self._connection = None
if self._connections:
for c in self._connections.values():
c.close()
self._connections = dict()

def _set_keyspace_for_all_conns(self, keyspace, callback):
if self.is_shutdown or not self._connection:
"""
Asynchronously sets the keyspace for all connections. When all
connections have been set, `callback` will be called with two
arguments: this pool, and a list of any errors that occurred.
"""
remaining_callbacks = set(self._connections.values())
errors = []

if not remaining_callbacks:
callback(self, errors)
return

def connection_finished_setting_keyspace(conn, error):
self.return_connection(conn)
errors = [] if not error else [error]
callback(self, errors)
remaining_callbacks.remove(conn)
if error:
errors.append(error)

if not remaining_callbacks:
callback(self, errors)

self._keyspace = keyspace
self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
for conn in self._connections.values():
conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace)

def get_connections(self):
c = self._connection
return [c] if c else []
c = self._connections
return list(self._connections.values()) if c else []

def get_state(self):
connection = self._connection
open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0
in_flights = [connection.in_flight] if connection else []
return {'shutdown': self.is_shutdown, 'open_count': open_count, 'in_flights': in_flights}
in_flights = [c.in_flight for c in self._connections.values()]
return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights}

@property
def open_count(self):
connection = self._connection
return 1 if connection and not (connection.is_closed or connection.is_defunct) else 0
return sum([1 if c and not (c.is_closed or c.is_defunct) else 0 for c in self._connections.values()])

_MAX_SIMULTANEOUS_CREATION = 1
_MIN_TRASH_INTERVAL = 10
Expand Down Expand Up @@ -522,7 +564,7 @@ def __init__(self, host, host_distance, session):
self.open_count = core_conns
log.debug("Finished initializing new connection pool for host %s", self.host)

def borrow_connection(self, timeout):
def borrow_connection(self, timeout, routing_key=None):
if self.is_shutdown:
raise ConnectionException(
"Pool for %s is shutdown" % (self.host,), self.host)
Expand Down
42 changes: 30 additions & 12 deletions tests/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
try:
from ccmlib.dse_cluster import DseCluster
from ccmlib.cluster import Cluster as CCMCluster
from ccmlib.scylla_cluster import ScyllaCluster as CCMSCyllaCluster
from ccmlib.cluster_factory import ClusterFactory as CCMClusterFactory
from ccmlib import common
except ImportError as e:
Expand Down Expand Up @@ -174,8 +175,12 @@ def _get_dse_version_from_cass(cass_version):
try:
cassandra_version = Version(cv_string) # env var is set to test-dse for DDAC
except:
# fallback to MAPPED_CASSANDRA_VERSION
cassandra_version = Version(mcv_string)
try:
# fallback to MAPPED_CASSANDRA_VERSION
cassandra_version = Version(mcv_string)
except:
cassandra_version = Version('3.11.4')
cv_string = '3.11.4'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cv_string is not initialized outside this branch - looks wrong.


CASSANDRA_VERSION = Version(mcv_string) if mcv_string else cassandra_version
CCM_VERSION = mcv_string if mcv_string else cv_string
Expand All @@ -184,20 +189,27 @@ def _get_dse_version_from_cass(cass_version):
CASSANDRA_DIR = os.getenv('CASSANDRA_DIR', None)

CCM_KWARGS = {}
IS_SCYLLA = False
if DSE_VERSION:
log.info('Using DSE version: %s', DSE_VERSION)
if not CASSANDRA_DIR:
CCM_KWARGS['version'] = DSE_VERSION
if DSE_CRED:
log.info("Using DSE credentials file located at {0}".format(DSE_CRED))
CCM_KWARGS['dse_credentials_file'] = DSE_CRED

elif CASSANDRA_DIR:
log.info("Using Cassandra dir: %s", CASSANDRA_DIR)
CCM_KWARGS['install_dir'] = CASSANDRA_DIR
else:
elif os.environ.get('CASSANDRA_VERSION'):
log.info('Using Cassandra version: %s', CCM_VERSION)
CCM_KWARGS['version'] = CCM_VERSION

elif os.getenv('INSTALL_DIRECTORY'):
CCM_KWARGS['install_dir'] = os.path.join(os.getenv('INSTALL_DIRECTORY'))
IS_SCYLLA = True
elif os.getenv('SCYLLA_VERSION'):
CCM_KWARGS['cassandra_version'] = os.path.join(os.getenv('SCYLLA_VERSION'))
IS_SCYLLA = True

#This changes the default contact_point parameter in Cluster
def set_default_cass_ip():
Expand Down Expand Up @@ -447,10 +459,10 @@ def is_current_cluster(cluster_name, node_counts, workloads):
if CCM_CLUSTER and CCM_CLUSTER.name == cluster_name:
if [len(list(nodes)) for dc, nodes in
groupby(CCM_CLUSTER.nodelist(), lambda n: n.data_center)] == node_counts:
for node in CCM_CLUSTER.nodelist():
if set(node.workloads) != set(workloads):
print("node workloads don't match creating new cluster")
return False
#for node in CCM_CLUSTER.nodelist():
# if set(node.workloads) != set(workloads):
# print("node workloads don't match creating new cluster")
# return False
return True
return False

Expand Down Expand Up @@ -559,8 +571,14 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None,

CCM_CLUSTER.set_dse_configuration_options(dse_options)
else:
CCM_CLUSTER = CCMCluster(path, cluster_name, **ccm_options)
if IS_SCYLLA:
CCM_CLUSTER = CCMSCyllaCluster(path, cluster_name, **ccm_options)
else:
CCM_CLUSTER = CCMCluster(path, cluster_name, **ccm_options)
CCM_CLUSTER.set_configuration_options({'start_native_transport': True})
if IS_SCYLLA:
CCM_CLUSTER.set_configuration_options({'experimental': True})

if Version(cassandra_version) >= Version('2.2'):
CCM_CLUSTER.set_configuration_options({'enable_user_defined_functions': True})
if Version(cassandra_version) >= Version('3.0'):
Expand All @@ -574,9 +592,9 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None,

# This will enable the Mirroring query handler which will echo our custom payload k,v pairs back

if 'graph' not in workloads:
if PROTOCOL_VERSION >= 4:
jvm_args = [" -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"]
#if 'graph' not in workloads:
# if PROTOCOL_VERSION >= 4:
# jvm_args = [" -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"]
if len(workloads) > 0:
for node in CCM_CLUSTER.nodes.values():
node.set_workloads(workloads)
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/standard/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def fetch_connections(self, host, cluster):
if conn._connections is not None and len(conn._connections) > 0:
connections.append(conn._connections)
else:
if conn._connection is not None:
connections.append(conn._connection)
if conn._connections and len(conn._connections.values()) > 0:
connections.append(conn._connections.values())
return connections

def wait_for_connections(self, host, cluster):
Expand Down
Loading