diff --git a/stdlib/re.pyi b/stdlib/re.pyi index ff2a55fb4e61..2f4f3a3a0ed4 100644 --- a/stdlib/re.pyi +++ b/stdlib/re.pyi @@ -1,6 +1,7 @@ import enum import sre_compile import sys +from _typeshed import ReadableBuffer from collections.abc import Callable, Iterator from sre_constants import error as error from typing import Any, AnyStr, overload @@ -155,70 +156,67 @@ if sys.version_info < (3, 7): # undocumented _pattern_type: type -# Type-wise these overloads are unnecessary, they could also be modeled using +# Type-wise the compile() overloads are unnecessary, they could also be modeled using # unions in the parameter types. However mypy has a bug regarding TypeVar # constraints (https://github.com/python/mypy/issues/11880), # which limits us here because AnyStr is a constrained TypeVar. +# pattern arguments do *not* accept arbitrary buffers such as bytearray, +# because the pattern must be hashable. @overload def compile(pattern: AnyStr, flags: _FlagsType = ...) -> Pattern[AnyStr]: ... @overload def compile(pattern: Pattern[AnyStr], flags: _FlagsType = ...) -> Pattern[AnyStr]: ... @overload -def search(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ... +def search(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> Match[str] | None: ... @overload -def search(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ... +def search(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> Match[bytes] | None: ... @overload -def match(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ... +def match(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> Match[str] | None: ... @overload -def match(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ... +def match(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> Match[bytes] | None: ... @overload -def fullmatch(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ... +def fullmatch(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> Match[str] | None: ... @overload -def fullmatch(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ... +def fullmatch(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> Match[bytes] | None: ... @overload -def split(pattern: AnyStr, string: AnyStr, maxsplit: int = ..., flags: _FlagsType = ...) -> list[AnyStr | Any]: ... +def split(pattern: str | Pattern[str], string: str, maxsplit: int = ..., flags: _FlagsType = ...) -> list[str | Any]: ... @overload -def split(pattern: Pattern[AnyStr], string: AnyStr, maxsplit: int = ..., flags: _FlagsType = ...) -> list[AnyStr | Any]: ... +def split( + pattern: bytes | Pattern[bytes], string: ReadableBuffer, maxsplit: int = ..., flags: _FlagsType = ... +) -> list[bytes | Any]: ... @overload -def findall(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> list[Any]: ... +def findall(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> list[Any]: ... @overload -def findall(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> list[Any]: ... - -# Return an iterator yielding match objects over all non-overlapping matches -# for the RE pattern in string. The string is scanned left-to-right, and -# matches are returned in the order found. Empty matches are included in the -# result unless they touch the beginning of another match. -@overload -def finditer(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Iterator[Match[AnyStr]]: ... +def findall(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> list[Any]: ... @overload -def finditer(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Iterator[Match[AnyStr]]: ... +def finditer(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> Iterator[Match[str]]: ... @overload -def sub(pattern: AnyStr, repl: AnyStr, string: AnyStr, count: int = ..., flags: _FlagsType = ...) -> AnyStr: ... +def finditer(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> Iterator[Match[bytes]]: ... @overload def sub( - pattern: AnyStr, repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ..., flags: _FlagsType = ... -) -> AnyStr: ... -@overload -def sub(pattern: Pattern[AnyStr], repl: AnyStr, string: AnyStr, count: int = ..., flags: _FlagsType = ...) -> AnyStr: ... + pattern: str | Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ..., flags: _FlagsType = ... +) -> str: ... @overload def sub( - pattern: Pattern[AnyStr], repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ..., flags: _FlagsType = ... -) -> AnyStr: ... -@overload -def subn(pattern: AnyStr, repl: AnyStr, string: AnyStr, count: int = ..., flags: _FlagsType = ...) -> tuple[AnyStr, int]: ... -@overload -def subn( - pattern: AnyStr, repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ..., flags: _FlagsType = ... -) -> tuple[AnyStr, int]: ... + pattern: bytes | Pattern[bytes], + repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer], + string: ReadableBuffer, + count: int = ..., + flags: _FlagsType = ..., +) -> bytes: ... @overload def subn( - pattern: Pattern[AnyStr], repl: AnyStr, string: AnyStr, count: int = ..., flags: _FlagsType = ... -) -> tuple[AnyStr, int]: ... + pattern: str | Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ..., flags: _FlagsType = ... +) -> tuple[str, int]: ... @overload def subn( - pattern: Pattern[AnyStr], repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ..., flags: _FlagsType = ... -) -> tuple[AnyStr, int]: ... + pattern: bytes | Pattern[bytes], + repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer], + string: ReadableBuffer, + count: int = ..., + flags: _FlagsType = ..., +) -> tuple[bytes, int]: ... def escape(pattern: AnyStr) -> AnyStr: ... def purge() -> None: ... def template(pattern: AnyStr | Pattern[AnyStr], flags: _FlagsType = ...) -> Pattern[AnyStr]: ... diff --git a/stdlib/typing.pyi b/stdlib/typing.pyi index 28b588d79c9b..aaf5536e87d2 100644 --- a/stdlib/typing.pyi +++ b/stdlib/typing.pyi @@ -1,6 +1,6 @@ import collections # Needed by aliases like DefaultDict, see mypy issue 2986 import sys -from _typeshed import Self as TypeshedSelf, SupportsKeysAndGetItem +from _typeshed import ReadableBuffer, Self as TypeshedSelf, SupportsKeysAndGetItem from abc import ABCMeta, abstractmethod from types import BuiltinFunctionType, CodeType, FrameType, FunctionType, MethodType, ModuleType, TracebackType from typing_extensions import Literal as _Literal, ParamSpec as _ParamSpec, final as _final @@ -1079,7 +1079,10 @@ class Match(Generic[AnyStr]): # this match instance. @property def re(self) -> Pattern[AnyStr]: ... - def expand(self, template: AnyStr) -> AnyStr: ... + @overload + def expand(self: Match[str], template: str) -> str: ... + @overload + def expand(self: Match[bytes], template: ReadableBuffer) -> bytes: ... # group() returns "AnyStr" or "AnyStr | None", depending on the pattern. @overload def group(self, __group: _Literal[0] = ...) -> AnyStr: ... @@ -1124,20 +1127,49 @@ class Pattern(Generic[AnyStr]): def groups(self) -> int: ... @property def pattern(self) -> AnyStr: ... - def search(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... - def match(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... - def fullmatch(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... - def split(self, string: AnyStr, maxsplit: int = ...) -> list[AnyStr | Any]: ... - def findall(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> list[Any]: ... - def finditer(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Iterator[Match[AnyStr]]: ... @overload - def sub(self, repl: AnyStr, string: AnyStr, count: int = ...) -> AnyStr: ... + def search(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... + @overload + def search(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... + @overload + def match(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... + @overload + def match(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... + @overload + def fullmatch(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... + @overload + def fullmatch(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... + @overload + def split(self: Pattern[str], string: str, maxsplit: int = ...) -> list[str | Any]: ... + @overload + def split(self: Pattern[bytes], string: ReadableBuffer, maxsplit: int = ...) -> list[bytes | Any]: ... + # return type depends on the number of groups in the pattern + @overload + def findall(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> list[Any]: ... + @overload + def findall(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> list[Any]: ... + @overload + def finditer(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Iterator[Match[str]]: ... + @overload + def finditer(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Iterator[Match[bytes]]: ... + @overload + def sub(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> str: ... @overload - def sub(self, repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> AnyStr: ... + def sub( + self: Pattern[bytes], + repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer], + string: ReadableBuffer, + count: int = ..., + ) -> bytes: ... @overload - def subn(self, repl: AnyStr, string: AnyStr, count: int = ...) -> tuple[AnyStr, int]: ... + def subn(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> tuple[str, int]: ... @overload - def subn(self, repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> tuple[AnyStr, int]: ... + def subn( + self: Pattern[bytes], + repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer], + string: ReadableBuffer, + count: int = ..., + ) -> tuple[bytes, int]: ... def __copy__(self) -> Pattern[AnyStr]: ... def __deepcopy__(self, __memo: Any) -> Pattern[AnyStr]: ... if sys.version_info >= (3, 9): diff --git a/test_cases/stdlib/typing/pattern.py b/test_cases/stdlib/typing/pattern.py new file mode 100644 index 000000000000..69978c4aa710 --- /dev/null +++ b/test_cases/stdlib/typing/pattern.py @@ -0,0 +1,8 @@ +from typing import Match, Optional, Pattern +from typing_extensions import assert_type + + +def test_search(str_pat: Pattern[str], bytes_pat: Pattern[bytes]) -> None: + assert_type(str_pat.search("x"), Optional[Match[str]]) + assert_type(bytes_pat.search(b"x"), Optional[Match[bytes]]) + assert_type(bytes_pat.search(bytearray(b"x")), Optional[Match[bytes]])