diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 8fc32e75cbd..f7f8e1beac0 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -134,10 +134,21 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None: xr.testing.assert_identical(dataset, roundtripped.to_xarray()) -def test_roundtrip_1d_pandas_extension_array() -> None: - df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])}) - arr = xr.Dataset.from_dataframe(df)["cat"] +@pytest.mark.parametrize( + "extension_array", + [ + pd.Categorical(["a", "b", "c"]), + pd.array([1, 2, 3], dtype="int64"), + pd.array(["a", "b", "c"], dtype="string"), + pd.arrays.IntervalArray( + [pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)] + ), + ], +) +def test_roundtrip_1d_pandas_extension_array(extension_array) -> None: + df = pd.DataFrame({"arr": extension_array}) + arr = xr.Dataset.from_dataframe(df)["arr"] roundtripped = arr.to_pandas() - assert (df["cat"] == roundtripped).all() - assert df["cat"].dtype == roundtripped.dtype + assert (df["arr"] == roundtripped).all() + assert df["arr"].dtype == roundtripped.dtype xr.testing.assert_identical(arr, roundtripped.to_xarray()) diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index c959a7f2536..132e6553412 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -4,6 +4,7 @@ from typing import Any import numpy as np +import pandas as pd from pandas.api.types import is_extension_array_dtype from xarray.compat import array_api_compat, npcompat @@ -63,7 +64,9 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: # N.B. these casting rules should match pandas dtype_: np.typing.DTypeLike fill_value: Any - if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()): + if pd.api.types.is_extension_array_dtype(dtype): + return dtype, pd.NA + elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()): # for now, we always promote string dtypes to object for consistency with existing behavior # TODO: refactor this once we have a better way to handle numpy vlen-string dtypes dtype_ = object diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 48dc3b7627a..9d7eaf055eb 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -273,18 +273,35 @@ def as_shared_dtype(scalars_or_arrays, xp=None): extension_array_types = [ x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) ] - non_nans = [x for x in scalars_or_arrays if not isna(x)] - if len(extension_array_types) == len(non_nans) and all( + non_nans_or_scalar = [ + x for x in scalars_or_arrays if not (isna(x) or np.isscalar(x)) + ] + if len(extension_array_types) == len(non_nans_or_scalar) and all( isinstance(x, type(extension_array_types[0])) for x in extension_array_types ): - return [ + # Get the extension array class of the first element, guaranteed to be the same + # as the others thanks to the anove check. + extension_array_class = type( + non_nans_or_scalar[0].array + if isinstance(non_nans_or_scalar[0], PandasExtensionArray) + else non_nans_or_scalar[0] + ) + # Cast scalars/nans to extension array class + arrays_with_nan_to_sequence = [ x - if not isna(x) - else PandasExtensionArray( - type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype) + if not (isna(x) or np.isscalar(x)) + else extension_array_class._from_sequence( + [x], dtype=non_nans_or_scalar[0].dtype ) for x in scalars_or_arrays ] + # Wrap the output if necessary + return [ + PandasExtensionArray(x) + if not isinstance(x, PandasExtensionArray) + else x + for x in arrays_with_nan_to_sequence + ] raise ValueError( f"Cannot cast values to shared type, found values: {scalars_or_arrays}" ) @@ -407,7 +424,6 @@ def where(condition, x, y): condition = asarray(condition, dtype=dtype, xp=xp) else: condition = astype(condition, dtype=dtype, xp=xp) - return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index f1631a5ea9e..584cf17c559 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -52,6 +52,17 @@ def __extension_duck_array__concatenate( return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined] +@implements(np.reshape) +def __extension_duck_array__reshape( + arr: T_ExtensionArray, shape: tuple +) -> T_ExtensionArray: + if (shape[0] == len(arr) and len(shape) == 1) or shape == (-1,): + return arr + raise NotImplementedError( + f"Cannot reshape 1d-only pandas extension array to: {shape}" + ) + + @implements(np.where) def __extension_duck_array__where( condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray @@ -101,16 +112,20 @@ def replace_duck_with_extension_array(args) -> list: args = tuple(replace_duck_with_extension_array(args)) if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: - return func(*args, **kwargs) + raise KeyError("Function not registered for pandas extension arrays.") res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) if is_extension_array_dtype(res): - return type(self)[type(res)](res) + return PandasExtensionArray(res) return res def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return ufunc(*inputs, **kwargs) def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: + if ( + isinstance(key, tuple) and len(key) == 1 + ): # pyarrow type arrays can't handle since-length tuples + key = key[0] item = self.array[key] if is_extension_array_dtype(item): return PandasExtensionArray(item) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 7aa333ffb2e..86fb147d382 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -19,6 +19,7 @@ from xarray.core.datatree_render import RenderDataTree from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype, ravel +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import MemoryCachedArray from xarray.core.options import OPTIONS, _get_boolean_with_default from xarray.core.treenode import group_subtrees @@ -176,6 +177,8 @@ def format_timedelta(t, timedelta_format=None): def format_item(x, timedelta_format=None, quote_strings=True): """Returns a succinct summary of an object as a string""" + if isinstance(x, PandasExtensionArray): + return f"{x.array[0]}" if isinstance(x, np.datetime64 | datetime): return format_timestamp(x) if isinstance(x, np.timedelta64 | timedelta): @@ -194,7 +197,9 @@ def format_items(x): """Returns a succinct summaries of all items in a sequence as strings""" x = to_duck_array(x) timedelta_format = "datetime" - if np.issubdtype(x.dtype, np.timedelta64): + if not isinstance(x, PandasExtensionArray) and np.issubdtype( + x.dtype, np.timedelta64 + ): x = astype(x, dtype="timedelta64[ns]") day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]") time_needed = x[~pd.isnull(x)] != day_part diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index b33192393f7..fe76df75fa0 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -363,6 +363,14 @@ def create_test_data( ) ), ) + if has_pyarrow: + obj["var5"] = ( + "dim1", + pd.array( + rs.integers(1, 10, size=dim_sizes[0]).tolist(), + dtype="int64[pyarrow]", + ), + ) if dim_sizes == _DEFAULT_TEST_DIM_SIZES: numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") else: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 49c6490d819..ed5aac4fe99 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -21,6 +21,7 @@ assert_equal, assert_identical, requires_dask, + requires_pyarrow, ) from xarray.tests.test_dataset import create_test_data @@ -154,19 +155,20 @@ def test_concat_missing_var() -> None: assert_identical(actual, expected) -def test_concat_categorical() -> None: +@pytest.mark.parametrize("var", ["var4", pytest.param("var5", marks=requires_pyarrow)]) +def test_concat_extension_array(var) -> None: data1 = create_test_data(use_extension_array=True) data2 = create_test_data(use_extension_array=True) concatenated = concat([data1, data2], dim="dim1") - assert ( - concatenated["var4"] - == type(data2["var4"].variable.data)._concat_same_type( + assert pd.Series( + concatenated[var] + == type(data2[var].variable.data)._concat_same_type( [ - data1["var4"].variable.data, - data2["var4"].variable.data, + data1[var].variable.data, + data2[var].variable.data, ] ) - ).all() + ).all() # need to wrap in series because pyarrow bool does not support `all` def test_concat_missing_multiple_consecutive_var() -> None: diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e7acdcdd4f3..da93a028349 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -53,6 +53,7 @@ assert_no_warnings, has_dask, has_dask_ge_2025_1_0, + has_pyarrow, raise_if_dask_computes, requires_bottleneck, requires_cupy, @@ -61,6 +62,7 @@ requires_iris, requires_numexpr, requires_pint, + requires_pyarrow, requires_scipy, requires_sparse, source_ndarray, @@ -3075,6 +3077,58 @@ def test_propagate_attrs(self, func) -> None: with set_options(keep_attrs=True): assert func(da).attrs == da.attrs + @pytest.mark.parametrize( + "fill_value,extension_array", + [ + pytest.param("a", pd.Categorical([pd.NA, "a", "b"]), id="categorical"), + ] + + [ + pytest.param( + 0, + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + id="int64[pyarrow]", + ) + ] + if has_pyarrow + else [], + ) + def test_fillna_extension_array(self, fill_value, extension_array) -> None: + srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array) + da = srs.to_xarray() + filled = da.fillna(fill_value) + assert filled.dtype == srs.dtype + assert (filled.values == np.array([fill_value, *(srs.values[1:])])).all() + + @requires_pyarrow + def test_fillna_extension_array_bad_val(self) -> None: + srs: pd.Series = pd.Series( + index=np.array([1, 2, 3]), + data=pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + ) + da = srs.to_xarray() + with pytest.raises(ValueError): + da.fillna("a") + + @pytest.mark.parametrize( + "extension_array", + [ + pytest.param(pd.Categorical([pd.NA, "a", "b"]), id="categorical"), + ] + + [ + pytest.param( + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), id="int64[pyarrow]" + ) + ] + if has_pyarrow + else [], + ) + def test_dropna_extension_array(self, extension_array) -> None: + srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array) + da = srs.to_xarray() + filled = da.dropna("index") + assert filled.dtype == srs.dtype + assert (filled.values == srs.values[1:]).all() + def test_fillna(self) -> None: a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") actual = a.fillna(-1) @@ -3637,7 +3691,7 @@ def test_series_categorical_index(self) -> None: s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list("aabbc"))) arr = DataArray(s) - assert "'a'" in repr(arr) # should not error + assert "a a b b c" in repr(arr) # should not error @pytest.mark.parametrize("use_dask", [True, False]) @pytest.mark.parametrize("data", ["list", "array", True]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index aba7760bf4d..bd6b4478d00 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -58,6 +58,7 @@ create_test_data, has_cftime, has_dask, + has_pyarrow, raise_if_dask_computes, requires_bottleneck, requires_cftime, @@ -65,6 +66,7 @@ requires_dask, requires_numexpr, requires_pint, + requires_pyarrow, requires_scipy, requires_sparse, source_ndarray, @@ -297,12 +299,15 @@ def test_repr(self) -> None: var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364 var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423 var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555 - var4 (dim1) category 32B 'b' 'c' 'b' 'a' 'c' 'a' 'c' 'a' + var4 (dim1) category 32B b c b a c a c a{} Attributes: - foo: bar""".format( - data["dim3"].dtype, - "ns", - ) + foo: bar""" + ).format( + data["dim3"].dtype, + "ns", + "\n var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1" + if has_pyarrow + else "", ) actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) @@ -1827,28 +1832,48 @@ def test_categorical_index_reindex(self) -> None: actual = ds.reindex(cat=["foo"])["cat"].values assert (actual == np.array(["foo"])).all() - @pytest.mark.parametrize("fill_value", [np.nan, pd.NA]) - def test_extensionarray_negative_reindex(self, fill_value) -> None: - cat = pd.Categorical( - ["foo", "bar", "baz"], - categories=["foo", "bar", "baz", "qux", "quux", "corge"], - ) + @pytest.mark.parametrize("fill_value", [np.nan, pd.NA, None]) + @pytest.mark.parametrize( + "extension_array", + [ + pytest.param( + pd.Categorical( + ["foo", "bar", "baz"], + categories=["foo", "bar", "baz", "qux"], + ), + id="categorical", + ), + ] + + [ + pytest.param( + pd.array([1, 1, None], dtype="int64[pyarrow]"), id="int64[pyarrow]" + ) + ] + if has_pyarrow + else [], + ) + def test_extensionarray_negative_reindex(self, fill_value, extension_array) -> None: ds = xr.Dataset( - {"cat": ("index", cat)}, + {"arr": ("index", extension_array)}, coords={"index": ("index", np.arange(3))}, ) + kwargs = {} + if fill_value is not None: + kwargs["fill_value"] = fill_value reindexed_cat = cast( pd.api.extensions.ExtensionArray, - ( - ds.reindex(index=[-1, 1, 1], fill_value=fill_value)["cat"] - .to_pandas() - .values - ), + (ds.reindex(index=[-1, 1, 1], **kwargs)["arr"].to_pandas().values), + ) + assert reindexed_cat.equals( # type: ignore[attr-defined] + pd.array( + [pd.NA, extension_array[1], extension_array[1]], + dtype=extension_array.dtype, + ) ) - assert reindexed_cat.equals(pd.array([pd.NA, "bar", "bar"], dtype=cat.dtype)) # type: ignore[attr-defined] + @requires_pyarrow def test_extension_array_reindex_same(self) -> None: - series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype()) + series = pd.Series([1, 2, pd.NA, 3], dtype="int32[pyarrow]") test = xr.Dataset({"test": series}) res = test.reindex(dim_0=series.index) align(res, test, join="exact") @@ -5531,6 +5556,51 @@ def test_dropna(self) -> None: with pytest.raises(TypeError, match=r"must specify how or thresh"): ds.dropna("a", how=None) # type: ignore[arg-type] + @pytest.mark.parametrize( + "fill_value,extension_array", + [ + pytest.param("a", pd.Categorical([pd.NA, "a", "b"]), id="category"), + ] + + [ + pytest.param( + 0, + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + id="int64[pyarrow]", + ) + ] + if has_pyarrow + else [], + ) + def test_fillna_extension_array(self, fill_value, extension_array) -> None: + srs = pd.DataFrame({"data": extension_array}, index=np.array([1, 2, 3])) + ds = srs.to_xarray() + filled = ds.fillna(fill_value) + assert filled["data"].dtype == extension_array.dtype + assert ( + filled["data"].values + == np.array([fill_value, *srs["data"].values[1:]], dtype="object") + ).all() + + @pytest.mark.parametrize( + "extension_array", + [ + pytest.param(pd.Categorical([pd.NA, "a", "b"]), id="category"), + ] + + [ + pytest.param( + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), id="int64[pyarrow]" + ) + ] + if has_pyarrow + else [], + ) + def test_dropna_extension_array(self, extension_array) -> None: + srs = pd.DataFrame({"data": extension_array}, index=np.array([1, 2, 3])) + ds = srs.to_xarray() + dropped = ds.dropna("index") + assert dropped["data"].dtype == extension_array.dtype + assert (dropped["data"].values == srs["data"].values[1:]).all() + def test_fillna(self) -> None: ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]})