diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index e6a1763..116df25 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -16,6 +16,8 @@ """ +from types import ModuleType + __all__ = [] # Warning: __array_api_version__ could change globally with @@ -325,12 +327,16 @@ ArrayAPIStrictFlags, ) -__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayAPIStrictFlags'] +__all__ += [ + 'set_array_api_strict_flags', + 'get_array_api_strict_flags', + 'reset_array_api_strict_flags', + 'ArrayAPIStrictFlags', + '__version__', +] try: - from . import _version - __version__ = _version.__version__ - del _version + from ._version import __version__ # type: ignore[import-not-found,unused-ignore] except ImportError: __version__ = "unknown" @@ -340,7 +346,7 @@ # use __getattr__. Note that linalg and fft are dynamically added and removed # from __all__ in set_array_api_strict_flags. -def __getattr__(name): +def __getattr__(name: str) -> ModuleType: if name in ['linalg', 'fft']: if name in get_array_api_strict_flags()['enabled_extensions']: if name == 'linalg': diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 0595594..1304d5a 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -16,62 +16,70 @@ from __future__ import annotations import operator +import sys +from collections.abc import Iterator from enum import IntEnum +from types import ModuleType +from typing import TYPE_CHECKING, Any, Final, Literal, SupportsIndex -from ._creation_functions import asarray +import numpy as np +import numpy.typing as npt + +from ._creation_functions import Undef, _undef, asarray from ._dtypes import ( - _DType, + DType, _all_dtypes, _boolean_dtypes, + _complex_floating_dtypes, + _dtype_categories, + _floating_dtypes, _integer_dtypes, _integer_or_boolean_dtypes, - _floating_dtypes, - _real_floating_dtypes, - _complex_floating_dtypes, _numeric_dtypes, - _result_type, - _dtype_categories, + _real_floating_dtypes, _real_to_complex_map, + _result_type, ) from ._flags import get_array_api_strict_flags, set_array_api_strict_flags +from ._typing import PyCapsule -from typing import TYPE_CHECKING, SupportsIndex -import types +if sys.version_info >= (3, 10): + from types import EllipsisType +elif TYPE_CHECKING: + from typing_extensions import EllipsisType +else: + EllipsisType = type(Ellipsis) -if TYPE_CHECKING: - from typing import Optional, Tuple, Union, Any - from ._typing import PyCapsule, Dtype - import numpy.typing as npt - -import numpy as np class Device: - def __init__(self, device="CPU_DEVICE"): + _device: Final[str] + __slots__ = ("_device", "__weakref__") + + def __init__(self, device: str = "CPU_DEVICE"): if device not in ("CPU_DEVICE", "device1", "device2"): raise ValueError(f"The device '{device}' is not a valid choice.") self._device = device - def __repr__(self): + def __repr__(self) -> str: return f"array_api_strict.Device('{self._device}')" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Device): return False return self._device == other._device - def __hash__(self): + def __hash__(self) -> int: return hash(("Device", self._device)) CPU_DEVICE = Device() ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2")) -_default = object() - # See https://github.com/data-apis/array-api-strict/issues/67 and the comment # on __array__ below. _allow_array = True + class Array: """ n-d array object for the array API namespace. @@ -87,12 +95,16 @@ class Array: functions, such as asarray(). """ + _array: npt.NDArray[Any] + _dtype: DType + _device: Device + __slots__ = ("_array", "_dtype", "_device", "__weakref__") # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. @classmethod - def _new(cls, x, /, device): + def _new(cls, x: npt.NDArray[Any] | np.generic, /, device: Device | None) -> Array: """ This is a private method for initializing the array API Array object. @@ -107,7 +119,7 @@ def _new(cls, x, /, device): if isinstance(x, np.generic): # Convert the array scalar to a 0-D array x = np.asarray(x) - _dtype = _DType(x.dtype) + _dtype = DType(x.dtype) if _dtype not in _all_dtypes: raise TypeError( f"The array_api_strict namespace does not support the dtype '{x.dtype}'" @@ -120,7 +132,7 @@ def _new(cls, x, /, device): return obj # Prevent Array() from working - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: object, **kwargs: object) -> Array: raise TypeError( "The array_api_strict Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead." ) @@ -128,7 +140,7 @@ def __new__(cls, *args, **kwargs): # These functions are not required by the spec, but are implemented for # the sake of usability. - def __repr__(self: Array, /) -> str: + def __repr__(self) -> str: """ Performs the operation __repr__. """ @@ -159,7 +171,9 @@ def __repr__(self: Array, /) -> str: # This was implemented historically for compatibility, and removing it has # caused issues for some libraries (see # https://github.com/data-apis/array-api-strict/issues/67). - def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]: + def __array__( + self, dtype: None | np.dtype[Any] = None, copy: None | bool = None + ) -> npt.NDArray[Any]: # We have to allow this to be internally enabled as there's no other # easy way to parse a list of Array objects in asarray(). if _allow_array: @@ -184,7 +198,9 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None # spec in places where it either deviates from or is more strict than # NumPy behavior - def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array: + def _check_allowed_dtypes( + self, other: Array | bool | int | float | complex, dtype_category: str, op: str + ) -> Array: """ Helper function for operators to only allow specific input dtypes @@ -197,7 +213,7 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor if self.dtype not in _dtype_categories[dtype_category]: raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") - if isinstance(other, (int, complex, float, bool)): + if isinstance(other, (bool, int, float, complex)): other = self._promote_scalar(other) elif isinstance(other, Array): if other.dtype not in _dtype_categories[dtype_category]: @@ -225,16 +241,18 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor return other - def _check_device(self, other): + def _check_device(self, other: Array | bool | int | float | complex) -> None: """Check that other is on a device compatible with the current array""" - if isinstance(other, (int, complex, float, bool)): + if isinstance(other, (bool, int, float, complex)): return elif isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") + else: + raise TypeError(f"Expected Array | python scalar; got {type(other)}") # Helper function to match the type promotion rules in the spec - def _promote_scalar(self, scalar): + def _promote_scalar(self, scalar: bool | int | float | complex) -> Array: """ Returns a promoted version of a Python scalar appropriate for use with operations on self. @@ -291,7 +309,7 @@ def _promote_scalar(self, scalar): return Array._new(np.array(scalar, dtype=target_dtype._np_dtype), device=self.device) @staticmethod - def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: + def _normalize_two_args(x1: Array, x2: Array) -> tuple[Array, Array]: """ Normalize inputs to two arg functions to fix type promotion rules @@ -327,7 +345,17 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: # Note: A large fraction of allowed indices are disallowed here (see the # docstring below) - def _validate_index(self, key, op="getitem"): + def _validate_index( + self, + key: ( + int + | slice + | EllipsisType + | Array + | tuple[int | slice | EllipsisType | Array | None, ...] + ), + op: Literal["getitem", "setitem"] = "getitem", + ) -> None: """ Validate an index according to the array API. @@ -509,7 +537,7 @@ def _validate_index(self, key, op="getitem"): # Everything below this line is required by the spec. - def __abs__(self: Array, /) -> Array: + def __abs__(self) -> Array: """ Performs the operation __abs__. """ @@ -518,7 +546,7 @@ def __abs__(self: Array, /) -> Array: res = self._array.__abs__() return self.__class__._new(res, device=self.device) - def __add__(self: Array, other: Union[int, float, Array], /) -> Array: + def __add__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __add__. """ @@ -530,7 +558,7 @@ def __add__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__add__(other._array) return self.__class__._new(res, device=self.device) - def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __and__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __and__. """ @@ -542,9 +570,7 @@ def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__and__(other._array) return self.__class__._new(res, device=self.device) - def __array_namespace__( - self: Array, /, *, api_version: Optional[str] = None - ) -> types.ModuleType: + def __array_namespace__(self, /, *, api_version: str | None = None) -> ModuleType: """ Return the array_api_strict namespace corresponding to api_version. @@ -563,7 +589,7 @@ def __array_namespace__( import array_api_strict return array_api_strict - def __bool__(self: Array, /) -> bool: + def __bool__(self) -> bool: """ Performs the operation __bool__. """ @@ -573,7 +599,7 @@ def __bool__(self: Array, /) -> bool: res = self._array.__bool__() return res - def __complex__(self: Array, /) -> complex: + def __complex__(self) -> complex: """ Performs the operation __complex__. """ @@ -584,52 +610,52 @@ def __complex__(self: Array, /) -> complex: return res def __dlpack__( - self: Array, + self, /, *, - stream: Optional[Union[int, Any]] = None, - max_version: Optional[tuple[int, int]] = _default, - dl_device: Optional[tuple[IntEnum, int]] = _default, - copy: Optional[bool] = _default, + stream: Any = None, + max_version: tuple[int, int] | None | Undef = _undef, + dl_device: tuple[IntEnum, int] | None | Undef = _undef, + copy: bool | None | Undef = _undef, ) -> PyCapsule: """ Performs the operation __dlpack__. """ if get_array_api_strict_flags()['api_version'] < '2023.12': - if max_version is not _default: + if max_version is not _undef: raise ValueError("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API") - if dl_device is not _default: + if dl_device is not _undef: raise ValueError("The device argument to __dlpack__ requires at least version 2023.12 of the array API") - if copy is not _default: + if copy is not _undef: raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") if np.lib.NumpyVersion(np.__version__) < '2.1.0': - if max_version not in [_default, None]: + if max_version not in [_undef, None]: raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") - if dl_device not in [_default, None]: + if dl_device not in [_undef, None]: raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") - if copy not in [_default, None]: + if copy not in [_undef, None]: raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") return self._array.__dlpack__(stream=stream) else: kwargs = {'stream': stream} - if max_version is not _default: + if max_version is not _undef: kwargs['max_version'] = max_version - if dl_device is not _default: + if dl_device is not _undef: kwargs['dl_device'] = dl_device - if copy is not _default: + if copy is not _undef: kwargs['copy'] = copy return self._array.__dlpack__(**kwargs) - def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: + def __dlpack_device__(self) -> tuple[IntEnum, int]: """ Performs the operation __dlpack_device__. """ # Note: device support is required for this return self._array.__dlpack_device__() - def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: + def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override] """ Performs the operation __eq__. """ @@ -643,7 +669,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: res = self._array.__eq__(other._array) return self.__class__._new(res, device=self.device) - def __float__(self: Array, /) -> float: + def __float__(self) -> float: """ Performs the operation __float__. """ @@ -655,7 +681,7 @@ def __float__(self: Array, /) -> float: res = self._array.__float__() return res - def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array: + def __floordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __floordiv__. """ @@ -667,7 +693,7 @@ def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__floordiv__(other._array) return self.__class__._new(res, device=self.device) - def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: + def __ge__(self, other: Array | int | float, /) -> Array: """ Performs the operation __ge__. """ @@ -680,14 +706,15 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: return self.__class__._new(res, device=self.device) def __getitem__( - self: Array, - key: Union[ - int, - slice, - ellipsis, # noqa: F821 - Tuple[Union[int, slice, ellipsis, None], ...], # noqa: F821 - Array, - ], + self, + key: ( + int + | slice + | EllipsisType + | Array + | None + | tuple[int | slice | EllipsisType | Array | None, ...] + ), /, ) -> Array: """ @@ -696,14 +723,13 @@ def __getitem__( # XXX Does key have to be on the same device? Is there an exception for CPU_DEVICE? # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index - self._validate_index(key) - if isinstance(key, Array): - # Indexing self._array with array_api_strict arrays can be erroneous - key = key._array - res = self._array.__getitem__(key) + self._validate_index(key, op="getitem") + # Indexing self._array with array_api_strict arrays can be erroneous + np_key = key._array if isinstance(key, Array) else key + res = self._array.__getitem__(np_key) return self._new(res, device=self.device) - def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: + def __gt__(self, other: Array | int | float, /) -> Array: """ Performs the operation __gt__. """ @@ -715,7 +741,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__gt__(other._array) return self.__class__._new(res, device=other.device) - def __int__(self: Array, /) -> int: + def __int__(self) -> int: """ Performs the operation __int__. """ @@ -727,14 +753,14 @@ def __int__(self: Array, /) -> int: res = self._array.__int__() return res - def __index__(self: Array, /) -> int: + def __index__(self) -> int: """ Performs the operation __index__. """ res = self._array.__index__() return res - def __invert__(self: Array, /) -> Array: + def __invert__(self) -> Array: """ Performs the operation __invert__. """ @@ -743,7 +769,7 @@ def __invert__(self: Array, /) -> Array: res = self._array.__invert__() return self.__class__._new(res, device=self.device) - def __iter__(self: Array, /): + def __iter__(self) -> Iterator[Array]: """ Performs the operation __iter__. """ @@ -758,7 +784,7 @@ def __iter__(self: Array, /): # implemented, which implies iteration on 1-D arrays. return (Array._new(i, device=self.device) for i in self._array) - def __le__(self: Array, other: Union[int, float, Array], /) -> Array: + def __le__(self, other: Array | int | float, /) -> Array: """ Performs the operation __le__. """ @@ -770,7 +796,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__le__(other._array) return self.__class__._new(res, device=self.device) - def __lshift__(self: Array, other: Union[int, Array], /) -> Array: + def __lshift__(self, other: Array | int, /) -> Array: """ Performs the operation __lshift__. """ @@ -782,7 +808,7 @@ def __lshift__(self: Array, other: Union[int, Array], /) -> Array: res = self._array.__lshift__(other._array) return self.__class__._new(res, device=self.device) - def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: + def __lt__(self, other: Array | int | float, /) -> Array: """ Performs the operation __lt__. """ @@ -794,7 +820,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__lt__(other._array) return self.__class__._new(res, device=self.device) - def __matmul__(self: Array, other: Array, /) -> Array: + def __matmul__(self, other: Array, /) -> Array: """ Performs the operation __matmul__. """ @@ -807,7 +833,7 @@ def __matmul__(self: Array, other: Array, /) -> Array: res = self._array.__matmul__(other._array) return self.__class__._new(res, device=self.device) - def __mod__(self: Array, other: Union[int, float, Array], /) -> Array: + def __mod__(self, other: Array | int | float, /) -> Array: """ Performs the operation __mod__. """ @@ -819,7 +845,7 @@ def __mod__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__mod__(other._array) return self.__class__._new(res, device=self.device) - def __mul__(self: Array, other: Union[int, float, Array], /) -> Array: + def __mul__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __mul__. """ @@ -831,7 +857,7 @@ def __mul__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__mul__(other._array) return self.__class__._new(res, device=self.device) - def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: + def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override] """ Performs the operation __ne__. """ @@ -843,7 +869,7 @@ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: res = self._array.__ne__(other._array) return self.__class__._new(res, device=self.device) - def __neg__(self: Array, /) -> Array: + def __neg__(self) -> Array: """ Performs the operation __neg__. """ @@ -852,7 +878,7 @@ def __neg__(self: Array, /) -> Array: res = self._array.__neg__() return self.__class__._new(res, device=self.device) - def __or__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __or__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __or__. """ @@ -864,7 +890,7 @@ def __or__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__or__(other._array) return self.__class__._new(res, device=self.device) - def __pos__(self: Array, /) -> Array: + def __pos__(self) -> Array: """ Performs the operation __pos__. """ @@ -873,11 +899,11 @@ def __pos__(self: Array, /) -> Array: res = self._array.__pos__() return self.__class__._new(res, device=self.device) - def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: + def __pow__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __pow__. """ - from ._elementwise_functions import pow + from ._elementwise_functions import pow # type: ignore[attr-defined] self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__pow__") @@ -887,7 +913,7 @@ def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: # arrays, so we use pow() here instead. return pow(self, other) - def __rshift__(self: Array, other: Union[int, Array], /) -> Array: + def __rshift__(self, other: Array | int, /) -> Array: """ Performs the operation __rshift__. """ @@ -901,10 +927,16 @@ def __rshift__(self: Array, other: Union[int, Array], /) -> Array: def __setitem__( self, - key: Union[ - int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array # noqa: F821 - ], - value: Union[int, float, bool, Array], + # Almost same as __getitem__ key but doesn't accept None + # or integer arrays + key: ( + int + | slice + | EllipsisType + | Array + | tuple[int | slice | EllipsisType, ...] + ), + value: Array | bool | int | float | complex, /, ) -> None: """ @@ -913,12 +945,11 @@ def __setitem__( # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index self._validate_index(key, op="setitem") - if isinstance(key, Array): - # Indexing self._array with array_api_strict arrays can be erroneous - key = key._array - self._array.__setitem__(key, asarray(value)._array) + # Indexing self._array with array_api_strict arrays can be erroneous + np_key = key._array if isinstance(key, Array) else key + self._array.__setitem__(np_key, asarray(value)._array) - def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: + def __sub__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __sub__. """ @@ -932,7 +963,7 @@ def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: # PEP 484 requires int to be a subtype of float, but __truediv__ should # not accept int. - def __truediv__(self: Array, other: Union[float, Array], /) -> Array: + def __truediv__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __truediv__. """ @@ -944,7 +975,7 @@ def __truediv__(self: Array, other: Union[float, Array], /) -> Array: res = self._array.__truediv__(other._array) return self.__class__._new(res, device=self.device) - def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __xor__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __xor__. """ @@ -956,7 +987,7 @@ def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__xor__(other._array) return self.__class__._new(res, device=self.device) - def __iadd__(self: Array, other: Union[int, float, Array], /) -> Array: + def __iadd__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __iadd__. """ @@ -967,7 +998,7 @@ def __iadd__(self: Array, other: Union[int, float, Array], /) -> Array: self._array.__iadd__(other._array) return self - def __radd__(self: Array, other: Union[int, float, Array], /) -> Array: + def __radd__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __radd__. """ @@ -979,7 +1010,7 @@ def __radd__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__radd__(other._array) return self.__class__._new(res, device=self.device) - def __iand__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __iand__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __iand__. """ @@ -990,7 +1021,7 @@ def __iand__(self: Array, other: Union[int, bool, Array], /) -> Array: self._array.__iand__(other._array) return self - def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __rand__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __rand__. """ @@ -1002,7 +1033,7 @@ def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__rand__(other._array) return self.__class__._new(res, device=self.device) - def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: + def __ifloordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __ifloordiv__. """ @@ -1013,7 +1044,7 @@ def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: self._array.__ifloordiv__(other._array) return self - def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: + def __rfloordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __rfloordiv__. """ @@ -1025,7 +1056,7 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__rfloordiv__(other._array) return self.__class__._new(res, device=self.device) - def __ilshift__(self: Array, other: Union[int, Array], /) -> Array: + def __ilshift__(self, other: Array | int, /) -> Array: """ Performs the operation __ilshift__. """ @@ -1036,7 +1067,7 @@ def __ilshift__(self: Array, other: Union[int, Array], /) -> Array: self._array.__ilshift__(other._array) return self - def __rlshift__(self: Array, other: Union[int, Array], /) -> Array: + def __rlshift__(self, other: Array | int, /) -> Array: """ Performs the operation __rlshift__. """ @@ -1048,7 +1079,7 @@ def __rlshift__(self: Array, other: Union[int, Array], /) -> Array: res = self._array.__rlshift__(other._array) return self.__class__._new(res, device=self.device) - def __imatmul__(self: Array, other: Array, /) -> Array: + def __imatmul__(self, other: Array, /) -> Array: """ Performs the operation __imatmul__. """ @@ -1061,7 +1092,7 @@ def __imatmul__(self: Array, other: Array, /) -> Array: res = self._array.__imatmul__(other._array) return self.__class__._new(res, device=self.device) - def __rmatmul__(self: Array, other: Array, /) -> Array: + def __rmatmul__(self, other: Array, /) -> Array: """ Performs the operation __rmatmul__. """ @@ -1074,7 +1105,7 @@ def __rmatmul__(self: Array, other: Array, /) -> Array: res = self._array.__rmatmul__(other._array) return self.__class__._new(res, device=self.device) - def __imod__(self: Array, other: Union[int, float, Array], /) -> Array: + def __imod__(self, other: Array | int | float, /) -> Array: """ Performs the operation __imod__. """ @@ -1084,7 +1115,7 @@ def __imod__(self: Array, other: Union[int, float, Array], /) -> Array: self._array.__imod__(other._array) return self - def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array: + def __rmod__(self, other: Array | int | float, /) -> Array: """ Performs the operation __rmod__. """ @@ -1096,7 +1127,7 @@ def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__rmod__(other._array) return self.__class__._new(res, device=self.device) - def __imul__(self: Array, other: Union[int, float, Array], /) -> Array: + def __imul__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __imul__. """ @@ -1106,7 +1137,7 @@ def __imul__(self: Array, other: Union[int, float, Array], /) -> Array: self._array.__imul__(other._array) return self - def __rmul__(self: Array, other: Union[int, float, Array], /) -> Array: + def __rmul__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __rmul__. """ @@ -1118,7 +1149,7 @@ def __rmul__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__rmul__(other._array) return self.__class__._new(res, device=self.device) - def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __ior__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __ior__. """ @@ -1128,7 +1159,7 @@ def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array: self._array.__ior__(other._array) return self - def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __ror__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __ror__. """ @@ -1140,7 +1171,7 @@ def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__ror__(other._array) return self.__class__._new(res, device=self.device) - def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array: + def __ipow__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __ipow__. """ @@ -1150,11 +1181,11 @@ def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array: self._array.__ipow__(other._array) return self - def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array: + def __rpow__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __rpow__. """ - from ._elementwise_functions import pow + from ._elementwise_functions import pow # type: ignore[attr-defined] other = self._check_allowed_dtypes(other, "numeric", "__rpow__") if other is NotImplemented: @@ -1163,7 +1194,7 @@ def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array: # for 0-d arrays, so we use pow() here instead. return pow(other, self) - def __irshift__(self: Array, other: Union[int, Array], /) -> Array: + def __irshift__(self, other: Array | int, /) -> Array: """ Performs the operation __irshift__. """ @@ -1173,7 +1204,7 @@ def __irshift__(self: Array, other: Union[int, Array], /) -> Array: self._array.__irshift__(other._array) return self - def __rrshift__(self: Array, other: Union[int, Array], /) -> Array: + def __rrshift__(self, other: Array | int, /) -> Array: """ Performs the operation __rrshift__. """ @@ -1185,7 +1216,7 @@ def __rrshift__(self: Array, other: Union[int, Array], /) -> Array: res = self._array.__rrshift__(other._array) return self.__class__._new(res, device=self.device) - def __isub__(self: Array, other: Union[int, float, Array], /) -> Array: + def __isub__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __isub__. """ @@ -1195,7 +1226,7 @@ def __isub__(self: Array, other: Union[int, float, Array], /) -> Array: self._array.__isub__(other._array) return self - def __rsub__(self: Array, other: Union[int, float, Array], /) -> Array: + def __rsub__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __rsub__. """ @@ -1207,7 +1238,7 @@ def __rsub__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__rsub__(other._array) return self.__class__._new(res, device=self.device) - def __itruediv__(self: Array, other: Union[float, Array], /) -> Array: + def __itruediv__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __itruediv__. """ @@ -1217,7 +1248,7 @@ def __itruediv__(self: Array, other: Union[float, Array], /) -> Array: self._array.__itruediv__(other._array) return self - def __rtruediv__(self: Array, other: Union[float, Array], /) -> Array: + def __rtruediv__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __rtruediv__. """ @@ -1229,7 +1260,7 @@ def __rtruediv__(self: Array, other: Union[float, Array], /) -> Array: res = self._array.__rtruediv__(other._array) return self.__class__._new(res, device=self.device) - def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __ixor__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __ixor__. """ @@ -1239,7 +1270,7 @@ def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array: self._array.__ixor__(other._array) return self - def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __rxor__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __rxor__. """ @@ -1251,7 +1282,7 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__rxor__(other._array) return self.__class__._new(res, device=self.device) - def to_device(self: Array, device: Device, /, stream: None = None) -> Array: + def to_device(self, device: Device, /, stream: None = None) -> Array: if stream is not None: raise ValueError("The stream argument to to_device() is not supported") if device == self._device: @@ -1262,7 +1293,7 @@ def to_device(self: Array, device: Device, /, stream: None = None) -> Array: raise ValueError(f"Unsupported device {device!r}") @property - def dtype(self) -> Dtype: + def dtype(self) -> DType: """ Array API compatible wrapper for :py:meth:`np.ndarray.dtype `. @@ -1290,7 +1321,7 @@ def ndim(self) -> int: return self._array.ndim @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: """ Array API compatible wrapper for :py:meth:`np.ndarray.shape `. diff --git a/array_api_strict/_constants.py b/array_api_strict/_constants.py index 15ab81d..d78354b 100644 --- a/array_api_strict/_constants.py +++ b/array_api_strict/_constants.py @@ -4,4 +4,4 @@ inf = np.inf nan = np.nan pi = np.pi -newaxis = np.newaxis +newaxis: None = np.newaxis diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 460dba9..3b80b8a 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -1,23 +1,33 @@ from __future__ import annotations +from collections.abc import Generator from contextlib import contextmanager -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from enum import Enum +from typing import TYPE_CHECKING, Literal -if TYPE_CHECKING: - from ._typing import ( - Array, - Device, - Dtype, - NestedSequence, - SupportsBufferProtocol, - ) -from ._dtypes import _DType, _all_dtypes +import numpy as np + +from ._dtypes import DType, _all_dtypes, _np_dtype from ._flags import get_array_api_strict_flags +from ._typing import NestedSequence, SupportsBufferProtocol, SupportsDLPack + +if TYPE_CHECKING: + # TODO import from typing (requires Python >=3.13) + from typing_extensions import TypeIs + + # Circular import + from ._array_object import Array, Device + + +class Undef(Enum): + UNDEF = 0 + + +_undef = Undef.UNDEF -import numpy as np @contextmanager -def allow_array(): +def allow_array() -> Generator[None]: """ Temporarily enable Array.__array__. This is needed for np.array to parse list of lists of Array objects. @@ -30,22 +40,25 @@ def allow_array(): finally: _array_object._allow_array = original_value -def _check_valid_dtype(dtype): + +def _check_valid_dtype(dtype: DType | None) -> None: # Note: Only spelling dtypes as the dtype objects is supported. if dtype not in (None,) + _all_dtypes: raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}") -def _supports_buffer_protocol(obj): + +def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]: try: - memoryview(obj) + memoryview(obj) # type: ignore[arg-type] except TypeError: return False return True -def _check_device(device): + +def _check_device(device: Device | None) -> None: # _array_object imports in this file are inside the functions to avoid # circular imports - from ._array_object import Device, ALL_DEVICES + from ._array_object import ALL_DEVICES, Device if device is not None and not isinstance(device, Device): raise ValueError(f"Unsupported device {device!r}") @@ -53,20 +66,20 @@ def _check_device(device): if device is not None and device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") + def asarray( - obj: Union[ - Array, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: Array + | bool + | int + | float + | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol, /, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - copy: Optional[bool] = None, + dtype: DType | None = None, + device: Device | None = None, + copy: bool | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.asarray `. @@ -118,13 +131,13 @@ def asarray( def arange( - start: Union[int, float], + start: int | float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: int | float | None = None, + step: int | float = 1, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.arange `. @@ -136,16 +149,17 @@ def arange( _check_valid_dtype(dtype) _check_device(device) - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype), device=device) + return Array._new( + np.arange(start, stop, step, dtype=_np_dtype(dtype)), + device=device, + ) def empty( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.empty `. @@ -157,13 +171,11 @@ def empty( _check_valid_dtype(dtype) _check_device(device) - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.empty(shape, dtype=dtype), device=device) + return Array._new(np.empty(shape, dtype=_np_dtype(dtype)), device=device) def empty_like( - x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None + x: Array, /, *, dtype: DType | None = None, device: Device | None = None ) -> Array: """ Array API compatible wrapper for :py:func:`np.empty_like `. @@ -177,19 +189,17 @@ def empty_like( if device is None: device = x.device - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.empty_like(x._array, dtype=dtype), device=device) + return Array._new(np.empty_like(x._array, dtype=_np_dtype(dtype)), device=device) def eye( n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, k: int = 0, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.eye `. @@ -201,45 +211,43 @@ def eye( _check_valid_dtype(dtype) _check_device(device) - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype), device=device) - + return Array._new( + np.eye(n_rows, M=n_cols, k=k, dtype=_np_dtype(dtype)), device=device + ) -_default = object() def from_dlpack( - x: object, + x: SupportsDLPack, /, *, - device: Optional[Device] = _default, - copy: Optional[bool] = _default, + device: Device | Undef | None = _undef, + copy: bool | Undef | None = _undef, ) -> Array: from ._array_object import Array if get_array_api_strict_flags()['api_version'] < '2023.12': - if device is not _default: + if device is not _undef: raise ValueError("The device argument to from_dlpack requires at least version 2023.12 of the array API") - if copy is not _default: + if copy is not _undef: raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API") # Going to wait for upstream numpy support - if device is not _default: + if device is not _undef: _check_device(device) else: device = None - if copy not in [_default, None]: + if copy not in [_undef, None]: raise NotImplementedError("The copy argument to from_dlpack is not yet implemented") return Array._new(np.from_dlpack(x), device=device) def full( - shape: Union[int, Tuple[int, ...]], - fill_value: Union[int, float], + shape: int | tuple[int, ...], + fill_value: bool | int | float | complex, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.full `. @@ -253,10 +261,8 @@ def full( if isinstance(fill_value, Array) and fill_value.ndim == 0: fill_value = fill_value._array - if dtype is not None: - dtype = dtype._np_dtype - res = np.full(shape, fill_value, dtype=dtype) - if _DType(res.dtype) not in _all_dtypes: + res = np.full(shape, fill_value, dtype=_np_dtype(dtype)) + if DType(res.dtype) not in _all_dtypes: # This will happen if the fill value is not something that NumPy # coerces to one of the acceptable dtypes. raise TypeError("Invalid input to full") @@ -266,10 +272,10 @@ def full( def full_like( x: Array, /, - fill_value: Union[int, float], + fill_value: bool | int | float | complex, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.full_like `. @@ -283,10 +289,8 @@ def full_like( if device is None: device = x.device - if dtype is not None: - dtype = dtype._np_dtype - res = np.full_like(x._array, fill_value, dtype=dtype) - if _DType(res.dtype) not in _all_dtypes: + res = np.full_like(x._array, fill_value, dtype=_np_dtype(dtype)) + if DType(res.dtype) not in _all_dtypes: # This will happen if the fill value is not something that NumPy # coerces to one of the acceptable dtypes. raise TypeError("Invalid input to full_like") @@ -294,13 +298,13 @@ def full_like( def linspace( - start: Union[int, float], - stop: Union[int, float], + start: int | float | complex, + stop: int | float | complex, /, num: int, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, ) -> Array: """ @@ -313,12 +317,13 @@ def linspace( _check_valid_dtype(dtype) _check_device(device) - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint), device=device) + return Array._new( + np.linspace(start, stop, num, dtype=_np_dtype(dtype), endpoint=endpoint), + device=device, + ) -def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: +def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> list[Array]: """ Array API compatible wrapper for :py:func:`np.meshgrid `. @@ -348,10 +353,10 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: def ones( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.ones `. @@ -363,13 +368,11 @@ def ones( _check_valid_dtype(dtype) _check_device(device) - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.ones(shape, dtype=dtype), device=device) + return Array._new(np.ones(shape, dtype=_np_dtype(dtype)), device=device) def ones_like( - x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None + x: Array, /, *, dtype: DType | None = None, device: Device | None = None ) -> Array: """ Array API compatible wrapper for :py:func:`np.ones_like `. @@ -383,9 +386,7 @@ def ones_like( if device is None: device = x.device - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.ones_like(x._array, dtype=dtype), device=device) + return Array._new(np.ones_like(x._array, dtype=_np_dtype(dtype)), device=device) def tril(x: Array, /, *, k: int = 0) -> Array: @@ -417,10 +418,10 @@ def triu(x: Array, /, *, k: int = 0) -> Array: def zeros( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.zeros `. @@ -432,13 +433,11 @@ def zeros( _check_valid_dtype(dtype) _check_device(device) - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.zeros(shape, dtype=dtype), device=device) + return Array._new(np.zeros(shape, dtype=_np_dtype(dtype)), device=device) def zeros_like( - x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None + x: Array, /, *, dtype: DType | None = None, device: Device | None = None ) -> Array: """ Array API compatible wrapper for :py:func:`np.zeros_like `. @@ -452,6 +451,4 @@ def zeros_like( if device is None: device = x.device - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.zeros_like(x._array, dtype=dtype), device=device) + return Array._new(np.zeros_like(x._array, dtype=_np_dtype(dtype)), device=device) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 1643043..7dc918d 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -1,38 +1,37 @@ from __future__ import annotations -from ._array_object import Array -from ._creation_functions import _check_device +from dataclasses import dataclass + +import numpy as np + +from ._array_object import Array, Device +from ._creation_functions import Undef, _check_device, _undef from ._dtypes import ( - _DType, + DType, _all_dtypes, _boolean_dtypes, - _signed_integer_dtypes, - _unsigned_integer_dtypes, - _integer_dtypes, - _real_floating_dtypes, _complex_floating_dtypes, + _integer_dtypes, _numeric_dtypes, + _real_floating_dtypes, _result_type, + _signed_integer_dtypes, + _unsigned_integer_dtypes, ) from ._flags import get_array_api_strict_flags -from dataclasses import dataclass -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import List, Tuple, Union, Optional - from ._typing import Dtype, Device - -import numpy as np - -# Use to emulate the asarray(device) argument not existing in 2022.12 -_default = object() # Note: astype is a function, not an array method as in NumPy. def astype( - x: Array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = _default + x: Array, + dtype: DType, + /, + *, + copy: bool = True, + # _default is used to emulate the device argument not existing in 2022.12 + device: Device | Undef | None = _undef, ) -> Array: - if device is not _default: + if device is not _undef: if get_array_api_strict_flags()['api_version'] >= '2023.12': _check_device(device) else: @@ -52,7 +51,7 @@ def astype( return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy), device=device) -def broadcast_arrays(*arrays: Array) -> List[Array]: +def broadcast_arrays(*arrays: Array) -> list[Array]: """ Array API compatible wrapper for :py:func:`np.broadcast_arrays `. @@ -65,7 +64,7 @@ def broadcast_arrays(*arrays: Array) -> List[Array]: ] -def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: +def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.broadcast_to `. @@ -76,7 +75,7 @@ def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: return Array._new(np.broadcast_to(x._array, shape), device=x.device) -def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: +def can_cast(from_: DType | Array, to: DType, /) -> bool: """ Array API compatible wrapper for :py:func:`np.can_cast `. @@ -112,7 +111,7 @@ class finfo_object: max: float min: float smallest_normal: float - dtype: Dtype + dtype: DType @dataclass @@ -120,18 +119,17 @@ class iinfo_object: bits: int max: int min: int - dtype: Dtype + dtype: DType -def finfo(type: Union[Dtype, Array], /) -> finfo_object: +def finfo(type: DType | Array, /) -> finfo_object: """ Array API compatible wrapper for :py:func:`np.finfo `. See its docstring for more information. """ - if isinstance(type, _DType): - type = type._np_dtype - fi = np.finfo(type) + np_type = type._array if isinstance(type, Array) else type._np_dtype + fi = np.finfo(np_type) # Note: The types of the float data here are float, whereas in NumPy they # are scalars of the corresponding float dtype. return finfo_object( @@ -140,35 +138,33 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object: float(fi.max), float(fi.min), float(fi.smallest_normal), - fi.dtype, + DType(fi.dtype), ) -def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: +def iinfo(type: DType | Array, /) -> iinfo_object: """ Array API compatible wrapper for :py:func:`np.iinfo `. See its docstring for more information. """ - if isinstance(type, _DType): - type = type._np_dtype - ii = np.iinfo(type) - return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype) + np_type = type._array if isinstance(type, Array) else type._np_dtype + ii = np.iinfo(np_type) + return iinfo_object(ii.bits, ii.max, ii.min, DType(ii.dtype)) # Note: isdtype is a new function from the 2022.12 array API specification. -def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]] -) -> bool: +def isdtype(dtype: DType, kind: DType | str | tuple[DType | str, ...]) -> bool: """ - Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. + Returns a boolean indicating whether a provided dtype is of a specified + data type ``kind``. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html for more details """ - if not isinstance(dtype, _DType): - raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}") + if not isinstance(dtype, DType): + raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}") if isinstance(kind, tuple): # Disallow nested tuples @@ -197,7 +193,10 @@ def isdtype( else: raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}") -def result_type(*arrays_and_dtypes: Union[Array, Dtype, int, float, complex, bool]) -> Dtype: + +def result_type( + *arrays_and_dtypes: DType | Array | bool | int | float | complex, +) -> DType: """ Array API compatible wrapper for :py:func:`np.result_type `. @@ -219,15 +218,15 @@ def result_type(*arrays_and_dtypes: Union[Array, Dtype, int, float, complex, boo A.append(a) # remove python scalars - A = [a for a in A if not isinstance(a, (bool, int, float, complex))] + B = [a for a in A if not isinstance(a, (bool, int, float, complex))] - if len(A) == 0: + if len(B) == 0: raise ValueError("at least one array or dtype is required") - elif len(A) == 1: - result = A[0] + elif len(B) == 1: + result = B[0] else: - t = A[0] - for t2 in A[1:]: + t = B[0] + for t2 in B[1:]: t = _result_type(t, t2) result = t diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index 66304dd..513650b 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -1,19 +1,27 @@ +from __future__ import annotations + +import builtins import warnings +from typing import Any, Final import numpy as np +import numpy.typing as npt # Note: we wrap the NumPy dtype objects in a bare class, so that none of the # additional methods and behaviors of NumPy dtype objects are exposed. -class _DType: - def __init__(self, np_dtype): - np_dtype = np.dtype(np_dtype) - self._np_dtype = np_dtype - def __repr__(self): +class DType: + _np_dtype: Final[np.dtype[Any]] + __slots__ = ("_np_dtype", "__weakref__") + + def __init__(self, np_dtype: npt.DTypeLike): + self._np_dtype = np.dtype(np_dtype) + + def __repr__(self) -> str: return f"array_api_strict.{self._np_dtype.name}" - def __eq__(self, other): + def __eq__(self, other: object) -> builtins.bool: # See https://github.com/numpy/numpy/pull/25370/files#r1423259515. # Avoid the user error of array_api_strict.float32 == numpy.float32, # which gives False. Making == error is probably too egregious, so @@ -26,12 +34,13 @@ def __eq__(self, other): a NumPy native dtype object, but you probably don't want to do this. \ array_api_strict dtype objects compare unequal to their NumPy equivalents. \ Such cross-library comparison is not supported by the standard.""", - stacklevel=2) - if not isinstance(other, _DType): + stacklevel=2, + ) + if not isinstance(other, DType): return NotImplemented return self._np_dtype == other._np_dtype - def __hash__(self): + def __hash__(self) -> int: # Note: this is not strictly required # (https://github.com/data-apis/array-api/issues/582), but makes the # dtype objects much easier to work with here and elsewhere if they @@ -39,20 +48,24 @@ def __hash__(self): return hash(self._np_dtype) -int8 = _DType("int8") -int16 = _DType("int16") -int32 = _DType("int32") -int64 = _DType("int64") -uint8 = _DType("uint8") -uint16 = _DType("uint16") -uint32 = _DType("uint32") -uint64 = _DType("uint64") -float32 = _DType("float32") -float64 = _DType("float64") -complex64 = _DType("complex64") -complex128 = _DType("complex128") +def _np_dtype(dtype: DType | None) -> np.dtype[Any] | None: + return dtype._np_dtype if dtype is not None else None + + +int8 = DType("int8") +int16 = DType("int16") +int32 = DType("int32") +int64 = DType("int64") +uint8 = DType("uint8") +uint16 = DType("uint16") +uint32 = DType("uint32") +uint64 = DType("uint64") +float32 = DType("float32") +float64 = DType("float64") +complex64 = DType("complex64") +complex128 = DType("complex128") # Note: This name is changed -bool = _DType("bool") +bool = DType("bool") _all_dtypes = ( int8, @@ -212,7 +225,7 @@ def __hash__(self): } -def _result_type(type1, type2): +def _result_type(type1: DType, type2: DType) -> DType: if (type1, type2) in _promotion_table: return _promotion_table[type1, type2] raise TypeError(f"{type1} and {type2} cannot be type promoted together") diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index c11b17c..6b52a58 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -1,51 +1,50 @@ from __future__ import annotations +import numpy as np + +from ._array_object import Array +from ._creation_functions import asarray +from ._data_type_functions import broadcast_to, iinfo from ._dtypes import ( _boolean_dtypes, - _floating_dtypes, - _real_floating_dtypes, _complex_floating_dtypes, + _dtype_categories, + _floating_dtypes, _integer_dtypes, _integer_or_boolean_dtypes, - _real_numeric_dtypes, _numeric_dtypes, + _real_floating_dtypes, + _real_numeric_dtypes, _result_type, - _dtype_categories, ) -from ._array_object import Array from ._flags import requires_api_version -from ._creation_functions import asarray -from ._data_type_functions import broadcast_to, iinfo from ._helpers import _maybe_normalize_py_scalars -from typing import Optional, Union - -import numpy as np - def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func): """Base implementation of a binary function, `func_name`, defined for - dtypes from `dtype_category` + dtypes from `dtype_category` """ x1, x2 = _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name) if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError( + f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined." + ) # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np_func(x1._array, x2._array), device=x1.device) -_binary_docstring_template=\ -""" +_binary_docstring_template = """ Array API compatible wrapper for :py:func:`np.%s `. See its docstring for more information. """ -def create_binary_func(func_name, dtype_category, np_func): +def _create_binary_func(func_name, dtype_category, np_func): def inner(x1, x2, /) -> Array: return _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func) return inner @@ -58,7 +57,7 @@ def inner(x1, x2, /) -> Array: "real numeric": "int | float | Array", "numeric": "int | float | complex | Array", "integer": "int | Array", - "integer or boolean": "int | bool | Array", + "integer or boolean": "bool | int | Array", "boolean": "bool | Array", "real floating-point": "float | Array", "complex floating-point": "complex | Array", @@ -75,7 +74,7 @@ def inner(x1, x2, /) -> Array: "bitwise_xor": "integer or boolean", "_bitwise_left_shift": "integer", # leading underscore deliberate "_bitwise_right_shift": "integer", - # XXX: copysign: real fp or numeric? + # XXX: copysign: real fp or numeric? "copysign": "real floating-point", "divide": "floating-point", "equal": "all", @@ -105,7 +104,7 @@ def inner(x1, x2, /) -> Array: "atan2": "arctan2", "_bitwise_left_shift": "left_shift", "_bitwise_right_shift": "right_shift", - "pow": "power" + "pow": "power", } @@ -117,7 +116,7 @@ def inner(x1, x2, /) -> Array: numpy_name = _numpy_renames.get(func_name, func_name) np_func = getattr(np, numpy_name) - func = create_binary_func(func_name, dtype_category, np_func) + func = _create_binary_func(func_name, dtype_category, np_func) func.__name__ = func_name func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name) @@ -153,7 +152,7 @@ def bitwise_right_shift(x1: int | Array, x2: int | Array, /) -> Array: # clean up to not pollute the namespace -del func, create_binary_func +del func, _create_binary_func def abs(x: Array, /) -> Array: @@ -271,8 +270,8 @@ def ceil(x: Array, /) -> Array: def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, + min: Array | int | float | None = None, + max: Array | int | float | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.clip `. @@ -351,6 +350,7 @@ def clip( def _isscalar(a): return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -584,6 +584,7 @@ def reciprocal(x: Array, /) -> Array: raise TypeError("Only floating-point dtypes are allowed in reciprocal") return Array._new(np.reciprocal(x._array), device=x.device) + def round(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.round `. diff --git a/array_api_strict/_fft.py b/array_api_strict/_fft.py index c888826..2998254 100644 --- a/array_api_strict/_fft.py +++ b/array_api_strict/_fft.py @@ -1,31 +1,29 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from collections.abc import Sequence +from typing import Literal -if TYPE_CHECKING: - from typing import Union, Optional, Literal - from ._typing import Device, Dtype as DType - from collections.abc import Sequence +import numpy as np +from ._array_object import ALL_DEVICES, Array, Device +from ._data_type_functions import astype from ._dtypes import ( + DType, + _complex_floating_dtypes, _floating_dtypes, _real_floating_dtypes, - _complex_floating_dtypes, - float32, complex64, + float32, ) -from ._array_object import Array, ALL_DEVICES -from ._data_type_functions import astype from ._flags import requires_extension -import numpy as np @requires_extension('fft') def fft( x: Array, /, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: @@ -48,7 +46,7 @@ def ifft( x: Array, /, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: @@ -71,8 +69,8 @@ def fftn( x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: """ @@ -94,8 +92,8 @@ def ifftn( x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: """ @@ -117,7 +115,7 @@ def rfft( x: Array, /, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: @@ -140,7 +138,7 @@ def irfft( x: Array, /, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: @@ -163,8 +161,8 @@ def rfftn( x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: """ @@ -186,8 +184,8 @@ def irfftn( x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: """ @@ -209,7 +207,7 @@ def hfft( x: Array, /, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: @@ -232,7 +230,7 @@ def ihfft( x: Array, /, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: @@ -256,8 +254,8 @@ def fftfreq( /, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None + dtype: DType | None = None, + device: Device | None = None ) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.fftfreq `. @@ -280,8 +278,8 @@ def rfftfreq( /, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None + dtype: DType | None = None, + device: Device | None = None ) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.rfftfreq `. @@ -299,7 +297,7 @@ def rfftfreq( return Array._new(np_result, device=device) @requires_extension('fft') -def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: +def fftshift(x: Array, /, *, axes: int | Sequence[int] | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.fftshift `. @@ -310,7 +308,7 @@ def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: return Array._new(np.fft.fftshift(x._array, axes=axes), device=x.device) @requires_extension('fft') -def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: +def ifftshift(x: Array, /, *, axes: int | Sequence[int] | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.ifftshift `. diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 3fce8a0..6729a4b 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -12,17 +12,32 @@ """ +from __future__ import annotations + import functools import os import warnings +from collections.abc import Callable +from types import TracebackType +from typing import TYPE_CHECKING, Any, Collection, TypeVar, cast import array_api_strict +if TYPE_CHECKING: + # TODO import from typing (requires Python >= 3.10) + from typing_extensions import ParamSpec + + P = ParamSpec("P") + +T = TypeVar("T") +_CallableT = TypeVar("_CallableT", bound=Callable[..., object]) + + supported_versions = ( "2021.12", "2022.12", "2023.12", - "2024.12" + "2024.12", ) draft_version = "2025.12" @@ -43,19 +58,23 @@ "fft": "2022.12", } -ENABLED_EXTENSIONS = default_extensions = ( +default_extensions: tuple[str, ...] = ( "linalg", "fft", ) +ENABLED_EXTENSIONS = default_extensions + + # Public functions + def set_array_api_strict_flags( *, - api_version=None, - boolean_indexing=None, - data_dependent_shapes=None, - enabled_extensions=None, -): + api_version: str | None = None, + boolean_indexing: bool | None = None, + data_dependent_shapes: bool | None = None, + enabled_extensions: Collection[str] | None = None, +) -> None: """ Set the array-api-strict flags to the specified values. @@ -178,7 +197,8 @@ def set_array_api_strict_flags( draft_version=draft_version, ) -def get_array_api_strict_flags(): + +def get_array_api_strict_flags() -> dict[str, Any]: """ Get the current array-api-strict flags. @@ -228,7 +248,7 @@ def get_array_api_strict_flags(): } -def reset_array_api_strict_flags(): +def reset_array_api_strict_flags() -> None: """ Reset the array-api-strict flags to their default values. @@ -300,8 +320,19 @@ class ArrayAPIStrictFlags: reset_array_api_strict_flags: Reset the flags to their default values. """ - def __init__(self, *, api_version=None, boolean_indexing=None, - data_dependent_shapes=None, enabled_extensions=None): + + kwargs: dict[str, Any] + old_flags: dict[str, Any] + __slots__ = ("kwargs", "old_flags") + + def __init__( + self, + *, + api_version: str | None = None, + boolean_indexing: bool | None = None, + data_dependent_shapes: bool | None = None, + enabled_extensions: Collection[str] | None = None, + ): self.kwargs = { "api_version": api_version, "boolean_indexing": boolean_indexing, @@ -310,12 +341,19 @@ def __init__(self, *, api_version=None, boolean_indexing=None, } self.old_flags = get_array_api_strict_flags() - def __enter__(self): + def __enter__(self) -> None: set_array_api_strict_flags(**self.kwargs) - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + /, + ) -> None: set_array_api_strict_flags(**self.old_flags) + # Private functions ENVIRONMENT_VARIABLES = [ @@ -325,8 +363,9 @@ def __exit__(self, exc_type, exc_value, traceback): "ARRAY_API_STRICT_ENABLED_EXTENSIONS", ] -def set_flags_from_environment(): - kwargs = {} + +def set_flags_from_environment() -> None: + kwargs: dict[str, Any] = {} if "ARRAY_API_STRICT_API_VERSION" in os.environ: kwargs["api_version"] = os.environ["ARRAY_API_STRICT_API_VERSION"] @@ -346,35 +385,41 @@ def set_flags_from_environment(): # linalg and fft to __all__ set_array_api_strict_flags(**kwargs) + set_flags_from_environment() # Decorators -def requires_api_version(version): - def decorator(func): + +def requires_api_version(version: str) -> Callable[[_CallableT], _CallableT]: + def decorator(func: Callable[P, T]) -> Callable[P, T]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if version > API_VERSION: raise RuntimeError( f"The function {func.__name__} requires API version {version} or later, " f"but the current API version for array-api-strict is {API_VERSION}" ) return func(*args, **kwargs) + return wrapper - return decorator -def requires_data_dependent_shapes(func): + return cast(Callable[[_CallableT], _CallableT], decorator) + + +def requires_data_dependent_shapes(func: Callable[P, T]) -> Callable[P, T]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if not DATA_DEPENDENT_SHAPES: raise RuntimeError(f"The function {func.__name__} requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") return func(*args, **kwargs) return wrapper -def requires_extension(extension): - def decorator(func): + +def requires_extension(extension: str) -> Callable[[_CallableT], _CallableT]: + def decorator(func: Callable[P, T]) -> Callable[P, T]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if extension not in ENABLED_EXTENSIONS: if extension == 'linalg' \ and func.__name__ in ['matmul', 'tensordot', @@ -382,5 +427,7 @@ def wrapper(*args, **kwargs): raise RuntimeError(f"The linalg extension has been disabled for array-api-strict. However, {func.__name__} is also present in the main array_api_strict namespace and may be used from there.") raise RuntimeError(f"The function {func.__name__} requires the {extension} extension, but it has been disabled for array-api-strict") return func(*args, **kwargs) + return wrapper - return decorator + + return cast(Callable[[_CallableT], _CallableT], decorator) diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py index d3fc9c9..291082e 100644 --- a/array_api_strict/_helpers.py +++ b/array_api_strict/_helpers.py @@ -1,18 +1,24 @@ -"""Private helper routines. -""" +"""Private helper routines.""" -from ._flags import get_array_api_strict_flags +from __future__ import annotations + +from ._array_object import Array from ._dtypes import _dtype_categories +from ._flags import get_array_api_strict_flags _py_scalars = (bool, int, float, complex) -def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name): - +def _maybe_normalize_py_scalars( + x1: Array | bool | int | float | complex, + x2: Array | bool | int | float | complex, + dtype_category: str, + func_name: str, +) -> tuple[Array, Array]: flags = get_array_api_strict_flags() if flags["api_version"] < "2024.12": # scalars will fail at the call site - return x1, x2 + return x1, x2 # type: ignore[return-value] _allowed_dtypes = _dtype_categories[dtype_category] @@ -34,4 +40,3 @@ def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name): raise TypeError(f"Only {dtype_category} dtypes are allowed in {func_name}(...). " f"Got {x1.dtype} and {x2.dtype}.") return x1, x2 - diff --git a/array_api_strict/_indexing_functions.py b/array_api_strict/_indexing_functions.py index d7a400e..ab25fab 100644 --- a/array_api_strict/_indexing_functions.py +++ b/array_api_strict/_indexing_functions.py @@ -1,17 +1,13 @@ from __future__ import annotations +import numpy as np + from ._array_object import Array from ._dtypes import _integer_dtypes from ._flags import requires_api_version -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional -import numpy as np - -def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: +def take(x: Array, indices: Array, /, *, axis: int | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.take `. @@ -27,6 +23,7 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device) + @requires_api_version('2024.12') def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: """ diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index a9dbebf..81f88bc 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -1,25 +1,22 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import numpy as np -if TYPE_CHECKING: - from typing import Optional, Union, Tuple, List - from ._typing import device, DefaultDataTypes, DataTypes, Capabilities - -from ._array_object import ALL_DEVICES, CPU_DEVICE +from . import _dtypes as dt +from ._array_object import ALL_DEVICES, CPU_DEVICE, Device from ._flags import get_array_api_strict_flags, requires_api_version -from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 +from ._typing import Capabilities, DataTypes, DefaultDataTypes + @requires_api_version('2023.12') class __array_namespace_info__: @requires_api_version('2023.12') def capabilities(self) -> Capabilities: flags = get_array_api_strict_flags() - res = {"boolean indexing": flags['boolean_indexing'], - "data-dependent shapes": flags['data_dependent_shapes'], - } + res: Capabilities = { # type: ignore[typeddict-item] + "boolean indexing": flags['boolean_indexing'], + "data-dependent shapes": flags['data_dependent_shapes'], + } if flags['api_version'] >= '2024.12': # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will # drop support for NumPy 1 but for now, just compute the number @@ -36,104 +33,104 @@ def capabilities(self) -> Capabilities: return res @requires_api_version('2023.12') - def default_device(self) -> device: + def default_device(self) -> Device: return CPU_DEVICE @requires_api_version('2023.12') def default_dtypes( self, *, - device: Optional[device] = None, + device: Device | None = None, ) -> DefaultDataTypes: return { - "real floating": float64, - "complex floating": complex128, - "integral": int64, - "indexing": int64, + "real floating": dt.float64, + "complex floating": dt.complex128, + "integral": dt.int64, + "indexing": dt.int64, } @requires_api_version('2023.12') def dtypes( self, *, - device: Optional[device] = None, - kind: Optional[Union[str, Tuple[str, ...]]] = None, + device: Device | None = None, + kind: str | tuple[str, ...] | None = None, ) -> DataTypes: if kind is None: return { - "bool": bool, - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - "float32": float32, - "float64": float64, - "complex64": complex64, - "complex128": complex128, + "bool": dt.bool, + "int8": dt.int8, + "int16": dt.int16, + "int32": dt.int32, + "int64": dt.int64, + "uint8": dt.uint8, + "uint16": dt.uint16, + "uint32": dt.uint32, + "uint64": dt.uint64, + "float32": dt.float32, + "float64": dt.float64, + "complex64": dt.complex64, + "complex128": dt.complex128, } if kind == "bool": - return {"bool": bool} + return {"bool": dt.bool} if kind == "signed integer": return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, + "int8": dt.int8, + "int16": dt.int16, + "int32": dt.int32, + "int64": dt.int64, } if kind == "unsigned integer": return { - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, + "uint8": dt.uint8, + "uint16": dt.uint16, + "uint32": dt.uint32, + "uint64": dt.uint64, } if kind == "integral": return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, + "int8": dt.int8, + "int16": dt.int16, + "int32": dt.int32, + "int64": dt.int64, + "uint8": dt.uint8, + "uint16": dt.uint16, + "uint32": dt.uint32, + "uint64": dt.uint64, } if kind == "real floating": return { - "float32": float32, - "float64": float64, + "float32": dt.float32, + "float64": dt.float64, } if kind == "complex floating": return { - "complex64": complex64, - "complex128": complex128, + "complex64": dt.complex64, + "complex128": dt.complex128, } if kind == "numeric": return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - "float32": float32, - "float64": float64, - "complex64": complex64, - "complex128": complex128, + "int8": dt.int8, + "int16": dt.int16, + "int32": dt.int32, + "int64": dt.int64, + "uint8": dt.uint8, + "uint16": dt.uint16, + "uint32": dt.uint32, + "uint64": dt.uint64, + "float32": dt.float32, + "float64": dt.float64, + "complex64": dt.complex64, + "complex128": dt.complex128, } if isinstance(kind, tuple): - res = {} + res: DataTypes = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") @requires_api_version('2023.12') - def devices(self) -> List[device]: + def devices(self) -> list[Device]: return list(ALL_DEVICES) diff --git a/array_api_strict/_linalg.py b/array_api_strict/_linalg.py index 7d379a0..27a2ddf 100644 --- a/array_api_strict/_linalg.py +++ b/array_api_strict/_linalg.py @@ -1,33 +1,25 @@ from __future__ import annotations +from collections.abc import Sequence from functools import partial +from typing import Literal, NamedTuple -from ._dtypes import ( - _floating_dtypes, - _numeric_dtypes, - float32, - complex64, - complex128, -) +import numpy as np +import numpy.linalg + +from ._array_object import Array from ._data_type_functions import finfo -from ._manipulation_functions import reshape +from ._dtypes import DType, _floating_dtypes, _numeric_dtypes, complex64, complex128 from ._elementwise_functions import conj -from ._array_object import Array -from ._flags import requires_extension, get_array_api_strict_flags +from ._flags import get_array_api_strict_flags, requires_extension +from ._manipulation_functions import reshape +from ._statistical_functions import _np_dtype_sumprod try: - from numpy._core.numeric import normalize_axis_tuple + from numpy._core.numeric import normalize_axis_tuple # type: ignore[attr-defined] except ImportError: - from numpy.core.numeric import normalize_axis_tuple + from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ._typing import Literal, Optional, Sequence, Tuple, Union, Dtype - -from typing import NamedTuple - -import numpy.linalg -import numpy as np class EighResult(NamedTuple): eigenvalues: Array @@ -175,7 +167,13 @@ def inv(x: Array, /) -> Array: # -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point # literals. @requires_extension('linalg') -def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array: # noqa: F821 +def matrix_norm( + x: Array, + /, + *, + keepdims: bool = False, + ord: float | Literal["fro", "nuc"] | None = "fro", +) -> Array: # noqa: F821 """ Array API compatible wrapper for :py:func:`np.linalg.norm `. @@ -186,7 +184,10 @@ def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in matrix_norm') - return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord), device=x.device) + return Array._new( + np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord), + device=x.device, + ) @requires_extension('linalg') @@ -206,7 +207,7 @@ def matrix_power(x: Array, n: int, /) -> Array: # Note: the keyword argument name rtol is different from np.linalg.matrix_rank @requires_extension('linalg') -def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: +def matrix_rank(x: Array, /, *, rtol: float | Array | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.matrix_rank `. @@ -218,13 +219,12 @@ def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> A raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") S = np.linalg.svd(x._array, compute_uv=False) if rtol is None: - tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * finfo(S.dtype).eps + tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps else: - if isinstance(rtol, Array): - rtol = rtol._array + rtol_np = rtol._array if isinstance(rtol, Array) else np.asarray(rtol) # Note: this is different from np.linalg.matrix_rank, which does not multiply # the tolerance by the largest singular value. - tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis] + tol = S.max(axis=-1, keepdims=True) * rtol_np[..., np.newaxis] return Array._new(np.count_nonzero(S > tol, axis=-1), device=x.device) @@ -252,7 +252,7 @@ def outer(x1: Array, x2: Array, /) -> Array: # Note: the keyword argument name rtol is different from np.linalg.pinv @requires_extension('linalg') -def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: +def pinv(x: Array, /, *, rtol: float | Array | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.pinv `. @@ -267,9 +267,8 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: # default tolerance by max(M, N). if rtol is None: rtol = max(x.shape[-2:]) * finfo(x.dtype).eps - if isinstance(rtol, Array): - rtol = rtol._array - return Array._new(np.linalg.pinv(x._array, rcond=rtol), device=x.device) + rtol_np = rtol._array if isinstance(rtol, Array) else rtol + return Array._new(np.linalg.pinv(x._array, rcond=rtol_np), device=x.device) @requires_extension('linalg') def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: # noqa: F821 @@ -312,14 +311,14 @@ def slogdet(x: Array, /) -> SlogdetResult: # To workaround this, the below is the code from np.linalg.solve except # only calling solve1 in the exactly 1D case. -def _solve(a, b): +def _solve(a: np.ndarray, b: np.ndarray) -> np.ndarray: try: - from numpy.linalg._linalg import ( + from numpy.linalg._linalg import ( # type: ignore[attr-defined] _makearray, _assert_stacked_2d, _assert_stacked_square, _commonType, isComplexType, _raise_linalgerror_singular ) except ImportError: - from numpy.linalg.linalg import ( + from numpy.linalg.linalg import ( # type: ignore[attr-defined] _makearray, _assert_stacked_2d, _assert_stacked_square, _commonType, isComplexType, _raise_linalgerror_singular ) @@ -382,14 +381,14 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). @requires_extension('linalg') -def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: +def svdvals(x: Array, /) -> Array: if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in svdvals') return Array._new(np.linalg.svd(x._array, compute_uv=False), device=x.device) # Note: trace is the numpy top-level namespace, not np.linalg @requires_extension('linalg') -def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Array: +def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.trace `. @@ -398,19 +397,13 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in trace') - # Note: trace() works the same as sum() and prod() (see - # _statistical_functions.py) - if dtype is None: - if get_array_api_strict_flags()['api_version'] < '2023.12': - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 - else: - dtype = dtype._np_dtype + # Note: trace() works the same as sum() and prod() (see _statistical_functions.py) + np_dtype = _np_dtype_sumprod(x, dtype) + # Note: trace always operates on the last two axes, whereas np.trace # operates on the first two axes by default - return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype)), device=x.device) + res = np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=np_dtype) + return Array._new(np.asarray(res), device=x.device) # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). @@ -418,7 +411,14 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf]]] but Literal does not support floating-point literals. @requires_extension('linalg') -def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: +def vector_norm( + x: Array, + /, + *, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, + ord: int | float = 2, +) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm `. @@ -456,8 +456,8 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) - _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) - for i in _axis: + axis_tup = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) + for i in axis_tup: shape[i] = 1 res = reshape(res, tuple(shape)) @@ -480,7 +480,13 @@ def matmul(x1: Array, x2: Array, /) -> Array: # Note: tensordot is the numpy top-level namespace but not in np.linalg @requires_extension('linalg') -def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, +) -> Array: from ._linear_algebra_functions import tensordot return tensordot(x1, x2, axes=axes) diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py index 6af2a15..d18214c 100644 --- a/array_api_strict/_linear_algebra_functions.py +++ b/array_api_strict/_linear_algebra_functions.py @@ -7,16 +7,15 @@ from __future__ import annotations -from ._dtypes import _numeric_dtypes +from collections.abc import Sequence + +import numpy as np +import numpy.linalg + from ._array_object import Array +from ._dtypes import _numeric_dtypes from ._flags import get_array_api_strict_flags -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ._typing import Sequence, Tuple, Union - -import numpy.linalg -import numpy as np # Note: matmul is the numpy top-level namespace but not in np.linalg def matmul(x1: Array, x2: Array, /) -> Array: @@ -38,7 +37,13 @@ def matmul(x1: Array, x2: Array, /) -> Array: # Note: tensordot is the numpy top-level namespace but not in np.linalg # Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. -def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, +) -> Array: # Note: the restriction to numeric dtypes only is different from # np.tensordot. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 63c3516..e2fd24c 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -1,21 +1,17 @@ from __future__ import annotations +import numpy as np + from ._array_object import Array from ._creation_functions import asarray from ._data_type_functions import astype, result_type from ._dtypes import _integer_dtypes, int64, uint64 -from ._flags import requires_api_version, get_array_api_strict_flags - -from typing import TYPE_CHECKING +from ._flags import get_array_api_strict_flags, requires_api_version -if TYPE_CHECKING: - from typing import List, Optional, Tuple, Union - -import numpy as np # Note: the function name is different here def concat( - arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0 + arrays: tuple[Array, ...] | list[Array], /, *, axis: int | None = 0 ) -> Array: """ Array API compatible wrapper for :py:func:`np.concatenate `. @@ -29,8 +25,11 @@ def concat( raise ValueError("concat inputs must all be on the same device") result_device = arrays[0].device - arrays = tuple(a._array for a in arrays) - return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype), device=result_device) + np_arrays = tuple(a._array for a in arrays) + return Array._new( + np.concatenate(np_arrays, axis=axis, dtype=dtype._np_dtype), + device=result_device, + ) def expand_dims(x: Array, /, *, axis: int) -> Array: @@ -42,7 +41,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array: return Array._new(np.expand_dims(x._array, axis), device=x.device) -def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: +def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.flip `. @@ -53,8 +52,8 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> @requires_api_version('2023.12') def moveaxis( x: Array, - source: Union[int, Tuple[int, ...]], - destination: Union[int, Tuple[int, ...]], + source: int | tuple[int, ...], + destination: int | tuple[int, ...], /, ) -> Array: """ @@ -66,7 +65,7 @@ def moveaxis( # Note: The function name is different here (see also matrix_transpose). # Unlike transpose(), the axes argument is required. -def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: +def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.transpose `. @@ -77,10 +76,10 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: @requires_api_version('2023.12') def repeat( x: Array, - repeats: Union[int, Array], + repeats: int | Array, /, *, - axis: Optional[int] = None, + axis: int | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.repeat `. @@ -108,12 +107,9 @@ def repeat( repeats = astype(repeats, int64) return Array._new(np.repeat(x._array, repeats._array, axis=axis), device=x.device) + # Note: the optional argument is called 'shape', not 'newshape' -def reshape(x: Array, - /, - shape: Tuple[int, ...], - *, - copy: Optional[bool] = None) -> Array: +def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.reshape `. @@ -135,9 +131,9 @@ def reshape(x: Array, def roll( x: Array, /, - shift: Union[int, Tuple[int, ...]], + shift: int | tuple[int, ...], *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.roll `. @@ -147,7 +143,7 @@ def roll( return Array._new(np.roll(x._array, shift, axis=axis), device=x.device) -def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: +def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.squeeze `. @@ -161,7 +157,7 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: return Array._new(np.squeeze(x._array, axis=axis), device=x.device) -def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: +def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array: """ Array API compatible wrapper for :py:func:`np.stack `. @@ -172,12 +168,12 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> if len({a.device for a in arrays}) > 1: raise ValueError("concat inputs must all be on the same device") result_device = arrays[0].device - arrays = tuple(a._array for a in arrays) - return Array._new(np.stack(arrays, axis=axis), device=result_device) + np_arrays = tuple(a._array for a in arrays) + return Array._new(np.stack(np_arrays, axis=axis), device=result_device) @requires_api_version('2023.12') -def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array: +def tile(x: Array, repetitions: tuple[int, ...], /) -> Array: """ Array API compatible wrapper for :py:func:`np.tile `. @@ -190,7 +186,7 @@ def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array: # Note: this function is new @requires_api_version('2023.12') -def unstack(x: Array, /, *, axis: int = 0) -> Tuple[Array, ...]: +def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]: if not (-x.ndim <= axis < x.ndim): raise ValueError("axis out of range") diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 9864132..b366ed9 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -1,18 +1,17 @@ from __future__ import annotations -from ._array_object import Array -from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool -from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags -from ._helpers import _maybe_normalize_py_scalars - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Literal, Optional, Tuple, Union +from typing import Literal import numpy as np +from ._array_object import Array +from ._dtypes import _real_numeric_dtypes, _result_type +from ._dtypes import bool as _bool +from ._flags import requires_api_version, requires_data_dependent_shapes +from ._helpers import _maybe_normalize_py_scalars + -def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: +def argmax(x: Array, /, *, axis: int | None = None, keepdims: bool = False) -> Array: """ Array API compatible wrapper for :py:func:`np.argmax `. @@ -23,7 +22,7 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims)), device=x.device) -def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: +def argmin(x: Array, /, *, axis: int | None = None, keepdims: bool = False) -> Array: """ Array API compatible wrapper for :py:func:`np.argmin `. @@ -35,7 +34,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - @requires_data_dependent_shapes -def nonzero(x: Array, /) -> Tuple[Array, ...]: +def nonzero(x: Array, /) -> tuple[Array, ...]: """ Array API compatible wrapper for :py:func:`np.nonzero `. @@ -52,7 +51,7 @@ def count_nonzero( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: """ @@ -71,7 +70,7 @@ def searchsorted( /, *, side: Literal["left", "right"] = "left", - sorter: Optional[Array] = None, + sorter: Array | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.searchsorted `. @@ -84,25 +83,29 @@ def searchsorted( if x1.device != x2.device: raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - sorter = sorter._array if sorter is not None else None + np_sorter = sorter._array if sorter is not None else None # TODO: The sort order of nans and signed zeros is implementation # dependent. Should we error/warn if they are present? # x1 must be 1-D, but NumPy already requires this. - return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device) + return Array._new( + np.searchsorted(x1._array, x2._array, side=side, sorter=np_sorter), + device=x1.device, + ) + def where( condition: Array, - x1: bool | int | float | complex | Array, - x2: bool | int | float | complex | Array, / + x1: Array | bool | int | float | complex, + x2: Array | bool | int | float | complex, + /, ) -> Array: """ Array API compatible wrapper for :py:func:`np.where `. See its docstring for more information. """ - if get_array_api_strict_flags()['api_version'] > '2023.12': - x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where") + x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) diff --git a/array_api_strict/_set_functions.py b/array_api_strict/_set_functions.py index 7bd5bad..e677a52 100644 --- a/array_api_strict/_set_functions.py +++ b/array_api_strict/_set_functions.py @@ -1,13 +1,12 @@ from __future__ import annotations -from ._array_object import Array - -from ._flags import requires_data_dependent_shapes - from typing import NamedTuple import numpy as np +from ._array_object import Array +from ._flags import requires_data_dependent_shapes + # Note: np.unique() is split into four functions in the array API: # unique_all, unique_counts, unique_inverse, and unique_values (this is done # to remove polymorphic return types). @@ -20,6 +19,7 @@ # Note: The functions here return a namedtuple (np.unique() returns a normal # tuple). + class UniqueAllResult(NamedTuple): values: Array indices: Array diff --git a/array_api_strict/_sorting_functions.py b/array_api_strict/_sorting_functions.py index 765bd9e..e9193f1 100644 --- a/array_api_strict/_sorting_functions.py +++ b/array_api_strict/_sorting_functions.py @@ -1,10 +1,12 @@ from __future__ import annotations -from ._array_object import Array -from ._dtypes import _real_numeric_dtypes +from typing import Literal import numpy as np +from ._array_object import Array +from ._dtypes import _real_numeric_dtypes + # Note: the descending keyword argument is new in this function def argsort( @@ -18,7 +20,7 @@ def argsort( if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in argsort") # Note: this keyword argument is different, and the default is different. - kind = "stable" if stable else "quicksort" + kind: Literal["stable", "quicksort"] = "stable" if stable else "quicksort" if not descending: res = np.argsort(x._array, axis=axis, kind=kind) else: @@ -35,6 +37,7 @@ def argsort( res = max_i - res return Array._new(res, device=x.device) + # Note: the descending keyword argument is new in this function def sort( x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True @@ -47,8 +50,7 @@ def sort( if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in sort") # Note: this keyword argument is different, and the default is different. - kind = "stable" if stable else "quicksort" - res = np.sort(x._array, axis=axis, kind=kind) + res = np.sort(x._array, axis=axis, kind="stable" if stable else "quicksort") if descending: res = np.flip(res, axis=axis) return Array._new(res, device=x.device) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index e41e7ef..668cd02 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -1,38 +1,36 @@ from __future__ import annotations +from typing import Any + +import numpy as np + +from ._array_object import Array +from ._creation_functions import ones, zeros from ._dtypes import ( - _real_floating_dtypes, - _real_numeric_dtypes, + DType, _floating_dtypes, + _np_dtype, _numeric_dtypes, + _real_floating_dtypes, + _real_numeric_dtypes, + complex64, + float32, ) -from ._array_object import Array -from ._dtypes import float32, complex64 -from ._flags import requires_api_version, get_array_api_strict_flags -from ._creation_functions import zeros, ones +from ._flags import get_array_api_strict_flags, requires_api_version from ._manipulation_functions import concat -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Tuple, Union - from ._typing import Dtype - -import numpy as np @requires_api_version('2023.12') def cumulative_sum( x: Array, /, *, - axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in cumulative_sum") - if dtype is not None: - dtype = dtype._np_dtype # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: @@ -44,7 +42,7 @@ def cumulative_sum( if axis < 0: axis += x.ndim x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis) - return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device) + return Array._new(np.cumsum(x._array, axis=axis, dtype=_np_dtype(dtype)), device=x.device) @requires_api_version('2024.12') @@ -52,8 +50,8 @@ def cumulative_prod( x: Array, /, *, - axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, ) -> Array: if x.dtype not in _numeric_dtypes: @@ -61,9 +59,6 @@ def cumulative_prod( if x.ndim == 0: raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod") - if dtype is not None: - dtype = dtype._np_dtype - if axis is None: if x.ndim > 1: raise ValueError("axis must be specified in cumulative_prod for more than one dimension") @@ -74,14 +69,14 @@ def cumulative_prod( if axis < 0: axis += x.ndim x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis) - return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype), device=x.device) + return Array._new(np.cumprod(x._array, axis=axis, dtype=_np_dtype(dtype)), device=x.device) def max( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: if x.dtype not in _real_numeric_dtypes: @@ -93,14 +88,15 @@ def mean( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: - if get_array_api_strict_flags()['api_version'] > '2023.12': - allowed_dtypes = _floating_dtypes - else: - allowed_dtypes = _real_floating_dtypes + allowed_dtypes = ( + _floating_dtypes + if get_array_api_strict_flags()['api_version'] > '2023.12' + else _real_floating_dtypes + ) if x.dtype not in allowed_dtypes: raise TypeError("Only floating-point dtypes are allowed in mean") @@ -111,7 +107,7 @@ def min( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: if x.dtype not in _real_numeric_dtypes: @@ -119,37 +115,43 @@ def min( return Array._new(np.min(x._array, axis=axis, keepdims=keepdims), device=x.device) +def _np_dtype_sumprod(x: Array, dtype: DType | None) -> np.dtype[Any] | None: + """In versions prior to 2023.12, sum() and prod() upcast for all + dtypes when dtype=None. For 2023.12, the behavior is the same as in + NumPy (only upcast for integral dtypes). + """ + if dtype is None and get_array_api_strict_flags()['api_version'] < '2023.12': + if x.dtype == float32: + return np.float64 # type: ignore[return-value] + elif x.dtype == complex64: + return np.complex128 # type: ignore[return-value] + return _np_dtype(dtype) + + def prod( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in prod") - if dtype is None: - # Note: In versions prior to 2023.12, sum() and prod() upcast for all - # dtypes when dtype=None. For 2023.12, the behavior is the same as in - # NumPy (only upcast for integral dtypes). - if get_array_api_strict_flags()['api_version'] < '2023.12': - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 - else: - dtype = dtype._np_dtype - return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims), device=x.device) + np_dtype = _np_dtype_sumprod(x, dtype) + return Array._new( + np.prod(x._array, dtype=np_dtype, axis=axis, keepdims=keepdims), + device=x.device, + ) def std( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: int | float = 0.0, keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here @@ -162,33 +164,26 @@ def sum( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sum") - if dtype is None: - # Note: In versions prior to 2023.12, sum() and prod() upcast for all - # dtypes when dtype=None. For 2023.12, the behavior is the same as in - # NumPy (only upcast for integral dtypes). - if get_array_api_strict_flags()['api_version'] < '2023.12': - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 - else: - dtype = dtype._np_dtype - return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims), device=x.device) + np_dtype = _np_dtype_sumprod(x, dtype) + return Array._new( + np.sum(x._array, axis=axis, dtype=np_dtype, keepdims=keepdims), + device=x.device, + ) def var( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: int | float = 0.0, keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index 94c4975..91095a8 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -8,41 +8,19 @@ from __future__ import annotations -__all__ = [ - "Array", - "Device", - "Dtype", - "SupportsDLPack", - "SupportsBufferProtocol", - "PyCapsule", -] - import sys +from typing import Any, Protocol, TypedDict, TypeVar -from typing import ( - Any, - TypedDict, - TypeVar, - Protocol, -) - -from ._array_object import Array, _device -from ._dtypes import _DType -from ._info import __array_namespace_info__ +from ._dtypes import DType _T_co = TypeVar("_T_co", covariant=True) + class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -Device = _device - -Dtype = _DType - -Info = __array_namespace_info__ - if sys.version_info >= (3, 12): from collections.abc import Buffer as SupportsBufferProtocol else: @@ -50,40 +28,42 @@ def __len__(self, /) -> int: ... PyCapsule = Any + class SupportsDLPack(Protocol): def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... + Capabilities = TypedDict( - "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool, - "max dimensions": int} + "Capabilities", + { + "boolean indexing": bool, + "data-dependent shapes": bool, + "max dimensions": int, + }, ) DefaultDataTypes = TypedDict( "DefaultDataTypes", { - "real floating": Dtype, - "complex floating": Dtype, - "integral": Dtype, - "indexing": Dtype, + "real floating": DType, + "complex floating": DType, + "integral": DType, + "indexing": DType, }, ) -DataTypes = TypedDict( - "DataTypes", - { - "bool": Dtype, - "float32": Dtype, - "float64": Dtype, - "complex64": Dtype, - "complex128": Dtype, - "int8": Dtype, - "int16": Dtype, - "int32": Dtype, - "int64": Dtype, - "uint8": Dtype, - "uint16": Dtype, - "uint32": Dtype, - "uint64": Dtype, - }, - total=False, -) + +class DataTypes(TypedDict, total=False): + bool: DType + float32: DType + float64: DType + complex64: DType + complex128: DType + int8: DType + int16: DType + int32: DType + int64: DType + uint8: DType + uint16: DType + uint32: DType + uint64: DType diff --git a/array_api_strict/_utility_functions.py b/array_api_strict/_utility_functions.py index f75f36f..fab1025 100644 --- a/array_api_strict/_utility_functions.py +++ b/array_api_strict/_utility_functions.py @@ -1,21 +1,20 @@ from __future__ import annotations -from ._array_object import Array -from ._flags import requires_api_version -from ._dtypes import _numeric_dtypes - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Tuple, Union +from typing import Any import numpy as np +import numpy.typing as npt + +from ._array_object import Array +from ._dtypes import _numeric_dtypes +from ._flags import requires_api_version def all( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: """ @@ -30,7 +29,7 @@ def any( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: """ @@ -40,6 +39,7 @@ def any( """ return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims)), device=x.device) + @requires_api_version('2024.12') def diff( x: Array, @@ -47,8 +47,8 @@ def diff( *, axis: int = -1, n: int = 1, - prepend: Optional[Array] = None, - append: Optional[Array] = None, + prepend: Array | None = None, + append: Array | None = None, ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in diff") @@ -57,7 +57,7 @@ def diff( # currently specified. # NumPy does not support prepend=None or append=None - kwargs = dict(axis=axis, n=n) + kwargs: dict[str, int | npt.NDArray[Any]] = {"axis": axis, "n": n} if prepend is not None: if prepend.device != x.device: raise ValueError(f"Arrays from two different devices ({prepend.device} and {x.device}) can not be combined.") diff --git a/array_api_strict/py.typed b/array_api_strict/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/array_api_strict/tests/test_validation.py b/array_api_strict/tests/test_validation.py index bd76ec6..5552e3a 100644 --- a/array_api_strict/tests/test_validation.py +++ b/array_api_strict/tests/test_validation.py @@ -1,11 +1,9 @@ -from typing import Callable - import pytest import array_api_strict as xp -def p(func: Callable, *args, **kwargs): +def p(func, *args, **kwargs): f_sig = ", ".join( [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs.items()] ) diff --git a/pyproject.toml b/pyproject.toml index b3d2594..cf9e6dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,3 +31,18 @@ Repository = "https://github.com/data-apis/array-api-strict" [tool.setuptools_scm] version_file = "array_api_strict/_version.py" +[tool.mypy] +disallow_incomplete_defs = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +no_implicit_optional = true +show_error_codes = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true +strict_bytes = true +local_partial_types = true + +[[tool.mypy.overrides]] +module = ["*.tests.*"] +disallow_untyped_defs = false