1
+ """Quantity."""
1
2
# Licensed under a 3-clause BSD style license - see LICENSE.rst
3
+
2
4
from __future__ import annotations
3
5
4
6
import operator
7
+ from collections .abc import Callable
5
8
from dataclasses import dataclass , replace
6
- from typing import TYPE_CHECKING
9
+ from typing import TYPE_CHECKING , Any , TypeAlias , TypeVar , Union , cast , overload
7
10
8
11
import array_api_compat
9
12
import astropy .units as u
10
13
import numpy as np
14
+ from astropy .units import UnitBase as Unit
11
15
from astropy .units .quantity_helper import UFUNC_HELPERS
12
16
13
17
from .utils import has_array_namespace
14
18
15
19
if TYPE_CHECKING :
16
- from typing import Any
20
+ from types import NotImplementedType
21
+ from typing import Any , Self
22
+
23
+ from ._array_api import Array
24
+ from ._quantity_api import ArrayQuantity , Unit
25
+
26
+
27
+ T = TypeVar ("T" )
17
28
18
29
19
30
DIMENSIONLESS = u .dimensionless_unscaled
20
31
21
32
PYTHON_NUMBER = float | int | complex
22
33
23
34
24
- def get_value_and_unit (arg , default_unit = None ):
25
- # HACK: interoperability with astropy Quantity. Have protocol?
26
- try :
27
- unit = arg .unit
28
- except AttributeError :
29
- return arg , default_unit
30
- else :
31
- return arg .value , unit
35
+ def get_value_and_unit (
36
+ arg : ArrayQuantity | Array , default_unit : Unit | None = None
37
+ ) -> tuple [Array , Unit ]:
38
+ return (
39
+ (arg .value , arg .unit ) if isinstance (arg , ArrayQuantity ) else (arg , default_unit )
40
+ )
32
41
33
42
34
- def value_in_unit (value , unit ) :
43
+ def value_in_unit (value : Array , unit : Unit ) -> Array :
35
44
v_value , v_unit = get_value_and_unit (value , default_unit = DIMENSIONLESS )
36
- return v_unit .to (unit , v_value )
45
+ return cast ( Array , v_unit .to (unit , v_value ) )
37
46
38
47
39
48
_OP_TO_NP_FUNC = {
@@ -48,7 +57,12 @@ def value_in_unit(value, unit):
48
57
OP_HELPERS = {op : UFUNC_HELPERS [np_func ] for op , np_func in _OP_TO_NP_FUNC .items ()}
49
58
50
59
51
- def _make_op (fop , mode ):
60
+ QuantityOpCallable : TypeAlias = Callable [
61
+ ["Quantity" , Any ], Union ["Quantity" , NotImplementedType ]
62
+ ]
63
+
64
+
65
+ def _make_op (fop : str , mode : str ) -> QuantityOpCallable :
52
66
assert mode in "fri"
53
67
op = fop if mode == "f" else "__" + mode + fop [2 :]
54
68
helper = OP_HELPERS [fop ]
@@ -68,27 +82,29 @@ def __op__(self, other):
68
82
return __op__
69
83
70
84
71
- def _make_ops (op ):
72
- return tuple (_make_op (op , mode ) for mode in "fri" )
85
+ def _make_ops (
86
+ op : str ,
87
+ ) -> tuple [QuantityOpCallable , QuantityOpCallable , QuantityOpCallable ]:
88
+ return (_make_op (op , "f" ), _make_op (op , "r" ), _make_op (op , "i" ))
73
89
74
90
75
- def _make_comp (comp ) :
76
- def __comp__ (self , other ) :
91
+ def _make_comp (comp : str ) -> Callable [[ Quantity , Any ], Array ] :
92
+ def _comp_ (self : Quantity , other : Any ) -> Array | NotImplementedType :
77
93
try :
78
94
other = value_in_unit (other , self .unit )
79
95
except Exception :
80
96
return NotImplemented
81
97
return getattr (self .value , comp )(other )
82
98
83
- return __comp__
99
+ return _comp_
84
100
85
101
86
- def _make_deferred (attr ) :
102
+ def _make_deferred (attr : str ) -> Callable [[ Quantity ], property ] :
87
103
# Use array_api_compat getter if available (size, device), since
88
104
# some array formats provide inconsistent implementations.
89
105
attr_getter = getattr (array_api_compat , attr , operator .attrgetter (attr ))
90
106
91
- def deferred (self ):
107
+ def deferred (self : Quantity ):
92
108
return attr_getter (self .value )
93
109
94
110
return property (deferred )
@@ -127,32 +143,61 @@ def defer_dimensionless(self):
127
143
return defer_dimensionless
128
144
129
145
130
- def _check_pow_args (exp , mod ):
131
- if mod is not None :
132
- return NotImplemented
146
+ # -----------------
147
+
148
+
149
+ @overload
150
+ def _parse_pow_mod (mod : None , / ) -> None : ...
151
+
152
+
153
+ @overload
154
+ def _parse_pow_mod (mod : object , / ) -> NotImplementedType : ...
155
+
156
+
157
+ def _parse_pow_mod (mod : T , / ) -> T | NotImplementedType :
158
+ return mod if mod is None else NotImplemented # type: ignore[redundant-expr]
159
+
133
160
134
- if not isinstance (exp , PYTHON_NUMBER ):
161
+ # -----------------
162
+
163
+
164
+ @overload
165
+ def _check_pow_exp (exp : Array | PYTHON_NUMBER , / ) -> PYTHON_NUMBER : ...
166
+
167
+
168
+ @overload
169
+ def _check_pow_exp (exp : object , / ) -> NotImplementedType : ...
170
+
171
+
172
+ def _check_pow_exp (exp : Any , / ) -> PYTHON_NUMBER | NotImplementedType :
173
+ out : PYTHON_NUMBER
174
+ if isinstance (exp , PYTHON_NUMBER ):
175
+ out = exp
176
+ else :
135
177
try :
136
- exp = exp . __complex__ ( )
178
+ out = complex ( exp )
137
179
except Exception :
138
180
try :
139
- return exp . __float__ ( )
181
+ return float ( exp )
140
182
except Exception :
141
183
return NotImplemented
142
184
143
- return exp .real if exp .imag == 0 else exp
185
+ return out .real if out .imag == 0 else out
144
186
145
187
146
188
@dataclass (frozen = True , eq = False )
147
189
class Quantity :
148
- value : Any
149
- unit : u . UnitBase
190
+ value : Array
191
+ unit : Unit
150
192
151
193
def __array_namespace__ (self , * , api_version : str | None = None ) -> Any :
152
194
# TODO: make our own?
195
+ del api_version
153
196
return np
154
197
155
- def _operate (self , other , op , units_helper ):
198
+ def _operate (
199
+ self , other : Any , op : Any , units_helper : Any
200
+ ) -> Self | NotImplementedType :
156
201
if not has_array_namespace (other ) and not isinstance (other , PYTHON_NUMBER ):
157
202
# HACK: unit should take care of this!
158
203
if not isinstance (other , u .UnitBase ):
@@ -221,9 +266,11 @@ def _operate(self, other, op, units_helper):
221
266
222
267
# TODO: __dlpack__, __dlpack_device__
223
268
224
- def __pow__ (self , exp , mod = None ):
225
- exp = _check_pow_args (exp , mod )
226
- if exp is NotImplemented :
269
+ def __pow__ (self , exp : Any , mod : Any = None ) -> Self | NotImplementedType :
270
+ if (mod := _parse_pow_mod (mod )) is NotImplemented :
271
+ return NotImplemented
272
+
273
+ if (exp := _check_pow_exp (exp )) is NotImplemented :
227
274
return NotImplemented
228
275
229
276
value = self .value .__pow__ (exp )
@@ -232,17 +279,25 @@ def __pow__(self, exp, mod=None):
232
279
return replace (self , value = value , unit = self .unit ** exp )
233
280
234
281
def __ipow__ (self , exp , mod = None ):
235
- exp = _check_pow_args (exp , mod )
236
- if exp is NotImplemented :
282
+ if (mod := _parse_pow_mod (mod )) is NotImplemented :
283
+ return NotImplemented
284
+
285
+ if (exp := _check_pow_exp (exp )) is NotImplemented :
237
286
return NotImplemented
238
287
239
288
value = self .value .__ipow__ (exp )
240
289
if value is NotImplemented :
241
290
return NotImplemented
242
291
return replace (self , value = value , unit = self .unit ** exp )
243
292
244
- def __setitem__ (self , item , value ):
245
- self .value [item ] = value_in_unit (value , self .unit )
293
+ def __setitem__ (self , item : Any , value : Any ) -> None :
294
+ """Call the setitem method of the array for the value in the unit.
295
+
296
+ The Array API does not guarantee mutability of the underlying array,
297
+ so this method will raise an exception if the array is immutable.
298
+
299
+ """
300
+ self .value [item ] = value_in_unit (value , self .unit ) # type: ignore[index]
246
301
247
302
__array_ufunc__ = None
248
303
__array_function__ = None
0 commit comments