@@ -568,9 +568,13 @@ def security():
568
568
return tls_only_security ()
569
569
570
570
571
- def _terminate_join (proc ):
572
- proc .terminate ()
573
- proc .join ()
571
+ def _kill_join (proc , timeout ):
572
+ proc .kill ()
573
+ proc .join (timeout )
574
+ if proc .is_alive ():
575
+ raise multiprocessing .TimeoutError (
576
+ f"Process { proc } did not shut down within { timeout } s"
577
+ )
574
578
proc .close ()
575
579
576
580
@@ -586,7 +590,7 @@ def cluster(
586
590
nanny = False ,
587
591
worker_kwargs = None ,
588
592
active_rpc_timeout = 10 ,
589
- disconnect_timeout = 20 ,
593
+ shutdown_timeout = 20 ,
590
594
scheduler_kwargs = None ,
591
595
config = None ,
592
596
):
@@ -618,7 +622,7 @@ def cluster(
618
622
)
619
623
ws .add (scheduler )
620
624
scheduler .start ()
621
- stack .callback (_terminate_join , scheduler )
625
+ stack .callback (_kill_join , scheduler , shutdown_timeout )
622
626
623
627
# Launch workers
624
628
workers_by_pid = {}
@@ -640,7 +644,7 @@ def cluster(
640
644
)
641
645
ws .add (proc )
642
646
proc .start ()
643
- stack .callback (_terminate_join , proc )
647
+ stack .callback (_kill_join , proc , shutdown_timeout )
644
648
workers_by_pid [proc .pid ] = {"proc" : proc }
645
649
646
650
saddr_or_exception = scheduler_q .get ()
@@ -656,50 +660,27 @@ def cluster(
656
660
657
661
start = time ()
658
662
try :
659
- try :
660
- security = scheduler_kwargs ["security" ]
661
- rpc_kwargs = {
662
- "connection_args" : security .get_connection_args ("client" )
663
- }
664
- except KeyError :
665
- rpc_kwargs = {}
666
-
667
- async def wait_for_workers ():
668
- async with rpc (saddr , ** rpc_kwargs ) as s :
669
- while True :
670
- nthreads = await s .ncores_running ()
671
- if len (nthreads ) == nworkers :
672
- break
673
- if time () - start > 5 :
674
- raise Exception ("Timeout on cluster creation" )
663
+ security = scheduler_kwargs ["security" ]
664
+ rpc_kwargs = {"connection_args" : security .get_connection_args ("client" )}
665
+ except KeyError :
666
+ rpc_kwargs = {}
667
+
668
+ async def wait_for_workers ():
669
+ async with rpc (saddr , ** rpc_kwargs ) as s :
670
+ while True :
671
+ nthreads = await s .ncores_running ()
672
+ if len (nthreads ) == nworkers :
673
+ break
674
+ if time () - start > 5 :
675
+ raise Exception ("Timeout on cluster creation" )
675
676
676
- _run_and_close_tornado (wait_for_workers )
677
+ _run_and_close_tornado (wait_for_workers )
677
678
678
- # avoid sending processes down to function
679
- yield {"address" : saddr }, [
680
- {"address" : w ["address" ], "proc" : weakref .ref (w ["proc" ])}
681
- for w in workers_by_pid .values ()
682
- ]
683
- finally :
684
-
685
- async def close ():
686
- logger .debug ("Closing out test cluster" )
687
- alive_workers = [
688
- w ["address" ]
689
- for w in workers_by_pid .values ()
690
- if w ["proc" ].is_alive ()
691
- ]
692
- await disconnect_all (
693
- alive_workers ,
694
- timeout = disconnect_timeout ,
695
- rpc_kwargs = rpc_kwargs ,
696
- )
697
- if scheduler .is_alive ():
698
- await disconnect (
699
- saddr , timeout = disconnect_timeout , rpc_kwargs = rpc_kwargs
700
- )
701
-
702
- _run_and_close_tornado (close )
679
+ # avoid sending processes down to function
680
+ yield {"address" : saddr }, [
681
+ {"address" : w ["address" ], "proc" : weakref .ref (w ["proc" ])}
682
+ for w in workers_by_pid .values ()
683
+ ]
703
684
try :
704
685
client = default_client ()
705
686
except ValueError :
@@ -708,26 +689,6 @@ async def close():
708
689
client .close ()
709
690
710
691
711
- async def disconnect (addr , timeout = 3 , rpc_kwargs = None ):
712
- rpc_kwargs = rpc_kwargs or {}
713
-
714
- async def do_disconnect ():
715
- async with rpc (addr , ** rpc_kwargs ) as w :
716
- # If the worker was killed hard (e.g. sigterm) during test runtime,
717
- # we do not know at this point and may not be able to connect
718
- with suppress (EnvironmentError , CommClosedError ):
719
- # Do not request a reply since comms will be closed by the
720
- # worker before a reply can be made and we will always trigger
721
- # the timeout
722
- await w .terminate (reply = False )
723
-
724
- await asyncio .wait_for (do_disconnect (), timeout = timeout )
725
-
726
-
727
- async def disconnect_all (addresses , timeout = 3 , rpc_kwargs = None ):
728
- await asyncio .gather (* (disconnect (addr , timeout , rpc_kwargs ) for addr in addresses ))
729
-
730
-
731
692
def gen_test (
732
693
timeout : float = _TEST_TIMEOUT ,
733
694
clean_kwargs : dict [str , Any ] | None = None ,
0 commit comments