Skip to content

Commit 218ea21

Browse files
rhkleijnheadtr1ck
andauthored
improve typing of DataArray and Dataset reductions (#6746)
* improve typing of DataArray and Dataset reductions * fix some mypy errors * import Self from xarray.core.types * fix typing errors * fix remaining mypy error * fix some typing issues * add whats-new * re-add type ignore --------- Co-authored-by: Michael Niklas <[email protected]>
1 parent 8215911 commit 218ea21

File tree

6 files changed

+43
-40
lines changed

6 files changed

+43
-40
lines changed

doc/whats-new.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ New Features
2626
different collections of coordinates prior to assign them to a Dataset or
2727
DataArray (:pull:`8102`) at once.
2828
By `Benoît Bovy <https://github.com/benbovy>`_.
29-
- Provide `preferred_chunks` for data read from netcdf files (:issue:`1440`, :pull:`7948`)
29+
- Provide `preferred_chunks` for data read from netcdf files (:issue:`1440`, :pull:`7948`).
30+
By `Martin Raspaud <https://github.com/mraspaud>`_.
31+
- Improved static typing of reduction methods (:pull:`6746`).
32+
By `Richard Kleijn <https://github.com/rhkleijn>`_.
3033

3134
Breaking changes
3235
~~~~~~~~~~~~~~~~

xarray/core/_aggregations.py

+29-29
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from xarray.core import duck_array_ops
1010
from xarray.core.options import OPTIONS
11-
from xarray.core.types import Dims
11+
from xarray.core.types import Dims, Self
1212
from xarray.core.utils import contains_only_chunked_or_numpy, module_available
1313

1414
if TYPE_CHECKING:
@@ -30,7 +30,7 @@ def reduce(
3030
keep_attrs: bool | None = None,
3131
keepdims: bool = False,
3232
**kwargs: Any,
33-
) -> Dataset:
33+
) -> Self:
3434
raise NotImplementedError()
3535

