diff --git a/kafka/client_async.py b/kafka/client_async.py index 2597fff61..6fe47c6f7 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -303,7 +303,7 @@ def _can_connect(self, node_id): def _conn_state_change(self, node_id, sock, conn): with self._lock: - if conn.connecting(): + if conn.state is ConnectionStates.CONNECTING: # SSL connections can enter this state 2x (second during Handshake) if node_id not in self._connecting: self._connecting.add(node_id) @@ -315,7 +315,19 @@ def _conn_state_change(self, node_id, sock, conn): if self.cluster.is_bootstrap(node_id): self._last_bootstrap = time.time() - elif conn.connected(): + elif conn.state is ConnectionStates.API_VERSIONS_SEND: + try: + self._selector.register(sock, selectors.EVENT_WRITE, conn) + except KeyError: + self._selector.modify(sock, selectors.EVENT_WRITE, conn) + + elif conn.state in (ConnectionStates.API_VERSIONS_RECV, ConnectionStates.AUTHENTICATING): + try: + self._selector.register(sock, selectors.EVENT_READ, conn) + except KeyError: + self._selector.modify(sock, selectors.EVENT_READ, conn) + + elif conn.state is ConnectionStates.CONNECTED: log.debug("Node %s connected", node_id) if node_id in self._connecting: self._connecting.remove(node_id) @@ -332,6 +344,8 @@ def _conn_state_change(self, node_id, sock, conn): if self.cluster.is_bootstrap(node_id): self._bootstrap_fails = 0 + if self._api_versions is None: + self._api_versions = conn._api_versions else: for node_id in list(self._conns.keys()): @@ -970,15 +984,14 @@ def refresh_done(val_or_error): def get_api_versions(self): """Return the ApiVersions map, if available. - Note: A call to check_version must previously have succeeded and returned - version 0.10.0 or later + Note: Only available after bootstrap; requires broker version 0.10.0 or later. Returns: a map of dict mapping {api_key : (min_version, max_version)}, or None if ApiVersion is not supported by the kafka cluster. """ return self._api_versions - def check_version(self, node_id=None, timeout=None, strict=False): + def check_version(self, node_id=None, timeout=None, **kwargs): """Attempt to guess the version of a Kafka broker. Keyword Arguments: @@ -994,50 +1007,45 @@ def check_version(self, node_id=None, timeout=None, strict=False): Raises: NodeNotReadyError (if node_id is provided) NoBrokersAvailable (if node_id is None) - UnrecognizedBrokerVersion: please file bug if seen! - AssertionError (if strict=True): please file bug if seen! """ timeout = timeout or (self.config['api_version_auto_timeout_ms'] / 1000) - self._lock.acquire() - end = time.time() + timeout - while time.time() < end: - - # It is possible that least_loaded_node falls back to bootstrap, - # which can block for an increasing backoff period - try_node = node_id or self.least_loaded_node() - if try_node is None: - self._lock.release() - raise Errors.NoBrokersAvailable() - if not self._init_connect(try_node): - if try_node == node_id: - raise Errors.NodeNotReadyError("Connection failed to %s" % node_id) - else: + with self._lock: + end = time.time() + timeout + while time.time() < end: + time_remaining = max(end - time.time(), 0) + if node_id is not None and self.connection_delay(node_id) > 0: + sleep_time = min(time_remaining, self.connection_delay(node_id) / 1000.0) + if sleep_time > 0: + time.sleep(sleep_time) continue - - conn = self._conns[try_node] - - # We will intentionally cause socket failures - # These should not trigger metadata refresh - self._refresh_on_disconnects = False - try: - remaining = end - time.time() - version = conn.check_version(timeout=remaining, strict=strict, topics=list(self.config['bootstrap_topics_filter'])) - if not self._api_versions: - self._api_versions = conn.get_api_versions() - self._lock.release() - return version - except Errors.NodeNotReadyError: - # Only raise to user if this is a node-specific request + try_node = node_id or self.least_loaded_node() + if try_node is None: + sleep_time = min(time_remaining, self.least_loaded_node_refresh_ms() / 1000.0) + if sleep_time > 0: + log.warning('No node available during check_version; sleeping %.2f secs', sleep_time) + time.sleep(sleep_time) + continue + log.debug('Attempting to check version with node %s', try_node) + if not self._init_connect(try_node): + if try_node == node_id: + raise Errors.NodeNotReadyError("Connection failed to %s" % node_id) + else: + continue + conn = self._conns[try_node] + + while conn.connecting() and time.time() < end: + timeout_ms = min((end - time.time()) * 1000, 200) + self.poll(timeout_ms=timeout_ms) + + if conn._api_version is not None: + return conn._api_version + + # Timeout + else: if node_id is not None: - self._lock.release() - raise - finally: - self._refresh_on_disconnects = True - - # Timeout - else: - self._lock.release() - raise Errors.NoBrokersAvailable() + raise Errors.NodeNotReadyError(node_id) + else: + raise Errors.NoBrokersAvailable() def api_version(self, operation, max_version=None): """Find the latest version of the protocol operation supported by both @@ -1063,7 +1071,7 @@ def api_version(self, operation, max_version=None): broker_api_versions = self._api_versions api_key = operation[0].API_KEY if broker_api_versions is None or api_key not in broker_api_versions: - raise IncompatibleBrokerVersion( + raise Errors.IncompatibleBrokerVersion( "Kafka broker does not support the '{}' Kafka protocol." .format(operation[0].__name__)) broker_min_version, broker_max_version = broker_api_versions[api_key] @@ -1071,7 +1079,7 @@ def api_version(self, operation, max_version=None): if version < broker_min_version: # max library version is less than min broker version. Currently, # no Kafka versions specify a min msg version. Maybe in the future? - raise IncompatibleBrokerVersion( + raise Errors.IncompatibleBrokerVersion( "No version of the '{}' Kafka protocol is supported by both the client and broker." .format(operation[0].__name__)) return version diff --git a/kafka/conn.py b/kafka/conn.py index 6aa20117e..fd6943171 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -24,8 +24,11 @@ from kafka.future import Future from kafka.metrics.stats import Avg, Count, Max, Rate from kafka.oauth.abstract import AbstractTokenProvider -from kafka.protocol.admin import SaslHandShakeRequest, DescribeAclsRequest, DescribeClientQuotasRequest +from kafka.protocol.admin import DescribeAclsRequest, DescribeClientQuotasRequest, ListGroupsRequest, SaslHandShakeRequest +from kafka.protocol.api_versions import ApiVersionsRequest +from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS from kafka.protocol.commit import OffsetFetchRequest +from kafka.protocol.find_coordinator import FindCoordinatorRequest from kafka.protocol.list_offsets import ListOffsetsRequest from kafka.protocol.produce import ProduceRequest from kafka.protocol.metadata import MetadataRequest @@ -92,12 +95,13 @@ class SSLWantWriteError(Exception): class ConnectionStates(object): - DISCONNECTING = '' DISCONNECTED = '' CONNECTING = '' HANDSHAKE = '' CONNECTED = '' AUTHENTICATING = '' + API_VERSIONS_SEND = '' + API_VERSIONS_RECV = '' class BrokerConnection(object): @@ -169,7 +173,7 @@ class BrokerConnection(object): Default: None api_version_auto_timeout_ms (int): number of milliseconds to throw a timeout exception from the constructor when checking the broker - api version. Only applies if api_version is None + api version. Only applies if api_version is None. Default: 2000. selector (selectors.BaseSelector): Provide a specific selector implementation to use for I/O multiplexing. Default: selectors.DefaultSelector @@ -215,6 +219,7 @@ class BrokerConnection(object): 'ssl_password': None, 'ssl_ciphers': None, 'api_version': None, + 'api_version_auto_timeout_ms': 2000, 'selector': selectors.DefaultSelector, 'state_change_callback': lambda node_id, sock, conn: True, 'metrics': None, @@ -228,6 +233,12 @@ class BrokerConnection(object): } SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL') SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512") + VERSION_CHECKS = ( + ((0, 9), ListGroupsRequest[0]()), + ((0, 8, 2), FindCoordinatorRequest[0]('kafka-python-default-group')), + ((0, 8, 1), OffsetFetchRequest[0]('kafka-python-default-group', [])), + ((0, 8, 0), MetadataRequest[0]([])), + ) def __init__(self, host, port, afi, **configs): self.host = host @@ -236,6 +247,9 @@ def __init__(self, host, port, afi, **configs): self._sock_afi = afi self._sock_addr = None self._api_versions = None + self._api_version = None + self._check_version_idx = None + self._api_versions_idx = 2 self._throttle_time = None self.config = copy.copy(self.DEFAULT_CONFIG) @@ -301,6 +315,7 @@ def __init__(self, host, port, afi, **configs): self._ssl_context = None if self.config['ssl_context'] is not None: self._ssl_context = self.config['ssl_context'] + self._api_versions_future = None self._sasl_auth_future = None self.last_attempt = 0 self._gai = [] @@ -404,17 +419,9 @@ def connect(self): self.config['state_change_callback'](self.node_id, self._sock, self) # _wrap_ssl can alter the connection state -- disconnects on failure self._wrap_ssl() - - elif self.config['security_protocol'] == 'SASL_PLAINTEXT': - log.debug('%s: initiating SASL authentication', self) - self.state = ConnectionStates.AUTHENTICATING - self.config['state_change_callback'](self.node_id, self._sock, self) - else: - # security_protocol PLAINTEXT - log.info('%s: Connection complete.', self) - self.state = ConnectionStates.CONNECTED - self._reset_reconnect_backoff() + log.debug('%s: checking broker Api Versions', self) + self.state = ConnectionStates.API_VERSIONS_SEND self.config['state_change_callback'](self.node_id, self._sock, self) # Connection failed @@ -433,15 +440,25 @@ def connect(self): if self.state is ConnectionStates.HANDSHAKE: if self._try_handshake(): log.debug('%s: completed SSL handshake.', self) - if self.config['security_protocol'] == 'SASL_SSL': - log.debug('%s: initiating SASL authentication', self) - self.state = ConnectionStates.AUTHENTICATING - else: - log.info('%s: Connection complete.', self) - self.state = ConnectionStates.CONNECTED - self._reset_reconnect_backoff() + log.debug('%s: checking broker Api Versions', self) + self.state = ConnectionStates.API_VERSIONS_SEND self.config['state_change_callback'](self.node_id, self._sock, self) + if self.state in (ConnectionStates.API_VERSIONS_SEND, ConnectionStates.API_VERSIONS_RECV): + if self._try_api_versions_check(): + # _try_api_versions_check has side-effects: possibly disconnected on socket errors + if self.state in (ConnectionStates.API_VERSIONS_SEND, ConnectionStates.API_VERSIONS_RECV): + if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'): + log.debug('%s: initiating SASL authentication', self) + self.state = ConnectionStates.AUTHENTICATING + self.config['state_change_callback'](self.node_id, self._sock, self) + else: + # security_protocol PLAINTEXT + log.info('%s: Connection complete.', self) + self.state = ConnectionStates.CONNECTED + self._reset_reconnect_backoff() + self.config['state_change_callback'](self.node_id, self._sock, self) + if self.state is ConnectionStates.AUTHENTICATING: assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL') if self._try_authenticate(): @@ -522,6 +539,87 @@ def _try_handshake(self): return False + def _try_api_versions_check(self): + if self._api_versions_future is None: + if self.config['api_version'] is not None: + self._api_version = self.config['api_version'] + self._api_versions = BROKER_API_VERSIONS[self._api_version] + return True + elif self._check_version_idx is None: + request = ApiVersionsRequest[self._api_versions_idx]() + future = Future() + response = self._send(request, blocking=True, request_timeout_ms=(self.config['api_version_auto_timeout_ms'] * 0.8)) + response.add_callback(self._handle_api_versions_response, future) + response.add_errback(self._handle_api_versions_failure, future) + self._api_versions_future = future + self.state = ConnectionStates.API_VERSIONS_RECV + self.config['state_change_callback'](self.node_id, self._sock, self) + elif self._check_version_idx < len(self.VERSION_CHECKS): + version, request = self.VERSION_CHECKS[self._check_version_idx] + future = Future() + response = self._send(request, blocking=True, request_timeout_ms=(self.config['api_version_auto_timeout_ms'] * 0.8)) + response.add_callback(self._handle_check_version_response, future, version) + response.add_errback(self._handle_check_version_failure, future) + self._api_versions_future = future + self.state = ConnectionStates.API_VERSIONS_RECV + self.config['state_change_callback'](self.node_id, self._sock, self) + else: + raise 'Unable to determine broker version.' + + for r, f in self.recv(): + f.success(r) + + # A connection error during blocking send could trigger close() which will reset the future + if self._api_versions_future is None: + return False + elif self._api_versions_future.failed(): + ex = self._api_versions_future.exception + if not isinstance(ex, Errors.KafkaConnectionError): + raise ex + return self._api_versions_future.succeeded() + + def _handle_api_versions_response(self, future, response): + error_type = Errors.for_code(response.error_code) + # if error_type i UNSUPPORTED_VERSION: retry w/ latest version from response + if error_type is not Errors.NoError: + future.failure(error_type()) + if error_type is Errors.UnsupportedVersionError: + self._api_versions_idx -= 1 + if self._api_versions_idx >= 0: + self._api_versions_future = None + self.state = ConnectionStates.API_VERSIONS_SEND + self.config['state_change_callback'](self.node_id, self._sock, self) + else: + self.close(error=error_type()) + return + self._api_versions = dict([ + (api_key, (min_version, max_version)) + for api_key, min_version, max_version in response.api_versions + ]) + self._api_version = self._infer_broker_version_from_api_versions(self._api_versions) + log.info('Broker version identified as %s', '.'.join(map(str, self._api_version))) + future.success(self._api_version) + self.connect() + + def _handle_api_versions_failure(self, future, ex): + future.failure(ex) + self._check_version_idx = 0 + # after failure connection is closed, so state should already be DISCONNECTED + + def _handle_check_version_response(self, future, version, _response): + log.info('Broker version identified as %s', '.'.join(map(str, version))) + log.info('Set configuration api_version=%s to skip auto' + ' check_version requests on startup', version) + self._api_versions = BROKER_API_VERSIONS[version] + self._api_version = version + future.success(version) + self.connect() + + def _handle_check_version_failure(self, future, ex): + future.failure(ex) + self._check_version_idx += 1 + # after failure connection is closed, so state should already be DISCONNECTED + def _try_authenticate(self): assert self.config['api_version'] is None or self.config['api_version'] >= (0, 10, 0) @@ -529,7 +627,7 @@ def _try_authenticate(self): # Build a SaslHandShakeRequest message request = SaslHandShakeRequest[0](self.config['sasl_mechanism']) future = Future() - sasl_response = self._send(request) + sasl_response = self._send(request, blocking=True) sasl_response.add_callback(self._handle_sasl_handshake_response, future) sasl_response.add_errback(lambda f, e: f.failure(e), future) self._sasl_auth_future = future @@ -554,23 +652,28 @@ def _handle_sasl_handshake_response(self, future, response): return future.failure(error_type(self)) if self.config['sasl_mechanism'] not in response.enabled_mechanisms: - return future.failure( + future.failure( Errors.UnsupportedSaslMechanismError( 'Kafka broker does not support %s sasl mechanism. Enabled mechanisms are: %s' % (self.config['sasl_mechanism'], response.enabled_mechanisms))) elif self.config['sasl_mechanism'] == 'PLAIN': - return self._try_authenticate_plain(future) + self._try_authenticate_plain(future) elif self.config['sasl_mechanism'] == 'GSSAPI': - return self._try_authenticate_gssapi(future) + self._try_authenticate_gssapi(future) elif self.config['sasl_mechanism'] == 'OAUTHBEARER': - return self._try_authenticate_oauth(future) + self._try_authenticate_oauth(future) elif self.config['sasl_mechanism'].startswith("SCRAM-SHA-"): - return self._try_authenticate_scram(future) + self._try_authenticate_scram(future) else: - return future.failure( + future.failure( Errors.UnsupportedSaslMechanismError( 'kafka-python does not support SASL mechanism %s' % self.config['sasl_mechanism'])) + assert future.is_done, 'SASL future not complete after mechanism processing!' + if future.failed(): + self.close(error=future.exception) + else: + self.connect() def _send_bytes(self, data): """Send some data via non-blocking IO @@ -901,7 +1004,17 @@ def connecting(self): different states, such as SSL handshake, authorization, etc).""" return self.state in (ConnectionStates.CONNECTING, ConnectionStates.HANDSHAKE, - ConnectionStates.AUTHENTICATING) + ConnectionStates.AUTHENTICATING, + ConnectionStates.API_VERSIONS_SEND, + ConnectionStates.API_VERSIONS_RECV) + + def initializing(self): + """Returns True if socket is connected but full connection is not complete. + During this time the connection may send api requests to the broker to + check api versions and perform SASL authentication.""" + return self.state in (ConnectionStates.AUTHENTICATING, + ConnectionStates.API_VERSIONS_SEND, + ConnectionStates.API_VERSIONS_RECV) def disconnected(self): """Return True iff socket is closed""" @@ -949,6 +1062,7 @@ def close(self, error=None): return log.log(logging.ERROR if error else logging.INFO, '%s: Closing connection. %s', self, error or '') self._update_reconnect_backoff() + self._api_versions_future = None self._sasl_auth_future = None self._protocol = KafkaProtocol( client_id=self.config['client_id'], @@ -975,8 +1089,7 @@ def close(self, error=None): def _can_send_recv(self): """Return True iff socket is ready for requests / responses""" - return self.state in (ConnectionStates.AUTHENTICATING, - ConnectionStates.CONNECTED) + return self.connected() or self.initializing() def send(self, request, blocking=True, request_timeout_ms=None): """Queue request for async network send, return Future() @@ -1218,16 +1331,6 @@ def next_ifr_request_timeout_ms(self): else: return float('inf') - def _handle_api_versions_response(self, response): - error_type = Errors.for_code(response.error_code) - if error_type is not Errors.NoError: - return False - self._api_versions = dict([ - (api_key, (min_version, max_version)) - for api_key, min_version, max_version in response.api_versions - ]) - return self._api_versions - def get_api_versions(self): if self._api_versions is not None: return self._api_versions @@ -1242,6 +1345,20 @@ def _infer_broker_version_from_api_versions(self, api_versions): test_cases = [ # format (, ) # Make sure to update consumer_integration test check when adding newer versions. + # ((3, 9), FetchRequest[17]), + # ((3, 8), ProduceRequest[11]), + # ((3, 7), FetchRequest[16]), + # ((3, 6), AddPartitionsToTxnRequest[4]), + # ((3, 5), FetchRequest[15]), + # ((3, 4), StopReplicaRequest[3]), # broker-internal api... + # ((3, 3), DescribeAclsRequest[3]), + # ((3, 2), JoinGroupRequest[9]), + # ((3, 1), FetchRequest[13]), + # ((3, 0), ListOffsetsRequest[7]), + # ((2, 8), ProduceRequest[9]), + # ((2, 7), FetchRequest[12]), + # ((2, 6), ListGroupsRequest[4]), + # ((2, 5), JoinGroupRequest[7]), ((2, 6), DescribeClientQuotasRequest[0]), ((2, 5), DescribeAclsRequest[2]), ((2, 4), ProduceRequest[8]), @@ -1268,121 +1385,24 @@ def _infer_broker_version_from_api_versions(self, api_versions): # so if all else fails, choose that return (0, 10, 0) - def check_version(self, timeout=2, strict=False, topics=[]): + def check_version(self, timeout=2, **kwargs): """Attempt to guess the broker version. + Keyword Arguments: + timeout (numeric, optional): Maximum number of seconds to block attempting + to connect and check version. Default 2 + Note: This is a blocking call. Returns: version tuple, i.e. (3, 9), (2, 4), etc ... + + Raises: NodeNotReadyError on timeout """ timeout_at = time.time() + timeout - log.info('Probing node %s broker version', self.node_id) - # Monkeypatch some connection configurations to avoid timeouts - override_config = { - 'request_timeout_ms': timeout * 1000, - 'max_in_flight_requests_per_connection': 5 - } - stashed = {} - for key in override_config: - stashed[key] = self.config[key] - self.config[key] = override_config[key] - - def reset_override_configs(): - for key in stashed: - self.config[key] = stashed[key] - - # kafka kills the connection when it doesn't recognize an API request - # so we can send a test request and then follow immediately with a - # vanilla MetadataRequest. If the server did not recognize the first - # request, both will be failed with a ConnectionError that wraps - # socket.error (32, 54, or 104) - from kafka.protocol.admin import ListGroupsRequest - from kafka.protocol.api_versions import ApiVersionsRequest - from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS - from kafka.protocol.commit import OffsetFetchRequest - from kafka.protocol.find_coordinator import FindCoordinatorRequest - - test_cases = [ - # All cases starting from 0.10 will be based on ApiVersionsResponse - ((0, 11), ApiVersionsRequest[1]()), - ((0, 10, 0), ApiVersionsRequest[0]()), - ((0, 9), ListGroupsRequest[0]()), - ((0, 8, 2), FindCoordinatorRequest[0]('kafka-python-default-group')), - ((0, 8, 1), OffsetFetchRequest[0]('kafka-python-default-group', [])), - ((0, 8, 0), MetadataRequest[0](topics)), - ] - - for version, request in test_cases: - if not self.connect_blocking(timeout_at - time.time()): - reset_override_configs() - raise Errors.NodeNotReadyError() - f = self.send(request) - # HACK: sleeping to wait for socket to send bytes - time.sleep(0.1) - # when broker receives an unrecognized request API - # it abruptly closes our socket. - # so we attempt to send a second request immediately - # that we believe it will definitely recognize (metadata) - # the attempt to write to a disconnected socket should - # immediately fail and allow us to infer that the prior - # request was unrecognized - mr = self.send(MetadataRequest[0](topics)) - - if not (f.is_done and mr.is_done) and self._sock is not None: - selector = self.config['selector']() - selector.register(self._sock, selectors.EVENT_READ) - while not (f.is_done and mr.is_done): - selector.select(1) - for response, future in self.recv(): - future.success(response) - selector.close() - - if f.succeeded(): - if version >= (0, 10, 0): - # Starting from 0.10 kafka broker we determine version - # by looking at ApiVersionsResponse - api_versions = self._handle_api_versions_response(f.value) - if not api_versions: - continue - version = self._infer_broker_version_from_api_versions(api_versions) - else: - if version not in BROKER_API_VERSIONS: - raise Errors.UnrecognizedBrokerVersion(version) - self._api_versions = BROKER_API_VERSIONS[version] - log.info('Broker version identified as %s', '.'.join(map(str, version))) - log.info('Set configuration api_version=%s to skip auto' - ' check_version requests on startup', version) - break - - # Only enable strict checking to verify that we understand failure - # modes. For most users, the fact that the request failed should be - # enough to rule out a particular broker version. - if strict: - # If the socket flush hack did not work (which should force the - # connection to close and fail all pending requests), then we - # get a basic Request Timeout. This is not ideal, but we'll deal - if isinstance(f.exception, Errors.RequestTimedOutError): - pass - - # 0.9 brokers do not close the socket on unrecognized api - # requests (bug...). In this case we expect to see a correlation - # id mismatch - elif (isinstance(f.exception, Errors.CorrelationIdError) and - version > (0, 9)): - pass - elif six.PY2: - assert isinstance(f.exception.args[0], socket.error) - assert f.exception.args[0].errno in (32, 54, 104) - else: - assert isinstance(f.exception.args[0], ConnectionError) - log.info("Broker is not v%s -- it did not recognize %s", - version, request.__class__.__name__) + if not self.connect_blocking(timeout_at - time.time()): + raise Errors.NodeNotReadyError() else: - reset_override_configs() - raise Errors.UnrecognizedBrokerVersion() - - reset_override_configs() - return version + return self._api_version def __str__(self): return "" % ( diff --git a/kafka/protocol/api_versions.py b/kafka/protocol/api_versions.py index 9a782928b..dc0aa588e 100644 --- a/kafka/protocol/api_versions.py +++ b/kafka/protocol/api_versions.py @@ -76,8 +76,8 @@ class ApiVersionsRequest_v1(Request): class ApiVersionsRequest_v2(Request): API_KEY = 18 API_VERSION = 2 - RESPONSE_TYPE = ApiVersionsResponse_v1 - SCHEMA = ApiVersionsRequest_v0.SCHEMA + RESPONSE_TYPE = ApiVersionsResponse_v2 + SCHEMA = ApiVersionsRequest_v1.SCHEMA ApiVersionsRequest = [ diff --git a/test/test_conn.py b/test/test_conn.py index 47f5c428e..959cbb4dc 100644 --- a/test/test_conn.py +++ b/test/test_conn.py @@ -15,6 +15,13 @@ import kafka.errors as Errors +from kafka.vendor import six + +if six.PY2: + ConnectionError = socket.error + TimeoutError = socket.error + BlockingIOError = Exception + @pytest.fixture def dns_lookup(mocker): @@ -27,13 +34,16 @@ def dns_lookup(mocker): def _socket(mocker): socket = mocker.MagicMock() socket.connect_ex.return_value = 0 + socket.send.side_effect = lambda d: len(d) + socket.recv.side_effect = BlockingIOError("mocked recv") mocker.patch('socket.socket', return_value=socket) return socket @pytest.fixture -def conn(_socket, dns_lookup): +def conn(_socket, dns_lookup, mocker): conn = BrokerConnection('localhost', 9092, socket.AF_INET) + mocker.patch.object(conn, '_try_api_versions_check', return_value=True) return conn @@ -217,12 +227,13 @@ def test_recv_disconnected(_socket, conn): conn.send(req) # Empty data on recv means the socket is disconnected + _socket.recv.side_effect = None _socket.recv.return_value = b'' # Attempt to receive should mark connection as disconnected - assert conn.connected() + assert conn.connected(), 'Not connected: %s' % conn.state conn.recv() - assert conn.disconnected() + assert conn.disconnected(), 'Not disconnected: %s' % conn.state def test_recv(_socket, conn):