1
+ """Quantity."""
1
2
# Licensed under a 3-clause BSD style license - see LICENSE.rst
3
+
2
4
from __future__ import annotations
3
5
4
6
import operator
7
+ from collections .abc import Callable
5
8
from dataclasses import dataclass , replace
6
- from typing import TYPE_CHECKING
9
+ from typing import (
10
+ TYPE_CHECKING ,
11
+ Any ,
12
+ TypeAlias ,
13
+ TypeGuard ,
14
+ TypeVar ,
15
+ Union ,
16
+ cast ,
17
+ overload ,
18
+ )
7
19
8
20
import array_api_compat
9
21
import astropy .units as u
10
22
import numpy as np
23
+ from astropy .units import UnitBase as Unit
11
24
from astropy .units .quantity_helper import UFUNC_HELPERS
12
25
13
26
if TYPE_CHECKING :
14
- from typing import Any
27
+ from types import NotImplementedType
28
+ from typing import Any , Self
29
+
30
+ from ._array_api import Array
31
+ from ._quantity_api import ArrayQuantity , Unit
32
+
33
+
34
+ T = TypeVar ("T" )
15
35
16
36
17
37
DIMENSIONLESS = u .dimensionless_unscaled
18
38
19
39
PYTHON_NUMBER = float | int | complex
20
40
21
41
22
- def has_array_namespace (arg ) :
42
+ def has_array_namespace (arg : Any ) -> TypeGuard [ Array ] :
23
43
try :
24
44
array_api_compat .array_namespace (arg )
25
45
except TypeError :
26
46
return False
27
- else :
28
- return True
47
+ return True
29
48
30
49
31
- def get_value_and_unit (arg , default_unit = None ):
32
- # HACK: interoperability with astropy Quantity. Have protocol?
33
- try :
34
- unit = arg .unit
35
- except AttributeError :
36
- return arg , default_unit
37
- else :
38
- return arg .value , unit
50
+ def get_value_and_unit (
51
+ arg : ArrayQuantity | Array , default_unit : Unit | None = None
52
+ ) -> tuple [Array , Unit ]:
53
+ return (
54
+ (arg .value , arg .unit ) if isinstance (arg , ArrayQuantity ) else (arg , default_unit )
55
+ )
39
56
40
57
41
- def value_in_unit (value , unit ) :
58
+ def value_in_unit (value : Array , unit : Unit ) -> Array :
42
59
v_value , v_unit = get_value_and_unit (value , default_unit = DIMENSIONLESS )
43
- return v_unit .to (unit , v_value )
60
+ return cast ( Array , v_unit .to (unit , v_value ) )
44
61
45
62
46
63
_OP_TO_NP_FUNC = {
@@ -55,7 +72,12 @@ def value_in_unit(value, unit):
55
72
OP_HELPERS = {op : UFUNC_HELPERS [np_func ] for op , np_func in _OP_TO_NP_FUNC .items ()}
56
73
57
74
58
- def _make_op (fop , mode ):
75
+ QuantityOpCallable : TypeAlias = Callable [
76
+ ["Quantity" , Any ], Union ["Quantity" , NotImplementedType ]
77
+ ]
78
+
79
+
80
+ def _make_op (fop : str , mode : str ) -> QuantityOpCallable :
59
81
assert mode in "fri"
60
82
op = fop if mode == "f" else "__" + mode + fop [2 :]
61
83
helper = OP_HELPERS [fop ]
@@ -75,27 +97,29 @@ def __op__(self, other):
75
97
return __op__
76
98
77
99
78
- def _make_ops (op ):
79
- return tuple (_make_op (op , mode ) for mode in "fri" )
100
+ def _make_ops (
101
+ op : str ,
102
+ ) -> tuple [QuantityOpCallable , QuantityOpCallable , QuantityOpCallable ]:
103
+ return (_make_op (op , "f" ), _make_op (op , "r" ), _make_op (op , "i" ))
80
104
81
105
82
- def _make_comp (comp ) :
83
- def __comp__ (self , other ) :
106
+ def _make_comp (comp : str ) -> Callable [[ Quantity , Any ], Array ] :
107
+ def _comp_ (self : Quantity , other : Any ) -> Array | NotImplementedType :
84
108
try :
85
109
other = value_in_unit (other , self .unit )
86
110
except Exception :
87
111
return NotImplemented
88
112
return getattr (self .value , comp )(other )
89
113
90
- return __comp__
114
+ return _comp_
91
115
92
116
93
- def _make_deferred (attr ) :
117
+ def _make_deferred (attr : str ) -> Callable [[ Quantity ], property ] :
94
118
# Use array_api_compat getter if available (size, device), since
95
119
# some array formats provide inconsistent implementations.
96
120
attr_getter = getattr (array_api_compat , attr , operator .attrgetter (attr ))
97
121
98
- def deferred (self ):
122
+ def deferred (self : Quantity ):
99
123
return attr_getter (self .value )
100
124
101
125
return property (deferred )
@@ -133,33 +157,61 @@ def defer_dimensionless(self):
133
157
134
158
return defer_dimensionless
135
159
160
+ # -----------------
161
+
162
+
163
+ @overload
164
+ def _parse_pow_mod (mod : None , / ) -> None : ...
165
+
166
+
167
+ @overload
168
+ def _parse_pow_mod (mod : object , / ) -> NotImplementedType : ...
169
+
136
170
137
- def _check_pow_args (exp , mod ):
138
- if mod is not None :
139
- return NotImplemented
171
+ def _parse_pow_mod (mod : T , / ) -> T | NotImplementedType :
172
+ return mod if mod is None else NotImplemented # type: ignore[redundant-expr]
140
173
141
- if not isinstance (exp , PYTHON_NUMBER ):
174
+
175
+ # -----------------
176
+
177
+
178
+ @overload
179
+ def _check_pow_exp (exp : Array | PYTHON_NUMBER , / ) -> PYTHON_NUMBER : ...
180
+
181
+
182
+ @overload
183
+ def _check_pow_exp (exp : object , / ) -> NotImplementedType : ...
184
+
185
+
186
+ def _check_pow_exp (exp : Any , / ) -> PYTHON_NUMBER | NotImplementedType :
187
+ out : PYTHON_NUMBER
188
+ if isinstance (exp , PYTHON_NUMBER ):
189
+ out = exp
190
+ else :
142
191
try :
143
- exp = exp . __complex__ ( )
192
+ out = complex ( exp )
144
193
except Exception :
145
194
try :
146
- return exp . __float__ ( )
195
+ return float ( exp )
147
196
except Exception :
148
197
return NotImplemented
149
198
150
- return exp .real if exp .imag == 0 else exp
199
+ return out .real if out .imag == 0 else out
151
200
152
201
153
202
@dataclass (frozen = True , eq = False )
154
203
class Quantity :
155
- value : Any
156
- unit : u . UnitBase
204
+ value : Array
205
+ unit : Unit
157
206
158
207
def __array_namespace__ (self , * , api_version : str | None = None ) -> Any :
159
208
# TODO: make our own?
209
+ del api_version
160
210
return np
161
211
162
- def _operate (self , other , op , units_helper ):
212
+ def _operate (
213
+ self , other : Any , op : Any , units_helper : Any
214
+ ) -> Self | NotImplementedType :
163
215
if not has_array_namespace (other ) and not isinstance (other , PYTHON_NUMBER ):
164
216
# HACK: unit should take care of this!
165
217
if not isinstance (other , u .UnitBase ):
@@ -205,7 +257,7 @@ def _operate(self, other, op, units_helper):
205
257
__lt__ = _make_comp ("__lt__" )
206
258
__ne__ = _make_comp ("__ne__" )
207
259
208
- # Atttributes deferred to those of .value
260
+ # Attributes deferred to those of .value
209
261
dtype = _make_deferred ("dtype" )
210
262
device = _make_deferred ("device" )
211
263
ndim = _make_deferred ("ndim" )
@@ -228,9 +280,11 @@ def _operate(self, other, op, units_helper):
228
280
229
281
# TODO: __dlpack__, __dlpack_device__
230
282
231
- def __pow__ (self , exp , mod = None ):
232
- exp = _check_pow_args (exp , mod )
233
- if exp is NotImplemented :
283
+ def __pow__ (self , exp : Any , mod : Any = None ) -> Self | NotImplementedType :
284
+ if (mod := _parse_pow_mod (mod )) is NotImplemented :
285
+ return NotImplemented
286
+
287
+ if (exp := _check_pow_exp (exp )) is NotImplemented :
234
288
return NotImplemented
235
289
236
290
value = self .value .__pow__ (exp )
@@ -239,17 +293,25 @@ def __pow__(self, exp, mod=None):
239
293
return replace (self , value = value , unit = self .unit ** exp )
240
294
241
295
def __ipow__ (self , exp , mod = None ):
242
- exp = _check_pow_args (exp , mod )
243
- if exp is NotImplemented :
296
+ if (mod := _parse_pow_mod (mod )) is NotImplemented :
297
+ return NotImplemented
298
+
299
+ if (exp := _check_pow_exp (exp )) is NotImplemented :
244
300
return NotImplemented
245
301
246
302
value = self .value .__ipow__ (exp )
247
303
if value is NotImplemented :
248
304
return NotImplemented
249
305
return replace (self , value = value , unit = self .unit ** exp )
250
306
251
- def __setitem__ (self , item , value ):
252
- self .value [item ] = value_in_unit (value , self .unit )
307
+ def __setitem__ (self , item : Any , value : Any ) -> None :
308
+ """Call the setitem method of the array for the value in the unit.
309
+
310
+ The Array API does not guarantee mutability of the underlying array,
311
+ so this method will raise an exception if the array is immutable.
312
+
313
+ """
314
+ self .value [item ] = value_in_unit (value , self .unit ) # type: ignore[index]
253
315
254
316
__array_ufunc__ = None
255
317
__array_function__ = None
0 commit comments