File tree 2 files changed +12
-3
lines changed
2 files changed +12
-3
lines changed Original file line number Diff line number Diff line change @@ -1224,7 +1224,7 @@ class DataFrame(NDFrame, OpsMixin):
1224
1224
@overload
1225
1225
def apply (
1226
1226
self ,
1227
- f : Callable [..., S1 ],
1227
+ f : Callable [..., S1 | NAType ],
1228
1228
axis : AxisIndex = ...,
1229
1229
raw : _bool = ...,
1230
1230
result_type : None = ...,
@@ -1248,7 +1248,7 @@ class DataFrame(NDFrame, OpsMixin):
1248
1248
@overload
1249
1249
def apply (
1250
1250
self ,
1251
- f : Callable [..., S1 ],
1251
+ f : Callable [..., S1 | NAType ],
1252
1252
axis : Axis = ...,
1253
1253
raw : _bool = ...,
1254
1254
args : Any = ...,
@@ -1309,7 +1309,7 @@ class DataFrame(NDFrame, OpsMixin):
1309
1309
@overload
1310
1310
def apply (
1311
1311
self ,
1312
- f : Callable [..., S1 ],
1312
+ f : Callable [..., S1 | NAType ],
1313
1313
raw : _bool = ...,
1314
1314
result_type : None = ...,
1315
1315
args : Any = ...,
Original file line number Diff line number Diff line change 43
43
)
44
44
import xarray as xr
45
45
46
+ from pandas ._libs .missing import NAType
46
47
from pandas ._typing import Scalar
47
48
48
49
from tests import (
@@ -578,6 +579,9 @@ def test_types_apply() -> None:
578
579
def returns_scalar (x : pd .Series ) -> int :
579
580
return 2
580
581
582
+ def returns_scalar_na (x : pd .Series ) -> int | NAType :
583
+ return 2 if (x < 5 ).all () else pd .NA
584
+
581
585
def returns_series (x : pd .Series ) -> pd .Series :
582
586
return x ** 2
583
587
@@ -604,6 +608,11 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
604
608
check (
605
609
assert_type (df .apply (returns_scalar ), "pd.Series[int]" ), pd .Series , np .integer
606
610
)
611
+ check (
612
+ assert_type (df .apply (returns_scalar_na ), "pd.Series[int]" ),
613
+ pd .Series ,
614
+ int ,
615
+ )
607
616
check (assert_type (df .apply (returns_series ), pd .DataFrame ), pd .DataFrame )
608
617
check (assert_type (df .apply (returns_listlike_of_3 ), pd .DataFrame ), pd .DataFrame )
609
618
check (assert_type (df .apply (returns_dict ), pd .Series ), pd .Series )
You can’t perform that action at this time.
0 commit comments