From 374212ad8313f86b9d993134ecb98ad3d62281db Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 18:50:31 +0100 Subject: [PATCH] [WIP] ENH: dask+cupy, dask+sparse etc. namespaces --- array_api_compat/common/_helpers.py | 25 ++++++++++--- array_api_compat/dask/array/__init__.py | 1 + array_api_compat/dask/array/_aliases.py | 13 ++++++- array_api_compat/dask/array/_meta.py | 49 +++++++++++++++++++++++++ 4 files changed, 81 insertions(+), 7 deletions(-) create mode 100644 array_api_compat/dask/array/_meta.py diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 77175d0d..a518b126 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -397,7 +397,9 @@ def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. - This includes both ``dask.array`` itself and the version wrapped by array-api-compat. + This includes ``dask.array`` itself, the version wrapped by array-api-compat, + and the bespoke namespaces generated by + ``array_api_compat.dask.array.wrap_namespace``. See Also -------- @@ -411,7 +413,13 @@ def is_dask_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"} + da_compat_name = _compat_module_name() + '.dask.array' + name = xp.__name__ + return ( + name in {'dask.array', da_compat_name} + or name.startswith(da_compat_name + '.') + and name[len(da_compat_name) + 1:] not in ("linalg", "fft") + ) def is_jax_namespace(xp: Namespace) -> bool: @@ -597,9 +605,16 @@ def your_function(x, y): elif is_dask_array(x): if _use_compat: _check_api_version(api_version) - from ..dask import array as dask_namespace - - namespaces.add(dask_namespace) + from ..dask.array import wrap_namespace + + # The meta-namespace is only used to generate the meta-array, so it + # would be useless to create a namespace such as e.g. + # array_api_compat.dask.array.array_api_compat.cupy. + # It would get worse once you vendor array-api-compat! + # So keep it clean with array_api_compat.dask.array.cupy. + mxp = array_namespace(x._meta, use_compat=False) + xp = wrap_namespace(mxp) + namespaces.add(xp) else: import dask.array as da diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 1e47b960..fec80197 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -4,6 +4,7 @@ # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 +from ._meta import wrap_namespace # noqa: F401 __array_api_version__: Final = "2024.12" diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 9687a9cd..86704e77 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -152,6 +152,7 @@ def asarray( dtype: DType | None = None, device: Device | None = None, copy: py_bool | None = None, + like: Array | None = None, **kwargs: object, ) -> Array: """ @@ -168,7 +169,11 @@ def asarray( if copy is False: raise ValueError("Unable to avoid copy when changing dtype") obj = obj.astype(dtype) - return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue] + if copy: + obj = obj.copy() + if like is not None: + obj = da.asarray(obj, like=like) + return obj if copy is False: raise NotImplementedError( @@ -177,7 +182,11 @@ def asarray( # copy=None to be uniform across dask < 2024.12 and >= 2024.12 # see https://github.com/dask/dask/pull/11524/ - obj = np.array(obj, dtype=dtype, copy=True) + if like is not None: + mxp = array_namespace(like) + obj = mxp.asarray(obj, dtype=dtype, copy=True) + else: + obj = np.array(obj, dtype=dtype, copy=True) return da.from_array(obj) diff --git a/array_api_compat/dask/array/_meta.py b/array_api_compat/dask/array/_meta.py new file mode 100644 index 00000000..9365ad2a --- /dev/null +++ b/array_api_compat/dask/array/_meta.py @@ -0,0 +1,49 @@ +import functools +import sys +import types + +from ...common._helpers import is_numpy_namespace +from ...common._typing import Namespace + +__all__ = ['wrap_namespace'] + + +def wrap_namespace(xp: Namespace) -> Namespace: + """Create a bespoke Dask namespace that wraps around another namespace. + + Parameters + ---------- + xp : namespace + Namespace to be wrapped by Dask + + Returns + ------- + namespace : + A module object that duplicates array_api_compat.dask.array, with the + difference that all creation functions will create an array with the same + meta namespace as the input. + """ + from .. import array as da_compat + + if is_numpy_namespace(xp): + return da_compat + + mod_name = f'{da_compat.__name__}.{xp.__name__}' + try: + return sys.modules[mod_name] + except KeyError: + pass + + mod = types.ModuleType(mod_name) + sys.modules[mod_name] = mod + + meta = xp.empty(()) + for name, v in da_compat.__dict__.items(): + if name.startswith('_'): + continue + if name in {'arange', 'asarray', 'empty', 'eye', 'from_dlpack', + 'full', 'linspace', 'ones', 'zeros'}: + v = functools.wraps(v)(functools.partial(v, like=meta)) + setattr(mod, name, v) + + return mod