Skip to content

Switch to lazy exports #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ lint.ignore = [
]
lint.per-file-ignores."docs/**/*.py" = [ "INP001" ] # No __init__.py in docs
lint.per-file-ignores."src/**/stats/*.py" = [ "A001", "A004" ] # Shadows builtins like `sum`
lint.per-file-ignores."src/fast_array_utils/types.py" = [ "N806" ] # We have variables that are classes here
lint.per-file-ignores."stubs/**/*.pyi" = [ "F403", "F405", "N801" ] # Stubs don’t follow name conventions
lint.per-file-ignores."tests/**/test_*.py" = [
"D100", # tests need no module docstrings
Expand Down
4 changes: 1 addition & 3 deletions src/fast_array_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

from __future__ import annotations

from . import _patches, conv, stats, types
from . import conv, stats, types


__all__ = ["conv", "stats", "types"]

_patches.patch_dask()
94 changes: 94 additions & 0 deletions src/fast_array_utils/_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from dataclasses import dataclass, field
from functools import cache
from types import UnionType
from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar, cast, overload


if TYPE_CHECKING:
from collections.abc import Callable

P = ParamSpec("P")
R = TypeVar("R")


__all__ = ["import_by_qualname", "lazy_singledispatch"]


def import_by_qualname(qualname: str) -> object:
from importlib import import_module

mod_path, obj_path = qualname.split(":")

mod = import_module(mod_path)

if mod_path == "dask" or mod_path.startswith("dask."):
from ._patches import patch_dask

patch_dask()

# get object
obj = mod
for name in obj_path.split("."):
try:
obj = getattr(obj, name)
except AttributeError as e:
msg = f"Could not import {'.'.join(obj_path)} from {'.'.join(mod_path)} "
raise ImportError(msg) from e

Check warning on line 39 in src/fast_array_utils/_import.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/_import.py#L37-L39

Added lines #L37 - L39 were not covered by tests
return obj


@dataclass
class lazy_singledispatch(Generic[P, R]): # noqa: N801
fallback: Callable[P, R]

_lazy: dict[tuple[str, str], Callable[..., R]] = field(init=False, default_factory=dict)
_eager: dict[type | UnionType, Callable[..., R]] = field(init=False, default_factory=dict)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
fn = self.dispatch(type(args[0])) # type: ignore[arg-type] # https://github.com/python/mypy/issues/11470
return fn(*args, **kwargs)

def __hash__(self) -> int:
return hash(self.fallback)

@cache # noqa: B019
def dispatch(self, typ: type) -> Callable[P, R]:
for cls_reg, fn in self._eager.items():
if issubclass(typ, cls_reg):
return fn

Check warning on line 61 in src/fast_array_utils/_import.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/_import.py#L61

Added line #L61 was not covered by tests
for (import_qualname, host_mod_name), fn in self._lazy.items():
for cls in typ.mro():
if cls.__module__.startswith(host_mod_name): # can be deeper
cls_reg = cast(type, import_by_qualname(import_qualname))
if issubclass(typ, cls_reg):
return fn
return self.fallback

@overload
def register(
self, qualname_or_type: str, /, host_mod_name: str | None = None
) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]: ...
@overload
def register(
self, qualname_or_type: type | UnionType, /, host_mod_name: None = None
) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]: ...

def register(
self, qualname_or_type: str | type | UnionType, /, host_mod_name: str | None = None
) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]:
def decorator(fn: Callable[..., R]) -> lazy_singledispatch[P, R]:
match qualname_or_type, host_mod_name:
case str(), _:
hmn = qualname_or_type.split(":")[0] if host_mod_name is None else host_mod_name
self._lazy[(qualname_or_type, hmn)] = fn
case type() | UnionType(), None:
self._eager[qualname_or_type] = fn
case _:
msg = f"name_or_type {qualname_or_type!r} must be a str, type, or UnionType"
raise TypeError(msg)
return self

return decorator
18 changes: 10 additions & 8 deletions src/fast_array_utils/conv/_asarray.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from functools import singledispatch
from typing import TYPE_CHECKING, Any, cast

import numpy as np
from numpy.typing import NDArray

from .. import types
from .._import import lazy_singledispatch
from ..types import OutOfCoreDataset


if TYPE_CHECKING:
from numpy.typing import ArrayLike

from .. import types


__all__ = ["asarray"]


