Skip to content

Commit 5849230

Browse files
AlexWaygoodpre-commit-ci[bot]JelleZijlstra
authored
Add a check for methods that hardcode their return type, but shouldn't (#174)
Introduce Y034: detect common errors where certain methods are annotated as having a fixed return type, despite returning `self` at runtime. Such methods should be annotated using the `_typeshed.Self` TypeVar. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jelle Zijlstra <[email protected]>
1 parent 49e85cf commit 5849230

File tree

4 files changed

+207
-10
lines changed

4 files changed

+207
-10
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ Features:
1515
* introduce Y032 (prefer `object` to `Any` for the second argument in `__eq__` and
1616
`__ne__` methods).
1717
* introduce Y033 (always use annotations in stubs, rather than type comments).
18+
* introduce Y034 (detect common errors where return types are hardcoded, but they
19+
should use `TypeVar`s instead).
1820

1921
## 22.1.0
2022

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ currently emitted:
6565
| Y031 | `TypedDict`s should use class-based syntax instead of assignment-based syntax wherever possible. (In situations where this is not possible, such as if a field is a Python keyword or an invalid identifier, this error will not be raised.)
6666
| Y032 | The second argument of an `__eq__` or `__ne__` method should usually be annotated with `object` rather than `Any`.
6767
| Y033 | Do not use type comments (e.g. `x = ... # type: int`) in stubs, even if the stub supports Python 2. Always use annotations instead (e.g. `x: int`).
68+
| Y034 | Y034 detects common errors where certain methods are annotated as having a fixed return type, despite returning `self` at runtime. Such methods should be annotated with `_typeshed.Self`.<br><br>This check looks for `__new__`, `__enter__` and `__aenter__` methods that return the class's name unparameterised. It also looks for `__iter__` methods that return `Iterator`, even if the class inherits directly from `Iterator`, and for `__aiter__` methods that return `AsyncIterator`, even if the class inherits directly from `AsyncIterator`. The check excludes methods decorated with `@overload` or `@abstractmethod`.
6869

6970
Many error codes enforce modern conventions, and some cannot yet be used in
7071
all cases:

pyi.py

Lines changed: 143 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -297,13 +297,96 @@ def _is_object(node: ast.expr, name: str, *, from_: Container[str]) -> bool:
297297
_is_Literal = partial(_is_object, name="Literal", from_=_TYPING_MODULES)
298298
_is_abstractmethod = partial(_is_object, name="abstractmethod", from_={"abc"})
299299
_is_Any = partial(_is_object, name="Any", from_={"typing"})
300+
_is_overload = partial(_is_object, name="overload", from_={"typing"})
301+
_is_final = partial(_is_object, name="final", from_=_TYPING_MODULES)
302+
303+
304+
def _is_decorated_with_final(
305+
node: ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef,
306+
) -> bool:
307+
return any(_is_final(decorator) for decorator in node.decorator_list)
308+
309+
310+
def _get_collections_abc_obj_id(node: ast.expr | None) -> str | None:
311+
"""
312+
If the node represents a subscripted object from collections.abc or typing,
313+
return the name of the object.
314+
Else, return None.
315+
316+
>>> import ast
317+
>>> node1 = ast.parse('AsyncIterator[str]').body[0].value
318+
>>> node2 = ast.parse('typing.AsyncIterator[str]').body[0].value
319+
>>> node3 = ast.parse('typing_extensions.AsyncIterator[str]').body[0].value
320+
>>> node4 = ast.parse('collections.abc.AsyncIterator[str]').body[0].value
321+
>>> node5 = ast.parse('collections.OrderedDict[str, int]').body[0].value
322+
>>> _get_collections_abc_obj_id(node1)
323+
'AsyncIterator'
324+
>>> _get_collections_abc_obj_id(node2)
325+
'AsyncIterator'
326+
>>> _get_collections_abc_obj_id(node3)
327+
'AsyncIterator'
328+
>>> _get_collections_abc_obj_id(node4)
329+
'AsyncIterator'
330+
>>> _get_collections_abc_obj_id(node5) is None
331+
True
332+
"""
333+
if not isinstance(node, ast.Subscript):
334+
return None
335+
subscripted_obj = node.value
336+
if isinstance(subscripted_obj, ast.Name):
337+
return subscripted_obj.id
338+
if not isinstance(subscripted_obj, ast.Attribute):
339+
return None
340+
obj_value, obj_attr = subscripted_obj.value, subscripted_obj.attr
341+
if isinstance(obj_value, ast.Name) and obj_value.id in _TYPING_MODULES:
342+
return obj_attr
343+
if (
344+
isinstance(obj_value, ast.Attribute)
345+
and _is_name(obj_value.value, "collections")
346+
and obj_value.attr == "abc"
347+
):
348+
return obj_attr
349+
return None
350+
351+
352+
_ITER_METHODS = frozenset({("Iterator", "__iter__"), ("AsyncIterator", "__aiter__")})
353+
354+
355+
def _has_bad_hardcoded_returns(method: ast.FunctionDef, classdef: ast.ClassDef) -> bool:
356+
"""Return `True` if `function` should be rewritten using `_typeshed.Self`."""
357+
# Much too complex for our purposes to worry about overloaded functions or abstractmethods
358+
if any(
359+
_is_overload(deco) or _is_abstractmethod(deco) for deco in method.decorator_list
360+
):
361+
return False
362+
363+
if not _non_kw_only_args_of(method.args): # weird, but theoretically possible
364+
return False
365+
366+
method_name, returns = method.name, method.returns
367+
368+
if _is_name(returns, classdef.name):
369+
return method_name in {"__enter__", "__new__"} and not _is_decorated_with_final(
370+
classdef
371+
)
372+
else:
373+
return_obj_name = _get_collections_abc_obj_id(returns)
374+
return (return_obj_name, method_name) in _ITER_METHODS and any(
375+
_get_collections_abc_obj_id(base_node) == return_obj_name
376+
for base_node in classdef.bases
377+
)
300378

