Skip to content

Commit 3278292

Browse files
committed
signal+tests: Add all possible overloads for gausspulse.
`t` can be either an array or a scalar or "cutoff". Add corresponding tests
1 parent 647b1c3 commit 3278292

File tree

2 files changed

+275
-65
lines changed

2 files changed

+275
-65
lines changed

scipy-stubs/signal/_waveforms.pyi

+136-50
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ _Truthy: TypeAlias = Literal[1, True]
1616
_Falsy: TypeAlias = Literal[0, False]
1717
_ArrayLikeFloat: TypeAlias = onp.ToFloat | onp.ToFloatND
1818
_Array_f8: TypeAlias = onp.ArrayND[np.float64]
19-
_GaussPulseTime: TypeAlias = _ArrayLikeFloat | Literal["cutoff"]
2019

2120
# Type vars to annotate `chirp`
2221
_NBT1 = TypeVar("_NBT1", bound=npt.NBitBase)
@@ -32,9 +31,127 @@ def sawtooth(t: _ArrayLikeFloat, width: _ArrayLikeFloat = 1) -> _Array_f8: ...
3231
def square(t: _ArrayLikeFloat, duty: _ArrayLikeFloat = 0.5) -> _Array_f8: ...
3332

3433
#
34+
@overload # Static type checking for float values
35+
def chirp(
36+
t: _ChirpTime[_NBT1],
37+
f0: _ChirpScalar[_NBT2],
38+
t1: _ChirpScalar[_NBT3],
39+
f1: _ChirpScalar[_NBT4],
40+
method: _ChirpMethod = "linear",
41+
phi: _ChirpScalar[_NBT5] = 0,
42+
vertex_zero: op.CanBool = True,
43+
) -> onp.ArrayND[np.floating[_NBT1 | _NBT2 | _NBT3 | _NBT4 | _NBT5]]: ...
44+
@overload # Other dtypes default to np.float64
45+
def chirp(
46+
t: onp.ToFloatND | _NestedSequence[float],
47+
f0: onp.ToFloat,
48+
t1: onp.ToFloat,
49+
f1: onp.ToFloat,
50+
method: _ChirpMethod = "linear",
51+
phi: onp.ToFloat = 0,
52+
vertex_zero: op.CanBool = True,
53+
) -> _Array_f8: ...
54+
55+
#
56+
def sweep_poly(
57+
t: _ArrayLikeFloat,
58+
poly: onp.ToFloatND | np.poly1d,
59+
phi: onp.ToFloat = 0,
60+
) -> _Array_f8: ...
61+
62+
#
63+
@overload # dtype is not given
64+
def unit_impulse(
65+
shape: AnyShape,
66+
idx: op.CanIndex | Iterable[op.CanIndex] | Literal["mid"] | None = None,
67+
dtype: type[float] = ...,
68+
) -> _Array_f8: ...
69+
@overload # dtype is given
70+
def unit_impulse(
71+
shape: AnyShape,
72+
idx: op.CanIndex | Iterable[op.CanIndex] | Literal["mid"] | None,
73+
dtype: _DTypeLike[_SCT],
74+
) -> npt.NDArray[_SCT]: ...
75+
76+
# Overloads for gausspulse when `t` is scalar
77+
@overload # retquad: False = ..., retenv: False = ...
78+
def gausspulse(
79+
t: onp.ToFloat,
80+
fc: onp.ToFloat = 1000,
81+
bw: onp.ToFloat = 0.5,
82+
bwr: onp.ToFloat = -6,
83+
tpr: onp.ToFloat = -60,
84+
retquad: _Falsy = False,
85+
retenv: _Falsy = False,
86+
) -> np.float64: ...
87+
@overload # retquad: False = ..., retenv: True (keyword)
88+
def gausspulse(
89+
t: onp.ToFloat,
90+
fc: onp.ToFloat = 1000,
91+
bw: onp.ToFloat = 0.5,
92+
bwr: onp.ToFloat = -6,
93+
tpr: onp.ToFloat = -60,
94+
retquad: _Falsy = False,
95+
*,
96+
retenv: _Truthy,
97+
) -> tuple[np.float64, np.float64]: ...
98+
@overload # retquad: False (positional), retenv: False (positional)
99+
def gausspulse(
100+
t: onp.ToFloat,
101+
fc: onp.ToFloat,
102+
bw: onp.ToFloat,
103+
bwr: onp.ToFloat,
104+
tpr: onp.ToFloat,
105+
retquad: _Falsy,
106+
retenv: _Truthy,
107+
) -> tuple[np.float64, np.float64]: ...
108+
@overload # retquad: True (positional), retenv: False = ...
109+
def gausspulse(
110+
t: onp.ToFloat,
111+
fc: onp.ToFloat,
112+
bw: onp.ToFloat,
113+
bwr: onp.ToFloat,
114+
tpr: onp.ToFloat,
115+
retquad: _Truthy,
116+
retenv: _Falsy = False,
117+
) -> tuple[np.float64, np.float64]: ...
118+
@overload # retquad: True (keyword), retenv: False = ...
119+
def gausspulse(
120+
t: onp.ToFloat,
121+
fc: onp.ToFloat = 1000,
122+
bw: onp.ToFloat = 0.5,
123+
bwr: onp.ToFloat = -6,
124+
tpr: onp.ToFloat = -60,
125+
*,
126+
retquad: _Truthy,
127+
retenv: _Falsy = False,
128+
) -> tuple[np.float64, np.float64]: ...
129+
@overload # retquad: True (positional), retenv: True (positional/keyword)
130+
def gausspulse(
131+
t: onp.ToFloat,
132+
fc: onp.ToFloat,
133+
bw: onp.ToFloat,
134+
bwr: onp.ToFloat,
135+
tpr: onp.ToFloat,
136+
retquad: _Truthy,
137+
retenv: _Truthy,
138+
) -> tuple[np.float64, np.float64, np.float64]: ...
139+
@overload # retquad: True (keyword), retenv: True
140+
def gausspulse(
141+
t: onp.ToFloat,
142+
fc: onp.ToFloat = 1000,
143+
bw: onp.ToFloat = 0.5,
144+
bwr: onp.ToFloat = -6,
145+
tpr: onp.ToFloat = -60,
146+
*,
147+
retquad: _Truthy,
148+
retenv: _Truthy,
149+
) -> tuple[np.float64, np.float64, np.float64]: ...
150+
151+
# Overloads for `gausspulse` when `t` is a non-scalar array like
35152
@overload # retquad: False = ..., retenv: False = ...
36153
def gausspulse(
37-
t: _GaussPulseTime,
154+
t: onp.ToFloatND,
38155
fc: onp.ToFloat = 1000,
39156
bw: onp.ToFloat = 0.5,
40157
bwr: onp.ToFloat = -6,
@@ -44,7 +161,7 @@ def gausspulse(
44161
) -> _Array_f8: ...
45162
@overload # retquad: False = ..., retenv: True (keyword)
46163
def gausspulse(
47-
t: _GaussPulseTime,
164+
t: onp.ToFloatND,
48165
fc: onp.ToFloat = 1000,
49166
bw: onp.ToFloat = 0.5,
50167
bwr: onp.ToFloat = -6,
@@ -55,7 +172,7 @@ def gausspulse(
55172
) -> tuple[_Array_f8, _Array_f8]: ...
56173
@overload # retquad: False (positional), retenv: False (positional)
57174
def gausspulse(
58-
t: _GaussPulseTime,
175+
t: onp.ToFloatND,
59176
fc: onp.ToFloat,
60177
bw: onp.ToFloat,
61178
bwr: onp.ToFloat,
@@ -65,7 +182,7 @@ def gausspulse(
65182
) -> tuple[_Array_f8, _Array_f8]: ...
66183
@overload # retquad: True (positional), retenv: False = ...
67184
def gausspulse(
68-
t: _GaussPulseTime,
185+
t: onp.ToFloatND,
69186
fc: onp.ToFloat,
70187
bw: onp.ToFloat,
71188
bwr: onp.ToFloat,
@@ -75,7 +192,7 @@ def gausspulse(
75192
) -> tuple[_Array_f8, _Array_f8]: ...
76193
@overload # retquad: True (keyword), retenv: False = ...
77194
def gausspulse(
78-
t: _GaussPulseTime,
195+
t: onp.ToFloatND,
79196
fc: onp.ToFloat = 1000,
80197
bw: onp.ToFloat = 0.5,
81198
bwr: onp.ToFloat = -6,
@@ -86,7 +203,7 @@ def gausspulse(
86203
) -> tuple[_Array_f8, _Array_f8]: ...
87204
@overload # retquad: True (positional), retenv: True (positional/keyword)
88205
def gausspulse(
89-
t: _GaussPulseTime,
206+
t: onp.ToFloatND,
90207
fc: onp.ToFloat,
91208
bw: onp.ToFloat,
92209
bwr: onp.ToFloat,
@@ -96,7 +213,7 @@ def gausspulse(
96213
) -> tuple[_Array_f8, _Array_f8, _Array_f8]: ...
97214
@overload # retquad: True (keyword), retenv: True
98215
def gausspulse(
99-
t: _GaussPulseTime,
216+
t: onp.ToFloatND,
100217
fc: onp.ToFloat = 1000,
101218
bw: onp.ToFloat = 0.5,
102219
bwr: onp.ToFloat = -6,
@@ -106,45 +223,14 @@ def gausspulse(
106223
retenv: _Truthy,
107224
) -> tuple[_Array_f8, _Array_f8, _Array_f8]: ...
108225

109-
#
110-
@overload # Static type checking for float values
111-
def chirp(
112-
t: _ChirpTime[_NBT1],
113-
f0: _ChirpScalar[_NBT2],
114-
t1: _ChirpScalar[_NBT3],
115-
f1: _ChirpScalar[_NBT4],
116-
method: _ChirpMethod = "linear",
117-
phi: _ChirpScalar[_NBT5] = 0,
118-
vertex_zero: op.CanBool = True,
119-
) -> onp.ArrayND[np.floating[_NBT1 | _NBT2 | _NBT3 | _NBT4 | _NBT5]]: ...
120-
@overload # Other dtypes default to np.float64
121-
def chirp(
122-
t: onp.ToFloatND | _NestedSequence[float],
123-
f0: onp.ToFloat,
124-
t1: onp.ToFloat,
125-
f1: onp.ToFloat,
126-
method: _ChirpMethod = "linear",
127-
phi: onp.ToFloat = 0,
128-
vertex_zero: op.CanBool = True,
129-
) -> _Array_f8: ...
130-
131-
#
132-
def sweep_poly(
133-
t: _ArrayLikeFloat,
134-
poly: onp.ToFloatND | np.poly1d,
135-
phi: onp.ToFloat = 0,
136-
) -> _Array_f8: ...
137-
138-
#
139-
@overload # dtype is not given
140-
def unit_impulse(
141-
shape: AnyShape,
142-
idx: op.CanIndex | Iterable[op.CanIndex] | Literal["mid"] | None = None,
143-
dtype: type[float] = ...,
144-
) -> _Array_f8: ...
145-
@overload # dtype is given
146-
def unit_impulse(
147-
shape: AnyShape,
148-
idx: op.CanIndex | Iterable[op.CanIndex] | Literal["mid"] | None,
149-
dtype: _DTypeLike[_SCT],
150-
) -> npt.NDArray[_SCT]: ...
226+
# Overloads for gausspulse when `t` is `"cutoff"`
227+
@overload # retquad: False = ..., retenv: False = ...
228+
def gausspulse(
229+
t: Literal["cutoff"],
230+
fc: onp.ToFloat = 1000,
231+
bw: onp.ToFloat = 0.5,
232+
bwr: onp.ToFloat = -6,
233+
tpr: onp.ToFloat = -60,
234+
retquad: op.CanBool = False,
235+
retenv: op.CanBool = False,
236+
) -> np.float64: ...

0 commit comments

Comments
 (0)