Skip to content

Commit 9f4165a

Browse files
authored
Support fixtures and pytest.mark.parametrize with gen_cluster (#4958)
Support fixtures and `pytest.mark.parametrize` with `gen_cluster` (#4958)
1 parent ac35e0f commit 9f4165a

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

distributed/tests/test_utils_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import pathlib
23
import socket
34
import threading
45
from contextlib import contextmanager
@@ -45,6 +46,47 @@ async def test_gen_cluster(c, s, a, b):
4546
assert await c.submit(lambda: 123) == 123
4647

4748

49+
@gen_cluster(client=True)
50+
async def test_gen_cluster_pytest_fixture(c, s, a, b, tmp_path):
51+
assert isinstance(tmp_path, pathlib.Path)
52+
assert isinstance(c, Client)
53+
assert isinstance(s, Scheduler)
54+
for w in [a, b]:
55+
assert isinstance(w, Worker)
56+
57+
58+
@pytest.mark.parametrize("foo", [True])
59+
@gen_cluster(client=True)
60+
async def test_gen_cluster_parametrized(c, s, a, b, foo):
61+
assert foo is True
62+
assert isinstance(c, Client)
63+
assert isinstance(s, Scheduler)
64+
for w in [a, b]:
65+
assert isinstance(w, Worker)
66+
67+
68+
@pytest.mark.parametrize("foo", [True])
69+
@pytest.mark.parametrize("bar", ["a", "b"])
70+
@gen_cluster(client=True)
71+
async def test_gen_cluster_multi_parametrized(c, s, a, b, foo, bar):
72+
assert foo is True
73+
assert bar in ("a", "b")
74+
assert isinstance(c, Client)
75+
assert isinstance(s, Scheduler)
76+
for w in [a, b]:
77+
assert isinstance(w, Worker)
78+
79+
80+
@pytest.mark.parametrize("foo", [True])
81+
@gen_cluster(client=True)
82+
async def test_gen_cluster_parametrized_variadic_workers(c, s, *workers, foo):
83+
assert foo is True
84+
assert isinstance(c, Client)
85+
assert isinstance(s, Scheduler)
86+
for w in workers:
87+
assert isinstance(w, Worker)
88+
89+
4890
@gen_cluster(
4991
client=True,
5092
Worker=Nanny,

distributed/utils_test.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import functools
55
import gc
6+
import inspect
67
import io
78
import itertools
89
import logging
@@ -861,6 +862,15 @@ def gen_cluster(
861862
async def test_foo(scheduler, worker1, worker2):
862863
await ... # use tornado coroutines
863864
865+
@pytest.mark.parametrize("param", [1, 2, 3])
866+
@gen_cluster()
867+
async def test_foo(scheduler, worker1, worker2, param):
868+
await ... # use tornado coroutines
869+
870+
@gen_cluster()
871+
async def test_foo(scheduler, worker1, worker2, pytest_fixture_a, pytest_fixture_b):
872+
await ... # use tornado coroutines
873+
864874
See also:
865875
start
866876
end
@@ -877,7 +887,7 @@ def _(func):
877887
if not iscoroutinefunction(func):
878888
func = gen.coroutine(func)
879889

880-
def test_func():
890+
def test_func(*outer_args, **kwargs):
881891
result = None
882892
workers = []
883893
with clean(timeout=active_rpc_timeout, **clean_kwargs) as loop:
@@ -919,7 +929,7 @@ async def coro():
919929
)
920930
args = [c] + args
921931
try:
922-
future = func(*args)
932+
future = func(*args, *outer_args, **kwargs)
923933
if timeout:
924934
future = asyncio.wait_for(future, timeout)
925935
result = await future
@@ -979,6 +989,21 @@ def get_unclosed():
979989

980990
return result
981991

992+
# Patch the signature so pytest can inject fixtures
993+
orig_sig = inspect.signature(func)
994+
args = [None] * (1 + len(nthreads)) # scheduler, *workers
995+
if client:
996+
args.insert(0, None)
997+
998+
bound = orig_sig.bind_partial(*args)
999+
test_func.__signature__ = orig_sig.replace(
1000+
parameters=[
1001+
p
1002+
for name, p in orig_sig.parameters.items()
1003+
if name not in bound.arguments
1004+
]
1005+
)
1006+
9821007
return test_func
9831008

9841009
return _

0 commit comments

Comments
 (0)