Skip to content

Commit 59daf9d

Browse files
committed
Move to Array API version 2023.12.
1 parent 38d1a67 commit 59daf9d

File tree

10 files changed

+171
-29
lines changed

10 files changed

+171
-29
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ jobs:
154154
uses: actions/checkout@v4
155155
with:
156156
repository: data-apis/array-api-tests
157-
ref: '33f2d2ea2f3dd2b3ceeeb4519d55e08096184149' # Latest commit as of 2024-05-29
157+
ref: 'd295a0a66cd82a43e84c1b8d73ca198cc45e9d23' # Latest commit as of 2024-05-29
158158
submodules: 'true'
159159
path: 'array-api-tests'
160160
- name: Set up Python

ci/Numba-array-api-xfails.txt

+13-7
Original file line numberDiff line numberDiff line change
@@ -29,42 +29,48 @@ array_api_tests/test_has_names.py::test_has_names[linalg-tensordot]
2929
array_api_tests/test_has_names.py::test_has_names[linalg-trace]
3030
array_api_tests/test_has_names.py::test_has_names[linalg-vecdot]
3131
array_api_tests/test_has_names.py::test_has_names[linalg-vector_norm]
32+
array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
33+
array_api_tests/test_has_names.py::test_has_names[manipulation-tile]
3234
array_api_tests/test_has_names.py::test_has_names[set-unique_all]
3335
array_api_tests/test_has_names.py::test_has_names[set-unique_inverse]
3436
array_api_tests/test_has_names.py::test_has_names[creation-arange]
3537
array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
3638
array_api_tests/test_has_names.py::test_has_names[creation-linspace]
3739
array_api_tests/test_has_names.py::test_has_names[creation-meshgrid]
40+
array_api_tests/test_has_names.py::test_has_names[searching-searchsorted]
3841
array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
42+
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum]
3943
array_api_tests/test_has_names.py::test_has_names[data_type-isdtype]
4044
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__]
4145
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__]
4246
array_api_tests/test_has_names.py::test_has_names[array_method-__setitem__]
4347
array_api_tests/test_indexing_functions.py::test_take
4448
array_api_tests/test_linalg.py::test_vecdot
49+
array_api_tests/test_manipulation_functions.py::test_repeat
50+
array_api_tests/test_manipulation_functions.py::test_tile
4551
array_api_tests/test_operators_and_elementwise_functions.py::test_ceil
4652
array_api_tests/test_operators_and_elementwise_functions.py::test_trunc
4753
array_api_tests/test_searching_functions.py::test_argmax
4854
array_api_tests/test_searching_functions.py::test_argmin
55+
array_api_tests/test_searching_functions.py::test_searchsorted
4956
array_api_tests/test_set_functions.py::test_unique_all
5057
array_api_tests/test_set_functions.py::test_unique_inverse
58+
array_api_tests/test_statistical_functions.py::test_cumulative_sum
5159
array_api_tests/test_signatures.py::test_func_signature[unique_all]
5260
array_api_tests/test_signatures.py::test_func_signature[unique_inverse]
5361
array_api_tests/test_signatures.py::test_func_signature[arange]
62+
array_api_tests/test_signatures.py::test_func_signature[cumulative_sum]
5463
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
5564
array_api_tests/test_signatures.py::test_func_signature[linspace]
5665
array_api_tests/test_signatures.py::test_func_signature[meshgrid]
66+
array_api_tests/test_signatures.py::test_func_signature[repeat]
67+
array_api_tests/test_signatures.py::test_func_signature[tile]
5768
array_api_tests/test_signatures.py::test_func_signature[argsort]
69+
array_api_tests/test_signatures.py::test_func_signature[searchsorted]
5870
array_api_tests/test_signatures.py::test_func_signature[isdtype]
5971
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
6072
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__]
6173
array_api_tests/test_signatures.py::test_array_method_signature[__setitem__]
6274
array_api_tests/test_sorting_functions.py::test_argsort
6375
array_api_tests/test_sorting_functions.py::test_sort
64-
array_api_tests/test_special_cases.py::test_nan_propagation[max]
65-
array_api_tests/test_special_cases.py::test_nan_propagation[mean]
66-
array_api_tests/test_special_cases.py::test_nan_propagation[min]
67-
array_api_tests/test_special_cases.py::test_nan_propagation[prod]
68-
array_api_tests/test_special_cases.py::test_nan_propagation[std]
69-
array_api_tests/test_special_cases.py::test_nan_propagation[sum]
70-
array_api_tests/test_special_cases.py::test_nan_propagation[var]
76+
array_api_tests/test_special_cases.py::test_nan_propagation[cumulative_sum]

sparse/__init__.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ._version import __version__, __version_tuple__ # noqa: F401
66

7-
__array_api_version__ = "2022.12"
7+
__array_api_version__ = "2023.12"
88

99

1010
class BackendType(Enum):
@@ -45,13 +45,7 @@ def get_backend_module():
4545

4646

