Skip to content

Commit 96d14b5

Browse files
committed
Start on synchronous testing
Close queues, hopefully fixed race condition
1 parent 0a93b24 commit 96d14b5

File tree

7 files changed

+192
-15
lines changed

7 files changed

+192
-15
lines changed

dispatcher/service/pool.py

+14
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ async def stop(self) -> None:
9090
except asyncio.TimeoutError:
9191
logger.error(f'Worker {self.worker_id} pid={self.process.pid} failed to send exit message in 3 seconds')
9292
self.status = 'error' # can signal for result task to exit, since no longer waiting for it here
93+
self.process.message_queue.close()
9394

9495
await self.join() # If worker fails to exit, this returns control without raising an exception
9596

@@ -461,6 +462,18 @@ async def shutdown(self) -> None:
461462
except asyncio.CancelledError:
462463
logger.info('The finished task was canceled, but we are shutting down so that is alright')
463464

465+
if self.management_task:
466+
logger.info('Canceling worker management task')
467+
self.management_task.cancel()
468+
try:
469+
await asyncio.wait_for(self.management_task, timeout=self.shutdown_timeout)
470+
except asyncio.TimeoutError:
471+
logger.error('The scaleup task failed to shut down')
472+
except asyncio.CancelledError:
473+
pass # intended
474+
475+
self.process_manager.shutdown()
476+
464477
logger.info('Pool is shut down')
465478

466479
def active_task_ct(self) -> int:
@@ -576,6 +589,7 @@ async def read_results_forever(self) -> None:
576589
async with self.workers.management_lock:
577590
worker.status = 'exited'
578591
worker.exit_msg_event.set()
592+
worker.process.message_queue.close()
579593

580594
if self.shutting_down:
581595
if all(worker.inactive for worker in self.workers):

dispatcher/service/process.py

+3
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ async def read_finished(self) -> dict[str, Union[str, int]]:
9999
message = await self.get_event_loop().run_in_executor(None, self.finished_queue.get)
100100
return message
101101

102+
def shutdown(self):
103+
self.finished_queue.close()
104+
102105

103106
class ForkServerManager(ProcessManager):
104107
mp_context = 'forkserver'

dispatcher/testing/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .subprocess import adispatcher_service, dispatcher_service
2+
3+
__all__ = ['adispatcher_service', 'dispatcher_service']

dispatcher/testing/subprocess.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import asyncio
2+
import contextlib
3+
import logging
4+
import multiprocessing
5+
import sys
6+
from multiprocessing.context import BaseContext
7+
from types import ModuleType
8+
from typing import Any, AsyncGenerator, Union
9+
10+
from ..config import DispatcherSettings
11+
from ..factories import from_settings
12+
from ..service.main import DispatcherMain
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class CommunicationItems:
18+
"""Various things used for communication between the parent process and the subprocess service
19+
20+
This will be passed in the call to the subprocess.
21+
"""
22+
23+
def __init__(self, main_events: tuple[str], pool_events: tuple[str], context: Union[BaseContext, ModuleType]) -> None:
24+
self.q_in: multiprocessing.Queue = context.Queue()
25+
self.q_out: multiprocessing.Queue = context.Queue()
26+
self.main_events = main_events
27+
self.pool_events = pool_events
28+
29+
30+
@contextlib.asynccontextmanager
31+
async def adispatcher_service(config: dict) -> AsyncGenerator[DispatcherMain, Any]:
32+
dispatcher = None
33+
try:
34+
settings = DispatcherSettings(config)
35+
dispatcher = from_settings(settings=settings) # type: ignore[arg-type]
36+
37+
await dispatcher.connect_signals()
38+
await dispatcher.start_working()
39+
await dispatcher.wait_for_producers_ready()
40+
await dispatcher.pool.events.workers_ready.wait()
41+
42+
assert dispatcher.pool.finished_count == 0 # sanity
43+
assert dispatcher.control_count == 0
44+
45+
yield dispatcher
46+
finally:
47+
if dispatcher:
48+
try:
49+
await dispatcher.shutdown()
50+
await dispatcher.cancel_tasks()
51+
except Exception:
52+
logger.exception('shutdown had error')
53+
54+
55+
async def asyncio_target(config: dict, comms: CommunicationItems) -> None:
56+
loop = asyncio.get_event_loop()
57+
async with adispatcher_service(config) as dispatcher:
58+
comms.q_out.put('ready')
59+
60+
events: dict[str, asyncio.Event] = {}
61+
for event_name in comms.main_events:
62+
events[event_name] = getattr(dispatcher.events, event_name)
63+
for event_name in comms.pool_events:
64+
events[event_name] = getattr(dispatcher.pool.events, event_name)
65+
66+
event_tasks: dict[str, asyncio.Task] = {}
67+
for event_name, event in events.items():
68+
event_tasks[event_name] = asyncio.create_task(event.wait(), name=f'waiting_for_{event_name}')
69+
70+
new_message_task = None
71+
72+
while True:
73+
if new_message_task is None:
74+
new_message_task = loop.run_in_executor(None, comms.q_in.get)
75+
76+
all_tasks = list(event_tasks.values()) + [new_message_task]
77+
await asyncio.wait(all_tasks, return_when=asyncio.FIRST_COMPLETED)
78+
79+
# Update our parent process with any events they requested from us
80+
for event_name, event in events.items():
81+
if event.is_set():
82+
comms.q_out.put(event_name)
83+
# await loop.run_in_executor(None, comms.q_out.put, event_name)
84+
event.clear()
85+
event_tasks[event_name] = asyncio.create_task(event.wait())
86+
87+
# If no no instructions came from parent then work is done, continue loop
88+
if not new_message_task.done():
89+
continue
90+
91+
message = new_message_task.result()
92+
new_message_task = None
93+
94+
if message == 'stop':
95+
print('shutting down pool server')
96+
for event in events.values():
97+
event.set() # close out other tasks
98+
await dispatcher.shutdown()
99+
break
100+
else:
101+
eval(message)
102+
103+
104+
def subprocess_main(config, comms):
105+
loop = asyncio.new_event_loop()
106+
try:
107+
loop.run_until_complete(asyncio_target(config, comms))
108+
except Exception:
109+
# The main process is very likely waiting for message of an event
110+
# and exceptions may not automatically halt the test, so give a value
111+
comms.q_out.put('error')
112+
raise
113+
finally:
114+
loop.close()
115+
116+
117+
@contextlib.contextmanager
118+
def dispatcher_service(config, main_events=(), pool_events=()):
119+
ctx = multiprocessing.get_context('spawn')
120+
comms = CommunicationItems(main_events=main_events, pool_events=pool_events, context=ctx)
121+
process = multiprocessing.Process(target=subprocess_main, args=(config, comms))
122+
try:
123+
process.start()
124+
ready_msg = comms.q_out.get()
125+
if ready_msg != 'ready':
126+
raise RuntimeError(f'Never got "ready" message from server, got {ready_msg}')
127+
yield comms
128+
finally:
129+
comms.q_in.put('stop')
130+
process.join(timeout=1)
131+
if process.is_alive():
132+
process.terminate() # SIGTERM
133+
comms.q_in.close()
134+
comms.q_out.close()
135+
sys.stdout.flush()
136+
sys.stderr.flush()

