Skip to content

Commit f5166be

Browse files
committed
Add test
1 parent b7232db commit f5166be

File tree

5 files changed

+162
-1
lines changed

5 files changed

+162
-1
lines changed

test/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,14 @@ def require_sync(self, func):
826826
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
827827
)
828828

829+
def require_async(self, func):
830+
"""Run a test only if using the asynchronous API.""" # unasync: off
831+
return self._require(
832+
lambda: not _IS_SYNC,
833+
"This test only works with the asynchronous API", # unasync: off
834+
func=func,
835+
)
836+
829837
def mongos_seeds(self):
830838
return ",".join("{}:{}".format(*address) for address in self.mongoses)
831839

test/asynchronous/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,14 @@ def require_sync(self, func):
828828
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
829829
)
830830

831+
def require_async(self, func):
832+
"""Run a test only if using the asynchronous API.""" # unasync: off
833+
return self._require(
834+
lambda: not _IS_SYNC,
835+
"This test only works with the asynchronous API", # unasync: off
836+
func=func,
837+
)
838+
831839
def mongos_seeds(self):
832840
return ",".join("{}:{}".format(*address) for address in self.mongoses)
833841

test/asynchronous/test_discovery_and_monitoring.py

+69
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,15 @@
2020
import socketserver
2121
import sys
2222
import threading
23+
import time
2324
from asyncio import StreamReader, StreamWriter
2425
from pathlib import Path
2526
from test.asynchronous.helpers import ConcurrentRunner
2627

28+
from pymongo.asynchronous.pool import AsyncConnection
29+
from pymongo.operations import _Op
30+
from pymongo.server_selectors import readable_server_selector
31+
2732
sys.path[0:0] = [""]
2833

2934
from test.asynchronous import (
@@ -46,6 +51,7 @@
4651
async_barrier_wait,
4752
async_create_barrier,
4853
async_wait_until,
54+
delay,
4955
server_name_to_type,
5056
)
5157
from unittest.mock import patch
@@ -370,6 +376,69 @@ async def test_pool_unpause(self):
370376
await listener.async_wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
371377
await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1)
372378

379+
@async_client_context.require_failCommand_appName
380+
@async_client_context.require_test_commands
381+
@async_client_context.require_async
382+
async def test_connection_close_does_not_block_other_operations(self):
383+
listener = CMAPHeartbeatListener()
384+
client = await self.async_single_client(
385+
appName="SDAMConnectionCloseTest",
386+
event_listeners=[listener],
387+
heartbeatFrequencyMS=500,
388+
minPoolSize=10,
389+
)
390+
server = await (await client._get_topology()).select_server(
391+
readable_server_selector, _Op.TEST
392+
)
393+
await async_wait_until(
394+
lambda: len(server._pool.conns) == 10,
395+
"pool initialized with 10 connections",
396+
)
397+
398+
await client.db.test.insert_one({"x": 1})
399+
close_delay = 0.05
400+
latencies = []
401+
402+
async def run_task():
403+
while True:
404+
start_time = time.monotonic()
405+
await client.db.test.find_one({})
406+
elapsed = time.monotonic() - start_time
407+
latencies.append(elapsed)
408+
if elapsed >= close_delay:
409+
break
410+
await asyncio.sleep(0.001)
411+
412+
task = ConcurrentRunner(target=run_task)
413+
await task.start()
414+
original_close = AsyncConnection.close_conn
415+
try:
416+
# Artificially delay the close operation to simulate a slow close
417+
async def mock_close(self, reason):
418+
await asyncio.sleep(close_delay)
419+
await original_close(self, reason)
420+
421+
AsyncConnection.close_conn = mock_close
422+
423+
fail_hello = {
424+
"mode": {"times": 4},
425+
"data": {
426+
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
427+
"errorCode": 91,
428+
"appName": "SDAMConnectionCloseTest",
429+
},
430+
}
431+
async with self.fail_point(fail_hello):
432+
# Wait for server heartbeat to fail
433+
await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
434+
# Wait until all idle connections are closed to simulate real-world conditions
435+
await listener.async_wait_for_event(monitoring.ConnectionClosedEvent, 10)
436+
# No operation latency should not significantly exceed close_delay
437+
self.assertLessEqual(max(latencies), close_delay * 1.5)
438+
finally:
439+
AsyncConnection.close_conn = original_close
440+
await task.join()
441+
373442

