Skip to content

Commit 522a608

Browse files
authored
Merge pull request #190 from asmeurer/sign-fix
Add a wrapper for sign for NumPy-likes
2 parents 5affae5 + 8dee4d6 commit 522a608

File tree

5 files changed

+24
-10
lines changed

5 files changed

+24
-10
lines changed

array_api_compat/common/_aliases.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import NamedTuple
1313
import inspect
1414

15-
from ._helpers import array_namespace, _check_device, device, is_torch_array
15+
from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace
1616

1717
# These functions are modified from the NumPy versions.
1818

@@ -530,11 +530,26 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
530530
raise ValueError("Input array must be at least 1-d.")
531531
return tuple(xp.moveaxis(x, axis, 0))
532532

533+
# numpy 1.26 does not use the standard definition for sign on complex numbers
534+
535+
def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
536+
if isdtype(x.dtype, 'complex floating', xp=xp):
537+
out = (x/xp.abs(x, **kwargs))[...]
538+
# sign(0) = 0 but the above formula would give nan
539+
out[x == 0+0j] = 0+0j
540+
else:
541+
out = xp.sign(x, **kwargs)
542+
# CuPy sign() does not propagate nans. See
543+
# https://github.com/data-apis/array-api-compat/issues/136
544+
if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
545+
out[xp.isnan(x)] = xp.nan
546+
return out[()]
547+
533548
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
534549
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
535550
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
536551
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
537552
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
538553
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
539554
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
540-
'unstack']
555+
'unstack', 'sign']

array_api_compat/cupy/_aliases.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
matmul = get_xp(cp)(_aliases.matmul)
6363
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6464
tensordot = get_xp(cp)(_aliases.tensordot)
65+
sign = get_xp(cp)(_aliases.sign)
6566

6667
_copy_default = object()
6768

@@ -109,13 +110,6 @@ def asarray(
109110

110111
return cp.array(obj, dtype=dtype, **kwargs)
111112

112-
def sign(x: ndarray, /) -> ndarray:
113-
# CuPy sign() does not propagate nans. See
114-
# https://github.com/data-apis/array-api-compat/issues/136
115-
out = cp.sign(x)
116-
out[cp.isnan(x)] = cp.nan
117-
return out
118-
119113
# These functions are completely new here. If the library already has them
120114
# (i.e., numpy 2.0), use the library version instead of our wrapper.
121115
if hasattr(cp, 'vecdot'):

array_api_compat/dask/array/_aliases.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _dask_arange(
104104
trunc = get_xp(np)(_aliases.trunc)
105105
matmul = get_xp(np)(_aliases.matmul)
106106
tensordot = get_xp(np)(_aliases.tensordot)
107-
107+
sign = get_xp(np)(_aliases.sign)
108108

109109
# asarray also adds the copy keyword, which is not present in numpy 1.0.
110110
def asarray(

array_api_compat/numpy/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
matmul = get_xp(np)(_aliases.matmul)
6363
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
6464
tensordot = get_xp(np)(_aliases.tensordot)
65+
sign = get_xp(np)(_aliases.sign)
6566

6667
def _supports_buffer_protocol(obj):
6768
try:

torch-xfails.txt

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1
5656
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)]
5757
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]
5858

59+
# inverse trig functions are too inaccurate on CPU
60+
array_api_tests/test_operators_and_elementwise_functions.py::test_acos
61+
array_api_tests/test_operators_and_elementwise_functions.py::test_atan
62+
array_api_tests/test_operators_and_elementwise_functions.py::test_asin
5963

6064
# overflow near float max
6165
array_api_tests/test_operators_and_elementwise_functions.py::test_log1p

0 commit comments

Comments
 (0)