Skip to content

Commit 4a20463

Browse files
committed
improve long running connection stability
1 parent 4302a50 commit 4a20463

File tree

1 file changed

+135
-28
lines changed

1 file changed

+135
-28
lines changed

nats/aio/client.py

Lines changed: 135 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
DEFAULT_MAX_RECONNECT_ATTEMPTS = 60
8686
DEFAULT_PING_INTERVAL = 120 # in seconds
8787
DEFAULT_MAX_OUTSTANDING_PINGS = 2
88+
DEFAULT_MAX_READ_TIMEOUTS = 3
8889
DEFAULT_MAX_PAYLOAD_SIZE = 1048576
8990
DEFAULT_MAX_FLUSHER_QUEUE_SIZE = 1024
9091
DEFAULT_FLUSH_TIMEOUT = 10 # in seconds
@@ -301,6 +302,7 @@ async def connect(
301302
max_reconnect_attempts: int = DEFAULT_MAX_RECONNECT_ATTEMPTS,
302303
ping_interval: int = DEFAULT_PING_INTERVAL,
303304
max_outstanding_pings: int = DEFAULT_MAX_OUTSTANDING_PINGS,
305+
max_read_timeouts: int = DEFAULT_MAX_READ_TIMEOUTS,
304306
dont_randomize: bool = False,
305307
flusher_queue_size: int = DEFAULT_MAX_FLUSHER_QUEUE_SIZE,
306308
no_echo: bool = False,
@@ -331,6 +333,7 @@ async def connect(
331333
:param discovered_server_cb: Callback to report when a new server joins the cluster.
332334
:param pending_size: Max size of the pending buffer for publishing commands.
333335
:param flush_timeout: Max duration to wait for a forced flush to occur.
336+
:param max_read_timeouts: Maximum number of consecutive read timeouts before considering a connection stale.
334337
335338
Connecting setting all callbacks::
336339
@@ -448,6 +451,7 @@ async def subscribe_handler(msg):
448451
self.options["reconnect_time_wait"] = reconnect_time_wait
449452
self.options["max_reconnect_attempts"] = max_reconnect_attempts
450453
self.options["ping_interval"] = ping_interval
454+
self.options["max_read_timeouts"] = max_read_timeouts
451455
self.options["max_outstanding_pings"] = max_outstanding_pings
452456
self.options["no_echo"] = no_echo
453457
self.options["user"] = user
@@ -1013,19 +1017,38 @@ async def request(
10131017
the responses.
10141018
10151019
"""
1020+
if not self.is_connected:
1021+
await self._check_connection_health()
1022+
1023+
if not self.is_connected:
1024+
if self.is_closed:
1025+
raise errors.ConnectionClosedError
1026+
elif self.is_reconnecting:
1027+
raise errors.ConnectionReconnectingError
1028+
else:
1029+
raise errors.ConnectionClosedError
1030+
10161031
if old_style:
10171032
# FIXME: Support headers in old style requests.
1018-
return await self._request_old_style(
1019-
subject, payload, timeout=timeout
1020-
)
1033+
try:
1034+
return await self._request_old_style(
1035+
subject, payload, timeout=timeout
1036+
)
1037+
except (errors.TimeoutError, asyncio.TimeoutError):
1038+
await self._check_connection_health()
1039+
raise errors.TimeoutError
10211040
else:
1022-
msg = await self._request_new_style(
1023-
subject, payload, timeout=timeout, headers=headers
1024-
)
1025-
if (msg.headers and msg.headers.get(nats.js.api.Header.STATUS)
1026-
== NO_RESPONDERS_STATUS):
1027-
raise errors.NoRespondersError
1028-
return msg
1041+
try:
1042+
msg = await self._request_new_style(
1043+
subject, payload, timeout=timeout, headers=headers
1044+
)
1045+
if msg.headers and msg.headers.get(nats.js.api.Header.STATUS
1046+
) == NO_RESPONDERS_STATUS:
1047+
raise errors.NoRespondersError
1048+
return msg
1049+
except (errors.TimeoutError, asyncio.TimeoutError):
1050+
await self._check_connection_health()
1051+
raise errors.TimeoutError
10291052

10301053
async def _request_new_style(
10311054
self,
@@ -1037,6 +1060,9 @@ async def _request_new_style(
10371060
if self.is_draining_pubs:
10381061
raise errors.ConnectionDrainingError
10391062

1063+
if not self.is_connected:
1064+
raise errors.ConnectionClosedError
1065+
10401066
if not self._resp_sub_prefix:
10411067
await self._init_request_sub()
10421068
assert self._resp_sub_prefix
@@ -1049,21 +1075,37 @@ async def _request_new_style(
10491075

10501076
# Then use the future to get the response.
10511077
future: asyncio.Future = asyncio.Future()
1052-
future.add_done_callback(
1053-
lambda f: self._resp_map.pop(token.decode(), None)
1054-
)
1055-
self._resp_map[token.decode()] = future
1078+
token_str = token.decode()
10561079

1057-
# Publish the request
1058-
await self.publish(
1059-
subject, payload, reply=inbox.decode(), headers=headers
1060-
)
1080+
def cleanup_resp_map(f):
1081+
self._resp_map.pop(token_str, None)
1082+
1083+
future.add_done_callback(cleanup_resp_map)
1084+
self._resp_map[token_str] = future
10611085

1062-
# Wait for the response or give up on timeout.
10631086
try:
1064-
return await asyncio.wait_for(future, timeout)
1065-
except asyncio.TimeoutError:
1066-
raise errors.TimeoutError
1087+
# Publish the request
1088+
await self.publish(
1089+
subject, payload, reply=inbox.decode(), headers=headers
1090+
)
1091+
1092+
if not self.is_connected:
1093+
future.cancel()
1094+
raise errors.ConnectionClosedError
1095+
1096+
try:
1097+
return await asyncio.wait_for(future, timeout)
1098+
except asyncio.TimeoutError:
1099+
cleanup_resp_map(future)
1100+
raise errors.TimeoutError
1101+
except asyncio.CancelledError:
1102+
cleanup_resp_map(future)
1103+
raise
1104+
except Exception:
1105+
if not future.done():
1106+
future.cancel()
1107+
cleanup_resp_map(future)
1108+
raise
10671109

10681110
def new_inbox(self) -> str:
10691111
"""
@@ -1399,6 +1441,35 @@ async def _process_err(self, err_msg: str) -> None:
13991441
# For now we handle similar as other clients and close.
14001442
asyncio.create_task(self._close(Client.CLOSED, do_cbs))
14011443

1444+
async def _check_connection_health(self) -> bool:
1445+
"""
1446+
Checks if the connection appears healthy, and if not, attempts reconnection.
1447+
1448+
Returns:
1449+
bool: True if connection is healthy or was successfully reconnected, False otherwise
1450+
"""
1451+
if not self.is_connected:
1452+
if self.options[
1453+
"allow_reconnect"
1454+
] and not self.is_reconnecting and not self.is_closed:
1455+
self._status = Client.RECONNECTING
1456+
self._ps.reset()
1457+
1458+
try:
1459+
if self._reconnection_task is not None and not self._reconnection_task.cancelled(
1460+
):
1461+
self._reconnection_task.cancel()
1462+
1463+
self._reconnection_task = asyncio.get_running_loop(
1464+
).create_task(self._attempt_reconnect())
1465+
1466+
await asyncio.sleep(self.options["reconnect_time_wait"])
1467+
return self.is_connected
1468+
except Exception:
1469+
return False
1470+
return False
1471+
return True
1472+
14021473
async def _process_op_err(self, e: Exception) -> None:
14031474
"""
14041475
Process errors which occurred while reading or parsing
@@ -2102,8 +2173,16 @@ async def _ping_interval(self) -> None:
21022173
await self._send_ping()
21032174
except (asyncio.CancelledError, RuntimeError, AttributeError):
21042175
break
2105-
# except asyncio.InvalidStateError:
2106-
# pass
2176+
except asyncio.InvalidStateError:
2177+
# Handle invalid state errors that can occur when connection state changes
2178+
if self.is_connected:
2179+
await self._process_op_err(ErrStaleConnection())
2180+
break
2181+
except Exception as e:
2182+
if self.is_connected:
2183+
await self._error_cb(e)
2184+
await self._process_op_err(ErrStaleConnection())
2185+
break
21072186

21082187
async def _read_loop(self) -> None:
21092188
"""
@@ -2112,6 +2191,8 @@ async def _read_loop(self) -> None:
21122191
In case of error while reading, it will stop running
21132192
and its task has to be rescheduled.
21142193
"""
2194+
read_timeout_count = 0
2195+
21152196
while True:
21162197
try:
21172198
should_bail = self.is_closed or self.is_reconnecting
@@ -2123,21 +2204,47 @@ async def _read_loop(self) -> None:
21232204
await self._process_op_err(err)
21242205
break
21252206

2126-
b = await self._transport.read(DEFAULT_BUFFER_SIZE)
2127-
await self._ps.parse(b)
2207+
# Use a timeout for reading to detect stalled connections
2208+
try:
2209+
read_future = self._transport.read(DEFAULT_BUFFER_SIZE)
2210+
b = await asyncio.wait_for(
2211+
read_future, timeout=self.options["ping_interval"]
2212+
)
2213+
read_timeout_count = 0
2214+
await self._ps.parse(b)
2215+
except asyncio.TimeoutError:
2216+
read_timeout_count += 1
2217+
if read_timeout_count >= self.options["max_read_timeouts"]:
2218+
err = ErrStaleConnection()
2219+
await self._error_cb(err)
2220+
await self._process_op_err(err)
2221+
break
2222+
continue
2223+
21282224
except errors.ProtocolError:
21292225
await self._process_op_err(errors.ProtocolError())
21302226
break
2227+
except ConnectionResetError as e:
2228+
await self._error_cb(e)
2229+
await self._process_op_err(errors.ConnectionClosedError())
2230+
break
21312231
except OSError as e:
2232+
await self._error_cb(e)
21322233
await self._process_op_err(e)
21332234
break
2235+
except asyncio.InvalidStateError:
2236+
if self.is_connected:
2237+
err = ErrStaleConnection()
2238+
await self._error_cb(err)
2239+
await self._process_op_err(err)
2240+
break
21342241
except asyncio.CancelledError:
21352242
break
21362243
except Exception as ex:
2244+
await self._error_cb(ex)
2245+
await self._process_op_err(ex)
21372246
_logger.error("nats: encountered error", exc_info=ex)
21382247
break
2139-
# except asyncio.InvalidStateError:
2140-
# pass
21412248

21422249
async def __aenter__(self) -> "Client":
21432250
"""For when NATS client is used in a context manager"""

0 commit comments

Comments
 (0)