Skip to content

Commit 38551c6

Browse files
committed
Remove __array__
This makes it raise an exception, since it isn't supported by the standard (if we just leave it unimplemented, then np.asarray() returns an object dtype array, which is not good). There is one issue here from the test suite, which is that this breaks the logic in asarray() for converting lists of array_api_strict 0-D arrays. I'm not yet sure what to do about that. Fixes #67.
1 parent ff126d7 commit 38551c6

File tree

2 files changed

+10
-28
lines changed

2 files changed

+10
-28
lines changed

array_api_strict/_array_object.py

+10-21
Original file line numberDiff line numberDiff line change
@@ -125,28 +125,17 @@ def __repr__(self: Array, /) -> str:
125125
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
126126
return prefix + mid + suffix
127127

128-
# This function is not required by the spec, but we implement it here for
129-
# convenience so that np.asarray(array_api_strict.Array) will work.
128+
# Disallow __array__, meaning calling `np.func()` on an array_api_strict
129+
# array will give an error. If we don't explicitly disallow it, NumPy
130+
# defaults to creating an object dtype array, which would lead to
131+
# confusing error messages at best and surprising bugs at worst.
132+
#
133+
# The alternative of course is to just support __array__, which is what we
134+
# used to do. But this isn't actually supported by the standard, so it can
135+
# lead to code assuming np.asarray(other_array) would always work in the
136+
# standard.
130137
def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]:
131-
"""
132-
Warning: this method is NOT part of the array API spec. Implementers
133-
of other libraries need not include it, and users should not assume it
134-
will be present in other implementations.
135-
136-
"""
137-
# copy keyword is new in 2.0.0; for older versions don't use it
138-
# retry without that keyword.
139-
if np.__version__[0] < '2':
140-
return np.asarray(self._array, dtype=dtype)
141-
elif np.__version__.startswith('2.0.0-dev0'):
142-
# Handle dev version for which we can't know based on version
143-
# number whether or not the copy keyword is supported.
144-
try:
145-
return np.asarray(self._array, dtype=dtype, copy=copy)
146-
except TypeError:
147-
return np.asarray(self._array, dtype=dtype)
148-
else:
149-
return np.asarray(self._array, dtype=dtype, copy=copy)
138+
raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported")
150139

151140
# These are various helper functions to make the array behavior match the
152141
# spec in places where it either deviates from or is more strict than

array_api_strict/tests/test_array_object.py

-7
Original file line numberDiff line numberDiff line change
@@ -342,13 +342,6 @@ def test_array_properties():
342342
assert isinstance(b.mT, Array)
343343
assert b.mT.shape == (3, 2)
344344

345-
def test___array__():
346-
a = ones((2, 3), dtype=int16)
347-
assert np.asarray(a) is a._array
348-
b = np.asarray(a, dtype=np.float64)
349-
assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64)))
350-
assert b.dtype == np.float64
351-
352345
def test_allow_newaxis():
353346
a = ones(5)
354347
indexed_a = a[None, :]

0 commit comments

Comments
 (0)