diff --git a/perfsephone/perfetto_renderer.py b/perfsephone/perfetto_renderer.py index 558df2d..4ce76ef 100644 --- a/perfsephone/perfetto_renderer.py +++ b/perfsephone/perfetto_renderer.py @@ -68,7 +68,7 @@ def remove_pytest_related_frames(root_frame: Frame) -> Sequence[Frame]: return [root_frame] -def render(session: Session, start_time: float) -> List[SerializableEvent]: +def render(session: Session, start_time: float, tid: int = 1) -> List[SerializableEvent]: renderer = SpeedscopeRenderer() root_frame = session.root_frame() if root_frame is None: @@ -103,13 +103,14 @@ def render_root_frame(root_frame: Frame) -> List[SerializableEvent]: cat=Category("runtime"), ts=timestamp, args={"file": file or "", "line": str(line or 0), "name": name or ""}, + tid=tid, ) ) elif ( speedscope_event.type == SpeedscopeEventType.CLOSE and name not in SYNTHETIC_LEAF_IDENTIFIERS ): - result.append(EndDurationEvent(ts=timestamp)) + result.append(EndDurationEvent(ts=timestamp, tid=tid)) return result for root in new_roots: diff --git a/perfsephone/plugin.py b/perfsephone/plugin.py index 233d527..1e048a9 100644 --- a/perfsephone/plugin.py +++ b/perfsephone/plugin.py @@ -5,10 +5,25 @@ import inspect import json +import logging +import threading from contextlib import contextmanager from dataclasses import asdict +from itertools import chain from pathlib import Path -from typing import Any, Dict, Final, Generator, List, Optional, Sequence, Tuple, Union +from types import FrameType +from typing import ( + Any, + Dict, + Final, + Generator, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, +) import pyinstrument import pytest @@ -25,6 +40,51 @@ from perfsephone.perfetto_renderer import render PERFETTO_ARG_NAME: Final[str] = "perfetto_path" +logger = logging.getLogger(__name__) + + +class ThreadProfiler: + def __init__(self) -> None: + self.thread_local = threading.local() + self.profilers: Dict[int, pyinstrument.Profiler] = {} + + def __call__( + self, + frame: FrameType, + event: Literal["call", "line", "return", "exception", "opcode"], + _args: Any, + ) -> Any: + """This method should only be used with `threading.settrace()`.""" + frame.f_trace_lines = False + + is_frame_from_thread_run: bool = ( + frame.f_code.co_name == "run" and frame.f_code.co_filename == threading.__file__ + ) + + # Detect when `Thread.run()` is called + if event == "call" and is_frame_from_thread_run: + # If this is the first time `Thread.run()` is being called on this thread, start the + # profiler. + if getattr(self.thread_local, "run_stack_depth", 0) == 0: + profiler = pyinstrument.Profiler(async_mode="disabled") + self.thread_local.profiler = profiler + self.thread_local.run_stack_depth = 0 + profiler.start() + # Keep track of the number of active calls of `Thread.run()`. + self.thread_local.run_stack_depth += 1 + return self.__call__ + + # Detect when `Threading.run()` returns. + if event == "return" and is_frame_from_thread_run: + self.thread_local.run_stack_depth -= 1 + # When there are no more active invocations of `Thread.run()`, this implies that the + # target of the thread being profiled has finished executing. + if self.thread_local.run_stack_depth == 0: + assert hasattr( + self.thread_local, "profiler" + ), "because a profiler must have been started" + self.thread_local.profiler.stop() + self.profilers[threading.get_ident()] = self.thread_local.profiler class PytestPerfettoPlugin: @@ -44,6 +104,17 @@ def __profile( result: List[SerializableEvent] = [] start_event = BeginDurationEvent(name=root_frame_name, cat=Category("test"), args=args) + thread_profiler = ThreadProfiler() + + # We use `threading.settrace`, as opposed to `threading.setprofile`, as + # `pyinstrument.Profiler().start()` calls `threading.setprofile` under the hood, overriding + # our profiling function. + # + # `threading.settrace` & `threading.setprofile` provides a rather convoluted mechanism of + # starting a pyinstrument profiler as soon as a thread starts executing its `run()` method, + # & stopping said profiler once the `run()` method finishes. + threading.settrace(thread_profiler) # type: ignore + result.append(start_event) profiler_async_mode = "enabled" if is_async else "disabled" with pyinstrument.Profiler(async_mode=profiler_async_mode) as profile: @@ -52,8 +123,28 @@ def __profile( start_rendering_event = BeginDurationEvent( name="[pytest-perfetto] Dumping frames", cat=Category("pytest") ) - if profile.last_session is not None: - result += render(profile.last_session, start_time=start_event.ts) + + threading.settrace(None) # type: ignore + + profiles_to_render = ( + profile + for profile in chain([profile], thread_profiler.profilers.values()) + if profile.last_session + ) + + for index, profiler in enumerate(profiles_to_render, start=1): + if profiler.is_running: + logger.warning( + "There exists a run-away thread which has not been joined after the end of the" + " test.The thread's profiler will be discarded." + ) + elif profiler.last_session: + result += render( + session=profiler.last_session, + start_time=profiler.last_session.start_time, + tid=index, + ) + end_rendering_event = EndDurationEvent() result += [end_event, start_rendering_event, end_rendering_event] diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 35822f5..57dc0f4 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -1,4 +1,6 @@ +import json from pathlib import Path +from typing import Final import pytest @@ -36,3 +38,40 @@ def test_hello(some_fixture) -> None: result = pytester.runpytest_subprocess(f"--perfetto={temp_perfetto_file_path}") result.assert_outcomes(passed=1) assert temp_perfetto_file_path.exists() + + +def test_given_multiple_threads__then_multiple_distinct_tids_are_reported( + pytester: pytest.Pytester, temp_perfetto_file_path: Path +) -> None: + pytester.makepyfile(""" + import threading + import time + + SLEEP_TIME_S = 0.002 + + def test_hello() -> None: + def foo() -> None: + def bar() -> None: + time.sleep(SLEEP_TIME_S) + thread = threading.Thread(target=bar) + thread.start() + thread.join() + + thread = threading.Thread(target=foo) + thread.start() + thread.join() + """) + pytester.runpytest_subprocess(f"--perfetto={temp_perfetto_file_path}").assert_outcomes(passed=1) + trace_file = json.load(temp_perfetto_file_path.open("r")) + EXPECTED_DISTINCT_TID_COUNT: Final[int] = 3 + + assert ( + len( + { + event["tid"] + for event in trace_file + if event.get("name", "") in ["foo", "bar", "test_hello"] + } + ) + == EXPECTED_DISTINCT_TID_COUNT + )