Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 82 additions & 4 deletions crates/ty_python_semantic/resources/mdtest/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -1476,8 +1476,7 @@ class P1(Protocol):
class P2(Protocol):
def x(self, y: int) -> None: ...

# TODO: this should pass
static_assert(is_equivalent_to(P1, P2)) # error: [static-assert-error]
static_assert(is_equivalent_to(P1, P2))
```

As with protocols that only have non-method members, this also holds true when they appear in
Expand All @@ -1487,8 +1486,7 @@ differently ordered unions:
class A: ...
class B: ...

# TODO: this should pass
static_assert(is_equivalent_to(A | B | P1, P2 | B | A)) # error: [static-assert-error]
static_assert(is_equivalent_to(A | B | P1, P2 | B | A))
```

## Narrowing of protocols
Expand Down Expand Up @@ -1896,6 +1894,86 @@ if isinstance(obj, (B, A)):
reveal_type(obj) # revealed: (Unknown & B) | (Unknown & A)
```

### Protocols that use `Self`

`Self` is a `TypeVar` with an upper bound of the class in which it is defined. This means that
`Self` annotations in protocols can also be tricky to handle without infinite recursion and stack
overflows.

```toml
[environment]
python-version = "3.12"
```

```py
from typing_extensions import Protocol, Self
from ty_extensions import static_assert

class _HashObject(Protocol):
def copy(self) -> Self: ...

class Foo: ...

# Attempting to build this union caused us to overflow on an early version of
# <https://github.com/astral-sh/ruff/pull/18659>
x: Foo | _HashObject
```

Some other similar cases that caused issues in our early `Protocol` implementation:

`a.py`:

```py
from typing_extensions import Protocol, Self

class PGconn(Protocol):
def connect(self) -> Self: ...

class Connection:
pgconn: PGconn

def is_crdb(conn: PGconn) -> bool:
return isinstance(conn, Connection)
```

and:

`b.py`:

```py
from typing_extensions import Protocol

class PGconn(Protocol):
def connect[T: PGconn](self: T) -> T: ...

class Connection:
pgconn: PGconn

def f(x: PGconn):
isinstance(x, Connection)
```

### Recursive protocols used as the first argument to `cast()`

These caused issues in an early version of our `Protocol` implementation due to the fact that we use
a recursive function in our `cast()` implementation to check whether a type contains `Unknown` or
`Todo`. Recklessly recursing into a type causes stack overflows if the type is recursive:

```toml
[environment]
python-version = "3.12"
```

```py
from typing import cast, Protocol

class Iterator[T](Protocol):
def __iter__(self) -> Iterator[T]: ...

def f(value: Iterator):
cast(Iterator, value) # error: [redundant-cast]
```

## TODO

Add tests for:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,20 @@ static_assert(not is_equivalent_to(CallableTypeOf[f12], CallableTypeOf[f13]))
static_assert(not is_equivalent_to(CallableTypeOf[f13], CallableTypeOf[f12]))
```

### Unions containing `Callable`s

Two unions containing different `Callable` types are equivalent even if the unions are differently
ordered:

```py
from ty_extensions import CallableTypeOf, Unknown, is_equivalent_to, static_assert

def f(x): ...
def g(x: Unknown): ...

static_assert(is_equivalent_to(CallableTypeOf[f] | int | str, str | int | CallableTypeOf[g]))
```

### Unions containing `Callable`s containing unions