3636
def count(
@@ -39,7 +39,7 @@ def count(
3939
*,
4040
keep_attrs: bool | None = None,
4141
**kwargs: Any,
42-
) -> Dataset:
42+
) -> Self:
4343
"""
4444
Reduce this Dataset's data by applying ``count`` along some dimension(s).
4545
@@ -111,7 +111,7 @@ def all(
111111
*,
112112
keep_attrs: bool | None = None,
113113
**kwargs: Any,
114-
) -> Dataset:
114+
) -> Self:
115115
"""
116116
Reduce this Dataset's data by applying ``all`` along some dimension(s).
117117
@@ -183,7 +183,7 @@ def any(
183183
*,
184184
keep_attrs: bool | None = None,
185185
**kwargs: Any,
186-
) -> Dataset:
186+
) -> Self:
187187
"""
188188
Reduce this Dataset's data by applying ``any`` along some dimension(s).
189189
@@ -256,7 +256,7 @@ def max(
256256
skipna: bool | None = None,
257257
keep_attrs: bool | None = None,
258258
**kwargs: Any,
259-
) -> Dataset:
259+
) -> Self:
260260
"""
261261
Reduce this Dataset's data by applying ``max`` along some dimension(s).
262262
@@ -343,7 +343,7 @@ def min(
343343
skipna: bool | None = None,
344344
keep_attrs: bool | None = None,
345345
**kwargs: Any,
346-
) -> Dataset:
346+
) -> Self:
347347
"""
348348
Reduce this Dataset's data by applying ``min`` along some dimension(s).
349349
@@ -430,7 +430,7 @@ def mean(
430430
skipna: bool | None = None,
431431
keep_attrs: bool | None = None,
432432
**kwargs: Any,
433-
) -> Dataset:
433+
) -> Self:
434434
"""
435435
Reduce this Dataset's data by applying ``mean`` along some dimension(s).
436436
@@ -522,7 +522,7 @@ def prod(
522522
min_count: int | None = None,
523523
keep_attrs: bool | None = None,
524524
**kwargs: Any,
525-
) -> Dataset:
525+
) -> Self:
526526
"""
527527
Reduce this Dataset's data by applying ``prod`` along some dimension(s).
528528
@@ -629,7 +629,7 @@ def sum(
629629
min_count: int | None = None,
630630
keep_attrs: bool | None = None,
631631
**kwargs: Any,
632-
) -> Dataset:
632+
) -> Self:
633633
"""
634634
Reduce this Dataset's data by applying ``sum`` along some dimension(s).
635635
@@ -736,7 +736,7 @@ def std(
736736
ddof: int = 0,
737737
keep_attrs: bool | None = None,
738738
**kwargs: Any,
739-
) -> Dataset:
739+
) -> Self:
740740
"""
741741
Reduce this Dataset's data by applying ``std`` along some dimension(s).
742742
@@ -840,7 +840,7 @@ def var(
840840
ddof: int = 0,
841841
keep_attrs: bool | None = None,
842842
**kwargs: Any,
843-
) -> Dataset:
843+
) -> Self:
844844
"""
845845
Reduce this Dataset's data by applying ``var`` along some dimension(s).
846846
@@ -943,7 +943,7 @@ def median(
943943
skipna: bool | None = None,
944944
keep_attrs: bool | None = None,
945945
**kwargs: Any,
946-
) -> Dataset:
946+
) -> Self:
947947
"""
948948
Reduce this Dataset's data by applying ``median`` along some dimension(s).
949949
@@ -1034,7 +1034,7 @@ def cumsum(
10341034
skipna: bool | None = None,
10351035
keep_attrs: bool | None = None,
10361036
**kwargs: Any,
1037-
) -> Dataset:
1037+
) -> Self:
10381038
"""
10391039
Reduce this Dataset's data by applying ``cumsum`` along some dimension(s).
10401040
@@ -1127,7 +1127,7 @@ def cumprod(
11271127
skipna: bool | None = None,
11281128
keep_attrs: bool | None = None,
11291129
**kwargs: Any,
1130-
) -> Dataset:
1130+
) -> Self:
11311131
"""
11321132
Reduce this Dataset's data by applying ``cumprod`` along some dimension(s).
11331133
@@ -1226,7 +1226,7 @@ def reduce(
12261226
keep_attrs: bool | None = None,
12271227
keepdims: bool = False,
12281228
**kwargs: Any,
1229-
) -> DataArray:
1229+
) -> Self:
12301230
raise NotImplementedError()
12311231

12321232
def count(
@@ -1235,7 +1235,7 @@ def count(
12351235
*,
12361236
keep_attrs: bool | None = None,
12371237
**kwargs: Any,
1238-
) -> DataArray:
1238+
) -> Self:
12391239
"""
12401240
Reduce this DataArray's data by applying ``count`` along some dimension(s).
12411241
@@ -1301,7 +1301,7 @@ def all(
13011301
*,
13021302
keep_attrs: bool | None = None,
13031303
**kwargs: Any,
1304-
) -> DataArray:
1304+
) -> Self:
13051305
"""
13061306
Reduce this DataArray's data by applying ``all`` along some dimension(s).
13071307
@@ -1367,7 +1367,7 @@ def any(
13671367
*,
13681368
keep_attrs: bool | None = None,
13691369
**kwargs: Any,
1370-
) -> DataArray:
1370+
) -> Self:
13711371
"""
13721372
Reduce this DataArray's data by applying ``any`` along some dimension(s).
13731373
@@ -1434,7 +1434,7 @@ def max(
14341434
skipna: bool | None = None,
14351435
keep_attrs: bool | None = None,
14361436
**kwargs: Any,
1437-
) -> DataArray:
1437+
) -> Self:
14381438
"""
14391439
Reduce this DataArray's data by applying ``max`` along some dimension(s).
14401440
@@ -1513,7 +1513,7 @@ def min(
15131513
skipna: bool | None = None,
15141514
keep_attrs: bool | None = None,
15151515
**kwargs: Any,
1516-
) -> DataArray:
1516+
) -> Self:
15171517
"""
15181518
Reduce this DataArray's data by applying ``min`` along some dimension(s).
15191519
@@ -1592,7 +1592,7 @@ def mean(
15921592
skipna: bool | None = None,
15931593
keep_attrs: bool | None = None,
15941594
**kwargs: Any,
1595-
) -> DataArray:
1595+
) -> Self:
15961596
"""
15971597
Reduce this DataArray's data by applying ``mean`` along some dimension(s).
15981598
@@ -1676,7 +1676,7 @@ def prod(
16761676
min_count: int | None = None,
16771677
keep_attrs: bool | None = None,
16781678
**kwargs: Any,
1679-
) -> DataArray:
1679+
) -> Self:
16801680
"""
16811681
Reduce this DataArray's data by applying ``prod`` along some dimension(s).
16821682
@@ -1773,7 +1773,7 @@ def sum(
17731773
min_count: int | None = None,
17741774
keep_attrs: bool | None = None,
17751775
**kwargs: Any,
1776-
) -> DataArray:
1776+
) -> Self:
17771777
"""
17781778
Reduce this DataArray's data by applying ``sum`` along some dimension(s).
17791779
@@ -1870,7 +1870,7 @@ def std(
18701870
ddof: int = 0,
18711871
keep_attrs: bool | None = None,
18721872
**kwargs: Any,
1873-
) -> DataArray:
1873+
) -> Self:
18741874
"""
18751875
Reduce this DataArray's data by applying ``std`` along some dimension(s).
18761876
@@ -1964,7 +1964,7 @@ def var(
19641964
ddof: int = 0,
19651965
keep_attrs: bool | None = None,
19661966
**kwargs: Any,
1967-
) -> DataArray:
1967+
) -> Self:
19681968
"""
19691969
Reduce this DataArray's data by applying ``var`` along some dimension(s).
19701970
@@ -2057,7 +2057,7 @@ def median(
20572057
skipna: bool | None = None,
20582058
keep_attrs: bool | None = None,
20592059
**kwargs: Any,
2060-
) -> DataArray:
2060+
) -> Self:
20612061
"""
20622062
Reduce this DataArray's data by applying ``median`` along some dimension(s).
20632063
@@ -2140,7 +2140,7 @@ def cumsum(
21402140
skipna: bool | None = None,
21412141
keep_attrs: bool | None = None,
21422142
**kwargs: Any,
2143-
) -> DataArray:
2143+
) -> Self:
21442144
"""
21452145
Reduce this DataArray's data by applying ``cumsum`` along some dimension(s).
21462146
@@ -2229,7 +2229,7 @@ def cumprod(
22292229
skipna: bool | None = None,
22302230
keep_attrs: bool | None = None,
22312231
**kwargs: Any,
2232-
) -> DataArray:
2232+
) -> Self:
22332233
"""
22342234
Reduce this DataArray's data by applying ``cumprod`` along some dimension(s).
22352235

