diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index a28ec8c1b..2708e2247 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -13,6 +13,7 @@ from typing import ( Any, ClassVar, Literal, + TypeAlias, final, overload, ) @@ -263,7 +264,16 @@ class Index(IndexOpsMixin[S1]): @property def str( self, - ) -> StringMethods[Self, MultiIndex, np_ndarray_bool, Index[list[str]]]: ... + ) -> StringMethods[ + Self, + MultiIndex, + np_ndarray_bool, + Index[list[str]], + Index[int], + Index[bytes], + Index[str], + Index[type[object]], + ]: ... def is_(self, other) -> bool: ... def __len__(self) -> int: ... def __array__(self, dtype=...) -> np.ndarray: ... @@ -455,6 +465,8 @@ class Index(IndexOpsMixin[S1]): ), ) -> Self: ... +UnknownIndex: TypeAlias = Index[Any] + def ensure_index_from_sequences( sequences: Sequence[Sequence[Dtype]], names: list[str] = ... ) -> Index: ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 072eb7240..526cac8b0 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1179,7 +1179,16 @@ class Series(IndexOpsMixin[S1], NDFrame): @property def str( self, - ) -> StringMethods[Series, DataFrame, Series[bool], Series[list[str]]]: ... + ) -> StringMethods[ + Self, + DataFrame, + Series[bool], + Series[list[str]], + Series[int], + Series[bytes], + Series[str], + Series[type[object]], + ]: ... @property def dt(self) -> CombinedDatetimelikeProperties: ... @property @@ -2318,3 +2327,5 @@ class IntervalSeries(Series[Interval[_OrderableT]], Generic[_OrderableT]): @property def array(self) -> IntervalArray: ... def diff(self, periods: int = ...) -> Never: ... + +UnknownSeries: TypeAlias = Series[Any] diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index c12851705..4d215e82a 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -1,3 +1,4 @@ +# pyright: strict from collections.abc import ( Callable, Sequence, @@ -12,6 +13,7 @@ from typing import ( ) import numpy as np +import numpy.typing as npt import pandas as pd from pandas import ( DataFrame, @@ -21,23 +23,36 @@ from pandas import ( ) from pandas.core.base import NoNewAttributesMixin +from pandas._libs.tslibs.nattype import NaTType from pandas._typing import ( JoinHow, + Scalar, T, np_ndarray_bool, ) -# The _TS type is what is used for the result of str.split with expand=True -_TS = TypeVar("_TS", bound=DataFrame | MultiIndex) -# The _TS2 type is what is used for the result of str.split with expand=False -_TS2 = TypeVar("_TS2", bound=Series[list[str]] | Index[list[str]]) -# The _TM type is what is used for the result of str.match -_TM = TypeVar("_TM", bound=Series[bool] | np_ndarray_bool) +# Used for the result of str.split with expand=True +_T_EXPANDING = TypeVar("_T_EXPANDING", bound=DataFrame | MultiIndex) +# Used for the result of str.split with expand=False +_T_LIST_STR = TypeVar("_T_LIST_STR", bound=Series[list[str]] | Index[list[str]]) +# Used for the result of str.match +_T_BOOL = TypeVar("_T_BOOL", bound=Series[bool] | np_ndarray_bool) +# Used for the result of str.index / str.find +_T_INT = TypeVar("_T_INT", bound=Series[int] | Index[int]) +# Used for the result of str.encode +_T_BYTES = TypeVar("_T_BYTES", bound=Series[bytes] | Index[bytes]) +# Used for the result of str.decode +_T_STR = TypeVar("_T_STR", bound=Series[str] | Index[str]) +# Used for the result of str.partition +_T_OBJECT = TypeVar("_T_OBJECT", bound=Series[type[object]] | Index[type[object]]) -class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): +class StringMethods( + NoNewAttributesMixin, + Generic[T, _T_EXPANDING, _T_BOOL, _T_LIST_STR, _T_INT, _T_BYTES, _T_STR, _T_OBJECT], +): def __init__(self, data: T) -> None: ... - def __getitem__(self, key: slice | int) -> T: ... - def __iter__(self) -> T: ... + def __getitem__(self, key: slice | int) -> _T_STR: ... + def __iter__(self) -> _T_STR: ... @overload def cat( self, @@ -58,15 +73,17 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): @overload def cat( self, - others: Series | pd.Index | pd.DataFrame | np.ndarray | list[Any], + others: ( + Series[str] | Index[str] | pd.DataFrame | npt.NDArray[np.str_] | list[str] + ), sep: str = ..., na_rep: str | None = ..., join: JoinHow = ..., - ) -> T: ... + ) -> _T_STR: ... @overload def split( self, pat: str = ..., *, n: int = ..., expand: Literal[True], regex: bool = ... - ) -> _TS: ... + ) -> _T_EXPANDING: ... @overload def split( self, @@ -75,77 +92,79 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): n: int = ..., expand: Literal[False] = ..., regex: bool = ..., - ) -> _TS2: ... + ) -> _T_LIST_STR: ... @overload - def rsplit(self, pat: str = ..., *, n: int = ..., expand: Literal[True]) -> _TS: ... + def rsplit( + self, pat: str = ..., *, n: int = ..., expand: Literal[True] + ) -> _T_EXPANDING: ... @overload def rsplit( self, pat: str = ..., *, n: int = ..., expand: Literal[False] = ... - ) -> _TS2: ... + ) -> _T_LIST_STR: ... @overload - def partition(self, sep: str = ...) -> pd.DataFrame: ... + def partition(self, sep: str = ...) -> _T_EXPANDING: ... @overload - def partition(self, *, expand: Literal[True]) -> pd.DataFrame: ... + def partition(self, *, expand: Literal[True]) -> _T_EXPANDING: ... @overload - def partition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ... + def partition(self, sep: str, expand: Literal[True]) -> _T_EXPANDING: ... @overload - def partition(self, sep: str, expand: Literal[False]) -> T: ... + def partition(self, sep: str, expand: Literal[False]) -> _T_OBJECT: ... @overload - def partition(self, *, expand: Literal[False]) -> T: ... + def partition(self, *, expand: Literal[False]) -> _T_OBJECT: ... @overload - def rpartition(self, sep: str = ...) -> pd.DataFrame: ... + def rpartition(self, sep: str = ...) -> _T_EXPANDING: ... @overload - def rpartition(self, *, expand: Literal[True]) -> pd.DataFrame: ... + def rpartition(self, *, expand: Literal[True]) -> _T_EXPANDING: ... @overload - def rpartition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ... + def rpartition(self, sep: str, expand: Literal[True]) -> _T_EXPANDING: ... @overload - def rpartition(self, sep: str, expand: Literal[False]) -> T: ... + def rpartition(self, sep: str, expand: Literal[False]) -> _T_OBJECT: ... @overload - def rpartition(self, *, expand: Literal[False]) -> T: ... - def get(self, i: int) -> T: ... - def join(self, sep: str) -> T: ... + def rpartition(self, *, expand: Literal[False]) -> _T_OBJECT: ... + def get(self, i: int) -> _T_STR: ... + def join(self, sep: str) -> _T_STR: ... def contains( self, - pat: str | re.Pattern, + pat: str | re.Pattern[str], case: bool = ..., flags: int = ..., - na=..., + na: Scalar | NaTType | None = ..., regex: bool = ..., - ) -> Series[bool]: ... + ) -> _T_BOOL: ... def match( self, pat: str, case: bool = ..., flags: int = ..., na: Any = ... - ) -> _TM: ... + ) -> _T_BOOL: ... def replace( self, pat: str, - repl: str | Callable[[re.Match], str], + repl: str | Callable[[re.Match[str]], str], n: int = ..., case: bool | None = ..., flags: int = ..., regex: bool = ..., - ) -> T: ... - def repeat(self, repeats: int | Sequence[int]) -> T: ... + ) -> _T_STR: ... + def repeat(self, repeats: int | Sequence[int]) -> _T_STR: ... def pad( self, width: int, side: Literal["left", "right", "both"] = ..., fillchar: str = ..., - ) -> T: ... - def center(self, width: int, fillchar: str = ...) -> T: ... - def ljust(self, width: int, fillchar: str = ...) -> T: ... - def rjust(self, width: int, fillchar: str = ...) -> T: ... - def zfill(self, width: int) -> T: ... + ) -> _T_STR: ... + def center(self, width: int, fillchar: str = ...) -> _T_STR: ... + def ljust(self, width: int, fillchar: str = ...) -> _T_STR: ... + def rjust(self, width: int, fillchar: str = ...) -> _T_STR: ... + def zfill(self, width: int) -> _T_STR: ... def slice( self, start: int | None = ..., stop: int | None = ..., step: int | None = ... ) -> T: ... def slice_replace( self, start: int | None = ..., stop: int | None = ..., repl: str | None = ... - ) -> T: ... - def decode(self, encoding: str, errors: str = ...) -> T: ... - def encode(self, encoding: str, errors: str = ...) -> T: ... - def strip(self, to_strip: str | None = ...) -> T: ... - def lstrip(self, to_strip: str | None = ...) -> T: ... - def rstrip(self, to_strip: str | None = ...) -> T: ... + ) -> _T_STR: ... + def decode(self, encoding: str, errors: str = ...) -> _T_STR: ... + def encode(self, encoding: str, errors: str = ...) -> _T_BYTES: ... + def strip(self, to_strip: str | None = ...) -> _T_STR: ... + def lstrip(self, to_strip: str | None = ...) -> _T_STR: ... + def rstrip(self, to_strip: str | None = ...) -> _T_STR: ... def wrap( self, width: int, @@ -154,45 +173,47 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): drop_whitespace: bool | None = ..., break_long_words: bool | None = ..., break_on_hyphens: bool | None = ..., - ) -> T: ... - def get_dummies(self, sep: str = ...) -> pd.DataFrame: ... - def translate(self, table: dict[int, int | str | None] | None) -> T: ... - def count(self, pat: str, flags: int = ...) -> Series[int]: ... - def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ... - def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ... - def findall(self, pat: str, flags: int = ...) -> Series: ... + ) -> _T_STR: ... + def get_dummies(self, sep: str = ...) -> _T_EXPANDING: ... + def translate(self, table: dict[int, int | str | None] | None) -> _T_STR: ... + def count(self, pat: str, flags: int = ...) -> _T_INT: ... + def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ... + def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ... + def findall(self, pat: str, flags: int = ...) -> _T_LIST_STR: ... @overload def extract( self, pat: str, flags: int = ..., *, expand: Literal[True] = ... ) -> pd.DataFrame: ... @overload - def extract(self, pat: str, flags: int, expand: Literal[False]) -> T: ... + def extract(self, pat: str, flags: int, expand: Literal[False]) -> _T_OBJECT: ... @overload - def extract(self, pat: str, flags: int = ..., *, expand: Literal[False]) -> T: ... + def extract( + self, pat: str, flags: int = ..., *, expand: Literal[False] + ) -> _T_OBJECT: ... def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ... - def find(self, sub: str, start: int = ..., end: int | None = ...) -> T: ... - def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> T: ... - def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ... - def index(self, sub: str, start: int = ..., end: int | None = ...) -> T: ... - def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> T: ... - def len(self) -> Series[int]: ... - def lower(self) -> T: ... - def upper(self) -> T: ... - def title(self) -> T: ... - def capitalize(self) -> T: ... - def swapcase(self) -> T: ... - def casefold(self) -> T: ... - def isalnum(self) -> Series[bool]: ... - def isalpha(self) -> Series[bool]: ... - def isdigit(self) -> Series[bool]: ... - def isspace(self) -> Series[bool]: ... - def islower(self) -> Series[bool]: ... - def isupper(self) -> Series[bool]: ... - def istitle(self) -> Series[bool]: ... - def isnumeric(self) -> Series[bool]: ... - def isdecimal(self) -> Series[bool]: ... + def find(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... + def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... + def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> _T_STR: ... + def index(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... + def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... + def len(self) -> _T_INT: ... + def lower(self) -> _T_STR: ... + def upper(self) -> _T_STR: ... + def title(self) -> _T_STR: ... + def capitalize(self) -> _T_STR: ... + def swapcase(self) -> _T_STR: ... + def casefold(self) -> _T_STR: ... + def isalnum(self) -> _T_BOOL: ... + def isalpha(self) -> _T_BOOL: ... + def isdigit(self) -> _T_BOOL: ... + def isspace(self) -> _T_BOOL: ... + def islower(self) -> _T_BOOL: ... + def isupper(self) -> _T_BOOL: ... + def istitle(self) -> _T_BOOL: ... + def isnumeric(self) -> _T_BOOL: ... + def isdecimal(self) -> _T_BOOL: ... def fullmatch( self, pat: str, case: bool = ..., flags: int = ..., na: Any = ... - ) -> Series[bool]: ... - def removeprefix(self, prefix: str) -> T: ... - def removesuffix(self, suffix: str) -> T: ... + ) -> _T_BOOL: ... + def removeprefix(self, prefix: str) -> _T_STR: ... + def removesuffix(self, suffix: str) -> _T_STR: ... diff --git a/tests/test_series.py b/tests/test_series.py index 9bce2bd10..61e1939a7 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -88,6 +88,7 @@ ) from pandas._typing import np_ndarray_int # noqa: F401 + # Tests will use numpy 2.1 in python 3.10 or later # From Numpy 2.1 __init__.pyi _DTypeKind: TypeAlias = Literal[ @@ -1602,148 +1603,6 @@ def test_categorical_codes(): assert_type(cat.codes, "np_ndarray_int") -def test_string_accessors(): - s = pd.Series( - ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] - ) - s2 = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) - s3 = pd.Series(["a1", "b2", "c3"]) - check(assert_type(s.str.capitalize(), pd.Series), pd.Series) - check(assert_type(s.str.casefold(), pd.Series), pd.Series) - check(assert_type(s.str.cat(sep="X"), str), str) - check(assert_type(s.str.center(10), pd.Series), pd.Series) - check(assert_type(s.str.contains("a"), "pd.Series[bool]"), pd.Series, np.bool_) - check( - assert_type(s.str.contains(re.compile(r"a")), "pd.Series[bool]"), - pd.Series, - np.bool_, - ) - check(assert_type(s.str.count("pp"), "pd.Series[int]"), pd.Series, np.integer) - check(assert_type(s.str.decode("utf-8"), pd.Series), pd.Series) - check(assert_type(s.str.encode("latin-1"), pd.Series), pd.Series) - check(assert_type(s.str.endswith("e"), "pd.Series[bool]"), pd.Series, np.bool_) - check( - assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]"), pd.Series, np.bool_ - ) - check(assert_type(s3.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) - check(assert_type(s3.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) - check(assert_type(s.str.find("p"), pd.Series), pd.Series) - check(assert_type(s.str.findall("pp"), pd.Series), pd.Series) - check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.get(2), pd.Series), pd.Series) - check(assert_type(s.str.get_dummies(), pd.DataFrame), pd.DataFrame) - check(assert_type(s.str.index("p"), pd.Series), pd.Series) - check(assert_type(s.str.isalnum(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isalpha(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isdecimal(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isdigit(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isnumeric(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.islower(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isspace(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.istitle(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isupper(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s2.str.join("-"), pd.Series), pd.Series) - check(assert_type(s.str.len(), "pd.Series[int]"), pd.Series, np.integer) - check(assert_type(s.str.ljust(80), pd.Series), pd.Series) - check(assert_type(s.str.lower(), pd.Series), pd.Series) - check(assert_type(s.str.lstrip("a"), pd.Series), pd.Series) - check(assert_type(s.str.match("pp"), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.normalize("NFD"), pd.Series), pd.Series) - check(assert_type(s.str.pad(80, "right"), pd.Series), pd.Series) - check(assert_type(s.str.partition("p"), pd.DataFrame), pd.DataFrame) - check(assert_type(s.str.removeprefix("a"), pd.Series), pd.Series) - check(assert_type(s.str.removesuffix("e"), pd.Series), pd.Series) - check(assert_type(s.str.repeat(2), pd.Series), pd.Series) - check(assert_type(s.str.replace("a", "X"), pd.Series), pd.Series) - check(assert_type(s.str.rfind("e"), pd.Series), pd.Series) - check(assert_type(s.str.rindex("p"), pd.Series), pd.Series) - check(assert_type(s.str.rjust(80), pd.Series), pd.Series) - check(assert_type(s.str.rpartition("p"), pd.DataFrame), pd.DataFrame) - check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]"), pd.Series, list) - check(assert_type(s.str.rsplit("a", expand=True), pd.DataFrame), pd.DataFrame) - check( - assert_type(s.str.rsplit("a", expand=False), "pd.Series[list[str]]"), - pd.Series, - list, - ) - check(assert_type(s.str.rstrip(), pd.Series), pd.Series) - check(assert_type(s.str.slice(0, 4, 2), pd.Series), pd.Series) - check(assert_type(s.str.slice_replace(0, 2, "XX"), pd.Series), pd.Series) - check(assert_type(s.str.split("a"), "pd.Series[list[str]]"), pd.Series, list) - # GH 194 - check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame) - check( - assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]"), - pd.Series, - list, - ) - check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, np.bool_) - check( - assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"), - pd.Series, - np.bool_, - ) - check(assert_type(s.str.strip(), pd.Series), pd.Series) - check(assert_type(s.str.swapcase(), pd.Series), pd.Series) - check(assert_type(s.str.title(), pd.Series), pd.Series) - check(assert_type(s.str.translate(None), pd.Series), pd.Series) - check(assert_type(s.str.upper(), pd.Series), pd.Series) - check(assert_type(s.str.wrap(80), pd.Series), pd.Series) - check(assert_type(s.str.zfill(10), pd.Series), pd.Series) - - -def test_series_overloads_cat(): - s = pd.Series( - ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] - ) - check(assert_type(s.str.cat(sep=";"), str), str) - check(assert_type(s.str.cat(None, sep=";"), str), str) - check( - assert_type(s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), pd.Series), - pd.Series, - ) - - -def test_series_overloads_partition(): - s = pd.Series( - [ - "ap;pl;ep", - "ban;an;ap", - "Che;rr;yp", - "DA;TEp", - "eGGp;LANT;p", - "12;3p", - "23.45p", - ] - ) - check(assert_type(s.str.partition(sep=";"), pd.DataFrame), pd.DataFrame) - check( - assert_type(s.str.partition(sep=";", expand=True), pd.DataFrame), pd.DataFrame - ) - check(assert_type(s.str.partition(sep=";", expand=False), pd.Series), pd.Series) - - check(assert_type(s.str.rpartition(sep=";"), pd.DataFrame), pd.DataFrame) - check( - assert_type(s.str.rpartition(sep=";", expand=True), pd.DataFrame), pd.DataFrame - ) - check(assert_type(s.str.rpartition(sep=";", expand=False), pd.Series), pd.Series) - - -def test_series_overloads_extract(): - s = pd.Series( - ["appl;ep", "ban;anap", "Cherr;yp", "DATEp", "eGGp;LANTp", "12;3p", "23.45p"] - ) - check(assert_type(s.str.extract(r"[ab](\d)"), pd.DataFrame), pd.DataFrame) - check( - assert_type(s.str.extract(r"[ab](\d)", expand=True), pd.DataFrame), pd.DataFrame - ) - check(assert_type(s.str.extract(r"[ab](\d)", expand=False), pd.Series), pd.Series) - check( - assert_type(s.str.extract(r"[ab](\d)", re.IGNORECASE, False), pd.Series), - pd.Series, - ) - - def test_relops() -> None: # GH 175 s: str = "abc" diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py new file mode 100644 index 000000000..4e83160f8 --- /dev/null +++ b/tests/test_string_accessors.py @@ -0,0 +1,415 @@ +import functools +import re + +import numpy as np +import numpy.typing as npt +import pandas as pd +from typing_extensions import assert_type + +from tests import check + +# Separately define here so pytest works +np_ndarray_bool = npt.NDArray[np.bool_] + + +DATA = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] +DATA_BYTES = [b"applep", b"bananap"] + + +def test_string_accessors_type_preserving_series() -> None: + s_str = pd.Series(DATA) + s_bytes = pd.Series(DATA_BYTES) + check(assert_type(s_str.str.slice(0, 4, 2), "pd.Series[str]"), pd.Series, str) + check(assert_type(s_bytes.str.slice(0, 4, 2), "pd.Series[bytes]"), pd.Series, bytes) + + +def test_string_accessors_type_preserving_index() -> None: + idx_str = pd.Index(DATA) + idx_bytes = pd.Index(DATA_BYTES) + check(assert_type(idx_str.str.slice(0, 4, 2), "pd.Index[str]"), pd.Index, str) + check(assert_type(idx_bytes.str.slice(0, 4, 2), "pd.Index[bytes]"), pd.Index, bytes) + + +def test_string_accessors_boolean_series(): + s = pd.Series(DATA) + _check = functools.partial(check, klass=pd.Series, dtype=np.bool_) + _check(assert_type(s.str.startswith("a"), "pd.Series[bool]")) + _check( + assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"), + ) + _check( + assert_type(s.str.contains("a"), "pd.Series[bool]"), + ) + _check( + assert_type(s.str.contains(re.compile(r"a")), "pd.Series[bool]"), + ) + _check(assert_type(s.str.endswith("e"), "pd.Series[bool]")) + _check(assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]")) + _check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]")) + _check(assert_type(s.str.isalnum(), "pd.Series[bool]")) + _check(assert_type(s.str.isalpha(), "pd.Series[bool]")) + _check(assert_type(s.str.isdecimal(), "pd.Series[bool]")) + _check(assert_type(s.str.isdigit(), "pd.Series[bool]")) + _check(assert_type(s.str.isnumeric(), "pd.Series[bool]")) + _check(assert_type(s.str.islower(), "pd.Series[bool]")) + _check(assert_type(s.str.isspace(), "pd.Series[bool]")) + _check(assert_type(s.str.istitle(), "pd.Series[bool]")) + _check(assert_type(s.str.isupper(), "pd.Series[bool]")) + _check(assert_type(s.str.match("pp"), "pd.Series[bool]")) + + +def test_string_accessors_boolean_index(): + idx = pd.Index(DATA) + _check = functools.partial(check, klass=np.ndarray, dtype=np.bool_) + _check(assert_type(idx.str.startswith("a"), np_ndarray_bool)) + _check( + assert_type(idx.str.startswith(("a", "b")), np_ndarray_bool), + ) + _check( + assert_type(idx.str.contains("a"), np_ndarray_bool), + ) + _check( + assert_type(idx.str.contains(re.compile(r"a")), np_ndarray_bool), + ) + _check(assert_type(idx.str.endswith("e"), np_ndarray_bool)) + _check(assert_type(idx.str.endswith(("e", "f")), np_ndarray_bool)) + _check(assert_type(idx.str.fullmatch("apple"), np_ndarray_bool)) + _check(assert_type(idx.str.isalnum(), np_ndarray_bool)) + _check(assert_type(idx.str.isalpha(), np_ndarray_bool)) + _check(assert_type(idx.str.isdecimal(), np_ndarray_bool)) + _check(assert_type(idx.str.isdigit(), np_ndarray_bool)) + _check(assert_type(idx.str.isnumeric(), np_ndarray_bool)) + _check(assert_type(idx.str.islower(), np_ndarray_bool)) + _check(assert_type(idx.str.isspace(), np_ndarray_bool)) + _check(assert_type(idx.str.istitle(), np_ndarray_bool)) + _check(assert_type(idx.str.isupper(), np_ndarray_bool)) + _check(assert_type(idx.str.match("pp"), np_ndarray_bool)) + + +def test_string_accessors_integer_series(): + s = pd.Series(DATA) + _check = functools.partial(check, klass=pd.Series, dtype=np.integer) + _check(assert_type(s.str.find("p"), "pd.Series[int]")) + _check(assert_type(s.str.index("p"), "pd.Series[int]")) + _check(assert_type(s.str.rfind("e"), "pd.Series[int]")) + _check(assert_type(s.str.rindex("p"), "pd.Series[int]")) + _check(assert_type(s.str.count("pp"), "pd.Series[int]")) + _check(assert_type(s.str.len(), "pd.Series[int]")) + + +def test_string_accessors_integer_index(): + idx = pd.Index(DATA) + _check = functools.partial(check, klass=pd.Index, dtype=np.integer) + _check(assert_type(idx.str.find("p"), "pd.Index[int]")) + _check(assert_type(idx.str.index("p"), "pd.Index[int]")) + _check(assert_type(idx.str.rfind("e"), "pd.Index[int]")) + _check(assert_type(idx.str.rindex("p"), "pd.Index[int]")) + _check(assert_type(idx.str.count("pp"), "pd.Index[int]")) + _check(assert_type(idx.str.len(), "pd.Index[int]")) + + +def test_string_accessors_string_series(): + s = pd.Series(DATA) + _check = functools.partial(check, klass=pd.Series, dtype=str) + _check(assert_type(s.str.capitalize(), "pd.Series[str]")) + _check(assert_type(s.str.casefold(), "pd.Series[str]")) + check(assert_type(s.str.cat(sep="X"), str), str) + _check(assert_type(s.str.center(10), "pd.Series[str]")) + _check(assert_type(s.str.get(2), "pd.Series[str]")) + _check(assert_type(s.str.ljust(80), "pd.Series[str]")) + _check(assert_type(s.str.lower(), "pd.Series[str]")) + _check(assert_type(s.str.lstrip("a"), "pd.Series[str]")) + _check(assert_type(s.str.normalize("NFD"), "pd.Series[str]")) + _check(assert_type(s.str.pad(80, "right"), "pd.Series[str]")) + _check(assert_type(s.str.removeprefix("a"), "pd.Series[str]")) + _check(assert_type(s.str.removesuffix("e"), "pd.Series[str]")) + _check(assert_type(s.str.repeat(2), "pd.Series[str]")) + _check(assert_type(s.str.replace("a", "X"), "pd.Series[str]")) + _check(assert_type(s.str.rjust(80), "pd.Series[str]")) + _check(assert_type(s.str.rstrip(), "pd.Series[str]")) + _check(assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]")) + _check(assert_type(s.str.strip(), "pd.Series[str]")) + _check(assert_type(s.str.swapcase(), "pd.Series[str]")) + _check(assert_type(s.str.title(), "pd.Series[str]")) + _check( + assert_type(s.str.translate({241: "n"}), "pd.Series[str]"), + ) + _check(assert_type(s.str.upper(), "pd.Series[str]")) + _check(assert_type(s.str.wrap(80), "pd.Series[str]")) + _check(assert_type(s.str.zfill(10), "pd.Series[str]")) + s_bytes = pd.Series([b"a1", b"b2", b"c3"]) + _check(assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]")) + s_list = pd.Series([["apple", "banana"], ["cherry", "date"], ["one", "eggplant"]]) + _check(assert_type(s_list.str.join("-"), "pd.Series[str]")) + + +def test_string_accessors_string_index(): + idx = pd.Index(DATA) + _check = functools.partial(check, klass=pd.Index, dtype=str) + _check(assert_type(idx.str.capitalize(), "pd.Index[str]")) + _check(assert_type(idx.str.casefold(), "pd.Index[str]")) + check(assert_type(idx.str.cat(sep="X"), str), str) + _check(assert_type(idx.str.center(10), "pd.Index[str]")) + _check(assert_type(idx.str.get(2), "pd.Index[str]")) + _check(assert_type(idx.str.ljust(80), "pd.Index[str]")) + _check(assert_type(idx.str.lower(), "pd.Index[str]")) + _check(assert_type(idx.str.lstrip("a"), "pd.Index[str]")) + _check(assert_type(idx.str.normalize("NFD"), "pd.Index[str]")) + _check(assert_type(idx.str.pad(80, "right"), "pd.Index[str]")) + _check(assert_type(idx.str.removeprefix("a"), "pd.Index[str]")) + _check(assert_type(idx.str.removesuffix("e"), "pd.Index[str]")) + _check(assert_type(idx.str.repeat(2), "pd.Index[str]")) + _check(assert_type(idx.str.replace("a", "X"), "pd.Index[str]")) + _check(assert_type(idx.str.rjust(80), "pd.Index[str]")) + _check(assert_type(idx.str.rstrip(), "pd.Index[str]")) + _check(assert_type(idx.str.slice_replace(0, 2, "XX"), "pd.Index[str]")) + _check(assert_type(idx.str.strip(), "pd.Index[str]")) + _check(assert_type(idx.str.swapcase(), "pd.Index[str]")) + _check(assert_type(idx.str.title(), "pd.Index[str]")) + _check( + assert_type(idx.str.translate({241: "n"}), "pd.Index[str]"), + ) + _check(assert_type(idx.str.upper(), "pd.Index[str]")) + _check(assert_type(idx.str.wrap(80), "pd.Index[str]")) + _check(assert_type(idx.str.zfill(10), "pd.Index[str]")) + idx_bytes = pd.Index([b"a1", b"b2", b"c3"]) + _check(assert_type(idx_bytes.str.decode("utf-8"), "pd.Index[str]")) + idx_list = pd.Index([["apple", "banana"], ["cherry", "date"], ["one", "eggplant"]]) + _check(assert_type(idx_list.str.join("-"), "pd.Index[str]")) + + +def test_string_accessors_bytes_series(): + s = pd.Series(["a1", "b2", "c3"]) + check(assert_type(s.str.encode("latin-1"), "pd.Series[bytes]"), pd.Series, bytes) + + +def test_string_accessors_bytes_index(): + s = pd.Index(["a1", "b2", "c3"]) + check(assert_type(s.str.encode("latin-1"), "pd.Index[bytes]"), pd.Index, bytes) + + +def test_string_accessors_list_series(): + s = pd.Series(DATA) + _check = functools.partial(check, klass=pd.Series, dtype=list) + _check(assert_type(s.str.findall("pp"), "pd.Series[list[str]]")) + _check(assert_type(s.str.split("a"), "pd.Series[list[str]]")) + # GH 194 + _check(assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]")) + _check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]")) + _check(assert_type(s.str.rsplit("a", expand=False), "pd.Series[list[str]]")) + + +def test_string_accessors_list_index(): + idx = pd.Index(DATA) + _check = functools.partial(check, klass=pd.Index, dtype=list) + _check(assert_type(idx.str.findall("pp"), "pd.Index[list[str]]")) + _check(assert_type(idx.str.split("a"), "pd.Index[list[str]]")) + # GH 194 + _check(assert_type(idx.str.split("a", expand=False), "pd.Index[list[str]]")) + _check(assert_type(idx.str.rsplit("a"), "pd.Index[list[str]]")) + _check(assert_type(idx.str.rsplit("a", expand=False), "pd.Index[list[str]]")) + + +def test_string_accessors_expanding_series(): + s = pd.Series(["a1", "b2", "c3"]) + _check = functools.partial(check, klass=pd.DataFrame) + _check(assert_type(s.str.extract(r"([ab])?(\d)"), pd.DataFrame)) + _check(assert_type(s.str.extractall(r"([ab])?(\d)"), pd.DataFrame)) + _check(assert_type(s.str.get_dummies(), pd.DataFrame)) + _check(assert_type(s.str.partition("p"), pd.DataFrame)) + _check(assert_type(s.str.rpartition("p"), pd.DataFrame)) + _check(assert_type(s.str.rsplit("a", expand=True), pd.DataFrame)) + _check(assert_type(s.str.split("a", expand=True), pd.DataFrame)) + + +def test_string_accessors_expanding_index(): + idx = pd.Index(["a1", "b2", "c3"]) + _check = functools.partial(check, klass=pd.MultiIndex) + _check(assert_type(idx.str.get_dummies(), pd.MultiIndex)) + _check(assert_type(idx.str.partition("p"), pd.MultiIndex)) + _check(assert_type(idx.str.rpartition("p"), pd.MultiIndex)) + _check(assert_type(idx.str.rsplit("a", expand=True), pd.MultiIndex)) + _check(assert_type(idx.str.split("a", expand=True), pd.MultiIndex)) + + # These ones are the odd ones out? + check(assert_type(idx.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) + check(assert_type(idx.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) + + +def test_series_overloads_partition(): + s = pd.Series( + [ + "ap;pl;ep", + "ban;an;ap", + "Che;rr;yp", + "DA;TEp", + "eGGp;LANT;p", + "12;3p", + "23.45p", + ] + ) + check(assert_type(s.str.partition(sep=";"), pd.DataFrame), pd.DataFrame) + check( + assert_type(s.str.partition(sep=";", expand=True), pd.DataFrame), pd.DataFrame + ) + check( + assert_type(s.str.partition(sep=";", expand=False), "pd.Series[type[object]]"), + pd.Series, + object, + ) + + check(assert_type(s.str.rpartition(sep=";"), pd.DataFrame), pd.DataFrame) + check( + assert_type(s.str.rpartition(sep=";", expand=True), pd.DataFrame), pd.DataFrame + ) + check( + assert_type(s.str.rpartition(sep=";", expand=False), "pd.Series[type[object]]"), + pd.Series, + object, + ) + + +def test_index_overloads_partition(): + idx = pd.Index( + [ + "ap;pl;ep", + "ban;an;ap", + "Che;rr;yp", + "DA;TEp", + "eGGp;LANT;p", + "12;3p", + "23.45p", + ] + ) + check(assert_type(idx.str.partition(sep=";"), pd.MultiIndex), pd.MultiIndex) + check( + assert_type(idx.str.partition(sep=";", expand=True), pd.MultiIndex), + pd.MultiIndex, + ) + check( + assert_type(idx.str.partition(sep=";", expand=False), "pd.Index[type[object]]"), + pd.Index, + object, + ) + + check(assert_type(idx.str.rpartition(sep=";"), pd.MultiIndex), pd.MultiIndex) + check( + assert_type(idx.str.rpartition(sep=";", expand=True), pd.MultiIndex), + pd.MultiIndex, + ) + check( + assert_type( + idx.str.rpartition(sep=";", expand=False), "pd.Index[type[object]]" + ), + pd.Index, + object, + ) + + +def test_series_overloads_cat(): + s = pd.Series(DATA) + check(assert_type(s.str.cat(sep=";"), str), str) + check(assert_type(s.str.cat(None, sep=";"), str), str) + check( + assert_type( + s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), + "pd.Series[str]", + ), + pd.Series, + str, + ) + check( + assert_type( + s.str.cat(pd.Series(["A", "B", "C", "D", "E", "F", "G"]), sep=";"), + "pd.Series[str]", + ), + pd.Series, + str, + ) + unknown_s = pd.DataFrame({"a": list("abcdefg")})["a"] + check(assert_type(s.str.cat(unknown_s, sep=";"), "pd.Series[str]"), pd.Series, str) + check(assert_type(unknown_s.str.cat(s, sep=";"), "pd.Series[str]"), pd.Series, str) + check( + assert_type(unknown_s.str.cat(unknown_s, sep=";"), "pd.Series[str]"), + pd.Series, + str, + ) + + +def test_index_overloads_cat(): + idx = pd.Index(DATA) + check(assert_type(idx.str.cat(sep=";"), str), str) + check(assert_type(idx.str.cat(None, sep=";"), str), str) + check( + assert_type( + idx.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), + "pd.Index[str]", + ), + pd.Index, + str, + ) + check( + assert_type( + idx.str.cat(pd.Index(["A", "B", "C", "D", "E", "F", "G"]), sep=";"), + "pd.Index[str]", + ), + pd.Index, + str, + ) + unknown_idx = pd.DataFrame({"a": list("abcdefg")}).set_index("a").index + check( + assert_type(idx.str.cat(unknown_idx, sep=";"), "pd.Index[str]"), pd.Index, str + ) + check( + assert_type(unknown_idx.str.cat(idx, sep=";"), "pd.Index[str]"), pd.Index, str + ) + check( + assert_type(unknown_idx.str.cat(unknown_idx, sep=";"), "pd.Index[str]"), + pd.Index, + str, + ) + + +def test_series_overloads_extract(): + s = pd.Series(DATA) + check(assert_type(s.str.extract(r"[ab](\d)"), pd.DataFrame), pd.DataFrame) + check( + assert_type(s.str.extract(r"[ab](\d)", expand=True), pd.DataFrame), pd.DataFrame + ) + check( + assert_type( + s.str.extract(r"[ab](\d)", expand=False), "pd.Series[type[object]]" + ), + pd.Series, + object, + ) + check( + assert_type( + s.str.extract(r"[ab](\d)", re.IGNORECASE, False), "pd.Series[type[object]]" + ), + pd.Series, + object, + ) + + +def test_index_overloads_extract(): + idx = pd.Index(DATA) + check(assert_type(idx.str.extract(r"[ab](\d)"), pd.DataFrame), pd.DataFrame) + check( + assert_type(idx.str.extract(r"[ab](\d)", expand=True), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type( + idx.str.extract(r"[ab](\d)", expand=False), "pd.Index[type[object]]" + ), + pd.Index, + object, + ) + check( + assert_type( + idx.str.extract(r"[ab](\d)", re.IGNORECASE, False), "pd.Index[type[object]]" + ), + pd.Index, + object, + )