Skip to content

[DRAFT]: Connection Reliability Improvements #675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 135 additions & 28 deletions nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
DEFAULT_MAX_RECONNECT_ATTEMPTS = 60
DEFAULT_PING_INTERVAL = 120 # in seconds
DEFAULT_MAX_OUTSTANDING_PINGS = 2
DEFAULT_MAX_READ_TIMEOUTS = 3
DEFAULT_MAX_PAYLOAD_SIZE = 1048576
DEFAULT_MAX_FLUSHER_QUEUE_SIZE = 1024
DEFAULT_FLUSH_TIMEOUT = 10 # in seconds
Expand Down Expand Up @@ -301,6 +302,7 @@ async def connect(
max_reconnect_attempts: int = DEFAULT_MAX_RECONNECT_ATTEMPTS,
ping_interval: int = DEFAULT_PING_INTERVAL,
max_outstanding_pings: int = DEFAULT_MAX_OUTSTANDING_PINGS,
max_read_timeouts: int = DEFAULT_MAX_READ_TIMEOUTS,
dont_randomize: bool = False,
flusher_queue_size: int = DEFAULT_MAX_FLUSHER_QUEUE_SIZE,
no_echo: bool = False,
Expand Down Expand Up @@ -331,6 +333,7 @@ async def connect(
:param discovered_server_cb: Callback to report when a new server joins the cluster.
:param pending_size: Max size of the pending buffer for publishing commands.
:param flush_timeout: Max duration to wait for a forced flush to occur.
:param max_read_timeouts: Maximum number of consecutive read timeouts before considering a connection stale.

Connecting setting all callbacks::

Expand Down Expand Up @@ -448,6 +451,7 @@ async def subscribe_handler(msg):
self.options["reconnect_time_wait"] = reconnect_time_wait
self.options["max_reconnect_attempts"] = max_reconnect_attempts
self.options["ping_interval"] = ping_interval
self.options["max_read_timeouts"] = max_read_timeouts
self.options["max_outstanding_pings"] = max_outstanding_pings
self.options["no_echo"] = no_echo
self.options["user"] = user
Expand Down Expand Up @@ -1013,19 +1017,38 @@ async def request(
the responses.

"""
if not self.is_connected:
await self._check_connection_health()

if not self.is_connected:
if self.is_closed:
raise errors.ConnectionClosedError
elif self.is_reconnecting:
raise errors.ConnectionReconnectingError
else:
raise errors.ConnectionClosedError

if old_style:
# FIXME: Support headers in old style requests.
return await self._request_old_style(
subject, payload, timeout=timeout
)
try:
return await self._request_old_style(
subject, payload, timeout=timeout
)
except (errors.TimeoutError, asyncio.TimeoutError):
await self._check_connection_health()
raise errors.TimeoutError
else:
msg = await self._request_new_style(
subject, payload, timeout=timeout, headers=headers
)
if (msg.headers and msg.headers.get(nats.js.api.Header.STATUS)
== NO_RESPONDERS_STATUS):
raise errors.NoRespondersError
return msg
try:
msg = await self._request_new_style(
subject, payload, timeout=timeout, headers=headers
)
if msg.headers and msg.headers.get(nats.js.api.Header.STATUS
) == NO_RESPONDERS_STATUS:
raise errors.NoRespondersError
return msg
except (errors.TimeoutError, asyncio.TimeoutError):
await self._check_connection_health()
raise errors.TimeoutError

async def _request_new_style(
self,
Expand All @@ -1037,6 +1060,9 @@ async def _request_new_style(
if self.is_draining_pubs:
raise errors.ConnectionDrainingError

if not self.is_connected:
raise errors.ConnectionClosedError

if not self._resp_sub_prefix:
await self._init_request_sub()
assert self._resp_sub_prefix
Expand All @@ -1049,21 +1075,37 @@ async def _request_new_style(

# Then use the future to get the response.
future: asyncio.Future = asyncio.Future()
future.add_done_callback(
lambda f: self._resp_map.pop(token.decode(), None)
)
self._resp_map[token.decode()] = future
token_str = token.decode()

# Publish the request
await self.publish(
subject, payload, reply=inbox.decode(), headers=headers
)
def cleanup_resp_map(f):
self._resp_map.pop(token_str, None)

future.add_done_callback(cleanup_resp_map)
self._resp_map[token_str] = future

# Wait for the response or give up on timeout.
try:
return await asyncio.wait_for(future, timeout)
except asyncio.TimeoutError:
raise errors.TimeoutError
# Publish the request
await self.publish(
subject, payload, reply=inbox.decode(), headers=headers
)

if not self.is_connected:
future.cancel()
raise errors.ConnectionClosedError

try:
return await asyncio.wait_for(future, timeout)
except asyncio.TimeoutError:
cleanup_resp_map(future)
raise errors.TimeoutError
except asyncio.CancelledError:
cleanup_resp_map(future)
raise
except Exception:
if not future.done():
future.cancel()
cleanup_resp_map(future)
raise

def new_inbox(self) -> str:
"""
Expand Down Expand Up @@ -1399,6 +1441,35 @@ async def _process_err(self, err_msg: str) -> None:
# For now we handle similar as other clients and close.
asyncio.create_task(self._close(Client.CLOSED, do_cbs))

async def _check_connection_health(self) -> bool:
"""
Checks if the connection appears healthy, and if not, attempts reconnection.

Returns:
bool: True if connection is healthy or was successfully reconnected, False otherwise
"""
if not self.is_connected:
if self.options[
"allow_reconnect"
] and not self.is_reconnecting and not self.is_closed:
self._status = Client.RECONNECTING
self._ps.reset()

try:
if self._reconnection_task is not None and not self._reconnection_task.cancelled(
):
self._reconnection_task.cancel()

self._reconnection_task = asyncio.get_running_loop(
).create_task(self._attempt_reconnect())

await asyncio.sleep(self.options["reconnect_time_wait"])
return self.is_connected
except Exception:
return False
return False
return True

async def _process_op_err(self, e: Exception) -> None:
"""
Process errors which occurred while reading or parsing
Expand Down Expand Up @@ -2102,8 +2173,16 @@ async def _ping_interval(self) -> None:
await self._send_ping()
except (asyncio.CancelledError, RuntimeError, AttributeError):
break
# except asyncio.InvalidStateError:
# pass
except asyncio.InvalidStateError:
# Handle invalid state errors that can occur when connection state changes
if self.is_connected:
await self._process_op_err(ErrStaleConnection())
break
except Exception as e:
if self.is_connected:
await self._error_cb(e)
await self._process_op_err(ErrStaleConnection())
break

async def _read_loop(self) -> None:
"""
Expand All @@ -2112,6 +2191,8 @@ async def _read_loop(self) -> None:
In case of error while reading, it will stop running
and its task has to be rescheduled.
"""
read_timeout_count = 0

while True:
try:
should_bail = self.is_closed or self.is_reconnecting
Expand All @@ -2123,21 +2204,47 @@ async def _read_loop(self) -> None:
await self._process_op_err(err)
break

b = await self._transport.read(DEFAULT_BUFFER_SIZE)
await self._ps.parse(b)
# Use a timeout for reading to detect stalled connections
try:
read_future = self._transport.read(DEFAULT_BUFFER_SIZE)
b = await asyncio.wait_for(
read_future, timeout=self.options["ping_interval"]
)
read_timeout_count = 0
await self._ps.parse(b)
except asyncio.TimeoutError:
read_timeout_count += 1
if read_timeout_count >= self.options["max_read_timeouts"]:
err = ErrStaleConnection()
await self._error_cb(err)
await self._process_op_err(err)
break
continue

except errors.ProtocolError:
await self._process_op_err(errors.ProtocolError())
break
except ConnectionResetError as e:
await self._error_cb(e)
await self._process_op_err(errors.ConnectionClosedError())
break
except OSError as e:
await self._error_cb(e)
await self._process_op_err(e)
break
except asyncio.InvalidStateError:
if self.is_connected:
err = ErrStaleConnection()
await self._error_cb(err)
await self._process_op_err(err)
break
except asyncio.CancelledError:
break
except Exception as ex:
await self._error_cb(ex)
await self._process_op_err(ex)
_logger.error("nats: encountered error", exc_info=ex)
break
# except asyncio.InvalidStateError:
# pass

async def __aenter__(self) -> "Client":
"""For when NATS client is used in a context manager"""
Expand Down
Loading