Skip to content

Commit 7e6aee4

Browse files
authored
GroupBy[Series].count() return type should be Series[int] (#966)
* GroupBy[Series].count() return type should be Series[int] * Use np.integer instead of np.int_ * Update pyright requirement '>=1.1.369' -> '>=1.1.374'
1 parent 458ecb4 commit 7e6aee4

File tree

4 files changed

+13
-2
lines changed

4 files changed

+13
-2
lines changed

pandas-stubs/core/groupby/groupby.pyi

+4-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,10 @@ class GroupBy(BaseGroupBy[NDFrameT]):
176176
@overload
177177
def all(self: GroupBy[DataFrame], skipna: bool = ...) -> DataFrame: ...
178178
@final
179-
def count(self) -> NDFrameT: ...
179+
@overload
180+
def count(self: GroupBy[Series]) -> Series[int]: ...
181+
@overload
182+
def count(self: GroupBy[DataFrame]) -> DataFrame: ...
180183
@final
181184
def mean(
182185
self,

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ mypy = "1.10.1"
4040
pandas = "2.2.2"
4141
pyarrow = ">=10.0.1"
4242
pytest = ">=7.1.2"
43-
pyright = ">=1.1.369"
43+
pyright = ">= 1.1.374"
4444
poethepoet = ">=0.16.5"
4545
loguru = ">=0.6.0"
4646
typing-extensions = ">=4.4.0"

tests/test_frame.py

+1
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,7 @@ def test_types_groupby_methods() -> None:
10251025
check(assert_type(df.groupby("col1").sum(), pd.DataFrame), pd.DataFrame)
10261026
check(assert_type(df.groupby("col1").prod(), pd.DataFrame), pd.DataFrame)
10271027
check(assert_type(df.groupby("col1").sample(), pd.DataFrame), pd.DataFrame)
1028+
check(assert_type(df.groupby("col1").count(), pd.DataFrame), pd.DataFrame)
10281029
check(
10291030
assert_type(df.groupby("col1").value_counts(normalize=False), "pd.Series[int]"),
10301031
pd.Series,

tests/test_series.py

+7
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,13 @@ def test_types_groupby_methods() -> None:
676676
check(assert_type(s.groupby(level=0).idxmax(), pd.Series), pd.Series)
677677
check(assert_type(s.groupby(level=0).idxmin(), pd.Series), pd.Series)
678678

679+
s2 = pd.Series(["w", "x", "y", "z"], index=[3, 4, 3, 4], dtype=str)
680+
check(
681+
assert_type(s2.groupby(level=0).count(), "pd.Series[int]"),
682+
pd.Series,
683+
np.integer,
684+
)
685+
679686

680687
def test_groupby_result() -> None:
681688
# GH 142

0 commit comments

Comments
 (0)