4747
def __getattr__(attr):
48-
if attr == "numba_backend":
49-
import sparse.numba_backend as backend_module
50-
51-
return backend_module
52-
if attr == "finch_backend":
53-
import sparse.finch_backend as backend_module
54-
55-
return backend_module
48+
if attr == "numba_backend" or attr == "finch_backend":
49+
raise AttributeError
5650

5751
return getattr(Backend.get_backend_module(), attr)

sparse/numba_backend/__init__.py

+25
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import sparse.numba_backend._info as _info
2+
13
from numpy import (
24
add,
35
bitwise_and,
@@ -9,6 +11,7 @@
911
complex64,
1012
complex128,
1113
conj,
14+
copysign,
1215
cos,
1316
cosh,
1417
divide,
@@ -23,6 +26,7 @@
2326
floor_divide,
2427
greater,
2528
greater_equal,
29+
hypot,
2630
iinfo,
2731
inf,
2832
int8,
@@ -41,6 +45,8 @@
4145
logical_not,
4246
logical_or,
4347
logical_xor,
48+
maximum,
49+
minimum,
4450
multiply,
4551
nan,
4652
negative,
@@ -50,6 +56,7 @@
5056
positive,
5157
remainder,
5258
sign,
59+
signbit,
5360
sin,
5461
sinh,
5562
sqrt,
@@ -119,6 +126,7 @@
119126
std,
120127
sum,
121128
tensordot,
129+
unstack,
122130
var,
123131
vecdot,
124132
zeros,
@@ -157,10 +165,16 @@
157165
where,
158166
)
159167
from ._dok import DOK
168+
from ._info import capabilities, default_device, default_dtypes, devices, dtypes
160169
from ._io import load_npz, save_npz
161170
from ._umath import elemwise
162171
from ._utils import random
163172

