8585DEFAULT_MAX_RECONNECT_ATTEMPTS = 60
8686DEFAULT_PING_INTERVAL = 120 # in seconds
8787DEFAULT_MAX_OUTSTANDING_PINGS = 2
88+ DEFAULT_MAX_READ_TIMEOUTS = 3
8889DEFAULT_MAX_PAYLOAD_SIZE = 1048576
8990DEFAULT_MAX_FLUSHER_QUEUE_SIZE = 1024
9091DEFAULT_FLUSH_TIMEOUT = 10 # in seconds
@@ -301,6 +302,7 @@ async def connect(
301302 max_reconnect_attempts : int = DEFAULT_MAX_RECONNECT_ATTEMPTS ,
302303 ping_interval : int = DEFAULT_PING_INTERVAL ,
303304 max_outstanding_pings : int = DEFAULT_MAX_OUTSTANDING_PINGS ,
305+ max_read_timeouts : int = DEFAULT_MAX_READ_TIMEOUTS ,
304306 dont_randomize : bool = False ,
305307 flusher_queue_size : int = DEFAULT_MAX_FLUSHER_QUEUE_SIZE ,
306308 no_echo : bool = False ,
@@ -331,6 +333,7 @@ async def connect(
331333 :param discovered_server_cb: Callback to report when a new server joins the cluster.
332334 :param pending_size: Max size of the pending buffer for publishing commands.
333335 :param flush_timeout: Max duration to wait for a forced flush to occur.
336+ :param max_read_timeouts: Maximum number of consecutive read timeouts before considering a connection stale.
334337
335338 Connecting setting all callbacks::
336339
@@ -448,6 +451,7 @@ async def subscribe_handler(msg):
448451 self .options ["reconnect_time_wait" ] = reconnect_time_wait
449452 self .options ["max_reconnect_attempts" ] = max_reconnect_attempts
450453 self .options ["ping_interval" ] = ping_interval
454+ self .options ["max_read_timeouts" ] = max_read_timeouts
451455 self .options ["max_outstanding_pings" ] = max_outstanding_pings
452456 self .options ["no_echo" ] = no_echo
453457 self .options ["user" ] = user
@@ -1013,19 +1017,38 @@ async def request(
10131017 the responses.
10141018
10151019 """
1020+ if not self .is_connected :
1021+ await self ._check_connection_health ()
1022+
1023+ if not self .is_connected :
1024+ if self .is_closed :
1025+ raise errors .ConnectionClosedError
1026+ elif self .is_reconnecting :
1027+ raise errors .ConnectionReconnectingError
1028+ else :
1029+ raise errors .ConnectionClosedError
1030+
10161031 if old_style :
10171032 # FIXME: Support headers in old style requests.
1018- return await self ._request_old_style (
1019- subject , payload , timeout = timeout
1020- )
1033+ try :
1034+ return await self ._request_old_style (
1035+ subject , payload , timeout = timeout
1036+ )
1037+ except (errors .TimeoutError , asyncio .TimeoutError ):
1038+ await self ._check_connection_health ()
1039+ raise errors .TimeoutError
10211040 else :
1022- msg = await self ._request_new_style (
1023- subject , payload , timeout = timeout , headers = headers
1024- )
1025- if (msg .headers and msg .headers .get (nats .js .api .Header .STATUS )
1026- == NO_RESPONDERS_STATUS ):
1027- raise errors .NoRespondersError
1028- return msg
1041+ try :
1042+ msg = await self ._request_new_style (
1043+ subject , payload , timeout = timeout , headers = headers
1044+ )
1045+ if msg .headers and msg .headers .get (nats .js .api .Header .STATUS
1046+ ) == NO_RESPONDERS_STATUS :
1047+ raise errors .NoRespondersError
1048+ return msg
1049+ except (errors .TimeoutError , asyncio .TimeoutError ):
1050+ await self ._check_connection_health ()
1051+ raise errors .TimeoutError
10291052
10301053 async def _request_new_style (
10311054 self ,
@@ -1037,6 +1060,9 @@ async def _request_new_style(
10371060 if self .is_draining_pubs :
10381061 raise errors .ConnectionDrainingError
10391062
1063+ if not self .is_connected :
1064+ raise errors .ConnectionClosedError
1065+
10401066 if not self ._resp_sub_prefix :
10411067 await self ._init_request_sub ()
10421068 assert self ._resp_sub_prefix
@@ -1049,21 +1075,37 @@ async def _request_new_style(
10491075
10501076 # Then use the future to get the response.
10511077 future : asyncio .Future = asyncio .Future ()
1052- future .add_done_callback (
1053- lambda f : self ._resp_map .pop (token .decode (), None )
1054- )
1055- self ._resp_map [token .decode ()] = future
1078+ token_str = token .decode ()
10561079
1057- # Publish the request
1058- await self .publish (
1059- subject , payload , reply = inbox .decode (), headers = headers
1060- )
1080+ def cleanup_resp_map (f ):
1081+ self ._resp_map .pop (token_str , None )
1082+
1083+ future .add_done_callback (cleanup_resp_map )
1084+ self ._resp_map [token_str ] = future
10611085
1062- # Wait for the response or give up on timeout.
10631086 try :
1064- return await asyncio .wait_for (future , timeout )
1065- except asyncio .TimeoutError :
1066- raise errors .TimeoutError
1087+ # Publish the request
1088+ await self .publish (
1089+ subject , payload , reply = inbox .decode (), headers = headers
1090+ )
1091+
1092+ if not self .is_connected :
1093+ future .cancel ()
1094+ raise errors .ConnectionClosedError
1095+
1096+ try :
1097+ return await asyncio .wait_for (future , timeout )
1098+ except asyncio .TimeoutError :
1099+ cleanup_resp_map (future )
1100+ raise errors .TimeoutError
1101+ except asyncio .CancelledError :
1102+ cleanup_resp_map (future )
1103+ raise
1104+ except Exception :
1105+ if not future .done ():
1106+ future .cancel ()
1107+ cleanup_resp_map (future )
1108+ raise
10671109
10681110 def new_inbox (self ) -> str :
10691111 """
@@ -1399,6 +1441,35 @@ async def _process_err(self, err_msg: str) -> None:
13991441 # For now we handle similar as other clients and close.
14001442 asyncio .create_task (self ._close (Client .CLOSED , do_cbs ))
14011443
1444+ async def _check_connection_health (self ) -> bool :
1445+ """
1446+ Checks if the connection appears healthy, and if not, attempts reconnection.
1447+
1448+ Returns:
1449+ bool: True if connection is healthy or was successfully reconnected, False otherwise
1450+ """
1451+ if not self .is_connected :
1452+ if self .options [
1453+ "allow_reconnect"
1454+ ] and not self .is_reconnecting and not self .is_closed :
1455+ self ._status = Client .RECONNECTING
1456+ self ._ps .reset ()
1457+
1458+ try :
1459+ if self ._reconnection_task is not None and not self ._reconnection_task .cancelled (
1460+ ):
1461+ self ._reconnection_task .cancel ()
1462+
1463+ self ._reconnection_task = asyncio .get_running_loop (
1464+ ).create_task (self ._attempt_reconnect ())
1465+
1466+ await asyncio .sleep (self .options ["reconnect_time_wait" ])
1467+ return self .is_connected
1468+ except Exception :
1469+ return False
1470+ return False
1471+ return True
1472+
14021473 async def _process_op_err (self , e : Exception ) -> None :
14031474 """
14041475 Process errors which occurred while reading or parsing
@@ -2102,8 +2173,16 @@ async def _ping_interval(self) -> None:
21022173 await self ._send_ping ()
21032174 except (asyncio .CancelledError , RuntimeError , AttributeError ):
21042175 break
2105- # except asyncio.InvalidStateError:
2106- # pass
2176+ except asyncio .InvalidStateError :
2177+ # Handle invalid state errors that can occur when connection state changes
2178+ if self .is_connected :
2179+ await self ._process_op_err (ErrStaleConnection ())
2180+ break
2181+ except Exception as e :
2182+ if self .is_connected :
2183+ await self ._error_cb (e )
2184+ await self ._process_op_err (ErrStaleConnection ())
2185+ break
21072186
21082187 async def _read_loop (self ) -> None :
21092188 """
@@ -2112,6 +2191,8 @@ async def _read_loop(self) -> None:
21122191 In case of error while reading, it will stop running
21132192 and its task has to be rescheduled.
21142193 """
2194+ read_timeout_count = 0
2195+
21152196 while True :
21162197 try :
21172198 should_bail = self .is_closed or self .is_reconnecting
@@ -2123,21 +2204,47 @@ async def _read_loop(self) -> None:
21232204 await self ._process_op_err (err )
21242205 break
21252206
2126- b = await self ._transport .read (DEFAULT_BUFFER_SIZE )
2127- await self ._ps .parse (b )
2207+ # Use a timeout for reading to detect stalled connections
2208+ try :
2209+ read_future = self ._transport .read (DEFAULT_BUFFER_SIZE )
2210+ b = await asyncio .wait_for (
2211+ read_future , timeout = self .options ["ping_interval" ]
2212+ )
2213+ read_timeout_count = 0
2214+ await self ._ps .parse (b )
2215+ except asyncio .TimeoutError :
2216+ read_timeout_count += 1
2217+ if read_timeout_count >= self .options ["max_read_timeouts" ]:
2218+ err = ErrStaleConnection ()
2219+ await self ._error_cb (err )
2220+ await self ._process_op_err (err )
2221+ break
2222+ continue
2223+
21282224 except errors .ProtocolError :
21292225 await self ._process_op_err (errors .ProtocolError ())
21302226 break
2227+ except ConnectionResetError as e :
2228+ await self ._error_cb (e )
2229+ await self ._process_op_err (errors .ConnectionClosedError ())
2230+ break
21312231 except OSError as e :
2232+ await self ._error_cb (e )
21322233 await self ._process_op_err (e )
21332234 break
2235+ except asyncio .InvalidStateError :
2236+ if self .is_connected :
2237+ err = ErrStaleConnection ()
2238+ await self ._error_cb (err )
2239+ await self ._process_op_err (err )
2240+ break
21342241 except asyncio .CancelledError :
21352242 break
21362243 except Exception as ex :
2244+ await self ._error_cb (ex )
2245+ await self ._process_op_err (ex )
21372246 _logger .error ("nats: encountered error" , exc_info = ex )
21382247 break
2139- # except asyncio.InvalidStateError:
2140- # pass
21412248
21422249 async def __aenter__ (self ) -> "Client" :
21432250 """For when NATS client is used in a context manager"""
0 commit comments