diff --git a/equinox/__init__.py b/equinox/__init__.py index d3045acf..99528f1b 100644 --- a/equinox/__init__.py +++ b/equinox/__init__.py @@ -15,4 +15,4 @@ from .update import apply_updates -__version__ = "0.3.0" +__version__ = "0.3.1" diff --git a/equinox/grad.py b/equinox/grad.py index fbc662ec..f256a0a1 100644 --- a/equinox/grad.py +++ b/equinox/grad.py @@ -25,6 +25,7 @@ def fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs): x = combine(diff_x, nondiff_x) return fun(x, *args, **kwargs) + @ft.wraps(fun) def fun_value_and_grad_wrapper(x, *args, **kwargs): diff_x, nondiff_x = partition(x, filter_spec) return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs) @@ -81,6 +82,7 @@ def grad_func(x__y): fun, filter_spec=filter_spec, **gradkwargs ) + @ft.wraps(fun) def fun_grad(*args, **kwargs): value, grad = fun_value_and_grad(*args, **kwargs) if has_aux: diff --git a/equinox/jit.py b/equinox/jit.py index beae0806..d2c2c3de 100644 --- a/equinox/jit.py +++ b/equinox/jit.py @@ -12,8 +12,9 @@ class _Static(Module): @ft.lru_cache(maxsize=None) -def _f_wrapped_cache(**jitkwargs): +def _f_wrapped_cache(fun, **jitkwargs): @ft.partial(jax.jit, static_argnums=(1, 2, 3), **jitkwargs) + @ft.wraps(fun) def f_wrapped(dynamic, static_treedef, static_leaves, filter_spec_return): static = jax.tree_unflatten(static_treedef, static_leaves) f, args, kwargs = combine(dynamic, static) @@ -24,6 +25,15 @@ def f_wrapped(dynamic, static_treedef, static_leaves, filter_spec_return): return f_wrapped +def _strip_wrapped_partial(fun): + """Preserve the outermost wraps call's docstring or traverse to the inner function""" + if hasattr(fun, "__wrapped__"): + return _strip_wrapped_partial(fun.__wrapped__) + if isinstance(fun, ft.partial): + return _strip_wrapped_partial(fun.func) + return fun + + def filter_jit( fun, *, @@ -94,7 +104,8 @@ def fun_wrapper(*args, **kwargs): static = (static_fun,) + static_args_kwargs static_leaves, static_treedef = jax.tree_flatten(static) static_leaves = tuple(static_leaves) - dynamic_out, static_out = _f_wrapped_cache(**jitkwargs)( + inner_fun = _strip_wrapped_partial(static_fun) + dynamic_out, static_out = _f_wrapped_cache(inner_fun, **jitkwargs)( dynamic, static_treedef, static_leaves, filter_spec_return ) return combine(dynamic_out, static_out.value) diff --git a/equinox/module.py b/equinox/module.py index 4baffb14..32b27a44 100644 --- a/equinox/module.py +++ b/equinox/module.py @@ -85,7 +85,7 @@ def _not_magic(k): # Inherits from abc.ABCMeta as a convenience for a common use-case. -# It's not a feature we use ourselve. +# It's not a feature we use ourselves. class _ModuleMeta(abc.ABCMeta): def __new__(mcs, name, bases, dict_): dict_ = { diff --git a/tests/test_filter_jit.py b/tests/test_filter_jit.py index bba92475..50af9094 100644 --- a/tests/test_filter_jit.py +++ b/tests/test_filter_jit.py @@ -204,3 +204,70 @@ def __call__(self, x): eqx.filter_jit(m)(y) eqx.filter_jit(m)(y) assert num_traces == 1 + + +@pytest.fixture +def log_compiles_config(): + """Setup and teardown of jax_log_compiles flag""" + with jax.log_compiles(True): + yield + + +def test_function_name_warning(log_compiles_config, caplog): + """Test that the proper function names are used when compiling a function decorated with `filter_jit`""" + + @eqx.filter_jit + def the_test_function_name(x): + return x + 1 + + # Trigger compile to log a warning message + the_test_function_name(jnp.array(1.0)) + + warning_text = caplog.text + + # Check that the warning message contains the function name + assert "Finished XLA compilation of the_test_function_name in" in warning_text + + # Check that it works for filter_grad also + @eqx.filter_jit + @eqx.filter_grad + def the_test_function_name_grad(x): + return x + 1 + + # Trigger compile to log a warning message + the_test_function_name_grad(jnp.array(1.0)) + + warning_text = caplog.text + + assert "Finished XLA compilation of the_test_function_name_grad in" in warning_text + + @eqx.filter_jit + @eqx.filter_value_and_grad + def the_test_function_name_value_and_grad(x): + return x + 1 + + # Trigger compile to log a warning message + the_test_function_name_value_and_grad(jnp.array(1.0)) + + warning_text = caplog.text + + assert ( + "Finished XLA compilation of the_test_function_name_value_and_grad in" + in warning_text + ) + + def wrapped_fun(x, y): + return x + y + + def the_test_function_name(x, y): + return wrapped_fun(x, y) + + fun = eqx.filter_jit( + ft.wraps(wrapped_fun)(ft.partial(the_test_function_name, jnp.array(1.0))) + ) + + fun(jnp.array(1.0)) + + warning_text = caplog.text + + assert "Finished XLA compilation of wrapped_fun in" in warning_text