Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 77 additions & 32 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import (
Any,
ClassVar,
Final,
Iterable,
Never,
Optional,
Expand Down Expand Up @@ -102,6 +104,32 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
_R = TypeVar("_R", _Value, tuple[_Value, ...])


_GT4PY_FIELD_METADATA_NS: Final = "__gt4py__"


def _gt4py_field_metadata(**kwargs: Any) -> dict[str, dict[str, Any]]:
"""Helper function to create a metadata dictionary for dataclass fields with a GT4Py-specific namespace."""
return {_GT4PY_FIELD_METADATA_NS: {**kwargs}}


@functools.cache
def _get_pickleable_field_names(dataclass_type: type) -> tuple[str, ...]:
"""Return the field names of a dataclass type."""
if not isinstance(dataclass_type, type) or not dataclasses.is_dataclass(dataclass_type):
raise TypeError(f"Expected a dataclass type, got '{dataclass_type}'")

names = []
for field in dataclasses.fields(dataclass_type):
metadata = field.metadata.get("_GT4PY_FIELD_METADATA_NS", None)
# Individual fields can opt out of serialization via dataclass
# definitions: field(<other args>, metadata=_gt4py_field_metadata(skip_pickle=True))
if metadata and metadata.get("skip_pickle"):
continue
names.append(field.name)

return tuple(names)


@dataclasses.dataclass(frozen=True)
class NdArrayField(
common.MutableField[common.DimsT, core_defs.ScalarT], common.FieldBuiltinFuncRegistry
Expand Down Expand Up @@ -443,6 +471,20 @@ def _slice(
assert common.is_relative_index_sequence(slice_)
return new_domain, slice_

def __getstate__(self) -> dict[str, Any]:
# Make sure we only copy dataclass instance fields to get rid of runtime-only
# cached attributes that must not influence serialization or fingerprints.
return {name: getattr(self, name) for name in _get_pickleable_field_names(type(self))}

def __setstate__(self, state: dict[str, Any]) -> None:
if hasattr(self, "__dict__"):
self.__dict__.clear()
self.__dict__.update(state)
else:
# In case the dataclass is frozen or uses slots, we need to use `object.__setattr__` for each field
for key, value in state.items():
object.__setattr__(self, key, value)

if dace:

def _dace_data_ptr(self) -> int:
Expand Down Expand Up @@ -492,9 +534,35 @@ def __post_init__(self) -> None:
self.domain.dim_index(self.codomain) is not None
)

@functools.cached_property
def _cache(self) -> dict:
return {}
@classmethod
def from_array( # type: ignore[override]
cls,
data: npt.ArrayLike | core_defs.NDArrayObject,
/,
codomain: common.DimT,
*,
domain: common.DomainLike,
dtype: Optional[core_defs.DTypeLike] = None,
skip_value: Optional[core_defs.IntegralScalar] = None,
) -> NdArrayConnectivityField:
domain = common.domain(domain)
xp = cls.array_ns

xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type)
array = xp.asarray(data, dtype=xp_dtype)

if dtype is not None:
assert array.dtype.type == core_defs.dtype(dtype).scalar_type

assert issubclass(array.dtype.type, core_defs.INTEGRAL_TYPES)

assert all(isinstance(d, common.Dimension) for d in domain.dims), domain
assert len(domain) == array.ndim
assert all(len(r) == s or s == 1 for r, s in zip(domain.ranges, array.shape))

assert isinstance(codomain, common.Dimension)

return cls(domain, array, codomain, _skip_value=skip_value)

@classmethod
def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ignore[override]
Expand Down Expand Up @@ -525,35 +593,12 @@ def kind(self) -> common.ConnectivityKind:

return self._kind

@classmethod
def from_array( # type: ignore[override]
cls,
data: npt.ArrayLike | core_defs.NDArrayObject,
/,
codomain: common.DimT,
*,
domain: common.DomainLike,
dtype: Optional[core_defs.DTypeLike] = None,
skip_value: Optional[core_defs.IntegralScalar] = None,
) -> NdArrayConnectivityField:
domain = common.domain(domain)
xp = cls.array_ns

xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type)
array = xp.asarray(data, dtype=xp_dtype)

if dtype is not None:
assert array.dtype.type == core_defs.dtype(dtype).scalar_type

assert issubclass(array.dtype.type, core_defs.INTEGRAL_TYPES)

assert all(isinstance(d, common.Dimension) for d in domain.dims), domain
assert len(domain) == array.ndim
assert all(len(r) == s or s == 1 for r, s in zip(domain.ranges, array.shape))

