Skip to content

Commit 0e01a15

Browse files
nstarmanmhvk
andauthored
✨ feat: add Quantity API (#16)
* ✨ feat: add Quantity API Adds 2 runtime-checkable protocols that define the minimum Q API and the more complete API that includes the Array API. These protocols are made public in a new module named api. Private protocols and API are defined in the private modules _src/array_api and _src/api. The former is a temporary module that we can deprecate when the https://github.com/data-apis/array-api-typing package is developed and released. Separately, ignores RUF022, so we can order __all__ Signed-off-by: nstarman <[email protected]> Co-authored-by: Marten H. van Kerkwijk <[email protected]>
1 parent 61011f8 commit 0e01a15

File tree

7 files changed

+213
-9
lines changed

7 files changed

+213
-9
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ ignore = [
141141
"PLR09", # Too many <...>
142142
"PLR2004", # Magic value used in comparison
143143
"RET505", # Unnecessary `else`/`elif` after `return` statement
144+
"RUF022", # `__all__` is not sorted
144145
]
145146

146147
[tool.ruff.lint.per-file-ignores]

src/quantity/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
Copyright (c) 2024 Astropy Developers. All rights reserved.
44
"""
55

6+
from . import api
67
from ._src import Quantity
78
from .version import version as __version__ # noqa: F401
89

9-
__all__ = ["Quantity"]
10+
__all__ = [
11+
# modules
12+
"api",
13+
# functions and classes
14+
"Quantity",
15+
]

src/quantity/_src/api.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""The Quantity API. Private module."""
2+
3+
__all__ = ["Quantity", "QuantityArray", "Unit"]
4+
5+
from typing import Protocol, runtime_checkable
6+
7+
from astropy.units import UnitBase as Unit
8+
9+
from .array_api import Array
10+
11+
12+
@runtime_checkable
13+
class Quantity(Protocol):
14+
"""Minimal definition of the Quantity API.
15+
16+
At minimum a Quantity must have the following attributes:
17+
18+
- `value`: the numerical value of the quantity (adhering to the Array API)
19+
- `unit`: the unit of the quantity
20+
21+
In practice, Quantities themselves must adhere to the Array API, not just
22+
their values. This stricter requirement is described by the `QuantityArray`
23+
protocol.
24+
25+
See Also
26+
--------
27+
QuantityArray : A Quantity that adheres to the Array API
28+
29+
Examples
30+
--------
31+
>>> import numpy as np
32+
>>> import astropy.units as u
33+
>>> from quantity import Quantity
34+
>>> from quantity import api
35+
36+
>>> issubclass(Quantity, api.Quantity)
37+
True
38+
39+
>>> q = Quantity(value=np.array([1, 2, 3]), unit=u.m)
40+
>>> isinstance(q, api.Quantity)
41+
True
42+
43+
"""
44+
45+
#: The numerical value of the quantity, adhering to the Array API.
46+
value: Array
47+
48+
#: The unit of the quantity.
49+
unit: Unit
50+
51+
@classmethod
52+
def __subclasshook__(cls: type, c: type) -> bool:
53+
"""Enable the subclass check for data descriptors."""
54+
return (
55+
hasattr(c, "value") or "value" in getattr(c, "__annotations__", ())
56+
) and (hasattr(c, "unit") or "unit" in getattr(c, "__annotations__", ()))
57+
58+
59+
@runtime_checkable
60+
class QuantityArray(Quantity, Array, Protocol):
61+
"""An array-valued Quantity.
62+
63+
A QuantityArray is a Quantity that itself adheres to the Array API. This
64+
means that the QuantityArray has properties like `shape`, `dtype`, and the
65+
`__array_namespace__` method, among many other properties and methods. To
66+
understand the full requirements of the Array API, see the `Array` protocol.
67+
The `Quantity` protocol describes the minimal requirements for a Quantity,
68+
separate from the Array API. QuantityArray is the combination of these two
69+
protocols and is the most complete description of a Quantity.
70+
71+
See Also
72+
--------
73+
Quantity : The minimal Quantity API, separate from the Array API
74+
75+
Examples
76+
--------
77+
>>> import numpy as np
78+
>>> import astropy.units as u
79+
>>> from quantity import Quantity
80+
>>> from quantity import api
81+
82+
>>> issubclass(Quantity, api.QuantityArray)
83+
True
84+
85+
>>> q = Quantity(value=np.array([1, 2, 3]), unit=u.m)
86+
>>> isinstance(q, api.QuantityArray)
87+
True
88+
89+
"""
90+
91+
...
92+
93+
@classmethod
94+
def __subclasshook__(cls: type, c: type) -> bool:
95+
"""Enable the subclass check for data descriptors."""
96+
return Quantity.__subclasscheck__(c) and issubclass(c, Array)

src/quantity/_src/array_api.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Minimal definition of the Array API.
2+
3+
NOTE: this module will be deprecated when
4+
https://github.com/data-apis/array-api-typing is released.
5+
6+
"""
7+
8+
from __future__ import annotations
9+
10+
__all__ = ["HasArrayNameSpace", "Array"]
11+
12+
from typing import Any, Protocol, runtime_checkable
13+
14+
15+
class HasArrayNameSpace(Protocol):
16+
"""Minimal definition of the Array API."""
17+
18+
def __array_namespace__(self) -> Any: ...
19+
20+
21+
@runtime_checkable
22+
class Array(HasArrayNameSpace, Protocol):
23+
"""Minimal definition of the Array API."""
24+
25+
def __pow__(self, other: Any) -> Array: ...

src/quantity/_src/core.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,27 @@
1010
import numpy as np
1111
from astropy.units.quantity_helper import UFUNC_HELPERS
1212

13+
from .api import QuantityArray
1314
from .utils import has_array_namespace
1415

1516
if TYPE_CHECKING:
1617
from typing import Any
1718

19+
from .api import Unit
20+
from .array_api import Array
21+
1822

1923
DIMENSIONLESS = u.dimensionless_unscaled
2024

2125
PYTHON_NUMBER = float | int | complex
2226

2327

24-
def get_value_and_unit(arg, default_unit=None):
25-
# HACK: interoperability with astropy Quantity. Have protocol?
26-
try:
27-
unit = arg.unit
28-
except AttributeError:
29-
return arg, default_unit
30-
else:
31-
return arg.value, unit
28+
def get_value_and_unit(
29+
arg: QuantityArray | Array, default_unit: Unit | None = None
30+
) -> tuple[Array, Unit]:
31+
return (
32+
(arg.value, arg.unit) if isinstance(arg, QuantityArray) else (arg, default_unit)
33+
)
3234

3335

3436
def value_in_unit(value, unit):

src/quantity/api.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Quantity-2.0: the Quantity API.
2+
3+
This module provides runtime-checkable Protocol objects that define the Quantity
4+
API. In particular there are:
5+
6+
- `Quantity`: the minimal definition of a Quantity, separate from the Array API.
7+
- `QuantityArray`: a Quantity that adheres to the Array API. This is the most
8+
complete definition of a Quantity, inheriting from the Quantity API and
9+
adding the requirements for the Array API.
10+
11+
"""
12+
# Copyright (c) 2024 Astropy Developers. All rights reserved.
13+
14+
from ._src.api import Quantity, QuantityArray
15+
16+
__all__ = ["Quantity", "QuantityArray"]

tests/test_api.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2+
"""Test the Quantity class Array API compatibility."""
3+
4+
import astropy.units as u
5+
import numpy as np
6+
import pytest
7+
8+
from quantity import Quantity, api
9+
10+
from .conftest import ARRAY_NAMESPACES
11+
12+
13+
def test_issubclass_api():
14+
"""Test that Quantity is a subclass of api.Quantity and api.QuantityArray."""
15+
assert issubclass(Quantity, api.Quantity)
16+
assert issubclass(Quantity, api.QuantityArray)
17+
18+
19+
def test_ndarray():
20+
"""Test that ndarray does not satisfy the Quantity API."""
21+
assert not issubclass(np.ndarray, api.Quantity)
22+
assert not isinstance(np.array([1, 2, 3]), api.Quantity)
23+
24+
25+
def test_astropy_quantity():
26+
"""Test that astropy.units.Quantity works with the Quantity API."""
27+
assert issubclass(u.Quantity, api.Quantity)
28+
assert isinstance(u.Quantity(np.array([1, 2, 3]), u.m), api.Quantity)
29+
30+
31+
# ------------------------------
32+
33+
34+
class TestIsinstanceAPI:
35+
"""Check Quantities are properly recognized independent of the array type."""
36+
37+
@pytest.fixture(scope="class", params=ARRAY_NAMESPACES)
38+
def array_and_quantity(self, request):
39+
xp = request.param.xp
40+
value = xp.asarray([1.0, 2.0, 3.0])
41+
q = Quantity(value, u.m)
42+
return value, q
43+
44+
# ====================
45+
46+
def test_issubclass_api(self, array_and_quantity):
47+
v, q = array_and_quantity
48+
assert not issubclass(type(v), api.Quantity)
49+
assert not issubclass(type(v), api.QuantityArray)
50+
assert issubclass(type(q), api.Quantity)
51+
assert issubclass(type(q), api.QuantityArray)
52+
53+
def test_isinstance_api(self, array_and_quantity):
54+
v, q = array_and_quantity
55+
assert not isinstance(v, api.Quantity)
56+
assert not isinstance(v, api.QuantityArray)
57+
assert isinstance(q, api.Quantity)
58+
assert isinstance(q, api.QuantityArray)

0 commit comments

Comments
 (0)