Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically profile newly spawned threads #17

Merged
merged 8 commits into from
Dec 28, 2024
5 changes: 3 additions & 2 deletions perfsephone/perfetto_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def remove_pytest_related_frames(root_frame: Frame) -> Sequence[Frame]:
return [root_frame]


def render(session: Session, start_time: float) -> List[SerializableEvent]:
def render(session: Session, start_time: float, tid: int = 1) -> List[SerializableEvent]:
renderer = SpeedscopeRenderer()
root_frame = session.root_frame()
if root_frame is None:
Expand Down Expand Up @@ -103,13 +103,14 @@ def render_root_frame(root_frame: Frame) -> List[SerializableEvent]:
cat=Category("runtime"),
ts=timestamp,
args={"file": file or "", "line": str(line or 0), "name": name or ""},
tid=tid,
)
)
elif (
speedscope_event.type == SpeedscopeEventType.CLOSE
and name not in SYNTHETIC_LEAF_IDENTIFIERS
):
result.append(EndDurationEvent(ts=timestamp))
result.append(EndDurationEvent(ts=timestamp, tid=tid))
return result

for root in new_roots:
Expand Down
97 changes: 94 additions & 3 deletions perfsephone/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,25 @@

import inspect
import json
import logging
import threading
from contextlib import contextmanager
from dataclasses import asdict
from itertools import chain
from pathlib import Path
from typing import Any, Dict, Final, Generator, List, Optional, Sequence, Tuple, Union
from types import FrameType
from typing import (
Any,
Dict,
Final,
Generator,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)

import pyinstrument
import pytest
Expand All @@ -25,6 +40,51 @@
from perfsephone.perfetto_renderer import render

PERFETTO_ARG_NAME: Final[str] = "perfetto_path"
logger = logging.getLogger(__name__)


class ThreadProfiler:
def __init__(self) -> None:
self.thread_local = threading.local()
self.profilers: Dict[int, pyinstrument.Profiler] = {}

def __call__(
self,
frame: FrameType,
event: Literal["call", "line", "return", "exception", "opcode"],
_args: Any,
) -> Any:
"""This method should only be used with `threading.settrace()`."""
frame.f_trace_lines = False

is_frame_from_thread_run: bool = (
frame.f_code.co_name == "run" and frame.f_code.co_filename == threading.__file__
)

# Detect when `Thread.run()` is called
if event == "call" and is_frame_from_thread_run:
# If this is the first time `Thread.run()` is being called on this thread, start the
# profiler.
if getattr(self.thread_local, "run_stack_depth", 0) == 0:
profiler = pyinstrument.Profiler(async_mode="disabled")
self.thread_local.profiler = profiler
self.thread_local.run_stack_depth = 0
profiler.start()
# Keep track of the number of active calls of `Thread.run()`.
self.thread_local.run_stack_depth += 1
return self.__call__

# Detect when `Threading.run()` returns.
if event == "return" and is_frame_from_thread_run:
self.thread_local.run_stack_depth -= 1
# When there are no more active invocations of `Thread.run()`, this implies that the
# target of the thread being profiled has finished executing.
if self.thread_local.run_stack_depth == 0:
assert hasattr(
self.thread_local, "profiler"
), "because a profiler must have been started"
self.thread_local.profiler.stop()
self.profilers[threading.get_ident()] = self.thread_local.profiler


class PytestPerfettoPlugin:
Expand All @@ -44,6 +104,17 @@ def __profile(
result: List[SerializableEvent] = []
start_event = BeginDurationEvent(name=root_frame_name, cat=Category("test"), args=args)

thread_profiler = ThreadProfiler()

# We use `threading.settrace`, as opposed to `threading.setprofile`, as
# `pyinstrument.Profiler().start()` calls `threading.setprofile` under the hood, overriding
# our profiling function.
#
# `threading.settrace` & `threading.setprofile` provides a rather convoluted mechanism of
# starting a pyinstrument profiler as soon as a thread starts executing its `run()` method,
# & stopping said profiler once the `run()` method finishes.
threading.settrace(thread_profiler) # type: ignore

result.append(start_event)
profiler_async_mode = "enabled" if is_async else "disabled"
with pyinstrument.Profiler(async_mode=profiler_async_mode) as profile:
Expand All @@ -52,8 +123,28 @@ def __profile(
start_rendering_event = BeginDurationEvent(
name="[pytest-perfetto] Dumping frames", cat=Category("pytest")
)
if profile.last_session is not None:
result += render(profile.last_session, start_time=start_event.ts)

threading.settrace(None) # type: ignore

profiles_to_render = (
profile
for profile in chain([profile], thread_profiler.profilers.values())
if profile.last_session
)

for index, profiler in enumerate(profiles_to_render, start=1):
if profiler.is_running:
logger.warning(
"There exists a run-away thread which has not been joined after the end of the"
" test.The thread's profiler will be discarded."
)
elif profiler.last_session:
result += render(
session=profiler.last_session,
start_time=profiler.last_session.start_time,
tid=index,
)

end_rendering_event = EndDurationEvent()
result += [end_event, start_rendering_event, end_rendering_event]

Expand Down
39 changes: 39 additions & 0 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
from pathlib import Path
from typing import Final

import pytest

Expand Down Expand Up @@ -36,3 +38,40 @@ def test_hello(some_fixture) -> None:
result = pytester.runpytest_subprocess(f"--perfetto={temp_perfetto_file_path}")
result.assert_outcomes(passed=1)
assert temp_perfetto_file_path.exists()


def test_given_multiple_threads__then_multiple_distinct_tids_are_reported(
pytester: pytest.Pytester, temp_perfetto_file_path: Path
) -> None:
pytester.makepyfile("""
import threading
import time

SLEEP_TIME_S = 0.002

def test_hello() -> None:
def foo() -> None:
def bar() -> None:
time.sleep(SLEEP_TIME_S)
thread = threading.Thread(target=bar)
thread.start()
thread.join()

thread = threading.Thread(target=foo)
thread.start()
thread.join()
""")
pytester.runpytest_subprocess(f"--perfetto={temp_perfetto_file_path}").assert_outcomes(passed=1)
trace_file = json.load(temp_perfetto_file_path.open("r"))
EXPECTED_DISTINCT_TID_COUNT: Final[int] = 3

assert (
len(
{
event["tid"]
for event in trace_file
if event.get("name", "") in ["foo", "bar", "test_hello"]
}
)
== EXPECTED_DISTINCT_TID_COUNT
)