Skip to content

Commit 297640c

Browse files
committed
feat(series): #1098 __add__ and __mul__
1 parent f6dd4a2 commit 297640c

File tree

2 files changed

+178
-13
lines changed

2 files changed

+178
-13
lines changed

pandas-stubs/core/series.pyi

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ from typing import (
2525
Generic,
2626
Literal,
2727
NoReturn,
28+
TypeVar,
2829
overload,
2930
)
3031

@@ -183,6 +184,12 @@ from pandas.core.dtypes.dtypes import CategoricalDtype
183184

184185
from pandas.plotting import PlotAccessor
185186

187+
_T_INT = TypeVar("_T_INT", bound=int)
188+
_T_COMPLEX = TypeVar("_T_COMPLEX", bound=complex)
189+
_T_INT_FLOAT = TypeVar("_T_INT_FLOAT", bound=int | float)
190+
_T_FLOAT_COMPLEX = TypeVar("_T_FLOAT_COMPLEX", bound=float | complex)
191+
_T_INT_FLOAT_COMPLEX = TypeVar("_T_INT_FLOAT_COMPLEX", bound=int | float | complex)
192+
186193
class _iLocIndexerSeries(_iLocIndexer, Generic[S1]):
187194
# get item
188195
@overload
@@ -1580,6 +1587,24 @@ class Series(IndexOpsMixin[S1], NDFrame):
15801587
# just failed to generate these so I couldn't match
15811588
# them up.
15821589
@overload
1590+
def __add__(
1591+
self: Series[_T_INT], other: int | Sequence[int] | Series[int]
1592+
) -> Series[_T_INT]: ...
1593+
@overload
1594+
def __add__(
1595+
self: Series[_T_INT_FLOAT],
1596+
other: _T_FLOAT_COMPLEX | Sequence[_T_FLOAT_COMPLEX] | Series[_T_FLOAT_COMPLEX],
1597+
) -> Series[_T_FLOAT_COMPLEX]: ...
1598+
@overload
1599+
def __add__(
1600+
self: Series[_T_COMPLEX],
1601+
other: (
1602+
_T_INT_FLOAT_COMPLEX
1603+
| Sequence[_T_INT_FLOAT_COMPLEX]
1604+
| Series[_T_INT_FLOAT_COMPLEX]
1605+
),
1606+
) -> Series[_T_COMPLEX]: ...
1607+
@overload
15831608
def __add__(self, other: S1 | Self) -> Self: ...
15841609
@overload
15851610
def __add__(
@@ -1610,6 +1635,24 @@ class Series(IndexOpsMixin[S1], NDFrame):
16101635
self, other: S1 | _ListLike | Series[S1] | datetime | timedelta | date
16111636
) -> Series[_bool]: ...
16121637
@overload
1638+
def __mul__(
1639+
self: Series[_T_INT_FLOAT], other: int | Sequence[int] | Series[int]
1640+
) -> Series[_T_INT_FLOAT]: ...
1641+
@overload
1642+
def __mul__(
1643+
self: Series[_T_INT_FLOAT],
1644+
other: _T_FLOAT_COMPLEX | Sequence[_T_FLOAT_COMPLEX] | Series[_T_FLOAT_COMPLEX],
1645+
) -> Series[_T_FLOAT_COMPLEX]: ...
1646+
@overload
1647+
def __mul__(
1648+
self: Series[_T_COMPLEX],
1649+
other: (
1650+
_T_INT_FLOAT_COMPLEX
1651+
| Sequence[_T_INT_FLOAT_COMPLEX]
1652+
| Series[_T_INT_FLOAT_COMPLEX]
1653+
),
1654+
) -> Series[_T_COMPLEX]: ...
1655+
@overload
16131656
def __mul__(
16141657
self, other: timedelta | Timedelta | TimedeltaSeries | np.timedelta64
16151658
) -> TimedeltaSeries: ...
@@ -1703,6 +1746,23 @@ class Series(IndexOpsMixin[S1], NDFrame):
17031746
@property
17041747
def loc(self) -> _LocIndexerSeries[S1]: ...
17051748
# Methods
1749+
@overload
1750+
def add(
1751+
self: Series[_T_INT],
1752+
other: int | Sequence[int] | Series[int],
1753+
level: Level | None = ...,
1754+
fill_value: float | None = ...,
1755+
axis: int = ...,
1756+
) -> Series[_T_INT]: ...
1757+
@overload
1758+
def add(
1759+
self: Series[_T_INT_FLOAT],
1760+
other: _T_FLOAT_COMPLEX | Sequence[_T_FLOAT_COMPLEX] | Series[_T_FLOAT_COMPLEX],
1761+
level: Level | None = ...,
1762+
fill_value: float | None = ...,
1763+
axis: int = ...,
1764+
) -> Series[_T_FLOAT_COMPLEX]: ...
1765+
@overload
17061766
def add(
17071767
self,
17081768
other: Series[S1] | Scalar,
@@ -1890,6 +1950,14 @@ class Series(IndexOpsMixin[S1], NDFrame):
18901950
axis: AxisIndex | None = ...,
18911951
) -> Series[S1]: ...
18921952
@overload
1953+
def mul(
1954+
self: Series[int],
1955+
other: int | Sequence[int] | Series[int],
1956+
level: Level | None = ...,
1957+
fill_value: float | None = ...,
1958+
axis: AxisIndex | None = ...,
1959+
) -> Series[int]: ...
1960+
@overload
18931961
def mul(
18941962
self,
18951963
other: timedelta | Timedelta | TimedeltaSeries | np.timedelta64,

tests/test_series.py

Lines changed: 110 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -815,11 +815,8 @@ def test_types_element_wise_arithmetic() -> None:
815815
check(assert_type(s - s2, pd.Series), pd.Series, np.integer)
816816
check(assert_type(s.sub(s2, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)
817817

818-
# TODO these two below should type pd.Series[int]
819-
# check(assert_type(s * s2, "pd.Series[int]"), pd.Series, np.integer )
820-
check(assert_type(s * s2, pd.Series), pd.Series, np.integer)
821-
# check(assert_type(s.mul(s2, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)
822-
check(assert_type(s.mul(s2, fill_value=0), pd.Series), pd.Series, np.integer)
818+
check(assert_type(s * s2, "pd.Series[int]"), pd.Series, np.integer)
819+
check(assert_type(s.mul(s2, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)
823820

824821
# TODO these two below should type pd.Series[float]
825822
# check(assert_type(s / s2, "pd.Series[float]"), pd.Series, np.float64)
@@ -839,18 +836,34 @@ def test_types_element_wise_arithmetic() -> None:
839836

840837
check(assert_type(divmod(s, s2), tuple["pd.Series[int]", "pd.Series[int]"]), tuple)
841838

839+
s2_float = s2.astype(float)
840+
check(assert_type(s + s2_float, "pd.Series[float]"), pd.Series, np.floating)
841+
check(
842+
assert_type(s.add(s2_float, fill_value=0), "pd.Series[float]"),
843+
pd.Series,
844+
np.floating,
845+
)
846+
847+
check(assert_type(s * s2_float, "pd.Series[float]"), pd.Series, np.floating)
848+
check(assert_type(s2_float * s, "pd.Series[float]"), pd.Series, np.floating)
849+
check(assert_type(s.mul(s2_float, fill_value=0), pd.Series), pd.Series, np.floating)
850+
check(assert_type(s2_float.mul(s, fill_value=0), pd.Series), pd.Series, np.floating)
851+
842852

843853
def test_types_scalar_arithmetic() -> None:
854+
# TODO: assert_type
844855
s = pd.Series([0, 1, -10])
845856

846857
check(assert_type(s + 1, "pd.Series[int]"), pd.Series, np.integer)
858+
check(assert_type(1 + s, "pd.Series[int]"), pd.Series, np.integer)
847859
check(assert_type(s.add(1, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)
848860

849861
res_sub: pd.Series = s - 1
850862
res_sub2: pd.Series = s.sub(1, fill_value=0)
851863

852-
res_mul: pd.Series = s * 2
853-
res_mul2: pd.Series = s.mul(2, fill_value=0)
864+
check(assert_type(s * 2, "pd.Series[int]"), pd.Series, np.integer)
865+
check(assert_type(2 * s, pd.Series), pd.Series, np.integer)
866+
check(assert_type(s.mul(2, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)
854867

855868
res_div: pd.Series = s / 2
856869
res_div2: pd.Series = s.div(2, fill_value=0)
@@ -866,13 +879,95 @@ def test_types_scalar_arithmetic() -> None:
866879
res_pow2: pd.Series = s**0.213
867880
res_pow3: pd.Series = s.pow(0.5)
868881

882+
check(assert_type(s + 1.0, "pd.Series[float]"), pd.Series, np.floating)
883+
check(assert_type(1.0 + s, pd.Series), pd.Series, np.floating)
884+
check(
885+
assert_type(s.add(1.0, fill_value=0), "pd.Series[float]"),
886+
pd.Series,
887+
np.floating,
888+
)
889+
890+
check(assert_type(s - 1.0, pd.Series), pd.Series, np.floating)
891+
check(assert_type(1.0 - s, pd.Series), pd.Series, np.floating)
892+
# check(
893+
# assert_type(s.sub(1.0, fill_value=0), pd.Series),
894+
# pd.Series,
895+
# np.floating
896+
# )
897+
898+
check(assert_type(s * 2.0, "pd.Series[float]"), pd.Series, np.floating)
899+
check(assert_type(2.0 * s, pd.Series), pd.Series, np.floating)
900+
check(assert_type(s.mul(2.0, fill_value=0), pd.Series), pd.Series, np.floating)
901+
869902

870903
def test_types_complex_arithmetic() -> None:
871-
"""Test adding complex number to pd.Series[float] GH 103."""
904+
"""Test arithmetic of complex numbers GH 103, GH 1098."""
905+
# TODO: assert_type should always be "pd.Series[complex]"
872906
c = 1 + 1j
873-
s = pd.Series([1.0, 2.0, 3.0])
874-
x = s + c
875-
y = s - c
907+
s_int = pd.Series([1, 2, 3])
908+
check(assert_type(s_int + c, "pd.Series[complex]"), pd.Series, complex)
909+
check(assert_type(s_int - c, pd.Series), pd.Series, complex)
910+
check(assert_type(s_int * c, "pd.Series[complex]"), pd.Series, complex)
911+
check(assert_type(s_int / c, pd.Series), pd.Series, complex)
912+
check(assert_type(s_int.add(c), "pd.Series[complex]"), pd.Series, complex)
913+
# check(assert_type(s_int.sub(c), pd.Series), pd.Series, complex)
914+
check(assert_type(s_int.mul(c), pd.Series), pd.Series, complex)
915+
# check(assert_type(s_int.div(c), pd.Series), pd.Series, complex)
916+
check(assert_type(c + s_int, pd.Series), pd.Series, complex)
917+
check(assert_type(c - s_int, pd.Series), pd.Series, complex)
918+
check(assert_type(c * s_int, pd.Series), pd.Series, complex)
919+
check(assert_type(c / s_int, pd.Series), pd.Series, complex)
920+
921+
s_float = s_int.astype(float)
922+
check(assert_type(s_float + c, "pd.Series[complex]"), pd.Series, complex)
923+
check(assert_type(s_float - c, pd.Series), pd.Series, complex)
924+
check(assert_type(s_float * c, "pd.Series[complex]"), pd.Series, complex)
925+
check(assert_type(s_float / c, pd.Series), pd.Series, complex)
926+
check(assert_type(s_float.add(c), "pd.Series[complex]"), pd.Series, complex)
927+
# check(assert_type(s_float.sub(c), pd.Series), pd.Series, complex)
928+
check(assert_type(s_float.mul(c), pd.Series), pd.Series, complex)
929+
# check(assert_type(s_float.div(c), pd.Series), pd.Series, complex)
930+
check(assert_type(c + s_float, pd.Series), pd.Series, complex)
931+
check(assert_type(c - s_float, pd.Series), pd.Series, complex)
932+
check(assert_type(c * s_float, pd.Series), pd.Series, complex)
933+
check(assert_type(c / s_float, pd.Series), pd.Series, complex)
934+
935+
s_comp = s_int + c
936+
check(assert_type(s_comp + c, "pd.Series[complex]"), pd.Series, complex)
937+
check(assert_type(s_comp - c, pd.Series), pd.Series, complex)
938+
check(assert_type(s_comp * c, "pd.Series[complex]"), pd.Series, complex)
939+
check(assert_type(s_comp / c, pd.Series), pd.Series, complex)
940+
check(assert_type(s_comp.add(c), "pd.Series[complex]"), pd.Series, complex)
941+
check(assert_type(s_comp.sub(c), "pd.Series[complex]"), pd.Series, complex)
942+
check(assert_type(s_comp.mul(c), pd.Series), pd.Series, complex)
943+
# check(assert_type(s_comp.div(c), pd.Series), pd.Series, complex)
944+
check(assert_type(c + s_comp, "pd.Series[complex]"), pd.Series, complex)
945+
check(assert_type(c - s_comp, pd.Series), pd.Series, complex)
946+
check(assert_type(c * s_comp, pd.Series), pd.Series, complex)
947+
check(assert_type(c / s_comp, pd.Series), pd.Series, complex)
948+
949+
check(assert_type(s_int + s_comp, "pd.Series[complex]"), pd.Series, complex)
950+
check(assert_type(s_int - s_comp, pd.Series), pd.Series, complex)
951+
check(assert_type(s_int * s_comp, "pd.Series[complex]"), pd.Series, complex)
952+
check(assert_type(s_int / s_comp, pd.Series), pd.Series, complex)
953+
check(assert_type(s_comp + s_int, "pd.Series[complex]"), pd.Series, complex)
954+
check(assert_type(s_comp - s_int, pd.Series), pd.Series, complex)
955+
check(assert_type(s_comp * s_int, "pd.Series[complex]"), pd.Series, complex)
956+
check(assert_type(s_comp / s_int, pd.Series), pd.Series, complex)
957+
check(assert_type(s_float + s_comp, "pd.Series[complex]"), pd.Series, complex)
958+
check(assert_type(s_float - s_comp, pd.Series), pd.Series, complex)
959+
check(assert_type(s_float * s_comp, "pd.Series[complex]"), pd.Series, complex)
960+
check(assert_type(s_float / s_comp, pd.Series), pd.Series, complex)
961+
check(assert_type(s_comp + s_float, "pd.Series[complex]"), pd.Series, complex)
962+
check(assert_type(s_comp - s_float, pd.Series), pd.Series, complex)
963+
check(assert_type(s_comp * s_float, "pd.Series[complex]"), pd.Series, complex)
964+
check(assert_type(s_comp / s_float, pd.Series), pd.Series, complex)
965+
966+
s2_comp = s_comp + c
967+
check(assert_type(s2_comp + s_comp, "pd.Series[complex]"), pd.Series, complex)
968+
check(assert_type(s2_comp - s_comp, pd.Series), pd.Series, complex)
969+
check(assert_type(s2_comp * s_comp, "pd.Series[complex]"), pd.Series, complex)
970+
check(assert_type(s2_comp / s_comp, pd.Series), pd.Series, complex)
876971

877972

878973
def test_types_groupby() -> None:
@@ -1636,11 +1731,11 @@ def test_series_multiindex_getitem() -> None:
16361731
def test_series_mul() -> None:
16371732
s = pd.Series([1, 2, 3])
16381733
sm = s * 4
1639-
check(assert_type(sm, pd.Series), pd.Series)
1734+
check(assert_type(sm, "pd.Series[int]"), pd.Series, np.integer)
16401735
ss = s - 4
16411736
check(assert_type(ss, pd.Series), pd.Series)
16421737
sm2 = s * s
1643-
check(assert_type(sm2, pd.Series), pd.Series)
1738+
check(assert_type(sm2, "pd.Series[int]"), pd.Series, np.integer)
16441739
sp = s + 4
16451740
check(assert_type(sp, "pd.Series[int]"), pd.Series, np.integer)
16461741

@@ -3796,6 +3891,7 @@ def test_path_div() -> None:
37963891
check(assert_type(folder / files, pd.Series), pd.Series, Path)
37973892

37983893
folders = pd.Series([folder, folder])
3894+
folders.__truediv__(Path("a.png"))
37993895
check(assert_type(folders / Path("a.png"), pd.Series), pd.Series, Path)
38003896

38013897

@@ -3892,6 +3988,7 @@ def foo(sf: pd.Series) -> None:
38923988
pass
38933989

38943990
foo(s)
3991+
s.__add__(pd.Series([1]))
38953992
check(assert_type(s + pd.Series([1]), pd.Series), pd.Series)
38963993

38973994

0 commit comments

Comments
 (0)