diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index a0e48ae557..ce6967bd53 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -464,6 +464,15 @@ def inner(*args: Any, **kwargs: Any) -> Any: return _decorator(func) if func is not None else _decorator +class EqualityBy(HashableBy): + """Use a hash function as the definition of equality for the wrapped object.""" + + __hash__ = HashableBy.__hash__ + + def __eq__(self, other: Any) -> bool: + return self is other or hash(self) == hash(other) + + # TODO(egparedes): it would be more efficient to implement the caching logic # here instead of relying on `functools.lru_cache` and wrapping/unwrapping the # arguments. @@ -477,7 +486,8 @@ def lru_cache( """ Wrap :func:`functools.lru_cache` but allow customizing the cache key. - Be careful: `key(obj1) == key(obj2)` must imply `obj1 == obj2`. + Be careful, with custom `key` functions, `key(obj1) == key(obj2)` automatically + implies `obj1 == obj2`, i.e. they are considered equal. >>> @lru_cache(key=id) ... def func(x): @@ -504,8 +514,8 @@ def cached_func(*args: HashableBy, **kwargs: HashableBy) -> _T: @functools.wraps(func) def inner(*args, **kwargs): # type: ignore[no-untyped-def] # cast below restores type info return cached_func( - *(hashable_by(key, arg) for arg in args), - **{k: hashable_by(key, arg) for k, arg in kwargs.items()}, + *(EqualityBy(key, arg) for arg in args), + **{k: EqualityBy(key, arg) for k, arg in kwargs.items()}, ) inner.cache_parameters = cached_func.cache_parameters # type: ignore[attr-defined] # mypy not aware of functools.lru_cache behavior diff --git a/tests/eve_tests/unit_tests/test_utils.py b/tests/eve_tests/unit_tests/test_utils.py index ae8e938396..63ddd6d68e 100644 --- a/tests/eve_tests/unit_tests/test_utils.py +++ b/tests/eve_tests/unit_tests/test_utils.py @@ -270,6 +270,22 @@ def func(x): assert cached.cache_info().misses == 1 +def test_lru_cache_no_eq_call(): + class A: + def __hash__(self) -> int: + return 1 + + def __eq__(self, other): + raise ValueError() # this function should never be called + + @eve.utils.lru_cache(key=lambda x: hash(x)) + def func(x): + pass + + func(A()) + func(A()) + + def test_fluid_partial(): from gt4py.eve.utils import fluid_partial