Skip to content

Commit

Permalink
Formatting with black
Browse files Browse the repository at this point in the history
  • Loading branch information
marcelroed committed Mar 28, 2022
1 parent 182f25b commit 3804a8d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion equinox/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def f_wrapped(dynamic, static_treedef, static_leaves, filter_spec_return):

def _strip_wrapped_partial(fun):
"""Preserve the outermost wraps call's docstring or traverse to the inner function"""
if hasattr(fun, '__wrapped__'):
if hasattr(fun, "__wrapped__"):
return _strip_wrapped_partial(fun.__wrapped__)
if isinstance(fun, ft.partial):
return _strip_wrapped_partial(fun.func)
Expand Down
16 changes: 11 additions & 5 deletions tests/test_filter_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def log_compiles_config():

def test_function_name_warning(log_compiles_config, caplog):
"""Test that the proper function names are used when compiling a function decorated with `filter_jit`"""

@eqx.filter_jit
def the_test_function_name(x):
return x + 1
Expand All @@ -225,7 +226,7 @@ def the_test_function_name(x):
warning_text = caplog.text

# Check that the warning message contains the function name
assert 'Finished XLA compilation of the_test_function_name in' in warning_text
assert "Finished XLA compilation of the_test_function_name in" in warning_text

# Check that it works for filter_grad also
@eqx.filter_jit
Expand All @@ -238,7 +239,7 @@ def the_test_function_name_grad(x):

warning_text = caplog.text

assert 'Finished XLA compilation of the_test_function_name_grad in' in warning_text
assert "Finished XLA compilation of the_test_function_name_grad in" in warning_text

@eqx.filter_jit
@eqx.filter_value_and_grad
Expand All @@ -250,18 +251,23 @@ def the_test_function_name_value_and_grad(x):

warning_text = caplog.text

assert 'Finished XLA compilation of the_test_function_name_value_and_grad in' in warning_text
assert (
"Finished XLA compilation of the_test_function_name_value_and_grad in"
in warning_text
)

def wrapped_fun(x, y):
return x + y

def the_test_function_name(x, y):
return wrapped_fun(x, y)

fun = eqx.filter_jit(ft.wraps(wrapped_fun)(ft.partial(the_test_function_name, jnp.array(1.0))))
fun = eqx.filter_jit(
ft.wraps(wrapped_fun)(ft.partial(the_test_function_name, jnp.array(1.0)))
)

fun(jnp.array(1.0))

warning_text = caplog.text

assert 'Finished XLA compilation of wrapped_fun in' in warning_text
assert "Finished XLA compilation of wrapped_fun in" in warning_text

0 comments on commit 3804a8d

Please sign in to comment.