Skip to content

Commit 79d4f22

Browse files
committed
WIP
Signed-off-by: nstarman <[email protected]>
1 parent 61011f8 commit 79d4f22

File tree

8 files changed

+184
-50
lines changed

8 files changed

+184
-50
lines changed

.pre-commit-config.yaml

+14-10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ repos:
2424
args: ["--pytest-test-first"]
2525
- id: trailing-whitespace
2626

27+
- repo: https://github.com/pre-commit/pygrep-hooks
28+
rev: "v1.10.0"
29+
hooks:
30+
- id: rst-backticks
31+
- id: rst-directive-colons
32+
- id: rst-inline-touching-normal
33+
34+
- repo: https://github.com/python-jsonschema/check-jsonschema
35+
rev: 0.29.4
36+
hooks:
37+
- id: check-dependabot
38+
- id: check-github-workflows
39+
- id: check-readthedocs
40+
2741
- repo: https://github.com/astral-sh/ruff-pre-commit
2842
rev: "v0.7.3"
2943
hooks:
@@ -56,16 +70,6 @@ repos:
5670
additional_dependencies:
5771
- pytest
5872

59-
- repo: https://github.com/pre-commit/pygrep-hooks
60-
rev: v1.10.0
61-
hooks:
62-
- id: rst-directive-colons
63-
# Detect mistake of rst directive not ending with double colon.
64-
- id: rst-inline-touching-normal
65-
# Detect mistake of inline code touching normal text in rst.
66-
- id: text-unicode-replacement-char
67-
# Forbid files which have a UTF-8 Unicode replacement character.
68-
6973
- repo: https://github.com/codespell-project/codespell
7074
rev: "v2.3.0"
7175
hooks:

pyproject.toml

+30
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies = [
1111
"array-api-compat>=1.9.1",
1212
"astropy>=7.0",
1313
"numpy>=2.0",
14+
"typing-extensions>=4.12.2",
1415
]
1516
dynamic = ["version"]
1617

@@ -105,6 +106,30 @@ exclude_lines = [
105106
"@overload",
106107
]
107108