tests/conftest.py

-8
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,6 @@ def conn_config():
8383
return {'conninfo': CONNECTION_STRING}
8484

8585

86-
@pytest.fixture
87-
def pg_dispatcher() -> DispatcherMain:
88-
# We can not reuse the connection between tests
89-
config = BASIC_CONFIG.copy()
90-
config['brokers']['pg_notify'].pop('async_connection_factory')
91-
return DispatcherMain(config)
92-
93-
9486
@pytest.fixture
9587
def test_settings():
9688
return DispatcherSettings(BASIC_CONFIG)

tests/integration/conftest.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
3+
from dispatcher.testing import dispatcher_service
4+
from dispatcher.factories import get_publisher_from_settings
5+
from dispatcher.config import DispatcherSettings
6+
7+
from tests.conftest import CONNECTION_STRING
8+
9+
10+
BASIC_CONFIG = {
11+
"version": 2,
12+
"brokers": {
13+
"pg_notify": {
14+
"channels": ['test_channel', 'test_channel2', 'test_channel3'],
15+
"config": {'conninfo': CONNECTION_STRING},
16+
"sync_connection_factory": "dispatcher.brokers.pg_notify.connection_saver",
17+
"default_publish_channel": "test_channel"
18+
}
19+
}
20+
}
21+
22+
23+
@pytest.fixture
24+
def pg_dispatcher():
25+
with dispatcher_service(BASIC_CONFIG, pool_events=('work_cleared',)) as comms:
26+
yield comms
27+
28+
29+
@pytest.fixture()
30+
def pg_broker():
31+
settings = DispatcherSettings(BASIC_CONFIG)
32+
return get_publisher_from_settings(settings=settings)

tests/integration/test_main.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,10 @@ async def wait_to_receive(dispatcher, ct, timeout=5.0, interval=0.05):
2323
raise RuntimeError(f'Failed to receive expected {ct} messages {dispatcher.pool.received_count}')
2424

2525

26-
@pytest.mark.asyncio
27-
async def test_run_lambda_function(apg_dispatcher, pg_message):
28-
clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait(), name='test_lambda_clear_wait')
29-
await pg_message('lambda: "This worked!"')
30-
await asyncio.wait_for(clearing_task, timeout=3)
31-
32-
assert apg_dispatcher.pool.finished_count == 1
26+
def test_run_lambda_function(pg_dispatcher, pg_broker):
27+
pg_broker.publish_message(message='lambda: "This worked!"')
28+
message = pg_dispatcher.q_out.get()
29+
assert message == 'work_cleared'
3330

3431

3532
@pytest.mark.asyncio

0 commit comments

Comments
 (0)