diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 096a427e425..f1631a5ea9e 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Any, Generic, cast @@ -148,4 +149,14 @@ def __getattr__(self, attr: str) -> Any: # Thus, if we didn't have `super().__getattribute__("array")` this method would call `self.array` (i.e., `getattr(self, "array")`) again while looking for `__setstate__` # (which is apparently the first thing sought in copy.copy from the under-construction copied object), # which would cause a recursion error since `array` is not present on the object when it is being constructed during `__{deep}copy__`. + # Even though we have defined these two methods now below due to `test_extension_array_copy_arrow_type` (cause unknown) + # we leave this here as it more robust than self.array return getattr(super().__getattribute__("array"), attr) + + def __copy__(self) -> PandasExtensionArray[T_ExtensionArray]: + return PandasExtensionArray(copy.copy(self.array)) + + def __deepcopy__( + self, memo: dict[int, Any] | None = None + ) -> PandasExtensionArray[T_ExtensionArray]: + return PandasExtensionArray(copy.deepcopy(self.array, memo=memo)) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 4052d414f63..eaafe2d4536 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import datetime as dt import pickle import warnings @@ -200,6 +201,16 @@ def test_extension_array_pyarrow_concatenate(self, arrow1, arrow2): assert concatenated[2].array[0]["x"] == 3 assert concatenated[3].array[0]["y"] + @requires_pyarrow + def test_extension_array_copy_arrow_type(self): + arr = pd.array([pd.NA, 1, 2], dtype="int64[pyarrow]") + # Relying on the `__getattr__` of `PandasExtensionArray` to do the deep copy + # recursively only fails for `int64[pyarrow]` and similar types so this + # test ensures that copying still works there. + assert isinstance( + copy.deepcopy(PandasExtensionArray(arr), memo=None).array, type(arr) + ) + def test___getitem__extension_duck_array(self, categorical1): extension_duck_array = PandasExtensionArray(categorical1) assert (extension_duck_array[0:2] == categorical1[0:2]).all()