Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: add typing information #8

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ repos:
args: ["--pytest-test-first"]
- id: trailing-whitespace

- repo: https://github.com/pre-commit/pygrep-hooks
rev: "v1.10.0"
hooks:
- id: rst-backticks
- id: rst-directive-colons
- id: rst-inline-touching-normal

- repo: https://github.com/python-jsonschema/check-jsonschema
rev: 0.29.4
hooks:
- id: check-dependabot
- id: check-github-workflows
- id: check-readthedocs

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.7.3"
hooks:
Expand Down Expand Up @@ -53,18 +67,6 @@ repos:
hooks:
- id: mypy
files: src
additional_dependencies:
- pytest

- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: rst-directive-colons
# Detect mistake of rst directive not ending with double colon.
- id: rst-inline-touching-normal
# Detect mistake of inline code touching normal text in rst.
- id: text-unicode-replacement-char
# Forbid files which have a UTF-8 Unicode replacement character.

- repo: https://github.com/codespell-project/codespell
rev: "v2.3.0"
Expand Down
30 changes: 30 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"array-api-compat>=1.9.1",
"astropy>=7.0",
"numpy>=2.0",
"typing-extensions>=4.12.2",
]
dynamic = ["version"]

Expand Down Expand Up @@ -105,6 +106,30 @@ exclude_lines = [
"@overload",
]

[tool.mypy]
python_version = "3.11"
files = ["quantity"]
strict = true

disallow_incomplete_defs = true
disallow_untyped_defs = false
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
warn_return_any = true
warn_unreachable = true
warn_unused_configs = true

