Skip to content

Commit 6d780a8

Browse files
committed
Fix issue with repeat()
NumPy does not allow repeats to be uint64 because it refuses to downcast it. Technically it worked before because we implement __array__ and repeat does manually cast in that case. I'm not really sure we should be supporting __array__ actually.
1 parent 05c8b0f commit 6d780a8

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

array_api_strict/_manipulation_functions.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from ._array_object import Array
44
from ._creation_functions import asarray
5-
from ._data_type_functions import result_type
6-
from ._dtypes import _integer_dtypes
5+
from ._data_type_functions import astype, result_type
6+
from ._dtypes import _integer_dtypes, int64, uint64
77
from ._flags import requires_api_version, get_array_api_strict_flags
88

99
from typing import TYPE_CHECKING
@@ -94,7 +94,13 @@ def repeat(
9494
else:
9595
raise TypeError("repeats must be an int or array")
9696

97-
return Array._new(np.repeat(x._array, repeats, axis=axis))
97+
if repeats.dtype == uint64:
98+
# NumPy does not allow uint64 because can't be cast down to x.dtype
99+
# with 'safe' casting. However, repeats values larger than 2**63 are
100+
# infeasable, and even if they are present by mistake, this will
101+
# lead to underflow and an error.
102+
repeats = astype(repeats, int64)
103+
return Array._new(np.repeat(x._array, repeats._array, axis=axis))
98104

99105
# Note: the optional argument is called 'shape', not 'newshape'
100106
def reshape(x: Array,

0 commit comments

Comments
 (0)