Skip to content

Commit bbf346c

Browse files
committed
BUG: take_along_axis: numpy requires an axis
1 parent 5e14b53 commit bbf346c

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

array_api_compat/numpy/_aliases.py

+10
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@ def count_nonzero(
140140
return result
141141

142142

143+
# "axis=-1" is an optional argument of `take_along_axis` but numpy has no default
144+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
145+
if axis is None:
146+
axis = -1
147+
return np.take_along_axis(x, indices, axis=axis)
148+
149+
143150
# These functions are completely new here. If the library already has them
144151
# (i.e., numpy 2.0), use the library version instead of our wrapper.
145152
if hasattr(np, "vecdot"):
@@ -157,6 +164,7 @@ def count_nonzero(
157164
else:
158165
unstack = get_xp(np)(_aliases.unstack)
159166

167+
160168
__all__ = [
161169
"__array_namespace_info__",
162170
"asarray",
@@ -175,10 +183,12 @@ def count_nonzero(
175183
"concat",
176184
"count_nonzero",
177185
"pow",
186+
"take_along_axis"
178187
]
179188
__all__ += _aliases.__all__
180189
_all_ignore = ["np", "get_xp"]
181190

182191

183192
def __dir__() -> list[str]:
184193
return __all__
194+

0 commit comments

Comments
 (0)