Skip to content

Commit f9887b4

Browse files
committed
Merge branch 'dask-fft' of github.com:lithomas1/array-api-compat into dask-fft
2 parents ec6dcc4 + 3ad4af2 commit f9887b4

33 files changed

+1904
-110
lines changed

.github/workflows/array-api-tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ jobs:
7474
if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
7575
env:
7676
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }}
77+
ARRAY_API_TESTS_VERSION: 2023.12
7778
# This enables the NEP 50 type promotion behavior (without it a lot of
7879
# tests fail on bad scalar type promotion behavior)
7980
NPY_PROMOTION_STATE: weak

.github/workflows/publish-package.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ jobs:
9494
if: >-
9595
(github.event_name == 'push' && startsWith(github.ref, 'refs/tags'))
9696
|| (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true')
97-
uses: pypa/gh-action-pypi-publish@v1.9.0
97+
uses: pypa/gh-action-pypi-publish@v1.10.1
9898
with:
9999
repository-url: https://test.pypi.org/legacy/
100100
print-hash: true
@@ -107,6 +107,6 @@ jobs:
107107

108108
- name: Publish distribution 📦 to PyPI
109109
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
110-
uses: pypa/gh-action-pypi-publish@v1.9.0
110+
uses: pypa/gh-action-pypi-publish@v1.10.1
111111
with:
112112
print-hash: true

array_api_compat/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717
this implementation for the default when working with NumPy arrays.
1818
1919
"""
20-
__version__ = '1.8'
20+
__version__ = '1.9'
2121

2222
from .common import * # noqa: F401, F403

array_api_compat/common/_aliases.py

Lines changed: 56 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import NamedTuple
1313
import inspect
1414

15-
from ._helpers import array_namespace, _check_device
15+
from ._helpers import array_namespace, _check_device, device, is_torch_array
1616

1717
# These functions are modified from the NumPy versions.
1818

@@ -264,6 +264,38 @@ def var(
264264
) -> ndarray:
265265
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
266266

267+
# cumulative_sum is renamed from cumsum, and adds the include_initial keyword
268+
# argument
269+
270+
def cumulative_sum(
271+
x: ndarray,
272+
/,
273+
xp,
274+
*,
275+
axis: Optional[int] = None,
276+
dtype: Optional[Dtype] = None,
277+
include_initial: bool = False,
278+
**kwargs
279+
) -> ndarray:
280+
wrapped_xp = array_namespace(x)
281+
282+
# TODO: The standard is not clear about what should happen when x.ndim == 0.
283+
if axis is None:
284+
if x.ndim > 1:
285+
raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
286+
axis = 0
287+
288+
res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
289+
290+
# np.cumsum does not support include_initial
291+
if include_initial:
292+
initial_shape = list(x.shape)
293+
initial_shape[axis] = 1
294+
res = xp.concatenate(
295+
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
296+
axis=axis,
297+
)
298+
return res
267299

268300
# The min and max argument names in clip are different and not optional in numpy, and type
269301
# promotion behavior is different.
@@ -281,10 +313,11 @@ def _isscalar(a):
281313
return isinstance(a, (int, float, type(None)))
282314
min_shape = () if _isscalar(min) else min.shape
283315
max_shape = () if _isscalar(max) else max.shape
284-
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
285316

286317
wrapped_xp = array_namespace(x)
287318

319+
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
320+
288321
# np.clip does type promotion but the array API clip requires that the
289322
# output have the same dtype as x. We do this instead of just downcasting
290323
# the result of xp.clip() to handle some corner cases better (e.g.,
@@ -305,20 +338,26 @@ def _isscalar(a):
305338

306339
# At least handle the case of Python integers correctly (see
307340
# https://github.com/numpy/numpy/pull/26892).
308-
if type(min) is int and min <= xp.iinfo(x.dtype).min:
341+
if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
309342
min = None
310-
if type(max) is int and max >= xp.iinfo(x.dtype).max:
343+
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
311344
max = None
312345

313346
if out is None:
314-
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True)
347+
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
348+
copy=True, device=device(x))
315349
if min is not None:
316-
a = xp.broadcast_to(xp.asarray(min), result_shape)
350+
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min):
351+
# Avoid loss of precision due to torch defaulting to float32
352+
min = wrapped_xp.asarray(min, dtype=xp.float64)
353+
a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape)
317354
ia = (out < a) | xp.isnan(a)
318355
# torch requires an explicit cast here
319356
out[ia] = wrapped_xp.astype(a[ia], out.dtype)
320357
if max is not None:
321-
b = xp.broadcast_to(xp.asarray(max), result_shape)
358+
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max):
359+
max = wrapped_xp.asarray(max, dtype=xp.float64)
360+
b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape)
322361
ib = (out > b) | xp.isnan(b)
323362
out[ib] = wrapped_xp.astype(b[ib], out.dtype)
324363
# Return a scalar for 0-D
@@ -389,42 +428,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
389428
raise ValueError("nonzero() does not support zero-dimensional arrays")
390429
return xp.nonzero(x, **kwargs)
391430

392-
# sum() and prod() should always upcast when dtype=None
393-
def sum(
394-
x: ndarray,
395-
/,
396-
xp,
397-
*,
398-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
399-
dtype: Optional[Dtype] = None,
400-
keepdims: bool = False,
401-
**kwargs,
402-
) -> ndarray:
403-
# `xp.sum` already upcasts integers, but not floats or complexes
404-
if dtype is None:
405-
if x.dtype == xp.float32:
406-
dtype = xp.float64
407-
elif x.dtype == xp.complex64:
408-
dtype = xp.complex128
409-
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
410-
411-
def prod(
412-
x: ndarray,
413-
/,
414-
xp,
415-
*,
416-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
417-
dtype: Optional[Dtype] = None,
418-
keepdims: bool = False,
419-
**kwargs,
420-
) -> ndarray:
421-
if dtype is None:
422-
if x.dtype == xp.float32:
423-
dtype = xp.float64
424-
elif x.dtype == xp.complex64:
425-
dtype = xp.complex128
426-
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
427-
428431
# ceil, floor, and trunc return integers for integer inputs
429432

430433
def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
@@ -521,10 +524,17 @@ def isdtype(
521524
# array_api_strict implementation will be very strict.
522525
return dtype == kind
523526

527+
# unstack is a new function in the 2023.12 array API standard
528+
def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
529+
if x.ndim == 0:
530+
raise ValueError("Input array must be at least 1-d.")
531+
return tuple(xp.moveaxis(x, axis, 0))
532+
524533
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
525534
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
526535
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
527536
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
528-
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort',
529-
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
530-
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
537+
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
538+
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
539+
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
540+
'unstack']

array_api_compat/common/_helpers.py

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def is_jax_array(x):
202202

203203
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
204204

205-
206205
def is_pydata_sparse_array(x) -> bool:
207206
"""
208207
Return True if `x` is an array from the `sparse` package.
@@ -255,6 +254,166 @@ def is_array_api_obj(x):
255254
or is_pydata_sparse_array(x) \
256255
or hasattr(x, '__array_namespace__')
257256

257+
def _compat_module_name():
258+
assert __name__.endswith('.common._helpers')
259+
return __name__.removesuffix('.common._helpers')
260+
261+
def is_numpy_namespace(xp) -> bool:
262+
"""
263+
Returns True if `xp` is a NumPy namespace.
264+
265+
This includes both NumPy itself and the version wrapped by array-api-compat.
266+
267+
See Also
268+
--------
269+
270+
array_namespace
271+
is_cupy_namespace
272+
is_torch_namespace
273+
is_ndonnx_namespace
274+
is_dask_namespace
275+
is_jax_namespace
276+
is_pydata_sparse_namespace
277+
is_array_api_strict_namespace
278+
"""
279+
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
280+
281+
def is_cupy_namespace(xp) -> bool:
282+
"""
283+
Returns True if `xp` is a CuPy namespace.
284+
285+
This includes both CuPy itself and the version wrapped by array-api-compat.
286+
287+
See Also
288+
--------
289+
290+
array_namespace
291+
is_numpy_namespace
292+
is_torch_namespace
293+
is_ndonnx_namespace
294+
is_dask_namespace
295+
is_jax_namespace
296+
is_pydata_sparse_namespace
297+
is_array_api_strict_namespace
298+
"""
299+
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
300+
301+
def is_torch_namespace(xp) -> bool:
302+
"""
303+
Returns True if `xp` is a PyTorch namespace.
304+
305+
This includes both PyTorch itself and the version wrapped by array-api-compat.
306+
307+
See Also
308+
--------
309+
310+
array_namespace
311+
is_numpy_namespace
312+
is_cupy_namespace
313+
is_ndonnx_namespace
314+
is_dask_namespace
315+
is_jax_namespace
316+
is_pydata_sparse_namespace
317+
is_array_api_strict_namespace
318+
"""
319+
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
320+
321+
322+
def is_ndonnx_namespace(xp):
323+
"""
324+
Returns True if `xp` is an NDONNX namespace.
325+
326+
See Also
327+
--------
328+
329+
array_namespace
330+
is_numpy_namespace
331+
is_cupy_namespace
332+
is_torch_namespace
333+
is_dask_namespace
334+
is_jax_namespace
335+
is_pydata_sparse_namespace
336+
is_array_api_strict_namespace
337+
"""
338+
return xp.__name__ == 'ndonnx'
339+
340+
def is_dask_namespace(xp):
341+
"""
342+
Returns True if `xp` is a Dask namespace.
343+
344+
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
345+
346+
See Also
347+
--------
348+
349+
array_namespace
350+
is_numpy_namespace
351+
is_cupy_namespace
352+
is_torch_namespace
353+
is_ndonnx_namespace
354+
is_jax_namespace
355+
is_pydata_sparse_namespace
356+
is_array_api_strict_namespace
357+
"""
358+
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
359+
360+
def is_jax_namespace(xp):
361+
"""
362+
Returns True if `xp` is a JAX namespace.
363+
364+
This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
365+
older versions of JAX.
366+
367+
See Also
368+
--------
369+
370+
array_namespace
371+
is_numpy_namespace
372+
is_cupy_namespace
373+
is_torch_namespace
374+
is_ndonnx_namespace
375+
is_dask_namespace
376+
is_pydata_sparse_namespace
377+
is_array_api_strict_namespace
378+
"""
379+
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
380+
381+
def is_pydata_sparse_namespace(xp):
382+
"""
383+
Returns True if `xp` is a pydata/sparse namespace.
384+
385+
See Also
386+
--------
387+
388+
array_namespace
389+
is_numpy_namespace
390+
is_cupy_namespace
391+
is_torch_namespace
392+
is_ndonnx_namespace
393+
is_dask_namespace
394+
is_jax_namespace
395+
is_array_api_strict_namespace
396+
"""
397+
return xp.__name__ == 'sparse'
398+
399+
def is_array_api_strict_namespace(xp):
400+
"""
401+
Returns True if `xp` is an array-api-strict namespace.
402+
403+
See Also
404+
--------
405+
406+
array_namespace
407+
is_numpy_namespace
408+
is_cupy_namespace
409+
is_torch_namespace
410+
is_ndonnx_namespace
411+
is_dask_namespace
412+
is_jax_namespace
413+
is_pydata_sparse_namespace
414+
"""
415+
return xp.__name__ == 'array_api_strict'
416+
258417
def _check_api_version(api_version):
259418
if api_version == '2021.12':
260419
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
@@ -643,13 +802,21 @@ def size(x):
643802
"device",
644803
"get_namespace",
645804
"is_array_api_obj",
805+
"is_array_api_strict_namespace",
646806
"is_cupy_array",
807+
"is_cupy_namespace",
647808
"is_dask_array",
809+
"is_dask_namespace",
648810
"is_jax_array",
811+
"is_jax_namespace",
649812
"is_numpy_array",
813+
"is_numpy_namespace",
650814
"is_torch_array",
815+
"is_torch_namespace",
651816
"is_ndonnx_array",
817+
"is_ndonnx_namespace",
652818
"is_pydata_sparse_array",
819+
"is_pydata_sparse_namespace",
653820
"size",
654821
"to_device",
655822
]

array_api_compat/common/_linalg.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,6 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
147147
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
148148

149149
def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
150-
if dtype is None:
151-
if x.dtype == xp.float32:
152-
dtype = xp.float64
153-
elif x.dtype == xp.complex64:
154-
dtype = xp.complex128
155150
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
156151

157152
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',

0 commit comments

Comments
 (0)