Skip to content

Commit 02e0c98

Browse files
Buffer support for re (#7679)
1 parent 5dad506 commit 02e0c98

File tree

3 files changed

+86
-48
lines changed

3 files changed

+86
-48
lines changed

stdlib/re.pyi

+34-36
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import enum
22
import sre_compile
33
import sys
4+
from _typeshed import ReadableBuffer
45
from collections.abc import Callable, Iterator
56
from sre_constants import error as error
67
from typing import Any, AnyStr, overload
@@ -155,70 +156,67 @@ if sys.version_info < (3, 7):
155156
# undocumented
156157
_pattern_type: type
157158

158-
# Type-wise these overloads are unnecessary, they could also be modeled using
159+
# Type-wise the compile() overloads are unnecessary, they could also be modeled using
159160
# unions in the parameter types. However mypy has a bug regarding TypeVar
160161
# constraints (https://github.com/python/mypy/issues/11880),
161162
# which limits us here because AnyStr is a constrained TypeVar.
162163

164+
# pattern arguments do *not* accept arbitrary buffers such as bytearray,
165+
# because the pattern must be hashable.
163166
@overload
164167
def compile(pattern: AnyStr, flags: _FlagsType = ...) -> Pattern[AnyStr]: ...
165168
@overload
166169
def compile(pattern: Pattern[AnyStr], flags: _FlagsType = ...) -> Pattern[AnyStr]: ...
167170
@overload
168-
def search(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ...
171+
def search(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> Match[str] | None: ...
169172
@overload
170-
def search(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ...
173+
def search(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> Match[bytes] | None: ...
171174
@overload
172-
def match(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ...
175+
def match(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> Match[str] | None: ...
173176
@overload
174-
def match(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ...
177+
def match(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> Match[bytes] | None: ...
175178
@overload
176-
def fullmatch(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ...
179+
def fullmatch(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> Match[str] | None: ...
177180
@overload
178-
def fullmatch(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Match[AnyStr] | None: ...
181+
def fullmatch(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> Match[bytes] | None: ...
179182
@overload
180-
def split(pattern: AnyStr, string: AnyStr, maxsplit: int = ..., flags: _FlagsType = ...) -> list[AnyStr | Any]: ...
183+
def split(pattern: str | Pattern[str], string: str, maxsplit: int = ..., flags: _FlagsType = ...) -> list[str | Any]: ...
181184
@overload
182-
def split(pattern: Pattern[AnyStr], string: AnyStr, maxsplit: int = ..., flags: _FlagsType = ...) -> list[AnyStr | Any]: ...
185+
def split(
186+
pattern: bytes | Pattern[bytes], string: ReadableBuffer, maxsplit: int = ..., flags: _FlagsType = ...
187+
) -> list[bytes | Any]: ...
183188
@overload
184-
def findall(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> list[Any]: ...
189+
def findall(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> list[Any]: ...
185190
@overload
186-
def findall(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> list[Any]: ...
187-
188-
# Return an iterator yielding match objects over all non-overlapping matches
189-
# for the RE pattern in string. The string is scanned left-to-right, and
190-
# matches are returned in the order found. Empty matches are included in the
191-
# result unless they touch the beginning of another match.
192-
@overload
193-
def finditer(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Iterator[Match[AnyStr]]: ...
191+
def findall(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> list[Any]: ...
194192
@overload
195-
def finditer(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Iterator[Match[AnyStr]]: ...
193+
def finditer(pattern: str | Pattern[str], string: str, flags: _FlagsType = ...) -> Iterator[Match[str]]: ...
196194
@overload
197-
def sub(pattern: AnyStr, repl: AnyStr, string: AnyStr, count: int = ..., flags: _FlagsType = ...) -> AnyStr: ...
195+
def finditer(pattern: bytes | Pattern[bytes], string: ReadableBuffer, flags: _FlagsType = ...) -> Iterator[Match[bytes]]: ...
198196
@overload
199197
def sub(
200-
pattern: AnyStr, repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ..., flags: _FlagsType = ...
201-
) -> AnyStr: ...
202-
@overload
203-
def sub(pattern: Pattern[AnyStr], repl: AnyStr, string: AnyStr, count: int = ..., flags: _FlagsType = ...) -> AnyStr: ...
198+
pattern: str | Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ..., flags: _FlagsType = ...
199+
) -> str: ...
204200
@overload
205201
def sub(
206-
pattern: Pattern[AnyStr], repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ..., flags: _FlagsType = ...
207-
) -> AnyStr: ...
208-
@overload
209-
def subn(pattern: AnyStr, repl: AnyStr, string: AnyStr, count: int = ..., flags: _FlagsType = ...) -> tuple[AnyStr, int]: ...
210-
@overload
211-
def subn(
212-
pattern: AnyStr, repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ..., flags: _FlagsType = ...
213-
) -> tuple[AnyStr, int]: ...
202+
pattern: bytes | Pattern[bytes],
203+
repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer],
204+
string: ReadableBuffer,
205+
count: int = ...,
206+
flags: _FlagsType = ...,
207+
) -> bytes: ...
214208
@overload
215209
def subn(
216-
pattern: Pattern[AnyStr], repl: AnyStr, string: AnyStr, count: int = ..., flags: _FlagsType = ...
217-
) -> tuple[AnyStr, int]: ...
210+
pattern: str | Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ..., flags: _FlagsType = ...
211+
) -> tuple[str, int]: ...
218212
@overload
219213
def subn(
220-
pattern: Pattern[AnyStr], repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ..., flags: _FlagsType = ...
221-
) -> tuple[AnyStr, int]: ...
214+
pattern: bytes | Pattern[bytes],
215+
repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer],
216+
string: ReadableBuffer,
217+
count: int = ...,
218+
flags: _FlagsType = ...,
219+
) -> tuple[bytes, int]: ...
222220
def escape(pattern: AnyStr) -> AnyStr: ...
223221
def purge() -> None: ...
224222
def template(pattern: AnyStr | Pattern[AnyStr], flags: _FlagsType = ...) -> Pattern[AnyStr]: ...

stdlib/typing.pyi

+44-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import collections # Needed by aliases like DefaultDict, see mypy issue 2986
22
import sys
3-
from _typeshed import Self as TypeshedSelf, SupportsKeysAndGetItem
3+
from _typeshed import ReadableBuffer, Self as TypeshedSelf, SupportsKeysAndGetItem
44
from abc import ABCMeta, abstractmethod
55
from types import BuiltinFunctionType, CodeType, FrameType, FunctionType, MethodType, ModuleType, TracebackType
66
from typing_extensions import Literal as _Literal, ParamSpec as _ParamSpec, final as _final
@@ -1079,7 +1079,10 @@ class Match(Generic[AnyStr]):
10791079
# this match instance.
10801080
@property
10811081
def re(self) -> Pattern[AnyStr]: ...
1082-
def expand(self, template: AnyStr) -> AnyStr: ...
1082+
@overload
1083+
def expand(self: Match[str], template: str) -> str: ...
1084+
@overload
1085+
def expand(self: Match[bytes], template: ReadableBuffer) -> bytes: ...
10831086
# group() returns "AnyStr" or "AnyStr | None", depending on the pattern.
10841087
@overload
10851088
def group(self, __group: _Literal[0] = ...) -> AnyStr: ...
@@ -1124,20 +1127,49 @@ class Pattern(Generic[AnyStr]):
11241127
def groups(self) -> int: ...
11251128
@property
11261129
def pattern(self) -> AnyStr: ...
1127-
def search(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ...
1128-
def match(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ...
1129-
def fullmatch(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ...
1130-
def split(self, string: AnyStr, maxsplit: int = ...) -> list[AnyStr | Any]: ...
1131-
def findall(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> list[Any]: ...
1132-
def finditer(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Iterator[Match[AnyStr]]: ...
11331130
@overload
1134-
def sub(self, repl: AnyStr, string: AnyStr, count: int = ...) -> AnyStr: ...
1131+
def search(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ...
1132+
@overload
1133+
def search(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ...
1134+
@overload
1135+
def match(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ...
1136+
@overload
1137+
def match(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ...
1138+
@overload
1139+
def fullmatch(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ...
1140+
@overload
1141+
def fullmatch(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ...
1142+
@overload
1143+
def split(self: Pattern[str], string: str, maxsplit: int = ...) -> list[str | Any]: ...
1144+
@overload
1145+
def split(self: Pattern[bytes], string: ReadableBuffer, maxsplit: int = ...) -> list[bytes | Any]: ...
1146+
# return type depends on the number of groups in the pattern
1147+
@overload
1148+
def findall(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> list[Any]: ...
1149+
@overload
1150+
def findall(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> list[Any]: ...
1151+
@overload
1152+
def finditer(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Iterator[Match[str]]: ...
1153+
@overload
1154+
def finditer(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Iterator[Match[bytes]]: ...
1155+
@overload
1156+
def sub(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> str: ...
11351157
@overload
1136-
def sub(self, repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> AnyStr: ...
1158+
def sub(
1159+
self: Pattern[bytes],
1160+
repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer],
1161+
string: ReadableBuffer,
1162+
count: int = ...,
1163+
) -> bytes: ...
11371164
@overload
1138-
def subn(self, repl: AnyStr, string: AnyStr, count: int = ...) -> tuple[AnyStr, int]: ...
1165+
def subn(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> tuple[str, int]: ...
11391166
@overload
1140-
def subn(self, repl: Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> tuple[AnyStr, int]: ...
1167+
def subn(
1168+
self: Pattern[bytes],
1169+
repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer],
1170+
string: ReadableBuffer,
1171+
count: int = ...,
1172+
) -> tuple[bytes, int]: ...
11411173
def __copy__(self) -> Pattern[AnyStr]: ...
11421174
def __deepcopy__(self, __memo: Any) -> Pattern[AnyStr]: ...
11431175
if sys.version_info >= (3, 9):

test_cases/stdlib/typing/pattern.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from typing import Match, Optional, Pattern
2+
from typing_extensions import assert_type
3+
4+
5+
def test_search(str_pat: Pattern[str], bytes_pat: Pattern[bytes]) -> None:
6+
assert_type(str_pat.search("x"), Optional[Match[str]])
7+
assert_type(bytes_pat.search(b"x"), Optional[Match[bytes]])
8+
assert_type(bytes_pat.search(bytearray(b"x")), Optional[Match[bytes]])

0 commit comments

Comments
 (0)