Skip to content

Commit 3efe28a

Browse files
gh-128002: add more thread safety tests for asyncio (#128480)
1 parent 75214f8 commit 3efe28a

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

Lib/test/test_asyncio/test_free_threading.py

+54
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77

88
threading_helper.requires_working_threading(module=True)
99

10+
11+
class MyException(Exception):
12+
pass
13+
14+
1015
def tearDownModule():
1116
asyncio._set_event_loop_policy(None)
1217

@@ -53,6 +58,55 @@ def runner():
5358
with threading_helper.start_threads(threads):
5459
pass
5560

61+
def test_run_coroutine_threadsafe(self) -> None:
62+
results = []
63+
64+
def in_thread(loop: asyncio.AbstractEventLoop):
65+
coro = asyncio.sleep(0.1, result=42)
66+
fut = asyncio.run_coroutine_threadsafe(coro, loop)
67+
result = fut.result()
68+
self.assertEqual(result, 42)
69+
results.append(result)
70+
71+
async def main():
72+
loop = asyncio.get_running_loop()
73+
async with asyncio.TaskGroup() as tg:
74+
for _ in range(10):
75+
tg.create_task(asyncio.to_thread(in_thread, loop))
76+
self.assertEqual(results, [42] * 10)
77+
78+
with asyncio.Runner() as r:
79+
loop = r.get_loop()
80+
loop.set_task_factory(self.factory)
81+
r.run(main())
82+
83+
def test_run_coroutine_threadsafe_exception(self) -> None:
84+
async def coro():
85+
await asyncio.sleep(0)
86+
raise MyException("test")
87+
88+
def in_thread(loop: asyncio.AbstractEventLoop):
89+
fut = asyncio.run_coroutine_threadsafe(coro(), loop)
90+
return fut.result()
91+
92+
async def main():
93+
loop = asyncio.get_running_loop()
94+
tasks = []
95+
for _ in range(10):
96+
task = loop.create_task(asyncio.to_thread(in_thread, loop))
97+
tasks.append(task)
98+
results = await asyncio.gather(*tasks, return_exceptions=True)
99+
100+
self.assertEqual(len(results), 10)
101+
for result in results:
102+
self.assertIsInstance(result, MyException)
103+
self.assertEqual(str(result), "test")
104+
105+
with asyncio.Runner() as r:
106+
loop = r.get_loop()
107+
loop.set_task_factory(self.factory)
108+
r.run(main())
109+
56110

57111
class TestPyFreeThreading(TestFreeThreading, TestCase):
58112
all_tasks = staticmethod(asyncio.tasks._py_all_tasks)

0 commit comments

Comments
 (0)