diff --git a/distributed/utils_test.py b/distributed/utils_test.py index c7af72d6f3b..fc0671e074c 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -568,9 +568,13 @@ def security(): return tls_only_security() -def _terminate_join(proc): - proc.terminate() - proc.join() +def _kill_join(proc, timeout): + proc.kill() + proc.join(timeout) + if proc.is_alive(): + raise multiprocessing.TimeoutError( + f"Process {proc} did not shut down within {timeout}s" + ) proc.close() @@ -586,7 +590,7 @@ def cluster( nanny=False, worker_kwargs=None, active_rpc_timeout=10, - disconnect_timeout=20, + shutdown_timeout=20, scheduler_kwargs=None, config=None, ): @@ -618,7 +622,7 @@ def cluster( ) ws.add(scheduler) scheduler.start() - stack.callback(_terminate_join, scheduler) + stack.callback(_kill_join, scheduler, shutdown_timeout) # Launch workers workers_by_pid = {} @@ -640,7 +644,7 @@ def cluster( ) ws.add(proc) proc.start() - stack.callback(_terminate_join, proc) + stack.callback(_kill_join, proc, shutdown_timeout) workers_by_pid[proc.pid] = {"proc": proc} saddr_or_exception = scheduler_q.get() @@ -656,50 +660,27 @@ def cluster( start = time() try: - try: - security = scheduler_kwargs["security"] - rpc_kwargs = { - "connection_args": security.get_connection_args("client") - } - except KeyError: - rpc_kwargs = {} - - async def wait_for_workers(): - async with rpc(saddr, **rpc_kwargs) as s: - while True: - nthreads = await s.ncores_running() - if len(nthreads) == nworkers: - break - if time() - start > 5: - raise Exception("Timeout on cluster creation") + security = scheduler_kwargs["security"] + rpc_kwargs = {"connection_args": security.get_connection_args("client")} + except KeyError: + rpc_kwargs = {} + + async def wait_for_workers(): + async with rpc(saddr, **rpc_kwargs) as s: + while True: + nthreads = await s.ncores_running() + if len(nthreads) == nworkers: + break + if time() - start > 5: + raise Exception("Timeout on cluster creation") - _run_and_close_tornado(wait_for_workers) + _run_and_close_tornado(wait_for_workers) - # avoid sending processes down to function - yield {"address": saddr}, [ - {"address": w["address"], "proc": weakref.ref(w["proc"])} - for w in workers_by_pid.values() - ] - finally: - - async def close(): - logger.debug("Closing out test cluster") - alive_workers = [ - w["address"] - for w in workers_by_pid.values() - if w["proc"].is_alive() - ] - await disconnect_all( - alive_workers, - timeout=disconnect_timeout, - rpc_kwargs=rpc_kwargs, - ) - if scheduler.is_alive(): - await disconnect( - saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs - ) - - _run_and_close_tornado(close) + # avoid sending processes down to function + yield {"address": saddr}, [ + {"address": w["address"], "proc": weakref.ref(w["proc"])} + for w in workers_by_pid.values() + ] try: client = default_client() except ValueError: @@ -708,26 +689,6 @@ async def close(): client.close() -async def disconnect(addr, timeout=3, rpc_kwargs=None): - rpc_kwargs = rpc_kwargs or {} - - async def do_disconnect(): - async with rpc(addr, **rpc_kwargs) as w: - # If the worker was killed hard (e.g. sigterm) during test runtime, - # we do not know at this point and may not be able to connect - with suppress(EnvironmentError, CommClosedError): - # Do not request a reply since comms will be closed by the - # worker before a reply can be made and we will always trigger - # the timeout - await w.terminate(reply=False) - - await asyncio.wait_for(do_disconnect(), timeout=timeout) - - -async def disconnect_all(addresses, timeout=3, rpc_kwargs=None): - await asyncio.gather(*(disconnect(addr, timeout, rpc_kwargs) for addr in addresses)) - - def gen_test( timeout: float = _TEST_TIMEOUT, clean_kwargs: dict[str, Any] | None = None,