Skip to content

Commit 3bfdb12

Browse files
authored
Revert "[ty] Completely remove the NoReturn shortcut optimization" (#23955)
## Summary This reverts commit 85a8516.
1 parent 85a8516 commit 3bfdb12

5 files changed

Lines changed: 111 additions & 34 deletions

File tree

crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,5 @@ def i[T: (int, str)](x: T) -> T:
485485
case _:
486486
assert_never(x)
487487

488-
# TODO: no error here
489-
# error: [invalid-return-type] "Return type does not match returned value: expected `T@i`, found `str | int`"
490488
return x
491489
```

crates/ty_python_semantic/src/semantic_index/builder.rs

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ use crate::semantic_index::expression::{Expression, ExpressionKind};
3535
use crate::semantic_index::member::MemberExprBuilder;
3636
use crate::semantic_index::place::{PlaceExpr, PlaceTableBuilder, ScopedPlaceId};
3737
use crate::semantic_index::predicate::{
38-
ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
39-
PredicateOrLiteral, ScopedPredicateId, StarImportPlaceholderPredicate,
38+
CallableAndCallExpr, ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate,
39+
PredicateNode, PredicateOrLiteral, ScopedPredicateId, StarImportPlaceholderPredicate,
4040
};
4141
use crate::semantic_index::re_exports::exported_names;
4242
use crate::semantic_index::reachability_constraints::{
@@ -2784,29 +2784,44 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
27842784
// We also only add these inside function scopes, since considering module-level
27852785
// constraints can affect the type of imported symbols, leading to a lot more
27862786
// work in third-party code.
2787-
let is_call = match value.as_ref() {
2788-
ast::Expr::Call(_) => true,
2789-
ast::Expr::Await(ast::ExprAwait { value: inner, .. }) => inner.is_call_expr(),
2790-
_ => false,
2787+
let call_info = match value.as_ref() {
2788+
ast::Expr::Call(ast::ExprCall { func, .. }) => {
2789+
Some((func.as_ref(), value.as_ref(), false))
2790+
}
2791+
ast::Expr::Await(ast::ExprAwait { value: inner, .. }) => match inner.as_ref() {
2792+
ast::Expr::Call(ast::ExprCall { func, .. }) => {
2793+
Some((func.as_ref(), value.as_ref(), true))
2794+
}
2795+
_ => None,
2796+
},
2797+
_ => None,
27912798
};
27922799

2793-
if is_call && !self.source_type.is_stub() && self.in_function_scope() {
2794-
let call_expr = self.add_standalone_expression(value.as_ref());
2800+
if let Some((func, expr, is_await)) = call_info {
2801+
if !self.source_type.is_stub() && self.in_function_scope() {
2802+
let callable = self.add_standalone_expression(func);
2803+
let call_expr = self.add_standalone_expression(expr);
2804+
2805+
let predicate = Predicate {
2806+
node: PredicateNode::ReturnsNever(CallableAndCallExpr {
2807+
callable,
2808+
call_expr,
2809+
is_await,
2810+
}),
2811+
is_positive: false,
2812+
};
2813+
let constraint = self.record_reachability_constraint(
2814+
PredicateOrLiteral::Predicate(predicate),
2815+
);
27952816

2796-
let predicate = Predicate {
2797-
node: PredicateNode::ReturnsNever(call_expr),
2798-
is_positive: false,
2799-
};
2800-
let constraint = self
2801-
.record_reachability_constraint(PredicateOrLiteral::Predicate(predicate));
2802-
2803-
// Also gate narrowing by this constraint: if the call returns
2804-
// `Never`, any narrowing in the current branch should be
2805-
// invalidated (since this path is unreachable). This enables
2806-
// narrowing to be preserved after if-statements where one branch
2807-
// calls a `NoReturn` function like `sys.exit()`.
2808-
self.current_use_def_map_mut()
2809-
.record_narrowing_constraint_for_all_places(constraint);
2817+
// Also gate narrowing by this constraint: if the call returns
2818+
// `Never`, any narrowing in the current branch should be
2819+
// invalidated (since this path is unreachable). This enables
2820+
// narrowing to be preserved after if-statements where one branch
2821+
// calls a `NoReturn` function like `sys.exit()`.
2822+
self.current_use_def_map_mut()
2823+
.record_narrowing_constraint_for_all_places(constraint);
2824+
}
28102825
}
28112826
}
28122827
_ => {

crates/ty_python_semantic/src/semantic_index/predicate.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,20 @@ impl PredicateOrLiteral<'_> {
9898
}
9999
}
100100

101+
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
102+
pub(crate) struct CallableAndCallExpr<'db> {
103+
pub(crate) callable: Expression<'db>,
104+
pub(crate) call_expr: Expression<'db>,
105+
/// Whether the call is wrapped in an `await` expression. If `true`, `call_expr` refers to the
106+
/// `await` expression rather than the call itself. This is used to detect terminal `await`s of
107+
/// async functions that return `Never`.
108+
pub(crate) is_await: bool,
109+
}
110+
101111
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
102112
pub(crate) enum PredicateNode<'db> {
103113
Expression(Expression<'db>),
104-
ReturnsNever(Expression<'db>),
114+
ReturnsNever(CallableAndCallExpr<'db>),
105115
Pattern(PatternPredicate<'db>),
106116
StarImportPlaceholder(StarImportPlaceholderPredicate<'db>),
107117
}

crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,12 @@ use crate::rank::RankBitBox;
205205
use crate::semantic_index::place::ScopedPlaceId;
206206
use crate::semantic_index::place_table;
207207
use crate::semantic_index::predicate::{
208-
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId,
208+
CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
209+
Predicates, ScopedPredicateId,
209210
};
210211
use crate::types::{
211-
IntersectionBuilder, KnownClass, NarrowingConstraint, Truthiness, Type, TypeContext,
212-
UnionBuilder, UnionType, infer_expression_type, infer_narrowing_constraint,
212+
CallableTypes, IntersectionBuilder, KnownClass, NarrowingConstraint, Truthiness, Type,
213+
TypeContext, UnionBuilder, UnionType, infer_expression_type, infer_narrowing_constraint,
213214
};
214215

215216
/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
@@ -1089,12 +1090,62 @@ impl ReachabilityConstraints {
10891090
.bool(db)
10901091
.negate_if(!predicate.is_positive)
10911092
}
1092-
PredicateNode::ReturnsNever(call_expr) => {
1093-
let call_expr_ty = infer_expression_type(db, call_expr, TypeContext::default());
1094-
if call_expr_ty.is_equivalent_to(db, Type::Never) {
1095-
Truthiness::AlwaysTrue
1093+
PredicateNode::ReturnsNever(CallableAndCallExpr {
1094+
callable,
1095+
call_expr,
1096+
is_await,
1097+
}) => {
1098+
// We first infer just the type of the callable. In the most likely case that the
1099+
// function is not marked with `NoReturn`, or that it always returns `NoReturn`,
1100+
// doing so allows us to avoid the more expensive work of inferring the entire call
1101+
// expression (which could involve inferring argument types to possibly run the overload
1102+
// selection algorithm).
1103+
// Avoiding this on the happy-path is important because these constraints can be
1104+
// very large in number, since we add them on all statement level function calls.
1105+
let ty = infer_expression_type(db, callable, TypeContext::default());
1106+
1107+
// Short-circuit for well known types that are known not to return `Never` when called.
1108+
// Without the short-circuit, we've seen that threads keep blocking each other
1109+
// because they all try to acquire Salsa's `CallableType` lock that ensures each type
1110+
// is only interned once. The lock is so heavily congested because there are only
1111+
// very few dynamic types, in which case Salsa's sharding the locks by value
1112+
// doesn't help much.
1113+
// See <https://github.com/astral-sh/ty/issues/968>.
1114+
if matches!(ty, Type::Dynamic(_)) {
1115+
return Truthiness::AlwaysFalse.negate_if(!predicate.is_positive);
1116+
}
1117+
1118+
let overloads_iterator = if let Some(callable) = ty
1119+
.try_upcast_to_callable(db)
1120+
.and_then(CallableTypes::exactly_one)
1121+
{
1122+
callable.signatures(db).overloads.iter()
10961123
} else {
1124+
return Truthiness::AlwaysFalse.negate_if(!predicate.is_positive);
1125+
};
1126+
1127+
let mut no_overloads_return_never = true;
1128+
let mut all_overloads_return_never = true;
1129+
let mut any_overload_is_generic = false;
1130+
1131+
for overload in overloads_iterator {
1132+
let returns_never = overload.return_ty.is_equivalent_to(db, Type::Never);
1133+
no_overloads_return_never &= !returns_never;
1134+
all_overloads_return_never &= returns_never;
1135+
any_overload_is_generic |= overload.return_ty.has_typevar(db);
1136+
}
1137+
1138+
if no_overloads_return_never && !any_overload_is_generic && !is_await {
10971139
Truthiness::AlwaysFalse
1140+
} else if all_overloads_return_never {
1141+
Truthiness::AlwaysTrue
1142+
} else {
1143+
let call_expr_ty = infer_expression_type(db, call_expr, TypeContext::default());
1144+
if call_expr_ty.is_equivalent_to(db, Type::Never) {
1145+
Truthiness::AlwaysTrue
1146+
} else {
1147+
Truthiness::AlwaysFalse
1148+
}
10981149
}
10991150
.negate_if(!predicate.is_positive)
11001151
}

crates/ty_python_semantic/src/types/narrow.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ use crate::semantic_index::expression::Expression;
33
use crate::semantic_index::place::{PlaceExpr, PlaceTable, PlaceTableBuilder, ScopedPlaceId};
44
use crate::semantic_index::place_table;
55
use crate::semantic_index::predicate::{
6-
ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
6+
CallableAndCallExpr, ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate,
7+
PredicateNode,
78
};
89
use crate::semantic_index::scope::ScopeId;
910
use crate::subscript::PyIndex;
@@ -761,7 +762,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
761762
match self.predicate {
762763
PredicateNode::Expression(expression) => expression.scope(self.db),
763764
PredicateNode::Pattern(pattern) => pattern.scope(self.db),
764-
PredicateNode::ReturnsNever(call_expr) => call_expr.scope(self.db),
765+
PredicateNode::ReturnsNever(CallableAndCallExpr { callable, .. }) => {
766+
callable.scope(self.db)
767+
}
765768
PredicateNode::StarImportPlaceholder(definition) => definition.scope(self.db),
766769
}
767770
}

0 commit comments

Comments
 (0)