Skip to content

add is_*_namespace helper functions #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 168 additions & 1 deletion array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def is_jax_array(x):

return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)


def is_pydata_sparse_array(x) -> bool:
"""
Return True if `x` is an array from the `sparse` package.
Expand Down Expand Up @@ -255,6 +254,166 @@ def is_array_api_obj(x):
or is_pydata_sparse_array(x) \
or hasattr(x, '__array_namespace__')

def _compat_module_name():
assert __name__.endswith('.common._helpers')
return __name__.removesuffix('.common._helpers')

def is_numpy_namespace(xp) -> bool:
"""
Returns True if `xp` is a NumPy namespace.

This includes both NumPy itself and the version wrapped by array-api-compat.

See Also
--------

array_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}

def is_cupy_namespace(xp) -> bool:
"""
Returns True if `xp` is a CuPy namespace.

This includes both CuPy itself and the version wrapped by array-api-compat.

See Also
--------

array_namespace
is_numpy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}

def is_torch_namespace(xp) -> bool:
"""
Returns True if `xp` is a PyTorch namespace.

This includes both PyTorch itself and the version wrapped by array-api-compat.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}


def is_ndonnx_namespace(xp):
"""
Returns True if `xp` is an NDONNX namespace.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'ndonnx'

def is_dask_namespace(xp):
"""
Returns True if `xp` is a Dask namespace.

This includes both ``dask.array`` itself and the version wrapped by array-api-compat.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}

def is_jax_namespace(xp):
"""
Returns True if `xp` is a JAX namespace.

This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
older versions of JAX.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}

def is_pydata_sparse_namespace(xp):
"""
Returns True if `xp` is a pydata/sparse namespace.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'sparse'

def is_array_api_strict_namespace(xp):
"""
Returns True if `xp` is an array-api-strict namespace.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
"""
return xp.__name__ == 'array_api_strict'

def _check_api_version(api_version):
if api_version == '2021.12':
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
Expand Down Expand Up @@ -643,13 +802,21 @@ def size(x):
"device",
"get_namespace",
"is_array_api_obj",
"is_array_api_strict_namespace",
"is_cupy_array",
"is_cupy_namespace",
"is_dask_array",
"is_dask_namespace",
"is_jax_array",
"is_jax_namespace",
"is_numpy_array",
"is_numpy_namespace",
"is_torch_array",
"is_torch_namespace",
"is_ndonnx_array",
"is_ndonnx_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"size",
"to_device",
]
Expand Down
10 changes: 9 additions & 1 deletion docs/helper-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ instead, which would be wrapped.
Inspection Helpers
------------------

These convenience functions can be used to test if an array comes from a
These convenience functions can be used to test if an array or namespace comes from a
specific library without importing that library if it hasn't been imported
yet.

Expand All @@ -51,3 +51,11 @@ yet.
.. autofunction:: is_jax_array
.. autofunction:: is_pydata_sparse_array
.. autofunction:: is_ndonnx_array
.. autofunction:: is_numpy_namespace
.. autofunction:: is_cupy_namespace
.. autofunction:: is_torch_namespace
.. autofunction:: is_dask_namespace
.. autofunction:: is_jax_namespace
.. autofunction:: is_pydata_sparse_namespace
.. autofunction:: is_ndonnx_namespace
.. autofunction:: is_array_api_strict_namespace
44 changes: 34 additions & 10 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
is_dask_array, is_jax_array, is_pydata_sparse_array)
from array_api_compat import ( # noqa: F401
is_numpy_array, is_cupy_array, is_torch_array,
is_dask_array, is_jax_array, is_pydata_sparse_array,
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
)

from array_api_compat import is_array_api_obj, device, to_device

Expand All @@ -10,7 +14,7 @@
import array
from numpy.testing import assert_allclose