301379

302380
def _unparse_assign_node(node: ast.Assign | ast.AnnAssign) -> str:
303381
"""Unparse an Assign node, and remove any newlines in it"""
304382
return unparse(node).replace("\n", "")
305383

306384

385+
def _unparse_func_node(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str:
386+
"""Unparse a function node, and reformat it to fit on one line."""
387+
return re.sub(r"\s+", " ", unparse(node)).strip()
388+
389+
307390
def _is_list_of_str_nodes(seq: list[ast.expr | None]) -> TypeGuard[list[ast.Str]]:
308391
return all(isinstance(item, ast.Str) for item in seq)
309392

@@ -341,6 +424,13 @@ def _is_bad_TypedDict(node: ast.Call) -> bool:
341424
)
342425

343426

427+
def _non_kw_only_args_of(args: ast.arguments) -> list[ast.arg]:
428+
"""Return a list containing the pos-only args and pos-or-kwd args of `args`"""
429+
# pos-only args don't exist on 3.7
430+
pos_only_args: list[ast.arg] = getattr(args, "posonlyargs", [])
431+
return pos_only_args + args.args
432+
433+
344434
@dataclass
345435
class NestingCounter:
346436
"""Class to help the PyiVisitor keep track of internal state"""
@@ -372,6 +462,8 @@ def __init__(self, filename: Path | None = None) -> None:
372462
self.string_literals_allowed = NestingCounter()
373463
self.in_function = NestingCounter()
374464
self.in_class = NestingCounter()
465+
# This is only relevant for visiting classes
466+
self.current_class_node: ast.ClassDef | None = None
375467

376468
def __repr__(self) -> str:
377469
return f"{self.__class__.__name__}(filename={self.filename!r})"
@@ -801,8 +893,11 @@ def _check_platform_check(self, node: ast.Compare) -> None:
801893
self.error(node, Y007)
802894

803895
def visit_ClassDef(self, node: ast.ClassDef) -> None:
896+
old_class_node = self.current_class_node
897+
self.current_class_node = node
804898
with self.in_class.enabled():
805899
self.generic_visit(node)
900+
self.current_class_node = old_class_node
806901

807902
# empty class body should contain "..." not "pass"
808903
if len(node.body) == 1:
@@ -825,17 +920,42 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
825920
):
826921
self.error(statement, Y013)
827922

828-
def _visit_method(self, node: ast.FunctionDef) -> None:
923+
def _Y034_error(
924+
self, node: ast.FunctionDef | ast.AsyncFunctionDef, cls_name: str
925+
) -> None:
926+
method_name = node.name
927+
copied_node = deepcopy(node)
928+
copied_node.decorator_list.clear()
929+
copied_node.returns = ast.Name(id="Self")
930+
first_arg = _non_kw_only_args_of(copied_node.args)[0]
931+
if method_name == "__new__":
932+
first_arg.annotation = ast.Subscript(
933+
value=ast.Name(id="type"), slice=ast.Name(id="Self")
934+
)
935+
referrer = '"__new__" methods'
936+
else:
937+
first_arg.annotation = ast.Name(id="Self")
938+
referrer = f'"{method_name}" methods in classes like "{cls_name}"'
939+
error_message = Y034.format(
940+
methods=referrer,
941+
method_name=f"{cls_name}.{method_name}",
942+
suggested_syntax=_unparse_func_node(copied_node),
943+
)
944+
self.error(node, error_message)
945+
946+
def _visit_synchronous_method(self, node: ast.FunctionDef) -> None:
829947
method_name = node.name
830948
all_args = node.args
949+
classdef = self.current_class_node
950+
assert classdef is not None
951+
952+
if _has_bad_hardcoded_returns(node, classdef=classdef):
953+
return self._Y034_error(node=node, cls_name=classdef.name)
831954

