Skip to content

Commit d4cdec4

Browse files
committed
Fix trailing whitespace in transforms_test.py
Removes trailing whitespace that was causing pre-commit hook failures. No functional changes - only formatting cleanup.
1 parent 6ef6246 commit d4cdec4

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/nnx/transforms_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def test_jit_static_args_with_shardings(self, static_argnums, static_argnames):
429429
n_devices = jax.local_device_count()
430430
devices = mesh_utils.create_device_mesh((n_devices,))
431431
mesh = jax.sharding.Mesh(devices, ('data',))
432-
432+
433433
def fn(x, scale, use_relu):
434434
y = x * scale
435435
if use_relu:
@@ -438,8 +438,8 @@ def fn(x, scale, use_relu):
438438

439439
x = jnp.linspace(-1.0, 1.0, 16, dtype=jnp.float32).reshape(4, 4)
440440
x_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('data'))
441-
442-
f = nnx.jit(fn, in_shardings=(x_sharding, None),
441+
442+
f = nnx.jit(fn, in_shardings=(x_sharding, None),
443443
static_argnums=static_argnums, static_argnames=static_argnames)
444444
y_relu = f(x, 0.5, True)
445445
y_no_relu = f(x, 0.5, False)

0 commit comments

Comments
 (0)