diff --git a/CHANGELOG.md b/CHANGELOG.md index f841de64..c00e2c0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ Features: * introduce Y032 (prefer `object` to `Any` for the second argument in `__eq__` and `__ne__` methods). * introduce Y033 (always use annotations in stubs, rather than type comments). +* introduce Y034 (detect common errors where return types are hardcoded, but they + should use `TypeVar`s instead). ## 22.1.0 diff --git a/README.md b/README.md index 1e28970d..a1013383 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ currently emitted: | Y031 | `TypedDict`s should use class-based syntax instead of assignment-based syntax wherever possible. (In situations where this is not possible, such as if a field is a Python keyword or an invalid identifier, this error will not be raised.) | Y032 | The second argument of an `__eq__` or `__ne__` method should usually be annotated with `object` rather than `Any`. | Y033 | Do not use type comments (e.g. `x = ... # type: int`) in stubs, even if the stub supports Python 2. Always use annotations instead (e.g. `x: int`). +| Y034 | Y034 detects common errors where certain methods are annotated as having a fixed return type, despite returning `self` at runtime. Such methods should be annotated with `_typeshed.Self`.

This check looks for `__new__`, `__enter__` and `__aenter__` methods that return the class's name unparameterised. It also looks for `__iter__` methods that return `Iterator`, even if the class inherits directly from `Iterator`, and for `__aiter__` methods that return `AsyncIterator`, even if the class inherits directly from `AsyncIterator`. The check excludes methods decorated with `@overload` or `@abstractmethod`. Many error codes enforce modern conventions, and some cannot yet be used in all cases: diff --git a/pyi.py b/pyi.py index c2a4153c..df1ab41d 100644 --- a/pyi.py +++ b/pyi.py @@ -297,6 +297,84 @@ def _is_object(node: ast.expr, name: str, *, from_: Container[str]) -> bool: _is_Literal = partial(_is_object, name="Literal", from_=_TYPING_MODULES) _is_abstractmethod = partial(_is_object, name="abstractmethod", from_={"abc"}) _is_Any = partial(_is_object, name="Any", from_={"typing"}) +_is_overload = partial(_is_object, name="overload", from_={"typing"}) +_is_final = partial(_is_object, name="final", from_=_TYPING_MODULES) + + +def _is_decorated_with_final( + node: ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef, +) -> bool: + return any(_is_final(decorator) for decorator in node.decorator_list) + + +def _get_collections_abc_obj_id(node: ast.expr | None) -> str | None: + """ + If the node represents a subscripted object from collections.abc or typing, + return the name of the object. + Else, return None. + + >>> import ast + >>> node1 = ast.parse('AsyncIterator[str]').body[0].value + >>> node2 = ast.parse('typing.AsyncIterator[str]').body[0].value + >>> node3 = ast.parse('typing_extensions.AsyncIterator[str]').body[0].value + >>> node4 = ast.parse('collections.abc.AsyncIterator[str]').body[0].value + >>> node5 = ast.parse('collections.OrderedDict[str, int]').body[0].value + >>> _get_collections_abc_obj_id(node1) + 'AsyncIterator' + >>> _get_collections_abc_obj_id(node2) + 'AsyncIterator' + >>> _get_collections_abc_obj_id(node3) + 'AsyncIterator' + >>> _get_collections_abc_obj_id(node4) + 'AsyncIterator' + >>> _get_collections_abc_obj_id(node5) is None + True + """ + if not isinstance(node, ast.Subscript): + return None + subscripted_obj = node.value + if isinstance(subscripted_obj, ast.Name): + return subscripted_obj.id + if not isinstance(subscripted_obj, ast.Attribute): + return None + obj_value, obj_attr = subscripted_obj.value, subscripted_obj.attr + if isinstance(obj_value, ast.Name) and obj_value.id in _TYPING_MODULES: + return obj_attr + if ( + isinstance(obj_value, ast.Attribute) + and _is_name(obj_value.value, "collections") + and obj_value.attr == "abc" + ): + return obj_attr + return None + + +_ITER_METHODS = frozenset({("Iterator", "__iter__"), ("AsyncIterator", "__aiter__")}) + + +def _has_bad_hardcoded_returns(method: ast.FunctionDef, classdef: ast.ClassDef) -> bool: + """Return `True` if `function` should be rewritten using `_typeshed.Self`.""" + # Much too complex for our purposes to worry about overloaded functions or abstractmethods + if any( + _is_overload(deco) or _is_abstractmethod(deco) for deco in method.decorator_list + ): + return False + + if not _non_kw_only_args_of(method.args): # weird, but theoretically possible + return False + + method_name, returns = method.name, method.returns + + if _is_name(returns, classdef.name): + return method_name in {"__enter__", "__new__"} and not _is_decorated_with_final( + classdef + ) + else: + return_obj_name = _get_collections_abc_obj_id(returns) + return (return_obj_name, method_name) in _ITER_METHODS and any( + _get_collections_abc_obj_id(base_node) == return_obj_name + for base_node in classdef.bases + ) def _unparse_assign_node(node: ast.Assign | ast.AnnAssign) -> str: @@ -304,6 +382,11 @@ def _unparse_assign_node(node: ast.Assign | ast.AnnAssign) -> str: return unparse(node).replace("\n", "") +def _unparse_func_node(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str: + """Unparse a function node, and reformat it to fit on one line.""" + return re.sub(r"\s+", " ", unparse(node)).strip() + + def _is_list_of_str_nodes(seq: list[ast.expr | None]) -> TypeGuard[list[ast.Str]]: return all(isinstance(item, ast.Str) for item in seq) @@ -341,6 +424,13 @@ def _is_bad_TypedDict(node: ast.Call) -> bool: ) +def _non_kw_only_args_of(args: ast.arguments) -> list[ast.arg]: + """Return a list containing the pos-only args and pos-or-kwd args of `args`""" + # pos-only args don't exist on 3.7 + pos_only_args: list[ast.arg] = getattr(args, "posonlyargs", []) + return pos_only_args + args.args + + @dataclass class NestingCounter: """Class to help the PyiVisitor keep track of internal state""" @@ -372,6 +462,8 @@ def __init__(self, filename: Path | None = None) -> None: self.string_literals_allowed = NestingCounter() self.in_function = NestingCounter() self.in_class = NestingCounter() + # This is only relevant for visiting classes + self.current_class_node: ast.ClassDef | None = None def __repr__(self) -> str: return f"{self.__class__.__name__}(filename={self.filename!r})" @@ -801,8 +893,11 @@ def _check_platform_check(self, node: ast.Compare) -> None: self.error(node, Y007) def visit_ClassDef(self, node: ast.ClassDef) -> None: + old_class_node = self.current_class_node + self.current_class_node = node with self.in_class.enabled(): self.generic_visit(node) + self.current_class_node = old_class_node # empty class body should contain "..." not "pass" if len(node.body) == 1: @@ -825,17 +920,42 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: ): self.error(statement, Y013) - def _visit_method(self, node: ast.FunctionDef) -> None: + def _Y034_error( + self, node: ast.FunctionDef | ast.AsyncFunctionDef, cls_name: str + ) -> None: + method_name = node.name + copied_node = deepcopy(node) + copied_node.decorator_list.clear() + copied_node.returns = ast.Name(id="Self") + first_arg = _non_kw_only_args_of(copied_node.args)[0] + if method_name == "__new__": + first_arg.annotation = ast.Subscript( + value=ast.Name(id="type"), slice=ast.Name(id="Self") + ) + referrer = '"__new__" methods' + else: + first_arg.annotation = ast.Name(id="Self") + referrer = f'"{method_name}" methods in classes like "{cls_name}"' + error_message = Y034.format( + methods=referrer, + method_name=f"{cls_name}.{method_name}", + suggested_syntax=_unparse_func_node(copied_node), + ) + self.error(node, error_message) + + def _visit_synchronous_method(self, node: ast.FunctionDef) -> None: method_name = node.name all_args = node.args + classdef = self.current_class_node + assert classdef is not None + + if _has_bad_hardcoded_returns(node, classdef=classdef): + return self._Y034_error(node=node, cls_name=classdef.name) if all_args.kwonlyargs: return - # pos-only args don't exist on 3.7 - pos_only_args: list[ast.arg] = getattr(all_args, "posonlyargs", []) - pos_or_kwd_args = all_args.args - non_kw_only_args = pos_only_args + pos_or_kwd_args + non_kw_only_args = _non_kw_only_args_of(all_args) # Raise an error for defining __str__ or __repr__ on a class, but only if: # 1). The method is not decorated with @abstractmethod @@ -854,10 +974,25 @@ def _visit_method(self, node: ast.FunctionDef) -> None: def visit_FunctionDef(self, node: ast.FunctionDef) -> None: if self.in_class.active: - self._visit_method(node) + self._visit_synchronous_method(node) self._visit_function(node) def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + if self.in_class.active: + classdef = self.current_class_node + assert classdef is not None + if ( + not any( + _is_overload(deco) or _is_abstractmethod(deco) + for deco in node.decorator_list + ) + and node.name == "__aenter__" + and _is_name(node.returns, classdef.name) + # weird, but theoretically possible for there to be 0 non-kw-only args + and _non_kw_only_args_of(node.args) + and not _is_decorated_with_final(classdef) + ): + self._Y034_error(node=node, cls_name=classdef.name) self._visit_function(node) def _Y019_error( @@ -865,10 +1000,8 @@ def _Y019_error( ) -> None: cleaned_method = deepcopy(node) cleaned_method.decorator_list.clear() - new_syntax = unparse(cleaned_method) + new_syntax = _unparse_func_node(cleaned_method) new_syntax = re.sub(rf"\b{typevar_name}\b", "Self", new_syntax) - new_syntax = re.sub(r"\s+", " ", new_syntax).strip() - self.error( # pass the node for the first argument to `self.error`, # rather than the function node, @@ -1124,3 +1257,4 @@ def parse_options(cls, optmanager, options, extra_args) -> None: 'Y032 Prefer "object" to "Any" for the second parameter in "{method_name}" methods' ) Y033 = 'Y033 Do not use type comments in stubs (e.g. use "x: int" instead of "x = ... # type: int")' +Y034 = 'Y034 {methods} usually return "self" at runtime. Consider using "_typeshed.Self" in "{method_name}", e.g. "{suggested_syntax}"' diff --git a/tests/classdefs.pyi b/tests/classdefs.pyi index 11224687..7512f776 100644 --- a/tests/classdefs.pyi +++ b/tests/classdefs.pyi @@ -1,29 +1,48 @@ import abc +import collections.abc import typing from abc import abstractmethod -from typing import Any +from typing import Any, AsyncIterator, Iterator, overload + +import typing_extensions +from _typeshed import Self +from typing_extensions import final class Bad: + def __new__(cls, *args: Any, **kwargs: Any) -> Bad: ... # Y034 "__new__" methods usually return "self" at runtime. Consider using "_typeshed.Self" in "Bad.__new__", e.g. "def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self: ..." def __repr__(self) -> str: ... # Y029 Defining __repr__ or __str__ in a stub is almost always redundant def __str__(self) -> str: ... # Y029 Defining __repr__ or __str__ in a stub is almost always redundant def __eq__(self, other: Any) -> bool: ... # Y032 Prefer "object" to "Any" for the second parameter in "__eq__" methods def __ne__(self, other: typing.Any) -> typing.Any: ... # Y032 Prefer "object" to "Any" for the second parameter in "__ne__" methods + def __enter__(self) -> Bad: ... # Y034 "__enter__" methods in classes like "Bad" usually return "self" at runtime. Consider using "_typeshed.Self" in "Bad.__enter__", e.g. "def __enter__(self: Self) -> Self: ..." + async def __aenter__(self) -> Bad: ... # Y034 "__aenter__" methods in classes like "Bad" usually return "self" at runtime. Consider using "_typeshed.Self" in "Bad.__aenter__", e.g. "async def __aenter__(self: Self) -> Self: ..." class Good: + def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self: ... @abstractmethod def __str__(self) -> str: ... @abc.abstractmethod def __repr__(self) -> str: ... def __eq__(self, other: object) -> bool: ... def __ne__(self, obj: object) -> int: ... + def __enter__(self: Self) -> Self: ... + async def __aenter__(self: Self) -> Self: ... class Fine: + @overload + def __new__(cls, foo: int) -> FineSubclass: ... + @overload + def __new__(cls, *args: Any, **kwargs: Any) -> Fine: ... @abc.abstractmethod def __str__(self) -> str: ... @abc.abstractmethod def __repr__(self) -> str: ... def __eq__(self, other: Any, strange_extra_arg: list[str]) -> Any: ... def __ne__(self, *, kw_only_other: Any) -> bool: ... + def __enter__(self) -> None: ... + async def __aenter__(self) -> bool: ... + +class FineSubclass(Fine): ... class AlsoGood(str): def __str__(self) -> AlsoGood: ... @@ -33,6 +52,47 @@ class FineAndDandy: def __str__(self, weird_extra_arg) -> str: ... def __repr__(self, weird_extra_arg_with_default=...) -> str: ... +@final +class WillNotBeSubclassed: + def __new__(cls, *args: Any, **kwargs: Any) -> WillNotBeSubclassed: ... + def __enter__(self) -> WillNotBeSubclassed: ... + async def __aenter__(self) -> WillNotBeSubclassed: ... + +# we don't emit an error for these; out of scope for a linter +class InvalidButPluginDoesNotCrash: + def __new__() -> InvalidButPluginDoesNotCrash: ... + def __enter__() -> InvalidButPluginDoesNotCrash: ... + async def __aenter__() -> InvalidButPluginDoesNotCrash: ... + +class BadIterator1(Iterator[int]): + def __iter__(self) -> Iterator[int]: ... # Y034 "__iter__" methods in classes like "BadIterator1" usually return "self" at runtime. Consider using "_typeshed.Self" in "BadIterator1.__iter__", e.g. "def __iter__(self: Self) -> Self: ..." + +class BadIterator2(typing.Iterator[int]): + def __iter__(self) -> Iterator[int]: ... # Y034 "__iter__" methods in classes like "BadIterator2" usually return "self" at runtime. Consider using "_typeshed.Self" in "BadIterator2.__iter__", e.g. "def __iter__(self: Self) -> Self: ..." + +class BadIterator3(typing_extensions.Iterator[int]): + def __iter__(self) -> collections.abc.Iterator[int]: ... # Y034 "__iter__" methods in classes like "BadIterator3" usually return "self" at runtime. Consider using "_typeshed.Self" in "BadIterator3.__iter__", e.g. "def __iter__(self: Self) -> Self: ..." + +class BadAsyncIterator(collections.abc.AsyncIterator[str]): + def __aiter__(self) -> typing.AsyncIterator[str]: ... # Y034 "__aiter__" methods in classes like "BadAsyncIterator" usually return "self" at runtime. Consider using "_typeshed.Self" in "BadAsyncIterator.__aiter__", e.g. "def __aiter__(self: Self) -> Self: ..." + +class Abstract(Iterator[str]): + @abstractmethod + def __iter__(self) -> Iterator[str]: ... + @abstractmethod + def __enter__(self) -> Abstract: ... + @abstractmethod + async def __aenter__(self) -> Abstract: ... + +class GoodIterator(Iterator[str]): + def __iter__(self: Self) -> Self: ... + +class GoodAsyncIterator(AsyncIterator[int]): + def __aiter__(self: Self) -> Self: ... + +class DoesNotInheritFromIterator: + def __iter__(self) -> DoesNotInheritFromIterator: ... + def __repr__(self) -> str: ... def __str__(self) -> str: ... def __eq__(self, other: Any) -> bool: ...