# fallback’s arg0 type has to include types of registered functions
@singledispatch
@lazy_singledispatch
def asarray(
x: ArrayLike
| types.CSBase
Expand All @@ -44,28 +46,28 @@ def asarray(
return np.asarray(x)


@asarray.register(types.CSBase)
@asarray.register("fast_array_utils.types:CSBase", "scipy.sparse")
def _(x: types.CSBase) -> NDArray[Any]:
from .scipy import to_dense

return to_dense(x)


@asarray.register(types.DaskArray)
@asarray.register("dask.array:Array")
def _(x: types.DaskArray) -> NDArray[Any]:
return asarray(x.compute()) # type: ignore[no-untyped-call]


@asarray.register(types.OutOfCoreDataset)
@asarray.register(OutOfCoreDataset)
def _(x: types.OutOfCoreDataset[types.CSBase | NDArray[Any]]) -> NDArray[Any]:
return asarray(x.to_memory())


@asarray.register(types.CupyArray)
@asarray.register("cupy:ndarray")
def _(x: types.CupyArray) -> NDArray[Any]:
return cast(NDArray[Any], x.get())


@asarray.register(types.CupySparseMatrix)
@asarray.register("cupyx.scipy.sparse:spmatrix")
def _(x: types.CupySparseMatrix) -> NDArray[Any]:
return cast(NDArray[Any], x.toarray().get())
14 changes: 8 additions & 6 deletions src/fast_array_utils/stats/_sum.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from functools import partial, singledispatch
from functools import partial
from typing import TYPE_CHECKING, Any, cast, overload

import numpy as np
from numpy.typing import NDArray

from .. import types
from .._import import lazy_singledispatch
from .._validation import validate_axis


Expand Down Expand Up @@ -54,30 +55,31 @@ def sum(
return _sum(x, axis=axis, dtype=dtype)


@singledispatch
@lazy_singledispatch
def _sum(
x: ArrayLike | types.CSBase | types.DaskArray,
/,
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
) -> NDArray[Any] | np.number[Any] | types.DaskArray:
assert not isinstance(x, types.CSBase | types.DaskArray)
return cast(NDArray[Any] | np.number[Any], np.sum(x, axis=axis, dtype=dtype))


@_sum.register(types.CSBase)
@_sum.register("fast_array_utils.types:CSBase", "scipy.sparse")
def _(
x: types.CSBase, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> NDArray[Any] | np.number[Any]:
import scipy.sparse as sp

if isinstance(x, types.CSMatrix):
from ..types import CSMatrix

if isinstance(x, CSMatrix):
x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x)
return cast(NDArray[Any] | np.number[Any], np.sum(x, axis=axis, dtype=dtype))


@_sum.register(types.DaskArray)
@_sum.register("dask.array:Array")
def _(
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> types.DaskArray:
Expand Down
100 changes: 69 additions & 31 deletions src/fast_array_utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@

from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable
from functools import cache
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast, runtime_checkable

from ._import import import_by_qualname


if TYPE_CHECKING:
from collections.abc import Callable
from types import UnionType


__all__ = [
Expand All @@ -20,61 +27,92 @@
T_co = TypeVar("T_co", covariant=True)


# scipy sparse
# registry for lazy exports:


_REGISTRY: dict[str, str | Callable[[], UnionType]] = {}


def _register(name: str) -> Callable[[Callable[[], UnionType]], Callable[[], UnionType]]:
def _decorator(fn: Callable[[], UnionType]) -> Callable[[], UnionType]:
_REGISTRY[name] = fn
return fn

return _decorator


@cache
def __getattr__(name: str) -> type | UnionType:
if (source := _REGISTRY.get(name)) is None:
# A name we don’t know about
raise AttributeError(name) from None

try:
if callable(source):
return source()

return cast(type, import_by_qualname(source))
except ImportError: # A name we can’t import
return type(name, (), {})


# lazy exports:


if TYPE_CHECKING:
from scipy.sparse import csc_array, csc_matrix, csr_array, csr_matrix

CSArray = csr_array | csc_array
CSMatrix = csr_matrix | csc_matrix
CSBase = CSMatrix | CSArray
else:
try: # cs?_array isn’t available in older scipy versions
from scipy.sparse import csc_array, csr_array
# cs?_array isn’t available in older scipy versions,
# so we import them separately

CSArray = csr_array | csc_array
except ImportError: # pragma: no cover
CSArray = type("CSArray", (), {})

try: # cs?_matrix is available when scipy is installed
@_register("CSMatrix")
def _get_cs_matrix() -> UnionType:
from scipy.sparse import csc_matrix, csr_matrix

CSMatrix = csr_matrix | csc_matrix
except ImportError: # pragma: no cover
CSMatrix = type("CSMatrix", (), {})
return csr_matrix | csc_matrix

CSBase = CSMatrix | CSArray
@_register("CSArray")
def _get_cs_array() -> UnionType:
from scipy.sparse import csc_array, csr_array

return csr_array | csc_array

if TYPE_CHECKING or find_spec("cupy"):
from cupy import ndarray as CupyArray
else: # pragma: no cover
CupyArray = type("ndarray", (), {})
@_register("CSBase")
def _get_cs_base() -> UnionType:
return __getattr__("CSMatrix") | __getattr__("CSArray")


if TYPE_CHECKING or find_spec("cupyx"):
if TYPE_CHECKING:
from cupy import ndarray as CupyArray
from cupyx.scipy.sparse import spmatrix as CupySparseMatrix
else: # pragma: no cover
CupySparseMatrix = type("spmatrix", (), {})
else:
_REGISTRY["CupyArray"] = "cupy:ndarray"
_REGISTRY["CupySparseMatrix"] = "cupyx.scipy.sparse:spmatrix"


if TYPE_CHECKING: # https://github.com/dask/dask/issues/8853
from dask.array.core import Array as DaskArray
elif find_spec("dask"):
from dask.array import Array as DaskArray
else: # pragma: no cover
DaskArray = type("array", (), {})
else:
_REGISTRY["DaskArray"] = "dask.array:Array"


if TYPE_CHECKING or find_spec("h5py"):
if TYPE_CHECKING:
from h5py import Dataset as H5Dataset
else: # pragma: no cover
H5Dataset = type("Dataset", (), {})
else:
_REGISTRY["H5Dataset"] = "h5py:Dataset"


if TYPE_CHECKING or find_spec("zarr"):
if TYPE_CHECKING:
from zarr import Array as ZarrArray
else: # pragma: no cover
ZarrArray = type("Array", (), {})
else:
_REGISTRY["ZarrArray"] = "zarr:Array"


# protocols:


@runtime_checkable
Expand Down
Loading