From 14d9316e4e028a0276fa0dc7ca2a845592397988 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Thu, 26 Jan 2023 17:43:10 +0000 Subject: [PATCH 1/5] Add more overloads to the `re` stubs to help out pyright --- stdlib/re.pyi | 18 ++++++++++++++++++ test_cases/stdlib/check_re.py | 23 +++++++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 test_cases/stdlib/check_re.py diff --git a/stdlib/re.pyi b/stdlib/re.pyi index 4962ab8edad9..22b96ac8cd29 100644 --- a/stdlib/re.pyi +++ b/stdlib/re.pyi @@ -68,6 +68,8 @@ class Match(Generic[AnyStr]): def expand(self: Match[str], template: str) -> str: ... @overload def expand(self: Match[bytes], template: ReadableBuffer) -> bytes: ... + @overload + def expand(self, template: AnyStr) -> AnyStr: ... # group() returns "AnyStr" or "AnyStr | None", depending on the pattern. @overload def group(self, __group: Literal[0] = ...) -> AnyStr: ... @@ -117,27 +119,39 @@ class Pattern(Generic[AnyStr]): @overload def search(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... @overload + def search(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... + @overload def match(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... @overload def match(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... @overload + def match(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... + @overload def fullmatch(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... @overload def fullmatch(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... @overload + def fullmatch(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... + @overload def split(self: Pattern[str], string: str, maxsplit: int = ...) -> list[str | Any]: ... @overload def split(self: Pattern[bytes], string: ReadableBuffer, maxsplit: int = ...) -> list[bytes | Any]: ... + @overload + def split(self, string: AnyStr, maxsplit: int = ...) -> list[AnyStr | Any]: ... # return type depends on the number of groups in the pattern @overload def findall(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> list[Any]: ... @overload def findall(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> list[Any]: ... @overload + def findall(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> list[AnyStr]: ... + @overload def finditer(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Iterator[Match[str]]: ... @overload def finditer(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Iterator[Match[bytes]]: ... @overload + def finditer(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Iterator[Match[AnyStr]]: ... + @overload def sub(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> str: ... @overload def sub( @@ -147,6 +161,8 @@ class Pattern(Generic[AnyStr]): count: int = ..., ) -> bytes: ... @overload + def sub(self, repl: AnyStr | Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> AnyStr: ... + @overload def subn(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> tuple[str, int]: ... @overload def subn( @@ -155,6 +171,8 @@ class Pattern(Generic[AnyStr]): string: ReadableBuffer, count: int = ..., ) -> tuple[bytes, int]: ... + @overload + def subn(self, repl: AnyStr | Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> tuple[AnyStr, int]: ... def __copy__(self) -> Pattern[AnyStr]: ... def __deepcopy__(self, __memo: Any) -> Pattern[AnyStr]: ... if sys.version_info >= (3, 9): diff --git a/test_cases/stdlib/check_re.py b/test_cases/stdlib/check_re.py new file mode 100644 index 000000000000..2826ee30c992 --- /dev/null +++ b/test_cases/stdlib/check_re.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import re +import typing as t +from typing_extensions import assert_type + + +def check_re_search(pattern: re.Pattern[t.AnyStr], string: t.AnyStr) -> re.Match[t.AnyStr]: + """See issue #9591""" + match = pattern.search(string) + if match is None: + raise ValueError(f"'{string!r}' does not match {pattern!r}") + return match + + +def check_no_ReadableBuffer_false_negatives() -> None: + b = bytearray(b"foo") # ReadableBuffer + + string_pattern = re.compile("foo") + string_pattern.search(b) # type: ignore + + bytes_pattern = re.compile(b"foo") + assert_type(bytes_pattern.search(b), t.Optional[t.Match[bytes]]) From 897752e09966ad801aeefbb2f1b9eaf4e3638dda Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Thu, 26 Jan 2023 17:47:36 +0000 Subject: [PATCH 2/5] Move test cases around --- test_cases/stdlib/check_re.py | 11 +++++++---- test_cases/stdlib/typing/check_pattern.py | 10 ---------- 2 files changed, 7 insertions(+), 14 deletions(-) delete mode 100644 test_cases/stdlib/typing/check_pattern.py diff --git a/test_cases/stdlib/check_re.py b/test_cases/stdlib/check_re.py index 2826ee30c992..28f7cbc00d3f 100644 --- a/test_cases/stdlib/check_re.py +++ b/test_cases/stdlib/check_re.py @@ -5,7 +5,13 @@ from typing_extensions import assert_type -def check_re_search(pattern: re.Pattern[t.AnyStr], string: t.AnyStr) -> re.Match[t.AnyStr]: +def check_search(str_pat: re.Pattern[str], bytes_pat: re.Pattern[bytes]) -> None: + assert_type(str_pat.search("x"), t.Optional[t.Match[str]]) + assert_type(bytes_pat.search(b"x"), t.Optional[t.Match[bytes]]) + assert_type(bytes_pat.search(bytearray(b"x")), t.Optional[t.Match[bytes]]) + + +def check_search_with_AnyStr(pattern: re.Pattern[t.AnyStr], string: t.AnyStr) -> re.Match[t.AnyStr]: """See issue #9591""" match = pattern.search(string) if match is None: @@ -18,6 +24,3 @@ def check_no_ReadableBuffer_false_negatives() -> None: string_pattern = re.compile("foo") string_pattern.search(b) # type: ignore - - bytes_pattern = re.compile(b"foo") - assert_type(bytes_pattern.search(b), t.Optional[t.Match[bytes]]) diff --git a/test_cases/stdlib/typing/check_pattern.py b/test_cases/stdlib/typing/check_pattern.py deleted file mode 100644 index ec5c1c4f6141..000000000000 --- a/test_cases/stdlib/typing/check_pattern.py +++ /dev/null @@ -1,10 +0,0 @@ -from __future__ import annotations - -from typing import Match, Optional, Pattern -from typing_extensions import assert_type - - -def test_search(str_pat: Pattern[str], bytes_pat: Pattern[bytes]) -> None: - assert_type(str_pat.search("x"), Optional[Match[str]]) - assert_type(bytes_pat.search(b"x"), Optional[Match[bytes]]) - assert_type(bytes_pat.search(bytearray(b"x")), Optional[Match[bytes]]) From 43c361a4530d16a14d944c94c7037881e3be36ee Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Thu, 26 Jan 2023 17:49:24 +0000 Subject: [PATCH 3/5] Many type ignores --- stdlib/re.pyi | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/stdlib/re.pyi b/stdlib/re.pyi index 22b96ac8cd29..12771441b779 100644 --- a/stdlib/re.pyi +++ b/stdlib/re.pyi @@ -67,7 +67,7 @@ class Match(Generic[AnyStr]): @overload def expand(self: Match[str], template: str) -> str: ... @overload - def expand(self: Match[bytes], template: ReadableBuffer) -> bytes: ... + def expand(self: Match[bytes], template: ReadableBuffer) -> bytes: ... # type: ignore[misc] @overload def expand(self, template: AnyStr) -> AnyStr: ... # group() returns "AnyStr" or "AnyStr | None", depending on the pattern. @@ -117,19 +117,19 @@ class Pattern(Generic[AnyStr]): @overload def search(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... @overload - def search(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... + def search(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... # type: ignore[misc] @overload def search(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... @overload def match(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... @overload - def match(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... + def match(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... # type: ignore[misc] @overload def match(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... @overload def fullmatch(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... @overload - def fullmatch(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... + def fullmatch(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... # type: ignore[misc] @overload def fullmatch(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... @overload @@ -148,13 +148,13 @@ class Pattern(Generic[AnyStr]): @overload def finditer(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Iterator[Match[str]]: ... @overload - def finditer(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Iterator[Match[bytes]]: ... + def finditer(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Iterator[Match[bytes]]: ... # type: ignore[misc] @overload def finditer(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Iterator[Match[AnyStr]]: ... @overload def sub(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> str: ... @overload - def sub( + def sub( # type: ignore[misc] self: Pattern[bytes], repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer], string: ReadableBuffer, @@ -165,7 +165,7 @@ class Pattern(Generic[AnyStr]): @overload def subn(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> tuple[str, int]: ... @overload - def subn( + def subn( # type: ignore[misc] self: Pattern[bytes], repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer], string: ReadableBuffer, From 94a72a9d547d95119f27cd18d0657a8a0a1d699e Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Thu, 26 Jan 2023 17:52:06 +0000 Subject: [PATCH 4/5] Simplify --- test_cases/stdlib/check_re.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test_cases/stdlib/check_re.py b/test_cases/stdlib/check_re.py index 28f7cbc00d3f..398c505d5962 100644 --- a/test_cases/stdlib/check_re.py +++ b/test_cases/stdlib/check_re.py @@ -20,7 +20,4 @@ def check_search_with_AnyStr(pattern: re.Pattern[t.AnyStr], string: t.AnyStr) -> def check_no_ReadableBuffer_false_negatives() -> None: - b = bytearray(b"foo") # ReadableBuffer - - string_pattern = re.compile("foo") - string_pattern.search(b) # type: ignore + re.compile("foo").search(bytearray(b"foo")) # type: ignore From cac3d488bfaeb7755804752b1154ab7c7de7c97d Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Thu, 26 Jan 2023 18:37:04 +0000 Subject: [PATCH 5/5] PEP 688 --- test_cases/stdlib/check_re.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test_cases/stdlib/check_re.py b/test_cases/stdlib/check_re.py index 398c505d5962..b6ab2b0d59d2 100644 --- a/test_cases/stdlib/check_re.py +++ b/test_cases/stdlib/check_re.py @@ -1,5 +1,6 @@ from __future__ import annotations +import mmap import re import typing as t from typing_extensions import assert_type @@ -9,6 +10,7 @@ def check_search(str_pat: re.Pattern[str], bytes_pat: re.Pattern[bytes]) -> None assert_type(str_pat.search("x"), t.Optional[t.Match[str]]) assert_type(bytes_pat.search(b"x"), t.Optional[t.Match[bytes]]) assert_type(bytes_pat.search(bytearray(b"x")), t.Optional[t.Match[bytes]]) + assert_type(bytes_pat.search(mmap.mmap(0, 10)), t.Optional[t.Match[bytes]]) def check_search_with_AnyStr(pattern: re.Pattern[t.AnyStr], string: t.AnyStr) -> re.Match[t.AnyStr]: @@ -21,3 +23,4 @@ def check_search_with_AnyStr(pattern: re.Pattern[t.AnyStr], string: t.AnyStr) -> def check_no_ReadableBuffer_false_negatives() -> None: re.compile("foo").search(bytearray(b"foo")) # type: ignore + re.compile("foo").search(mmap.mmap(0, 10)) # type: ignore