Skip to content

Commit 97ec434

Browse files
authored
Reimplement Datatree typed ops (#9619)
* test unary op * implement and generate unary ops * test for unary op with inherited coordinates * re-enable arithmetic tests * implementation for binary ops * test ds * dt commutativity * ensure other types defer to DataTree, thus fixing #9365 * test for inplace binary op * pseudocode implementation of inplace binary op, and xfail test * remove some unneeded type: ignore comments * return type should be DataTree * type datatree ops as accepting dataset-compatible types too * use same type hinting hack as Dataset does for __eq__ not being same as Mapping * ignore return type * add some methods to api docs * don't try to import DataTree.astype in API docs * test to check that single-node trees aren't broadcast * return NotImplemented * remove pseudocode for inplace binary ops * map_over_subtree -> map_over_datasets
1 parent 33ead65 commit 97ec434

9 files changed

+316
-22
lines changed

doc/api.rst

+9-9
Original file line numberDiff line numberDiff line change
@@ -849,20 +849,20 @@ Aggregate data in all nodes in the subtree simultaneously.
849849
DataTree.cumsum
850850
DataTree.cumprod
851851

852-
.. ndarray methods
853-
.. ---------------
852+
ndarray methods
853+
---------------
854854

855-
.. Methods copied from :py:class:`numpy.ndarray` objects, here applying to the data in all nodes in the subtree.
855+
Methods copied from :py:class:`numpy.ndarray` objects, here applying to the data in all nodes in the subtree.
856856

857-
.. .. autosummary::
858-
.. :toctree: generated/
857+
.. autosummary::
858+
:toctree: generated/
859859

860-
.. DataTree.argsort
860+
DataTree.argsort
861+
DataTree.conj
862+
DataTree.conjugate
863+
DataTree.round
861864
.. DataTree.astype
862865
.. DataTree.clip
863-
.. DataTree.conj
864-
.. DataTree.conjugate
865-
.. DataTree.round
866866
.. DataTree.rank
867867
868868
.. Reshaping and reorganising

xarray/core/_typed_ops.py

+162
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from xarray.core.types import (
1313
DaCompatible,
1414
DsCompatible,
15+
DtCompatible,
1516
Self,
1617
T_Xarray,
1718
VarCompatible,
@@ -23,6 +24,167 @@
2324
from xarray.core.types import T_DataArray as T_DA
2425

2526

27+
class DataTreeOpsMixin:
28+
__slots__ = ()
29+
30+
def _binary_op(
31+
self, other: DtCompatible, f: Callable, reflexive: bool = False
32+
) -> Self:
33+
raise NotImplementedError
34+
35+
def __add__(self, other: DtCompatible) -> Self:
36+
return self._binary_op(other, operator.add)
37+
38+
def __sub__(self, other: DtCompatible) -> Self:
39+
return self._binary_op(other, operator.sub)
40+
41+
def __mul__(self, other: DtCompatible) -> Self:
42+
return self._binary_op(other, operator.mul)
43+
44+
def __pow__(self, other: DtCompatible) -> Self:
45+
return self._binary_op(other, operator.pow)
46+
47+
def __truediv__(self, other: DtCompatible) -> Self:
48+
return self._binary_op(other, operator.truediv)
49+
50+
def __floordiv__(self, other: DtCompatible) -> Self:
51+
return self._binary_op(other, operator.floordiv)
52+
53+
def __mod__(self, other: DtCompatible) -> Self:
54+
return self._binary_op(other, operator.mod)
55+
56+
def __and__(self, other: DtCompatible) -> Self:
57+
return self._binary_op(other, operator.and_)
58+
59+
def __xor__(self, other: DtCompatible) -> Self:
60+
return self._binary_op(other, operator.xor)
61+
62+
def __or__(self, other: DtCompatible) -> Self:
63+
return self._binary_op(other, operator.or_)
64+
65+
def __lshift__(self, other: DtCompatible) -> Self:
66+
return self._binary_op(other, operator.lshift)
67+
68+
def __rshift__(self, other: DtCompatible) -> Self:
69+
return self._binary_op(other, operator.rshift)
70+
71+
def __lt__(self, other: DtCompatible) -> Self:
72+
return self._binary_op(other, operator.lt)
73+
74+
def __le__(self, other: DtCompatible) -> Self:
75+
return self._binary_op(other, operator.le)
76+
77+
def __gt__(self, other: DtCompatible) -> Self:
78+
return self._binary_op(other, operator.gt)
79+
80+
def __ge__(self, other: DtCompatible) -> Self:
81+
return self._binary_op(other, operator.ge)
82+
83+
def __eq__(self, other: DtCompatible) -> Self: # type:ignore[override]
84+
return self._binary_op(other, nputils.array_eq)
85+
86+
def __ne__(self, other: DtCompatible) -> Self: # type:ignore[override]
87+
return self._binary_op(other, nputils.array_ne)
88+
89+
# When __eq__ is defined but __hash__ is not, then an object is unhashable,
90+
# and it should be declared as follows:
91+
__hash__: None # type:ignore[assignment]
92+
93+
def __radd__(self, other: DtCompatible) -> Self:
94+
return self._binary_op(other, operator.add, reflexive=True)
95+
96+
def __rsub__(self, other: DtCompatible) -> Self:
97+
return self._binary_op(other, operator.sub, reflexive=True)
98+
99+
def __rmul__(self, other: DtCompatible) -> Self:
100+
return self._binary_op(other, operator.mul, reflexive=True)
101+
102+
def __rpow__(self, other: DtCompatible) -> Self:
103+
return self._binary_op(other, operator.pow, reflexive=True)
104+
105+
def __rtruediv__(self, other: DtCompatible) -> Self:
106+
return self._binary_op(other, operator.truediv, reflexive=True)
107+
108+
def __rfloordiv__(self, other: DtCompatible) -> Self:
109+
return self._binary_op(other, operator.floordiv, reflexive=True)
110+
111+
def __rmod__(self, other: DtCompatible) -> Self:
112+
return self._binary_op(other, operator.mod, reflexive=True)
113+
114+
def __rand__(self, other: DtCompatible) -> Self:
115+
return self._binary_op(other, operator.and_, reflexive=True)
116+
117+
def __rxor__(self, other: DtCompatible) -> Self:
118+
return self._binary_op(other, operator.xor, reflexive=True)
119+
120+
def __ror__(self, other: DtCompatible) -> Self:
121+
return self._binary_op(other, operator.or_, reflexive=True)
122+
123+
def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self:
124+
raise NotImplementedError
125+
126+
def __neg__(self) -> Self:
127+
return self._unary_op(operator.neg)
128+
129+
def __pos__(self) -> Self:
130+
return self._unary_op(operator.pos)
131+
132+
def __abs__(self) -> Self:
133+
return self._unary_op(operator.abs)
134+
135+
def __invert__(self) -> Self:
136+
return self._unary_op(operator.invert)
137+
138+
def round(self, *args: Any, **kwargs: Any) -> Self:
139+
return self._unary_op(ops.round_, *args, **kwargs)
140+
141+
def argsort(self, *args: Any, **kwargs: Any) -> Self:
142+
return self._unary_op(ops.argsort, *args, **kwargs)
143+
144+
def conj(self, *args: Any, **kwargs: Any) -> Self:
145+
return self._unary_op(ops.conj, *args, **kwargs)
146+
147+
def conjugate(self, *args: Any, **kwargs: Any) -> Self:
148+
return self._unary_op(ops.conjugate, *args, **kwargs)
149+
150+
__add__.__doc__ = operator.add.__doc__
151+
__sub__.__doc__ = operator.sub.__doc__
152+
__mul__.__doc__ = operator.mul.__doc__
153+
__pow__.__doc__ = operator.pow.__doc__
154+
__truediv__.__doc__ = operator.truediv.__doc__
155+
__floordiv__.__doc__ = operator.floordiv.__doc__
156+
__mod__.__doc__ = operator.mod.__doc__
157+
__and__.__doc__ = operator.and_.__doc__
158+
__xor__.__doc__ = operator.xor.__doc__
159+
__or__.__doc__ = operator.or_.__doc__
160+
__lshift__.__doc__ = operator.lshift.__doc__
161+
__rshift__.__doc__ = operator.rshift.__doc__
162+
__lt__.__doc__ = operator.lt.__doc__
163+
__le__.__doc__ = operator.le.__doc__
164+
__gt__.__doc__ = operator.gt.__doc__
165+
__ge__.__doc__ = operator.ge.__doc__
166+
__eq__.__doc__ = nputils.array_eq.__doc__
167+
__ne__.__doc__ = nputils.array_ne.__doc__
168+
__radd__.__doc__ = operator.add.__doc__
169+
__rsub__.__doc__ = operator.sub.__doc__
170+
__rmul__.__doc__ = operator.mul.__doc__
171+
__rpow__.__doc__ = operator.pow.__doc__
172+
__rtruediv__.__doc__ = operator.truediv.__doc__
173+
__rfloordiv__.__doc__ = operator.floordiv.__doc__
174+
__rmod__.__doc__ = operator.mod.__doc__
175+
__rand__.__doc__ = operator.and_.__doc__
176+
__rxor__.__doc__ = operator.xor.__doc__
177+
__ror__.__doc__ = operator.or_.__doc__
178+
__neg__.__doc__ = operator.neg.__doc__
179+
__pos__.__doc__ = operator.pos.__doc__
180+
__abs__.__doc__ = operator.abs.__doc__
181+
__invert__.__doc__ = operator.invert.__doc__
182+
round.__doc__ = ops.round_.__doc__
183+
argsort.__doc__ = ops.argsort.__doc__
184+
conj.__doc__ = ops.conj.__doc__
185+
conjugate.__doc__ = ops.conjugate.__doc__
186+
187+
26188
class DatasetOpsMixin:
27189
__slots__ = ()
28190

xarray/core/dataarray.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4765,9 +4765,10 @@ def _unary_op(self, f: Callable, *args, **kwargs) -> Self:
47654765
def _binary_op(
47664766
self, other: DaCompatible, f: Callable, reflexive: bool = False
47674767
) -> Self:
4768+
from xarray.core.datatree import DataTree
47684769
from xarray.core.groupby import GroupBy
47694770

4770-
if isinstance(other, Dataset | GroupBy):
4771+
if isinstance(other, DataTree | Dataset | GroupBy):
47714772
return NotImplemented
47724773
if isinstance(other, DataArray):
47734774
align_type = OPTIONS["arithmetic_join"]

xarray/core/dataset.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7784,9 +7784,10 @@ def _unary_op(self, f, *args, **kwargs) -> Self:
77847784

77857785
def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset:
77867786
from xarray.core.dataarray import DataArray
7787+
from xarray.core.datatree import DataTree
77877788
from xarray.core.groupby import GroupBy
77887789

7789-
if isinstance(other, GroupBy):
7790+
if isinstance(other, DataTree | GroupBy):
77907791
return NotImplemented
77917792
align_type = OPTIONS["arithmetic_join"] if join is None else join
77927793
if isinstance(other, DataArray | Dataset):

xarray/core/datatree.py

+40
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import functools
34
import itertools
45
import textwrap
56
from collections import ChainMap
@@ -15,6 +16,7 @@
1516

1617
from xarray.core import utils
1718
from xarray.core._aggregations import DataTreeAggregations
19+
from xarray.core._typed_ops import DataTreeOpsMixin
1820
from xarray.core.alignment import align
1921
from xarray.core.common import TreeAttrAccessMixin
2022
from xarray.core.coordinates import Coordinates, DataTreeCoordinates
@@ -60,6 +62,7 @@
6062
from xarray.core.merge import CoercibleMapping, CoercibleValue
6163
from xarray.core.types import (
6264
Dims,
65+
DtCompatible,
6366
ErrorOptions,
6467
ErrorOptionsWithWarn,
6568
NetcdfWriteModes,
@@ -403,6 +406,7 @@ def map( # type: ignore[override]
403406
class DataTree(
404407
NamedNode["DataTree"],
405408
DataTreeAggregations,
409+
DataTreeOpsMixin,
406410
TreeAttrAccessMixin,
407411
Mapping[str, "DataArray | DataTree"],
408412
):
@@ -1486,6 +1490,42 @@ def groups(self):
14861490
"""Return all groups in the tree, given as a tuple of path-like strings."""
14871491
return tuple(node.path for node in self.subtree)
14881492

1493+
def _unary_op(self, f, *args, **kwargs) -> DataTree:
1494+
# TODO do we need to any additional work to avoid duplication etc.? (Similar to aggregations)
1495+
return self.map_over_datasets(f, *args, **kwargs) # type: ignore[return-value]
1496+
1497+
def _binary_op(self, other, f, reflexive=False, join=None) -> DataTree:
1498+
from xarray.core.dataset import Dataset
1499+
from xarray.core.groupby import GroupBy
1500+
1501+
if isinstance(other, GroupBy):
1502+
return NotImplemented
1503+
1504+
ds_binop = functools.partial(
1505+
Dataset._binary_op,
1506+
f=f,
1507+
reflexive=reflexive,
1508+
join=join,
1509+
)
1510+
return map_over_datasets(ds_binop)(self, other)
1511+
1512+
def _inplace_binary_op(self, other, f) -> Self:
1513+
from xarray.core.groupby import GroupBy
1514+
1515+
if isinstance(other, GroupBy):
1516+
raise TypeError(
1517+
"in-place operations between a DataTree and "
1518+
"a grouped object are not permitted"
1519+
)
1520+
1521+
# TODO see GH issue #9629 for required implementation
1522+
raise NotImplementedError()
1523+
1524+
# TODO: dirty workaround for mypy 1.5 error with inherited DatasetOpsMixin vs. Mapping
1525+
# related to https://github.com/python/mypy/issues/9319?
1526+
def __eq__(self, other: DtCompatible) -> Self: # type: ignore[override]
1527+
return super().__eq__(other)
1528+
14891529
def to_netcdf(
14901530
self,
14911531
filepath,

xarray/core/types.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from xarray.core.coordinates import Coordinates
4242
from xarray.core.dataarray import DataArray
4343
from xarray.core.dataset import Dataset
44+
from xarray.core.datatree import DataTree
4445
from xarray.core.indexes import Index, Indexes
4546
from xarray.core.utils import Frozen
4647
from xarray.core.variable import IndexVariable, Variable
@@ -194,6 +195,7 @@ def copy(
194195
VarCompatible = Union["Variable", "ScalarOrArray"]
195196
DaCompatible = Union["DataArray", "VarCompatible"]
196197
DsCompatible = Union["Dataset", "DaCompatible"]
198+
DtCompatible = Union["DataTree", "DsCompatible"]
197199
GroupByCompatible = Union["Dataset", "DataArray"]
198200

199201
# Don't change to Hashable | Collection[Hashable]

xarray/core/variable.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2323,7 +2323,7 @@ def _unary_op(self, f, *args, **kwargs):
23232323
return result
23242324

23252325
def _binary_op(self, other, f, reflexive=False):
2326-
if isinstance(other, xr.DataArray | xr.Dataset):
2326+
if isinstance(other, xr.DataTree | xr.DataArray | xr.Dataset):
23272327
return NotImplemented
23282328
if reflexive and issubclass(type(self), type(other)):
23292329
other_data, self_data, dims = _broadcast_compat_data(other, self)

0 commit comments

Comments
 (0)