Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2c1ab5e

Browse files
committedNov 27, 2024··
WIP
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 19c4cae commit 2c1ab5e

9 files changed

+1192
-46
lines changed
 

‎.github/workflows/ci_workflows.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ name: CI
33
on:
44
push:
55
branches:
6-
- main
6+
- main
77
tags:
8-
- '*'
8+
- "*"
99
pull_request:
1010

1111
concurrency:

‎.pre-commit-config.yaml

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
ci:
2+
autoupdate_schedule: "monthly"
3+
autoupdate_commit_msg: "chore: update pre-commit hooks"
4+
autofix_commit_msg: "style: pre-commit fixes"
5+
6+
default_stages: [pre-commit, pre-push]
7+
8+
repos:
9+
- repo: meta
10+
hooks:
11+
- id: check-useless-excludes
12+
13+
- repo: https://github.com/pre-commit/pre-commit-hooks
14+
rev: "v5.0.0"
15+
hooks:
16+
- id: check-added-large-files
17+
- id: check-case-conflict
18+
- id: check-merge-conflict
19+
- id: check-yaml
20+
- id: debug-statements
21+
- id: end-of-file-fixer
22+
- id: mixed-line-ending
23+
- id: name-tests-test
24+
args: ["--pytest-test-first"]
25+
- id: trailing-whitespace
26+
27+
- repo: https://github.com/pre-commit/pygrep-hooks
28+
rev: "v1.10.0"
29+
hooks:
30+
- id: rst-backticks
31+
- id: rst-directive-colons
32+
- id: rst-inline-touching-normal
33+
34+
- repo: https://github.com/python-jsonschema/check-jsonschema
35+
rev: 0.29.4
36+
hooks:
37+
- id: check-dependabot
38+
- id: check-github-workflows
39+
- id: check-readthedocs
40+
41+
- repo: https://github.com/astral-sh/ruff-pre-commit
42+
rev: "v0.7.3"
43+
hooks:
44+
# Run the linter
45+
- id: ruff
46+
types_or: [python, pyi, jupyter]
47+
args: ["--fix", "--show-fixes"]
48+
# Run the formatter
49+
- id: ruff-format
50+
types_or: [python, pyi, jupyter]
51+
52+
- repo: https://github.com/adamchainz/blacken-docs
53+
rev: "1.19.1"
54+
hooks:
55+
- id: blacken-docs
56+
additional_dependencies: [black==23.*]
57+
58+
- repo: https://github.com/rbubley/mirrors-prettier
59+
rev: "v3.3.3"
60+
hooks:
61+
- id: prettier
62+
types_or: [yaml, markdown, html, css, scss, javascript, json]
63+
args: [--prose-wrap=always]
64+
65+
- repo: https://github.com/pre-commit/mirrors-mypy
66+
rev: "v1.13.0"
67+
hooks:
68+
- id: mypy
69+
files: src
70+
additional_dependencies:
71+
- pytest
72+
73+
- repo: https://github.com/codespell-project/codespell
74+
rev: "v2.3.0"
75+
hooks:
76+
- id: codespell
77+
78+
- repo: https://github.com/abravalheri/validate-pyproject
79+
rev: v0.23
80+
hooks:
81+
- id: validate-pyproject
82+
83+
- repo: local
84+
hooks:
85+
- id: disallow-caps
86+
name: Disallow improper capitalization
87+
language: pygrep
88+
entry: PyBind|Numpy|Cmake|CCache|Github|PyTest
89+
exclude: .pre-commit-config.yaml

‎pyproject.toml

+30-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies = [
1818
"array-api-compat>=1.9.1",
1919
"astropy>=7.0",
2020
"numpy>=2.0",
21+
"typing-extensions>=4.12.2",
2122
]
2223
dynamic = ["version"]
2324

@@ -103,6 +104,30 @@ exclude_lines = [
103104
"@overload",
104105
]
105106

