Skip to content

Commit

Permalink
Merge pull request #52 from marcelroed/main
Browse files Browse the repository at this point in the history
Fix function names when logging compilation progress
  • Loading branch information
patrick-kidger authored Mar 28, 2022
2 parents 3af0cde + 3804a8d commit 41900ca
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 4 deletions.
2 changes: 1 addition & 1 deletion equinox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from .update import apply_updates


__version__ = "0.3.0"
__version__ = "0.3.1"
2 changes: 2 additions & 0 deletions equinox/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs):
x = combine(diff_x, nondiff_x)
return fun(x, *args, **kwargs)

@ft.wraps(fun)
def fun_value_and_grad_wrapper(x, *args, **kwargs):
diff_x, nondiff_x = partition(x, filter_spec)
return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
Expand Down Expand Up @@ -81,6 +82,7 @@ def grad_func(x__y):
fun, filter_spec=filter_spec, **gradkwargs
)

@ft.wraps(fun)
def fun_grad(*args, **kwargs):
value, grad = fun_value_and_grad(*args, **kwargs)
if has_aux:
Expand Down
15 changes: 13 additions & 2 deletions equinox/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ class _Static(Module):


@ft.lru_cache(maxsize=None)
def _f_wrapped_cache(**jitkwargs):
def _f_wrapped_cache(fun, **jitkwargs):
@ft.partial(jax.jit, static_argnums=(1, 2, 3), **jitkwargs)
@ft.wraps(fun)
def f_wrapped(dynamic, static_treedef, static_leaves, filter_spec_return):
static = jax.tree_unflatten(static_treedef, static_leaves)
f, args, kwargs = combine(dynamic, static)
Expand All @@ -24,6 +25,15 @@ def f_wrapped(dynamic, static_treedef, static_leaves, filter_spec_return):
return f_wrapped


def _strip_wrapped_partial(fun):
"""Preserve the outermost wraps call's docstring or traverse to the inner function"""
if hasattr(fun, "__wrapped__"):
return _strip_wrapped_partial(fun.__wrapped__)
if isinstance(fun, ft.partial):
return _strip_wrapped_partial(fun.func)
return fun


def filter_jit(
fun,
*,
Expand Down Expand Up @@ -94,7 +104,8 @@ def fun_wrapper(*args, **kwargs):
static = (static_fun,) + static_args_kwargs
static_leaves, static_treedef = jax.tree_flatten(static)
static_leaves = tuple(static_leaves)
dynamic_out, static_out = _f_wrapped_cache(**jitkwargs)(
inner_fun = _strip_wrapped_partial(static_fun)
dynamic_out, static_out = _f_wrapped_cache(inner_fun, **jitkwargs)(
dynamic, static_treedef, static_leaves, filter_spec_return
)
return combine(dynamic_out, static_out.value)
Expand Down
2 changes: 1 addition & 1 deletion equinox/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _not_magic(k):


# Inherits from abc.ABCMeta as a convenience for a common use-case.
# It's not a feature we use ourselve.
# It's not a feature we use ourselves.
class _ModuleMeta(abc.ABCMeta):
def __new__(mcs, name, bases, dict_):
dict_ = {
Expand Down
67 changes: 67 additions & 0 deletions tests/test_filter_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,70 @@ def __call__(self, x):
eqx.filter_jit(m)(y)
eqx.filter_jit(m)(y)
assert num_traces == 1


@pytest.fixture
def log_compiles_config():
"""Setup and teardown of jax_log_compiles flag"""
with jax.log_compiles(True):
yield


def test_function_name_warning(log_compiles_config, caplog):
"""Test that the proper function names are used when compiling a function decorated with `filter_jit`"""

@eqx.filter_jit
def the_test_function_name(x):
return x + 1

# Trigger compile to log a warning message
the_test_function_name(jnp.array(1.0))

warning_text = caplog.text

# Check that the warning message contains the function name
assert "Finished XLA compilation of the_test_function_name in" in warning_text

# Check that it works for filter_grad also
@eqx.filter_jit
@eqx.filter_grad
def the_test_function_name_grad(x):
return x + 1

# Trigger compile to log a warning message
the_test_function_name_grad(jnp.array(1.0))

warning_text = caplog.text

assert "Finished XLA compilation of the_test_function_name_grad in" in warning_text

@eqx.filter_jit
@eqx.filter_value_and_grad
def the_test_function_name_value_and_grad(x):
return x + 1

# Trigger compile to log a warning message
the_test_function_name_value_and_grad(jnp.array(1.0))

warning_text = caplog.text

assert (
"Finished XLA compilation of the_test_function_name_value_and_grad in"
in warning_text
)

def wrapped_fun(x, y):
return x + y

def the_test_function_name(x, y):
return wrapped_fun(x, y)

fun = eqx.filter_jit(
ft.wraps(wrapped_fun)(ft.partial(the_test_function_name, jnp.array(1.0)))
)

fun(jnp.array(1.0))

warning_text = caplog.text

assert "Finished XLA compilation of wrapped_fun in" in warning_text

0 comments on commit 41900ca

Please sign in to comment.