|
13 | 13 | from collections.abc import Callable
|
14 | 14 | from functools import partial
|
15 | 15 | from importlib import import_module
|
| 16 | +from typing import Any |
16 | 17 |
|
17 | 18 | import numpy as np
|
18 | 19 | import pandas as pd
|
|
27 | 28 | from xarray.compat import dask_array_compat, dask_array_ops
|
28 | 29 | from xarray.compat.array_api_compat import get_array_namespace
|
29 | 30 | from xarray.core import dtypes, nputils
|
| 31 | +from xarray.core.extension_array import PandasExtensionArray |
30 | 32 | from xarray.core.options import OPTIONS
|
31 | 33 | from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
|
32 | 34 | from xarray.namedarray.parallelcompat import get_chunked_array_type
|
@@ -143,6 +145,21 @@ def round(array):
|
143 | 145 | around: Callable = round
|
144 | 146 |
|
145 | 147 |
|
| 148 | +def isna(data: Any) -> bool: |
| 149 | + """Checks if data is literally np.nan or pd.NA. |
| 150 | +
|
| 151 | + Parameters |
| 152 | + ---------- |
| 153 | + data |
| 154 | + Any python object |
| 155 | +
|
| 156 | + Returns |
| 157 | + ------- |
| 158 | + Whether or not the data is np.nan or pd.NA |
| 159 | + """ |
| 160 | + return data is pd.NA or data is np.nan |
| 161 | + |
| 162 | + |
146 | 163 | def isnull(data):
|
147 | 164 | data = asarray(data)
|
148 | 165 |
|
@@ -256,13 +273,20 @@ def as_shared_dtype(scalars_or_arrays, xp=None):
|
256 | 273 | extension_array_types = [
|
257 | 274 | x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
|
258 | 275 | ]
|
259 |
| - if len(extension_array_types) == len(scalars_or_arrays) and all( |
| 276 | + non_nans = [x for x in scalars_or_arrays if not isna(x)] |
| 277 | + if len(extension_array_types) == len(non_nans) and all( |
260 | 278 | isinstance(x, type(extension_array_types[0])) for x in extension_array_types
|
261 | 279 | ):
|
262 |
| - return scalars_or_arrays |
| 280 | + return [ |
| 281 | + x |
| 282 | + if not isna(x) |
| 283 | + else PandasExtensionArray( |
| 284 | + type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype) |
| 285 | + ) |
| 286 | + for x in scalars_or_arrays |
| 287 | + ] |
263 | 288 | raise ValueError(
|
264 |
| - "Cannot cast arrays to shared type, found" |
265 |
| - f" array types {[x.dtype for x in scalars_or_arrays]}" |
| 289 | + f"Cannot cast values to shared type, found values: {scalars_or_arrays}" |
266 | 290 | )
|
267 | 291 |
|
268 | 292 | # Avoid calling array_type("cupy") repeatidely in the any check
|
|
0 commit comments