Skip to content

Commit 86a402b

Browse files
committed
Merge branch 'main' into signbit-nan
2 parents 93f9048 + 44bf2af commit 86a402b

13 files changed

+139
-10
lines changed

.github/workflows/array-api-tests-torch.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ jobs:
1010
# Proper linalg testing will require
1111
# https://github.com/data-apis/array-api-tests/pull/101
1212
pytest-extra-args: "--disable-extension linalg"
13+
extra-env-vars: |
14+
ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64

.github/workflows/array-api-tests.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ on:
2727
skips-file-extra:
2828
required: false
2929
type: string
30-
30+
extra-env-vars:
31+
required: false
32+
type: string
33+
description: "Multiline string of environment variables to set for the test run."
3134

3235
env:
3336
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline"
@@ -54,6 +57,11 @@ jobs:
5457
uses: actions/setup-python@v5
5558
with:
5659
python-version: ${{ matrix.python-version }}
60+
- name: Set Extra Environment Variables
61+
# Set additional environment variables if provided
62+
if: inputs.extra-env-vars
63+
run: |
64+
echo "${{ inputs.extra-env-vars }}" >> $GITHUB_ENV
5765
- name: Install dependencies
5866
# NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way
5967
# to put this in the numpy 1.21 config file.

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
- name: Run Tests
3434
run: |
3535
if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then
36-
PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask")
36+
PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask and not sparse")
3737
fi
3838
pytest -v "${PYTEST_EXTRA[@]}"
3939

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
This is a small wrapper around common array libraries that is compatible with
44
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
5-
NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
6-
libraries, or if you encounter any issues, please [open an
5+
NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support
6+
for other array libraries, or if you encounter any issues, please [open an
77
issue](https://github.com/data-apis/array-api-compat/issues).
88

99
See the documentation for more details https://data-apis.org/array-api-compat/

array_api_compat/common/_helpers.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def is_numpy_array(x):
5050
is_torch_array
5151
is_dask_array
5252
is_jax_array
53+
is_pydata_sparse
5354
"""
5455
# Avoid importing NumPy if it isn't already
5556
if 'numpy' not in sys.modules:
@@ -79,6 +80,7 @@ def is_cupy_array(x):
7980
is_torch_array
8081
is_dask_array
8182
is_jax_array
83+
is_pydata_sparse
8284
"""
8385
# Avoid importing NumPy if it isn't already
8486
if 'cupy' not in sys.modules:
@@ -105,6 +107,7 @@ def is_torch_array(x):
105107
is_cupy_array
106108
is_dask_array
107109
is_jax_array
110+
is_pydata_sparse
108111
"""
109112
# Avoid importing torch if it isn't already
110113
if 'torch' not in sys.modules:
@@ -131,6 +134,7 @@ def is_dask_array(x):
131134
is_cupy_array
132135
is_torch_array
133136
is_jax_array
137+
is_pydata_sparse
134138
"""
135139
# Avoid importing dask if it isn't already
136140
if 'dask.array' not in sys.modules:
@@ -157,6 +161,7 @@ def is_jax_array(x):
157161
is_cupy_array
158162
is_torch_array
159163
is_dask_array
164+
is_pydata_sparse
160165
"""
161166
# Avoid importing jax if it isn't already
162167
if 'jax' not in sys.modules:
@@ -166,6 +171,35 @@ def is_jax_array(x):
166171

167172
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
168173

174+
175+
def is_pydata_sparse(x) -> bool:
176+
"""
177+
Return True if `x` is an array from the `sparse` package.
178+
179+
This function does not import `sparse` if it has not already been imported
180+
and is therefore cheap to use.
181+
182+
183+
See Also
184+
--------
185+
186+
array_namespace
187+
is_array_api_obj
188+
is_numpy_array
189+
is_cupy_array
190+
is_torch_array
191+
is_dask_array
192+
is_jax_array
193+
"""
194+
# Avoid importing jax if it isn't already
195+
if 'sparse' not in sys.modules:
196+
return False
197+
198+
import sparse
199+
200+
# TODO: Account for other backends.
201+
return isinstance(x, sparse.SparseArray)
202+
169203
def is_array_api_obj(x):
170204
"""
171205
Return True if `x` is an array API compatible array object.
@@ -185,6 +219,7 @@ def is_array_api_obj(x):
185219
or is_torch_array(x) \
186220
or is_dask_array(x) \
187221
or is_jax_array(x) \
222+
or is_pydata_sparse(x) \
188223
or hasattr(x, '__array_namespace__')
189224

190225
def _check_api_version(api_version):
@@ -253,6 +288,7 @@ def your_function(x, y):
253288
is_torch_array
254289
is_dask_array
255290
is_jax_array
291+
is_pydata_sparse
256292
257293
"""
258294
if use_compat not in [None, True, False]:
@@ -312,6 +348,15 @@ def your_function(x, y):
312348
# not have a wrapper submodule for it.
313349
import jax.experimental.array_api as jnp
314350
namespaces.add(jnp)
351+
elif is_pydata_sparse(x):
352+
if use_compat is True:
353+
_check_api_version(api_version)
354+
raise ValueError("`sparse` does not have an array-api-compat wrapper")
355+
else:
356+
import sparse
357+
# `sparse` is already an array namespace. We do not have a wrapper
358+
# submodule for it.
359+
namespaces.add(sparse)
315360
elif hasattr(x, '__array_namespace__'):
316361
if use_compat is True:
317362
raise ValueError("The given array does not have an array-api-compat wrapper")
@@ -406,8 +451,23 @@ def device(x: Array, /) -> Device:
406451
return x.device()
407452
else:
408453
return x.device
454+
elif is_pydata_sparse(x):
455+
# `sparse` will gain `.device`, so check for this first.
456+
x_device = getattr(x, 'device', None)
457+
if x_device is not None:
458+
return x_device
459+
# Everything but DOK has this attr.
460+
try:
461+
inner = x.data
462+
except AttributeError:
463+
return "cpu"
464+
# Return the device of the constituent array
465+
return device(inner)
409466
return x.device
410467

468+
# Prevent shadowing, used below
469+
_device = device
470+
411471
# Based on cupy.array_api.Array.to_device
412472
def _cupy_to_device(x, device, /, stream=None):
413473
import cupy as cp
@@ -523,6 +583,10 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
523583
# This import adds to_device to x
524584
import jax.experimental.array_api # noqa: F401
525585
return x.to_device(device, stream=stream)
586+
elif is_pydata_sparse(x) and device == _device(x):
587+
# Perform trivial check to return the same array if
588+
# device is same instead of err-ing.
589+
return x
526590
return x.to_device(device, stream=stream)
527591

528592
def size(x):
@@ -549,6 +613,7 @@ def size(x):
549613
"is_jax_array",
550614
"is_numpy_array",
551615
"is_torch_array",
616+
"is_pydata_sparse",
552617
"size",
553618
"to_device",
554619
]

array_api_compat/torch/linalg.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,22 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
6060

6161
def solve(x1: array, x2: array, /, **kwargs) -> array:
6262
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
63+
# Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
64+
# whenever
65+
# 1. x1.ndim - 1 == x2.ndim
66+
# 2. x1.shape[:-1] == x2.shape
67+
#
68+
# See linalg_solve_is_vector_rhs in
69+
# aten/src/ATen/native/LinearAlgebraUtils.h and
70+
# TORCH_META_FUNC(_linalg_solve_ex) in
71+
# aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
72+
#
73+
# The easiest way to work around this is to prepend a size 1 dimension to
74+
# x2, since x2 is already one dimension less than x1.
75+
#
76+
# See https://github.com/pytorch/pytorch/issues/52915
77+
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
78+
x2 = x2[None]
6379
return torch.linalg.solve(x1, x2, **kwargs)
6480

6581
# torch.trace doesn't support the offset argument and doesn't support stacking
@@ -78,7 +94,23 @@ def vector_norm(
7894
) -> array:
7995
# torch.vector_norm incorrectly treats axis=() the same as axis=None
8096
if axis == ():
81-
keepdims = True
97+
out = kwargs.get('out')
98+
if out is None:
99+
dtype = None
100+
if x.dtype == torch.complex64:
101+
dtype = torch.float32
102+
elif x.dtype == torch.complex128:
103+
dtype = torch.float64
104+
105+
out = torch.zeros_like(x, dtype=dtype)
106+
107+
# The norm of a single scalar works out to abs(x) in every case except
108+
# for ord=0, which is x != 0.
109+
if ord == 0:
110+
out[:] = (x != 0)
111+
else:
112+
out[:] = torch.abs(x)
113+
return out
82114
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
83115

84116
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',

docs/changelog.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44

55
## Major Changes
66

7+
- Add support for `sparse`. Note that unlike other array libraries,
8+
array-api-compat does not contain any wrappers for `sparse` functions. All
9+
`sparse` array API support is in `sparse` itself. Thus, there is no
10+
`array_api_compat.sparse` submodule, and
11+
`array_namespace(<pydata/sparse array>)` returns the `sparse` module.
12+
13+
- Added the function `is_pydata_sparse(x)`.
14+
715
- Drop support for Python 3.8.
816

917
- NumPy 2.0 is now left completely unwrapped.

docs/supported-array-libraries.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,6 @@ For `linalg`, several methods are missing, for example:
132132
Other methods may only be partially implemented or return incorrect results at times.
133133

134134
The minimum supported Dask version is 2023.12.0.
135+
136+
## [`sparse`](https://sparse.pydata.org/en/stable/)
137+
Similar to JAX, `sparse` Array API support is contained directly in `sparse`.

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ jax[cpu]
44
numpy
55
pytest
66
torch
7+
sparse >=0.15.1

tests/_helpers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
from importlib import import_module
2+
import sys
23

34
import pytest
45

56
wrapped_libraries = ["cupy", "torch", "dask.array"]
6-
all_libraries = wrapped_libraries + ["numpy", "jax.numpy"]
7+
all_libraries = wrapped_libraries + ["numpy", "jax.numpy", "sparse"]
78
import numpy as np
89
if np.__version__[0] == '1':
910
wrapped_libraries.append("numpy")
1011

12+
# `sparse` added array API support as of Python 3.10.
13+
if sys.version_info >= (3, 10):
14+
all_libraries.append('sparse')
15+
1116
def import_(library, wrapper=False):
1217
if library == 'cupy':
1318
pytest.importorskip(library)
1419
if wrapper:
1520
if 'jax' in library:
1621
library = 'jax.experimental.array_api'
22+
elif library.startswith('sparse'):
23+
library = 'sparse'
1724
else:
1825
library = 'array_api_compat.' + library
1926

tests/test_array_namespace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_array_namespace(library, api_version, use_compat):
1919
xp = import_(library)
2020

2121
array = xp.asarray([1.0, 2.0, 3.0])
22-
if use_compat is True and library in ['array_api_strict', 'jax.numpy']:
22+
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
2323
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
2424
return
2525
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)

tests/test_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
2-
is_dask_array, is_jax_array)
2+
is_dask_array, is_jax_array, is_pydata_sparse)
33

44
from array_api_compat import is_array_api_obj, device, to_device
55

@@ -16,6 +16,7 @@
1616
'torch': 'is_torch_array',
1717
'dask.array': 'is_dask_array',
1818
'jax.numpy': 'is_jax_array',
19+
'sparse': 'is_pydata_sparse',
1920
}
2021

2122
@pytest.mark.parametrize('library', is_functions.keys())
@@ -76,6 +77,8 @@ def test_asarray_cross_library(source_library, target_library, request):
7677
if source_library == "cupy" and target_library != "cupy":
7778
# cupy explicitly disallows implicit conversions to CPU
7879
pytest.skip(reason="cupy does not support implicit conversion to CPU")
80+
elif source_library == "sparse" and target_library != "sparse":
81+
pytest.skip(reason="`sparse` does not allow implicit densification")
7982
src_lib = import_(source_library, wrapper=True)
8083
tgt_lib = import_(target_library, wrapper=True)
8184
is_tgt_type = globals()[is_functions[target_library]]

tests/test_no_dependencies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _test_dependency(mod):
3333

3434
# array-api-strict is an example of an array API library that isn't
3535
# wrapped by array-api-compat.
36-
if "strict" not in mod:
36+
if "strict" not in mod and mod != "sparse":
3737
is_mod_array = getattr(array_api_compat, f"is_{mod.split('.')[0]}_array")
3838
assert not is_mod_array(a)
3939
assert mod not in sys.modules
@@ -50,7 +50,7 @@ def _test_dependency(mod):
5050
# Y (except most array libraries actually do themselves depend on numpy).
5151

5252
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array",
53-
"jax.numpy", "array_api_strict"])
53+
"jax.numpy", "sparse", "array_api_strict"])
5454
def test_numpy_dependency(library):
5555
# This import is here because it imports numpy
5656
from ._helpers import import_

0 commit comments

Comments
 (0)