Skip to content

Commit d39444a

Browse files
authored
Merge pull request #3195 from A5rocks/check-trio-running
Add `in_trio_run` and `in_trio_task`
2 parents 3b94a1a + ce97b8a commit d39444a

File tree

8 files changed

+174
-2
lines changed

8 files changed

+174
-2
lines changed

docs/source/reference-lowlevel.rst

+50
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,56 @@ Global statistics
5656
.. autoclass:: RunStatistics()
5757

5858

59+
.. _trio_contexts:
60+
61+
Checking for Trio
62+
-----------------
63+
64+
If you want to interact with an active Trio run -- perhaps you need to
65+
know the :func:`~trio.current_time` or the
66+
:func:`~trio.lowlevel.current_task` -- then Trio needs to have certain
67+
state available to it or else you will get a
68+
``RuntimeError("must be called from async context")``.
69+
This requires that you either be:
70+
71+
* indirectly inside (and on the same thread as) a call to
72+
:func:`trio.run`, for run-level information such as the
73+
:func:`~trio.current_time` or :func:`~trio.lowlevel.current_clock`;
74+
or
75+
76+
* indirectly inside a Trio task, for task-level information such as
77+
the :func:`~trio.lowlevel.current_task` or
78+
:func:`~trio.current_effective_deadline`.
79+
80+
Internally, this state is provided by thread-local variables tracking
81+
the current run and the current task. Sometimes, it's useful to know
82+
in advance whether a call will fail or to have dynamic information for
83+
safeguards against running something inside or outside Trio. To do so,
84+
call :func:`trio.lowlevel.in_trio_run` or
85+
:func:`trio.lowlevel.in_trio_task`, which will provide answers
86+
according to the following table.
87+
88+
89+
+--------------------------------------------------------+-----------------------------------+------------------------------------+
90+
| situation | :func:`trio.lowlevel.in_trio_run` | :func:`trio.lowlevel.in_trio_task` |
91+
+========================================================+===================================+====================================+
92+
| inside a Trio-flavored async function | `True` | `True` |
93+
+--------------------------------------------------------+-----------------------------------+------------------------------------+
94+
| in a thread without an active call to :func:`trio.run` | `False` | `False` |
95+
+--------------------------------------------------------+-----------------------------------+------------------------------------+
96+
| in a guest run's host loop | `True` | `False` |
97+
+--------------------------------------------------------+-----------------------------------+------------------------------------+
98+
| inside an instrument call | `True` | depends |
99+
+--------------------------------------------------------+-----------------------------------+------------------------------------+
100+
| in a thread created by :func:`trio.to_thread.run_sync` | `False` | `False` |
101+
+--------------------------------------------------------+-----------------------------------+------------------------------------+
102+
| inside an abort function | `True` | `True` |
103+
+--------------------------------------------------------+-----------------------------------+------------------------------------+
104+
105+
.. autofunction:: in_trio_run
106+
107+
.. autofunction:: in_trio_task
108+
59109
The current clock
60110
-----------------
61111

newsfragments/2757.feature.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add :func:`trio.lowlevel.in_trio_run` and :func:`trio.lowlevel.in_trio_task` and document the semantics (and differences) thereof. See :ref:`the documentation <trio_contexts>`.

src/trio/_core/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
current_task,
4646
current_time,
4747
current_trio_token,
48+
in_trio_run,
49+
in_trio_task,
4850
notify_closing,
4951
open_nursery,
5052
remove_instrument,

