diff --git a/pyproject.toml b/pyproject.toml index 9803536..25b8616 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,6 +141,7 @@ ignore = [ "PLR09", # Too many <...> "PLR2004", # Magic value used in comparison "RET505", # Unnecessary `else`/`elif` after `return` statement + "RUF022", # `__all__` is not sorted ] [tool.ruff.lint.per-file-ignores] diff --git a/src/quantity/__init__.py b/src/quantity/__init__.py index d230993..a8aadb2 100644 --- a/src/quantity/__init__.py +++ b/src/quantity/__init__.py @@ -3,7 +3,13 @@ Copyright (c) 2024 Astropy Developers. All rights reserved. """ +from . import api from ._src import Quantity from .version import version as __version__ # noqa: F401 -__all__ = ["Quantity"] +__all__ = [ + # modules + "api", + # functions and classes + "Quantity", +] diff --git a/src/quantity/_src/api.py b/src/quantity/_src/api.py new file mode 100644 index 0000000..fec7476 --- /dev/null +++ b/src/quantity/_src/api.py @@ -0,0 +1,96 @@ +"""The Quantity API. Private module.""" + +__all__ = ["Quantity", "QuantityArray", "Unit"] + +from typing import Protocol, runtime_checkable + +from astropy.units import UnitBase as Unit + +from .array_api import Array + + +@runtime_checkable +class Quantity(Protocol): + """Minimal definition of the Quantity API. + + At minimum a Quantity must have the following attributes: + + - `value`: the numerical value of the quantity (adhering to the Array API) + - `unit`: the unit of the quantity + + In practice, Quantities themselves must adhere to the Array API, not just + their values. This stricter requirement is described by the `QuantityArray` + protocol. + + See Also + -------- + QuantityArray : A Quantity that adheres to the Array API + + Examples + -------- + >>> import numpy as np + >>> import astropy.units as u + >>> from quantity import Quantity + >>> from quantity import api + + >>> issubclass(Quantity, api.Quantity) + True + + >>> q = Quantity(value=np.array([1, 2, 3]), unit=u.m) + >>> isinstance(q, api.Quantity) + True + + """ + + #: The numerical value of the quantity, adhering to the Array API. + value: Array + + #: The unit of the quantity. + unit: Unit + + @classmethod + def __subclasshook__(cls: type, c: type) -> bool: + """Enable the subclass check for data descriptors.""" + return ( + hasattr(c, "value") or "value" in getattr(c, "__annotations__", ()) + ) and (hasattr(c, "unit") or "unit" in getattr(c, "__annotations__", ())) + + +@runtime_checkable +class QuantityArray(Quantity, Array, Protocol): + """An array-valued Quantity. + + A QuantityArray is a Quantity that itself adheres to the Array API. This + means that the QuantityArray has properties like `shape`, `dtype`, and the + `__array_namespace__` method, among many other properties and methods. To + understand the full requirements of the Array API, see the `Array` protocol. + The `Quantity` protocol describes the minimal requirements for a Quantity, + separate from the Array API. QuantityArray is the combination of these two + protocols and is the most complete description of a Quantity. + + See Also + -------- + Quantity : The minimal Quantity API, separate from the Array API + + Examples + -------- + >>> import numpy as np + >>> import astropy.units as u + >>> from quantity import Quantity + >>> from quantity import api + + >>> issubclass(Quantity, api.QuantityArray) + True + + >>> q = Quantity(value=np.array([1, 2, 3]), unit=u.m) + >>> isinstance(q, api.QuantityArray) + True + + """ + + ... + + @classmethod + def __subclasshook__(cls: type, c: type) -> bool: + """Enable the subclass check for data descriptors.""" + return Quantity.__subclasscheck__(c) and issubclass(c, Array) diff --git a/src/quantity/_src/array_api.py b/src/quantity/_src/array_api.py new file mode 100644 index 0000000..556ab7c --- /dev/null +++ b/src/quantity/_src/array_api.py @@ -0,0 +1,25 @@ +"""Minimal definition of the Array API. + +NOTE: this module will be deprecated when +https://github.com/data-apis/array-api-typing is released. + +""" + +from __future__ import annotations + +__all__ = ["HasArrayNameSpace", "Array"] + +from typing import Any, Protocol, runtime_checkable + + +class HasArrayNameSpace(Protocol): + """Minimal definition of the Array API.""" + + def __array_namespace__(self) -> Any: ... + + +@runtime_checkable +class Array(HasArrayNameSpace, Protocol): + """Minimal definition of the Array API.""" + + def __pow__(self, other: Any) -> Array: ... diff --git a/src/quantity/_src/core.py b/src/quantity/_src/core.py index adfb568..6f194a8 100644 --- a/src/quantity/_src/core.py +++ b/src/quantity/_src/core.py @@ -10,25 +10,27 @@ import numpy as np from astropy.units.quantity_helper import UFUNC_HELPERS +from .api import QuantityArray from .utils import has_array_namespace if TYPE_CHECKING: from typing import Any + from .api import Unit + from .array_api import Array + DIMENSIONLESS = u.dimensionless_unscaled PYTHON_NUMBER = float | int | complex -def get_value_and_unit(arg, default_unit=None): - # HACK: interoperability with astropy Quantity. Have protocol? - try: - unit = arg.unit - except AttributeError: - return arg, default_unit - else: - return arg.value, unit +def get_value_and_unit( + arg: QuantityArray | Array, default_unit: Unit | None = None +) -> tuple[Array, Unit]: + return ( + (arg.value, arg.unit) if isinstance(arg, QuantityArray) else (arg, default_unit) + ) def value_in_unit(value, unit): diff --git a/src/quantity/api.py b/src/quantity/api.py new file mode 100644 index 0000000..0ef4d41 --- /dev/null +++ b/src/quantity/api.py @@ -0,0 +1,16 @@ +"""Quantity-2.0: the Quantity API. + +This module provides runtime-checkable Protocol objects that define the Quantity +API. In particular there are: + +- `Quantity`: the minimal definition of a Quantity, separate from the Array API. +- `QuantityArray`: a Quantity that adheres to the Array API. This is the most + complete definition of a Quantity, inheriting from the Quantity API and + adding the requirements for the Array API. + +""" +# Copyright (c) 2024 Astropy Developers. All rights reserved. + +from ._src.api import Quantity, QuantityArray + +__all__ = ["Quantity", "QuantityArray"] diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..4e26fab --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,58 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst +"""Test the Quantity class Array API compatibility.""" + +import astropy.units as u +import numpy as np +import pytest + +from quantity import Quantity, api + +from .conftest import ARRAY_NAMESPACES + + +def test_issubclass_api(): + """Test that Quantity is a subclass of api.Quantity and api.QuantityArray.""" + assert issubclass(Quantity, api.Quantity) + assert issubclass(Quantity, api.QuantityArray) + + +def test_ndarray(): + """Test that ndarray does not satisfy the Quantity API.""" + assert not issubclass(np.ndarray, api.Quantity) + assert not isinstance(np.array([1, 2, 3]), api.Quantity) + + +def test_astropy_quantity(): + """Test that astropy.units.Quantity works with the Quantity API.""" + assert issubclass(u.Quantity, api.Quantity) + assert isinstance(u.Quantity(np.array([1, 2, 3]), u.m), api.Quantity) + + +# ------------------------------ + + +class TestIsinstanceAPI: + """Check Quantities are properly recognized independent of the array type.""" + + @pytest.fixture(scope="class", params=ARRAY_NAMESPACES) + def array_and_quantity(self, request): + xp = request.param.xp + value = xp.asarray([1.0, 2.0, 3.0]) + q = Quantity(value, u.m) + return value, q + + # ==================== + + def test_issubclass_api(self, array_and_quantity): + v, q = array_and_quantity + assert not issubclass(type(v), api.Quantity) + assert not issubclass(type(v), api.QuantityArray) + assert issubclass(type(q), api.Quantity) + assert issubclass(type(q), api.QuantityArray) + + def test_isinstance_api(self, array_and_quantity): + v, q = array_and_quantity + assert not isinstance(v, api.Quantity) + assert not isinstance(v, api.QuantityArray) + assert isinstance(q, api.Quantity) + assert isinstance(q, api.QuantityArray)