Skip to content

Commit

Permalink
Half compatibility with typeguard v4.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Nov 19, 2024
1 parent 4307e19 commit 520daae
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 27 deletions.
2 changes: 1 addition & 1 deletion docs/api/runtime-type-checking.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Runtime type checking **synergises beautifully with `jax.jit`!** All shape check

There are two approaches: either use [`jaxtyping.jaxtyped`][] to typecheck a single function, or [`jaxtyping.install_import_hook`][] to typecheck a whole codebase.

In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are [beartype](https://github.com/beartype/beartype) and [typeguard](https://github.com/agronholm/typeguard). (If using typeguard, then specifically the version `2.*` series should be used. Later versions -- `3` and `4` -- have some known issues.)
In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are [beartype](https://github.com/beartype/beartype) and [typeguard](https://github.com/agronholm/typeguard).

!!! warning

Expand Down
27 changes: 27 additions & 0 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import enum
import functools as ft
import importlib.metadata
import importlib.util
import re
import sys
Expand Down Expand Up @@ -738,6 +739,28 @@ def __init_subclass__(cls, **kwargs):
_complex128 = "complex128"


# Workaround a longstanding bug in typeguard v4, by monkeypatching their internals.
# https://stackoverflow.com/questions/79201839/hello-world-for-jaxtyping/79205145#79205145
# https://github.com/patrick-kidger/jaxtyping/issues/80
# https://github.com/agronholm/typeguard/issues/353
# This is as robust as I can make it to future changes in typeguard, I think.
typeguard_v4_compat = False
try:
typeguard_distribution = importlib.metadata.distribution("typeguard")
except importlib.metadata.PackageNotFoundError:
pass
else:
if typeguard_distribution.version.split(".", 1)[0] == "4":
if importlib.util.find_spec("typeguard._transformer") is not None:
import typeguard._transformer

if hasattr(typeguard._transformer, "annotated_names"):
annotated_names = typeguard._transformer.annotated_names
if type(annotated_names) is tuple:
if all(type(x) is str for x in annotated_names):
typeguard_v4_compat = True


def _make_dtype(_dtypes, name):
class _Cls(AbstractDtype):
dtypes = _dtypes
Expand All @@ -748,6 +771,10 @@ class _Cls(AbstractDtype):
_Cls.__module__ = "builtins"
else:
_Cls.__module__ = "jaxtyping"
if typeguard_v4_compat:
typeguard._transformer.annotated_names = (
typeguard._transformer.annotated_names + (f"jaxtyping.{name}",)
)
return _Cls


Expand Down
24 changes: 19 additions & 5 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,13 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore
module = getattr(fn, "__module__", "<generated_by_jaxtyping>")

# Use the same name so that typeguard warnings look correct.
# Set the line number so that typeguard v4 finds us.
lineno = getattr(getattr(fn, "__code__", None), "co_firstlineno", 1)
full_fn, output_name = _make_fn_with_signature(
name, qualname, module, full_signature, output=True
name, qualname, module, full_signature, output=True, lineno=lineno
)
param_fn = _make_fn_with_signature(
name, qualname, module, param_signature, output=False
name, qualname, module, param_signature, output=False, lineno=lineno
)
full_fn = _apply_typechecker(typechecker, full_fn)
param_fn = _apply_typechecker(typechecker, param_fn)
Expand Down Expand Up @@ -616,13 +618,19 @@ def _check_dataclass_annotations(self, typechecker):
self.__class__.__module__,
signature,
output=False,
lineno=1,
)
f = jaxtyped(f, typechecker=typechecker)
f(self, **values)


def _make_fn_with_signature(
name: str, qualname: str, module: str, signature: inspect.Signature, output: bool
name: str,
qualname: str,
module: str,
signature: inspect.Signature,
output: bool,
lineno: int,
):
"""Dynamically creates a function `fn` with name `name` and signature `signature`.
Expand Down Expand Up @@ -740,7 +748,8 @@ def _make_fn_with_signature(
else:
retstr = f"-> {name_to_annotation['return']}"

fnstr = f"def {name}({argstr}){retstr}:\n {outstr}"
newlines = "\n" * (lineno - 1)
fnstr = f"{newlines}def {name}({argstr}){retstr}:\n {outstr}"
exec(fnstr, scope)
fn = scope[name]
del scope[name] # Avoids introducing a reference cycle.
Expand Down Expand Up @@ -802,7 +811,12 @@ def _get_problem_arg(
assert keep_annotation is not sentinel
new_signature = inspect.Signature(new_parameters)
fn = _make_fn_with_signature(
"check_single_arg", "check_single_arg", module, new_signature, output=False
"check_single_arg",
"check_single_arg",
module,
new_signature,
output=False,
lineno=1,
)
fn = _apply_typechecker(
typechecker, fn
Expand Down
2 changes: 1 addition & 1 deletion test/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ numpy<2
pytest
pytest-asyncio
tensorflow
typeguard<3
typeguard
20 changes: 0 additions & 20 deletions test/test_import_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,6 @@
_here = pathlib.Path(__file__).parent


try:
typeguard_version = importlib.metadata.version("typeguard")
except Exception as e:
raise ImportError("Could not find typeguard version") from e
else:
try:
major, _, _ = typeguard_version.split(".")
major = int(major)
except Exception as e:
raise ImportError(
f"Unexpected typeguard version {typeguard_version}; not formatted as "
"`major.minor.patch`"
) from e
if major != 2:
raise ImportError(
"jaxtyping's tests required typeguard version 2. (Versions 3 and 4 are both "
"known to have bugs.)"
)


assert not hasattr(jaxtyping, "_test_import_hook_counter")
jaxtyping._test_import_hook_counter = 0

Expand Down

0 comments on commit 520daae

Please sign in to comment.