173+
174+
def __array_namespace_info__():
175+
return _info
176+
177+
164178
__all__ = [
165179
"COO",
166180
"DOK",
@@ -196,19 +210,25 @@
196210
"broadcast_arrays",
197211
"broadcast_to",
198212
"can_cast",
213+
"capabilities",
199214
"ceil",
200215
"clip",
201216
"complex128",
202217
"complex64",
203218
"concat",
204219
"concatenate",
205220
"conj",
221+
"copysign",
206222
"cos",
207223
"cosh",
224+
"default_device",
225+
"default_dtypes",
226+
"devices",
208227
"diagonal",
209228
"diagonalize",
210229
"divide",
211230
"dot",
231+
"dtypes",
212232
"e",
213233
"einsum",
214234
"elemwise",
@@ -230,6 +250,7 @@
230250
"full_like",
231251
"greater",
232252
"greater_equal",
253+
"hypot",
233254
"iinfo",
234255
"imag",
235256
"inf",
@@ -258,8 +279,10 @@
258279
"matmul",
259280
"matrix_transpose",
260281
"max",
282+
"maximum",
261283
"mean",
262284
"min",
285+
"minimum",
263286
"moveaxis",
264287
"multiply",
265288
"nan",
@@ -291,6 +314,7 @@
291314
"round",
292315
"save_npz",
293316
"sign",
317+
"signbit",
294318
"sin",
295319
"sinh",
296320
"sort",
@@ -314,6 +338,7 @@
314338
"uint8",
315339
"unique_counts",
316340
"unique_values",
341+
"unstack",
317342
"var",
318343
"vecdot",
319344
"where",

sparse/numba_backend/_common.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def tensordot(a, b, axes=2, *, return_type=None):
155155
bs = b.shape
156156
ndb = b.ndim
157157
equal = True
158-
if nda == 0 or ndb == 0:
158+
if not (builtins.all(-nda <= ax < nda for ax in axes_a) and builtins.all(-ndb <= ax < ndb for ax in axes_b)):
159159
pos = int(nda != 0)
160160
raise ValueError(f"Input {pos} operand does not have enough dimensions")
161161
if na != nb:
@@ -2146,10 +2146,22 @@ def reshape(x, /, shape, *, copy=None):
21462146
return x.reshape(shape=shape)
21472147

21482148

2149-
def astype(x, dtype, /, *, copy=True):
2149+
@_check_device
2150+
def astype(x, dtype, /, *, copy=True, device=None):
21502151
return x.astype(dtype, copy=copy)
21512152

21522153

2154+
def unstack(x, /, *, axis=0):
2155+
axis = normalize_axis(axis, x.ndim)
2156+
out = []
2157+
2158+
for i in range(x.shape[axis]):
2159+
idx = (slice(None),) * axis + (i,)
2160+
out.append(x[idx])
2161+
2162+
return tuple(out)
2163+
2164+
21532165
@_support_numpy
21542166
def squeeze(x, /, axis=None):
21552167
"""Remove singleton dimensions from array.

sparse/numba_backend/_coo/common.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ def _diagonal_idx(coordlist, axis1, axis2, offset):
10111011
return np.array([i for i in range(len(coordlist[axis1])) if coordlist[axis1][i] + offset == coordlist[axis2][i]])
10121012

10131013

1014-
def clip(a, a_min=None, a_max=None, out=None):
1014+
def clip(a, min=None, max=None, out=None):
10151015
"""
10161016
Clip (limit) the values in the array.
10171017
@@ -1042,19 +1042,19 @@ def clip(a, a_min=None, a_max=None, out=None):
10421042
--------
10431043
>>> import sparse
10441044
>>> x = sparse.COO.from_numpy([0, 0, 0, 1, 2, 3])
1045-
>>> sparse.clip(x, a_min=1).todense() # doctest: +NORMALIZE_WHITESPACE
1045+
>>> sparse.clip(x, min=1).todense() # doctest: +NORMALIZE_WHITESPACE
10461046
array([1, 1, 1, 1, 2, 3])
1047-
>>> sparse.clip(x, a_max=1).todense() # doctest: +NORMALIZE_WHITESPACE
1047+
>>> sparse.clip(x, max=1).todense() # doctest: +NORMALIZE_WHITESPACE
10481048
array([0, 0, 0, 1, 1, 1])
1049-
>>> sparse.clip(x, a_min=1, a_max=2).todense() # doctest: +NORMALIZE_WHITESPACE
1049+
>>> sparse.clip(x, min=1, max=2).todense() # doctest: +NORMALIZE_WHITESPACE
10501050
array([1, 1, 1, 1, 2, 2])
10511051
10521052
See Also
10531053
--------
10541054
numpy.clip : Equivalent NumPy function
10551055
"""
10561056
a = asCOO(a, name="clip")
1057-
return a.clip(a_min, a_max)
1057+
return a.clip(min, max)
10581058

10591059

10601060
def expand_dims(x, /, *, axis=0):

sparse/numba_backend/_info.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import numpy as np
2+
3+
from ._common import _check_device
4+
5+
__all__ = [
6+
"capabilities",
7+
"default_device",
8+
"default_dtypes",
9+
"devices",
10+
"dtypes",
11+
]
12+
13+
_CAPABILITIES = {
14+
"boolean indexing": True,
15+
"data-dependent shapes": True,
16+
}
17+
18+
_DEFAULT_DTYPES = {
19+
"cpu": {
20+
"real floating": np.dtype(np.float64),
21+
"complex floating": np.dtype(np.complex128),
22+
"integral": np.dtype(np.int64),
23+
"indexing": np.dtype(np.int64),
24+
}
25+
}
26+
27+
28+
def _get_dtypes_with_prefix(prefix: str):
29+
out = set()
30+
for a in np.__all__:
31+
if not a.startswith(prefix):
32+
continue
33+
try:
34+
dt = np.dtype(getattr(np, a))
35+
out.add(dt)
36+
except (ValueError, TypeError, AttributeError):
37+
pass
38+
return sorted(out)
39+
40+
41+
_DTYPES = {
42+
"cpu": {
43+
"bool": [np.bool_],
44+
"signed integer": _get_dtypes_with_prefix("int"),
45+
"unsigned integer": _get_dtypes_with_prefix("uint"),
46+
"real floating": _get_dtypes_with_prefix("float"),
47+
"complex floating": _get_dtypes_with_prefix("complex"),
48+
}
49+
}
50+
51+
for _dtdict in _DTYPES.values():
52+
_dtdict["integral"] = _dtdict["signed integer"] + _dtdict["unsigned integer"]
53+
_dtdict["numeric"] = _dtdict["integral"] + _dtdict["real floating"] + _dtdict["complex floating"]
54+
55+
del _dtdict
56+
57+
58+
def capabilities():
59+
return _CAPABILITIES
60+
61+
62+
def default_device():
63+
return "cpu"
64+
65+
66+
@_check_device
67+
def default_dtypes(*, device=None):
68+
if device is None:
69+
device = default_device()
70+
return _DEFAULT_DTYPES[device]
71+
72+
73+
def devices():
74+
return ["cpu"]
75+
76+
77+
@_check_device
78+
def dtypes(*, device=None, kind=None):
79+
if device is None:
80+
device = default_device()
81+
82+
device_dtypes = _DTYPES[device]
83+
84+
if kind is None:
85+
return device_dtypes
86+
87+
if isinstance(kind, str):
88+
return device_dtypes[kind]
89+
90+
out = {}
91+
92+
for k in kind:
93+
out[k] = device_dtypes[k]
94+
95+
return out

sparse/numba_backend/_sparse_array.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -607,10 +607,12 @@ def clip(self, min=None, max=None, out=None):
607607
sparse.clip : For full documentation and more details.
608608
numpy.clip : Equivalent NumPy function.
609609
"""
610-
if min is None and max is None:
611-
raise ValueError("One of max or min must be given.")
612610
if out is not None and not isinstance(out, tuple):
613611
out = (out,)
612+
if min is None and max is None:
613+
if out is not None:
614+
return self.__array_ufunc__(np.identity, "__call__", self, out=out)
615+
return self
614616
return self.__array_ufunc__(np.clip, "__call__", self, a_min=min, a_max=max, out=out)
615617

616618
def astype(self, dtype, casting="unsafe", copy=True):

0 commit comments

Comments
 (0)