diff --git a/CHANGELOG.md b/CHANGELOG.md
index f841de64..c00e2c0a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -15,6 +15,8 @@ Features:
* introduce Y032 (prefer `object` to `Any` for the second argument in `__eq__` and
`__ne__` methods).
* introduce Y033 (always use annotations in stubs, rather than type comments).
+* introduce Y034 (detect common errors where return types are hardcoded, but they
+ should use `TypeVar`s instead).
## 22.1.0
diff --git a/README.md b/README.md
index 1e28970d..a1013383 100644
--- a/README.md
+++ b/README.md
@@ -65,6 +65,7 @@ currently emitted:
| Y031 | `TypedDict`s should use class-based syntax instead of assignment-based syntax wherever possible. (In situations where this is not possible, such as if a field is a Python keyword or an invalid identifier, this error will not be raised.)
| Y032 | The second argument of an `__eq__` or `__ne__` method should usually be annotated with `object` rather than `Any`.
| Y033 | Do not use type comments (e.g. `x = ... # type: int`) in stubs, even if the stub supports Python 2. Always use annotations instead (e.g. `x: int`).
+| Y034 | Y034 detects common errors where certain methods are annotated as having a fixed return type, despite returning `self` at runtime. Such methods should be annotated with `_typeshed.Self`.
This check looks for `__new__`, `__enter__` and `__aenter__` methods that return the class's name unparameterised. It also looks for `__iter__` methods that return `Iterator`, even if the class inherits directly from `Iterator`, and for `__aiter__` methods that return `AsyncIterator`, even if the class inherits directly from `AsyncIterator`. The check excludes methods decorated with `@overload` or `@abstractmethod`.
Many error codes enforce modern conventions, and some cannot yet be used in
all cases:
diff --git a/pyi.py b/pyi.py
index c2a4153c..df1ab41d 100644
--- a/pyi.py
+++ b/pyi.py
@@ -297,6 +297,84 @@ def _is_object(node: ast.expr, name: str, *, from_: Container[str]) -> bool:
_is_Literal = partial(_is_object, name="Literal", from_=_TYPING_MODULES)
_is_abstractmethod = partial(_is_object, name="abstractmethod", from_={"abc"})
_is_Any = partial(_is_object, name="Any", from_={"typing"})
+_is_overload = partial(_is_object, name="overload", from_={"typing"})
+_is_final = partial(_is_object, name="final", from_=_TYPING_MODULES)
+
+
+def _is_decorated_with_final(
+ node: ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef,
+) -> bool:
+ return any(_is_final(decorator) for decorator in node.decorator_list)
+
+
+def _get_collections_abc_obj_id(node: ast.expr | None) -> str | None:
+ """
+ If the node represents a subscripted object from collections.abc or typing,
+ return the name of the object.
+ Else, return None.
+
+ >>> import ast
+ >>> node1 = ast.parse('AsyncIterator[str]').body[0].value
+ >>> node2 = ast.parse('typing.AsyncIterator[str]').body[0].value
+ >>> node3 = ast.parse('typing_extensions.AsyncIterator[str]').body[0].value
+ >>> node4 = ast.parse('collections.abc.AsyncIterator[str]').body[0].value
+ >>> node5 = ast.parse('collections.OrderedDict[str, int]').body[0].value
+ >>> _get_collections_abc_obj_id(node1)
+ 'AsyncIterator'
+ >>> _get_collections_abc_obj_id(node2)
+ 'AsyncIterator'
+ >>> _get_collections_abc_obj_id(node3)
+ 'AsyncIterator'
+ >>> _get_collections_abc_obj_id(node4)
+ 'AsyncIterator'
+ >>> _get_collections_abc_obj_id(node5) is None
+ True
+ """
+ if not isinstance(node, ast.Subscript):
+ return None
+ subscripted_obj = node.value
+ if isinstance(subscripted_obj, ast.Name):
+ return subscripted_obj.id
+ if not isinstance(subscripted_obj, ast.Attribute):
+ return None
+ obj_value, obj_attr = subscripted_obj.value, subscripted_obj.attr
+ if isinstance(obj_value, ast.Name) and obj_value.id in _TYPING_MODULES:
+ return obj_attr
+ if (
+ isinstance(obj_value, ast.Attribute)
+ and _is_name(obj_value.value, "collections")
+ and obj_value.attr == "abc"
+ ):
+ return obj_attr
+ return None
+
+
+_ITER_METHODS = frozenset({("Iterator", "__iter__"), ("AsyncIterator", "__aiter__")})
+
+
+def _has_bad_hardcoded_returns(method: ast.FunctionDef, classdef: ast.ClassDef) -> bool:
+ """Return `True` if `function` should be rewritten using `_typeshed.Self`."""
+ # Much too complex for our purposes to worry about overloaded functions or abstractmethods
+ if any(
+ _is_overload(deco) or _is_abstractmethod(deco) for deco in method.decorator_list
+ ):
+ return False
+
+ if not _non_kw_only_args_of(method.args): # weird, but theoretically possible
+ return False
+
+ method_name, returns = method.name, method.returns
+
+ if _is_name(returns, classdef.name):
+ return method_name in {"__enter__", "__new__"} and not _is_decorated_with_final(
+ classdef
+ )
+ else:
+ return_obj_name = _get_collections_abc_obj_id(returns)
+ return (return_obj_name, method_name) in _ITER_METHODS and any(
+ _get_collections_abc_obj_id(base_node) == return_obj_name
+ for base_node in classdef.bases
+ )
def _unparse_assign_node(node: ast.Assign | ast.AnnAssign) -> str:
@@ -304,6 +382,11 @@ def _unparse_assign_node(node: ast.Assign | ast.AnnAssign) -> str:
return unparse(node).replace("\n", "")
+def _unparse_func_node(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str:
+ """Unparse a function node, and reformat it to fit on one line."""
+ return re.sub(r"\s+", " ", unparse(node)).strip()
+
+
def _is_list_of_str_nodes(seq: list[ast.expr | None]) -> TypeGuard[list[ast.Str]]:
return all(isinstance(item, ast.Str) for item in seq)
@@ -341,6 +424,13 @@ def _is_bad_TypedDict(node: ast.Call) -> bool:
)
+def _non_kw_only_args_of(args: ast.arguments) -> list[ast.arg]:
+ """Return a list containing the pos-only args and pos-or-kwd args of `args`"""
+ # pos-only args don't exist on 3.7
+ pos_only_args: list[ast.arg] = getattr(args, "posonlyargs", [])
+ return pos_only_args + args.args
+
+
@dataclass
class NestingCounter:
"""Class to help the PyiVisitor keep track of internal state"""
@@ -372,6 +462,8 @@ def __init__(self, filename: Path | None = None) -> None:
self.string_literals_allowed = NestingCounter()
self.in_function = NestingCounter()
self.in_class = NestingCounter()
+ # This is only relevant for visiting classes
+ self.current_class_node: ast.ClassDef | None = None
def __repr__(self) -> str:
return f"{self.__class__.__name__}(filename={self.filename!r})"
@@ -801,8 +893,11 @@ def _check_platform_check(self, node: ast.Compare) -> None:
self.error(node, Y007)
def visit_ClassDef(self, node: ast.ClassDef) -> None:
+ old_class_node = self.current_class_node
+ self.current_class_node = node
with self.in_class.enabled():
self.generic_visit(node)
+ self.current_class_node = old_class_node
# empty class body should contain "..." not "pass"
if len(node.body) == 1:
@@ -825,17 +920,42 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
):
self.error(statement, Y013)
- def _visit_method(self, node: ast.FunctionDef) -> None:
+ def _Y034_error(
+ self, node: ast.FunctionDef | ast.AsyncFunctionDef, cls_name: str
+ ) -> None:
+ method_name = node.name
+ copied_node = deepcopy(node)
+ copied_node.decorator_list.clear()
+ copied_node.returns = ast.Name(id="Self")
+ first_arg = _non_kw_only_args_of(copied_node.args)[0]
+ if method_name == "__new__":
+ first_arg.annotation = ast.Subscript(
+ value=ast.Name(id="type"), slice=ast.Name(id="Self")
+ )
+ referrer = '"__new__" methods'
+ else:
+ first_arg.annotation = ast.Name(id="Self")
+ referrer = f'"{method_name}" methods in classes like "{cls_name}"'
+ error_message = Y034.format(
+ methods=referrer,
+ method_name=f"{cls_name}.{method_name}",
+ suggested_syntax=_unparse_func_node(copied_node),
+ )
+ self.error(node, error_message)
+
+ def _visit_synchronous_method(self, node: ast.FunctionDef) -> None:
method_name = node.name
all_args = node.args
+ classdef = self.current_class_node
+ assert classdef is not None
+
+ if _has_bad_hardcoded_returns(node, classdef=classdef):
+ return self._Y034_error(node=node, cls_name=classdef.name)
if all_args.kwonlyargs:
return
- # pos-only args don't exist on 3.7
- pos_only_args: list[ast.arg] = getattr(all_args, "posonlyargs", [])
- pos_or_kwd_args = all_args.args
- non_kw_only_args = pos_only_args + pos_or_kwd_args
+ non_kw_only_args = _non_kw_only_args_of(all_args)
# Raise an error for defining __str__ or __repr__ on a class, but only if:
# 1). The method is not decorated with @abstractmethod
@@ -854,10 +974,25 @@ def _visit_method(self, node: ast.FunctionDef) -> None:
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if self.in_class.active:
- self._visit_method(node)
+ self._visit_synchronous_method(node)
self._visit_function(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
+ if self.in_class.active:
+ classdef = self.current_class_node
+ assert classdef is not None
+ if (
+ not any(
+ _is_overload(deco) or _is_abstractmethod(deco)
+ for deco in node.decorator_list
+ )
+ and node.name == "__aenter__"
+ and _is_name(node.returns, classdef.name)
+ # weird, but theoretically possible for there to be 0 non-kw-only args
+ and _non_kw_only_args_of(node.args)
+ and not _is_decorated_with_final(classdef)
+ ):
+ self._Y034_error(node=node, cls_name=classdef.name)
self._visit_function(node)
def _Y019_error(
@@ -865,10 +1000,8 @@ def _Y019_error(
) -> None:
cleaned_method = deepcopy(node)
cleaned_method.decorator_list.clear()
- new_syntax = unparse(cleaned_method)
+ new_syntax = _unparse_func_node(cleaned_method)
new_syntax = re.sub(rf"\b{typevar_name}\b", "Self", new_syntax)
- new_syntax = re.sub(r"\s+", " ", new_syntax).strip()
-
self.error(
# pass the node for the first argument to `self.error`,
# rather than the function node,
@@ -1124,3 +1257,4 @@ def parse_options(cls, optmanager, options, extra_args) -> None:
'Y032 Prefer "object" to "Any" for the second parameter in "{method_name}" methods'
)
Y033 = 'Y033 Do not use type comments in stubs (e.g. use "x: int" instead of "x = ... # type: int")'
+Y034 = 'Y034 {methods} usually return "self" at runtime. Consider using "_typeshed.Self" in "{method_name}", e.g. "{suggested_syntax}"'
diff --git a/tests/classdefs.pyi b/tests/classdefs.pyi
index 11224687..7512f776 100644
--- a/tests/classdefs.pyi
+++ b/tests/classdefs.pyi
@@ -1,29 +1,48 @@
import abc
+import collections.abc
import typing
from abc import abstractmethod
-from typing import Any
+from typing import Any, AsyncIterator, Iterator, overload
+
+import typing_extensions
+from _typeshed import Self
+from typing_extensions import final
class Bad:
+ def __new__(cls, *args: Any, **kwargs: Any) -> Bad: ... # Y034 "__new__" methods usually return "self" at runtime. Consider using "_typeshed.Self" in "Bad.__new__", e.g. "def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self: ..."
def __repr__(self) -> str: ... # Y029 Defining __repr__ or __str__ in a stub is almost always redundant
def __str__(self) -> str: ... # Y029 Defining __repr__ or __str__ in a stub is almost always redundant
def __eq__(self, other: Any) -> bool: ... # Y032 Prefer "object" to "Any" for the second parameter in "__eq__" methods
def __ne__(self, other: typing.Any) -> typing.Any: ... # Y032 Prefer "object" to "Any" for the second parameter in "__ne__" methods
+ def __enter__(self) -> Bad: ... # Y034 "__enter__" methods in classes like "Bad" usually return "self" at runtime. Consider using "_typeshed.Self" in "Bad.__enter__", e.g. "def __enter__(self: Self) -> Self: ..."
+ async def __aenter__(self) -> Bad: ... # Y034 "__aenter__" methods in classes like "Bad" usually return "self" at runtime. Consider using "_typeshed.Self" in "Bad.__aenter__", e.g. "async def __aenter__(self: Self) -> Self: ..."
class Good:
+ def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self: ...
@abstractmethod
def __str__(self) -> str: ...
@abc.abstractmethod
def __repr__(self) -> str: ...
def __eq__(self, other: object) -> bool: ...
def __ne__(self, obj: object) -> int: ...
+ def __enter__(self: Self) -> Self: ...
+ async def __aenter__(self: Self) -> Self: ...
class Fine:
+ @overload
+ def __new__(cls, foo: int) -> FineSubclass: ...
+ @overload
+ def __new__(cls, *args: Any, **kwargs: Any) -> Fine: ...
@abc.abstractmethod
def __str__(self) -> str: ...
@abc.abstractmethod
def __repr__(self) -> str: ...
def __eq__(self, other: Any, strange_extra_arg: list[str]) -> Any: ...
def __ne__(self, *, kw_only_other: Any) -> bool: ...
+ def __enter__(self) -> None: ...
+ async def __aenter__(self) -> bool: ...
+
+class FineSubclass(Fine): ...
class AlsoGood(str):
def __str__(self) -> AlsoGood: ...
@@ -33,6 +52,47 @@ class FineAndDandy:
def __str__(self, weird_extra_arg) -> str: ...
def __repr__(self, weird_extra_arg_with_default=...) -> str: ...
+@final
+class WillNotBeSubclassed:
+ def __new__(cls, *args: Any, **kwargs: Any) -> WillNotBeSubclassed: ...
+ def __enter__(self) -> WillNotBeSubclassed: ...
+ async def __aenter__(self) -> WillNotBeSubclassed: ...
+
+# we don't emit an error for these; out of scope for a linter
+class InvalidButPluginDoesNotCrash:
+ def __new__() -> InvalidButPluginDoesNotCrash: ...
+ def __enter__() -> InvalidButPluginDoesNotCrash: ...
+ async def __aenter__() -> InvalidButPluginDoesNotCrash: ...
+
+class BadIterator1(Iterator[int]):
+ def __iter__(self) -> Iterator[int]: ... # Y034 "__iter__" methods in classes like "BadIterator1" usually return "self" at runtime. Consider using "_typeshed.Self" in "BadIterator1.__iter__", e.g. "def __iter__(self: Self) -> Self: ..."
+
+class BadIterator2(typing.Iterator[int]):
+ def __iter__(self) -> Iterator[int]: ... # Y034 "__iter__" methods in classes like "BadIterator2" usually return "self" at runtime. Consider using "_typeshed.Self" in "BadIterator2.__iter__", e.g. "def __iter__(self: Self) -> Self: ..."
+
+class BadIterator3(typing_extensions.Iterator[int]):
+ def __iter__(self) -> collections.abc.Iterator[int]: ... # Y034 "__iter__" methods in classes like "BadIterator3" usually return "self" at runtime. Consider using "_typeshed.Self" in "BadIterator3.__iter__", e.g. "def __iter__(self: Self) -> Self: ..."
+
+class BadAsyncIterator(collections.abc.AsyncIterator[str]):
+ def __aiter__(self) -> typing.AsyncIterator[str]: ... # Y034 "__aiter__" methods in classes like "BadAsyncIterator" usually return "self" at runtime. Consider using "_typeshed.Self" in "BadAsyncIterator.__aiter__", e.g. "def __aiter__(self: Self) -> Self: ..."
+
+class Abstract(Iterator[str]):
+ @abstractmethod
+ def __iter__(self) -> Iterator[str]: ...
+ @abstractmethod
+ def __enter__(self) -> Abstract: ...
+ @abstractmethod
+ async def __aenter__(self) -> Abstract: ...
+
+class GoodIterator(Iterator[str]):
+ def __iter__(self: Self) -> Self: ...
+
+class GoodAsyncIterator(AsyncIterator[int]):
+ def __aiter__(self: Self) -> Self: ...
+
+class DoesNotInheritFromIterator:
+ def __iter__(self) -> DoesNotInheritFromIterator: ...
+
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
def __eq__(self, other: Any) -> bool: ...