Skip to content

Commit 647b1c3

Browse files
committed
signal: Type annotate chirp so that the output dtype can be inferred
In `_waveforms.pyi`
1 parent 7b975ec commit 647b1c3

File tree

1 file changed

+28
-5
lines changed

1 file changed

+28
-5
lines changed

scipy-stubs/signal/_waveforms.pyi

+28-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import numpy as np
55
import numpy.typing as npt
66
import optype as op
77
import optype.numpy as onp
8-
from numpy._typing import _DTypeLike
8+
from numpy._typing import _ArrayLike, _DTypeLike, _NestedSequence
99
from scipy._typing import AnyShape
1010

1111
__all__ = ["chirp", "gausspulse", "sawtooth", "square", "sweep_poly", "unit_impulse"]
@@ -18,6 +18,16 @@ _ArrayLikeFloat: TypeAlias = onp.ToFloat | onp.ToFloatND
1818
_Array_f8: TypeAlias = onp.ArrayND[np.float64]
1919
_GaussPulseTime: TypeAlias = _ArrayLikeFloat | Literal["cutoff"]
2020

21+
# Type vars to annotate `chirp`
22+
_NBT1 = TypeVar("_NBT1", bound=npt.NBitBase)
23+
_NBT2 = TypeVar("_NBT2", bound=npt.NBitBase)
24+
_NBT3 = TypeVar("_NBT3", bound=npt.NBitBase)
25+
_NBT4 = TypeVar("_NBT4", bound=npt.NBitBase)
26+
_NBT5 = TypeVar("_NBT5", bound=npt.NBitBase)
27+
_ChirpTime: TypeAlias = _ArrayLike[np.floating[_NBT1] | np.integer[_NBT1]]
28+
_ChirpScalar: TypeAlias = float | np.floating[_NBT1] | np.integer[_NBT1]
29+
_ChirpMethod: TypeAlias = Literal["linear", "quadratic", "logarithmic", "hyperbolic"]
30+
2131
def sawtooth(t: _ArrayLikeFloat, width: _ArrayLikeFloat = 1) -> _Array_f8: ...
2232
def square(t: _ArrayLikeFloat, duty: _ArrayLikeFloat = 0.5) -> _Array_f8: ...
2333

@@ -96,16 +106,29 @@ def gausspulse(
96106
retenv: _Truthy,
97107
) -> tuple[_Array_f8, _Array_f8, _Array_f8]: ...
98108

99-
# float16 -> float16, float32 -> float32, ... -> float64
109+
#
110+
@overload # Static type checking for float values
111+
def chirp(
112+
t: _ChirpTime[_NBT1],
113+
f0: _ChirpScalar[_NBT2],
114+
t1: _ChirpScalar[_NBT3],
115+
f1: _ChirpScalar[_NBT4],
116+
method: _ChirpMethod = "linear",
117+
phi: _ChirpScalar[_NBT5] = 0,
118+
vertex_zero: op.CanBool = True,
119+
) -> onp.ArrayND[np.floating[_NBT1 | _NBT2 | _NBT3 | _NBT4 | _NBT5]]: ...
120+
@overload # Other dtypes default to np.float64
100121
def chirp(
101-
t: onp.ToFloatND,
122+
t: onp.ToFloatND | _NestedSequence[float],
102123
f0: onp.ToFloat,
103124
t1: onp.ToFloat,
104125
f1: onp.ToFloat,
105-
method: Literal["linear", "quadratic", "logarithmic", "hyperbolic"] = "linear",
126+
method: _ChirpMethod = "linear",
106127
phi: onp.ToFloat = 0,
107128
vertex_zero: op.CanBool = True,
108-
) -> npt.NDArray[np.float16 | np.float32 | np.float64]: ...
129+
) -> _Array_f8: ...
130+
131+
#
109132
def sweep_poly(
110133
t: _ArrayLikeFloat,
111134
poly: onp.ToFloatND | np.poly1d,

0 commit comments

Comments
 (0)