832955
if all_args.kwonlyargs:
833956
return
834957

835-
# pos-only args don't exist on 3.7
836-
pos_only_args: list[ast.arg] = getattr(all_args, "posonlyargs", [])
837-
pos_or_kwd_args = all_args.args
838-
non_kw_only_args = pos_only_args + pos_or_kwd_args
958+
non_kw_only_args = _non_kw_only_args_of(all_args)
839959

840960
# Raise an error for defining __str__ or __repr__ on a class, but only if:
841961
# 1). The method is not decorated with @abstractmethod
@@ -854,21 +974,34 @@ def _visit_method(self, node: ast.FunctionDef) -> None:
854974

855975
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
856976
if self.in_class.active:
857-
self._visit_method(node)
977+
self._visit_synchronous_method(node)
858978
self._visit_function(node)
859979

860980
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
981+
if self.in_class.active:
982+
classdef = self.current_class_node
983+
assert classdef is not None
984+
if (
985+
not any(
986+
_is_overload(deco) or _is_abstractmethod(deco)
987+
for deco in node.decorator_list
988+
)
989+
and node.name == "__aenter__"
990+
and _is_name(node.returns, classdef.name)
991+
# weird, but theoretically possible for there to be 0 non-kw-only args
992+
and _non_kw_only_args_of(node.args)
993+
and not _is_decorated_with_final(classdef)
994+
):
995+
self._Y034_error(node=node, cls_name=classdef.name)
861996
self._visit_function(node)
862997

863998
def _Y019_error(
864999
self, node: ast.FunctionDef | ast.AsyncFunctionDef, typevar_name: str
8651000
) -> None:
8661001
cleaned_method = deepcopy(node)
8671002
cleaned_method.decorator_list.clear()
868-
new_syntax = unparse(cleaned_method)
1003+
new_syntax = _unparse_func_node(cleaned_method)
8691004
new_syntax = re.sub(rf"\b{typevar_name}\b", "Self", new_syntax)
870-
new_syntax = re.sub(r"\s+", " ", new_syntax).strip()
871-
8721005
self.error(
8731006
# pass the node for the first argument to `self.error`,
8741007
# rather than the function node,
@@ -1124,3 +1257,4 @@ def parse_options(cls, optmanager, options, extra_args) -> None:
11241257
'Y032 Prefer "object" to "Any" for the second parameter in "{method_name}" methods'
11251258
)
11261259
Y033 = 'Y033 Do not use type comments in stubs (e.g. use "x: int" instead of "x = ... # type: int")'
1260+
Y034 = 'Y034 {methods} usually return "self" at runtime. Consider using "_typeshed.Self" in "{method_name}", e.g. "{suggested_syntax}"'

tests/classdefs.pyi

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,48 @@
11
import abc
2+
import collections.abc
23
import typing
34
from abc import abstractmethod
4-
from typing import Any
5+
from typing import Any, AsyncIterator, Iterator, overload
6+
7+
import typing_extensions
8+
from _typeshed import Self
9+
from typing_extensions import final
510

611
class Bad:
12+
def __new__(cls, *args: Any, **kwargs: Any) -> Bad: ... # Y034 "__new__" methods usually return "self" at runtime. Consider using "_typeshed.Self" in "Bad.__new__", e.g. "def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self: ..."
713
def __repr__(self) -> str: ... # Y029 Defining __repr__ or __str__ in a stub is almost always redundant
814
def __str__(self) -> str: ... # Y029 Defining __repr__ or __str__ in a stub is almost always redundant
915
def __eq__(self, other: Any) -> bool: ... # Y032 Prefer "object" to "Any" for the second parameter in "__eq__" methods
1016
def __ne__(self, other: typing.Any) -> typing.Any: ... # Y032 Prefer "object" to "Any" for the second parameter in "__ne__" methods
17+
def __enter__(self) -> Bad: ... # Y034 "__enter__" methods in classes like "Bad" usually return "self" at runtime. Consider using "_typeshed.Self" in "Bad.__enter__", e.g. "def __enter__(self: Self) -> Self: ..."
18+
async def __aenter__(self) -> Bad: ... # Y034 "__aenter__" methods in classes like "Bad" usually return "self" at runtime. Consider using "_typeshed.Self" in "Bad.__aenter__", e.g. "async def __aenter__(self: Self) -> Self: ..."
1119