Differently ordered unions inside `Callable`s inside unions can still be equivalent:
Expand Down
6 changes: 5 additions & 1 deletion crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,7 @@ impl<'db> Type<'db> {
Type::Dynamic(_) => Some(CallableType::single(db, Signature::dynamic(self))),

Type::FunctionLiteral(function_literal) => {
Some(function_literal.into_callable_type(db))
Some(Type::Callable(function_literal.into_callable_type(db)))
}
Type::BoundMethod(bound_method) => Some(bound_method.into_callable_type(db)),

Expand Down Expand Up @@ -7336,6 +7336,10 @@ impl<'db> CallableType<'db> {
///
/// See [`Type::is_equivalent_to`] for more details.
fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
if self == other {
return true;
}

self.is_function_like(db) == other.is_function_like(db)
&& self
.signatures(db)
Expand Down
22 changes: 21 additions & 1 deletion crates/ty_python_semantic/src/types/cyclic.rs
Copy link
Member Author

@AlexWaygood AlexWaygood Jul 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed that the reason we never finished when trying to type-check DateType on earlier versions of this PR was that we were repeatedly calling Type::normalized_impl on the same types over and over again. Caching the results within any one call to Type::normalized() fixes this.

Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use rustc_hash::FxHashMap;

use crate::FxIndexSet;
use crate::types::Type;
use std::cmp::Eq;
Expand All @@ -19,14 +21,27 @@ pub(crate) type PairVisitor<'db> = CycleDetector<(Type<'db>, Type<'db>), bool>;

#[derive(Debug)]
pub(crate) struct CycleDetector<T, R> {
/// If the type we're visiting is present in `seen`,
/// it indicates that we've hit a cycle (due to a recursive type);
/// we need to immediately short circuit the whole operation and return the fallback value.
/// That's why we pop items off the end of `seen` after we've visited them.
seen: FxIndexSet<T>,

/// Unlike `seen`, this field is a pure performance optimisation (and an essential one).
/// If the type we're trying to normalize is present in `cache`, it doesn't necessarily mean we've hit a cycle:
/// it just means that we've already visited this inner type as part of a bigger call chain we're currently in.
/// Since this cache is just a performance optimisation, it doesn't make sense to pop items off the end of the
/// cache after they've been visited (it would sort-of defeat the point of a cache if we did!)
cache: FxHashMap<T, R>,

fallback: R,
}

impl<T: Hash + Eq, R: Copy> CycleDetector<T, R> {
impl<T: Hash + Eq + Copy, R: Copy> CycleDetector<T, R> {
pub(crate) fn new(fallback: R) -> Self {
CycleDetector {
seen: FxIndexSet::default(),
cache: FxHashMap::default(),
fallback,
}
}
Expand All @@ -35,7 +50,12 @@ impl<T: Hash + Eq, R: Copy> CycleDetector<T, R> {
if !self.seen.insert(item) {
return self.fallback;
}
if let Some(ty) = self.cache.get(&item) {
self.seen.pop();
return *ty;
}
let ret = func(self);
self.cache.insert(item, ret);
self.seen.pop();
ret
}
Expand Down
4 changes: 2 additions & 2 deletions crates/ty_python_semantic/src/types/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,8 @@ impl<'db> FunctionType<'db> {
}

/// Convert the `FunctionType` into a [`Type::Callable`].
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> {
Type::Callable(CallableType::new(db, self.signature(db), false))
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> {
CallableType::new(db, self.signature(db), false)
}

/// Convert the `FunctionType` into a [`Type::BoundMethod`].
Expand Down
9 changes: 8 additions & 1 deletion crates/ty_python_semantic/src/types/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,14 @@ impl<'db> ProtocolInstanceType<'db> {
///
/// TODO: consider the types of the members as well as their existence
pub(super) fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
self.normalized(db) == other.normalized(db)
if self == other {
return true;
}
let self_normalized = self.normalized(db);
if self_normalized == Type::ProtocolInstance(other) {
return true;
}
self_normalized == other.normalized(db)
}

/// Return `true` if this protocol type is disjoint from the protocol `other`.
Expand Down
15 changes: 6 additions & 9 deletions crates/ty_python_semantic/src/types/protocol_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ impl<'db> ProtocolMemberData<'db> {

#[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, Hash)]
enum ProtocolMemberKind<'db> {
Method(Type<'db>), // TODO: use CallableType
Method(CallableType<'db>),
Property(PropertyInstanceType<'db>),
Other(Type<'db>),
}
Expand Down Expand Up @@ -335,7 +335,7 @@ fn walk_protocol_member<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
visitor: &mut V,
) {
match member.kind {
ProtocolMemberKind::Method(method) => visitor.visit_type(db, method),
ProtocolMemberKind::Method(method) => visitor.visit_callable_type(db, method),
ProtocolMemberKind::Property(property) => {
visitor.visit_property_instance_type(db, property);
}
Expand All @@ -354,7 +354,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {

fn ty(&self) -> Type<'db> {
match &self.kind {
ProtocolMemberKind::Method(callable) => *callable,
ProtocolMemberKind::Method(callable) => Type::Callable(*callable),
ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property),
ProtocolMemberKind::Other(ty) => *ty,
}
Expand Down Expand Up @@ -508,13 +508,10 @@ fn cached_protocol_interface<'db>(
(Type::Callable(callable), BoundOnClass::Yes)
if callable.is_function_like(db) =>
{
ProtocolMemberKind::Method(ty)
ProtocolMemberKind::Method(callable)
}
// TODO: method members that have `FunctionLiteral` types should be upcast
// to `CallableType` so that two protocols with identical method members
// are recognized as equivalent.
(Type::FunctionLiteral(_function), BoundOnClass::Yes) => {
ProtocolMemberKind::Method(ty)
(Type::FunctionLiteral(function), BoundOnClass::Yes) => {
ProtocolMemberKind::Method(function.into_callable_type(db))
}
_ => ProtocolMemberKind::Other(ty),
};
Expand Down
11 changes: 8 additions & 3 deletions crates/ty_python_semantic/src/types/signatures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1318,8 +1318,13 @@ impl<'db> Parameter<'db> {
form,
} = self;

// Ensure unions and intersections are ordered in the annotated type (if there is one)
let annotated_type = annotated_type.map(|ty| ty.normalized_impl(db, visitor));
// Ensure unions and intersections are ordered in the annotated type (if there is one).
// Ensure that a parameter without an annotation is treated equivalently to a parameter
// with a dynamic type as its annotation. (We must use `Any` here as all dynamic types
// normalize to `Any`.)
let annotated_type = annotated_type
.map(|ty| ty.normalized_impl(db, visitor))
.unwrap_or_else(Type::any);

// Ensure that parameter names are stripped from positional-only, variadic and keyword-variadic parameters.
// Ensure that we only record whether a parameter *has* a default
Expand Down Expand Up @@ -1351,7 +1356,7 @@ impl<'db> Parameter<'db> {
};

Self {
annotated_type,
annotated_type: Some(annotated_type),
kind,
form: *form,
}
Expand Down
Loading