Skip to content

Commit 19c4cae

Browse files
authored
Merge pull request #2 from mhvk/quantity-basics
Quantity basics * SETUP: tell ruff to check long lines too. * ENH: Quantity with the basic Array API methods * ENH: Support basics for dask and strict array api version * ENH: Also test with JAX * MAINT: set minimum versions for dependencies
2 parents d03c22e + 6385b76 commit 19c4cae

7 files changed

+874
-5
lines changed

pyproject.toml

+13-4
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,24 @@ license = { file = "licenses/LICENSE.rst" }
1414
authors = [
1515
{ name = "The Astropy Developers", email = "[email protected]" }
1616
]
17-
dependencies = []
17+
dependencies = [
18+
"array-api-compat>=1.9.1",
19+
"astropy>=7.0",
20+
"numpy>=2.0",
21+
]
1822
dynamic = ["version"]
1923

2024
[project.optional-dependencies]
25+
# Include other array types than numpy/array-api-strict. Mostly for tests.
26+
all = [
27+
"dask>=2024.11.2",
28+
"jax>=0.4.35",
29+
]
2130
test = [
2231
"pytest",
2332
"pytest-doctestplus",
24-
"pytest-cov"
33+
"pytest-cov",
34+
"array-api-strict",
2535
]
2636
docs = [
2737
"sphinx",
@@ -106,8 +116,7 @@ extend-select = [
106116
"ARG", # flake8-unused-arguments
107117
"B", # flake8-bugbear
108118
"C4", # flake8-comprehensions
109-
"EM", # flake8-errmsg
110-
"EXE", # flake8-executable
119+
"E", # errors
111120
"G", # flake8-logging-format
112121
"I", # isort
113122
"ICN", # flake8-import-conventions

quantity/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
1+
"""
2+
Copyright (c) 2024 Astropy Developers. All rights reserved.
3+
4+
quantity-2.0: Prototyping the next generation Quantity
5+
"""
6+
17
from __future__ import annotations
28

9+
from .core import Quantity
310
from .version import version as __version__ # noqa: F401
411

5-
__all__ = []
12+
__all__ = ["Quantity"]

quantity/core.py

+255
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2+
from __future__ import annotations
3+
4+
import operator
5+
from dataclasses import dataclass, replace
6+
from typing import TYPE_CHECKING
7+
8+
import array_api_compat
9+
import astropy.units as u
10+
import numpy as np
11+
from astropy.units.quantity_helper import UFUNC_HELPERS
12+
13+
if TYPE_CHECKING:
14+
from typing import Any
15+
16+
17+
DIMENSIONLESS = u.dimensionless_unscaled
18+
19+
PYTHON_NUMBER = float | int | complex
20+
21+
22+
def has_array_namespace(arg):
23+
try:
24+
array_api_compat.array_namespace(arg)
25+
except TypeError:
26+
return False
27+
else:
28+
return True
29+
30+
31+
def get_value_and_unit(arg, default_unit=None):
32+
# HACK: interoperability with astropy Quantity. Have protocol?
33+
try:
34+
unit = arg.unit
35+
except AttributeError:
36+
return arg, default_unit
37+
else:
38+
return arg.value, unit
39+
40+
41+
def value_in_unit(value, unit):
42+
v_value, v_unit = get_value_and_unit(value, default_unit=DIMENSIONLESS)
43+
return v_unit.to(unit, v_value)
44+
45+
46+
_OP_TO_NP_FUNC = {
47+
"__add__": np.add,
48+
"__floordiv__": np.floor_divide,
49+
"__matmul__": np.matmul,
50+
"__mod__": np.mod,
51+
"__mul__": np.multiply,
52+
"__sub__": np.subtract,
53+
"__truediv__": np.true_divide,
54+
}
55+
OP_HELPERS = {op: UFUNC_HELPERS[np_func] for op, np_func in _OP_TO_NP_FUNC.items()}
56+
57+
58+
def _make_op(fop, mode):
59+
assert mode in "fri"
60+
op = fop if mode == "f" else "__" + mode + fop[2:]
61+
helper = OP_HELPERS[fop]
62+
op_func = getattr(operator, fop)
63+
if mode == "r":
64+
65+
def wrapped_helper(u1, u2):
66+
return helper(op_func, u2, u1)
67+
else:
68+
69+
def wrapped_helper(u1, u2):
70+
return helper(op_func, u1, u2)
71+
72+
def __op__(self, other):
73+
return self._operate(other, op, wrapped_helper)
74+
75+
return __op__
76+
77+
78+
def _make_ops(op):
79+
return tuple(_make_op(op, mode) for mode in "fri")
80+
81+
82+
def _make_comp(comp):
83+
def __comp__(self, other):
84+
try:
85+
other = value_in_unit(other, self.unit)
86+
except Exception:
87+
return NotImplemented
88+
return getattr(self.value, comp)(other)
89+
90+
return __comp__
91+
92+
93+
def _make_deferred(attr):
94+
# Use array_api_compat getter if available (size, device), since
95+
# some array formats provide inconsistent implementations.
96+
attr_getter = getattr(array_api_compat, attr, operator.attrgetter(attr))
97+
98+
def deferred(self):
99+
return attr_getter(self.value)
100+
101+
return property(deferred)
102+
103+
104+
def _make_same_unit_method(attr):
105+
if array_api_func := getattr(array_api_compat, attr, None):
106+
107+
def same_unit(self, *args, **kwargs):
108+
return replace(self, value=array_api_func(self.value, *args, **kwargs))
109+
110+
else:
111+
112+
def same_unit(self, *args, **kwargs):
113+
return replace(self, value=getattr(self.value, attr)(*args, **kwargs))
114+
115+
return same_unit
116+
117+
118+
def _make_same_unit_attribute(attr):
119+
attr_getter = getattr(array_api_compat, attr, operator.attrgetter(attr))
120+
121+
def same_unit(self):
122+
return replace(self, value=attr_getter(self.value))
123+
124+
return property(same_unit)
125+
126+
127+
def _make_defer_dimensionless(attr):
128+
def defer_dimensionless(self):
129+
try:
130+
return getattr(self.unit.to(DIMENSIONLESS, self.value), attr)()
131+
except Exception as exc:
132+
raise TypeError from exc
133+
134+
return defer_dimensionless
135+
136+
137+
def _check_pow_args(exp, mod):
138+
if mod is not None:
139+
return NotImplemented
140+
141+
if not isinstance(exp, PYTHON_NUMBER):
142+
try:
143+
exp = exp.__complex__()
144+
except Exception:
145+
try:
146+
return exp.__float__()
147+
except Exception:
148+
return NotImplemented
149+
150+
return exp.real if exp.imag == 0 else exp
151+
152+
153+
@dataclass(frozen=True, eq=False)
154+
class Quantity:
155+
value: Any
156+
unit: u.UnitBase
157+
158+
def __array_namespace__(self, *, api_version: str | None = None) -> Any:
159+
# TODO: make our own?
160+
return np
161+
162+
def _operate(self, other, op, units_helper):
163+
if not has_array_namespace(other) and not isinstance(other, PYTHON_NUMBER):
164+
# HACK: unit should take care of this!
165+
if not isinstance(other, u.UnitBase):
166+
return NotImplemented
167+
168+
try:
169+
unit = getattr(operator, op)(self.unit, other)
170+
except Exception:
171+
return NotImplemented
172+
else:
173+
return replace(self, unit=unit)
174+
175+
other_value, other_unit = get_value_and_unit(other)
176+
self_value = self.value
177+
(conv0, conv1), unit = units_helper(self.unit, other_unit)
178+
if conv0 is not None:
179+
self_value = conv0(self_value)
180+
if conv1 is not None:
181+
other_value = conv1(other_value)
182+
try:
183+
value = getattr(self_value, op)(other_value)
184+
except AttributeError:
185+
return NotImplemented
186+
if value is NotImplemented:
187+
return NotImplemented
188+
return replace(self, value=value, unit=unit)
189+
190+
# Operators (skipping ones that make no sense, like __and__);
191+
# __pow__ and __rpow__ need special treatment and are defined below.
192+
__add__, __radd__, __iadd__ = _make_ops("__add__")
193+
__floordiv__, __rfloordiv__, __ifloordiv__ = _make_ops("__floordiv__")
194+
__matmul__, __rmatmul__, __imatmul__ = _make_ops("__matmul__")
195+
__mod__, __rmod__, __imod__ = _make_ops("__mod__")
196+
__mul__, __rmul__, __imul__ = _make_ops("__mul__")
197+
__sub__, __rsub__, __isub__ = _make_ops("__sub__")
198+
__truediv__, __rtruediv__, __itruediv__ = _make_ops("__truediv__")
199+
200+
# Comparisons
201+
__eq__ = _make_comp("__eq__")
202+
__ge__ = _make_comp("__ge__")
203+
__gt__ = _make_comp("__gt__")
204+
__le__ = _make_comp("__le__")
205+
__lt__ = _make_comp("__lt__")
206+
__ne__ = _make_comp("__ne__")
207+
208+
# Atttributes deferred to those of .value
209+
dtype = _make_deferred("dtype")
210+
device = _make_deferred("device")
211+
ndim = _make_deferred("ndim")
212+
shape = _make_deferred("shape")
213+
size = _make_deferred("size")
214+
215+
# Deferred to .value, yielding new Quantity with same unit.
216+
mT = _make_same_unit_attribute("mT")
217+
T = _make_same_unit_attribute("T")
218+
__abs__ = _make_same_unit_method("__abs__")
219+
__neg__ = _make_same_unit_method("__neg__")
220+
__pos__ = _make_same_unit_method("__pos__")
221+
__getitem__ = _make_same_unit_method("__getitem__")
222+
to_device = _make_same_unit_method("to_device")
223+
224+
# Deferred to .value, after making ourselves dimensionless (if possible).
225+
__complex__ = _make_defer_dimensionless("__complex__")
226+
__float__ = _make_defer_dimensionless("__float__")
227+
__int__ = _make_defer_dimensionless("__int__")
228+
229+
# TODO: __dlpack__, __dlpack_device__
230+
231+
def __pow__(self, exp, mod=None):
232+
exp = _check_pow_args(exp, mod)
233+
if exp is NotImplemented:
234+
return NotImplemented
235+
236+
value = self.value.__pow__(exp)
237+
if value is NotImplemented:
238+
return NotImplemented
239+
return replace(self, value=value, unit=self.unit**exp)
240+
241+
def __ipow__(self, exp, mod=None):
242+
exp = _check_pow_args(exp, mod)
243+
if exp is NotImplemented:
244+
return NotImplemented
245+
246+
value = self.value.__ipow__(exp)
247+
if value is NotImplemented:
248+
return NotImplemented
249+
return replace(self, value=value, unit=self.unit**exp)
250+
251+
def __setitem__(self, item, value):
252+
self.value[item] = value_in_unit(value, self.unit)
253+
254+
__array_ufunc__ = None
255+
__array_function__ = None

quantity/tests/conftest.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2+
from __future__ import annotations
3+
4+
import array_api_compat
5+
import astropy.units as u
6+
import numpy as np
7+
from astropy.utils.decorators import classproperty
8+
9+
ARRAY_NAMESPACES = []
10+
11+
12+
class ANSTests:
13+
IMMUTABLE = False # default
14+
NO_SETITEM = False
15+
16+
def __init_subclass__(cls, **kwargs):
17+
# Add class to namespaces available for testing if the underlying
18+
# array class is available.
19+
if not cls.__name__.startswith("Test"):
20+
try:
21+
cls.xp # noqa: B018
22+
except ImportError:
23+
pass
24+
else:
25+
ARRAY_NAMESPACES.append(cls)
26+
27+
@classmethod
28+
def setup_class(cls):
29+
cls.ARRAY_CLASS = type(cls.xp.ones((1,)))
30+
31+
32+
class UsingNDArray(ANSTests):
33+
xp = np
34+
35+
36+
class MonkeyPatchUnitConversion:
37+
@classmethod
38+
def setup_class(cls):
39+
super().setup_class()
40+
# TODO: update astropy so this monkeypatch is not necessary!
41+
# Enable non-coercing unit conversion on all astropy versions.
42+
cls._old_condition_arg = u.core._condition_arg
43+
u.core._condition_arg = lambda x: x
44+
45+
@classmethod
46+
def teardown_class(cls):
47+
u.core._condition_arg = cls._old_condition_arg
48+
49+
50+
class UsingArrayAPIStrict(MonkeyPatchUnitConversion, ANSTests):
51+
@classproperty(lazy=True)
52+
def xp(cls):
53+
return __import__("array_api_strict")
54+
55+
56+
class UsingDask(MonkeyPatchUnitConversion, ANSTests):
57+
IMMUTABLE = True
58+
59+
@classproperty(lazy=True)
60+
def xp(cls):
61+
import dask.array as da
62+
63+
return array_api_compat.array_namespace(da.array([1.0]))
64+
65+
66+
class UsingJAX(MonkeyPatchUnitConversion, ANSTests):
67+
IMMUTABLE = True
68+
NO_SETITEM = True
69+
70+
@classproperty(lazy=True)
71+
def xp(cls):
72+
return __import__("jax").numpy

0 commit comments

Comments
 (0)