diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c19d6015..25726931 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,11 +2,11 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.6.4 hooks: + - id: ruff-format # formatter + types_or: [python, pyi, jupyter, toml] - id: ruff # linter types_or: [python, pyi, jupyter, toml] args: [--fix] - - id: ruff-format # formatter - types_or: [python, pyi, jupyter, toml] - repo: https://github.com/RobertCraigie/pyright-python rev: v1.1.379 hooks: diff --git a/equinox/_jit.py b/equinox/_jit.py index c011b6a9..9be74b85 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -1,3 +1,4 @@ +import atexit import functools as ft import inspect import logging @@ -7,8 +8,10 @@ from typing_extensions import ParamSpec import jax +import jax._src.dispatch import jax._src.traceback_util as traceback_util import jax.core +import jax.numpy as jnp from jaxtyping import PyTree from ._compile_utils import ( @@ -50,7 +53,8 @@ def fun_wrapped(dynamic_donate, dynamic_nodonate, static): assert dummy_arg is None out = fun(*args, **kwargs) dynamic_out, static_out = partition(out, is_array) - return dynamic_out, Static(static_out) + marker = jnp.array(0) + return marker, dynamic_out, Static(static_out) fun_name, fun_qualname = fun_names fun_wrapped.__name__ = fun_name @@ -99,7 +103,7 @@ def _preprocess(info, args, kwargs, return_static: bool = False): def _postprocess(out): - dynamic_out, static_out = out + _, dynamic_out, static_out = out return combine(dynamic_out, static_out.value) @@ -112,6 +116,22 @@ class XlaRuntimeError(Exception): pass +try: + wait_for_tokens = jax._src.dispatch.wait_for_tokens +except AttributeError: + pass # forward compatibility +else: + # Fix for https://github.com/patrick-kidger/diffrax/issues/506 + def wait_for_tokens2(): + try: + wait_for_tokens() + except XlaRuntimeError: + pass + + atexit.unregister(wait_for_tokens) + atexit.register(wait_for_tokens2) + + # This is the class we use to raise runtime errors from `eqx.error_if`. class EquinoxRuntimeError(RuntimeError): pass @@ -178,7 +198,8 @@ def __wrapped__(self): def _call(self, is_lower, args, kwargs): __tracebackhide__ = True # Used by our error messages when figuring out where to stop walking the stack. - if not currently_jitting(): + jitting = currently_jitting() + if not jitting: __equinox_filter_jit__ = True # noqa: F841 info = ( self._signature, @@ -207,9 +228,15 @@ def _call(self, is_lower, args, kwargs): warnings.filterwarnings( "ignore", message="Some donated buffers were not usable*" ) - out = self._cached(dynamic_donate, dynamic_nodonate, static) + marker, _, _ = out = self._cached( + dynamic_donate, dynamic_nodonate, static + ) else: - out = self._cached(dynamic_donate, dynamic_nodonate, static) + marker, _, _ = out = self._cached( + dynamic_donate, dynamic_nodonate, static + ) + if not jitting: + marker.block_until_ready() except XlaRuntimeError as e: # Catch Equinox's runtime errors, and re-raise them with actually useful # information. (By default XlaRuntimeError produces a lot of terrifying