diff --git a/src/quantity/_src/core.py b/src/quantity/_src/core.py index 109aaf9..bf78b8a 100644 --- a/src/quantity/_src/core.py +++ b/src/quantity/_src/core.py @@ -8,7 +8,7 @@ import array_api_compat import astropy.units as u import numpy as np -from astropy.units.quantity_helper import UFUNC_HELPERS +from astropy.units.quantity_helper import UFUNC_HELPERS, converters_and_unit from .api import QuantityArray from .utils import has_array_namespace @@ -23,6 +23,21 @@ DIMENSIONLESS = u.dimensionless_unscaled PYTHON_NUMBER = float | int | complex +TUPLE_OR_LIST = tuple | list + +NORMALIZED_FUNCTION_NAMES = { + "absolute": "abs", + "arccos": "acos", + "arccosh": "acosh", + "arcsin": "asin", + "arcsinh": "asinh", + "arctan": "atan", + "arctan2": "atan2", + "arctanh": "atanh", + "conjugate": "conj", + "true_divide": "divide", + "power": "pow", +} def get_value_and_unit( @@ -247,5 +262,115 @@ def __ipow__(self, exp, mod=None): def __setitem__(self, item, value): self.value[item] = value_in_unit(value, self.unit) - __array_ufunc__ = None + def __array_ufunc__(self, function, method, *inputs, **kwargs): + # Used for our namespace. + if (where := kwargs.get("where")) is not None and isinstance(where, Quantity): + return NotImplemented + + converters, unit = converters_and_unit(function, method, *inputs) + out = kwargs.get("out") + if out is not None: + # If pre-allocated output is used, check it is suitable. + # This also returns array view, to ensure we don't loop back. + kwargs["out"] = self._check_output( + out[0] if len(out) == 1 else out, unit, function=function + ) + + if method == "reduce" and "initial" in kwargs and unit is not None: + # Special-case for initial argument for reductions like + # np.add.reduce. This should be converted to the output unit as + # well, which is typically the same as the input unit (but can + # in principle be different: unitless for np.equal, radian + # for np.arctan2, though those are not necessarily useful!) + kwargs["initial"] = self._to_own_unit(kwargs["initial"], unit=unit) + + input_values = [get_value_and_unit(in_)[0] for in_ in inputs] + if not all( + isinstance(v, PYTHON_NUMBER) or has_array_namespace(v) + for v in (input_values[:-1] if method == "reduceat" else input_values) + ): + return NotImplemented + input_values = [ + v if conv is None else conv(v) + for (v, conv) in zip(input_values, converters, strict=True) + ] + try: + xp = self.value.__array_namespace__() + except AttributeError: + try: + xp = array_api_compat.array_namespace(self.value) + except TypeError: + xp = np + + if xp is np: + xp_func = function + else: + function_name = NORMALIZED_FUNCTION_NAMES.get(n := function.__name__, n) + xp_func = getattr(xp, function_name) + try: + result = getattr(xp_func, method)(*input_values, **kwargs) + except Exception: + # TODO: JAX supports "at" method if one passes in inplace=False. + return NotImplemented + return self._result_as_quantity(result, unit) + + @classmethod + def _check_output(cls, output, unit, function=None): + if isinstance(output, tuple): + return tuple( + cls._check_output(output_, unit_, function) + for output_, unit_ in zip(output, unit, strict=True) + ) + + if output is None: + return None + + if unit is None: + if isinstance(output, Quantity): + msg = "Cannot store non-quantity output{} in {} instance" + raise TypeError( + msg.format( + ( + f" from {function.__name__} function" + if function is not None + else "" + ), + type(output), + ) + ) + return output + + if not isinstance(output, Quantity): + if unit == DIMENSIONLESS: + return output + + msg = ( + "Cannot store output with unit '{}'{} in {} instance. " + "Use {} instance instead." + ) + raise u.UnitTypeError( + msg.format( + unit, + ( + f" from {function.__name__} function" + if function is not None + else "" + ), + type(output), + cls, + ) + ) + return output.value + + @classmethod + def _result_as_quantity(cls, result, unit): + if isinstance(result, TUPLE_OR_LIST): + # Some np.linalg functions return namedtuple, which is handy to access + # elements by name, but cannot be directly initialized with an iterator. + result_cls = getattr(result, "_make", result.__class__) + return result_cls(cls(r, u) for (r, u) in zip(result, unit, strict=True)) + + # If needed, weap the result array as a Quantity with the proper unit. + return result if unit is None else cls(result, unit) + __array_function__ = None diff --git a/tests/conftest.py b/tests/conftest.py index 94aff2f..411683d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import astropy.units as u import numpy as np from astropy.utils.decorators import classproperty +from numpy.testing import assert_array_almost_equal_nulp, assert_array_equal ARRAY_NAMESPACES = [] @@ -11,6 +12,7 @@ class ANSTests: IMMUTABLE = False # default NO_SETITEM = False + NO_OUTPUTS = False def __init_subclass__(cls, **kwargs): # Add class to namespaces available for testing if the underlying @@ -47,6 +49,8 @@ def teardown_class(cls): class UsingArrayAPIStrict(MonkeyPatchUnitConversion, ANSTests): + NO_OUTPUTS = True + @classproperty(lazy=True) def xp(cls): return __import__("array_api_strict") @@ -65,7 +69,33 @@ def xp(cls): class UsingJAX(MonkeyPatchUnitConversion, ANSTests): IMMUTABLE = True NO_SETITEM = True + NO_OUTPUTS = True @classproperty(lazy=True) def xp(cls): return __import__("jax").numpy + + +def assert_quantity_equal(q1, q2, nulp=0): + assert q1.unit == q2.unit + assert q1.value.__class__ is q2.value.__class__ + if nulp: + assert_array_almost_equal_nulp(q1.value, q2.value, nulp=nulp) + else: + assert_array_equal(q1.value, q2.value) + + +class TrackingNameSpace: + """Intermediate namespace that tracks attributes that were used. + + Used to check whether we test complete sets of functions in the Array API. + """ + + def __init__(self, ns): + self.ns = ns + self.used_attrs = set() + + def __getattr__(self, attr): + if not attr.startswith("_"): + self.used_attrs.add(attr) + return getattr(self.ns, attr) diff --git a/tests/test_element_wise_functions.py b/tests/test_element_wise_functions.py new file mode 100644 index 0000000..1e39c63 --- /dev/null +++ b/tests/test_element_wise_functions.py @@ -0,0 +1,488 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst +"""Test that element-wise functions on Quantity properly propagate units. + +This just tests the functions defined by the Array API: +https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html + +Note: tests classes are combined with setups for different array types +at the very end. Hence, they do not have the usual Test prefix. +""" + +import operator + +import array_api_strict +import astropy.units as u +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +from quantity import Quantity + +from .conftest import ( + ARRAY_NAMESPACES, + TrackingNameSpace, + UsingNDArray, + assert_quantity_equal, +) + +# All element-wise functions defined by the Array API. +# https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html +ARRAY_API_ELEMENT_WISE_FUNCTIONS = { + "abs", "acos", "acosh", "add", "asin", "asinh", "atan", "atan2", "atanh", + "bitwise_and", "bitwise_left_shift", "bitwise_invert", "bitwise_or", + "bitwise_right_shift", "bitwise_xor", + "ceil", "clip", "conj", "copysign", "cos", "cosh", + "divide", "equal", "exp", "expm1", "floor", "floor_divide", + "greater", "greater_equal", "hypot", "imag", "isfinite", "isinf", "isnan", + "less", "less_equal", "log", "log1p", "log2", "log10", "logaddexp", + "logical_and", "logical_not", "logical_or", "logical_xor", + "maximum", "minimum", "multiply", "negative", "not_equal", "positive", + "pow", "real", "remainder", "round", "sign", "signbit", "sin", "sinh", + "square", "sqrt", "subtract", "tan", "tanh", "trunc" +} # fmt: skip + + +# Ensure we test functions from our own array namespace +# (currently, just np, but may change). +# Track which attributes where gotten so we can check our tests are complete. +qp = TrackingNameSpace(Quantity(np.array(1.0), u.one).__array_namespace__()) + + +class QuantitySetup: + @classmethod + def setup_class(cls): + super().setup_class() + cls.a1 = cls.xp.asarray(np.arange(1.0, 11.0).reshape(5, 2)) + cls.a2 = cls.xp.asarray([8.0, 10.0]) + cls.q1 = Quantity(cls.a1, u.meter) + cls.q2 = Quantity(cls.a2, u.centimeter) + + +class Arithmetic(QuantitySetup): + # Repeating QuantityOperationTests with corresponding functions. + def test_add(self): + # Take units from left object, q1 + got = qp.add(self.q1, self.q2) + exp = Quantity(self.q1.value + self.q2.value / 100.0, u.m) + assert_quantity_equal(got, exp, nulp=1) + # Take units from left object, q2 + got = qp.add(self.q2, self.q1) + exp = Quantity(self.q1.value * 100 + self.q2.value, u.cm) + assert_quantity_equal(got, exp, nulp=1) + + def test_subtract(self): + # Take units from left object, q1 + got = qp.subtract(self.q1, self.q2) + exp = Quantity(self.q1.value - self.q2.value / 100.0, u.m) + assert_quantity_equal(got, exp, nulp=1) + + # Take units from left object, q2 + got = qp.subtract(self.q2, self.q1) + exp = Quantity(self.q2.value - 100.0 * self.q1.value, u.cm) + assert_quantity_equal(got, exp, nulp=1) + + def test_multiply(self): + got = qp.multiply(self.q1, self.q2) + exp = Quantity(self.q1.value * self.q2.value, u.Unit("m cm")) + assert_quantity_equal(got, exp) + got = qp.multiply(self.q2, self.q1) + assert_quantity_equal(got, exp) + + def test_divide(self): + got = qp.divide(self.q1, self.q2) + exp = Quantity(self.q1.value / self.q2.value, u.Unit("m/cm")) + assert_quantity_equal(got, exp) + got = qp.divide(self.q2, self.q1) + exp = Quantity(self.q2.value / self.q1.value, u.Unit("cm/m")) + assert_quantity_equal(got, exp) + + def test_floor_divide(self): + got = qp.floor_divide(self.q1, self.q2) + exp = Quantity(self.q1.value // (0.01 * self.q2.value), u.one) + assert_quantity_equal(got, exp) + got = qp.floor_divide(self.q2, self.q1) + exp = Quantity(self.q2.value // (100.0 * self.q1.value), u.one) + assert_quantity_equal(got, exp) + + def test_remainder(self): + got = qp.remainder(self.q1, self.q2) + exp = Quantity(self.q1.value % (0.01 * self.q2.value), self.q1.unit) + assert_quantity_equal(got, exp) + got = qp.remainder(self.q2, self.q1) + exp = Quantity(self.q2.value % (100.0 * self.q1.value), self.q2.unit) + assert_quantity_equal(got, exp) + + def test_negative(self): + got = qp.negative(self.q1) + exp = Quantity(-self.q1.value, u.m) + assert_quantity_equal(got, exp) + + got = qp.negative(qp.negative(self.q1)) + assert_quantity_equal(got, self.q1) + + def test_positive(self): + got = qp.positive(self.q1) + assert_quantity_equal(got, self.q1) + + def test_abs(self): + got = qp.abs(self.q1) + exp = Quantity(abs(self.q1.value), u.m) + assert_quantity_equal(got, exp) + got = qp.abs(-self.q1) + exp = Quantity(abs(self.q1.value), u.m) + assert_quantity_equal(got, exp) + + def test_floor_divide_remainder_roundtrip(self): + got = qp.add( + qp.remainder(self.q1, self.q2), + qp.multiply(qp.floor_divide(self.q1, self.q2), self.q2), + ) + assert_quantity_equal(got, self.q1, nulp=1) + got = qp.add( + qp.remainder(self.q2, self.q1), + qp.multiply(qp.floor_divide(self.q2, self.q1), self.q1), + ) + assert_quantity_equal(got, self.q2, nulp=1) + + def test_dimensionless_operations(self): + q1 = Quantity(self.a1, u.m / u.km) + q2 = Quantity(self.a2, u.mm / u.km) + got = qp.add(q1, q2) + exp = Quantity(q1.value + q2.value / 1000.0, q1.unit) + assert_quantity_equal(got, exp, nulp=1) + # Test plain array. + a = self.xp.asarray(1.0) + got = qp.add(q1, a) + exp = Quantity(q1.value / 1000.0 + 1.0, u.one) + assert_quantity_equal(got, exp, nulp=1) + + def test_integer_promotion(self): + a1 = self.xp.asarray([1, 2, 3]) + try: + a1 * 0.001 + except Exception: + pytest.xfail(reason="{self.xp!r} does not support int to float promotion.") + q1 = Quantity(a1, u.m / u.km) + a2 = self.xp.asarray([4, 5, 6]) + got = qp.add(q1, a2) + exp = Quantity(q1.value / 1000.0 + a2, u.one) + assert_quantity_equal(got, exp, nulp=1) + + def test_incompatible_units(self): + """Raise when trying to add or subtract incompatible units""" + q = Quantity(21.52, unit=u.second) + with pytest.raises(u.UnitsError, match="[Cc]an only apply 'add' function"): + qp.add(self.q1, q) + + def test_non_number_type(self): + with pytest.raises(TypeError, match=r"[Uu]nsupported operand type\(s\).*"): + qp.add(self.q1, {"a": 1}) + + with pytest.raises(TypeError): + qp.add(self.q1, u.meter) + + def test_multiplication_with_unit(self): + with pytest.raises(TypeError): + qp.multiply(self.q1, u.s) + + with pytest.raises(TypeError): + qp.multiply(u.s, self.q1) + + with pytest.raises(TypeError): + qp.multiply(self.q1, u.mag(u.Jy)) + + def test_division_with_unit(self): + with pytest.raises(TypeError): + qp.divide(self.q1, u.s) + + with pytest.raises(TypeError): + qp.divide(u.s, self.q1) + + def test_floor_division_errors(self): + q2 = Quantity(self.a1, u.s) + with pytest.raises(u.UnitsError, match="[Cc]an only apply 'floor_divide'"): + qp.floor_divide(self.q1, q2) + + with pytest.raises(TypeError): + qp.floor_divide(self.q1, u.s) + + def test_dimensionless_error(self): + with pytest.raises(u.UnitsError): + qp.add(self.q1, Quantity(self.a1, unit=u.one)) + + with pytest.raises(u.UnitsError): + qp.add(self.q1, Quantity(self.a1, unit=u.one)) + + +class Powers(QuantitySetup): + def test_pow(self): + # raise quantity to a power + p2 = self.xp.asarray(2.0) + got = qp.pow(self.q1, p2) + exp = Quantity(self.a1**2, u.Unit("m^2")) + assert_quantity_equal(got, exp) + p3 = self.xp.asarray(3.0) + got = qp.pow(self.q1, p3) + exp = Quantity(self.a1**3, u.Unit("m^3")) + assert_quantity_equal(got, exp) + + def test_square(self): + got = qp.square(self.q1) + exp = Quantity(self.xp.square(self.a1), u.Unit("m^2")) + assert_quantity_equal(got, exp) + + def test_sqrt(self): + got = qp.sqrt(self.q1) + exp = Quantity(self.xp.sqrt(self.a1), u.Unit("m^(1/2)")) + assert_quantity_equal(got, exp) + + def test_hypot(self): + got = qp.hypot(self.q1, self.q2) + exp = Quantity(self.xp.hypot(self.a1, 0.01 * self.a2), self.q1.unit) + assert_quantity_equal(got, exp) + + +class ArithmeticWithNumbers(QuantitySetup): + # Separate tests, since not strictly required by the Array API, + # and hence array_api_strict doesn't pass with them. + def test_multiplication_with_number(self): + got = qp.multiply(15.0, self.q1) + exp = Quantity(15.0 * self.q1.value, u.m) + assert_quantity_equal(got, exp) + got = qp.multiply(self.q1, 15.0) + assert_quantity_equal(got, exp) + + def test_division_with_number(self): + got = qp.divide(self.q1, 10.0) + exp = Quantity(self.q1.value / 10.0, u.m) + assert_quantity_equal(got, exp) + got = qp.divide(11.0, self.q1) + exp = Quantity(11.0 / self.q1.value, u.m**-1) + assert_quantity_equal(got, exp) + + @pytest.mark.parametrize( + "exponent", + [2, 2.0, np.uint64(2), np.int32(2), np.float32(2), Quantity(2.0, u.one)], + ) + def test_quantity_as_power(self, exponent): + # raise unit to a dimensionless Quantity power + if isinstance(exponent, Quantity): + pytest.xfail(reason="cannot handle quantity exponent yet") + got = qp.pow(self.q1, exponent) + exp = Quantity(self.q1.value**2, u.m**2) + assert_quantity_equal(got, exp) + + +class Comparisons(QuantitySetup): + @classmethod + def setup_class(cls): + super().setup_class() + cls.q1_in_cm = Quantity(cls.q1.value * 100.0, u.cm) + cls.a2_in_m = cls.q2.unit.to(cls.q1.unit, cls.a2) + + @pytest.mark.parametrize( + ("func", "op", "includes_equal"), + [ + (qp.equal, operator.eq, True), + (qp.not_equal, operator.ne, False), + (qp.greater, operator.gt, False), + (qp.greater_equal, operator.ge, True), + (qp.less, operator.lt, False), + (qp.less_equal, operator.le, True), + ], + ) + def test_comparison(self, func, op, includes_equal): + got = func(self.q1, self.q1_in_cm) + assert got.shape == self.q1.shape + assert_array_equal(got, includes_equal) + got = func(self.q1, self.q2) + exp = op(self.q1.value, self.a2_in_m) + assert_array_equal(got, exp) + + def test_not_equal_to_unit(self): + unit = u.cm**3 + q = Quantity(self.xp.asarray([1.0]), unit) + with pytest.raises(TypeError): + qp.not_equal(q, unit) + + +class NumericTests: + @classmethod + def setup_class(cls): + super().setup_class() + cls.a = cls.xp.asarray([1.1, 1.9, -2.1, np.inf, -np.inf, np.nan]) + cls.q = Quantity(cls.a, u.m) + + @pytest.mark.parametrize("func", ["isfinite", "isinf", "isnan", "sign", "signbit"]) + def test_numeric_test(self, func): + qp_func = getattr(qp, func) + xp_func = getattr(self.xp, func) + got = qp_func(self.q) + exp = xp_func(self.a) + assert not isinstance(got, Quantity) + assert_array_equal(got, exp) + + +class ClipAndTransfer: + @classmethod + def setup_class(cls): + super().setup_class() + cls.a1 = cls.xp.asarray([1.1, 1.9, -2.1, np.inf, -np.inf, np.nan]) + cls.q1 = Quantity(cls.a1, u.m) + cls.a2 = cls.xp.asarray([1.2, 180.0, -200.0, -np.inf, np.nan, np.nan]) + cls.q2 = Quantity(cls.a2, u.cm) + + @pytest.mark.parametrize("func", ["ceil", "floor", "round", "trunc"]) + def test_one_arg(self, func): + qp_func = getattr(qp, func) + xp_func = getattr(self.xp, func) + if not isinstance(qp_func, np.ufunc): + pytest.xfail(reason="only numpy ufuncs are supported") + got = qp_func(self.q1) + exp = Quantity(xp_func(self.a1), self.q1.unit) + assert_quantity_equal(got, exp) + + @pytest.mark.parametrize("func", ["minimum", "maximum"]) + def test_min_max(self, func): + qp_func = getattr(qp, func) + xp_func = getattr(self.xp, func) + got = qp_func(self.q1, self.q2) + exp = Quantity(xp_func(self.a1, 0.01 * self.a2), self.q1.unit) + assert_quantity_equal(got, exp) + + def test_copysign(self): + got = qp.copysign(self.q1, self.q2) + exp = Quantity(self.xp.copysign(self.a1, self.a2), self.q1.unit) + assert_quantity_equal(got, exp) + + @pytest.mark.xfail(reason="only numpy ufuncs are supported") + def test_clip(self): + q3 = Quantity(self.xp.asarray(1.0), u.km) + got = qp.clip(self.q1, min=self.q2, max=q3) + exp = Quantity( + self.xp.clip(self.a1, min=0.01 * self.a2, max=1000.0), self.q1.unit + ) + assert_quantity_equal(got, exp) + + +class Trig: + @classmethod + def setup_class(cls): + super().setup_class() + angles = [-45.0, 0.0, 30.0, 120.0] + cls.a_deg = cls.xp.asarray(angles) + cls.a_rad = cls.xp.asarray(np.deg2rad(angles)) + cls.q_deg = Quantity(cls.a_deg, u.deg) + cls.q_rad = Quantity(cls.a_rad, u.rad) + + @pytest.mark.parametrize("func", ["sin", "cos", "tan", "sinh", "cosh", "tanh"]) + def test_trig(self, func): + qp_func = getattr(qp, func) + xp_func = getattr(self.xp, func) + got = qp_func(self.q_deg) + exp = Quantity(xp_func(self.a_rad), u.one) + assert_quantity_equal(got, exp, nulp=1) + + @pytest.mark.parametrize( + "func", ["asin", "acos", "atan", "asinh", "acosh", "atanh"] + ) + def test_inverse(self, func): + qp_func = getattr(qp, func) + xp_func = getattr(self.xp, func) + + xp_forw = getattr(self.xp, func[1:]) + a_in = xp_forw(self.a_rad) + q_in = Quantity(a_in * 100.0, u.percent) + + got = qp_func(q_in) + exp = Quantity(xp_func(a_in), u.rad) + assert_quantity_equal(got, exp, nulp=5) + + def test_atan2(self): + sina = qp.sin(self.q_deg) + cosa = qp.cos(self.q_deg) + got = qp.atan2(sina, cosa) + exp = Quantity( + self.xp.atan2(self.xp.sin(self.a_rad), self.xp.cos(self.a_rad)), u.rad + ) + assert_quantity_equal(got, exp, nulp=1) + + +class ExpAndLog: + @classmethod + def setup_class(cls): + super().setup_class() + cls.a = cls.xp.asarray([0.5, 1.0, 2.0]) + cls.q = Quantity(cls.a * 100.0, u.percent) + + @pytest.mark.parametrize("func", ["exp", "expm1", "log", "log1p", "log2", "log10"]) + def test_exp_or_log(self, func): + qp_func = getattr(qp, func) + xp_func = getattr(self.xp, func) + got = qp_func(self.q) + exp = Quantity(xp_func(self.a), u.one) + assert_quantity_equal(got, exp, nulp=1) + + def test_logaddexp(self): + q2 = Quantity(self.a, u.one) + got = qp.logaddexp(self.q, q2) + exp = Quantity(self.xp.logaddexp(self.a, self.a), u.one) + assert_quantity_equal(got, exp, nulp=1) + + +class Complex: + @classmethod + def setup_class(cls): + super().setup_class() + cls.a = cls.xp.asarray([1 + 1j, 1.0, -2.0 + 1.0j, np.inf, -np.inf * 1j]) + cls.q = Quantity(cls.a, u.m) + + @pytest.mark.parametrize("func", ["conj", "imag", "real"]) + def test_func(self, func): + qp_func = getattr(qp, func) + xp_func = getattr(self.xp, func) + if not isinstance(qp_func, np.ufunc): + pytest.xfail(reason="only numpy ufuncs are supported") + got = qp_func(self.q) + exp = Quantity(xp_func(self.a), self.q.unit) + assert_quantity_equal(got, exp) + + +# Create the actual test classes. +for base_setup in ARRAY_NAMESPACES: + for tests in ( + Arithmetic, + Powers, + ArithmeticWithNumbers, + Comparisons, + NumericTests, + ClipAndTransfer, + Trig, + ExpAndLog, + Complex, + ): + if tests is ArithmeticWithNumbers and base_setup.xp is array_api_strict: + continue + name = f"Test{tests.__name__}{base_setup.__name__}" + globals()[name] = type(name, (tests, base_setup), {}) + + +class TestUnsupported(QuantitySetup, UsingNDArray): + """Unsupported functions. No need to test with anything but numpy.""" + + @pytest.mark.parametrize( + "func", + [ + "bitwise_and", "bitwise_invert", "bitwise_or", "bitwise_xor", + "bitwise_left_shift", "bitwise_right_shift", + "logical_and", "logical_not", "logical_or", "logical_xor", + ] + ) # fmt: skip + def test_unsupported(self, func): + qp_func = getattr(qp, func) + with pytest.raises(TypeError): + qp_func(self.q1, self.q2) + + +def test_completeness(): + assert qp.used_attrs == ARRAY_API_ELEMENT_WISE_FUNCTIONS diff --git a/tests/test_operations.py b/tests/test_operations.py index 43b6158..ebfa688 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -11,17 +11,16 @@ import astropy.units as u import numpy as np import pytest -from numpy.testing import assert_array_almost_equal_nulp, assert_array_equal +from numpy.testing import assert_array_equal from quantity import Quantity -from .conftest import ARRAY_NAMESPACES, UsingArrayAPIStrict, UsingJAX - - -def assert_quantity_equal(q1, q2, nulp=0): - assert q1.unit == q2.unit - assert q1.value.__class__ is q2.value.__class__ - assert_array_almost_equal_nulp(q1.value, q2.value, nulp=nulp) +from .conftest import ( + ARRAY_NAMESPACES, + UsingArrayAPIStrict, + UsingJAX, + assert_quantity_equal, +) class QuantitySetup: @@ -74,7 +73,11 @@ def test_multiplication_with_unit(self): got = self.q1 * u.s exp = Quantity(self.q1.value, u.Unit("m s")) assert_quantity_equal(got, exp) + + def test_reverse_multiplication_with_unit(self): got = u.s * self.q1 + if isinstance(got, u.Quantity): + pytest.xfail(reason="Astropy unit took over, causing u.Quantity return.") # TODO: for array-api-strict, this is not great, since it changes it # to a regular array. But the problem really is with astropy unit. exp = Quantity(np.float64(1.0) * self.q1.value, u.Unit("m s")) @@ -100,8 +103,12 @@ def test_division_with_unit(self): got = self.q1 / u.s exp = Quantity(self.q1.value, u.Unit("m/s")) assert_quantity_equal(got, exp) + + def test_reverse_division_with_unit(self): # Divide into a unit. got = u.s / self.q1 + if isinstance(got, u.Quantity): + pytest.xfail(reason="Astropy unit took over, causing u.Quantity return.") # TODO: for array-api-strict, this is not great, since it changes it # to a regular array. But the problem really is with astropy unit. exp = Quantity(np.float64(1.0) / self.q1.value, u.Unit("s/m")) @@ -387,7 +394,14 @@ def test_array_mix2(self): got = q1 - q2 exp = Quantity(np.asarray(self.a1 - 0.01 * self.a2), self.q1.unit) assert_quantity_equal(got, exp) - got = a1 - q2 + + @pytest.mark.xfail( + reason="strict_array_api.subtract behaves differently from __sub__" + ) + def test_array_mix3(self): + a1 = np.asarray(self.a1) + got = a1 - self.q2 + exp = self.MyQuantity(np.asarray(self.a1 - 0.01 * self.a2), self.q1.unit) assert_quantity_equal(got, exp) def test_subtract(self): diff --git a/tests/test_ufunc.py b/tests/test_ufunc.py new file mode 100644 index 0000000..28abedb --- /dev/null +++ b/tests/test_ufunc.py @@ -0,0 +1,108 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst +"""Test ufunc behaviour beyond what is required by the Array API.""" + +import astropy.units as u +import numpy as np +import pytest + +from quantity import Quantity + +from .conftest import ARRAY_NAMESPACES, assert_quantity_equal +from .test_element_wise_functions import QuantitySetup + +# Ensure we test functions from our own array namespace +# (currently, just np, but may change). +qp = Quantity(np.array(1.0), u.one).__array_namespace__() + + +class Inplace(QuantitySetup): + def try_func(self, func, *args, **kwargs): + try: + return func(*args, **kwargs) + except Exception: + if self.NO_OUTPUTS: + pytest.xfail(reason="array type does not support out argument") + + @pytest.mark.parametrize("func", ["negative", "square"]) + def test_inplace_one_arg(self, func): + qp_func = getattr(qp, func) + exp = qp_func(self.q1) + q_out = Quantity(self.xp.zeros_like(exp.value), exp.unit) + got = self.try_func(qp_func, self.q1, out=q_out) + assert got is not q_out # Quantity is immutable. + assert got.value is q_out.value + assert_quantity_equal(got, exp) + + @pytest.mark.parametrize("func", ["add", "subtract", "divide"]) + def test_inplace_two_arg(self, func): + qp_func = getattr(qp, func) + exp = qp_func(self.q1, self.q2) + q_out = Quantity(self.xp.zeros_like(exp.value), exp.unit) + got = self.try_func(qp_func, self.q1, self.q2, out=q_out) + assert got is not q_out # Quantity is immutable. + assert got.value is q_out.value + assert_quantity_equal(got, exp) + + def test_inplace_two_outputs(self): + if any(t in self.__class__.__name__ for t in ("APIStrict", "Dask")): + pytest.xfail(reason=f"{self.q1.__class__} does not have divmod") + exps = qp.divmod(self.q1, self.q2) + q_outs = tuple( + Quantity(self.xp.zeros_like(exp.value), exp.unit) for exp in exps + ) + gots = self.try_func(qp.divmod, self.q1, self.q2, out=q_outs) + for got, q_out, exp in zip(gots, q_outs, exps, strict=False): + assert got is not q_out # Quantity is immutable. + assert got.value is q_out.value + assert_quantity_equal(got, exp) + + +class Methods(QuantitySetup): + def test_reduce(self): + if not hasattr(self.xp.add, "reduce"): + pytest.xfail("array type does not support ufunc.reduce") + exp = Quantity(self.xp.add.reduce(self.a1, axis=0), self.q1.unit) + got = qp.add.reduce(self.q1, axis=0) + assert_quantity_equal(got, exp) + + def test_reduceat(self): + if not hasattr(self.xp.add, "reduceat"): + pytest.xfail("array type does not support ufunc.reduce_at") + indices = np.array((2, 3)) # JAX only takes scalar or ndarray. + exp = Quantity(self.xp.add.reduceat(self.a1, indices, axis=0), self.q1.unit) + got = qp.add.reduceat(self.q1, indices, axis=0) + assert_quantity_equal(got, exp) + + def test_accumulate(self): + if not hasattr(self.xp.add, "accumulate"): + pytest.xfail("array type does not support ufunc.accumulate") + exp = Quantity(self.xp.add.accumulate(self.a1, axis=0), self.q1.unit) + got = qp.add.accumulate(self.q1, axis=0) + assert_quantity_equal(got, exp) + + def test_at(self): + if not hasattr(self.xp.add, "at") or self.NO_OUTPUTS: + # TODO: NO_OUTPUTS is not strictly applicable; e.g., JAX supports + # np.add.at but one has to pass in inplace=False. + pytest.xfail("array type does not support ufunc.at") + values = [1.0, 2.0] + a = self.xp.asarray(values) + self.xp.add.at(a, 1, 100.0) + exp = Quantity(a, u.cm) + + got = Quantity(self.xp.asarray(values), u.cm) + qp.add.at(got, 1, Quantity(1.0, u.m)) + assert_quantity_equal(got, exp) + + +# Create the actual test classes. +for base_setup in ARRAY_NAMESPACES: + for tests in (Inplace, Methods): + name = f"Test{tests.__name__}{base_setup.__name__}" + globals()[name] = type(name, (tests, base_setup), {}) + + +def test_where_not_supported(): + q = Quantity(np.asarray([1.0, 2.0]), u.m) + with pytest.raises(TypeError): + np.add(q, q, where=q)