diff --git a/mypy/stubtest.py b/mypy/stubtest.py index d2f9cfcca974..96f6aa5af96a 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -5,6 +5,7 @@ """ import argparse +import collections.abc import copy import enum import importlib @@ -23,7 +24,7 @@ from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union, cast import typing_extensions -from typing_extensions import Type +from typing_extensions import Type, get_origin import mypy.build import mypy.modulefinder @@ -1031,39 +1032,71 @@ def verify_typealias( stub: nodes.TypeAlias, runtime: MaybeMissing[Any], object_path: List[str] ) -> Iterator[Error]: stub_target = mypy.types.get_proper_type(stub.target) + stub_desc = f"Type alias for {stub_target}" if isinstance(runtime, Missing): - yield Error( - object_path, - "is not present at runtime", - stub, - runtime, - stub_desc=f"Type alias for: {stub_target}", - ) + yield Error(object_path, "is not present at runtime", stub, runtime, stub_desc=stub_desc) return + runtime_origin = get_origin(runtime) or runtime if isinstance(stub_target, mypy.types.Instance): - yield from verify(stub_target.type, runtime, object_path) + if not isinstance(runtime_origin, type): + yield Error( + object_path, + "is inconsistent, runtime is not a type", + stub, + runtime, + stub_desc=stub_desc, + ) + return + + stub_origin = stub_target.type + # Do our best to figure out the fullname of the runtime object... + runtime_name: object + try: + runtime_name = runtime_origin.__qualname__ + except AttributeError: + runtime_name = getattr(runtime_origin, "__name__", MISSING) + if isinstance(runtime_name, str): + runtime_module: object = getattr(runtime_origin, "__module__", MISSING) + if isinstance(runtime_module, str): + if runtime_module == "collections.abc" or ( + runtime_module == "re" and runtime_name in {"Match", "Pattern"} + ): + runtime_module = "typing" + runtime_fullname = f"{runtime_module}.{runtime_name}" + if re.fullmatch(rf"_?{re.escape(stub_origin.fullname)}", runtime_fullname): + # Okay, we're probably fine. + return + + # Okay, either we couldn't construct a fullname + # or the fullname of the stub didn't match the fullname of the runtime. + # Fallback to a full structural check of the runtime vis-a-vis the stub. + yield from verify(stub_origin, runtime_origin, object_path) return if isinstance(stub_target, mypy.types.UnionType): - if not getattr(runtime, "__origin__", None) is Union: + # complain if runtime is not a Union or UnionType + if runtime_origin is not Union and ( + not (sys.version_info >= (3, 10) and isinstance(runtime, types.UnionType)) + ): yield Error(object_path, "is not a Union", stub, runtime, stub_desc=str(stub_target)) # could check Union contents here... return if isinstance(stub_target, mypy.types.TupleType): - if tuple not in getattr(runtime, "__mro__", ()): + if tuple not in getattr(runtime_origin, "__mro__", ()): yield Error( - object_path, - "is not a subclass of tuple", - stub, - runtime, - stub_desc=str(stub_target), + object_path, "is not a subclass of tuple", stub, runtime, stub_desc=stub_desc ) # could check Tuple contents here... return + if isinstance(stub_target, mypy.types.CallableType): + if runtime_origin is not collections.abc.Callable: + yield Error( + object_path, "is not a type alias for Callable", stub, runtime, stub_desc=stub_desc + ) + # could check Callable contents here... + return if isinstance(stub_target, mypy.types.AnyType): return - yield Error( - object_path, "is not a recognised type alias", stub, runtime, stub_desc=str(stub_target) - ) + yield Error(object_path, "is not a recognised type alias", stub, runtime, stub_desc=stub_desc) # ==================== diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index ef06608a9c1b..61c46ea01b91 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -44,6 +44,7 @@ def __getitem__(self, typeargs: Any) -> object: ... Callable: _SpecialForm = ... Generic: _SpecialForm = ... Protocol: _SpecialForm = ... +Union: _SpecialForm = ... class TypeVar: def __init__(self, name, covariant: bool = ..., contravariant: bool = ...) -> None: ... @@ -61,6 +62,7 @@ def __init__(self, name: str) -> None: ... class Coroutine(Generic[_T_co, _S, _R]): ... class Iterable(Generic[_T_co]): ... class Mapping(Generic[_K, _V]): ... +class Match(Generic[_T]): ... class Sequence(Iterable[_T_co]): ... class Tuple(Sequence[_T_co]): ... def overload(func: _T) -> _T: ... @@ -703,6 +705,190 @@ class Y: ... yield Case(stub="B = str", runtime="", error="B") # ... but only if the alias isn't private yield Case(stub="_C = int", runtime="", error=None) + yield Case( + stub=""" + from typing import Tuple + D = tuple[str, str] + E = Tuple[int, int, int] + F = Tuple[str, int] + """, + runtime=""" + from typing import List, Tuple + D = Tuple[str, str] + E = Tuple[int, int, int] + F = List[str] + """, + error="F", + ) + yield Case( + stub=""" + from typing import Union + G = str | int + H = Union[str, bool] + I = str | int + """, + runtime=""" + from typing import Union + G = Union[str, int] + H = Union[str, bool] + I = str + """, + error="I", + ) + yield Case( + stub=""" + import typing + from collections.abc import Iterable + from typing import Dict + K = dict[str, str] + L = Dict[int, int] + KK = Iterable[str] + LL = typing.Iterable[str] + """, + runtime=""" + from typing import Iterable, Dict + K = Dict[str, str] + L = Dict[int, int] + KK = Iterable[str] + LL = Iterable[str] + """, + error=None, + ) + yield Case( + stub=""" + from typing import Generic, TypeVar + _T = TypeVar("_T") + class _Spam(Generic[_T]): + def foo(self) -> None: ... + IntFood = _Spam[int] + """, + runtime=""" + from typing import Generic, TypeVar + _T = TypeVar("_T") + class _Bacon(Generic[_T]): + def foo(self, arg): pass + IntFood = _Bacon[int] + """, + error="IntFood.foo", + ) + yield Case(stub="StrList = list[str]", runtime="StrList = ['foo', 'bar']", error="StrList") + yield Case( + stub=""" + import collections.abc + from typing import Callable + N = Callable[[str], bool] + O = collections.abc.Callable[[int], str] + P = Callable[[str], bool] + """, + runtime=""" + from typing import Callable + N = Callable[[str], bool] + O = Callable[[int], str] + P = int + """, + error="P", + ) + yield Case( + stub=""" + class Foo: + class Bar: ... + BarAlias = Foo.Bar + """, + runtime=""" + class Foo: + class Bar: pass + BarAlias = Foo.Bar + """, + error=None, + ) + yield Case( + stub=""" + from io import StringIO + StringIOAlias = StringIO + """, + runtime=""" + from _io import StringIO + StringIOAlias = StringIO + """, + error=None, + ) + yield Case( + stub=""" + from typing import Match + M = Match[str] + """, + runtime=""" + from typing import Match + M = Match[str] + """, + error=None, + ) + yield Case( + stub=""" + class Baz: + def fizz(self) -> None: ... + BazAlias = Baz + """, + runtime=""" + class Baz: + def fizz(self): pass + BazAlias = Baz + Baz.__name__ = Baz.__qualname__ = Baz.__module__ = "New" + """, + error=None, + ) + yield Case( + stub=""" + class FooBar: + __module__: None # type: ignore + def fizz(self) -> None: ... + FooBarAlias = FooBar + """, + runtime=""" + class FooBar: + def fizz(self): pass + FooBarAlias = FooBar + FooBar.__module__ = None + """, + error=None, + ) + if sys.version_info >= (3, 10): + yield Case( + stub=""" + import collections.abc + import re + from typing import Callable, Dict, Match, Iterable, Tuple, Union + Q = Dict[str, str] + R = dict[int, int] + S = Tuple[int, int] + T = tuple[str, str] + U = int | str + V = Union[int, str] + W = Callable[[str], bool] + Z = collections.abc.Callable[[str], bool] + QQ = Iterable[str] + RR = collections.abc.Iterable[str] + MM = Match[str] + MMM = re.Match[str] + """, + runtime=""" + from collections.abc import Callable, Iterable + from re import Match + Q = dict[str, str] + R = dict[int, int] + S = tuple[int, int] + T = tuple[str, str] + U = int | str + V = int | str + W = Callable[[str], bool] + Z = Callable[[str], bool] + QQ = Iterable[str] + RR = Iterable[str] + MM = Match[str] + MMM = Match[str] + """, + error=None, + ) @collect_cases def test_enum(self) -> Iterator[Case]: