Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 92 additions & 8 deletions src/gt4py/eve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def first(iterable: Iterable[T], *, default: Union[T, NothingType] = NOTHING) ->
raise error


def isinstancechecker(type_info: Union[Type, Iterable[Type]]) -> Callable[[Any], bool]:
def isinstancechecker(
type_info: Union[Type, Iterable[Type], types.UnionType],
) -> Callable[[Any], bool]:
"""Return a callable object that checks if operand is an instance of `type_info`.

Examples:
Expand All @@ -101,18 +103,22 @@ def isinstancechecker(type_info: Union[Type, Iterable[Type]]) -> Callable[[Any],
False

"""
types: Tuple[Type, ...] = tuple()
accepted_types: Tuple[Type, ...] = tuple()
if isinstance(type_info, type):
types = (type_info,)
accepted_types = (type_info,)
elif isinstance(type_info, types.UnionType):
accepted_types = type_info.__args__
elif not isinstance(type_info, tuple) and is_collection(type_info):
types = tuple(type_info)
accepted_types = tuple(type_info)
else:
types = type_info # type:ignore # it is checked at run-time
accepted_types = type_info # type:ignore # it is checked at run-time

if not isinstance(types, tuple) or not all(isinstance(t, type) for t in types):
raise ValueError(f"Invalid type(s) definition: '{types}'.")
if not isinstance(accepted_types, tuple) or not all(
isinstance(t, type) for t in accepted_types
):
raise ValueError(f"Invalid type(s) definition: '{accepted_types}'.")

return lambda obj: isinstance(obj, types)
return lambda obj: isinstance(obj, accepted_types)


def attrchecker(*names: str) -> Callable[[Any], bool]:
Expand Down Expand Up @@ -527,6 +533,84 @@ def partial(self, *args: Any, **kwargs: Any) -> fluid_partial:
return fluid_partial(self, *args, **kwargs)


class TypeMapping(collections.abc.Mapping[type, _T]):
"""
A mapping from types to values supporting complex type-based dispatching.

The mapping supports registering values for specific types, and
retrieving values based on the type key, supporting subtyping
relationship exactly in the same way as `functools.singledispatch()` works.
For example, if a value is registered for a base class, it will be returned
for instances of derived classes unless a more specific type is registered.

Examples:
>>> mapping = TypeMapping(lambda type_: f"Default for {type_}")
>>> mapping[int] = "Integer handler"
>>> mapping[int]
'Integer handler'
>>> mapping[float]
"Default for <class 'float'>"

>>> import collections
>>> mapping[tuple] = "Tuple handler"
>>> mapping[tuple]
'Tuple handler'
>>> mapping[collections.namedtuple("Point", ["x", "y"])]
'Tuple handler'
"""

def __init__(self, fallback_factory: Callable[[type], _T]) -> None:
self._fallback_factory = fallback_factory
self._dispatcher = functools.singledispatch(self._fallback_factory)

def __getitem__(self, type_: type) -> _T:
dispatched = self._dispatcher.dispatch(type_)
return (
self._fallback_factory(type_)
if dispatched is self._fallback_factory
else cast(_T, dispatched)
)

def __setitem__(self, type_: type, value: _T) -> None:
self._dispatcher.register(type_, value) # type: ignore[call-overload] # abusine singledispatch to register any value, not just callables
self.clear_cache()

def __iter__(self) -> Iterator[type]:
return iter(self._dispatcher.registry)

def __len__(self) -> int:
return len(self._dispatcher.registry)

def __contains__(self, type_: object) -> bool:
"""Check if a type is registered in the mapping (including via superclasses)."""
return self._dispatcher.dispatch(type_) is not self._fallback_factory

@overload
def register(self, type_: type, value: _T) -> _T: ...

@overload
def register(self, type_: type, value: NothingType = NOTHING) -> Callable[[_T], _T]: ...

def register(self, type_: type, value: _T | NothingType = NOTHING) -> _T | Callable[[_T], _T]:
"""Return a decorator to register a value for the given type."""

if value is not NOTHING:
assert not isinstance(value, NothingType)
self[type_] = value
return value
else:

def _decorator(value: _T) -> _T:
self[type_] = value
return value

return _decorator

def clear_cache(self) -> None:
"""Clear the type dispatching cache."""
self._dispatcher._clear_cache()


@overload
def with_fluid_partial(
func: Literal[None] = None, *args: Any, **kwargs: Any
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/next/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
"""

# ruff: noqa: F401
from __future__ import annotations


# reexport the actual configuration manager instance as a public attribute
from ._config import Config as config_type, config # ruff: isort: skip

from .._core.definitions import CUPY_DEVICE_TYPE, Device, DeviceType, is_scalar_type
from . import common, ffront, iterator, program_processors, typing
from .common import (
Expand Down Expand Up @@ -52,6 +58,8 @@
__all__ = [
# submodules
"common",
"config",
"config_type",
"ffront",
"iterator",
"program_processors",
Expand Down
Loading
Loading