From d2afe4e26aba72359619f7008f6401fe741a07a8 Mon Sep 17 00:00:00 2001
From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
Date: Fri, 25 Oct 2024 12:05:00 +0200
Subject: [PATCH] Hopefully fixing deserialization issues once and for all.

---
 jaxtyping/_array_types.py | 31 +++++++--------
 test/requirements.txt     |  1 +
 test/test_array.py        | 80 ++++++++++++++++++++++++---------------
 test/test_equals.py       | 35 -----------------
 4 files changed, 65 insertions(+), 82 deletions(-)
 delete mode 100644 test/test_equals.py

diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py
index 51f33da..b6d4c53 100644
--- a/jaxtyping/_array_types.py
+++ b/jaxtyping/_array_types.py
@@ -320,26 +320,23 @@ def _check_shape(
 @ft.lru_cache(maxsize=None)
 def _make_metaclass(base_metaclass):
     class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
-        def _get_props(cls):
-            props_tuple = (
-                cls.index_variadic,
-                cls.dims,
-                cls.array_type,
-                cls.dtypes,
-                cls.dim_str,
-            )
-            return props_tuple
-
+        # We have to use identity-based eq/hash behaviour. The reason for this is that
+        # when deserializing using cloudpickle (very common, it seems), that cloudpickle
+        # will actually attempt to put a partially constructed class in a dictionary.
+        # So if we start accessing `cls.index_variadic` and the like here, then that
+        # explodes.
+        # See
+        # https://github.com/patrick-kidger/jaxtyping/issues/198
+        # https://github.com/patrick-kidger/jaxtyping/issues/261
+        #
+        # This does mean that if you want to compare two array annotations for equality
+        # (e.g. this happens in jaxtyping's tests as part of checking correctness) then
+        # a custom equality function must be used -- we can't put it here.
         def __eq__(cls, other):
-            if type(cls) is not type(other):
-                return False
-
-            return cls._get_props() == other._get_props()
+            return cls is other
 
         def __hash__(cls):
-            # Does not use `_get_props` as these attributes don't necessarily exist
-            # during depickling. See #198.
-            return 0
+            return id(cls)
 
     return MetaAbstractArray
 
diff --git a/test/requirements.txt b/test/requirements.txt
index fcd8451..44fc63b 100644
--- a/test/requirements.txt
+++ b/test/requirements.txt
@@ -3,6 +3,7 @@ cloudpickle
 equinox
 IPython
 jax
+numpy<2
 pytest
 pytest-asyncio
 tensorflow
diff --git a/test/test_array.py b/test/test_array.py
index 4600d54..99d7d45 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -33,6 +33,7 @@
     torch = None
 
 from jaxtyping import (
+    AbstractArray,
     AbstractDtype,
     AnnotationError,
     Array,
@@ -528,6 +529,15 @@ class A:
         A(3, jnp.zeros(4))
 
 
+def _to_set(x) -> set[tuple]:
+    return {
+        (xi.index_variadic, xi.dims, xi.array_type, xi.dtypes, xi.dim_str)
+        if issubclass(xi, AbstractArray)
+        else xi
+        for xi in x
+    }
+
+
 def test_arraylike(typecheck, getkey):
     floatlike1 = Float32[ArrayLike, ""]
     floatlike2 = Float[ArrayLike, ""]
@@ -536,41 +546,51 @@ def test_arraylike(typecheck, getkey):
     assert get_origin(floatlike1) is Union
     assert get_origin(floatlike2) is Union
     assert get_origin(floatlike3) is Union
-    assert set(get_args(floatlike1)) == {
-        Float32[Array, ""],
-        Float32[np.ndarray, ""],
-        Float32[np.number, ""],
-        float,
-    }
-    assert set(get_args(floatlike2)) == {
-        Float[Array, ""],
-        Float[np.ndarray, ""],
-        Float[np.number, ""],
-        float,
-    }
-    assert set(get_args(floatlike3)) == {
-        Float32[Array, "4"],
-        Float32[np.ndarray, "4"],
-    }
+    assert _to_set(get_args(floatlike1)) == _to_set(
+        [
+            Float32[Array, ""],
+            Float32[np.ndarray, ""],
+            Float32[np.number, ""],
+            float,
+        ]
+    )
+    assert _to_set(get_args(floatlike2)) == _to_set(
+        [
+            Float[Array, ""],
+            Float[np.ndarray, ""],
+            Float[np.number, ""],
+            float,
+        ]
+    )
+    assert _to_set(get_args(floatlike3)) == _to_set(
+        [
+            Float32[Array, "4"],
+            Float32[np.ndarray, "4"],
+        ]
+    )
 
     shaped1 = Shaped[ArrayLike, ""]
     shaped2 = Shaped[ArrayLike, "4"]
     assert get_origin(shaped1) is Union
     assert get_origin(shaped2) is Union
-    assert set(get_args(shaped1)) == {
-        Shaped[Array, ""],
-        Shaped[np.ndarray, ""],
-        Shaped[np.bool_, ""],
-        Shaped[np.number, ""],
-        bool,
-        int,
-        float,
-        complex,
-    }
-    assert set(get_args(shaped2)) == {
-        Shaped[Array, "4"],
-        Shaped[np.ndarray, "4"],
-    }
+    assert _to_set(get_args(shaped1)) == _to_set(
+        [
+            Shaped[Array, ""],
+            Shaped[np.ndarray, ""],
+            Shaped[np.bool_, ""],
+            Shaped[np.number, ""],
+            bool,
+            int,
+            float,
+            complex,
+        ]
+    )
+    assert _to_set(get_args(shaped2)) == _to_set(
+        [
+            Shaped[Array, "4"],
+            Shaped[np.ndarray, "4"],
+        ]
+    )
 
 
 def test_subclass():
diff --git a/test/test_equals.py b/test/test_equals.py
deleted file mode 100644
index b648ff6..0000000
--- a/test/test_equals.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from typing import Tuple, Union
-
-import pytest
-
-from jaxtyping import (
-    Array,
-    Float,
-    Float32,
-    Integer,
-    PRNGKeyArray,
-    PyTree,
-    Shaped,
-)
-
-
-@pytest.mark.parametrize(
-    "make_fn",
-    [
-        lambda: Float[Array, "4"],
-        lambda: Float32[Array, ""],
-        lambda: Integer[Array, "1 2 3"],
-        lambda: Shaped[PRNGKeyArray, "2"],
-        lambda: Float[float, "#*shape"],
-        lambda: PyTree[int],
-        lambda: PyTree[Float[Array, ""]],
-        lambda: PyTree[Float32[Array, "*m b c"]],
-        lambda: PyTree[PyTree[Float32[Array, "1 2 b *"]]],
-        lambda: PyTree[Union[str, Float32[Array, "1"]]],
-        lambda: PyTree[
-            Tuple[int, float, Float[Array, ""], PyTree[Union[Float[Array, ""], float]]]
-        ],
-    ],
-)
-def test_equals(make_fn):
-    assert make_fn() == make_fn()