@@ -5,7 +5,7 @@ import numpy as np
5
5
import numpy .typing as npt
6
6
import optype as op
7
7
import optype .numpy as onp
8
- from numpy ._typing import _DTypeLike
8
+ from numpy ._typing import _ArrayLike , _DTypeLike , _NestedSequence
9
9
from scipy ._typing import AnyShape
10
10
11
11
__all__ = ["chirp" , "gausspulse" , "sawtooth" , "square" , "sweep_poly" , "unit_impulse" ]
@@ -18,6 +18,16 @@ _ArrayLikeFloat: TypeAlias = onp.ToFloat | onp.ToFloatND
18
18
_Array_f8 : TypeAlias = onp .ArrayND [np .float64 ]
19
19
_GaussPulseTime : TypeAlias = _ArrayLikeFloat | Literal ["cutoff" ]
20
20
21
+ # Type vars to annotate `chirp`
22
+ _NBT1 = TypeVar ("_NBT1" , bound = npt .NBitBase )
23
+ _NBT2 = TypeVar ("_NBT2" , bound = npt .NBitBase )
24
+ _NBT3 = TypeVar ("_NBT3" , bound = npt .NBitBase )
25
+ _NBT4 = TypeVar ("_NBT4" , bound = npt .NBitBase )
26
+ _NBT5 = TypeVar ("_NBT5" , bound = npt .NBitBase )
27
+ _ChirpTime : TypeAlias = _ArrayLike [np .floating [_NBT1 ] | np .integer [_NBT1 ]]
28
+ _ChirpScalar : TypeAlias = float | np .floating [_NBT1 ] | np .integer [_NBT1 ]
29
+ _ChirpMethod : TypeAlias = Literal ["linear" , "quadratic" , "logarithmic" , "hyperbolic" ]
30
+
21
31
def sawtooth (t : _ArrayLikeFloat , width : _ArrayLikeFloat = 1 ) -> _Array_f8 : ...
22
32
def square (t : _ArrayLikeFloat , duty : _ArrayLikeFloat = 0.5 ) -> _Array_f8 : ...
23
33
@@ -96,16 +106,29 @@ def gausspulse(
96
106
retenv : _Truthy ,
97
107
) -> tuple [_Array_f8 , _Array_f8 , _Array_f8 ]: ...
98
108
99
- # float16 -> float16, float32 -> float32, ... -> float64
109
+ #
110
+ @overload # Static type checking for float values
111
+ def chirp (
112
+ t : _ChirpTime [_NBT1 ],
113
+ f0 : _ChirpScalar [_NBT2 ],
114
+ t1 : _ChirpScalar [_NBT3 ],
115
+ f1 : _ChirpScalar [_NBT4 ],
116
+ method : _ChirpMethod = "linear" ,
117
+ phi : _ChirpScalar [_NBT5 ] = 0 ,
118
+ vertex_zero : op .CanBool = True ,
119
+ ) -> onp .ArrayND [np .floating [_NBT1 | _NBT2 | _NBT3 | _NBT4 | _NBT5 ]]: ...
120
+ @overload # Other dtypes default to np.float64
100
121
def chirp (
101
- t : onp .ToFloatND ,
122
+ t : onp .ToFloatND | _NestedSequence [ float ] ,
102
123
f0 : onp .ToFloat ,
103
124
t1 : onp .ToFloat ,
104
125
f1 : onp .ToFloat ,
105
- method : Literal [ "linear" , "quadratic" , "logarithmic" , "hyperbolic" ] = "linear" ,
126
+ method : _ChirpMethod = "linear" ,
106
127
phi : onp .ToFloat = 0 ,
107
128
vertex_zero : op .CanBool = True ,
108
- ) -> npt .NDArray [np .float16 | np .float32 | np .float64 ]: ...
129
+ ) -> _Array_f8 : ...
130
+
131
+ #
109
132
def sweep_poly (
110
133
t : _ArrayLikeFloat ,
111
134
poly : onp .ToFloatND | np .poly1d ,
0 commit comments