Skip to content

BUG: add test that wrapping preserves view/copy semantics, fix where it doesn't #333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,27 +524,6 @@ def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
return xp.nonzero(x, **kwargs)


# ceil, floor, and trunc return integers for integer inputs


def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.ceil(x, **kwargs)


def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.floor(x, **kwargs)


def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.trunc(x, **kwargs)


# linear algebra functions


Expand Down Expand Up @@ -707,9 +686,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
"argsort",
"sort",
"nonzero",
"ceil",
"floor",
"trunc",
"matmul",
"matrix_transpose",
"tensordot",
Expand Down
24 changes: 20 additions & 4 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@
argsort = get_xp(cp)(_aliases.argsort)
sort = get_xp(cp)(_aliases.sort)
nonzero = get_xp(cp)(_aliases.nonzero)
ceil = get_xp(cp)(_aliases.ceil)
floor = get_xp(cp)(_aliases.floor)
trunc = get_xp(cp)(_aliases.trunc)
matmul = get_xp(cp)(_aliases.matmul)
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)
Expand Down Expand Up @@ -123,6 +120,25 @@ def count_nonzero(
return cp.expand_dims(result, axis)
return result

# ceil, floor, and trunc return integers for integer inputs

def ceil(x: Array, /) -> Array:
if cp.issubdtype(x.dtype, cp.integer):
return x.copy()
return cp.ceil(x)


def floor(x: Array, /) -> Array:
if cp.issubdtype(x.dtype, cp.integer):
return x.copy()
return cp.floor(x)


def trunc(x: Array, /) -> Array:
if cp.issubdtype(x.dtype, cp.integer):
return x.copy()
return cp.trunc(x)
Comment on lines +125 to +140
Copy link
Contributor

@crusaderky crusaderky Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int arguments are outside of the scope of the Array API.
So I'm not sure this should be here at all?

For numpy is a bit more nuanced - on one hand it makes sense to unify numpy 1.x behaviour to match 2.x.
On the other hand it's a array-api-compat tweak specifically to cover a change in behaviour that is already outside of the Array API.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's just backwards compat. On main, all backends try to return integers for integer inputs (via common/_aliases.py), so this PR tries to preserve that, while fixing the views/copies issue.



# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
Expand Down Expand Up @@ -151,6 +167,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
'take_along_axis']
'ceil', 'floor', 'trunc', 'take_along_axis']

_all_ignore = ['cp', 'get_xp']
3 changes: 0 additions & 3 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,6 @@ def arange(
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
vecdot = get_xp(da)(_aliases.vecdot)
nonzero = get_xp(da)(_aliases.nonzero)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)
sign = get_xp(np)(_aliases.sign)
Expand Down
26 changes: 23 additions & 3 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@
argsort = get_xp(np)(_aliases.argsort)
sort = get_xp(np)(_aliases.sort)
nonzero = get_xp(np)(_aliases.nonzero)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
tensordot = get_xp(np)(_aliases.tensordot)
Expand Down Expand Up @@ -145,6 +142,26 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
return np.take_along_axis(x, indices, axis=axis)


# ceil, floor, and trunc return integers for integer inputs in NumPy < 2

def ceil(x: Array, /) -> Array:
if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer):
return x.copy()
return np.ceil(x)


def floor(x: Array, /) -> Array:
if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer):
return x.copy()
return np.floor(x)


def trunc(x: Array, /) -> Array:
if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer):
return x.copy()
return np.trunc(x)


# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np, "vecdot"):
Expand Down Expand Up @@ -173,6 +190,9 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
"atan",
"atan2",
"atanh",
"ceil",
"floor",
"trunc",
"bitwise_left_shift",
"bitwise_invert",
"bitwise_right_shift",
Expand Down
64 changes: 64 additions & 0 deletions tests/test_copies_or_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
A collection of tests to make sure that wrapped namespaces agree with the bare ones
on whether to return a view or a copy of inputs.
"""
import pytest
from ._helpers import import_, wrapped_libraries


FUNC_INPUTS = [
# func_name, arr_input, dtype, scalar_value
('abs', [1, 2], 'int8', 3),
('abs', [1, 2], 'float32', 3.),
('ceil', [1, 2], 'int8', 3),
('clip', [1, 2], 'int8', 3),
('conj', [1, 2], 'int8', 3),
('floor', [1, 2], 'int8', 3),
('imag', [1j, 2j], 'complex64', 3),
('positive', [1, 2], 'int8', 3),
('real', [1., 2.], 'float32', 3.),
('round', [1, 2], 'int8', 3),
('sign', [0, 0], 'float32', 3),
('trunc', [1, 2], 'int8', 3),
('trunc', [1, 2], 'float32', 3),
]


def ensure_unary(func, arr):
"""Make a trivial unary function from func."""
if func.__name__ == 'clip':
return lambda x: func(x, arr[0], arr[1])
return func


def is_view(func, a, value):
"""Apply `func`, mutate the output; does the input change?"""
b = func(a)
b[0] = value
return a[0] == value


@pytest.mark.parametrize('xp_name', wrapped_libraries + ['array_api_strict'])
@pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS])
def test_view_or_copy(inputs, xp_name):
bare_xp = import_(xp_name, wrapper=False)
wrapped_xp = import_(xp_name, wrapper=True)

func_name, arr_input, dtype_str, value = inputs
dtype = getattr(bare_xp, dtype_str)

bare_func = getattr(bare_xp, func_name)
bare_func = ensure_unary(bare_func, arr_input)

wrapped_func = getattr(wrapped_xp, func_name)
wrapped_func = ensure_unary(wrapped_func, arr_input)

# bare namespace: mutate the output, does the input change?
a = bare_xp.asarray(arr_input, dtype=dtype)
is_view_bare = is_view(bare_func, a, value)

# wrapped namespace: mutate the output, does the input change?
a1 = wrapped_xp.asarray(arr_input, dtype=dtype)
is_view_wrapped = is_view(wrapped_func, a1, value)

assert is_view_bare == is_view_wrapped
Loading