Skip to content

Commit aa1be84

Browse files
Jake VanderPlasJaxonnxruntime Authors
authored andcommitted
Forward-fix for JAX API changes
In jax-ml/jax#33984, JAX will begin returning tuples rather than lists for several jax.numpy APIs. This fixes breakages associated with that change. PiperOrigin-RevId: 846323468
1 parent a1650c2 commit aa1be84

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

jaxonnxruntime/onnx_ops/scatterelements.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ def onnx_scatterelements(*input_args, axis, reduction):
9696
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#ScatterElements for more details."""
9797
data, indices, updates = input_args
9898

99-
idx = jnp.meshgrid(
100-
*(jnp.arange(n) for n in data.shape), sparse=True, indexing="ij"
99+
idx = list(
100+
jnp.meshgrid(
101+
*(jnp.arange(n) for n in data.shape), sparse=True, indexing="ij"
102+
)
101103
)
102104
idx[axis] = indices
103105
out = getattr(data.at[tuple(idx)], reduction)(

jaxonnxruntime/onnx_ops/scatternd.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,12 @@ def onnx_scatternd(*input_args, reduction: str):
141141
# e.g., for (r-k)=3 and z updates:
142142
# [(1,1,1,1), (1,range(x1),1,1), (1,1,range(x2),1), (1,1,1,range(x3))]
143143
# L----------L------------------L------------------L----> dim for z
144-
idx = jnp.meshgrid(
145-
*(jnp.arange(n) for n in [1] + list(data.shape[k:])),
146-
sparse=True,
147-
indexing="ij",
144+
idx = list(
145+
jnp.meshgrid(
146+
*(jnp.arange(n) for n in [1] + list(data.shape[k:])),
147+
sparse=True,
148+
indexing="ij",
149+
)
148150
)
149151
assert idx[0].ndim == (r - k) + 1
150152

0 commit comments

Comments
 (0)