From d2afe4e26aba72359619f7008f6401fe741a07a8 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 25 Oct 2024 12:05:00 +0200 Subject: [PATCH] Hopefully fixing deserialization issues once and for all. --- jaxtyping/_array_types.py | 31 +++++++-------- test/requirements.txt | 1 + test/test_array.py | 80 ++++++++++++++++++++++++--------------- test/test_equals.py | 35 ----------------- 4 files changed, 65 insertions(+), 82 deletions(-) delete mode 100644 test/test_equals.py diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py index 51f33da..b6d4c53 100644 --- a/jaxtyping/_array_types.py +++ b/jaxtyping/_array_types.py @@ -320,26 +320,23 @@ def _check_shape( @ft.lru_cache(maxsize=None) def _make_metaclass(base_metaclass): class MetaAbstractArray(_MetaAbstractArray, base_metaclass): - def _get_props(cls): - props_tuple = ( - cls.index_variadic, - cls.dims, - cls.array_type, - cls.dtypes, - cls.dim_str, - ) - return props_tuple - + # We have to use identity-based eq/hash behaviour. The reason for this is that + # when deserializing using cloudpickle (very common, it seems), that cloudpickle + # will actually attempt to put a partially constructed class in a dictionary. + # So if we start accessing `cls.index_variadic` and the like here, then that + # explodes. + # See + # https://github.com/patrick-kidger/jaxtyping/issues/198 + # https://github.com/patrick-kidger/jaxtyping/issues/261 + # + # This does mean that if you want to compare two array annotations for equality + # (e.g. this happens in jaxtyping's tests as part of checking correctness) then + # a custom equality function must be used -- we can't put it here. def __eq__(cls, other): - if type(cls) is not type(other): - return False - - return cls._get_props() == other._get_props() + return cls is other def __hash__(cls): - # Does not use `_get_props` as these attributes don't necessarily exist - # during depickling. See #198. - return 0 + return id(cls) return MetaAbstractArray diff --git a/test/requirements.txt b/test/requirements.txt index fcd8451..44fc63b 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -3,6 +3,7 @@ cloudpickle equinox IPython jax +numpy<2 pytest pytest-asyncio tensorflow diff --git a/test/test_array.py b/test/test_array.py index 4600d54..99d7d45 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -33,6 +33,7 @@ torch = None from jaxtyping import ( + AbstractArray, AbstractDtype, AnnotationError, Array, @@ -528,6 +529,15 @@ class A: A(3, jnp.zeros(4)) +def _to_set(x) -> set[tuple]: + return { + (xi.index_variadic, xi.dims, xi.array_type, xi.dtypes, xi.dim_str) + if issubclass(xi, AbstractArray) + else xi + for xi in x + } + + def test_arraylike(typecheck, getkey): floatlike1 = Float32[ArrayLike, ""] floatlike2 = Float[ArrayLike, ""] @@ -536,41 +546,51 @@ def test_arraylike(typecheck, getkey): assert get_origin(floatlike1) is Union assert get_origin(floatlike2) is Union assert get_origin(floatlike3) is Union - assert set(get_args(floatlike1)) == { - Float32[Array, ""], - Float32[np.ndarray, ""], - Float32[np.number, ""], - float, - } - assert set(get_args(floatlike2)) == { - Float[Array, ""], - Float[np.ndarray, ""], - Float[np.number, ""], - float, - } - assert set(get_args(floatlike3)) == { - Float32[Array, "4"], - Float32[np.ndarray, "4"], - } + assert _to_set(get_args(floatlike1)) == _to_set( + [ + Float32[Array, ""], + Float32[np.ndarray, ""], + Float32[np.number, ""], + float, + ] + ) + assert _to_set(get_args(floatlike2)) == _to_set( + [ + Float[Array, ""], + Float[np.ndarray, ""], + Float[np.number, ""], + float, + ] + ) + assert _to_set(get_args(floatlike3)) == _to_set( + [ + Float32[Array, "4"], + Float32[np.ndarray, "4"], + ] + ) shaped1 = Shaped[ArrayLike, ""] shaped2 = Shaped[ArrayLike, "4"] assert get_origin(shaped1) is Union assert get_origin(shaped2) is Union - assert set(get_args(shaped1)) == { - Shaped[Array, ""], - Shaped[np.ndarray, ""], - Shaped[np.bool_, ""], - Shaped[np.number, ""], - bool, - int, - float, - complex, - } - assert set(get_args(shaped2)) == { - Shaped[Array, "4"], - Shaped[np.ndarray, "4"], - } + assert _to_set(get_args(shaped1)) == _to_set( + [ + Shaped[Array, ""], + Shaped[np.ndarray, ""], + Shaped[np.bool_, ""], + Shaped[np.number, ""], + bool, + int, + float, + complex, + ] + ) + assert _to_set(get_args(shaped2)) == _to_set( + [ + Shaped[Array, "4"], + Shaped[np.ndarray, "4"], + ] + ) def test_subclass(): diff --git a/test/test_equals.py b/test/test_equals.py deleted file mode 100644 index b648ff6..0000000 --- a/test/test_equals.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Tuple, Union - -import pytest - -from jaxtyping import ( - Array, - Float, - Float32, - Integer, - PRNGKeyArray, - PyTree, - Shaped, -) - - -@pytest.mark.parametrize( - "make_fn", - [ - lambda: Float[Array, "4"], - lambda: Float32[Array, ""], - lambda: Integer[Array, "1 2 3"], - lambda: Shaped[PRNGKeyArray, "2"], - lambda: Float[float, "#*shape"], - lambda: PyTree[int], - lambda: PyTree[Float[Array, ""]], - lambda: PyTree[Float32[Array, "*m b c"]], - lambda: PyTree[PyTree[Float32[Array, "1 2 b *"]]], - lambda: PyTree[Union[str, Float32[Array, "1"]]], - lambda: PyTree[ - Tuple[int, float, Float[Array, ""], PyTree[Union[Float[Array, ""], float]]] - ], - ], -) -def test_equals(make_fn): - assert make_fn() == make_fn()