Releases: patrick-kidger/jaxtyping
jaxtyping v0.2.26
Features
- Added
jaxtyping.print_bindings
to manually inspect the values of each axis, whilst inside a function. - Added support for
jaxtyping.{Int4, UInt4}
. (#174, thanks @jianlijianli!)
Bugfixes
- We no longer import JAX at all, even if it is present. This ensures compatibility when using jaxtyping+PyTorch alongside an old JAX installation. (All JAX re-exports, like
jaxtyping.Array = jax.Array
, are looked up dynamically rather than import time.) (#178) - We no longer raise false postiives when
@jaxtyped
-ing generators (withyield
statements). (#91, #171, thanks @knyazer!)
Internals
- Added support for beartype's pseudostandard
__instancecheck_str__
method. Instead ofisinstance(x, Float[Array, "foo"])
, then one can now callFloat[Array, "foo"].__instancecheck_str__(x)
, which will return either an empty string (success) or an error message describing why the check failed (wrong shape, wrong dtype, ...). In practice this feature probably isn't super usable right now; we'll need to wait until we've later done a better job ensuring compatibility between the jaxtyping import hooks and the beartype import hooks.
Docs
- Fixes by @jeertmans (#154) and @afrozenator (#170) -- thank you!
New Contributors
- @jeertmans made their first contribution in #154
- @afrozenator made their first contribution in #170
Full Changelog: v0.2.25...v0.2.26
jaxtyping v0.2.25
This release is primarily a usability release, designed to help ensure the library is being used correctly.
- The error messages from a failed typecheck have been improved, to explicitly highlight more information about which argument was wrong. :)
- If the
jaxtyping.jaxtyped(typechecker=...)
argument is not passed, then a warning will be displayed. In practice, this will trigger:- if using the old double-decorator syntax (
@jaxtyped @beartype def foo(...): ...
) -- upgrade to the new@jaxtyped(typechecker=beartype) def foo(...): ...
syntax and get better error messages! :) - If making the easy mistake of writing
@jaxtyped(beartype) def foo(...): ...
-- in this case it's actually thebeartype
call that is jaxtype'd, notfoo
.
- if using the old double-decorator syntax (
- Incorrect use of jaxtyping annotations will now raise an
jaxtyping.AnnotationError
rather than a mix ofRuntimeError
s,NameError
s etc. For exampleisinstance(x, Float)
is not correct (you should write something likeFloat[Array, "..."]
) instead), and this will raise such anAnnotationError
. - Introduced two config flags:
JAXTYPING_DISABLE=1
/jaxtyping.config.update("jaxtyping_disable", True)
: if enabled then all runtime type checking will be skipped.JAXTYPING_REMOVE_TYPECHECKER_STACK=1
/jaxtyping.config.update("jaxtyping_remove_typechecker_stack", True)
: if enabled then type-checking errors will only show thejaxtyping.TypeCheckError
, and won't include any extra stack trace from the underlying type-checker (beartype/typeguard). Some users have found that they preferred the conciseness over the extra information.
Full Changelog: v0.2.24...v0.2.25
jaxtyping v0.2.24
New features
- Error messages will now include useful shape information for debugging. (!!!) This closes the venerable #6, which is is one of the oldest feature requests for jaxtyping. This is enabled by using the following syntax, instead of the old double-decorator syntax:
and moreover this is what
from jaxtyping import jaxtyped from beartype/typeguard import beartype/typechecked as typechecker @jaxtyped(typechecker=typechecker) # passing as keyword argument is important def foo(...): ...
install_import_hook
now does.
As an example of this done, consider this buggy code:will now produce the error messageimport jax.numpy as jnp from jaxtyping import Array, Float, jaxtyped from beartype import beartype @jaxtyped(typechecker=beartype) def f(x: Float[Array, "foo bar"], y: Float[Array, "foo"]): ... f(jnp.zeros((3, 4)), jnp.zeros(5))
Hurrah! I'm really glad to have this important quality-of-life improvement in. (#6, #138)jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of f. The problem arose whilst typechecking argument 'y'. Called with arguments: {'x': f32[3,4], 'y': f32[5]} Parameter annotations: (x: Float[Array, 'foo bar'], y: Float[Array, 'foo']). The current values for each jaxtyping axis annotation are as follows. foo=3 bar=4
- Added support for the following:
in which axis names enclosed in
def make_zeros(size: int) -> Float[Array, "{size}"]: return jnp.zeros(size)
{...}
are evaluated as f-strings using the value of the argument of the function. This closes the long-standing feature request #93. (#93, #140) (Heads-up @MilesCranmer!) - Added support for declaring PyTree structures, which like array shapes must match across all arguments. For example
demands that
def f(x: PyTree[int, "T"], y: PyTree[float, "T"])
x
andy
be PyTrees with the samejax.tree_util.tree_structure
as each other. (#135) - Added support for treepath-dependent sizes using
?
. This makes it possible for the value of a dimension to vary across its position within a pytree, but must still be consistent with its value in other pytrees of the same structure. Such annotations look likePyTree[Float[Array, "?foo"], "T"]
. Together with the previous point, this means that you can now declare that two pytrees must have the exact same structure and array shapes as each other: usePyTree[Float[Array, "?*shape"], "T"]
as the annotation for both. (#136) - Added
jaxtyping.Real
, which admits any float, signed integer, or unsigned integer. (But not bools or complexes.) (#128) - If JAX is installed, then
jaxtyping.DTypeLike
is now available (it is just a forwarding on ofjax.typing.DTypeLike
). (#129)
Bugfixes
- Fixed no error being raised when having mismatched variadic+broadcast and variadic+nonbroadcast dimensions; see #134 for details. (#134)
- Fixed
jaxtyping.Key
not being compatible with the new-stylejax.random.key
. (As opposed to the old-stylejax.random.PRNGKey
.) (#142, #143) - Fixed
install_import_hook(..., None)
crashing (#145, #146). - Variadic shapes combined with
bool
/int
/float
/complex
now work correctly, e.g.Float[float, "..."]
is now valid (and equivalent to justfloat
). This is useful in particular forFloat[ArrayLike, "..."]
to work correctly (asArrayLike
includesfloat
). (#133)
Better error messages
- The error message due to a nonexist symbolic dimension -- e.g.
def f(x: Float[Array, "dim*2"])
leavesdim
unspecified -- are now fixed. (#131) - The error message due to the wrong dataclass attribute type -- e.g.
will now correctly include the
@dataclass class Foo: attribute_name: int Foo("strings are not integers")
attribute_name
. (#132)
Note that this release may result in new errors being raised, due to the inclusion of #134. If so then you then the appropriate thing to do is to fix your code -- this is a correct error that jaxtyping was previously failing to raise.
Full Changelog: v0.2.23...v0.2.24
jaxtyping v0.2.23
Changes
- The import hook is now compatible with
equinox.field(converter=...)
. More precisely: the import hook no longer checks the__init__
method of dataclasses. Instead, it checks that each attribute matches its type annotation, after__init__
has run. - jaxtyping now requires typeguard version
v2.*
, and explictly disallows later versions (v3 and v4), as these are known to be buggy. (Thanks @knyazer! #124)
Crash fixes
- Now robust to some crashes induced with varying jax/numpy/tensorflow versions (#115).
- The import hook is now tolerant to additionally manually importing beartype/typeguard. (Thanks @knyazer! #116)
- The package is now tolerant to faulty IPython installs (#117)
Full Changelog: v0.2.22...v0.2.23
jaxtyping v0.2.22
- jaxtyping now offers an IPython extension. (Thanks @knyazer! #112)
This means that you can now write the following at the top of your IPython/Jupyter/Colab notebook, and have everything you write be automatically type-checked:import jaxtyping %load_ext jaxtyping %jaxtyping.typechecker beartype.beartype # or any other runtime type checker, e.g. typeguard
- Forward compatibility with JAX's upcoming changes to PRNGs.
jaxtyping.PRNGKeyArray
will match against either old-stylejax.random.PRNGKey
and new-stylejax.random.key
. Meanwhilejaxtyping.Key[Array, ...]
will match against only new-stylejax.random.key
s. (#109) - Better error message when doing just
Float[Array]
. (#110) - Now robust to JAX installations that aren't installed properly. (E.g. not supported on current hardware.) (#111)
New Contributors
Full Changelog: v0.2.21...v0.2.22
jaxtyping v0.2.21
- Fix for
__pycache__
filling up with lots of redundant entries. (#102, #103) - Compatibility with future versions of JAX (whatever version exists ~3 months from now), when JAX's way of detecting PRNGKeys (
jax.core.is_opaque_dtype
) will be deprecated and changed. (#98)
Full Changelog: v0.2.20...v0.2.21
jaxtyping v0.2.20
- Added
jaxtyping.PyTreeDef
type. - Can now detect
x = jaxtyping.PyTree[foo]
viaissubclass(x, jaxtyping.PyTree)
. - Fixed #89, in which
__builtins__
was getting added as an extra key to the memo stack. - Renamed modules with a leading underscore to indicate that they're private.
- Bump minimum Python version to 3.9.
Full Changelog: v0.2.19...v0.2.20
jaxtyping v0.2.19
- Proper documentation! Not just markdown files on GitHub any more. Check out https://docs.kidger.site/jaxtyping.
- Added
jaxtyping.{PRNGKeyArray,Scalar,ScalarLike}
- Can now nest, e.g.
Image = Float[Array, "channels height width"] BatchImage = Float[Image, "batch"]
- Now packaging in the modern way with pyproject.toml.
- Dtypes can now match regexes (e.g. used in keys to match
^key<\w+>$
),
Full Changelog: v0.2.15...v0.2.19
jaxtyping v0.2.18
(Yanked; broke the pytest hook. Prefer v0.2.19 instead.)
Full Changelog: v0.2.15...v0.2.18
jaxtyping v0.2.17
(Yanked; had incompatibility with non-JAX installations. Prefer v0.2.19 instead.)
Full Changelog: v0.2.15...v0.2.17