Skip to content

Commit

Permalink
JAX 0.4.33 changes how errors work, which broke the nice error_if mes…
Browse files Browse the repository at this point in the history
…sages. Hopefully this should be cross-version compatible.
  • Loading branch information
patrick-kidger committed Oct 1, 2024
1 parent f687b9f commit 141c18d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
hooks:
- id: ruff-format # formatter
types_or: [python, pyi, jupyter, toml]
- id: ruff # linter
types_or: [python, pyi, jupyter, toml]
args: [--fix]
- id: ruff-format # formatter
types_or: [python, pyi, jupyter, toml]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.379
hooks:
Expand Down
37 changes: 32 additions & 5 deletions equinox/_jit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import atexit
import functools as ft
import inspect
import logging
Expand All @@ -7,8 +8,10 @@
from typing_extensions import ParamSpec

import jax
import jax._src.dispatch
import jax._src.traceback_util as traceback_util
import jax.core
import jax.numpy as jnp
from jaxtyping import PyTree

from ._compile_utils import (
Expand Down Expand Up @@ -50,7 +53,8 @@ def fun_wrapped(dynamic_donate, dynamic_nodonate, static):
assert dummy_arg is None
out = fun(*args, **kwargs)
dynamic_out, static_out = partition(out, is_array)
return dynamic_out, Static(static_out)
marker = jnp.array(0)
return marker, dynamic_out, Static(static_out)

fun_name, fun_qualname = fun_names
fun_wrapped.__name__ = fun_name
Expand Down Expand Up @@ -99,7 +103,7 @@ def _preprocess(info, args, kwargs, return_static: bool = False):


def _postprocess(out):
dynamic_out, static_out = out
_, dynamic_out, static_out = out
return combine(dynamic_out, static_out.value)


Expand All @@ -112,6 +116,22 @@ class XlaRuntimeError(Exception):
pass


try:
wait_for_tokens = jax._src.dispatch.wait_for_tokens
except AttributeError:
pass # forward compatibility
else:
# Fix for https://github.com/patrick-kidger/diffrax/issues/506
def wait_for_tokens2():
try:
wait_for_tokens()
except XlaRuntimeError:
pass

atexit.unregister(wait_for_tokens)
atexit.register(wait_for_tokens2)


# This is the class we use to raise runtime errors from `eqx.error_if`.
class EquinoxRuntimeError(RuntimeError):
pass
Expand Down Expand Up @@ -178,7 +198,8 @@ def __wrapped__(self):
def _call(self, is_lower, args, kwargs):
__tracebackhide__ = True
# Used by our error messages when figuring out where to stop walking the stack.
if not currently_jitting():
jitting = currently_jitting()
if not jitting:
__equinox_filter_jit__ = True # noqa: F841
info = (
self._signature,
Expand Down Expand Up @@ -207,9 +228,15 @@ def _call(self, is_lower, args, kwargs):
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable*"
)
out = self._cached(dynamic_donate, dynamic_nodonate, static)
marker, _, _ = out = self._cached(
dynamic_donate, dynamic_nodonate, static
)
else:
out = self._cached(dynamic_donate, dynamic_nodonate, static)
marker, _, _ = out = self._cached(
dynamic_donate, dynamic_nodonate, static
)
if not jitting:
marker.block_until_ready()
except XlaRuntimeError as e:
# Catch Equinox's runtime errors, and re-raise them with actually useful
# information. (By default XlaRuntimeError produces a lot of terrifying
Expand Down

0 comments on commit 141c18d

Please sign in to comment.