1220
class Good:
21+
def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self: ...
1322
@abstractmethod
1423
def __str__(self) -> str: ...
1524
@abc.abstractmethod
1625
def __repr__(self) -> str: ...
1726
def __eq__(self, other: object) -> bool: ...
1827
def __ne__(self, obj: object) -> int: ...
28+
def __enter__(self: Self) -> Self: ...
29+
async def __aenter__(self: Self) -> Self: ...
1930

2031
class Fine:
32+
@overload
33+
def __new__(cls, foo: int) -> FineSubclass: ...
34+
@overload
35+
def __new__(cls, *args: Any, **kwargs: Any) -> Fine: ...
2136
@abc.abstractmethod
2237
def __str__(self) -> str: ...
2338
@abc.abstractmethod
2439
def __repr__(self) -> str: ...
2540
def __eq__(self, other: Any, strange_extra_arg: list[str]) -> Any: ...
2641
def __ne__(self, *, kw_only_other: Any) -> bool: ...
42+
def __enter__(self) -> None: ...
43+
async def __aenter__(self) -> bool: ...
44+
45+
class FineSubclass(Fine): ...
2746

2847
class AlsoGood(str):
2948
def __str__(self) -> AlsoGood: ...
@@ -33,6 +52,47 @@ class FineAndDandy:
3352
def __str__(self, weird_extra_arg) -> str: ...
3453
def __repr__(self, weird_extra_arg_with_default=...) -> str: ...
3554

55+
@final
56+
class WillNotBeSubclassed:
57+
def __new__(cls, *args: Any, **kwargs: Any) -> WillNotBeSubclassed: ...
58+
def __enter__(self) -> WillNotBeSubclassed: ...
59+
async def __aenter__(self) -> WillNotBeSubclassed: ...
60+
61+
# we don't emit an error for these; out of scope for a linter
62+
class InvalidButPluginDoesNotCrash:
63+
def __new__() -> InvalidButPluginDoesNotCrash: ...
64+
def __enter__() -> InvalidButPluginDoesNotCrash: ...
65+
async def __aenter__() -> InvalidButPluginDoesNotCrash: ...
66+
67+
class BadIterator1(Iterator[int]):
68+
def __iter__(self) -> Iterator[int]: ... # Y034 "__iter__" methods in classes like "BadIterator1" usually return "self" at runtime. Consider using "_typeshed.Self" in "BadIterator1.__iter__", e.g. "def __iter__(self: Self) -> Self: ..."
69+
70+
class BadIterator2(typing.Iterator[int]):
71+
def __iter__(self) -> Iterator[int]: ... # Y034 "__iter__" methods in classes like "BadIterator2" usually return "self" at runtime. Consider using "_typeshed.Self" in "BadIterator2.__iter__", e.g. "def __iter__(self: Self) -> Self: ..."
72+
73+
class BadIterator3(typing_extensions.Iterator[int]):
74+
def __iter__(self) -> collections.abc.Iterator[int]: ... # Y034 "__iter__" methods in classes like "BadIterator3" usually return "self" at runtime. Consider using "_typeshed.Self" in "BadIterator3.__iter__", e.g. "def __iter__(self: Self) -> Self: ..."
75+
76+
class BadAsyncIterator(collections.abc.AsyncIterator[str]):
77+
def __aiter__(self) -> typing.AsyncIterator[str]: ... # Y034 "__aiter__" methods in classes like "BadAsyncIterator" usually return "self" at runtime. Consider using "_typeshed.Self" in "BadAsyncIterator.__aiter__", e.g. "def __aiter__(self: Self) -> Self: ..."
78+
79+
class Abstract(Iterator[str]):
80+
@abstractmethod
81+
def __iter__(self) -> Iterator[str]: ...
82+
@abstractmethod
83+
def __enter__(self) -> Abstract: ...
84+
@abstractmethod
85+
async def __aenter__(self) -> Abstract: ...
86+
87+
class GoodIterator(Iterator[str]):
88+
def __iter__(self: Self) -> Self: ...
89+
90+
class GoodAsyncIterator(AsyncIterator[int]):
91+
def __aiter__(self: Self) -> Self: ...
92+
93+
class DoesNotInheritFromIterator:
94+
def __iter__(self) -> DoesNotInheritFromIterator: ...
95+
3696
def __repr__(self) -> str: ...
3797
def __str__(self) -> str: ...
3898
def __eq__(self, other: Any) -> bool: ...

0 commit comments

Comments
 (0)