-
-
Notifications
You must be signed in to change notification settings - Fork 732
Don't connect to cluster subprocesses at shutdown #6829
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -568,9 +568,13 @@ def security(): | |
return tls_only_security() | ||
|
||
|
||
def _terminate_join(proc): | ||
proc.terminate() | ||
proc.join() | ||
def _kill_join(proc, timeout): | ||
proc.kill() | ||
proc.join(timeout) | ||
if proc.is_alive(): | ||
raise multiprocessing.TimeoutError( | ||
f"Process {proc} did not shut down within {timeout}s" | ||
) | ||
proc.close() | ||
|
||
|
||
|
@@ -586,7 +590,7 @@ def cluster( | |
nanny=False, | ||
worker_kwargs=None, | ||
active_rpc_timeout=10, | ||
disconnect_timeout=20, | ||
shutdown_timeout=20, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This argument name was unused across the codebase, so changing it seems fine. |
||
scheduler_kwargs=None, | ||
config=None, | ||
): | ||
|
@@ -618,7 +622,7 @@ def cluster( | |
) | ||
ws.add(scheduler) | ||
scheduler.start() | ||
stack.callback(_terminate_join, scheduler) | ||
stack.callback(_kill_join, scheduler, shutdown_timeout) | ||
|
||
# Launch workers | ||
workers_by_pid = {} | ||
|
@@ -640,7 +644,7 @@ def cluster( | |
) | ||
ws.add(proc) | ||
proc.start() | ||
stack.callback(_terminate_join, proc) | ||
stack.callback(_kill_join, proc, shutdown_timeout) | ||
workers_by_pid[proc.pid] = {"proc": proc} | ||
|
||
saddr_or_exception = scheduler_q.get() | ||
|
@@ -656,50 +660,27 @@ def cluster( | |
|
||
start = time() | ||
try: | ||
try: | ||
security = scheduler_kwargs["security"] | ||
rpc_kwargs = { | ||
"connection_args": security.get_connection_args("client") | ||
} | ||
except KeyError: | ||
rpc_kwargs = {} | ||
|
||
async def wait_for_workers(): | ||
async with rpc(saddr, **rpc_kwargs) as s: | ||
while True: | ||
nthreads = await s.ncores_running() | ||
if len(nthreads) == nworkers: | ||
break | ||
if time() - start > 5: | ||
raise Exception("Timeout on cluster creation") | ||
security = scheduler_kwargs["security"] | ||
rpc_kwargs = {"connection_args": security.get_connection_args("client")} | ||
except KeyError: | ||
rpc_kwargs = {} | ||
|
||
async def wait_for_workers(): | ||
async with rpc(saddr, **rpc_kwargs) as s: | ||
while True: | ||
nthreads = await s.ncores_running() | ||
if len(nthreads) == nworkers: | ||
break | ||
if time() - start > 5: | ||
raise Exception("Timeout on cluster creation") | ||
|
||
_run_and_close_tornado(wait_for_workers) | ||
_run_and_close_tornado(wait_for_workers) | ||
|
||
# avoid sending processes down to function | ||
yield {"address": saddr}, [ | ||
{"address": w["address"], "proc": weakref.ref(w["proc"])} | ||
for w in workers_by_pid.values() | ||
] | ||
finally: | ||
|
||
async def close(): | ||
logger.debug("Closing out test cluster") | ||
alive_workers = [ | ||
w["address"] | ||
for w in workers_by_pid.values() | ||
if w["proc"].is_alive() | ||
] | ||
await disconnect_all( | ||
alive_workers, | ||
timeout=disconnect_timeout, | ||
rpc_kwargs=rpc_kwargs, | ||
) | ||
if scheduler.is_alive(): | ||
await disconnect( | ||
saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs | ||
) | ||
Comment on lines
-685
to
-700
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The change is that I just deleted this entire |
||
|
||
_run_and_close_tornado(close) | ||
# avoid sending processes down to function | ||
yield {"address": saddr}, [ | ||
{"address": w["address"], "proc": weakref.ref(w["proc"])} | ||
for w in workers_by_pid.values() | ||
] | ||
try: | ||
client = default_client() | ||
except ValueError: | ||
|
@@ -708,26 +689,6 @@ async def close(): | |
client.close() | ||
|
||
|
||
async def disconnect(addr, timeout=3, rpc_kwargs=None): | ||
rpc_kwargs = rpc_kwargs or {} | ||
|
||
async def do_disconnect(): | ||
async with rpc(addr, **rpc_kwargs) as w: | ||
# If the worker was killed hard (e.g. sigterm) during test runtime, | ||
# we do not know at this point and may not be able to connect | ||
with suppress(EnvironmentError, CommClosedError): | ||
# Do not request a reply since comms will be closed by the | ||
# worker before a reply can be made and we will always trigger | ||
# the timeout | ||
await w.terminate(reply=False) | ||
|
||
await asyncio.wait_for(do_disconnect(), timeout=timeout) | ||
|
||
|
||
async def disconnect_all(addresses, timeout=3, rpc_kwargs=None): | ||
await asyncio.gather(*(disconnect(addr, timeout, rpc_kwargs) for addr in addresses)) | ||
|
||
|
||
def gen_test( | ||
timeout: float = _TEST_TIMEOUT, | ||
clean_kwargs: dict[str, Any] | None = None, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rely on the pytest timeout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also maybe kill them all at the same time then join them all