Skip to content

Commit 4fd1658

Browse files
committed
Allow using forkserver
Run main tests with forkserver Add tests about pid accuracy Get tests mostly working Run linters and close connections Wait for workers to be ready before starting test wrap up linters Update schema run linters Python 3.10 compat
1 parent 29c7ca1 commit 4fd1658

File tree

7 files changed

+122
-38
lines changed

7 files changed

+122
-38
lines changed

dispatcher/factories.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
from copy import deepcopy
3-
from typing import Iterable, Optional, Type, get_args, get_origin
3+
from typing import Iterable, Literal, Optional, Type, get_args, get_origin
44

55
from . import producers
66
from .brokers import get_broker
@@ -10,7 +10,7 @@
1010
from .control import Control
1111
from .service.main import DispatcherMain
1212
from .service.pool import WorkerPool
13-
from .service.process import ProcessManager
13+
from .service import process
1414

1515
"""
1616
Creates objects from settings,
@@ -21,10 +21,16 @@
2121
# ---- Service objects ----
2222

2323

24+
def process_manager_from_settings(settings: LazySettings = global_settings):
25+
cls_name = settings.service.get('process_manager_cls', 'ForkServer')
26+
process_manager_cls = getattr(process, cls_name)
27+
return process_manager_cls()
28+
29+
2430
def pool_from_settings(settings: LazySettings = global_settings):
2531
kwargs = settings.service.get('pool_kwargs', {}).copy()
2632
kwargs['settings'] = settings
27-
kwargs['process_manager'] = ProcessManager() # TODO: use process_manager_cls from settings
33+
kwargs['process_manager'] = process_manager_from_settings(settings=settings)
2834
return WorkerPool(**kwargs)
2935

3036

@@ -119,6 +125,11 @@ def generate_settings_schema(settings: LazySettings = global_settings) -> dict:
119125
ret = deepcopy(settings.serialize())
120126

121127
ret['service']['pool_kwargs'] = schema_for_cls(WorkerPool)
128+
ret['service']['process_manager_kwargs'] = {}
129+
pm_classes = (process.ProcessManager, process.ForkServerManager)
130+
for pm_cls in pm_classes:
131+
ret['service']['process_manager_kwargs'].update(schema_for_cls(pm_cls))
132+
ret['service']['process_manager_cls'] = str(Literal[tuple(pm_cls.__name__ for pm_cls in pm_classes)])
122133

123134
for broker_name, broker_kwargs in settings.brokers.items():
124135
broker = get_broker(broker_name, broker_kwargs)

dispatcher/service/pool.py

+3
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(self) -> None:
9393
self.work_cleared: asyncio.Event = asyncio.Event() # Totally quiet, no blocked or queued messages, no busy workers
9494
self.management_event: asyncio.Event = asyncio.Event() # Process spawning is backgrounded, so this is the kicker
9595
self.timeout_event: asyncio.Event = asyncio.Event() # Anything that might affect the timeout watcher task
96+
self.workers_ready: asyncio.Event = asyncio.Event() # min workers have started and sent ready message
9697

9798

9899
class WorkerPool:
@@ -402,6 +403,8 @@ async def read_results_forever(self) -> None:
402403

403404
if event == 'ready':
404405
worker.status = 'ready'
406+
if all(worker.status == 'ready' for worker in self.workers.values()):
407+
self.events.workers_ready.set()
405408
await self.drain_queue()
406409

407410
elif event == 'shutdown':

dispatcher/service/process.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
import asyncio
22
import multiprocessing
3+
from multiprocessing.context import BaseContext
4+
from types import ModuleType
35
from typing import Callable, Iterable, Optional, Union
46

57
from ..worker.task import work_loop
68

79

810
class ProcessProxy:
9-
def __init__(self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop) -> None:
10-
self.message_queue: multiprocessing.Queue = multiprocessing.Queue()
11-
self._process = multiprocessing.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue))
11+
def __init__(
12+
self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop, ctx: Union[BaseContext, ModuleType] = multiprocessing
13+
) -> None:
14+
self.message_queue: multiprocessing.Queue = ctx.Queue()
15+
# This is intended use of multiprocessing context, but not available on BaseContext
16+
self._process = ctx.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue)) # type: ignore
1217

1318
def start(self) -> None:
1419
self._process.start()
@@ -37,8 +42,11 @@ def terminate(self) -> None:
3742

3843

3944
class ProcessManager:
45+
mp_context = 'fork'
46+
4047
def __init__(self) -> None:
41-
self.finished_queue: multiprocessing.Queue = multiprocessing.Queue()
48+
self.ctx = multiprocessing.get_context(self.mp_context)
49+
self.finished_queue: multiprocessing.Queue = self.ctx.Queue()
4250
self._loop = None
4351

4452
def get_event_loop(self):
@@ -47,8 +55,16 @@ def get_event_loop(self):
4755
return self._loop
4856

4957
def create_process(self, args: Iterable[int | str | dict], **kwargs) -> ProcessProxy:
50-
return ProcessProxy(args, self.finished_queue, **kwargs)
58+
return ProcessProxy(args, self.finished_queue, ctx=self.ctx, **kwargs)
5159

5260
async def read_finished(self) -> dict[str, Union[str, int]]:
5361
message = await self.get_event_loop().run_in_executor(None, self.finished_queue.get)
5462
return message
63+
64+
65+
class ForkServerManager(ProcessManager):
66+
mp_context = 'forkserver'
67+
68+
def __init__(self, preload_modules: Optional[list[str]] = None):
69+
super().__init__()
70+
self.ctx.set_forkserver_preload(preload_modules if preload_modules else [])

schema.json

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
"service": {
1818
"pool_kwargs": {
1919
"max_workers": "<class 'int'>"
20-
}
20+
},
21+
"process_manager_kwargs": {
22+
"preload_modules": "typing.Optional[list[str]]"
23+
},
24+
"process_manager_cls": "typing.Literal['ProcessManager', 'ForkServerManager']"
2125
},
2226
"publish": {
2327
"default_broker": "str"

tests/conftest.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dispatcher.service.main import DispatcherMain
1010
from dispatcher.control import Control
1111

12-
from dispatcher.brokers.pg_notify import Broker, create_connection, acreate_connection
12+
from dispatcher.brokers.pg_notify import Broker, acreate_connection, connection_save
1313
from dispatcher.registry import DispatcherMethodRegistry
1414
from dispatcher.config import DispatcherSettings
1515
from dispatcher.factories import from_settings, get_control_from_settings
@@ -56,6 +56,21 @@ async def aconnection_for_test():
5656
await conn.close()
5757

5858

59+
@pytest.fixture(autouse=True)
60+
def clear_connection():
61+
"""Always close connections between tests
62+
63+
Tests will do a lot of unthoughtful forking, and connections can not
64+
be shared accross processes.
65+
"""
66+
if connection_save._connection:
67+
connection_save._connection.close()
68+
connection_save._connection = None
69+
if connection_save._async_connection:
70+
connection_save._async_connection.close()
71+
connection_save._async_connection = None
72+
73+
5974
@pytest.fixture
6075
def conn_config():
6176
return {'conninfo': CONNECTION_STRING}
@@ -73,18 +88,28 @@ def pg_dispatcher() -> DispatcherMain:
7388
def test_settings():
7489
return DispatcherSettings(BASIC_CONFIG)
7590

76-
77-
@pytest_asyncio.fixture(loop_scope="function", scope="function")
78-
async def apg_dispatcher(test_settings) -> AsyncIterator[DispatcherMain]:
91+
@pytest_asyncio.fixture(
92+
loop_scope="function",
93+
scope="function",
94+
params=['ProcessManager', 'ForkServerManager'],
95+
ids=["fork", "forkserver"],
96+
)
97+
async def apg_dispatcher(request) -> AsyncIterator[DispatcherMain]:
7998
dispatcher = None
8099
try:
81-
dispatcher = from_settings(settings=test_settings)
100+
this_test_config = BASIC_CONFIG.copy()
101+
this_test_config.setdefault('service', {})
102+
this_test_config['service']['process_manager_cls'] = request.param
103+
this_settings = DispatcherSettings(this_test_config)
104+
dispatcher = from_settings(settings=this_settings)
82105

83106
await dispatcher.connect_signals()
84107
await dispatcher.start_working()
85108
await dispatcher.wait_for_producers_ready()
109+
await dispatcher.pool.events.workers_ready.wait()
86110

87111
assert dispatcher.pool.finished_count == 0 # sanity
112+
assert dispatcher.control_count == 0
88113

89114
yield dispatcher
90115
finally:

tests/integration/test_main.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ async def wait_to_receive(dispatcher, ct, timeout=5.0, interval=0.05):
2424

2525
@pytest.mark.asyncio
2626
async def test_run_lambda_function(apg_dispatcher, pg_message):
27-
assert apg_dispatcher.pool.finished_count == 0
28-
2927
clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait(), name='test_lambda_clear_wait')
3028
await pg_message('lambda: "This worked!"')
3129
await asyncio.wait_for(clearing_task, timeout=3)
@@ -93,7 +91,7 @@ async def test_cancel_task(apg_dispatcher, pg_message, pg_control):
9391
await pg_message(msg)
9492

9593
clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait())
96-
await asyncio.sleep(0.04)
94+
await asyncio.sleep(0.2)
9795
canceled_jobs = await asyncio.wait_for(pg_control.acontrol_with_reply('cancel', data={'uuid': 'foobar'}, timeout=1), timeout=5)
9896
worker_id, canceled_message = canceled_jobs[0][0]
9997
assert canceled_message['uuid'] == 'foobar'
@@ -128,8 +126,6 @@ async def test_message_with_delay(apg_dispatcher, pg_message, pg_control):
128126

129127
@pytest.mark.asyncio
130128
async def test_cancel_delayed_task(apg_dispatcher, pg_message, pg_control):
131-
assert apg_dispatcher.pool.finished_count == 0
132-
133129
# Send message to run task with a delay
134130
msg = json.dumps({'task': 'lambda: print("This task should be canceled before start")', 'uuid': 'delay_task_will_cancel', 'delay': 0.8})
135131
await pg_message(msg)
@@ -149,8 +145,6 @@ async def test_cancel_delayed_task(apg_dispatcher, pg_message, pg_control):
149145

150146
@pytest.mark.asyncio
151147
async def test_cancel_with_no_reply(apg_dispatcher, pg_message, pg_control):
152-
assert apg_dispatcher.pool.finished_count == 0
153-
154148
# Send message to run task with a delay
155149
msg = json.dumps({'task': 'lambda: print("This task should be canceled before start")', 'uuid': 'delay_task_will_cancel', 'delay': 2.0})
156150
await pg_message(msg)
@@ -167,8 +161,6 @@ async def test_cancel_with_no_reply(apg_dispatcher, pg_message, pg_control):
167161

168162
@pytest.mark.asyncio
169163
async def test_alive_check(apg_dispatcher, pg_control):
170-
assert apg_dispatcher.control_count == 0
171-
172164
alive = await asyncio.wait_for(pg_control.acontrol_with_reply('alive', timeout=1), timeout=5)
173165
assert alive == [None]
174166

@@ -177,8 +169,6 @@ async def test_alive_check(apg_dispatcher, pg_control):
177169

178170
@pytest.mark.asyncio
179171
async def test_task_discard(apg_dispatcher, pg_message):
180-
assert apg_dispatcher.pool.finished_count == 0
181-
182172
messages = [
183173
json.dumps(
184174
{'task': 'lambda: __import__("time").sleep(9)', 'on_duplicate': 'discard', 'uuid': f'dscd-{i}'}
@@ -195,8 +185,6 @@ async def test_task_discard(apg_dispatcher, pg_message):
195185

196186
@pytest.mark.asyncio
197187
async def test_task_discard_in_task_definition(apg_dispatcher, test_settings):
198-
assert apg_dispatcher.pool.finished_count == 0
199-
200188
for i in range(10):
201189
test_methods.sleep_discard.apply_async(args=[2], settings=test_settings)
202190

@@ -208,8 +196,6 @@ async def test_task_discard_in_task_definition(apg_dispatcher, test_settings):
208196

209197
@pytest.mark.asyncio
210198
async def test_tasks_in_serial(apg_dispatcher, test_settings):
211-
assert apg_dispatcher.pool.finished_count == 0
212-
213199
for i in range(10):
214200
test_methods.sleep_serial.apply_async(args=[2], settings=test_settings)
215201

@@ -221,8 +207,6 @@ async def test_tasks_in_serial(apg_dispatcher, test_settings):
221207

222208
@pytest.mark.asyncio
223209
async def test_tasks_queue_one(apg_dispatcher, test_settings):
224-
assert apg_dispatcher.pool.finished_count == 0
225-
226210
for i in range(10):
227211
test_methods.sleep_queue_one.apply_async(args=[2], settings=test_settings)
228212

tests/unit/service/test_process.py

+48-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from multiprocessing import Queue
2+
import os
23

3-
from dispatcher.service.process import ProcessManager, ProcessProxy
4+
import pytest
5+
6+
from dispatcher.service.process import ProcessManager, ForkServerManager, ProcessProxy
47

58

69
def test_pass_messages_to_worker():
@@ -17,15 +20,53 @@ def work_loop(a, b, c, in_q, out_q):
1720
assert msg == 'done 1 2 3 start'
1821

1922

20-
def test_pass_messages_via_process_manager():
21-
def work_loop(var, in_q, out_q):
22-
has_read = in_q.get()
23-
out_q.put(f'done {var} {has_read}')
23+
def work_loop2(var, in_q, out_q):
24+
"""
25+
Due to the mechanics of forkserver, this can not be defined in local variables,
26+
it has to be importable, but this _is_ importable from the test module.
27+
"""
28+
has_read = in_q.get()
29+
out_q.put(f'done {var} {has_read}')
2430

25-
process_manager = ProcessManager()
26-
process = process_manager.create_process(('value',), target=work_loop)
31+
32+
@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager])
33+
def test_pass_messages_via_process_manager(manager_cls):
34+
process_manager = manager_cls()
35+
process = process_manager.create_process(('value',), target=work_loop2)
2736
process.start()
2837

2938
process.message_queue.put('msg1')
3039
msg = process_manager.finished_queue.get()
3140
assert msg == 'done value msg1'
41+
42+
43+
@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager])
44+
def test_workers_have_different_pid(manager_cls):
45+
process_manager = manager_cls()
46+
processes = [process_manager.create_process((f'value{i}',), target=work_loop2) for i in range(2)]
47+
48+
for i in range(2):
49+
process = processes[i]
50+
process.start()
51+
process.message_queue.put(f'msg{i}')
52+
53+
assert processes[0].pid != processes[1].pid # title of test
54+
55+
msg1 = process_manager.finished_queue.get()
56+
msg2 = process_manager.finished_queue.get()
57+
assert set([msg1, msg2]) == set(['done value1 msg1', 'done value0 msg0'])
58+
59+
60+
61+
def return_pid(in_q, out_q):
62+
out_q.put(f'{os.getpid()}')
63+
64+
65+
@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager])
66+
def test_pid_is_correct(manager_cls):
67+
process_manager = manager_cls()
68+
process = process_manager.create_process((), target=return_pid)
69+
process.start()
70+
71+
msg = process_manager.finished_queue.get()
72+
assert int(msg) == process.pid

0 commit comments

Comments
 (0)