Skip to content

Commit b324b3a

Browse files
Merge pull request #2392 from richardsheridan/from_thread_check_cancelled
Expand cancellation usability from native trio threads
2 parents b161fec + ab092b0 commit b324b3a

8 files changed

+441
-130
lines changed

docs/source/reference-core.rst

+19
Original file line numberDiff line numberDiff line change
@@ -1823,6 +1823,25 @@ to spawn a child thread, and then use a :ref:`memory channel
18231823

18241824
.. literalinclude:: reference-core/from-thread-example.py
18251825

1826+
.. note::
1827+
1828+
The ``from_thread.run*`` functions reuse the host task that called
1829+
:func:`trio.to_thread.run_sync` to run your provided function, as long as you're
1830+
using the default ``cancellable=False`` so Trio can be sure that the task will remain
1831+
around to perform the work. If you pass ``cancellable=True`` at the outset, or if
1832+
you provide a :class:`~trio.lowlevel.TrioToken` when calling back in to Trio, your
1833+
functions will be executed in a new system task. Therefore, the
1834+
:func:`~trio.lowlevel.current_task`, :func:`current_effective_deadline`, or other
1835+
task-tree specific values may differ depending on keyword argument values.
1836+
1837+
You can also use :func:`trio.from_thread.check_cancelled` to check for cancellation from
1838+
a thread that was spawned by :func:`trio.to_thread.run_sync`. If the call to
1839+
:func:`~trio.to_thread.run_sync` was cancelled (even if ``cancellable=False``!), then
1840+
:func:`~trio.from_thread.check_cancelled` will raise :func:`trio.Cancelled`.
1841+
It's like ``trio.from_thread.run(trio.sleep, 0)``, but much faster.
1842+
1843+
.. autofunction:: trio.from_thread.check_cancelled
1844+
18261845
Threads and task-local storage
18271846
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
18281847

newsfragments/2392.feature.rst

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
If called from a thread spawned by `trio.to_thread.run_sync`, `trio.from_thread.run` and
2+
`trio.from_thread.run_sync` now reuse the task and cancellation status of the host task;
3+
this means that context variables and cancel scopes naturally propagate 'through'
4+
threads spawned by Trio. You can also use `trio.from_thread.check_cancelled`
5+
to efficiently check for cancellation without reentering the Trio thread.

trio/_tests/test_threads.py

+188-15
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
import pytest
1414
import sniffio
1515

16-
from trio._core import TrioToken, current_trio_token
17-
18-
from .. import CapacityLimiter, Event, _core, sleep
16+
from .. import CapacityLimiter, Event, _core, fail_after, sleep, sleep_forever
1917
from .._core._tests.test_ki import ki_self
2018
from .._core._tests.tutil import buggy_pypy_asyncgens
2119
from .._threads import (
2220
current_default_thread_limiter,
21+
from_thread_check_cancelled,
2322
from_thread_run,
2423
from_thread_run_sync,
2524
to_thread_run_sync,
@@ -645,7 +644,7 @@ async def async_fn(): # pragma: no cover
645644
def thread_fn():
646645
from_thread_run_sync(async_fn)
647646

648-
with pytest.raises(TypeError, match="expected a sync function"):
647+
with pytest.raises(TypeError, match="expected a synchronous function"):
649648
await to_thread_run_sync(thread_fn)
650649

651650

@@ -810,25 +809,32 @@ def test_from_thread_run_during_shutdown():
810809
save = []
811810
record = []
812811

813-
async def agen():
812+
async def agen(token):
814813
try:
815814
yield
816815
finally:
817-
with pytest.raises(_core.RunFinishedError), _core.CancelScope(shield=True):
818-
await to_thread_run_sync(from_thread_run, sleep, 0)
819-
record.append("ok")
820-
821-
async def main():
822-
save.append(agen())
816+
with _core.CancelScope(shield=True):
817+
try:
818+
await to_thread_run_sync(
819+
partial(from_thread_run, sleep, 0, trio_token=token)
820+
)
821+
except _core.RunFinishedError:
822+
record.append("finished")
823+
else:
824+
record.append("clean")
825+
826+
async def main(use_system_task):
827+
save.append(agen(_core.current_trio_token() if use_system_task else None))
823828
await save[-1].asend(None)
824829

825-
_core.run(main)
826-
assert record == ["ok"]
830+
_core.run(main, True) # System nursery will be closed and raise RunFinishedError
831+
_core.run(main, False) # host task will be rescheduled as normal
832+
assert record == ["finished", "clean"]
827833

828834

829835
async def test_trio_token_weak_referenceable():
830-
token = current_trio_token()
831-
assert isinstance(token, TrioToken)
836+
token = _core.current_trio_token()
837+
assert isinstance(token, _core.TrioToken)
832838
weak_reference = weakref.ref(token)
833839
assert token is weak_reference()
834840

@@ -842,3 +848,170 @@ def __bool__(self):
842848

843849
with pytest.raises(NotImplementedError):
844850
await to_thread_run_sync(int, cancellable=BadBool())
851+
852+
853+
async def test_from_thread_reuses_task():
854+
task = _core.current_task()
855+
856+
async def async_current_task():
857+
return _core.current_task()
858+
859+
assert task is await to_thread_run_sync(from_thread_run_sync, _core.current_task)
860+
assert task is await to_thread_run_sync(from_thread_run, async_current_task)
861+
862+
863+
async def test_recursive_to_thread():
864+
tid = None
865+
866+
def get_tid_then_reenter():
867+
nonlocal tid
868+
tid = threading.get_ident()
869+
return from_thread_run(to_thread_run_sync, threading.get_ident)
870+
871+
assert tid != await to_thread_run_sync(get_tid_then_reenter)
872+
873+
874+
async def test_from_thread_host_cancelled():
875+
queue = stdlib_queue.Queue()
876+
877+
def sync_check():
878+
from_thread_run_sync(cancel_scope.cancel)
879+
try:
880+
from_thread_run_sync(bool)
881+
except _core.Cancelled: # pragma: no cover
882+
queue.put(True) # sync functions don't raise Cancelled
883+
else:
884+
queue.put(False)
885+
886+
with _core.CancelScope() as cancel_scope:
887+
await to_thread_run_sync(sync_check)
888+
889+
assert not cancel_scope.cancelled_caught
890+
assert not queue.get_nowait()
891+
892+
with _core.CancelScope() as cancel_scope:
893+
await to_thread_run_sync(sync_check, cancellable=True)
894+
895+
assert cancel_scope.cancelled_caught
896+
assert not await to_thread_run_sync(partial(queue.get, timeout=1))
897+
898+
async def no_checkpoint():
899+
return True
900+
901+
def async_check():
902+
from_thread_run_sync(cancel_scope.cancel)
903+
try:
904+
assert from_thread_run(no_checkpoint)
905+
except _core.Cancelled: # pragma: no cover
906+
queue.put(True) # async functions raise Cancelled at checkpoints
907+
else:
908+
queue.put(False)
909+
910+
with _core.CancelScope() as cancel_scope:
911+
await to_thread_run_sync(async_check)
912+
913+
assert not cancel_scope.cancelled_caught
914+
assert not queue.get_nowait()
915+
916+
with _core.CancelScope() as cancel_scope:
917+
await to_thread_run_sync(async_check, cancellable=True)
918+
919+
assert cancel_scope.cancelled_caught
920+
assert not await to_thread_run_sync(partial(queue.get, timeout=1))
921+
922+
async def async_time_bomb():
923+
cancel_scope.cancel()
924+
with fail_after(10):
925+
await sleep_forever()
926+
927+
with _core.CancelScope() as cancel_scope:
928+
await to_thread_run_sync(from_thread_run, async_time_bomb)
929+
930+
assert cancel_scope.cancelled_caught
931+
932+
933+
async def test_from_thread_check_cancelled():
934+
q = stdlib_queue.Queue()
935+
936+
async def child(cancellable, scope):
937+
with scope:
938+
record.append("start")
939+
try:
940+
return await to_thread_run_sync(f, cancellable=cancellable)
941+
except _core.Cancelled:
942+
record.append("cancel")
943+
raise
944+
finally:
945+
record.append("exit")
946+
947+
def f():
948+
try:
949+
from_thread_check_cancelled()
950+
except _core.Cancelled: # pragma: no cover, test failure path
951+
q.put("Cancelled")
952+
else:
953+
q.put("Not Cancelled")
954+
ev.wait()
955+
return from_thread_check_cancelled()
956+
957+
# Base case: nothing cancelled so we shouldn't see cancels anywhere
958+
record = []
959+
ev = threading.Event()
960+
async with _core.open_nursery() as nursery:
961+
nursery.start_soon(child, False, _core.CancelScope())
962+
await wait_all_tasks_blocked()
963+
assert record[0] == "start"
964+
assert q.get(timeout=1) == "Not Cancelled"
965+
ev.set()
966+
# implicit assertion, Cancelled not raised via nursery
967+
assert record[1] == "exit"
968+
969+
# cancellable=False case: a cancel will pop out but be handled by
970+
# the appropriate cancel scope
971+
record = []
972+
ev = threading.Event()
973+
scope = _core.CancelScope() # Nursery cancel scope gives false positives
974+
async with _core.open_nursery() as nursery:
975+
nursery.start_soon(child, False, scope)
976+
await wait_all_tasks_blocked()
977+
assert record[0] == "start"
978+
assert q.get(timeout=1) == "Not Cancelled"
979+
scope.cancel()
980+
ev.set()
981+
assert scope.cancelled_caught
982+
assert "cancel" in record
983+
assert record[-1] == "exit"
984+
985+
# cancellable=True case: slightly different thread behavior needed
986+
# check thread is cancelled "soon" after abandonment
987+
def f(): # noqa: F811
988+
ev.wait()
989+
try:
990+
from_thread_check_cancelled()
991+
except _core.Cancelled:
992+
q.put("Cancelled")
993+
else: # pragma: no cover, test failure path
994+
q.put("Not Cancelled")
995+
996+
record = []
997+
ev = threading.Event()
998+
scope = _core.CancelScope()
999+
async with _core.open_nursery() as nursery:
1000+
nursery.start_soon(child, True, scope)
1001+
await wait_all_tasks_blocked()
1002+
assert record[0] == "start"
1003+
scope.cancel()
1004+
ev.set()
1005+
assert scope.cancelled_caught
1006+
assert "cancel" in record
1007+
assert record[-1] == "exit"
1008+
assert q.get(timeout=1) == "Cancelled"
1009+
1010+
1011+
async def test_from_thread_check_cancelled_raises_in_foreign_threads():
1012+
with pytest.raises(RuntimeError):
1013+
from_thread_check_cancelled()
1014+
q = stdlib_queue.Queue()
1015+
_core.start_thread_soon(from_thread_check_cancelled, lambda _: q.put(_))
1016+
with pytest.raises(RuntimeError):
1017+
q.get(timeout=1).unwrap()

trio/_tests/verify_types_darwin.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
],
4141
"exportedSymbolCounts": {
4242
"withAmbiguousType": 0,
43-
"withKnownType": 630,
43+
"withKnownType": 631,
4444
"withUnknownType": 0
4545
},
4646
"ignoreUnknownTypesFromImports": true,

trio/_tests/verify_types_linux.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
],
2929
"exportedSymbolCounts": {
3030
"withAmbiguousType": 0,
31-
"withKnownType": 627,
31+
"withKnownType": 628,
3232
"withUnknownType": 0
3333
},
3434
"ignoreUnknownTypesFromImports": true,

trio/_tests/verify_types_windows.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
],
6565
"exportedSymbolCounts": {
6666
"withAmbiguousType": 0,
67-
"withKnownType": 630,
67+
"withKnownType": 631,
6868
"withUnknownType": 0
6969
},
7070
"ignoreUnknownTypesFromImports": true,

0 commit comments

Comments
 (0)