Skip to content

Commit e1f3779

Browse files
authored
Don't connect to cluster subprocesses at shutdown (#6829)
1 parent caf5189 commit e1f3779

File tree

1 file changed

+29
-68
lines changed

1 file changed

+29
-68
lines changed

distributed/utils_test.py

Lines changed: 29 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,13 @@ def security():
568568
return tls_only_security()
569569

570570

571-
def _terminate_join(proc):
572-
proc.terminate()
573-
proc.join()
571+
def _kill_join(proc, timeout):
572+
proc.kill()
573+
proc.join(timeout)
574+
if proc.is_alive():
575+
raise multiprocessing.TimeoutError(
576+
f"Process {proc} did not shut down within {timeout}s"
577+
)
574578
proc.close()
575579

576580

@@ -586,7 +590,7 @@ def cluster(
586590
nanny=False,
587591
worker_kwargs=None,
588592
active_rpc_timeout=10,
589-
disconnect_timeout=20,
593+
shutdown_timeout=20,
590594
scheduler_kwargs=None,
591595
config=None,
592596
):
@@ -618,7 +622,7 @@ def cluster(
618622
)
619623
ws.add(scheduler)
620624
scheduler.start()
621-
stack.callback(_terminate_join, scheduler)
625+
stack.callback(_kill_join, scheduler, shutdown_timeout)
622626

623627
# Launch workers
624628
workers_by_pid = {}
@@ -640,7 +644,7 @@ def cluster(
640644
)
641645
ws.add(proc)
642646
proc.start()
643-
stack.callback(_terminate_join, proc)
647+
stack.callback(_kill_join, proc, shutdown_timeout)
644648
workers_by_pid[proc.pid] = {"proc": proc}
645649

646650
saddr_or_exception = scheduler_q.get()
@@ -656,50 +660,27 @@ def cluster(
656660

657661
start = time()
658662
try:
659-
try:
660-
security = scheduler_kwargs["security"]
661-
rpc_kwargs = {
662-
"connection_args": security.get_connection_args("client")
663-
}
664-
except KeyError:
665-
rpc_kwargs = {}
666-
667-
async def wait_for_workers():
668-
async with rpc(saddr, **rpc_kwargs) as s:
669-
while True:
670-
nthreads = await s.ncores_running()
671-
if len(nthreads) == nworkers:
672-
break
673-
if time() - start > 5:
674-
raise Exception("Timeout on cluster creation")
663+
security = scheduler_kwargs["security"]
664+
rpc_kwargs = {"connection_args": security.get_connection_args("client")}
665+
except KeyError:
666+
rpc_kwargs = {}
667+
668+
async def wait_for_workers():
669+
async with rpc(saddr, **rpc_kwargs) as s:
670+
while True:
671+
nthreads = await s.ncores_running()
672+
if len(nthreads) == nworkers:
673+
break
674+
if time() - start > 5:
675+
raise Exception("Timeout on cluster creation")
675676

676-
_run_and_close_tornado(wait_for_workers)
677+
_run_and_close_tornado(wait_for_workers)
677678

678-
# avoid sending processes down to function
679-
yield {"address": saddr}, [
680-
{"address": w["address"], "proc": weakref.ref(w["proc"])}
681-
for w in workers_by_pid.values()
682-
]
683-
finally:
684-
685-
async def close():
686-
logger.debug("Closing out test cluster")
687-
alive_workers = [
688-
w["address"]
689-
for w in workers_by_pid.values()
690-
if w["proc"].is_alive()
691-
]
692-
await disconnect_all(
693-
alive_workers,
694-
timeout=disconnect_timeout,
695-
rpc_kwargs=rpc_kwargs,
696-
)
697-
if scheduler.is_alive():
698-
await disconnect(
699-
saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs
700-
)
701-
702-
_run_and_close_tornado(close)
679+
# avoid sending processes down to function
680+
yield {"address": saddr}, [
681+
{"address": w["address"], "proc": weakref.ref(w["proc"])}
682+
for w in workers_by_pid.values()
683+
]
703684
try:
704685
client = default_client()
705686
except ValueError:
@@ -708,26 +689,6 @@ async def close():
708689
client.close()
709690

710691

711-
async def disconnect(addr, timeout=3, rpc_kwargs=None):
712-
rpc_kwargs = rpc_kwargs or {}
713-
714-
async def do_disconnect():
715-
async with rpc(addr, **rpc_kwargs) as w:
716-
# If the worker was killed hard (e.g. sigterm) during test runtime,
717-
# we do not know at this point and may not be able to connect
718-
with suppress(EnvironmentError, CommClosedError):
719-
# Do not request a reply since comms will be closed by the
720-
# worker before a reply can be made and we will always trigger
721-
# the timeout
722-
await w.terminate(reply=False)
723-
724-
await asyncio.wait_for(do_disconnect(), timeout=timeout)
725-
726-
727-
async def disconnect_all(addresses, timeout=3, rpc_kwargs=None):
728-
await asyncio.gather(*(disconnect(addr, timeout, rpc_kwargs) for addr in addresses))
729-
730-
731692
def gen_test(
732693
timeout: float = _TEST_TIMEOUT,
733694
clean_kwargs: dict[str, Any] | None = None,

0 commit comments

Comments
 (0)