diff --git a/pyproject.toml b/pyproject.toml index 53f946e..627145a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,7 @@ lint.ignore = [ ] lint.per-file-ignores."docs/**/*.py" = [ "INP001" ] # No __init__.py in docs lint.per-file-ignores."src/**/stats/*.py" = [ "A001", "A004" ] # Shadows builtins like `sum` +lint.per-file-ignores."src/fast_array_utils/types.py" = [ "N806" ] # We have variables that are classes here lint.per-file-ignores."stubs/**/*.pyi" = [ "F403", "F405", "N801" ] # Stubs don’t follow name conventions lint.per-file-ignores."tests/**/test_*.py" = [ "D100", # tests need no module docstrings diff --git a/src/fast_array_utils/__init__.py b/src/fast_array_utils/__init__.py index 135e10e..9556bd1 100644 --- a/src/fast_array_utils/__init__.py +++ b/src/fast_array_utils/__init__.py @@ -3,9 +3,7 @@ from __future__ import annotations -from . import _patches, conv, stats, types +from . import conv, stats, types __all__ = ["conv", "stats", "types"] - -_patches.patch_dask() diff --git a/src/fast_array_utils/_import.py b/src/fast_array_utils/_import.py new file mode 100644 index 0000000..329d2ac --- /dev/null +++ b/src/fast_array_utils/_import.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +from dataclasses import dataclass, field +from functools import cache +from types import UnionType +from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar, cast, overload + + +if TYPE_CHECKING: + from collections.abc import Callable + +P = ParamSpec("P") +R = TypeVar("R") + + +__all__ = ["import_by_qualname", "lazy_singledispatch"] + + +def import_by_qualname(qualname: str) -> object: + from importlib import import_module + + mod_path, obj_path = qualname.split(":") + + mod = import_module(mod_path) + + if mod_path == "dask" or mod_path.startswith("dask."): + from ._patches import patch_dask + + patch_dask() + + # get object + obj = mod + for name in obj_path.split("."): + try: + obj = getattr(obj, name) + except AttributeError as e: + msg = f"Could not import {'.'.join(obj_path)} from {'.'.join(mod_path)} " + raise ImportError(msg) from e + return obj + + +@dataclass +class lazy_singledispatch(Generic[P, R]): # noqa: N801 + fallback: Callable[P, R] + + _lazy: dict[tuple[str, str], Callable[..., R]] = field(init=False, default_factory=dict) + _eager: dict[type | UnionType, Callable[..., R]] = field(init=False, default_factory=dict) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + fn = self.dispatch(type(args[0])) # type: ignore[arg-type] # https://github.com/python/mypy/issues/11470 + return fn(*args, **kwargs) + + def __hash__(self) -> int: + return hash(self.fallback) + + @cache # noqa: B019 + def dispatch(self, typ: type) -> Callable[P, R]: + for cls_reg, fn in self._eager.items(): + if issubclass(typ, cls_reg): + return fn + for (import_qualname, host_mod_name), fn in self._lazy.items(): + for cls in typ.mro(): + if cls.__module__.startswith(host_mod_name): # can be deeper + cls_reg = cast(type, import_by_qualname(import_qualname)) + if issubclass(typ, cls_reg): + return fn + return self.fallback + + @overload + def register( + self, qualname_or_type: str, /, host_mod_name: str | None = None + ) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]: ... + @overload + def register( + self, qualname_or_type: type | UnionType, /, host_mod_name: None = None + ) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]: ... + + def register( + self, qualname_or_type: str | type | UnionType, /, host_mod_name: str | None = None + ) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]: + def decorator(fn: Callable[..., R]) -> lazy_singledispatch[P, R]: + match qualname_or_type, host_mod_name: + case str(), _: + hmn = qualname_or_type.split(":")[0] if host_mod_name is None else host_mod_name + self._lazy[(qualname_or_type, hmn)] = fn + case type() | UnionType(), None: + self._eager[qualname_or_type] = fn + case _: + msg = f"name_or_type {qualname_or_type!r} must be a str, type, or UnionType" + raise TypeError(msg) + return self + + return decorator diff --git a/src/fast_array_utils/conv/_asarray.py b/src/fast_array_utils/conv/_asarray.py index 42f3e06..bb94c27 100644 --- a/src/fast_array_utils/conv/_asarray.py +++ b/src/fast_array_utils/conv/_asarray.py @@ -1,24 +1,26 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from functools import singledispatch from typing import TYPE_CHECKING, Any, cast import numpy as np from numpy.typing import NDArray -from .. import types +from .._import import lazy_singledispatch +from ..types import OutOfCoreDataset if TYPE_CHECKING: from numpy.typing import ArrayLike + from .. import types + __all__ = ["asarray"] # fallback’s arg0 type has to include types of registered functions -@singledispatch +@lazy_singledispatch def asarray( x: ArrayLike | types.CSBase @@ -44,28 +46,28 @@ def asarray( return np.asarray(x) -@asarray.register(types.CSBase) +@asarray.register("fast_array_utils.types:CSBase", "scipy.sparse") def _(x: types.CSBase) -> NDArray[Any]: from .scipy import to_dense return to_dense(x) -@asarray.register(types.DaskArray) +@asarray.register("dask.array:Array") def _(x: types.DaskArray) -> NDArray[Any]: return asarray(x.compute()) # type: ignore[no-untyped-call] -@asarray.register(types.OutOfCoreDataset) +@asarray.register(OutOfCoreDataset) def _(x: types.OutOfCoreDataset[types.CSBase | NDArray[Any]]) -> NDArray[Any]: return asarray(x.to_memory()) -@asarray.register(types.CupyArray) +@asarray.register("cupy:ndarray") def _(x: types.CupyArray) -> NDArray[Any]: return cast(NDArray[Any], x.get()) -@asarray.register(types.CupySparseMatrix) +@asarray.register("cupyx.scipy.sparse:spmatrix") def _(x: types.CupySparseMatrix) -> NDArray[Any]: return cast(NDArray[Any], x.toarray().get()) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 83199ef..a74c032 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from functools import partial, singledispatch +from functools import partial from typing import TYPE_CHECKING, Any, cast, overload import numpy as np from numpy.typing import NDArray from .. import types +from .._import import lazy_singledispatch from .._validation import validate_axis @@ -54,7 +55,7 @@ def sum( return _sum(x, axis=axis, dtype=dtype) -@singledispatch +@lazy_singledispatch def _sum( x: ArrayLike | types.CSBase | types.DaskArray, /, @@ -62,22 +63,23 @@ def _sum( axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None, ) -> NDArray[Any] | np.number[Any] | types.DaskArray: - assert not isinstance(x, types.CSBase | types.DaskArray) return cast(NDArray[Any] | np.number[Any], np.sum(x, axis=axis, dtype=dtype)) -@_sum.register(types.CSBase) +@_sum.register("fast_array_utils.types:CSBase", "scipy.sparse") def _( x: types.CSBase, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None ) -> NDArray[Any] | np.number[Any]: import scipy.sparse as sp - if isinstance(x, types.CSMatrix): + from ..types import CSMatrix + + if isinstance(x, CSMatrix): x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x) return cast(NDArray[Any] | np.number[Any], np.sum(x, axis=axis, dtype=dtype)) -@_sum.register(types.DaskArray) +@_sum.register("dask.array:Array") def _( x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None ) -> types.DaskArray: diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index 8026123..0f7b6e9 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -3,8 +3,15 @@ from __future__ import annotations -from importlib.util import find_spec -from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable +from functools import cache +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast, runtime_checkable + +from ._import import import_by_qualname + + +if TYPE_CHECKING: + from collections.abc import Callable + from types import UnionType __all__ = [ @@ -20,7 +27,38 @@ T_co = TypeVar("T_co", covariant=True) -# scipy sparse +# registry for lazy exports: + + +_REGISTRY: dict[str, str | Callable[[], UnionType]] = {} + + +def _register(name: str) -> Callable[[Callable[[], UnionType]], Callable[[], UnionType]]: + def _decorator(fn: Callable[[], UnionType]) -> Callable[[], UnionType]: + _REGISTRY[name] = fn + return fn + + return _decorator + + +@cache +def __getattr__(name: str) -> type | UnionType: + if (source := _REGISTRY.get(name)) is None: + # A name we don’t know about + raise AttributeError(name) from None + + try: + if callable(source): + return source() + + return cast(type, import_by_qualname(source)) + except ImportError: # A name we can’t import + return type(name, (), {}) + + +# lazy exports: + + if TYPE_CHECKING: from scipy.sparse import csc_array, csc_matrix, csr_array, csr_matrix @@ -28,53 +66,53 @@ CSMatrix = csr_matrix | csc_matrix CSBase = CSMatrix | CSArray else: - try: # cs?_array isn’t available in older scipy versions - from scipy.sparse import csc_array, csr_array + # cs?_array isn’t available in older scipy versions, + # so we import them separately - CSArray = csr_array | csc_array - except ImportError: # pragma: no cover - CSArray = type("CSArray", (), {}) - - try: # cs?_matrix is available when scipy is installed + @_register("CSMatrix") + def _get_cs_matrix() -> UnionType: from scipy.sparse import csc_matrix, csr_matrix - CSMatrix = csr_matrix | csc_matrix - except ImportError: # pragma: no cover - CSMatrix = type("CSMatrix", (), {}) + return csr_matrix | csc_matrix - CSBase = CSMatrix | CSArray + @_register("CSArray") + def _get_cs_array() -> UnionType: + from scipy.sparse import csc_array, csr_array + return csr_array | csc_array -if TYPE_CHECKING or find_spec("cupy"): - from cupy import ndarray as CupyArray -else: # pragma: no cover - CupyArray = type("ndarray", (), {}) + @_register("CSBase") + def _get_cs_base() -> UnionType: + return __getattr__("CSMatrix") | __getattr__("CSArray") -if TYPE_CHECKING or find_spec("cupyx"): +if TYPE_CHECKING: + from cupy import ndarray as CupyArray from cupyx.scipy.sparse import spmatrix as CupySparseMatrix -else: # pragma: no cover - CupySparseMatrix = type("spmatrix", (), {}) +else: + _REGISTRY["CupyArray"] = "cupy:ndarray" + _REGISTRY["CupySparseMatrix"] = "cupyx.scipy.sparse:spmatrix" if TYPE_CHECKING: # https://github.com/dask/dask/issues/8853 from dask.array.core import Array as DaskArray -elif find_spec("dask"): - from dask.array import Array as DaskArray -else: # pragma: no cover - DaskArray = type("array", (), {}) +else: + _REGISTRY["DaskArray"] = "dask.array:Array" -if TYPE_CHECKING or find_spec("h5py"): +if TYPE_CHECKING: from h5py import Dataset as H5Dataset -else: # pragma: no cover - H5Dataset = type("Dataset", (), {}) +else: + _REGISTRY["H5Dataset"] = "h5py:Dataset" -if TYPE_CHECKING or find_spec("zarr"): +if TYPE_CHECKING: from zarr import Array as ZarrArray -else: # pragma: no cover - ZarrArray = type("Array", (), {}) +else: + _REGISTRY["ZarrArray"] = "zarr:Array" + + +# protocols: @runtime_checkable