13
13
import pytest
14
14
import sniffio
15
15
16
- from trio ._core import TrioToken , current_trio_token
17
-
18
- from .. import CapacityLimiter , Event , _core , sleep
16
+ from .. import CapacityLimiter , Event , _core , fail_after , sleep , sleep_forever
19
17
from .._core ._tests .test_ki import ki_self
20
18
from .._core ._tests .tutil import buggy_pypy_asyncgens
21
19
from .._threads import (
22
20
current_default_thread_limiter ,
21
+ from_thread_check_cancelled ,
23
22
from_thread_run ,
24
23
from_thread_run_sync ,
25
24
to_thread_run_sync ,
@@ -645,7 +644,7 @@ async def async_fn(): # pragma: no cover
645
644
def thread_fn ():
646
645
from_thread_run_sync (async_fn )
647
646
648
- with pytest .raises (TypeError , match = "expected a sync function" ):
647
+ with pytest .raises (TypeError , match = "expected a synchronous function" ):
649
648
await to_thread_run_sync (thread_fn )
650
649
651
650
@@ -810,25 +809,32 @@ def test_from_thread_run_during_shutdown():
810
809
save = []
811
810
record = []
812
811
813
- async def agen ():
812
+ async def agen (token ):
814
813
try :
815
814
yield
816
815
finally :
817
- with pytest .raises (_core .RunFinishedError ), _core .CancelScope (shield = True ):
818
- await to_thread_run_sync (from_thread_run , sleep , 0 )
819
- record .append ("ok" )
820
-
821
- async def main ():
822
- save .append (agen ())
816
+ with _core .CancelScope (shield = True ):
817
+ try :
818
+ await to_thread_run_sync (
819
+ partial (from_thread_run , sleep , 0 , trio_token = token )
820
+ )
821
+ except _core .RunFinishedError :
822
+ record .append ("finished" )
823
+ else :
824
+ record .append ("clean" )
825
+
826
+ async def main (use_system_task ):
827
+ save .append (agen (_core .current_trio_token () if use_system_task else None ))
823
828
await save [- 1 ].asend (None )
824
829
825
- _core .run (main )
826
- assert record == ["ok" ]
830
+ _core .run (main , True ) # System nursery will be closed and raise RunFinishedError
831
+ _core .run (main , False ) # host task will be rescheduled as normal
832
+ assert record == ["finished" , "clean" ]
827
833
828
834
829
835
async def test_trio_token_weak_referenceable ():
830
- token = current_trio_token ()
831
- assert isinstance (token , TrioToken )
836
+ token = _core . current_trio_token ()
837
+ assert isinstance (token , _core . TrioToken )
832
838
weak_reference = weakref .ref (token )
833
839
assert token is weak_reference ()
834
840
@@ -842,3 +848,170 @@ def __bool__(self):
842
848
843
849
with pytest .raises (NotImplementedError ):
844
850
await to_thread_run_sync (int , cancellable = BadBool ())
851
+
852
+
853
+ async def test_from_thread_reuses_task ():
854
+ task = _core .current_task ()
855
+
856
+ async def async_current_task ():
857
+ return _core .current_task ()
858
+
859
+ assert task is await to_thread_run_sync (from_thread_run_sync , _core .current_task )
860
+ assert task is await to_thread_run_sync (from_thread_run , async_current_task )
861
+
862
+
863
+ async def test_recursive_to_thread ():
864
+ tid = None
865
+
866
+ def get_tid_then_reenter ():
867
+ nonlocal tid
868
+ tid = threading .get_ident ()
869
+ return from_thread_run (to_thread_run_sync , threading .get_ident )
870
+
871
+ assert tid != await to_thread_run_sync (get_tid_then_reenter )
872
+
873
+
874
+ async def test_from_thread_host_cancelled ():
875
+ queue = stdlib_queue .Queue ()
876
+
877
+ def sync_check ():
878
+ from_thread_run_sync (cancel_scope .cancel )
879
+ try :
880
+ from_thread_run_sync (bool )
881
+ except _core .Cancelled : # pragma: no cover
882
+ queue .put (True ) # sync functions don't raise Cancelled
883
+ else :
884
+ queue .put (False )
885
+
886
+ with _core .CancelScope () as cancel_scope :
887
+ await to_thread_run_sync (sync_check )
888
+
889
+ assert not cancel_scope .cancelled_caught
890
+ assert not queue .get_nowait ()
891
+
892
+ with _core .CancelScope () as cancel_scope :
893
+ await to_thread_run_sync (sync_check , cancellable = True )
894
+
895
+ assert cancel_scope .cancelled_caught
896
+ assert not await to_thread_run_sync (partial (queue .get , timeout = 1 ))
897
+
898
+ async def no_checkpoint ():
899
+ return True
900
+
901
+ def async_check ():
902
+ from_thread_run_sync (cancel_scope .cancel )
903
+ try :
904
+ assert from_thread_run (no_checkpoint )
905
+ except _core .Cancelled : # pragma: no cover
906
+ queue .put (True ) # async functions raise Cancelled at checkpoints
907
+ else :
908
+ queue .put (False )
909
+
910
+ with _core .CancelScope () as cancel_scope :
911
+ await to_thread_run_sync (async_check )
912
+
913
+ assert not cancel_scope .cancelled_caught
914
+ assert not queue .get_nowait ()
915
+
916
+ with _core .CancelScope () as cancel_scope :
917
+ await to_thread_run_sync (async_check , cancellable = True )
918
+
919
+ assert cancel_scope .cancelled_caught
920
+ assert not await to_thread_run_sync (partial (queue .get , timeout = 1 ))
921
+
922
+ async def async_time_bomb ():
923
+ cancel_scope .cancel ()
924
+ with fail_after (10 ):
925
+ await sleep_forever ()
926
+
927
+ with _core .CancelScope () as cancel_scope :
928
+ await to_thread_run_sync (from_thread_run , async_time_bomb )
929
+
930
+ assert cancel_scope .cancelled_caught
931
+
932
+
933
+ async def test_from_thread_check_cancelled ():
934
+ q = stdlib_queue .Queue ()
935
+
936
+ async def child (cancellable , scope ):
937
+ with scope :
938
+ record .append ("start" )
939
+ try :
940
+ return await to_thread_run_sync (f , cancellable = cancellable )
941
+ except _core .Cancelled :
942
+ record .append ("cancel" )
943
+ raise
944
+ finally :
945
+ record .append ("exit" )
946
+
947
+ def f ():
948
+ try :
949
+ from_thread_check_cancelled ()
950
+ except _core .Cancelled : # pragma: no cover, test failure path
951
+ q .put ("Cancelled" )
952
+ else :
953
+ q .put ("Not Cancelled" )
954
+ ev .wait ()
955
+ return from_thread_check_cancelled ()
956
+
957
+ # Base case: nothing cancelled so we shouldn't see cancels anywhere
958
+ record = []
959
+ ev = threading .Event ()
960
+ async with _core .open_nursery () as nursery :
961
+ nursery .start_soon (child , False , _core .CancelScope ())
962
+ await wait_all_tasks_blocked ()
963
+ assert record [0 ] == "start"
964
+ assert q .get (timeout = 1 ) == "Not Cancelled"
965
+ ev .set ()
966
+ # implicit assertion, Cancelled not raised via nursery
967
+ assert record [1 ] == "exit"
968
+
969
+ # cancellable=False case: a cancel will pop out but be handled by
970
+ # the appropriate cancel scope
971
+ record = []
972
+ ev = threading .Event ()
973
+ scope = _core .CancelScope () # Nursery cancel scope gives false positives
974
+ async with _core .open_nursery () as nursery :
975
+ nursery .start_soon (child , False , scope )
976
+ await wait_all_tasks_blocked ()
977
+ assert record [0 ] == "start"
978
+ assert q .get (timeout = 1 ) == "Not Cancelled"
979
+ scope .cancel ()
980
+ ev .set ()
981
+ assert scope .cancelled_caught
982
+ assert "cancel" in record
983
+ assert record [- 1 ] == "exit"
984
+
985
+ # cancellable=True case: slightly different thread behavior needed
986
+ # check thread is cancelled "soon" after abandonment
987
+ def f (): # noqa: F811
988
+ ev .wait ()
989
+ try :
990
+ from_thread_check_cancelled ()
991
+ except _core .Cancelled :
992
+ q .put ("Cancelled" )
993
+ else : # pragma: no cover, test failure path
994
+ q .put ("Not Cancelled" )
995
+
996
+ record = []
997
+ ev = threading .Event ()
998
+ scope = _core .CancelScope ()
999
+ async with _core .open_nursery () as nursery :
1000
+ nursery .start_soon (child , True , scope )
1001
+ await wait_all_tasks_blocked ()
1002
+ assert record [0 ] == "start"
1003
+ scope .cancel ()
1004
+ ev .set ()
1005
+ assert scope .cancelled_caught
1006
+ assert "cancel" in record
1007
+ assert record [- 1 ] == "exit"
1008
+ assert q .get (timeout = 1 ) == "Cancelled"
1009
+
1010
+
1011
+ async def test_from_thread_check_cancelled_raises_in_foreign_threads ():
1012
+ with pytest .raises (RuntimeError ):
1013
+ from_thread_check_cancelled ()
1014
+ q = stdlib_queue .Queue ()
1015
+ _core .start_thread_soon (from_thread_check_cancelled , lambda _ : q .put (_ ))
1016
+ with pytest .raises (RuntimeError ):
1017
+ q .get (timeout = 1 ).unwrap ()
0 commit comments