Skip to content

Commit dd9a46b

Browse files
authored
(fix): allow upcasting of nans in as_shared_dtype for extension arrays (#10292)
1 parent 7d00e0a commit dd9a46b

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

xarray/core/duck_array_ops.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from collections.abc import Callable
1414
from functools import partial
1515
from importlib import import_module
16+
from typing import Any
1617

1718
import numpy as np
1819
import pandas as pd
@@ -27,6 +28,7 @@
2728
from xarray.compat import dask_array_compat, dask_array_ops
2829
from xarray.compat.array_api_compat import get_array_namespace
2930
from xarray.core import dtypes, nputils
31+
from xarray.core.extension_array import PandasExtensionArray
3032
from xarray.core.options import OPTIONS
3133
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
3234
from xarray.namedarray.parallelcompat import get_chunked_array_type
@@ -143,6 +145,21 @@ def round(array):
143145
around: Callable = round
144146

145147

148+
def isna(data: Any) -> bool:
149+
"""Checks if data is literally np.nan or pd.NA.
150+
151+
Parameters
152+
----------
153+
data
154+
Any python object
155+
156+
Returns
157+
-------
158+
Whether or not the data is np.nan or pd.NA
159+
"""
160+
return data is pd.NA or data is np.nan
161+
162+
146163
def isnull(data):
147164
data = asarray(data)
148165

@@ -256,13 +273,20 @@ def as_shared_dtype(scalars_or_arrays, xp=None):
256273
extension_array_types = [
257274
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
258275
]
259-
if len(extension_array_types) == len(scalars_or_arrays) and all(
276+
non_nans = [x for x in scalars_or_arrays if not isna(x)]
277+
if len(extension_array_types) == len(non_nans) and all(
260278
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
261279
):
262-
return scalars_or_arrays
280+
return [
281+
x
282+
if not isna(x)
283+
else PandasExtensionArray(
284+
type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype)
285+
)
286+
for x in scalars_or_arrays
287+
]
263288
raise ValueError(
264-
"Cannot cast arrays to shared type, found"
265-
f" array types {[x.dtype for x in scalars_or_arrays]}"
289+
f"Cannot cast values to shared type, found values: {scalars_or_arrays}"
266290
)
267291

268292
# Avoid calling array_type("cupy") repeatidely in the any check

xarray/tests/test_dataset.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from copy import copy, deepcopy
99
from io import StringIO
1010
from textwrap import dedent
11-
from typing import Any, Literal
11+
from typing import Any, Literal, cast
1212

1313
import numpy as np
1414
import pandas as pd
@@ -1827,6 +1827,26 @@ def test_categorical_index_reindex(self) -> None:
18271827
actual = ds.reindex(cat=["foo"])["cat"].values
18281828
assert (actual == np.array(["foo"])).all()
18291829

1830+
@pytest.mark.parametrize("fill_value", [np.nan, pd.NA])
1831+
def test_extensionarray_negative_reindex(self, fill_value) -> None:
1832+
cat = pd.Categorical(
1833+
["foo", "bar", "baz"],
1834+
categories=["foo", "bar", "baz", "qux", "quux", "corge"],
1835+
)
1836+
ds = xr.Dataset(
1837+
{"cat": ("index", cat)},
1838+
coords={"index": ("index", np.arange(3))},
1839+
)
1840+
reindexed_cat = cast(
1841+
pd.api.extensions.ExtensionArray,
1842+
(
1843+
ds.reindex(index=[-1, 1, 1], fill_value=fill_value)["cat"]
1844+
.to_pandas()
1845+
.values
1846+
),
1847+
)
1848+
assert reindexed_cat.equals(pd.array([pd.NA, "bar", "bar"], dtype=cat.dtype)) # type: ignore[attr-defined]
1849+
18301850
def test_extension_array_reindex_same(self) -> None:
18311851
series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype())
18321852
test = xr.Dataset({"test": series})

0 commit comments

Comments
 (0)