Skip to content

Commit 06735d7

Browse files
authored
Fix discovered server callback not being awaited (#660)
1 parent 97fcc7f commit 06735d7

File tree

4 files changed

+9
-17
lines changed

4 files changed

+9
-17
lines changed

nats/aio/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,7 +1856,7 @@ def _process_disconnect(self) -> None:
18561856
"""
18571857
self._status = Client.DISCONNECTED
18581858

1859-
def _process_info(
1859+
async def _process_info(
18601860
self, info: Dict[str, Any], initial_connection: bool = False
18611861
) -> None:
18621862
"""
@@ -1899,7 +1899,7 @@ def _process_info(
18991899

19001900
if (not initial_connection and connect_urls
19011901
and self._discovered_server_cb):
1902-
self._discovered_server_cb()
1902+
await self._discovered_server_cb()
19031903

19041904
def _host_is_ip(self, connect_url: Optional[str]) -> bool:
19051905
if connect_url is None:
@@ -1960,7 +1960,7 @@ async def _process_connect_init(self) -> None:
19601960
if srv_info.get("auth_required", False):
19611961
self._auth_configured = True
19621962

1963-
self._process_info(srv_info, initial_connection=True)
1963+
await self._process_info(srv_info, initial_connection=True)
19641964

19651965
if "version" in self._server_info:
19661966
self._current_server.server_version = self._server_info["version"]

nats/protocol/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ async def parse(self, data: bytes = b""):
161161
if info:
162162
info_line = info.groups()[0]
163163
srv_info = json.loads(info_line.decode())
164-
self.nc._process_info(srv_info)
164+
await self.nc._process_info(srv_info)
165165
del self.buf[:info.end()]
166166
continue
167167

tests/test_client.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1922,13 +1922,8 @@ async def test_discover_servers_on_first_connect(self):
19221922
await asyncio.sleep(1)
19231923

19241924
options = {"servers": ["nats://127.0.0.1:4223", ]}
1925-
1926-
discovered_server_cb = mock.Mock()
1927-
1928-
with mock.patch("asyncio.iscoroutinefunction", return_value=True):
1929-
await nc.connect(
1930-
**options, discovered_server_cb=discovered_server_cb
1931-
)
1925+
discovered_server_cb = mock.AsyncMock()
1926+
await nc.connect(**options, discovered_server_cb=discovered_server_cb)
19321927
self.assertTrue(nc.is_connected)
19331928
await nc.close()
19341929
self.assertTrue(nc.is_closed)
@@ -1941,11 +1936,8 @@ async def test_discover_servers_after_first_connect(self):
19411936
nc = NATS()
19421937

19431938
options = {"servers": ["nats://127.0.0.1:4223", ]}
1944-
discovered_server_cb = mock.Mock()
1945-
with mock.patch("asyncio.iscoroutinefunction", return_value=True):
1946-
await nc.connect(
1947-
**options, discovered_server_cb=discovered_server_cb
1948-
)
1939+
discovered_server_cb = mock.AsyncMock()
1940+
await nc.connect(**options, discovered_server_cb=discovered_server_cb)
19491941

19501942
# Start rest of cluster members so that we receive them
19511943
# connect_urls on the first connect.

tests/test_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async def _process_msg(self, sid, subject, reply, payload, headers=None):
3333
async def _process_err(self, err=None):
3434
pass
3535

36-
def _process_info(self, info):
36+
async def _process_info(self, info):
3737
self._server_info = info
3838

3939

0 commit comments

Comments
 (0)