Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call typechecker from isolated frame #260

Merged
merged 3 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ def __repr__(self):
_tb_flag = True


def _apply_typechecker(typechecker, fn):
"""Calls `typechecker(fn)` in an isolated frame, returning the result.

This avoids reference cycles that can otherwise occur if `typechecker` grabs
the calling frame's locals.
"""
return typechecker(fn)


@overload
def jaxtyped(
*,
Expand Down Expand Up @@ -422,8 +431,8 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore
param_fn = _make_fn_with_signature(
name, qualname, module, param_signature, output=False
)
full_fn = typechecker(full_fn)
param_fn = typechecker(param_fn)
full_fn = _apply_typechecker(typechecker, full_fn)
param_fn = _apply_typechecker(typechecker, param_fn)

def wrapped_fn_impl(args, kwargs, bound, memos):
# First type-check just the parameters before the function is
Expand Down Expand Up @@ -790,7 +799,9 @@ def _get_problem_arg(
fn = _make_fn_with_signature(
"check_single_arg", "check_single_arg", module, new_signature, output=False
)
fn = typechecker(fn) # but no `jaxtyped`; keep the same environment.
fn = _apply_typechecker(
typechecker, fn
) # but no `jaxtyped`; keep the same environment.
try:
fn(*args, **kwargs)
except Exception as e:
Expand Down
31 changes: 26 additions & 5 deletions test/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import abc
import dataclasses
import sys
from typing import no_type_check

import jax.numpy as jnp
import jax.random as jr
import pytest
import typeguard

from jaxtyping import Array, Float, jaxtyped, print_bindings

Expand Down Expand Up @@ -213,10 +213,6 @@ def g(x: Float[Array, "foo bar"]):


def test_no_garbage(typecheck):
if typecheck is typeguard.typechecked:
# Currently fails due to reference cycles in typeguard.
pytest.skip()

with assert_no_garbage():

@jaxtyped(typechecker=typecheck)
Expand All @@ -236,3 +232,28 @@ class _Obj:
x: int

_Obj(x=5)


def test_no_garbage_frame_capture_typecheck():
with assert_no_garbage():
# Some typechecker implementations (e.g., typeguard 2.13.3) capture the calling
# frame's f_locals. This test checks that the calling frames in jaxtyping are
# sufficiently isolated to avoid introducing reference cycles when a
# typechecker does this.
def frame_locals_capture(fn):
locals = sys._getframe(1).f_locals

def wrapper(*args, **kwargs):
# Required to ensure wrapper holds a reference to f_locals, which is
# the scenario under test.
_ = locals
return fn(*args, **kwargs)

return wrapper

@jaxtyped(typechecker=frame_locals_capture)
@dataclasses.dataclass
class _Obj:
x: int

_Obj(x=5)
Loading