Skip to content

Commit

Permalink
Automatically profile newly spawned threads (#17)
Browse files Browse the repository at this point in the history
* Proof of Concept support for multithreading

* Allow the tid to be specified with the `render` method

* Store the ThreadProfiler run stack depth and profiler in thread local storage

* Remove useless `ThreadProfiler` in `pytest_pyfunc_call`

* Emit a warning log when a test ends with a an active thread

* Test that multiple threads are actually recorded

* Do not recalculate if the trace invocation was from `threading.Thread.run`

* Document ThreadProfiler.__call__
  • Loading branch information
lucamuscat authored Dec 28, 2024
1 parent e663c8f commit 2c7c134
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 5 deletions.
5 changes: 3 additions & 2 deletions perfsephone/perfetto_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def remove_pytest_related_frames(root_frame: Frame) -> Sequence[Frame]:
return [root_frame]


def render(session: Session, start_time: float) -> List[SerializableEvent]:
def render(session: Session, start_time: float, tid: int = 1) -> List[SerializableEvent]:
renderer = SpeedscopeRenderer()
root_frame = session.root_frame()
if root_frame is None:
Expand Down Expand Up @@ -103,13 +103,14 @@ def render_root_frame(root_frame: Frame) -> List[SerializableEvent]:
cat=Category("runtime"),
ts=timestamp,
args={"file": file or "", "line": str(line or 0), "name": name or ""},
tid=tid,
)
)
elif (
speedscope_event.type == SpeedscopeEventType.CLOSE
and name not in SYNTHETIC_LEAF_IDENTIFIERS
):
result.append(EndDurationEvent(ts=timestamp))
result.append(EndDurationEvent(ts=timestamp, tid=tid))
return result

for root in new_roots:
Expand Down
97 changes: 94 additions & 3 deletions perfsephone/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,25 @@

import inspect
import json
import logging
import threading
from contextlib import contextmanager
from dataclasses import asdict
from itertools import chain
from pathlib import Path
from typing import Any, Dict, Final, Generator, List, Optional, Sequence, Tuple, Union
from types import FrameType
from typing import (
Any,
Dict,
Final,
Generator,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)

import pyinstrument
import pytest
Expand All @@ -25,6 +40,51 @@
from perfsephone.perfetto_renderer import render

PERFETTO_ARG_NAME: Final[str] = "perfetto_path"
logger = logging.getLogger(__name__)


class ThreadProfiler:
def __init__(self) -> None:
self.thread_local = threading.local()
self.profilers: Dict[int, pyinstrument.Profiler] = {}

def __call__(
self,
frame: FrameType,
event: Literal["call", "line", "return", "exception", "opcode"],
_args: Any,
) -> Any:
"""This method should only be used with `threading.settrace()`."""
frame.f_trace_lines = False

is_frame_from_thread_run: bool = (
frame.f_code.co_name == "run" and frame.f_code.co_filename == threading.__file__
)

# Detect when `Thread.run()` is called
if event == "call" and is_frame_from_thread_run:
# If this is the first time `Thread.run()` is being called on this thread, start the
# profiler.
if getattr(self.thread_local, "run_stack_depth", 0) == 0:
profiler = pyinstrument.Profiler(async_mode="disabled")
self.thread_local.profiler = profiler
self.thread_local.run_stack_depth = 0
profiler.start()
# Keep track of the number of active calls of `Thread.run()`.
self.thread_local.run_stack_depth += 1
return self.__call__

# Detect when `Threading.run()` returns.
if event == "return" and is_frame_from_thread_run:
self.thread_local.run_stack_depth -= 1
# When there are no more active invocations of `Thread.run()`, this implies that the
# target of the thread being profiled has finished executing.
if self.thread_local.run_stack_depth == 0:
assert hasattr(
self.thread_local, "profiler"
), "because a profiler must have been started"
self.thread_local.profiler.stop()
self.profilers[threading.get_ident()] = self.thread_local.profiler


class PytestPerfettoPlugin:
Expand All @@ -44,6 +104,17 @@ def __profile(
result: List[SerializableEvent] = []
start_event = BeginDurationEvent(name=root_frame_name, cat=Category("test"), args=args)

thread_profiler = ThreadProfiler()

# We use `threading.settrace`, as opposed to `threading.setprofile`, as
# `pyinstrument.Profiler().start()` calls `threading.setprofile` under the hood, overriding
# our profiling function.
#
# `threading.settrace` & `threading.setprofile` provides a rather convoluted mechanism of
# starting a pyinstrument profiler as soon as a thread starts executing its `run()` method,
# & stopping said profiler once the `run()` method finishes.
threading.settrace(thread_profiler) # type: ignore

result.append(start_event)
profiler_async_mode = "enabled" if is_async else "disabled"
with pyinstrument.Profiler(async_mode=profiler_async_mode) as profile:
Expand All @@ -52,8 +123,28 @@ def __profile(
start_rendering_event = BeginDurationEvent(
name="[pytest-perfetto] Dumping frames", cat=Category("pytest")
)
if profile.last_session is not None:
result += render(profile.last_session, start_time=start_event.ts)

threading.settrace(None) # type: ignore

profiles_to_render = (
profile
for profile in chain([profile], thread_profiler.profilers.values())
if profile.last_session
)

for index, profiler in enumerate(profiles_to_render, start=1):
if profiler.is_running:
logger.warning(
"There exists a run-away thread which has not been joined after the end of the"
" test.The thread's profiler will be discarded."
)
elif profiler.last_session:
result += render(
session=profiler.last_session,
start_time=profiler.last_session.start_time,
tid=index,
)

end_rendering_event = EndDurationEvent()
result += [end_event, start_rendering_event, end_rendering_event]

Expand Down
39 changes: 39 additions & 0 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
from pathlib import Path
from typing import Final

import pytest

Expand Down Expand Up @@ -36,3 +38,40 @@ def test_hello(some_fixture) -> None:
result = pytester.runpytest_subprocess(f"--perfetto={temp_perfetto_file_path}")
result.assert_outcomes(passed=1)
assert temp_perfetto_file_path.exists()


def test_given_multiple_threads__then_multiple_distinct_tids_are_reported(
pytester: pytest.Pytester, temp_perfetto_file_path: Path
) -> None:
pytester.makepyfile("""
import threading
import time
SLEEP_TIME_S = 0.002
def test_hello() -> None:
def foo() -> None:
def bar() -> None:
time.sleep(SLEEP_TIME_S)
thread = threading.Thread(target=bar)
thread.start()
thread.join()
thread = threading.Thread(target=foo)
thread.start()
thread.join()
""")
pytester.runpytest_subprocess(f"--perfetto={temp_perfetto_file_path}").assert_outcomes(passed=1)
trace_file = json.load(temp_perfetto_file_path.open("r"))
EXPECTED_DISTINCT_TID_COUNT: Final[int] = 3

assert (
len(
{
event["tid"]
for event in trace_file
if event.get("name", "") in ["foo", "bar", "test_hello"]
}
)
== EXPECTED_DISTINCT_TID_COUNT
)

0 comments on commit 2c7c134

Please sign in to comment.