xarray/core/computation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1402,14 +1402,14 @@ def _cov_corr(
14021402
) / (valid_count)
14031403

14041404
if method == "cov":
1405-
return cov # type: ignore[return-value]
1405+
return cov
14061406

14071407
else:
14081408
# compute std + corr
14091409
da_a_std = da_a.std(dim=dim)
14101410
da_b_std = da_b.std(dim=dim)
14111411
corr = cov / (da_a_std * da_b_std)
1412-
return corr # type: ignore[return-value]
1412+
return corr
14131413

14141414

14151415
def cross(

xarray/core/rolling.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,8 @@ def reduce(
475475
obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna
476476
)
477477

478-
result = windows.reduce(
479-
func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs
480-
)
478+
dim = list(rolling_dim.values())
479+
result = windows.reduce(func, dim=dim, keep_attrs=keep_attrs, **kwargs)
481480

482481
# Find valid windows based on count.
483482
counts = self._counts(keep_attrs=False)
@@ -494,14 +493,15 @@ def _counts(self, keep_attrs: bool | None) -> DataArray:
494493
# array is faster to be reduced than object array.
495494
# The use of skipna==False is also faster since it does not need to
496495
# copy the strided array.
496+
dim = list(rolling_dim.values())
497497
counts = (
498498
self.obj.notnull(keep_attrs=keep_attrs)
499499
.rolling(
500500
{d: w for d, w in zip(self.dim, self.window)},
501501
center={d: self.center[i] for i, d in enumerate(self.dim)},
502502
)
503503
.construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs)
504-
.sum(dim=list(rolling_dim.values()), skipna=False, keep_attrs=keep_attrs)
504+
.sum(dim=dim, skipna=False, keep_attrs=keep_attrs)
505505
)
506506
return counts
507507

xarray/tests/test_groupby.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def test_groupby_dataset_iter() -> None:
728728
def test_groupby_dataset_errors() -> None:
729729
data = create_test_data()
730730
with pytest.raises(TypeError, match=r"`group` must be"):
731-
data.groupby(np.arange(10)) # type: ignore
731+
data.groupby(np.arange(10)) # type: ignore[arg-type,unused-ignore]
732732
with pytest.raises(ValueError, match=r"length does not match"):
733733
data.groupby(data["dim1"][:3])
734734
with pytest.raises(TypeError, match=r"`group` must be"):

xarray/util/generate_aggregations.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
2828
from xarray.core import duck_array_ops
2929
from xarray.core.options import OPTIONS
30-
from xarray.core.types import Dims
30+
from xarray.core.types import Dims, Self
3131
from xarray.core.utils import contains_only_chunked_or_numpy, module_available
3232
3333
if TYPE_CHECKING:
@@ -50,7 +50,7 @@ def reduce(
5050
keep_attrs: bool | None = None,
5151
keepdims: bool = False,
5252
**kwargs: Any,
53-
) -> {obj}:
53+
) -> Self:
5454
raise NotImplementedError()"""
5555

5656
GROUPBY_PREAMBLE = """
@@ -108,7 +108,7 @@ def {method}(
108108
*,{extra_kwargs}
109109
keep_attrs: bool | None = None,
110110
**kwargs: Any,
111-
) -> {obj}:
111+
) -> Self:
112112
"""
113113
Reduce this {obj}'s data by applying ``{method}`` along some dimension(s).
114114

0 commit comments

Comments
 (0)