374443
class TestServerMonitoringMode(AsyncIntegrationTest):
375444
@async_client_context.require_no_serverless

test/test_discovery_and_monitoring.py

+67
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,15 @@
2020
import socketserver
2121
import sys
2222
import threading
23+
import time
2324
from asyncio import StreamReader, StreamWriter
2425
from pathlib import Path
2526
from test.helpers import ConcurrentRunner
2627

28+
from pymongo.operations import _Op
29+
from pymongo.server_selectors import readable_server_selector
30+
from pymongo.synchronous.pool import Connection
31+
2732
sys.path[0:0] = [""]
2833

2934
from test import (
@@ -45,6 +50,7 @@
4550
assertion_context,
4651
barrier_wait,
4752
create_barrier,
53+
delay,
4854
server_name_to_type,
4955
wait_until,
5056
)
@@ -370,6 +376,67 @@ def test_pool_unpause(self):
370376
listener.wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
371377
listener.wait_for_event(monitoring.PoolReadyEvent, 1)
372378

379+
@client_context.require_failCommand_appName
380+
@client_context.require_test_commands
381+
@client_context.require_async
382+
def test_connection_close_does_not_block_other_operations(self):
383+
listener = CMAPHeartbeatListener()
384+
client = self.single_client(
385+
appName="SDAMConnectionCloseTest",
386+
event_listeners=[listener],
387+
heartbeatFrequencyMS=500,
388+
minPoolSize=10,
389+
)
390+
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
391+
wait_until(
392+
lambda: len(server._pool.conns) == 10,
393+
"pool initialized with 10 connections",
394+
)
395+
396+
client.db.test.insert_one({"x": 1})
397+
close_delay = 0.05
398+
latencies = []
399+
400+
def run_task():
401+
while True:
402+
start_time = time.monotonic()
403+
client.db.test.find_one({})
404+
elapsed = time.monotonic() - start_time
405+
latencies.append(elapsed)
406+
if elapsed >= close_delay:
407+
break
408+
time.sleep(0.001)
409+
410+
task = ConcurrentRunner(target=run_task)
411+
task.start()
412+
original_close = Connection.close_conn
413+
try:
414+
# Artificially delay the close operation to simulate a slow close
415+
def mock_close(self, reason):
416+
time.sleep(close_delay)
417+
original_close(self, reason)
418+
419+
Connection.close_conn = mock_close
420+
421+
fail_hello = {
422+
"mode": {"times": 4},
423+
"data": {
424+
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
425+
"errorCode": 91,
426+
"appName": "SDAMConnectionCloseTest",
427+
},
428+
}
429+
with self.fail_point(fail_hello):
430+
# Wait for server heartbeat to fail
431+
listener.wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
432+
# Wait until all idle connections are closed to simulate real-world conditions
433+
listener.wait_for_event(monitoring.ConnectionClosedEvent, 10)
434+
# No operation latency should not significantly exceed close_delay
435+
self.assertLessEqual(max(latencies), close_delay * 1.5)
436+
finally:
437+
Connection.close_conn = original_close
438+
task.join()
439+
373440

374441
class TestServerMonitoringMode(IntegrationTest):
375442
@client_context.require_no_serverless

tools/synchro.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ def process_files(
288288
if file in docstring_translate_files:
289289
lines = translate_docstrings(lines)
290290
if file in sync_test_files:
291-
translate_imports(lines)
291+
lines = translate_imports(lines)
292+
lines = process_ignores(lines)
292293
f.seek(0)
293294
f.writelines(lines)
294295
f.truncate()
@@ -390,6 +391,14 @@ def translate_docstrings(lines: list[str]) -> list[str]:
390391
return [line for line in lines if line != "DOCSTRING_REMOVED"]
391392

392393

394+
def process_ignores(lines: list[str]) -> list[str]:
395+
for i in range(len(lines)):
396+
for k, v in replacements.items():
397+
if "unasync: off" in lines[i] and v in lines[i]:
398+
lines[i] = lines[i].replace(v, k)
399+
return lines
400+
401+
393402
def unasync_directory(files: list[str], src: str, dest: str, replacements: dict[str, str]) -> None:
394403
unasync_files(
395404
files,

0 commit comments

Comments
 (0)