Skip to content

Commit 45b71e5

Browse files
committed
ensure_index: raise better error for traced inputs
1 parent 3243e23 commit 45b71e5

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

jax/_src/api_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@
3737

3838
def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:
3939
"""Ensure x is either an index or a tuple of indices."""
40+
x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
4041
try:
4142
return operator.index(x)
4243
except TypeError:
4344
return tuple(map(operator.index, x))
4445

4546
def _ensure_index_tuple(x: Any) -> Tuple[int, ...]:
4647
"""Convert x to a tuple of indices."""
48+
x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
4749
try:
4850
return (operator.index(x),)
4951
except TypeError:

0 commit comments

Comments
 (0)