|
| 1 | +# Licensed under a 3-clause BSD style license - see LICENSE.rst |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +import operator |
| 5 | +from dataclasses import dataclass, replace |
| 6 | +from typing import TYPE_CHECKING |
| 7 | + |
| 8 | +import array_api_compat |
| 9 | +import astropy.units as u |
| 10 | +import numpy as np |
| 11 | +from astropy.units.quantity_helper import UFUNC_HELPERS |
| 12 | + |
| 13 | +if TYPE_CHECKING: |
| 14 | + from typing import Any |
| 15 | + |
| 16 | + |
| 17 | +DIMENSIONLESS = u.dimensionless_unscaled |
| 18 | + |
| 19 | +PYTHON_NUMBER = float | int | complex |
| 20 | + |
| 21 | + |
| 22 | +def has_array_namespace(arg): |
| 23 | + try: |
| 24 | + array_api_compat.array_namespace(arg) |
| 25 | + except TypeError: |
| 26 | + return False |
| 27 | + else: |
| 28 | + return True |
| 29 | + |
| 30 | + |
| 31 | +def get_value_and_unit(arg, default_unit=None): |
| 32 | + # HACK: interoperability with astropy Quantity. Have protocol? |
| 33 | + try: |
| 34 | + unit = arg.unit |
| 35 | + except AttributeError: |
| 36 | + return arg, default_unit |
| 37 | + else: |
| 38 | + return arg.value, unit |
| 39 | + |
| 40 | + |
| 41 | +def value_in_unit(value, unit): |
| 42 | + v_value, v_unit = get_value_and_unit(value, default_unit=DIMENSIONLESS) |
| 43 | + return v_unit.to(unit, v_value) |
| 44 | + |
| 45 | + |
| 46 | +_OP_TO_NP_FUNC = { |
| 47 | + "__add__": np.add, |
| 48 | + "__floordiv__": np.floor_divide, |
| 49 | + "__matmul__": np.matmul, |
| 50 | + "__mod__": np.mod, |
| 51 | + "__mul__": np.multiply, |
| 52 | + "__sub__": np.subtract, |
| 53 | + "__truediv__": np.true_divide, |
| 54 | +} |
| 55 | +OP_HELPERS = {op: UFUNC_HELPERS[np_func] for op, np_func in _OP_TO_NP_FUNC.items()} |
| 56 | + |
| 57 | + |
| 58 | +def _make_op(fop, mode): |
| 59 | + assert mode in "fri" |
| 60 | + op = fop if mode == "f" else "__" + mode + fop[2:] |
| 61 | + helper = OP_HELPERS[fop] |
| 62 | + op_func = getattr(operator, fop) |
| 63 | + if mode == "r": |
| 64 | + |
| 65 | + def wrapped_helper(u1, u2): |
| 66 | + return helper(op_func, u2, u1) |
| 67 | + else: |
| 68 | + |
| 69 | + def wrapped_helper(u1, u2): |
| 70 | + return helper(op_func, u1, u2) |
| 71 | + |
| 72 | + def __op__(self, other): |
| 73 | + return self._operate(other, op, wrapped_helper) |
| 74 | + |
| 75 | + return __op__ |
| 76 | + |
| 77 | + |
| 78 | +def _make_ops(op): |
| 79 | + return tuple(_make_op(op, mode) for mode in "fri") |
| 80 | + |
| 81 | + |
| 82 | +def _make_comp(comp): |
| 83 | + def __comp__(self, other): |
| 84 | + try: |
| 85 | + other = value_in_unit(other, self.unit) |
| 86 | + except Exception: |
| 87 | + return NotImplemented |
| 88 | + return getattr(self.value, comp)(other) |
| 89 | + |
| 90 | + return __comp__ |
| 91 | + |
| 92 | + |
| 93 | +def _make_deferred(attr): |
| 94 | + # Use array_api_compat getter if available (size, device), since |
| 95 | + # some array formats provide inconsistent implementations. |
| 96 | + attr_getter = getattr(array_api_compat, attr, operator.attrgetter(attr)) |
| 97 | + |
| 98 | + def deferred(self): |
| 99 | + return attr_getter(self.value) |
| 100 | + |
| 101 | + return property(deferred) |
| 102 | + |
| 103 | + |
| 104 | +def _make_same_unit_method(attr): |
| 105 | + if array_api_func := getattr(array_api_compat, attr, None): |
| 106 | + |
| 107 | + def same_unit(self, *args, **kwargs): |
| 108 | + return replace(self, value=array_api_func(self.value, *args, **kwargs)) |
| 109 | + |
| 110 | + else: |
| 111 | + |
| 112 | + def same_unit(self, *args, **kwargs): |
| 113 | + return replace(self, value=getattr(self.value, attr)(*args, **kwargs)) |
| 114 | + |
| 115 | + return same_unit |
| 116 | + |
| 117 | + |
| 118 | +def _make_same_unit_attribute(attr): |
| 119 | + attr_getter = getattr(array_api_compat, attr, operator.attrgetter(attr)) |
| 120 | + |
| 121 | + def same_unit(self): |
| 122 | + return replace(self, value=attr_getter(self.value)) |
| 123 | + |
| 124 | + return property(same_unit) |
| 125 | + |
| 126 | + |
| 127 | +def _make_defer_dimensionless(attr): |
| 128 | + def defer_dimensionless(self): |
| 129 | + try: |
| 130 | + return getattr(self.unit.to(DIMENSIONLESS, self.value), attr)() |
| 131 | + except Exception as exc: |
| 132 | + raise TypeError from exc |
| 133 | + |
| 134 | + return defer_dimensionless |
| 135 | + |
| 136 | + |
| 137 | +def _check_pow_args(exp, mod): |
| 138 | + if mod is not None: |
| 139 | + return NotImplemented |
| 140 | + |
| 141 | + if not isinstance(exp, PYTHON_NUMBER): |
| 142 | + try: |
| 143 | + exp = exp.__complex__() |
| 144 | + except Exception: |
| 145 | + try: |
| 146 | + return exp.__float__() |
| 147 | + except Exception: |
| 148 | + return NotImplemented |
| 149 | + |
| 150 | + return exp.real if exp.imag == 0 else exp |
| 151 | + |
| 152 | + |
| 153 | +@dataclass(frozen=True, eq=False) |
| 154 | +class Quantity: |
| 155 | + value: Any |
| 156 | + unit: u.UnitBase |
| 157 | + |
| 158 | + def __array_namespace__(self, *, api_version: str | None = None) -> Any: |
| 159 | + # TODO: make our own? |
| 160 | + return np |
| 161 | + |
| 162 | + def _operate(self, other, op, units_helper): |
| 163 | + if not has_array_namespace(other) and not isinstance(other, PYTHON_NUMBER): |
| 164 | + # HACK: unit should take care of this! |
| 165 | + if not isinstance(other, u.UnitBase): |
| 166 | + return NotImplemented |
| 167 | + |
| 168 | + try: |
| 169 | + unit = getattr(operator, op)(self.unit, other) |
| 170 | + except Exception: |
| 171 | + return NotImplemented |
| 172 | + else: |
| 173 | + return replace(self, unit=unit) |
| 174 | + |
| 175 | + other_value, other_unit = get_value_and_unit(other) |
| 176 | + self_value = self.value |
| 177 | + (conv0, conv1), unit = units_helper(self.unit, other_unit) |
| 178 | + if conv0 is not None: |
| 179 | + self_value = conv0(self_value) |
| 180 | + if conv1 is not None: |
| 181 | + other_value = conv1(other_value) |
| 182 | + try: |
| 183 | + value = getattr(self_value, op)(other_value) |
| 184 | + except AttributeError: |
| 185 | + return NotImplemented |
| 186 | + if value is NotImplemented: |
| 187 | + return NotImplemented |
| 188 | + return replace(self, value=value, unit=unit) |
| 189 | + |
| 190 | + # Operators (skipping ones that make no sense, like __and__); |
| 191 | + # __pow__ and __rpow__ need special treatment and are defined below. |
| 192 | + __add__, __radd__, __iadd__ = _make_ops("__add__") |
| 193 | + __floordiv__, __rfloordiv__, __ifloordiv__ = _make_ops("__floordiv__") |
| 194 | + __matmul__, __rmatmul__, __imatmul__ = _make_ops("__matmul__") |
| 195 | + __mod__, __rmod__, __imod__ = _make_ops("__mod__") |
| 196 | + __mul__, __rmul__, __imul__ = _make_ops("__mul__") |
| 197 | + __sub__, __rsub__, __isub__ = _make_ops("__sub__") |
| 198 | + __truediv__, __rtruediv__, __itruediv__ = _make_ops("__truediv__") |
| 199 | + |
| 200 | + # Comparisons |
| 201 | + __eq__ = _make_comp("__eq__") |
| 202 | + __ge__ = _make_comp("__ge__") |
| 203 | + __gt__ = _make_comp("__gt__") |
| 204 | + __le__ = _make_comp("__le__") |
| 205 | + __lt__ = _make_comp("__lt__") |
| 206 | + __ne__ = _make_comp("__ne__") |
| 207 | + |
| 208 | + # Atttributes deferred to those of .value |
| 209 | + dtype = _make_deferred("dtype") |
| 210 | + device = _make_deferred("device") |
| 211 | + ndim = _make_deferred("ndim") |
| 212 | + shape = _make_deferred("shape") |
| 213 | + size = _make_deferred("size") |
| 214 | + |
| 215 | + # Deferred to .value, yielding new Quantity with same unit. |
| 216 | + mT = _make_same_unit_attribute("mT") |
| 217 | + T = _make_same_unit_attribute("T") |
| 218 | + __abs__ = _make_same_unit_method("__abs__") |
| 219 | + __neg__ = _make_same_unit_method("__neg__") |
| 220 | + __pos__ = _make_same_unit_method("__pos__") |
| 221 | + __getitem__ = _make_same_unit_method("__getitem__") |
| 222 | + to_device = _make_same_unit_method("to_device") |
| 223 | + |
| 224 | + # Deferred to .value, after making ourselves dimensionless (if possible). |
| 225 | + __complex__ = _make_defer_dimensionless("__complex__") |
| 226 | + __float__ = _make_defer_dimensionless("__float__") |
| 227 | + __int__ = _make_defer_dimensionless("__int__") |
| 228 | + |
| 229 | + # TODO: __dlpack__, __dlpack_device__ |
| 230 | + |
| 231 | + def __pow__(self, exp, mod=None): |
| 232 | + exp = _check_pow_args(exp, mod) |
| 233 | + if exp is NotImplemented: |
| 234 | + return NotImplemented |
| 235 | + |
| 236 | + value = self.value.__pow__(exp) |
| 237 | + if value is NotImplemented: |
| 238 | + return NotImplemented |
| 239 | + return replace(self, value=value, unit=self.unit**exp) |
| 240 | + |
| 241 | + def __ipow__(self, exp, mod=None): |
| 242 | + exp = _check_pow_args(exp, mod) |
| 243 | + if exp is NotImplemented: |
| 244 | + return NotImplemented |
| 245 | + |
| 246 | + value = self.value.__ipow__(exp) |
| 247 | + if value is NotImplemented: |
| 248 | + return NotImplemented |
| 249 | + return replace(self, value=value, unit=self.unit**exp) |
| 250 | + |
| 251 | + def __setitem__(self, item, value): |
| 252 | + self.value[item] = value_in_unit(value, self.unit) |
| 253 | + |
| 254 | + __array_ufunc__ = None |
| 255 | + __array_function__ = None |
0 commit comments