[[tool.mypy.overrides]]
module = ["quantity._dev.*", "quantity.tests.*"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a bit of a flyby - cool to see the mypy setup! But this is out of date - test has moved now (and another reason why that is good!)

ignore_errors = true

[[tool.mypy.overrides]]
ignore_missing_imports = true
module = [
"astropy.*",
"array_api_compat.*"
]


[tool.ruff]
exclude=[ # package template provided files.
"setup.py",
Expand Down Expand Up @@ -147,3 +172,8 @@ ignore = [
[tool.ruff.lint.per-file-ignores]
"tests/**" = ["T20"]
"noxfile.py" = ["T20"]

[dependency-groups]
typing = [
"mypy>=1.13.0",
]
17 changes: 17 additions & 0 deletions src/quantity/_array_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Minimal definition of the Array API."""

from __future__ import annotations

from typing import Any, Protocol


class HasArrayNameSpace(Protocol):
"""Minimal defintion of the Array API."""

def __array_namespace__(self) -> Any: ...


class Array(HasArrayNameSpace, Protocol):
"""Minimal defintion of the Array API."""

def __pow__(self, other: Any) -> Array: ...
24 changes: 24 additions & 0 deletions src/quantity/_quantity_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Minimal definition of the Quantity API."""

__all__ = ["Quantity", "ArrayQuantity", "Unit"]

from typing import Protocol, runtime_checkable

from astropy.units import UnitBase as Unit

from ._array_api import Array


@runtime_checkable
class Quantity(Protocol):
"""Minimal definition of the Quantity API."""

value: Array
unit: Unit


@runtime_checkable
class ArrayQuantity(Quantity, Array, Protocol):
"""An array-valued Quantity."""

...
115 changes: 86 additions & 29 deletions src/quantity/_src/core.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
"""Quantity."""
# Licensed under a 3-clause BSD style license - see LICENSE.rst

from __future__ import annotations

import operator
from collections.abc import Callable
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, Union, cast, overload

import array_api_compat
import astropy.units as u
import numpy as np
from astropy.units import UnitBase as Unit
from astropy.units.quantity_helper import UFUNC_HELPERS

from .api import QuantityArray
from .utils import has_array_namespace

if TYPE_CHECKING:
from typing import Any
from types import NotImplementedType
from typing import Any, Self

from ._array_api import Array
from ._quantity_api import ArrayQuantity, Unit


T = TypeVar("T")

from .api import Unit
from .array_api import Array
Expand All @@ -33,9 +44,9 @@ def get_value_and_unit(
)


def value_in_unit(value, unit):
def value_in_unit(value: Array, unit: Unit) -> Array:
v_value, v_unit = get_value_and_unit(value, default_unit=DIMENSIONLESS)
return v_unit.to(unit, v_value)
return cast(Array, v_unit.to(unit, v_value))


_OP_TO_NP_FUNC = {
Expand All @@ -50,7 +61,12 @@ def value_in_unit(value, unit):
OP_HELPERS = {op: UFUNC_HELPERS[np_func] for op, np_func in _OP_TO_NP_FUNC.items()}


def _make_op(fop, mode):
QuantityOpCallable: TypeAlias = Callable[
["Quantity", Any], Union["Quantity", NotImplementedType]
]


def _make_op(fop: str, mode: str) -> QuantityOpCallable:
assert mode in "fri"
op = fop if mode == "f" else "__" + mode + fop[2:]
helper = OP_HELPERS[fop]
Expand All @@ -70,27 +86,29 @@ def __op__(self, other):
return __op__


def _make_ops(op):
return tuple(_make_op(op, mode) for mode in "fri")
def _make_ops(
op: str,
) -> tuple[QuantityOpCallable, QuantityOpCallable, QuantityOpCallable]:
return (_make_op(op, "f"), _make_op(op, "r"), _make_op(op, "i"))


def _make_comp(comp):
def __comp__(self, other):
def _make_comp(comp: str) -> Callable[[Quantity, Any], Array]:
def _comp_(self: Quantity, other: Any) -> Array | NotImplementedType:
try:
other = value_in_unit(other, self.unit)
except Exception:
return NotImplemented
return getattr(self.value, comp)(other)

return __comp__
return _comp_


def _make_deferred(attr):
def _make_deferred(attr: str) -> Callable[[Quantity], property]:
# Use array_api_compat getter if available (size, device), since
# some array formats provide inconsistent implementations.
attr_getter = getattr(array_api_compat, attr, operator.attrgetter(attr))

def deferred(self):
def deferred(self: Quantity):
return attr_getter(self.value)

return property(deferred)
Expand Down Expand Up @@ -129,32 +147,61 @@ def defer_dimensionless(self):
return defer_dimensionless


def _check_pow_args(exp, mod):
if mod is not None:
return NotImplemented
# -----------------


@overload
def _parse_pow_mod(mod: None, /) -> None: ...


@overload
def _parse_pow_mod(mod: object, /) -> NotImplementedType: ...


def _parse_pow_mod(mod: T, /) -> T | NotImplementedType:
return mod if mod is None else NotImplemented # type: ignore[redundant-expr]


if not isinstance(exp, PYTHON_NUMBER):
# -----------------


@overload
def _check_pow_exp(exp: Array | PYTHON_NUMBER, /) -> PYTHON_NUMBER: ...


@overload
def _check_pow_exp(exp: object, /) -> NotImplementedType: ...


def _check_pow_exp(exp: Any, /) -> PYTHON_NUMBER | NotImplementedType:
out: PYTHON_NUMBER
if isinstance(exp, PYTHON_NUMBER):
out = exp
else:
try:
exp = exp.__complex__()
out = complex(exp)
except Exception:
try:
return exp.__float__()
return float(exp)
except Exception:
return NotImplemented

return exp.real if exp.imag == 0 else exp
return out.real if out.imag == 0 else out


@dataclass(frozen=True, eq=False)
class Quantity:
value: Any
unit: u.UnitBase
value: Array
unit: Unit

def __array_namespace__(self, *, api_version: str | None = None) -> Any:
# TODO: make our own?
del api_version
return np

def _operate(self, other, op, units_helper):
def _operate(
self, other: Any, op: Any, units_helper: Any
) -> Self | NotImplementedType:
if not has_array_namespace(other) and not isinstance(other, PYTHON_NUMBER):
# HACK: unit should take care of this!
if not isinstance(other, u.UnitBase):
Expand Down Expand Up @@ -223,9 +270,11 @@ def _operate(self, other, op, units_helper):

# TODO: __dlpack__, __dlpack_device__

def __pow__(self, exp, mod=None):
exp = _check_pow_args(exp, mod)
if exp is NotImplemented:
def __pow__(self, exp: Any, mod: Any = None) -> Self | NotImplementedType:
if (mod := _parse_pow_mod(mod)) is NotImplemented:
return NotImplemented

if (exp := _check_pow_exp(exp)) is NotImplemented:
return NotImplemented

value = self.value.__pow__(exp)
Expand All @@ -234,17 +283,25 @@ def __pow__(self, exp, mod=None):
return replace(self, value=value, unit=self.unit**exp)

def __ipow__(self, exp, mod=None):
exp = _check_pow_args(exp, mod)
if exp is NotImplemented:
if (mod := _parse_pow_mod(mod)) is NotImplemented:
return NotImplemented

if (exp := _check_pow_exp(exp)) is NotImplemented:
return NotImplemented

value = self.value.__ipow__(exp)
if value is NotImplemented:
return NotImplemented
return replace(self, value=value, unit=self.unit**exp)

def __setitem__(self, item, value):
self.value[item] = value_in_unit(value, self.unit)
def __setitem__(self, item: Any, value: Any) -> None:
"""Call the setitem method of the array for the value in the unit.

The Array API does not guarantee mutability of the underlying array,
so this method will raise an exception if the array is immutable.

"""
self.value[item] = value_in_unit(value, self.unit) # type: ignore[index]

__array_ufunc__ = None
__array_function__ = None
7 changes: 4 additions & 3 deletions src/quantity/_src/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Utility functions for the quantity package."""

from typing import Any, TypeGuard

import array_api_compat


def has_array_namespace(arg: object) -> bool:
def has_array_namespace(arg: Any) -> TypeGuard[Array]:
try:
array_api_compat.array_namespace(arg)
except TypeError:
return False
else:
return True
return True
Empty file added src/quantity/py.typed
Empty file.
3 changes: 3 additions & 0 deletions src/quantity/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# NOTE: First try _dev.scm_version if it exists and setuptools_scm is installed
# This file is not included in wheels/tarballs, so otherwise it will
# fall back on the generated _version module.

__all__ = ['version']

version: str
try:
try:
Expand Down
Loading