|
12 | 12 | from xarray.core.types import (
|
13 | 13 | DaCompatible,
|
14 | 14 | DsCompatible,
|
| 15 | + DtCompatible, |
15 | 16 | Self,
|
16 | 17 | T_Xarray,
|
17 | 18 | VarCompatible,
|
|
23 | 24 | from xarray.core.types import T_DataArray as T_DA
|
24 | 25 |
|
25 | 26 |
|
| 27 | +class DataTreeOpsMixin: |
| 28 | + __slots__ = () |
| 29 | + |
| 30 | + def _binary_op( |
| 31 | + self, other: DtCompatible, f: Callable, reflexive: bool = False |
| 32 | + ) -> Self: |
| 33 | + raise NotImplementedError |
| 34 | + |
| 35 | + def __add__(self, other: DtCompatible) -> Self: |
| 36 | + return self._binary_op(other, operator.add) |
| 37 | + |
| 38 | + def __sub__(self, other: DtCompatible) -> Self: |
| 39 | + return self._binary_op(other, operator.sub) |
| 40 | + |
| 41 | + def __mul__(self, other: DtCompatible) -> Self: |
| 42 | + return self._binary_op(other, operator.mul) |
| 43 | + |
| 44 | + def __pow__(self, other: DtCompatible) -> Self: |
| 45 | + return self._binary_op(other, operator.pow) |
| 46 | + |
| 47 | + def __truediv__(self, other: DtCompatible) -> Self: |
| 48 | + return self._binary_op(other, operator.truediv) |
| 49 | + |
| 50 | + def __floordiv__(self, other: DtCompatible) -> Self: |
| 51 | + return self._binary_op(other, operator.floordiv) |
| 52 | + |
| 53 | + def __mod__(self, other: DtCompatible) -> Self: |
| 54 | + return self._binary_op(other, operator.mod) |
| 55 | + |
| 56 | + def __and__(self, other: DtCompatible) -> Self: |
| 57 | + return self._binary_op(other, operator.and_) |
| 58 | + |
| 59 | + def __xor__(self, other: DtCompatible) -> Self: |
| 60 | + return self._binary_op(other, operator.xor) |
| 61 | + |
| 62 | + def __or__(self, other: DtCompatible) -> Self: |
| 63 | + return self._binary_op(other, operator.or_) |
| 64 | + |
| 65 | + def __lshift__(self, other: DtCompatible) -> Self: |
| 66 | + return self._binary_op(other, operator.lshift) |
| 67 | + |
| 68 | + def __rshift__(self, other: DtCompatible) -> Self: |
| 69 | + return self._binary_op(other, operator.rshift) |
| 70 | + |
| 71 | + def __lt__(self, other: DtCompatible) -> Self: |
| 72 | + return self._binary_op(other, operator.lt) |
| 73 | + |
| 74 | + def __le__(self, other: DtCompatible) -> Self: |
| 75 | + return self._binary_op(other, operator.le) |
| 76 | + |
| 77 | + def __gt__(self, other: DtCompatible) -> Self: |
| 78 | + return self._binary_op(other, operator.gt) |
| 79 | + |
| 80 | + def __ge__(self, other: DtCompatible) -> Self: |
| 81 | + return self._binary_op(other, operator.ge) |
| 82 | + |
| 83 | + def __eq__(self, other: DtCompatible) -> Self: # type:ignore[override] |
| 84 | + return self._binary_op(other, nputils.array_eq) |
| 85 | + |
| 86 | + def __ne__(self, other: DtCompatible) -> Self: # type:ignore[override] |
| 87 | + return self._binary_op(other, nputils.array_ne) |
| 88 | + |
| 89 | + # When __eq__ is defined but __hash__ is not, then an object is unhashable, |
| 90 | + # and it should be declared as follows: |
| 91 | + __hash__: None # type:ignore[assignment] |
| 92 | + |
| 93 | + def __radd__(self, other: DtCompatible) -> Self: |
| 94 | + return self._binary_op(other, operator.add, reflexive=True) |
| 95 | + |
| 96 | + def __rsub__(self, other: DtCompatible) -> Self: |
| 97 | + return self._binary_op(other, operator.sub, reflexive=True) |
| 98 | + |
| 99 | + def __rmul__(self, other: DtCompatible) -> Self: |
| 100 | + return self._binary_op(other, operator.mul, reflexive=True) |
| 101 | + |
| 102 | + def __rpow__(self, other: DtCompatible) -> Self: |
| 103 | + return self._binary_op(other, operator.pow, reflexive=True) |
| 104 | + |
| 105 | + def __rtruediv__(self, other: DtCompatible) -> Self: |
| 106 | + return self._binary_op(other, operator.truediv, reflexive=True) |
| 107 | + |
| 108 | + def __rfloordiv__(self, other: DtCompatible) -> Self: |
| 109 | + return self._binary_op(other, operator.floordiv, reflexive=True) |
| 110 | + |
| 111 | + def __rmod__(self, other: DtCompatible) -> Self: |
| 112 | + return self._binary_op(other, operator.mod, reflexive=True) |
| 113 | + |
| 114 | + def __rand__(self, other: DtCompatible) -> Self: |
| 115 | + return self._binary_op(other, operator.and_, reflexive=True) |
| 116 | + |
| 117 | + def __rxor__(self, other: DtCompatible) -> Self: |
| 118 | + return self._binary_op(other, operator.xor, reflexive=True) |
| 119 | + |
| 120 | + def __ror__(self, other: DtCompatible) -> Self: |
| 121 | + return self._binary_op(other, operator.or_, reflexive=True) |
| 122 | + |
| 123 | + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: |
| 124 | + raise NotImplementedError |
| 125 | + |
| 126 | + def __neg__(self) -> Self: |
| 127 | + return self._unary_op(operator.neg) |
| 128 | + |
| 129 | + def __pos__(self) -> Self: |
| 130 | + return self._unary_op(operator.pos) |
| 131 | + |
| 132 | + def __abs__(self) -> Self: |
| 133 | + return self._unary_op(operator.abs) |
| 134 | + |
| 135 | + def __invert__(self) -> Self: |
| 136 | + return self._unary_op(operator.invert) |
| 137 | + |
| 138 | + def round(self, *args: Any, **kwargs: Any) -> Self: |
| 139 | + return self._unary_op(ops.round_, *args, **kwargs) |
| 140 | + |
| 141 | + def argsort(self, *args: Any, **kwargs: Any) -> Self: |
| 142 | + return self._unary_op(ops.argsort, *args, **kwargs) |
| 143 | + |
| 144 | + def conj(self, *args: Any, **kwargs: Any) -> Self: |
| 145 | + return self._unary_op(ops.conj, *args, **kwargs) |
| 146 | + |
| 147 | + def conjugate(self, *args: Any, **kwargs: Any) -> Self: |
| 148 | + return self._unary_op(ops.conjugate, *args, **kwargs) |
| 149 | + |
| 150 | + __add__.__doc__ = operator.add.__doc__ |
| 151 | + __sub__.__doc__ = operator.sub.__doc__ |
| 152 | + __mul__.__doc__ = operator.mul.__doc__ |
| 153 | + __pow__.__doc__ = operator.pow.__doc__ |
| 154 | + __truediv__.__doc__ = operator.truediv.__doc__ |
| 155 | + __floordiv__.__doc__ = operator.floordiv.__doc__ |
| 156 | + __mod__.__doc__ = operator.mod.__doc__ |
| 157 | + __and__.__doc__ = operator.and_.__doc__ |
| 158 | + __xor__.__doc__ = operator.xor.__doc__ |
| 159 | + __or__.__doc__ = operator.or_.__doc__ |
| 160 | + __lshift__.__doc__ = operator.lshift.__doc__ |
| 161 | + __rshift__.__doc__ = operator.rshift.__doc__ |
| 162 | + __lt__.__doc__ = operator.lt.__doc__ |
| 163 | + __le__.__doc__ = operator.le.__doc__ |
| 164 | + __gt__.__doc__ = operator.gt.__doc__ |
| 165 | + __ge__.__doc__ = operator.ge.__doc__ |
| 166 | + __eq__.__doc__ = nputils.array_eq.__doc__ |
| 167 | + __ne__.__doc__ = nputils.array_ne.__doc__ |
| 168 | + __radd__.__doc__ = operator.add.__doc__ |
| 169 | + __rsub__.__doc__ = operator.sub.__doc__ |
| 170 | + __rmul__.__doc__ = operator.mul.__doc__ |
| 171 | + __rpow__.__doc__ = operator.pow.__doc__ |
| 172 | + __rtruediv__.__doc__ = operator.truediv.__doc__ |
| 173 | + __rfloordiv__.__doc__ = operator.floordiv.__doc__ |
| 174 | + __rmod__.__doc__ = operator.mod.__doc__ |
| 175 | + __rand__.__doc__ = operator.and_.__doc__ |
| 176 | + __rxor__.__doc__ = operator.xor.__doc__ |
| 177 | + __ror__.__doc__ = operator.or_.__doc__ |
| 178 | + __neg__.__doc__ = operator.neg.__doc__ |
| 179 | + __pos__.__doc__ = operator.pos.__doc__ |
| 180 | + __abs__.__doc__ = operator.abs.__doc__ |
| 181 | + __invert__.__doc__ = operator.invert.__doc__ |
| 182 | + round.__doc__ = ops.round_.__doc__ |
| 183 | + argsort.__doc__ = ops.argsort.__doc__ |
| 184 | + conj.__doc__ = ops.conj.__doc__ |
| 185 | + conjugate.__doc__ = ops.conjugate.__doc__ |
| 186 | + |
| 187 | + |
26 | 188 | class DatasetOpsMixin:
|
27 | 189 | __slots__ = ()
|
28 | 190 |
|
|
0 commit comments