|
12 | 12 | from typing import NamedTuple
|
13 | 13 | import inspect
|
14 | 14 |
|
15 |
| -from ._helpers import array_namespace, _check_device, device, is_torch_array |
| 15 | +from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace |
16 | 16 |
|
17 | 17 | # These functions are modified from the NumPy versions.
|
18 | 18 |
|
@@ -530,11 +530,26 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
|
530 | 530 | raise ValueError("Input array must be at least 1-d.")
|
531 | 531 | return tuple(xp.moveaxis(x, axis, 0))
|
532 | 532 |
|
| 533 | +# numpy 1.26 does not use the standard definition for sign on complex numbers |
| 534 | + |
| 535 | +def sign(x: ndarray, /, xp, **kwargs) -> ndarray: |
| 536 | + if isdtype(x.dtype, 'complex floating', xp=xp): |
| 537 | + out = (x/xp.abs(x, **kwargs))[...] |
| 538 | + # sign(0) = 0 but the above formula would give nan |
| 539 | + out[x == 0+0j] = 0+0j |
| 540 | + else: |
| 541 | + out = xp.sign(x, **kwargs) |
| 542 | + # CuPy sign() does not propagate nans. See |
| 543 | + # https://github.com/data-apis/array-api-compat/issues/136 |
| 544 | + if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): |
| 545 | + out[xp.isnan(x)] = xp.nan |
| 546 | + return out[()] |
| 547 | + |
533 | 548 | __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
|
534 | 549 | 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
|
535 | 550 | 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
|
536 | 551 | 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
|
537 | 552 | 'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
|
538 | 553 | 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
|
539 | 554 | 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
|
540 |
| - 'unstack'] |
| 555 | + 'unstack', 'sign'] |
0 commit comments