Skip to content

Commit 05c7888

Browse files
authored
Array API fixes for astype (#7847)
* array API fixes for astype * whatsnew
1 parent 97a2032 commit 05c7888

File tree

3 files changed

+49
-33
lines changed

3 files changed

+49
-33
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ Documentation
4646
Internal Changes
4747
~~~~~~~~~~~~~~~~
4848

49+
- Minor improvements to support of the python `array api standard <https://data-apis.org/array-api/latest/>`_,
50+
internally using the function ``xp.astype()`` instead of the method ``arr.astype()``, as the latter is not in the standard.
51+
(:pull:`7847`) By `Tom Nicholas <https://github.com/TomNicholas>`_.
4952

5053
.. _whats-new.2023.05.0:
5154

xarray/core/accessor_str.py

+43-30
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
import numpy as np
5353

54+
from xarray.core import duck_array_ops
5455
from xarray.core.computation import apply_ufunc
5556
from xarray.core.types import T_DataArray
5657

@@ -2085,13 +2086,16 @@ def _get_res_multi(val, pat):
20852086
else:
20862087
# dtype MUST be object or strings can be truncated
20872088
# See: https://github.com/numpy/numpy/issues/8352
2088-
return self._apply(
2089-
func=_get_res_multi,
2090-
func_args=(pat,),
2091-
dtype=np.object_,
2092-
output_core_dims=[[dim]],
2093-
output_sizes={dim: maxgroups},
2094-
).astype(self._obj.dtype.kind)
2089+
return duck_array_ops.astype(
2090+
self._apply(
2091+
func=_get_res_multi,
2092+
func_args=(pat,),
2093+
dtype=np.object_,
2094+
output_core_dims=[[dim]],
2095+
output_sizes={dim: maxgroups},
2096+
),
2097+
self._obj.dtype.kind,
2098+
)
20952099

20962100
def extractall(
20972101
self,
@@ -2258,15 +2262,18 @@ def _get_res(val, ipat, imaxcount=maxcount, dtype=self._obj.dtype):
22582262

22592263
return res
22602264

2261-
return self._apply(
2262-
# dtype MUST be object or strings can be truncated
2263-
# See: https://github.com/numpy/numpy/issues/8352
2264-
func=_get_res,
2265-
func_args=(pat,),
2266-
dtype=np.object_,
2267-
output_core_dims=[[group_dim, match_dim]],
2268-
output_sizes={group_dim: maxgroups, match_dim: maxcount},
2269-
).astype(self._obj.dtype.kind)
2265+
return duck_array_ops.astype(
2266+
self._apply(
2267+
# dtype MUST be object or strings can be truncated
2268+
# See: https://github.com/numpy/numpy/issues/8352
2269+
func=_get_res,
2270+
func_args=(pat,),
2271+
dtype=np.object_,
2272+
output_core_dims=[[group_dim, match_dim]],
2273+
output_sizes={group_dim: maxgroups, match_dim: maxcount},
2274+
),
2275+
self._obj.dtype.kind,
2276+
)
22702277

22712278
def findall(
22722279
self,
@@ -2385,13 +2392,16 @@ def _partitioner(
23852392

23862393
# dtype MUST be object or strings can be truncated
23872394
# See: https://github.com/numpy/numpy/issues/8352
2388-
return self._apply(
2389-
func=arrfunc,
2390-
func_args=(sep,),
2391-
dtype=np.object_,
2392-
output_core_dims=[[dim]],
2393-
output_sizes={dim: 3},
2394-
).astype(self._obj.dtype.kind)
2395+
return duck_array_ops.astype(
2396+
self._apply(
2397+
func=arrfunc,
2398+
func_args=(sep,),
2399+
dtype=np.object_,
2400+
output_core_dims=[[dim]],
2401+
output_sizes={dim: 3},
2402+
),
2403+
self._obj.dtype.kind,
2404+
)
23952405

23962406
def partition(
23972407
self,
@@ -2510,13 +2520,16 @@ def _dosplit(mystr, sep, maxsplit=maxsplit, dtype=self._obj.dtype):
25102520

25112521
# dtype MUST be object or strings can be truncated
25122522
# See: https://github.com/numpy/numpy/issues/8352
2513-
return self._apply(
2514-
func=_dosplit,
2515-
func_args=(sep,),
2516-
dtype=np.object_,
2517-
output_core_dims=[[dim]],
2518-
output_sizes={dim: maxsplit},
2519-
).astype(self._obj.dtype.kind)
2523+
return duck_array_ops.astype(
2524+
self._apply(
2525+
func=_dosplit,
2526+
func_args=(sep,),
2527+
dtype=np.object_,
2528+
output_core_dims=[[dim]],
2529+
output_sizes={dim: maxsplit},
2530+
),
2531+
self._obj.dtype.kind,
2532+
)
25202533

25212534
def split(
25222535
self,

xarray/core/variable.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1421,7 +1421,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA):
14211421
pads = [(0, 0) if d != dim else dim_pad for d in self.dims]
14221422

14231423
data = np.pad(
1424-
trimmed_data.astype(dtype),
1424+
duck_array_ops.astype(trimmed_data, dtype),
14251425
pads,
14261426
mode="constant",
14271427
constant_values=fill_value,
@@ -1570,7 +1570,7 @@ def pad(
15701570
pad_option_kwargs["reflect_type"] = reflect_type
15711571

15721572
array = np.pad(
1573-
self.data.astype(dtype, copy=False),
1573+
duck_array_ops.astype(self.data, dtype, copy=False),
15741574
pad_width_by_index,
15751575
mode=mode,
15761576
**pad_option_kwargs,
@@ -2438,7 +2438,7 @@ def rolling_window(
24382438
"""
24392439
if fill_value is dtypes.NA: # np.nan is passed
24402440
dtype, fill_value = dtypes.maybe_promote(self.dtype)
2441-
var = self.astype(dtype, copy=False)
2441+
var = duck_array_ops.astype(self, dtype, copy=False)
24422442
else:
24432443
dtype = self.dtype
24442444
var = self

0 commit comments

Comments
 (0)