is_functions = {
is_array_functions = {
'numpy': 'is_numpy_array',
'cupy': 'is_cupy_array',
'torch': 'is_torch_array',
Expand All @@ -19,18 +23,38 @@
'sparse': 'is_pydata_sparse_array',
}

@pytest.mark.parametrize('library', is_functions.keys())
@pytest.mark.parametrize('func', is_functions.values())
is_namespace_functions = {
'numpy': 'is_numpy_namespace',
'cupy': 'is_cupy_namespace',
'torch': 'is_torch_namespace',
'dask.array': 'is_dask_namespace',
'jax.numpy': 'is_jax_namespace',
'sparse': 'is_pydata_sparse_namespace',
}


@pytest.mark.parametrize('library', is_array_functions.keys())
@pytest.mark.parametrize('func', is_array_functions.values())
def test_is_xp_array(library, func):
lib = import_(library)
is_func = globals()[func]

x = lib.asarray([1, 2, 3])

assert is_func(x) == (func == is_functions[library])
assert is_func(x) == (func == is_array_functions[library])

assert is_array_api_obj(x)


@pytest.mark.parametrize('library', is_namespace_functions.keys())
@pytest.mark.parametrize('func', is_namespace_functions.values())
def test_is_xp_namespace(library, func):
lib = import_(library)
is_func = globals()[func]

assert is_func(lib) == (func == is_namespace_functions[library])


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
Expand Down Expand Up @@ -64,8 +88,8 @@ def test_to_device_host(library):
assert_allclose(x, expected)


@pytest.mark.parametrize("target_library", is_functions.keys())
@pytest.mark.parametrize("source_library", is_functions.keys())
@pytest.mark.parametrize("target_library", is_array_functions.keys())
@pytest.mark.parametrize("source_library", is_array_functions.keys())
def test_asarray_cross_library(source_library, target_library, request):
if source_library == "dask.array" and target_library == "torch":
# Allow rest of test to execute instead of immediately xfailing
Expand All @@ -81,7 +105,7 @@ def test_asarray_cross_library(source_library, target_library, request):
pytest.skip(reason="`sparse` does not allow implicit densification")
src_lib = import_(source_library, wrapper=True)
tgt_lib = import_(target_library, wrapper=True)
is_tgt_type = globals()[is_functions[target_library]]
is_tgt_type = globals()[is_array_functions[target_library]]

a = src_lib.asarray([1, 2, 3])
b = tgt_lib.asarray(a)
Expand All @@ -96,7 +120,7 @@ def test_asarray_copy(library):
# should be able to delete this.
xp = import_(library, wrapper=True)
asarray = xp.asarray
is_lib_func = globals()[is_functions[library]]
is_lib_func = globals()[is_array_functions[library]]
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()

if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
Expand Down
1 change: 1 addition & 0 deletions tests/test_vendoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_vendoring_torch():

uses_torch._test_torch()


def test_vendoring_dask():
from vendor_test import uses_dask
uses_dask._test_dask()
9 changes: 8 additions & 1 deletion vendor_test/uses_cupy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Basic test that vendoring works

from .vendored._compat import cupy as cp_compat
from .vendored._compat import (
cupy as cp_compat,
is_cupy_array,
is_cupy_namespace,
)

import cupy as cp

Expand All @@ -16,3 +20,6 @@ def _test_cupy():
assert isinstance(res, cp.ndarray)

cp.testing.assert_allclose(res, [1., 2., 9.])

assert is_cupy_array(res)
assert is_cupy_namespace(cp) and is_cupy_namespace(cp_compat)
4 changes: 4 additions & 0 deletions vendor_test/uses_dask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Basic test that vendoring works

from .vendored._compat.dask import array as dask_compat
from .vendored._compat import is_dask_array, is_dask_namespace

import dask.array as da
import numpy as np
Expand All @@ -17,3 +18,6 @@ def _test_dask():
assert isinstance(res, da.Array)

np.testing.assert_allclose(res, [1., 2., 9.])

assert is_dask_array(res)
assert is_dask_namespace(da) and is_dask_namespace(dask_compat)
10 changes: 9 additions & 1 deletion vendor_test/uses_numpy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Basic test that vendoring works

from .vendored._compat import numpy as np_compat
from .vendored._compat import (
is_numpy_array,
is_numpy_namespace,
numpy as np_compat,
)


import numpy as np

Expand All @@ -16,3 +21,6 @@ def _test_numpy():
assert isinstance(res, np.ndarray)

np.testing.assert_allclose(res, [1., 2., 9.])

assert is_numpy_array(res)
assert is_numpy_namespace(np) and is_numpy_namespace(np_compat)
10 changes: 9 additions & 1 deletion vendor_test/uses_torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Basic test that vendoring works

from .vendored._compat import torch as torch_compat
from .vendored._compat import (
is_torch_array,
is_torch_namespace,
torch as torch_compat,
)

import torch

Expand All @@ -20,3 +24,7 @@ def _test_torch():
assert isinstance(res, torch.Tensor)

torch.testing.assert_allclose(res, [[1., 2., 3.]])

assert is_torch_array(res)
assert is_torch_namespace(torch) and is_torch_namespace(torch_compat)

Loading