109+
[tool.mypy]
110+
python_version = "3.11"
111+
files = ["quantity"]
112+
strict = true
113+
114+
disallow_incomplete_defs = true
115+
disallow_untyped_defs = false
116+
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
117+
warn_return_any = true
118+
warn_unreachable = true
119+
warn_unused_configs = true
120+
121+
[[tool.mypy.overrides]]
122+
module = ["quantity._dev.*", "quantity.tests.*"]
123+
ignore_errors = true
124+
125+
[[tool.mypy.overrides]]
126+
ignore_missing_imports = true
127+
module = [
128+
"astropy.*",
129+
"array_api_compat.*"
130+
]
131+
132+
108133
[tool.ruff]
109134
exclude=[ # package template provided files.
110135
"setup.py",
@@ -146,3 +171,8 @@ ignore = [
146171
[tool.ruff.lint.per-file-ignores]
147172
"tests/**" = ["T20"]
148173
"noxfile.py" = ["T20"]
174+
175+
[dependency-groups]
176+
typing = [
177+
"mypy>=1.13.0",
178+
]

quantity/_array_api.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""Minimal definition of the Array API."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, Protocol
6+
7+
8+
class HasArrayNameSpace(Protocol):
9+
"""Minimal defintion of the Array API."""
10+
11+
def __array_namespace__(self) -> Any: ...
12+
13+
14+
class Array(HasArrayNameSpace, Protocol):
15+
"""Minimal defintion of the Array API."""
16+
17+
def __pow__(self, other: Any) -> Array: ...

quantity/_quantity_api.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Minimal definition of the Quantity API."""
2+
3+
__all__ = ["Quantity", "ArrayQuantity", "Unit"]
4+
5+
from typing import Protocol, runtime_checkable
6+
7+
from astropy.units import UnitBase as Unit
8+
9+
from ._array_api import Array
10+
11+
12+
@runtime_checkable
13+
class Quantity(Protocol):
14+
"""Minimal definition of the Quantity API."""
15+
16+
value: Array
17+
unit: Unit
18+
19+
20+
@runtime_checkable
21+
class ArrayQuantity(Quantity, Array, Protocol):
22+
"""An array-valued Quantity."""
23+
24+
...

quantity/py.typed

Whitespace-only changes.

src/quantity/_src/core.py

+92-37
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,48 @@
1+
"""Quantity."""
12
# Licensed under a 3-clause BSD style license - see LICENSE.rst
3+
24
from __future__ import annotations
35

46
import operator
7+
from collections.abc import Callable
58
from dataclasses import dataclass, replace
6-
from typing import TYPE_CHECKING
9+
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, Union, cast, overload
710

811
import array_api_compat
912
import astropy.units as u
1013
import numpy as np
14+
from astropy.units import UnitBase as Unit
1115
from astropy.units.quantity_helper import UFUNC_HELPERS
1216

1317
from .utils import has_array_namespace
1418

1519
if TYPE_CHECKING:
16-
from typing import Any
20+
from types import NotImplementedType
21+
from typing import Any, Self
22+
23+
from ._array_api import Array
24+
from ._quantity_api import ArrayQuantity, Unit
25+
26+
27+
T = TypeVar("T")
1728

1829

1930
DIMENSIONLESS = u.dimensionless_unscaled
2031

2132
PYTHON_NUMBER = float | int | complex
2233

2334

24-
def get_value_and_unit(arg, default_unit=None):
25-
# HACK: interoperability with astropy Quantity. Have protocol?
26-
try:
27-
unit = arg.unit
28-
except AttributeError:
29-
return arg, default_unit
30-
else:
31-
return arg.value, unit
35+
def get_value_and_unit(
36+
arg: ArrayQuantity | Array, default_unit: Unit | None = None
37+
) -> tuple[Array, Unit]:
38+
return (
39+
(arg.value, arg.unit) if isinstance(arg, ArrayQuantity) else (arg, default_unit)
40+
)
3241

3342

34-
def value_in_unit(value, unit):
43+
def value_in_unit(value: Array, unit: Unit) -> Array:
3544
v_value, v_unit = get_value_and_unit(value, default_unit=DIMENSIONLESS)
36-
return v_unit.to(unit, v_value)
45+
return cast(Array, v_unit.to(unit, v_value))
3746

3847

3948
_OP_TO_NP_FUNC = {
@@ -48,7 +57,12 @@ def value_in_unit(value, unit):
4857
OP_HELPERS = {op: UFUNC_HELPERS[np_func] for op, np_func in _OP_TO_NP_FUNC.items()}
4958

5059

51-
def _make_op(fop, mode):
60+
QuantityOpCallable: TypeAlias = Callable[
61+
["Quantity", Any], Union["Quantity", NotImplementedType]
62+
]
63+
64+
65+
def _make_op(fop: str, mode: str) -> QuantityOpCallable:
5266
assert mode in "fri"
5367
op = fop if mode == "f" else "__" + mode + fop[2:]
5468
helper = OP_HELPERS[fop]
@@ -68,27 +82,29 @@ def __op__(self, other):
6882
return __op__
6983

7084

71-
def _make_ops(op):
72-
return tuple(_make_op(op, mode) for mode in "fri")
85+
def _make_ops(
86+
op: str,
87+
) -> tuple[QuantityOpCallable, QuantityOpCallable, QuantityOpCallable]:
88+
return (_make_op(op, "f"), _make_op(op, "r"), _make_op(op, "i"))
7389

7490

75-
def _make_comp(comp):
76-
def __comp__(self, other):
91+
def _make_comp(comp: str) -> Callable[[Quantity, Any], Array]:
92+
def _comp_(self: Quantity, other: Any) -> Array | NotImplementedType:
7793
try:
7894
other = value_in_unit(other, self.unit)
7995
except Exception:
8096
return NotImplemented
8197
return getattr(self.value, comp)(other)
8298

83-
return __comp__
99+
return _comp_
84100

85101

86-
def _make_deferred(attr):
102+
def _make_deferred(attr: str) -> Callable[[Quantity], property]:
87103
# Use array_api_compat getter if available (size, device), since
88104
# some array formats provide inconsistent implementations.
89105
attr_getter = getattr(array_api_compat, attr, operator.attrgetter(attr))
90106

91-
def deferred(self):
107+
def deferred(self: Quantity):
92108
return attr_getter(self.value)
93109

94110
return property(deferred)
@@ -127,32 +143,61 @@ def defer_dimensionless(self):
127143
return defer_dimensionless
128144

129145

130-
def _check_pow_args(exp, mod):
131-
if mod is not None:
132-
return NotImplemented
146+
# -----------------
147+
148+
149+
@overload
150+
def _parse_pow_mod(mod: None, /) -> None: ...
151+
152+
153+
@overload
154+
def _parse_pow_mod(mod: object, /) -> NotImplementedType: ...
155+
156+
157+
def _parse_pow_mod(mod: T, /) -> T | NotImplementedType:
158+
return mod if mod is None else NotImplemented # type: ignore[redundant-expr]
159+
133160

134-
if not isinstance(exp, PYTHON_NUMBER):
161+
# -----------------
162+
163+
164+
@overload
165+
def _check_pow_exp(exp: Array | PYTHON_NUMBER, /) -> PYTHON_NUMBER: ...
166+
167+
168+
@overload
169+
def _check_pow_exp(exp: object, /) -> NotImplementedType: ...
170+
171+
172+
def _check_pow_exp(exp: Any, /) -> PYTHON_NUMBER | NotImplementedType:
173+
out: PYTHON_NUMBER
174+
if isinstance(exp, PYTHON_NUMBER):
175+
out = exp
176+
else:
135177
try:
136-
exp = exp.__complex__()
178+
out = complex(exp)
137179
except Exception:
138180
try:
139-
return exp.__float__()
181+
return float(exp)
140182
except Exception:
141183
return NotImplemented
142184

143-
return exp.real if exp.imag == 0 else exp
185+
return out.real if out.imag == 0 else out
144186

145187

146188
@dataclass(frozen=True, eq=False)
147189
class Quantity:
148-
value: Any
149-
unit: u.UnitBase
190+
value: Array
191+
unit: Unit
150192

151193
def __array_namespace__(self, *, api_version: str | None = None) -> Any:
152194
# TODO: make our own?
195+
del api_version
153196
return np
154197

155-
def _operate(self, other, op, units_helper):
198+
def _operate(
199+
self, other: Any, op: Any, units_helper: Any
200+
) -> Self | NotImplementedType:
156201
if not has_array_namespace(other) and not isinstance(other, PYTHON_NUMBER):
157202
# HACK: unit should take care of this!
158203
if not isinstance(other, u.UnitBase):
@@ -221,9 +266,11 @@ def _operate(self, other, op, units_helper):
221266

222267
# TODO: __dlpack__, __dlpack_device__
223268

224-
def __pow__(self, exp, mod=None):
225-
exp = _check_pow_args(exp, mod)
226-
if exp is NotImplemented:
269+
def __pow__(self, exp: Any, mod: Any = None) -> Self | NotImplementedType:
270+
if (mod := _parse_pow_mod(mod)) is NotImplemented:
271+
return NotImplemented
272+
273+
if (exp := _check_pow_exp(exp)) is NotImplemented:
227274
return NotImplemented
228275

229276
value = self.value.__pow__(exp)
@@ -232,17 +279,25 @@ def __pow__(self, exp, mod=None):
232279
return replace(self, value=value, unit=self.unit**exp)
233280

234281
def __ipow__(self, exp, mod=None):
235-
exp = _check_pow_args(exp, mod)
236-
if exp is NotImplemented:
282+
if (mod := _parse_pow_mod(mod)) is NotImplemented:
283+
return NotImplemented
284+
285+
if (exp := _check_pow_exp(exp)) is NotImplemented:
237286
return NotImplemented
238287

239288
value = self.value.__ipow__(exp)
240289
if value is NotImplemented:
241290
return NotImplemented
242291
return replace(self, value=value, unit=self.unit**exp)
243292

244-
def __setitem__(self, item, value):
245-
self.value[item] = value_in_unit(value, self.unit)
293+
def __setitem__(self, item: Any, value: Any) -> None:
294+
"""Call the setitem method of the array for the value in the unit.
295+
296+
The Array API does not guarantee mutability of the underlying array,
297+
so this method will raise an exception if the array is immutable.
298+
299+
"""
300+
self.value[item] = value_in_unit(value, self.unit) # type: ignore[index]
246301

247302
__array_ufunc__ = None
248303
__array_function__ = None

src/quantity/_src/utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Utility functions for the quantity package."""
22

3+
from typing import Any, TypeGuard
4+
35
import array_api_compat
46

57

6-
def has_array_namespace(arg: object) -> bool:
8+
def has_array_namespace(arg: Any) -> TypeGuard[Array]:
79
try:
810
array_api_compat.array_namespace(arg)
911
except TypeError:
1012
return False
11-
else:
12-
return True
13+
return True

src/quantity/version.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# NOTE: First try _dev.scm_version if it exists and setuptools_scm is installed
22
# This file is not included in wheels/tarballs, so otherwise it will
33
# fall back on the generated _version module.
4+
5+
__all__ = ['version']
6+
47
version: str
58
try:
69
try:

0 commit comments

Comments
 (0)