diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index fc7b202a961..a928edbd0fa 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -167,9 +167,9 @@ def test_worker_doesnt_await_task_completion(loop): future = c.submit(sleep, 100) sleep(0.1) start = time() - c.restart() + c.restart(timeout="5s", wait_for_workers=False) stop = time() - assert stop - start < 20 + assert stop - start < 10 @gen_cluster(Worker=Nanny, timeout=60) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 23f8817ab10..e519ef18f78 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -152,7 +152,7 @@ def loop_in_thread(cleanup): loop_started = concurrent.futures.Future() with concurrent.futures.ThreadPoolExecutor( 1, thread_name_prefix="test IOLoop" - ) as tpe: + ) as tpe, config_for_cluster_tests(): async def run(): io_loop = IOLoop.current() @@ -672,7 +672,7 @@ def cluster( ws = weakref.WeakSet() enable_proctitle_on_children() - with check_process_leak(check=True), check_instances(), _reconfigure(): + with check_process_leak(check=True), check_instances(), config_for_cluster_tests(): if nanny: _run_worker = run_nanny else: @@ -835,6 +835,7 @@ async def async_fn_outer(async_fn, /, *args, **kwargs): def _(func): @functools.wraps(func) + @config_for_cluster_tests() @clean(**clean_kwargs) def test_func(*args, **kwargs): if not iscoroutinefunction(func): @@ -1038,6 +1039,7 @@ def _(func): raise RuntimeError("gen_cluster only works for coroutine functions.") @functools.wraps(func) + @config_for_cluster_tests(**{"distributed.comm.timeouts.connect": "5s"}) @clean(**clean_kwargs) def test_func(*outer_args, **kwargs): async def async_fn(): @@ -1880,16 +1882,17 @@ def check_instances(): @contextmanager -def _reconfigure(): +def config_for_cluster_tests(**extra_config): + "Set recommended config values for tests that create or interact with clusters." reset_config() with dask.config.set( { "local_directory": tempfile.gettempdir(), - "distributed.comm.timeouts.connect": "5s", "distributed.admin.tick.interval": "500 ms", "distributed.worker.profile.enabled": False, - } + }, + **extra_config, ): # Restore default logging levels # XXX use pytest hooks/fixtures instead? @@ -1905,8 +1908,7 @@ def clean(threads=True, instances=True, processes=True): with check_thread_leak() if threads else nullcontext(): with check_process_leak(check=processes): with check_instances() if instances else nullcontext(): - with _reconfigure(): - yield + yield @pytest.fixture