Skip to content

Commit 588d9dd

Browse files
committed
Forward-fix for JAX API changes
In jax-ml/jax#33984, JAX will begin returning tuples rather than lists for several jax.numpy APIs. This fixes breakages associated with that change.
1 parent 7edb6a4 commit 588d9dd

File tree

1 file changed

+1
-1
lines changed
  • keras/src/backend/jax

1 file changed

+1
-1
lines changed

keras/src/backend/jax/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,7 @@ def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
909909
values = jnp.greater_equal(jnp.ravel(x), 0).astype(dtype)
910910
values_count = values.shape[0]
911911
indices = [jnp.arange(dim) for dim in x.shape]
912-
indices = jnp.meshgrid(*indices, indexing="ij")
912+
indices = list(jnp.meshgrid(*indices, indexing="ij"))
913913
indices.insert(axis, jnp.maximum(x, 0)) # Deal with negative indices
914914
indices = [a.reshape(values_count, 1).astype("int32") for a in indices]
915915
indices = jnp.concatenate(indices, axis=1)

0 commit comments

Comments
 (0)