Skip to content

ENH: pad: add delegation #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
expand_dims
kron
nunique
pad
setdiff1d
sinc
```
933 changes: 493 additions & 440 deletions pixi.lock

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,11 @@ tests-backends = ["py310", "tests", "backends", "cuda-backends"]
minversion = "6.0"
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
xfail_strict = true
filterwarnings = ["error"]
filterwarnings = [
"error",
# TODO: when Python 3.10 is dropped, use `enum.member` in `_delegation.py`
"ignore:functools.partial will be a method descriptor:FutureWarning",
]
log_cli_level = "INFO"
testpaths = ["tests"]
markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific backend"]
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Extra array functions built on top of the array API standard."""

from ._funcs import (
from ._delegation import pad
from ._lib._funcs import (
at,
atleast_nd,
cov,
create_diagonal,
expand_dims,
kron,
nunique,
pad,
setdiff1d,
sinc,
)
Expand Down
125 changes: 125 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Delegation to existing implementations for Public API Functions."""

import functools
from enum import Enum
from types import ModuleType
from typing import final

from ._lib import _funcs
from ._lib._utils._compat import (
array_namespace,
is_cupy_namespace,
is_jax_namespace,
is_numpy_namespace,
is_torch_namespace,
)
from ._lib._utils._typing import Array

__all__ = ["pad"]


@final
class IsNamespace(Enum):
"""Enum to access is_namespace functions as the backend."""

# TODO: when Python 3.10 is dropped, use `enum.member`
# https://stackoverflow.com/a/74302109
CUPY = functools.partial(is_cupy_namespace)
JAX = functools.partial(is_jax_namespace)
NUMPY = functools.partial(is_numpy_namespace)
TORCH = functools.partial(is_torch_namespace)

def __call__(self, xp: ModuleType) -> bool:
"""
Call the is_namespace function.

Parameters
----------
xp : array_namespace
Array namespace to check.

Returns
-------
bool
``True`` if xp matches the namespace, ``False`` otherwise.
"""
return self.value(xp)


CUPY = IsNamespace.CUPY
JAX = IsNamespace.JAX
NUMPY = IsNamespace.NUMPY
TORCH = IsNamespace.TORCH


def _delegate(xp: ModuleType, *backends: IsNamespace) -> bool:
"""
Check whether `xp` is one of the `backends` to delegate to.

Parameters
----------
xp : array_namespace
Array namespace to check.
*backends : IsNamespace
Arbitrarily many backends (from the ``IsNamespace`` enum) to check.

Returns
-------
bool
``True`` if `xp` matches one of the `backends`, ``False`` otherwise.
"""
return any(is_namespace(xp) for is_namespace in backends)


def pad(
x: Array,
pad_width: int | tuple[int, int] | list[tuple[int, int]],
mode: str = "constant",
*,
constant_values: bool | int | float | complex = 0,
xp: ModuleType | None = None,
) -> Array:
"""
Pad the input array.

Parameters
----------
x : array
Input array.
pad_width : int or tuple of ints or list of pairs of ints
Pad the input array with this many elements from each side.
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
each pair applies to the corresponding axis of ``x``.
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
copies of this tuple.
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
array
The input array,
padded with ``pad_width`` elements equal to ``constant_values``.
"""
xp = array_namespace(x) if xp is None else xp

if mode != "constant":
msg = "Only `'constant'` mode is currently supported"
raise NotImplementedError(msg)

# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
if _delegate(xp, TORCH):
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

if _delegate(xp, NUMPY, JAX, CUPY):
return xp.pad(x, pad_width, mode, constant_values=constant_values)

return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""Modules housing private functions."""
"""Internals of array-api-extra."""
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Public API Functions."""
"""Array-agnostic implementations for the public API."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except.... it isn't agnostic, see for example the special paths in at and nunique

Copy link
Member Author

@lucascolley lucascolley Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I would like to split the file structure so that functions which make use of special paths are separate from array-agnostic implementations. I'll save that for a follow-up.

# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
from __future__ import annotations
Expand All @@ -11,13 +11,9 @@
from types import ModuleType
from typing import ClassVar, cast

from ._lib import _compat, _utils
from ._lib._compat import (
array_namespace,
is_jax_array,
is_writeable_array,
)
from ._lib._typing import Array, Index
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_jax_array, is_writeable_array
from ._utils._typing import Array, Index

__all__ = [
"at",
Expand Down Expand Up @@ -151,7 +147,7 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
m = atleast_nd(m, ndim=2, xp=xp)
m = xp.astype(m, dtype)

avg = _utils.mean(m, axis=1, xp=xp)
avg = _helpers.mean(m, axis=1, xp=xp)
fact = m.shape[1] - 1

if fact <= 0:
Expand Down Expand Up @@ -467,7 +463,7 @@ def setdiff1d(
else:
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]


def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
Expand Down Expand Up @@ -562,54 +558,18 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
def pad(
x: Array,
pad_width: int | tuple[int, int] | list[tuple[int, int]],
mode: str = "constant",
*,
xp: ModuleType | None = None,
constant_values: bool | int | float | complex = 0,
) -> Array:
"""
Pad the input array.

Parameters
----------
x : array
Input array.
pad_width : int or tuple of ints or list of pairs of ints
Pad the input array with this many elements from each side.
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
each pair applies to the corresponding axis of ``x``.
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
copies of this tuple.
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.

Returns
-------
array
The input array,
padded with ``pad_width`` elements equal to ``constant_values``.
"""
if mode != "constant":
msg = "Only `'constant'` mode is currently supported"
raise NotImplementedError(msg)

value = constant_values

xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in `array_api_extra._delegation.py`."""
# make pad_width a list of length-2 tuples of ints
x_ndim = cast(int, x.ndim)
if isinstance(pad_width, int):
pad_width = [(pad_width, pad_width)] * x_ndim
if isinstance(pad_width, tuple):
pad_width = [pad_width] * x_ndim

if xp is None:
xp = array_namespace(x)

# https://github.com/python/typeshed/issues/13376
slices: list[slice] = [] # type: ignore[no-any-explicit]
newshape: list[int] = []
Expand All @@ -633,7 +593,7 @@ def pad(

padded = xp.full(
tuple(newshape),
fill_value=value,
fill_value=constant_values,
dtype=x.dtype,
device=_compat.device(x),
)
Expand Down
6 changes: 4 additions & 2 deletions src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
Note that this is private API; don't expect it to be stable.
"""

from ._compat import (
from types import ModuleType

from ._utils._compat import (
array_namespace,
is_cupy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
)
from ._typing import Array, ModuleType
from ._utils._typing import Array

__all__ = ["xp_assert_close", "xp_assert_equal"]

Expand Down
1 change: 1 addition & 0 deletions src/array_api_extra/_lib/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Modules housing private utility functions."""
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
# `array-api-compat` to override the import location

try:
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
from ...._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
array_namespace,
device,
is_cupy_namespace,
is_jax_array,
is_jax_namespace,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
is_writeable_array,
Expand All @@ -21,6 +22,7 @@
is_cupy_namespace,
is_jax_array,
is_jax_namespace,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
is_writeable_array,
Expand All @@ -33,6 +35,7 @@
"is_cupy_namespace",
"is_jax_array",
"is_jax_namespace",
"is_numpy_namespace",
"is_pydata_sparse_namespace",
"is_torch_namespace",
"is_writeable_array",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ def array_namespace(
use_compat: bool | None = None,
) -> ArrayModule: ...
def device(x: Array, /) -> Device: ...
def is_cupy_namespace(x: object, /) -> bool: ...
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
def is_jax_array(x: object, /) -> bool: ...
def is_jax_namespace(x: object, /) -> bool: ...
def is_pydata_sparse_namespace(x: object, /) -> bool: ...
def is_torch_namespace(x: object, /) -> bool: ...
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
def is_writeable_array(x: object, /) -> bool: ...
def size(x: Array, /) -> int | None: ...
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Utility functions used by `array_api_extra/_funcs.py`."""
"""Helper functions used by `array_api_extra/_funcs.py`."""

# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
from __future__ import annotations

from types import ModuleType

from . import _compat
from ._typing import Array, ModuleType
from ._typing import Array

__all__ = ["in1d", "mean"]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Static typing helpers."""

from types import ModuleType
from typing import Any

# To be changed to a Protocol later (see data-apis/array-api#589)
Array = Any # type: ignore[no-any-explicit]
Device = Any # type: ignore[no-any-explicit]
Index = Any # type: ignore[no-any-explicit]

__all__ = ["Array", "Device", "Index", "ModuleType"]
__all__ = ["Array", "Device", "Index"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I liked this before. If anything, we should rename it to ArrayModuleType.
I would hope that eventually a new library array_api_types defines what functions exactly an array api compatible module must declare.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this should be fixed eventually by array-api-typing. In the meantime, feel free to submit a PR changing use of ModuleType to an ArrayNamespace alias.

7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Pytest fixtures."""

from enum import Enum
from types import ModuleType
from typing import cast

import pytest

from array_api_extra._lib._compat import array_namespace
from array_api_extra._lib._compat import device as get_device
from array_api_extra._lib._typing import Device, ModuleType
from array_api_extra._lib._utils._compat import array_namespace
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._typing import Device


class Library(Enum):
Expand Down
Loading
Loading