From 6c1fa819ae51be76b6e7e29d5ec69a376253c70b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20R=C3=B8d?= Date: Mon, 28 Mar 2022 02:42:02 +0100 Subject: [PATCH 1/9] Show the correct function/method name when compiling with jax_log_compiles=True --- equinox/grad.py | 1 + equinox/jit.py | 9 +++++++-- equinox/module.py | 2 +- tests/test_filter_jit.py | 36 ++++++++++++++++++++++++++++++++++++ 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/equinox/grad.py b/equinox/grad.py index fbc662ec..282c3ea1 100644 --- a/equinox/grad.py +++ b/equinox/grad.py @@ -81,6 +81,7 @@ def grad_func(x__y): fun, filter_spec=filter_spec, **gradkwargs ) + @ft.wraps(fun) def fun_grad(*args, **kwargs): value, grad = fun_value_and_grad(*args, **kwargs) if has_aux: diff --git a/equinox/jit.py b/equinox/jit.py index beae0806..31a52cc0 100644 --- a/equinox/jit.py +++ b/equinox/jit.py @@ -12,8 +12,9 @@ class _Static(Module): @ft.lru_cache(maxsize=None) -def _f_wrapped_cache(**jitkwargs): +def _f_wrapped_cache(fun, **jitkwargs): @ft.partial(jax.jit, static_argnums=(1, 2, 3), **jitkwargs) + @ft.wraps(fun) def f_wrapped(dynamic, static_treedef, static_leaves, filter_spec_return): static = jax.tree_unflatten(static_treedef, static_leaves) f, args, kwargs = combine(dynamic, static) @@ -94,7 +95,11 @@ def fun_wrapper(*args, **kwargs): static = (static_fun,) + static_args_kwargs static_leaves, static_treedef = jax.tree_flatten(static) static_leaves = tuple(static_leaves) - dynamic_out, static_out = _f_wrapped_cache(**jitkwargs)( + if isinstance(fun, ft.partial): + inner_fun = fun.func + else: + inner_fun = fun + dynamic_out, static_out = _f_wrapped_cache(inner_fun, **jitkwargs)( dynamic, static_treedef, static_leaves, filter_spec_return ) return combine(dynamic_out, static_out.value) diff --git a/equinox/module.py b/equinox/module.py index 4baffb14..32b27a44 100644 --- a/equinox/module.py +++ b/equinox/module.py @@ -85,7 +85,7 @@ def _not_magic(k): # Inherits from abc.ABCMeta as a convenience for a common use-case. -# It's not a feature we use ourselve. +# It's not a feature we use ourselves. class _ModuleMeta(abc.ABCMeta): def __new__(mcs, name, bases, dict_): dict_ = { diff --git a/tests/test_filter_jit.py b/tests/test_filter_jit.py index bba92475..8ec5817f 100644 --- a/tests/test_filter_jit.py +++ b/tests/test_filter_jit.py @@ -204,3 +204,39 @@ def __call__(self, x): eqx.filter_jit(m)(y) eqx.filter_jit(m)(y) assert num_traces == 1 + + +@pytest.fixture +def log_compiles_config(): + """Setup and teardown of jax_log_compiles flag""" + jax.config.update('jax_log_compiles', True) + yield + jax.config.update('jax_log_compiles', False) + + +def test_function_name_warning(log_compiles_config, caplog): + """Test that the proper function names are used when compiling a function decorated with `filter_jit`""" + @eqx.filter_jit + def the_test_function_name(x): + return x + 1 + + # Trigger compile to log a warning message + the_test_function_name(jnp.array(1.0)) + + warning_text = caplog.text + + # Check that the warning message contains the function name + assert 'Finished XLA compilation of the_test_function_name in' in warning_text + + # Check that it works for filter_grad also + @eqx.filter_jit + @eqx.filter_grad + def the_test_function_name_grad(x): + return x + 1 + + # Trigger compile to log a warning message + the_test_function_name_grad(jnp.array(1.0)) + + warning_text = caplog.text + + assert 'Finished XLA compilation of the_test_function_name_grad in' in warning_text From ef949d7643904bf843b68ad5df76881e6c0a0be4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20R=C3=B8d?= Date: Mon, 28 Mar 2022 02:49:37 +0100 Subject: [PATCH 2/9] Fix and test wrapping for filter_value_and_grad --- equinox/grad.py | 1 + tests/test_filter_jit.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/equinox/grad.py b/equinox/grad.py index 282c3ea1..f256a0a1 100644 --- a/equinox/grad.py +++ b/equinox/grad.py @@ -25,6 +25,7 @@ def fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs): x = combine(diff_x, nondiff_x) return fun(x, *args, **kwargs) + @ft.wraps(fun) def fun_value_and_grad_wrapper(x, *args, **kwargs): diff_x, nondiff_x = partition(x, filter_spec) return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs) diff --git a/tests/test_filter_jit.py b/tests/test_filter_jit.py index 8ec5817f..2c241498 100644 --- a/tests/test_filter_jit.py +++ b/tests/test_filter_jit.py @@ -240,3 +240,16 @@ def the_test_function_name_grad(x): warning_text = caplog.text assert 'Finished XLA compilation of the_test_function_name_grad in' in warning_text + + + @eqx.filter_jit + @eqx.filter_value_and_grad + def the_test_function_name_value_and_grad(x): + return x + 1 + + # Trigger compile to log a warning message + the_test_function_name_value_and_grad(jnp.array(1.0)) + + warning_text = caplog.text + + assert 'Finished XLA compilation of the_test_function_name_value_and_grad in' in warning_text From 9ff6682cb704ac932a5f13fa7be6cf67cc1c76d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20R=C3=B8d?= Date: Mon, 28 Mar 2022 03:08:20 +0100 Subject: [PATCH 3/9] Formatting with flake8 --- tests/test_filter_jit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_filter_jit.py b/tests/test_filter_jit.py index 2c241498..034e792c 100644 --- a/tests/test_filter_jit.py +++ b/tests/test_filter_jit.py @@ -241,7 +241,6 @@ def the_test_function_name_grad(x): assert 'Finished XLA compilation of the_test_function_name_grad in' in warning_text - @eqx.filter_jit @eqx.filter_value_and_grad def the_test_function_name_value_and_grad(x): From b7bcddbabeaa5d7e6e0ae8a60daf179ad11a1652 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20R=C3=B8d?= Date: Mon, 28 Mar 2022 10:03:18 +0100 Subject: [PATCH 4/9] Use `static_fun` for wrapping --- equinox/jit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/equinox/jit.py b/equinox/jit.py index 31a52cc0..8387b23b 100644 --- a/equinox/jit.py +++ b/equinox/jit.py @@ -96,9 +96,9 @@ def fun_wrapper(*args, **kwargs): static_leaves, static_treedef = jax.tree_flatten(static) static_leaves = tuple(static_leaves) if isinstance(fun, ft.partial): - inner_fun = fun.func + inner_fun = static_fun.func else: - inner_fun = fun + inner_fun = static_fun dynamic_out, static_out = _f_wrapped_cache(inner_fun, **jitkwargs)( dynamic, static_treedef, static_leaves, filter_spec_return ) From f3eff2309ee00fc2697f18c192a470e4125d419c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20R=C3=B8d?= Date: Mon, 28 Mar 2022 10:03:44 +0100 Subject: [PATCH 5/9] Use context manager for flag state --- tests/test_filter_jit.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_filter_jit.py b/tests/test_filter_jit.py index 034e792c..c0226e85 100644 --- a/tests/test_filter_jit.py +++ b/tests/test_filter_jit.py @@ -209,9 +209,8 @@ def __call__(self, x): @pytest.fixture def log_compiles_config(): """Setup and teardown of jax_log_compiles flag""" - jax.config.update('jax_log_compiles', True) - yield - jax.config.update('jax_log_compiles', False) + with jax.log_compiles(True): + yield def test_function_name_warning(log_compiles_config, caplog): From ffd56c811e710baa9880fb43da95d85e598a7072 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20R=C3=B8d?= Date: Mon, 28 Mar 2022 10:18:40 +0100 Subject: [PATCH 6/9] Also use static_fun for isinstance --- equinox/jit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/equinox/jit.py b/equinox/jit.py index 8387b23b..08796a5d 100644 --- a/equinox/jit.py +++ b/equinox/jit.py @@ -95,7 +95,7 @@ def fun_wrapper(*args, **kwargs): static = (static_fun,) + static_args_kwargs static_leaves, static_treedef = jax.tree_flatten(static) static_leaves = tuple(static_leaves) - if isinstance(fun, ft.partial): + if isinstance(static_fun, ft.partial): inner_fun = static_fun.func else: inner_fun = static_fun From 955a7a4a9d09315a4aac25b320d7e6a3a5118845 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20R=C3=B8d?= Date: Mon, 28 Mar 2022 10:34:37 +0100 Subject: [PATCH 7/9] Fix edge-case when wrapping a partial --- equinox/jit.py | 14 ++++++++++---- tests/test_filter_jit.py | 14 ++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/equinox/jit.py b/equinox/jit.py index 08796a5d..9067890d 100644 --- a/equinox/jit.py +++ b/equinox/jit.py @@ -25,6 +25,15 @@ def f_wrapped(dynamic, static_treedef, static_leaves, filter_spec_return): return f_wrapped +def _strip_wrapped_partial(fun): + """Preserve the outermost wraps call's docstring or traverse to the inner function""" + if hasattr(fun, '__wrapped__'): + return _strip_wrapped_partial(fun.__wrapped__) + if isinstance(fun, ft.partial): + return _strip_wrapped_partial(fun.func) + return fun + + def filter_jit( fun, *, @@ -95,10 +104,7 @@ def fun_wrapper(*args, **kwargs): static = (static_fun,) + static_args_kwargs static_leaves, static_treedef = jax.tree_flatten(static) static_leaves = tuple(static_leaves) - if isinstance(static_fun, ft.partial): - inner_fun = static_fun.func - else: - inner_fun = static_fun + inner_fun = _strip_wrapped_partial(static_fun) dynamic_out, static_out = _f_wrapped_cache(inner_fun, **jitkwargs)( dynamic, static_treedef, static_leaves, filter_spec_return ) diff --git a/tests/test_filter_jit.py b/tests/test_filter_jit.py index c0226e85..48019ab2 100644 --- a/tests/test_filter_jit.py +++ b/tests/test_filter_jit.py @@ -251,3 +251,17 @@ def the_test_function_name_value_and_grad(x): warning_text = caplog.text assert 'Finished XLA compilation of the_test_function_name_value_and_grad in' in warning_text + + def wrapped_fun(x, y): + return x + y + + def the_test_function_name(x, y): + return wrapped_fun(x, y) + + fun = eqx.filter_jit(ft.wraps(wrapped_fun)(ft.partial(the_test_function_name, jnp.array(1.0)))) + + fun(jnp.array(1.0)) + + warning_text = caplog.text + + assert 'Finished XLA compilation of wrapped_fun in' in warning_text From 182f25bc5f54872cfe0bb9c90876d2a61c847ba6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20R=C3=B8d?= Date: Mon, 28 Mar 2022 10:57:53 +0100 Subject: [PATCH 8/9] Bump version number --- equinox/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/equinox/__init__.py b/equinox/__init__.py index d3045acf..99528f1b 100644 --- a/equinox/__init__.py +++ b/equinox/__init__.py @@ -15,4 +15,4 @@ from .update import apply_updates -__version__ = "0.3.0" +__version__ = "0.3.1" From 3804a8d60217bde685bee0a893a7bd55b1e63c26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20R=C3=B8d?= Date: Mon, 28 Mar 2022 11:08:35 +0100 Subject: [PATCH 9/9] Formatting with black --- equinox/jit.py | 2 +- tests/test_filter_jit.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/equinox/jit.py b/equinox/jit.py index 9067890d..d2c2c3de 100644 --- a/equinox/jit.py +++ b/equinox/jit.py @@ -27,7 +27,7 @@ def f_wrapped(dynamic, static_treedef, static_leaves, filter_spec_return): def _strip_wrapped_partial(fun): """Preserve the outermost wraps call's docstring or traverse to the inner function""" - if hasattr(fun, '__wrapped__'): + if hasattr(fun, "__wrapped__"): return _strip_wrapped_partial(fun.__wrapped__) if isinstance(fun, ft.partial): return _strip_wrapped_partial(fun.func) diff --git a/tests/test_filter_jit.py b/tests/test_filter_jit.py index 48019ab2..50af9094 100644 --- a/tests/test_filter_jit.py +++ b/tests/test_filter_jit.py @@ -215,6 +215,7 @@ def log_compiles_config(): def test_function_name_warning(log_compiles_config, caplog): """Test that the proper function names are used when compiling a function decorated with `filter_jit`""" + @eqx.filter_jit def the_test_function_name(x): return x + 1 @@ -225,7 +226,7 @@ def the_test_function_name(x): warning_text = caplog.text # Check that the warning message contains the function name - assert 'Finished XLA compilation of the_test_function_name in' in warning_text + assert "Finished XLA compilation of the_test_function_name in" in warning_text # Check that it works for filter_grad also @eqx.filter_jit @@ -238,7 +239,7 @@ def the_test_function_name_grad(x): warning_text = caplog.text - assert 'Finished XLA compilation of the_test_function_name_grad in' in warning_text + assert "Finished XLA compilation of the_test_function_name_grad in" in warning_text @eqx.filter_jit @eqx.filter_value_and_grad @@ -250,7 +251,10 @@ def the_test_function_name_value_and_grad(x): warning_text = caplog.text - assert 'Finished XLA compilation of the_test_function_name_value_and_grad in' in warning_text + assert ( + "Finished XLA compilation of the_test_function_name_value_and_grad in" + in warning_text + ) def wrapped_fun(x, y): return x + y @@ -258,10 +262,12 @@ def wrapped_fun(x, y): def the_test_function_name(x, y): return wrapped_fun(x, y) - fun = eqx.filter_jit(ft.wraps(wrapped_fun)(ft.partial(the_test_function_name, jnp.array(1.0)))) + fun = eqx.filter_jit( + ft.wraps(wrapped_fun)(ft.partial(the_test_function_name, jnp.array(1.0))) + ) fun(jnp.array(1.0)) warning_text = caplog.text - assert 'Finished XLA compilation of wrapped_fun in' in warning_text + assert "Finished XLA compilation of wrapped_fun in" in warning_text