107+
[tool.mypy]
108+
python_version = "3.11"
109+
files = ["quantity"]
110+
strict = true
111+
112+
disallow_incomplete_defs = true
113+
disallow_untyped_defs = false
114+
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
115+
warn_return_any = true
116+
warn_unreachable = true
117+
warn_unused_configs = true
118+
119+
[[tool.mypy.overrides]]
120+
module = ["quantity._dev.*", "quantity.tests.*"]
121+
ignore_errors = true
122+
123+
[[tool.mypy.overrides]]
124+
ignore_missing_imports = true
125+
module = [
126+
"astropy.*",
127+
"array_api_compat.*"
128+
]
129+
130+
106131
[tool.ruff]
107132
exclude=[ # package template provided files.
108133
"setup.py",
@@ -140,10 +165,12 @@ ignore = [
140165
"PLR2004", # Magic value used in comparison
141166
"RET505", # Unnecessary `else`/`elif` after `return` statement
142167
]
143-
isort.required-imports = ["from __future__ import annotations"]
144-
# Uncomment if using a _compat.typing backport
145-
# typing-modules = ["quantity_2_0._compat.typing"]
146168

147169
[tool.ruff.lint.per-file-ignores]
148170
"tests/**" = ["T20"]
149171
"noxfile.py" = ["T20"]
172+
173+
[dependency-groups]
174+
typing = [
175+
"mypy>=1.13.0",
176+
]

‎quantity/_array_api.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""Minimal definition of the Array API."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, Protocol
6+
7+
8+
class HasArrayNameSpace(Protocol):
9+
"""Minimal defintion of the Array API."""
10+
11+
def __array_namespace__(self) -> Any: ...
12+
13+
14+
class Array(HasArrayNameSpace, Protocol):
15+
"""Minimal defintion of the Array API."""
16+
17+
def __pow__(self, other: Any) -> Array: ...

‎quantity/_quantity_api.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Minimal definition of the Quantity API."""
2+
3+
__all__ = ["Quantity", "ArrayQuantity", "Unit"]
4+
5+
from typing import Protocol, runtime_checkable
6+
7+
from astropy.units import UnitBase as Unit
8+
9+
from ._array_api import Array
10+
11+
12+
@runtime_checkable
13+
class Quantity(Protocol):
14+
"""Minimal definition of the Quantity API."""
15+
16+
value: Array
17+
unit: Unit
18+
19+
20+
@runtime_checkable
21+
class ArrayQuantity(Quantity, Array, Protocol):
22+
"""An array-valued Quantity."""
23+
24+
...

‎quantity/core.py

+103-41
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,63 @@
1+
"""Quantity."""
12
# Licensed under a 3-clause BSD style license - see LICENSE.rst
3+
24
from __future__ import annotations
35

46
import operator
7+
from collections.abc import Callable
58
from dataclasses import dataclass, replace
6-
from typing import TYPE_CHECKING
9+
from typing import (
10+
TYPE_CHECKING,
11+
Any,
12+
TypeAlias,
13+
TypeGuard,
14+
TypeVar,
15+
Union,
16+
cast,
17+
overload,
18+
)
719

820
import array_api_compat
921
import astropy.units as u
1022
import numpy as np
23+
from astropy.units import UnitBase as Unit
1124
from astropy.units.quantity_helper import UFUNC_HELPERS
1225

1326
if TYPE_CHECKING:
14-
from typing import Any
27+
from types import NotImplementedType
28+
from typing import Any, Self
29+
30+
from ._array_api import Array
31+
from ._quantity_api import ArrayQuantity, Unit
32+
33+
34+
T = TypeVar("T")
1535

1636

1737
DIMENSIONLESS = u.dimensionless_unscaled
1838

1939
PYTHON_NUMBER = float | int | complex
2040

2141

22-
def has_array_namespace(arg):
42+
def has_array_namespace(arg: Any) -> TypeGuard[Array]:
2343
try:
2444
array_api_compat.array_namespace(arg)
2545
except TypeError:
2646
return False
27-
else:
28-
return True
47+
return True
2948

3049

31-
def get_value_and_unit(arg, default_unit=None):
32-
# HACK: interoperability with astropy Quantity. Have protocol?
33-
try:
34-
unit = arg.unit
35-
except AttributeError:
36-
return arg, default_unit
37-
else:
38-
return arg.value, unit
50+
def get_value_and_unit(
51+
arg: ArrayQuantity | Array, default_unit: Unit | None = None
52+
) -> tuple[Array, Unit]:
53+
return (
54+
(arg.value, arg.unit) if isinstance(arg, ArrayQuantity) else (arg, default_unit)
55+
)
3956

4057

41-
def value_in_unit(value, unit):
58+
def value_in_unit(value: Array, unit: Unit) -> Array:
4259
v_value, v_unit = get_value_and_unit(value, default_unit=DIMENSIONLESS)
43-
return v_unit.to(unit, v_value)
60+
return cast(Array, v_unit.to(unit, v_value))
4461

4562

4663
_OP_TO_NP_FUNC = {
@@ -55,7 +72,12 @@ def value_in_unit(value, unit):
5572
OP_HELPERS = {op: UFUNC_HELPERS[np_func] for op, np_func in _OP_TO_NP_FUNC.items()}
5673

5774

58-
def _make_op(fop, mode):
75+
QuantityOpCallable: TypeAlias = Callable[
76+
["Quantity", Any], Union["Quantity", NotImplementedType]
77+
]
78+
79+
80+
def _make_op(fop: str, mode: str) -> QuantityOpCallable:
5981
assert mode in "fri"
6082
op = fop if mode == "f" else "__" + mode + fop[2:]
6183
helper = OP_HELPERS[fop]
@@ -75,27 +97,29 @@ def __op__(self, other):
7597
return __op__
7698

7799

78-
def _make_ops(op):
79-
return tuple(_make_op(op, mode) for mode in "fri")
100+
def _make_ops(
101+
op: str,
102+
) -> tuple[QuantityOpCallable, QuantityOpCallable, QuantityOpCallable]:
103+
return (_make_op(op, "f"), _make_op(op, "r"), _make_op(op, "i"))
80104

81105

82-
def _make_comp(comp):
83-
def __comp__(self, other):
106+
def _make_comp(comp: str) -> Callable[[Quantity, Any], Array]:
107+
def _comp_(self: Quantity, other: Any) -> Array | NotImplementedType:
84108
try:
85109
other = value_in_unit(other, self.unit)
86110
except Exception:
87111
return NotImplemented
88112
return getattr(self.value, comp)(other)
89113

90-
return __comp__
114+
return _comp_
91115

92116

93-
def _make_deferred(attr):
117+
def _make_deferred(attr: str) -> Callable[[Quantity], property]:
94118
# Use array_api_compat getter if available (size, device), since
95119
# some array formats provide inconsistent implementations.
96120
attr_getter = getattr(array_api_compat, attr, operator.attrgetter(attr))
97121

98-
def deferred(self):
122+
def deferred(self: Quantity):
99123
return attr_getter(self.value)
100124

101125
return property(deferred)
@@ -133,33 +157,61 @@ def defer_dimensionless(self):
133157

134158
return defer_dimensionless
135159

160+
# -----------------
161+
162+
163+
@overload
164+
def _parse_pow_mod(mod: None, /) -> None: ...
165+
166+
167+
@overload
168+
def _parse_pow_mod(mod: object, /) -> NotImplementedType: ...
169+
136170

137-
def _check_pow_args(exp, mod):
138-
if mod is not None:
139-
return NotImplemented
171+
def _parse_pow_mod(mod: T, /) -> T | NotImplementedType:
172+
return mod if mod is None else NotImplemented # type: ignore[redundant-expr]
140173

141-
if not isinstance(exp, PYTHON_NUMBER):
174+
175+
# -----------------
176+
177+
178+
@overload
179+
def _check_pow_exp(exp: Array | PYTHON_NUMBER, /) -> PYTHON_NUMBER: ...
180+
181+
182+
@overload
183+
def _check_pow_exp(exp: object, /) -> NotImplementedType: ...
184+
185+
186+
def _check_pow_exp(exp: Any, /) -> PYTHON_NUMBER | NotImplementedType:
187+
out: PYTHON_NUMBER
188+
if isinstance(exp, PYTHON_NUMBER):
189+
out = exp
190+
else:
142191
try:
143-
exp = exp.__complex__()
192+
out = complex(exp)
144193
except Exception:
145194
try:
146-
return exp.__float__()
195+
return float(exp)
147196
except Exception:
148197
return NotImplemented
149198

150-
return exp.real if exp.imag == 0 else exp
199+
return out.real if out.imag == 0 else out
151200

152201

153202
@dataclass(frozen=True, eq=False)
154203
class Quantity:
155-
value: Any
156-
unit: u.UnitBase
204+
value: Array
205+
unit: Unit
157206

158207
def __array_namespace__(self, *, api_version: str | None = None) -> Any:
159208
# TODO: make our own?
209+
del api_version
160210
return np
161211

162-
def _operate(self, other, op, units_helper):
212+
def _operate(
213+
self, other: Any, op: Any, units_helper: Any
214+
) -> Self | NotImplementedType:
163215
if not has_array_namespace(other) and not isinstance(other, PYTHON_NUMBER):
164216
# HACK: unit should take care of this!
165217
if not isinstance(other, u.UnitBase):
@@ -205,7 +257,7 @@ def _operate(self, other, op, units_helper):
205257
__lt__ = _make_comp("__lt__")
206258
__ne__ = _make_comp("__ne__")
207259

208-
# Atttributes deferred to those of .value
260+
# Attributes deferred to those of .value
209261
dtype = _make_deferred("dtype")
210262
device = _make_deferred("device")
211263
ndim = _make_deferred("ndim")
@@ -228,9 +280,11 @@ def _operate(self, other, op, units_helper):
228280

229281
# TODO: __dlpack__, __dlpack_device__
230282

231-
def __pow__(self, exp, mod=None):
232-
exp = _check_pow_args(exp, mod)
233-
if exp is NotImplemented:
283+
def __pow__(self, exp: Any, mod: Any = None) -> Self | NotImplementedType:
284+
if (mod := _parse_pow_mod(mod)) is NotImplemented:
285+
return NotImplemented
286+
287+
if (exp := _check_pow_exp(exp)) is NotImplemented:
234288
return NotImplemented
235289

236290
value = self.value.__pow__(exp)
@@ -239,17 +293,25 @@ def __pow__(self, exp, mod=None):
239293
return replace(self, value=value, unit=self.unit**exp)
240294

241295
def __ipow__(self, exp, mod=None):
242-
exp = _check_pow_args(exp, mod)
243-
if exp is NotImplemented:
296+
if (mod := _parse_pow_mod(mod)) is NotImplemented:
297+
return NotImplemented
298+
299+
if (exp := _check_pow_exp(exp)) is NotImplemented:
244300
return NotImplemented
245301

246302
value = self.value.__ipow__(exp)
247303
if value is NotImplemented:
248304
return NotImplemented
249305
return replace(self, value=value, unit=self.unit**exp)
250306

251-
def __setitem__(self, item, value):
252-
self.value[item] = value_in_unit(value, self.unit)
307+
def __setitem__(self, item: Any, value: Any) -> None:
308+
"""Call the setitem method of the array for the value in the unit.
309+
310+
The Array API does not guarantee mutability of the underlying array,
311+
so this method will raise an exception if the array is immutable.
312+
313+
"""
314+
self.value[item] = value_in_unit(value, self.unit) # type: ignore[index]
253315

254316
__array_ufunc__ = None
255317
__array_function__ = None

‎quantity/py.typed

Whitespace-only changes.

‎quantity/version.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# NOTE: First try _dev.scm_version if it exists and setuptools_scm is installed
22
# This file is not included in wheels/tarballs, so otherwise it will
33
# fall back on the generated _version module.
4+
5+
__all__ = ['version']
6+
47
try:
58
try:
69
from ._dev.scm_version import version

‎uv.lock

+924
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
Please sign in to comment.