assert isinstance(codomain, common.Dimension)

return cls(domain, array, codomain, _skip_value=skip_value)
# This embedded run-time cache is only used to speed up repeated calls to
# `inverse_image` and `restrict`, and it should not be considered part of
# the connectivity field definition, and therefore it should not be serialized.
@functools.cached_property
def _cache(self) -> dict:
return {}

def inverse_image(self, image_range: common.UnitRange | common.NamedRange) -> common.Domain:
cache_key = hash((id(self.ndarray), self.domain, image_range))
Expand Down
100 changes: 100 additions & 0 deletions tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import math
import operator
import pickle
from typing import Callable, Iterable, Optional

import numpy as np
Expand Down Expand Up @@ -819,6 +820,61 @@ def test_setitem_wrong_domain():
field[(1, slice(None))] = value_incompatible


def test_nd_array_field_getstate_excludes_cached_properties():
"""Test that __getstate__ only serializes dataclass fields, excluding cached properties."""
data = np.asarray([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64)
field = constructors.as_field((D0, D1), data)

# Access a cached property to populate the cache
_ = field.__gt_buffer_info__
assert "__gt_buffer_info__" in field.__dict__

state = field.__getstate__()

# State should only contain dataclass instance fields (_domain, _ndarray)
assert "_domain" in state
assert "_ndarray" in state
assert "array_ns" not in state # ClassVar
assert "__gt_buffer_info__" not in state
assert "__dict__" not in state


def test_nd_array_field_setstate_restores_state():
"""Test that __setstate__ properly restores field state without cached data."""
original = constructors.as_field((D0, D1), np.arange(6.0).reshape(2, 3))

# Create a new field and restore state
restored = constructors.as_field((D0, D1), np.zeros((2, 3)))
original_state = original.__getstate__()
restored.__setstate__(original_state)

# Verify state is restored
assert restored.domain == original.domain
assert np.array_equal(restored.ndarray, original.ndarray)
assert restored.dtype == original.dtype


def test_nd_array_field_pickle_roundtrip():
"""Test that NdArrayField can be pickled and unpickled correctly using getstate/setstate."""
original = constructors.as_field((D0, D1), np.arange(12.0).reshape(3, 4))

# Access cached property to ensure it's not included in serialization
_ = original.__gt_buffer_info__
assert "__gt_buffer_info__" in original.__dict__

# Perform a real pickle roundtrip, exercising __getstate__ and __setstate__
pickled = pickle.dumps(original)
restored = pickle.loads(pickled)

# Verify restoration
assert restored.domain == original.domain
assert np.array_equal(restored.ndarray, original.ndarray)
assert restored.dtype == original.dtype
assert restored.shape == original.shape
# Cached property should not be present in the restored instance
assert "__gt_buffer_info__" not in restored.__dict__


def test_nd_array_connectivity_field_buffer_info(nd_array_implementation):
import dataclasses

Expand All @@ -839,6 +895,50 @@ def test_nd_array_connectivity_field_buffer_info(nd_array_implementation):
assert buffer_info is e2v_conn.__gt_buffer_info__


def test_nd_array_connectivity_field_getstate_excludes_runtime_cache():
V = Dimension("V")
E = Dimension("E")

e2v_conn = common._connectivity(
np.asarray([2, 3, 4, 5]),
domain=common.domain([common.named_range((E, (0, 4)))]),
codomain=V,
)

_ = e2v_conn.inverse_image(UnitRange(2, 5))
assert "_cache" in e2v_conn.__dict__

state = e2v_conn.__getstate__()
assert "_cache" not in state
assert state["_codomain"] == V
assert state["_skip_value"] is None


def test_nd_array_connectivity_field_setstate_restores_state_without_cache():
V = Dimension("V")
E = Dimension("E")

original = common._connectivity(
np.asarray([2, 3, 4, 5]),
domain=common.domain([common.named_range((E, (0, 4)))]),
codomain=V,
skip_value=-1,
)
restored = common._connectivity(
np.asarray([0, 0, 0, 0]),
domain=common.domain([common.named_range((E, (0, 4)))]),
codomain=V,
)

state = original.__getstate__()
restored.__setstate__(state)

assert restored.codomain == original.codomain
assert restored.skip_value == original.skip_value
assert np.array_equal(restored.ndarray, original.ndarray)
assert "_cache" not in restored.__dict__


def test_connectivity_field_inverse_image():
V = Dimension("V")
E = Dimension("E")
Expand Down
Loading