diff --git a/nats/aio/client.py b/nats/aio/client.py index be379af8..c3439bcc 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -85,6 +85,7 @@ DEFAULT_MAX_RECONNECT_ATTEMPTS = 60 DEFAULT_PING_INTERVAL = 120 # in seconds DEFAULT_MAX_OUTSTANDING_PINGS = 2 +DEFAULT_MAX_READ_TIMEOUTS = 3 DEFAULT_MAX_PAYLOAD_SIZE = 1048576 DEFAULT_MAX_FLUSHER_QUEUE_SIZE = 1024 DEFAULT_FLUSH_TIMEOUT = 10 # in seconds @@ -301,6 +302,7 @@ async def connect( max_reconnect_attempts: int = DEFAULT_MAX_RECONNECT_ATTEMPTS, ping_interval: int = DEFAULT_PING_INTERVAL, max_outstanding_pings: int = DEFAULT_MAX_OUTSTANDING_PINGS, + max_read_timeouts: int = DEFAULT_MAX_READ_TIMEOUTS, dont_randomize: bool = False, flusher_queue_size: int = DEFAULT_MAX_FLUSHER_QUEUE_SIZE, no_echo: bool = False, @@ -331,6 +333,7 @@ async def connect( :param discovered_server_cb: Callback to report when a new server joins the cluster. :param pending_size: Max size of the pending buffer for publishing commands. :param flush_timeout: Max duration to wait for a forced flush to occur. + :param max_read_timeouts: Maximum number of consecutive read timeouts before considering a connection stale. Connecting setting all callbacks:: @@ -448,6 +451,7 @@ async def subscribe_handler(msg): self.options["reconnect_time_wait"] = reconnect_time_wait self.options["max_reconnect_attempts"] = max_reconnect_attempts self.options["ping_interval"] = ping_interval + self.options["max_read_timeouts"] = max_read_timeouts self.options["max_outstanding_pings"] = max_outstanding_pings self.options["no_echo"] = no_echo self.options["user"] = user @@ -1013,19 +1017,38 @@ async def request( the responses. """ + if not self.is_connected: + await self._check_connection_health() + + if not self.is_connected: + if self.is_closed: + raise errors.ConnectionClosedError + elif self.is_reconnecting: + raise errors.ConnectionReconnectingError + else: + raise errors.ConnectionClosedError + if old_style: # FIXME: Support headers in old style requests. - return await self._request_old_style( - subject, payload, timeout=timeout - ) + try: + return await self._request_old_style( + subject, payload, timeout=timeout + ) + except (errors.TimeoutError, asyncio.TimeoutError): + await self._check_connection_health() + raise errors.TimeoutError else: - msg = await self._request_new_style( - subject, payload, timeout=timeout, headers=headers - ) - if (msg.headers and msg.headers.get(nats.js.api.Header.STATUS) - == NO_RESPONDERS_STATUS): - raise errors.NoRespondersError - return msg + try: + msg = await self._request_new_style( + subject, payload, timeout=timeout, headers=headers + ) + if msg.headers and msg.headers.get(nats.js.api.Header.STATUS + ) == NO_RESPONDERS_STATUS: + raise errors.NoRespondersError + return msg + except (errors.TimeoutError, asyncio.TimeoutError): + await self._check_connection_health() + raise errors.TimeoutError async def _request_new_style( self, @@ -1037,6 +1060,9 @@ async def _request_new_style( if self.is_draining_pubs: raise errors.ConnectionDrainingError + if not self.is_connected: + raise errors.ConnectionClosedError + if not self._resp_sub_prefix: await self._init_request_sub() assert self._resp_sub_prefix @@ -1049,21 +1075,37 @@ async def _request_new_style( # Then use the future to get the response. future: asyncio.Future = asyncio.Future() - future.add_done_callback( - lambda f: self._resp_map.pop(token.decode(), None) - ) - self._resp_map[token.decode()] = future + token_str = token.decode() - # Publish the request - await self.publish( - subject, payload, reply=inbox.decode(), headers=headers - ) + def cleanup_resp_map(f): + self._resp_map.pop(token_str, None) + + future.add_done_callback(cleanup_resp_map) + self._resp_map[token_str] = future - # Wait for the response or give up on timeout. try: - return await asyncio.wait_for(future, timeout) - except asyncio.TimeoutError: - raise errors.TimeoutError + # Publish the request + await self.publish( + subject, payload, reply=inbox.decode(), headers=headers + ) + + if not self.is_connected: + future.cancel() + raise errors.ConnectionClosedError + + try: + return await asyncio.wait_for(future, timeout) + except asyncio.TimeoutError: + cleanup_resp_map(future) + raise errors.TimeoutError + except asyncio.CancelledError: + cleanup_resp_map(future) + raise + except Exception: + if not future.done(): + future.cancel() + cleanup_resp_map(future) + raise def new_inbox(self) -> str: """ @@ -1399,6 +1441,35 @@ async def _process_err(self, err_msg: str) -> None: # For now we handle similar as other clients and close. asyncio.create_task(self._close(Client.CLOSED, do_cbs)) + async def _check_connection_health(self) -> bool: + """ + Checks if the connection appears healthy, and if not, attempts reconnection. + + Returns: + bool: True if connection is healthy or was successfully reconnected, False otherwise + """ + if not self.is_connected: + if self.options[ + "allow_reconnect" + ] and not self.is_reconnecting and not self.is_closed: + self._status = Client.RECONNECTING + self._ps.reset() + + try: + if self._reconnection_task is not None and not self._reconnection_task.cancelled( + ): + self._reconnection_task.cancel() + + self._reconnection_task = asyncio.get_running_loop( + ).create_task(self._attempt_reconnect()) + + await asyncio.sleep(self.options["reconnect_time_wait"]) + return self.is_connected + except Exception: + return False + return False + return True + async def _process_op_err(self, e: Exception) -> None: """ Process errors which occurred while reading or parsing @@ -2102,8 +2173,16 @@ async def _ping_interval(self) -> None: await self._send_ping() except (asyncio.CancelledError, RuntimeError, AttributeError): break - # except asyncio.InvalidStateError: - # pass + except asyncio.InvalidStateError: + # Handle invalid state errors that can occur when connection state changes + if self.is_connected: + await self._process_op_err(ErrStaleConnection()) + break + except Exception as e: + if self.is_connected: + await self._error_cb(e) + await self._process_op_err(ErrStaleConnection()) + break async def _read_loop(self) -> None: """ @@ -2112,6 +2191,8 @@ async def _read_loop(self) -> None: In case of error while reading, it will stop running and its task has to be rescheduled. """ + read_timeout_count = 0 + while True: try: should_bail = self.is_closed or self.is_reconnecting @@ -2123,21 +2204,47 @@ async def _read_loop(self) -> None: await self._process_op_err(err) break - b = await self._transport.read(DEFAULT_BUFFER_SIZE) - await self._ps.parse(b) + # Use a timeout for reading to detect stalled connections + try: + read_future = self._transport.read(DEFAULT_BUFFER_SIZE) + b = await asyncio.wait_for( + read_future, timeout=self.options["ping_interval"] + ) + read_timeout_count = 0 + await self._ps.parse(b) + except asyncio.TimeoutError: + read_timeout_count += 1 + if read_timeout_count >= self.options["max_read_timeouts"]: + err = ErrStaleConnection() + await self._error_cb(err) + await self._process_op_err(err) + break + continue + except errors.ProtocolError: await self._process_op_err(errors.ProtocolError()) break + except ConnectionResetError as e: + await self._error_cb(e) + await self._process_op_err(errors.ConnectionClosedError()) + break except OSError as e: + await self._error_cb(e) await self._process_op_err(e) break + except asyncio.InvalidStateError: + if self.is_connected: + err = ErrStaleConnection() + await self._error_cb(err) + await self._process_op_err(err) + break except asyncio.CancelledError: break except Exception as ex: + await self._error_cb(ex) + await self._process_op_err(ex) _logger.error("nats: encountered error", exc_info=ex) break - # except asyncio.InvalidStateError: - # pass async def __aenter__(self) -> "Client": """For when NATS client is used in a context manager"""