Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Buffer support for re #7679

Merged
merged 2 commits into from
Apr 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 34 additions & 36 deletions stdlib/re.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
import sre_compile
import sys
from _typeshed import ReadableBuffer
from collections.abc import Callable, Iterator
from sre_constants import error as error
from typing import Any, AnyStr, overload
Expand Down Expand Up @@ -155,70 +156,67 @@ if sys.version_info < (3, 7):
# undocumented
_pattern_type: type

# Type-wise these overloads are unnecessary, they could also be modeled using
# Type-wise the compile() overloads are unnecessary, they could also be modeled using
# unions in the parameter types. However mypy has a bug regarding TypeVar
# constraints (https://github.com/python/mypy/issues/11880),
# which limits us here because AnyStr is a constrained TypeVar.

# pattern arguments do *not* accept arbitrary buffers such as bytearray,
# because the pattern must be hashable.
@overload
def compile(pattern: AnyStr, flags: _FlagsType = ...) -> Pattern[AnyStr]: ...
@overload
def compile(pattern: Pattern[AnyStr], flags: _FlagsType = ...) -> Pattern[AnyStr]: ...
@overload
def search(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ...
def search(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> Match[str] | None: ...
@overload
def search(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ...
def search(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> Match[bytes] | None: ...
@overload
def match(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ...
def match(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> Match[str] | None: ...
@overload
def match(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ...
def match(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> Match[bytes] | None: ...
@overload
def fullmatch(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ...
def fullmatch(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> Match[str] | None: ...
@overload
def fullmatch(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ...
def fullmatch(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> Match[bytes] | None: ...
@overload
def split(pattern: AnyStr, string: AnyStr, maxsplit: int = ..., flags: _FlagsType = ...) -> list[AnyStr | Any]: ...
def split(pattern: str | Pattern[str], string: str, maxsplit: int = ..., flags: _FlagsType = ...) -> list[str | Any]: ...
@overload
def split(pattern: Pattern[AnyStr], string: AnyStr, maxsplit: int = ..., flags: _FlagsType = ...) -> list[AnyStr | Any]: ...
def split(
pattern: bytes | Pattern[bytes], string: ReadableBuffer, maxsplit: int = ..., flags: _FlagsType = ...
) -> list[bytes | Any]: ...
@overload
def findall(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> list[Any]: ...
def findall(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> list[Any]: ...
@overload
def findall(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> list[Any]: ...

# Return an iterator yielding match objects over all non-overlapping matches
# for the RE pattern in string. The string is scanned left-to-right, and
# matches are returned in the order found. Empty matches are included in the
# result unless they touch the beginning of another match.
@overload
def finditer(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Iterator[Match[AnyStr]]: ...
def findall(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> list[Any]: ...
@overload
def finditer(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Iterator[Match[AnyStr]]: ...
def finditer(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> Iterator[Match[str]]: ...
@overload
def sub(pattern: AnyStr, repl: AnyStr, string: AnyStr, count: int = ..., flags: _FlagsType = ...) -> AnyStr: ...
def finditer(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> Iterator[Match[bytes]]: ...
@overload
def sub(
pattern: AnyStr, repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ..., flags: _FlagsType = ...
) -> AnyStr: ...
@overload
def sub(pattern: Pattern[AnyStr], repl: AnyStr, string: AnyStr, count: int = ..., flags: _FlagsType = ...) -> AnyStr: ...
pattern: str | Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ..., flags: _FlagsType = ...
) -> str: ...
@overload
def sub(
pattern: Pattern[AnyStr], repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ..., flags: _FlagsType = ...
) -> AnyStr: ...
@overload
def subn(pattern: AnyStr, repl: AnyStr, string: AnyStr, count: int = ..., flags: _FlagsType = ...) -> tuple[AnyStr, int]: ...
@overload
def subn(
pattern: AnyStr, repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ..., flags: _FlagsType = ...
) -> tuple[AnyStr, int]: ...
pattern: bytes | Pattern[bytes],
repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer],
string: ReadableBuffer,
count: int = ...,
flags: _FlagsType = ...,
) -> bytes: ...
@overload
def subn(
pattern: Pattern[AnyStr], repl: AnyStr, string: AnyStr, count: int = ..., flags: _FlagsType = ...
) -> tuple[AnyStr, int]: ...
pattern: str | Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ..., flags: _FlagsType = ...
) -> tuple[str, int]: ...
@overload
def subn(
pattern: Pattern[AnyStr], repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ..., flags: _FlagsType = ...
) -> tuple[AnyStr, int]: ...
pattern: bytes | Pattern[bytes],
repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer],
string: ReadableBuffer,
count: int = ...,
flags: _FlagsType = ...,
) -> tuple[bytes, int]: ...
def escape(pattern: AnyStr) -> AnyStr: ...
def purge() -> None: ...
def template(pattern: AnyStr | Pattern[AnyStr], flags: _FlagsType = ...) -> Pattern[AnyStr]: ...
56 changes: 44 additions & 12 deletions stdlib/typing.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections # Needed by aliases like DefaultDict, see mypy issue 2986
import sys
from _typeshed import Self as TypeshedSelf, SupportsKeysAndGetItem
from _typeshed import ReadableBuffer, Self as TypeshedSelf, SupportsKeysAndGetItem
from abc import ABCMeta, abstractmethod
from types import BuiltinFunctionType, CodeType, FrameType, FunctionType, MethodType, ModuleType, TracebackType
from typing_extensions import Literal as _Literal, ParamSpec as _ParamSpec, final as _final
Expand Down Expand Up @@ -1079,7 +1079,10 @@ class Match(Generic[AnyStr]):
# this match instance.
@property
def re(self) -> Pattern[AnyStr]: ...
def expand(self, template: AnyStr) -> AnyStr: ...
@overload
def expand(self: Match[str], template: str) -> str: ...
@overload
def expand(self: Match[bytes], template: ReadableBuffer) -> bytes: ...
# group() returns "AnyStr" or "AnyStr | None", depending on the pattern.
@overload
def group(self, __group: _Literal[0] = ...) -> AnyStr: ...
Expand Down Expand Up @@ -1124,20 +1127,49 @@ class Pattern(Generic[AnyStr]):
def groups(self) -> int: ...
@property
def pattern(self) -> AnyStr: ...
def search(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ...
def match(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ...
def fullmatch(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ...
def split(self, string: AnyStr, maxsplit: int = ...) -> list[AnyStr | Any]: ...
def findall(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> list[Any]: ...
def finditer(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Iterator[Match[AnyStr]]: ...
@overload
def sub(self, repl: AnyStr, string: AnyStr, count: int = ...) -> AnyStr: ...
def search(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ...
@overload
def search(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ...
@overload
def match(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ...
@overload
def match(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ...
@overload
def fullmatch(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ...
@overload
def fullmatch(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ...
@overload
def split(self: Pattern[str], string: str, maxsplit: int = ...) -> list[str | Any]: ...
@overload
def split(self: Pattern[bytes], string: ReadableBuffer, maxsplit: int = ...) -> list[bytes | Any]: ...
# return type depends on the number of groups in the pattern
@overload
def findall(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> list[Any]: ...
@overload
def findall(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> list[Any]: ...
@overload
def finditer(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Iterator[Match[str]]: ...
@overload
def finditer(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Iterator[Match[bytes]]: ...
@overload
def sub(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> str: ...
@overload
def sub(self, repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> AnyStr: ...
def sub(
self: Pattern[bytes],
repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer],
string: ReadableBuffer,
count: int = ...,
) -> bytes: ...
@overload
def subn(self, repl: AnyStr, string: AnyStr, count: int = ...) -> tuple[AnyStr, int]: ...
def subn(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> tuple[str, int]: ...
@overload
def subn(self, repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> tuple[AnyStr, int]: ...
def subn(
self: Pattern[bytes],
repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer],
string: ReadableBuffer,
count: int = ...,
) -> tuple[bytes, int]: ...
def __copy__(self) -> Pattern[AnyStr]: ...
def __deepcopy__(self, __memo: Any) -> Pattern[AnyStr]: ...
if sys.version_info >= (3, 9):
Expand Down
8 changes: 8 additions & 0 deletions test_cases/stdlib/typing/pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Match, Optional, Pattern
from typing_extensions import assert_type


def test_search(str_pat: Pattern[str], bytes_pat: Pattern[bytes]) -> None:
assert_type(str_pat.search("x"), Optional[Match[str]])
assert_type(bytes_pat.search(b"x"), Optional[Match[bytes]])
assert_type(bytes_pat.search(bytearray(b"x")), Optional[Match[bytes]])