Skip to content

Commit 58f1b19

Browse files
committed
[ty] Implement equivalence for protocols with method members
1 parent 44f2f77 commit 58f1b19

File tree

4 files changed

+91
-16
lines changed

4 files changed

+91
-16
lines changed

crates/ty_python_semantic/resources/mdtest/protocols.md

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,8 +1476,7 @@ class P1(Protocol):
14761476
class P2(Protocol):
14771477
def x(self, y: int) -> None: ...
14781478

1479-
# TODO: this should pass
1480-
static_assert(is_equivalent_to(P1, P2)) # error: [static-assert-error]
1479+
static_assert(is_equivalent_to(P1, P2))
14811480
```
14821481

14831482
As with protocols that only have non-method members, this also holds true when they appear in
@@ -1487,8 +1486,7 @@ differently ordered unions:
14871486
class A: ...
14881487
class B: ...
14891488

1490-
# TODO: this should pass
1491-
static_assert(is_equivalent_to(A | B | P1, P2 | B | A)) # error: [static-assert-error]
1489+
static_assert(is_equivalent_to(A | B | P1, P2 | B | A))
14921490
```
14931491

14941492
## Narrowing of protocols
@@ -1896,6 +1894,86 @@ if isinstance(obj, (B, A)):
18961894
reveal_type(obj) # revealed: (Unknown & B) | (Unknown & A)
18971895
```
18981896

1897+
### Protocols that use `Self`
1898+
1899+
`Self` is a `TypeVar` with an upper bound of the class in which it is defined. This means that
1900+
`Self` annotations in protocols can also be tricky to handle without infinite recursion and stack
1901+
overflows.
1902+
1903+
```toml
1904+
[environment]
1905+
python-version = "3.12"
1906+
```
1907+
1908+
```py
1909+
from typing_extensions import Protocol, Self
1910+
from ty_extensions import static_assert
1911+
1912+
class _HashObject(Protocol):
1913+
def copy(self) -> Self: ...
1914+
1915+
class Foo: ...
1916+
1917+
# Attempting to build this union caused us to overflow on an early version of
1918+
# <https://github.com/astral-sh/ruff/pull/18659>
1919+
x: Foo | _HashObject
1920+
```
1921+
1922+
Some other similar cases that caused issues in our early `Protocol` implementation:
1923+
1924+
`a.py`:
1925+
1926+
```py
1927+
from typing_extensions import Protocol, Self
1928+
1929+
class PGconn(Protocol):
1930+
def connect(self) -> Self: ...
1931+
1932+
class Connection:
1933+
pgconn: PGconn
1934+
1935+
def is_crdb(conn: PGconn) -> bool:
1936+
return isinstance(conn, Connection)
1937+
```
1938+
1939+
and:
1940+
1941+
`b.py`:
1942+
1943+
```py
1944+
from typing_extensions import Protocol
1945+
1946+
class PGconn(Protocol):
1947+
def connect[T: PGconn](self: T) -> T: ...
1948+
1949+
class Connection:
1950+
pgconn: PGconn
1951+
1952+
def f(x: PGconn):
1953+
isinstance(x, Connection)
1954+
```
1955+
1956+
### Recursive protocols used as the first argument to `cast()`
1957+
1958+
These caused issues in an early version of our `Protocol` implementation due to the fact that we use
1959+
a recursive function in our `cast()` implementation to check whether a type contains `Unknown` or
1960+
`Todo`. Recklessly recursing into a type causes stack overflows if the type is recursive:
1961+
1962+
```toml
1963+
[environment]
1964+
python-version = "3.12"
1965+
```
1966+
1967+
```py
1968+
from typing import cast, Protocol
1969+
1970+
class Iterator[T](Protocol):
1971+
def __iter__(self) -> Iterator[T]: ...
1972+
1973+
def f(value: Iterator):
1974+
cast(Iterator, value) # error: [redundant-cast]
1975+
```
1976+
18991977
## TODO
19001978

19011979
Add tests for:

crates/ty_python_semantic/src/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,7 @@ impl<'db> Type<'db> {
11021102
Type::Dynamic(_) => Some(CallableType::single(db, Signature::dynamic(self))),
11031103

11041104
Type::FunctionLiteral(function_literal) => {
1105-
Some(function_literal.into_callable_type(db))
1105+
Some(Type::Callable(function_literal.into_callable_type(db)))
11061106
}
11071107
Type::BoundMethod(bound_method) => Some(bound_method.into_callable_type(db)),
11081108

crates/ty_python_semantic/src/types/function.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,8 +767,8 @@ impl<'db> FunctionType<'db> {
767767
}
768768

769769
/// Convert the `FunctionType` into a [`Type::Callable`].
770-
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> {
771-
Type::Callable(CallableType::new(db, self.signature(db), false))
770+
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> {
771+
CallableType::new(db, self.signature(db), false)
772772
}
773773

774774
/// Convert the `FunctionType` into a [`Type::BoundMethod`].

crates/ty_python_semantic/src/types/protocol_class.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ impl<'db> ProtocolMemberData<'db> {
260260

261261
#[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, Hash)]
262262
enum ProtocolMemberKind<'db> {
263-
Method(Type<'db>), // TODO: use CallableType
263+
Method(CallableType<'db>),
264264
Property(PropertyInstanceType<'db>),
265265
Other(Type<'db>),
266266
}
@@ -335,7 +335,7 @@ fn walk_protocol_member<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
335335
visitor: &mut V,
336336
) {
337337
match member.kind {
338-
ProtocolMemberKind::Method(method) => visitor.visit_type(db, method),
338+
ProtocolMemberKind::Method(method) => visitor.visit_callable_type(db, method),
339339
ProtocolMemberKind::Property(property) => {
340340
visitor.visit_property_instance_type(db, property);
341341
}
@@ -354,7 +354,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
354354

355355
fn ty(&self) -> Type<'db> {
356356
match &self.kind {
357-
ProtocolMemberKind::Method(callable) => *callable,
357+
ProtocolMemberKind::Method(callable) => Type::Callable(*callable),
358358
ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property),
359359
ProtocolMemberKind::Other(ty) => *ty,
360360
}
@@ -508,13 +508,10 @@ fn cached_protocol_interface<'db>(
508508
(Type::Callable(callable), BoundOnClass::Yes)
509509
if callable.is_function_like(db) =>
510510
{
511-
ProtocolMemberKind::Method(ty)
511+
ProtocolMemberKind::Method(callable)
512512
}
513-
// TODO: method members that have `FunctionLiteral` types should be upcast
514-
// to `CallableType` so that two protocols with identical method members
515-
// are recognized as equivalent.
516-
(Type::FunctionLiteral(_function), BoundOnClass::Yes) => {
517-
ProtocolMemberKind::Method(ty)
513+
(Type::FunctionLiteral(function), BoundOnClass::Yes) => {
514+
ProtocolMemberKind::Method(function.into_callable_type(db))
518515
}
519516
_ => ProtocolMemberKind::Other(ty),
520517
};

0 commit comments

Comments
 (0)