src/trio/_core/_run.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -2283,7 +2283,7 @@ def setup_runner(
22832283
# It wouldn't be *hard* to support nested calls to run(), but I can't
22842284
# think of a single good reason for it, so let's be conservative for
22852285
# now:
2286-
if hasattr(GLOBAL_RUN_CONTEXT, "runner"):
2286+
if in_trio_run():
22872287
raise RuntimeError("Attempted to call run() from inside a run()")
22882288

22892289
if clock is None:
@@ -2832,8 +2832,9 @@ def unrolled_run(
28322832
except BaseException as exc:
28332833
raise TrioInternalError("internal error in Trio - please file a bug!") from exc
28342834
finally:
2835-
GLOBAL_RUN_CONTEXT.__dict__.clear()
28362835
runner.close()
2836+
GLOBAL_RUN_CONTEXT.__dict__.clear()
2837+
28372838
# Have to do this after runner.close() has disabled KI protection,
28382839
# because otherwise there's a race where ki_pending could get set
28392840
# after we check it.
@@ -2952,6 +2953,24 @@ async def checkpoint_if_cancelled() -> None:
29522953
task._cancel_points += 1
29532954

29542955

2956+
def in_trio_run() -> bool:
2957+
"""Check whether we are in a Trio run.
2958+
This returns `True` if and only if :func:`~trio.current_time` will succeed.
2959+
2960+
See also the discussion of differing ways of :ref:`detecting Trio <trio_contexts>`.
2961+
"""
2962+
return hasattr(GLOBAL_RUN_CONTEXT, "runner")
2963+
2964+
2965+
def in_trio_task() -> bool:
2966+
"""Check whether we are in a Trio task.
2967+
This returns `True` if and only if :func:`~trio.lowlevel.current_task` will succeed.
2968+
2969+
See also the discussion of differing ways of :ref:`detecting Trio <trio_contexts>`.
2970+
"""
2971+
return hasattr(GLOBAL_RUN_CONTEXT, "task")
2972+
2973+
29552974
if sys.platform == "win32":
29562975
from ._generated_io_windows import *
29572976
from ._io_windows import (

src/trio/_core/_tests/test_guest_mode.py

+20
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,26 @@ async def synchronize() -> None:
264264
sniffio_library.name = None
265265

266266

267+
def test_guest_mode_trio_context_detection() -> None:
268+
def check(thing: bool) -> None:
269+
assert thing
270+
271+
assert not trio.lowlevel.in_trio_run()
272+
assert not trio.lowlevel.in_trio_task()
273+
274+
async def trio_main(in_host: InHost) -> None:
275+
for _ in range(2):
276+
assert trio.lowlevel.in_trio_run()
277+
assert trio.lowlevel.in_trio_task()
278+
279+
in_host(lambda: check(trio.lowlevel.in_trio_run()))
280+
in_host(lambda: check(not trio.lowlevel.in_trio_task()))
281+
282+
trivial_guest_run(trio_main)
283+
assert not trio.lowlevel.in_trio_run()
284+
assert not trio.lowlevel.in_trio_task()
285+
286+
267287
def test_warn_set_wakeup_fd_overwrite() -> None:
268288
assert signal.set_wakeup_fd(-1) == -1
269289

src/trio/_core/_tests/test_instrumentation.py

+47
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,50 @@ async def main() -> None:
266266
assert "task_exited" not in runner.instruments
267267

268268
_core.run(main)
269+
270+
271+
def test_instrument_call_trio_context() -> None:
272+
called = set()
273+
274+
class Instrument(_abc.Instrument):
275+
pass
276+
277+
hooks = {
278+
# not run in task context
279+
"after_io_wait": (True, False),
280+
"before_io_wait": (True, False),
281+
"before_run": (True, False),
282+
"after_run": (True, False),
283+
# run in task context
284+
"before_task_step": (True, True),
285+
"after_task_step": (True, True),
286+
"task_exited": (True, True),
287+
# depends
288+
"task_scheduled": (True, None),
289+
"task_spawned": (True, None),
290+
}
291+
for hook, val in hooks.items():
292+
293+
def h(
294+
self: Instrument,
295+
*args: object,
296+
hook: str = hook,
297+
val: tuple[bool, bool | None] = val,
298+
) -> None:
299+
fail_str = f"failed in {hook}"
300+
301+
assert _core.in_trio_run() == val[0], fail_str
302+
if val[1] is not None:
303+
assert _core.in_trio_task() == val[1], fail_str
304+
called.add(hook)
305+
306+
setattr(Instrument, hook, h)
307+
308+
async def main() -> None:
309+
await _core.checkpoint()
310+
311+
async with _core.open_nursery() as nursery:
312+
nursery.start_soon(_core.checkpoint)
313+
314+
_core.run(main, instruments=[Instrument()])
315+
assert called == set(hooks)

src/trio/_core/_tests/test_run.py

+31
Original file line numberDiff line numberDiff line change
@@ -2855,3 +2855,34 @@ def run(self, fn: Callable[[], object]) -> object:
28552855

28562856
with mock.patch("trio._core._run.copy_context", return_value=Context()):
28572857
assert _count_context_run_tb_frames() == 1
2858+
2859+
2860+
@restore_unraisablehook()
2861+
def test_trio_context_detection() -> None:
2862+
assert not _core.in_trio_run()
2863+
assert not _core.in_trio_task()
2864+
2865+
def inner() -> None:
2866+
assert _core.in_trio_run()
2867+
assert _core.in_trio_task()
2868+
2869+
def sync_inner() -> None:
2870+
assert not _core.in_trio_run()
2871+
assert not _core.in_trio_task()
2872+
2873+
def inner_abort(_: object) -> _core.Abort:
2874+
assert _core.in_trio_run()
2875+
assert _core.in_trio_task()
2876+
return _core.Abort.SUCCEEDED
2877+
2878+
async def main() -> None:
2879+
assert _core.in_trio_run()
2880+
assert _core.in_trio_task()
2881+
2882+
inner()
2883+
2884+
await to_thread_run_sync(sync_inner)
2885+
with _core.CancelScope(deadline=_core.current_time() - 1):
2886+
await _core.wait_task_rescheduled(inner_abort)
2887+
2888+
_core.run(main)

src/trio/lowlevel.py

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
currently_ki_protected as currently_ki_protected,
3838
disable_ki_protection as disable_ki_protection,
3939
enable_ki_protection as enable_ki_protection,
40+
in_trio_run as in_trio_run,
41+
in_trio_task as in_trio_task,
4042
notify_closing as notify_closing,
4143
permanently_detach_coroutine_object as permanently_detach_coroutine_object,
4244
reattach_detached_coroutine_object as reattach_detached_coroutine_object,

0 commit comments

Comments
 (0)