Skip to content

Commit 5a21bea

Browse files
committed
just the api parts
1 parent 9090aea commit 5a21bea

11 files changed

Lines changed: 580 additions & 185 deletions

File tree

crates/ty_python_semantic/src/types.rs

Lines changed: 329 additions & 83 deletions
Large diffs are not rendered by default.

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ use crate::types::enums::is_enum_class;
2626
use crate::types::function::{
2727
DataclassTransformerParams, FunctionDecorators, FunctionType, KnownFunction, OverloadLiteral,
2828
};
29-
use crate::types::generics::{Specialization, SpecializationBuilder, SpecializationError};
29+
use crate::types::generics::{
30+
InferableTypeVars, Specialization, SpecializationBuilder, SpecializationError,
31+
};
3032
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
3133
use crate::types::tuple::{TupleLength, TupleType};
3234
use crate::types::{
@@ -597,7 +599,8 @@ impl<'db> Bindings<'db> {
597599
Type::FunctionLiteral(function_type) => match function_type.known(db) {
598600
Some(KnownFunction::IsEquivalentTo) => {
599601
if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() {
600-
let constraints = ty_a.when_equivalent_to(db, *ty_b);
602+
let constraints =
603+
ty_a.when_equivalent_to(db, *ty_b, InferableTypeVars::None);
601604
let tracked = TrackedConstraintSet::new(db, constraints);
602605
overload.set_return_type(Type::KnownInstance(
603606
KnownInstanceType::ConstraintSet(tracked),
@@ -607,7 +610,8 @@ impl<'db> Bindings<'db> {
607610

608611
Some(KnownFunction::IsSubtypeOf) => {
609612
if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() {
610-
let constraints = ty_a.when_subtype_of(db, *ty_b);
613+
let constraints =
614+
ty_a.when_subtype_of(db, *ty_b, InferableTypeVars::None);
611615
let tracked = TrackedConstraintSet::new(db, constraints);
612616
overload.set_return_type(Type::KnownInstance(
613617
KnownInstanceType::ConstraintSet(tracked),
@@ -617,7 +621,8 @@ impl<'db> Bindings<'db> {
617621

618622
Some(KnownFunction::IsAssignableTo) => {
619623
if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() {
620-
let constraints = ty_a.when_assignable_to(db, *ty_b);
624+
let constraints =
625+
ty_a.when_assignable_to(db, *ty_b, InferableTypeVars::None);
621626
let tracked = TrackedConstraintSet::new(db, constraints);
622627
overload.set_return_type(Type::KnownInstance(
623628
KnownInstanceType::ConstraintSet(tracked),
@@ -627,7 +632,8 @@ impl<'db> Bindings<'db> {
627632

628633
Some(KnownFunction::IsDisjointFrom) => {
629634
if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() {
630-
let constraints = ty_a.when_disjoint_from(db, *ty_b);
635+
let constraints =
636+
ty_a.when_disjoint_from(db, *ty_b, InferableTypeVars::None);
631637
let tracked = TrackedConstraintSet::new(db, constraints);
632638
overload.set_return_type(Type::KnownInstance(
633639
KnownInstanceType::ConstraintSet(tracked),
@@ -1407,7 +1413,10 @@ impl<'db> CallableBinding<'db> {
14071413
let parameter_type = overload.signature.parameters()[*parameter_index]
14081414
.annotated_type()
14091415
.unwrap_or(Type::unknown());
1410-
if argument_type.is_assignable_to(db, parameter_type) {
1416+
if argument_type
1417+
.when_assignable_to(db, parameter_type, overload.inferable_typevars)
1418+
.is_always_satisfied()
1419+
{
14111420
is_argument_assignable_to_any_overload = true;
14121421
break 'overload;
14131422
}
@@ -1633,7 +1642,14 @@ impl<'db> CallableBinding<'db> {
16331642
.unwrap_or(Type::unknown());
16341643
let first_parameter_type = &mut first_parameter_types[parameter_index];
16351644
if let Some(first_parameter_type) = first_parameter_type {
1636-
if !first_parameter_type.is_equivalent_to(db, current_parameter_type) {
1645+
if !first_parameter_type
1646+
.when_equivalent_to(
1647+
db,
1648+
current_parameter_type,
1649+
overload.inferable_typevars,
1650+
)
1651+
.is_always_satisfied()
1652+
{
16371653
participating_parameter_indexes.insert(parameter_index);
16381654
}
16391655
} else {
@@ -1750,7 +1766,12 @@ impl<'db> CallableBinding<'db> {
17501766
matching_overloads.all(|(_, overload)| {
17511767
overload
17521768
.return_type()
1753-
.is_equivalent_to(db, first_overload_return_type)
1769+
.when_equivalent_to(
1770+
db,
1771+
first_overload_return_type,
1772+
overload.inferable_typevars,
1773+
)
1774+
.is_always_satisfied()
17541775
})
17551776
} else {
17561777
// No matching overload
@@ -2461,6 +2482,7 @@ struct ArgumentTypeChecker<'a, 'db> {
24612482
call_expression_tcx: &'a TypeContext<'db>,
24622483
errors: &'a mut Vec<BindingError<'db>>,
24632484

2485+
inferable_typevars: InferableTypeVars<'db, 'db>,
24642486
specialization: Option<Specialization<'db>>,
24652487
}
24662488

@@ -2482,6 +2504,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
24822504
parameter_tys,
24832505
call_expression_tcx,
24842506
errors,
2507+
inferable_typevars: InferableTypeVars::None,
24852508
specialization: None,
24862509
}
24872510
}
@@ -2514,11 +2537,12 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25142537
}
25152538

25162539
fn infer_specialization(&mut self) {
2517-
if self.signature.generic_context.is_none() {
2540+
let Some(generic_context) = self.signature.generic_context else {
25182541
return;
2519-
}
2542+
};
25202543

2521-
let mut builder = SpecializationBuilder::new(self.db);
2544+
// TODO: Use the list of inferable typevars from the generic context of the callable.
2545+
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
25222546

25232547
// Note that we infer the annotated type _before_ the arguments if this call is part of
25242548
// an annotated assignment, to closer match the order of any unions written in the type
@@ -2563,10 +2587,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25632587
}
25642588
}
25652589

2566-
self.specialization = self
2567-
.signature
2568-
.generic_context
2569-
.map(|gc| builder.build(gc, *self.call_expression_tcx));
2590+
self.specialization = Some(builder.build(generic_context, *self.call_expression_tcx));
25702591
}
25712592

25722593
fn check_argument_type(
@@ -2590,7 +2611,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25902611
// constraint set that we get from this assignability check, instead of inferring and
25912612
// building them in an earlier separate step.
25922613
if argument_type
2593-
.when_assignable_to(self.db, expected_ty)
2614+
.when_assignable_to(self.db, expected_ty, self.inferable_typevars)
25942615
.is_never_satisfied()
25952616
{
25962617
let positional = matches!(argument, Argument::Positional | Argument::Synthetic)
@@ -2719,7 +2740,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27192740
return;
27202741
};
27212742

2722-
if !key_type.is_assignable_to(self.db, KnownClass::Str.to_instance(self.db)) {
2743+
if !key_type
2744+
.when_assignable_to(
2745+
self.db,
2746+
KnownClass::Str.to_instance(self.db),
2747+
self.inferable_typevars,
2748+
)
2749+
.is_always_satisfied()
2750+
{
27232751
self.errors.push(BindingError::InvalidKeyType {
27242752
argument_index: adjusted_argument_index,
27252753
provided_ty: key_type,
@@ -2754,8 +2782,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27542782
}
27552783
}
27562784

2757-
fn finish(self) -> Option<Specialization<'db>> {
2758-
self.specialization
2785+
fn finish(self) -> (InferableTypeVars<'db, 'db>, Option<Specialization<'db>>) {
2786+
(self.inferable_typevars, self.specialization)
27592787
}
27602788
}
27612789

@@ -2819,6 +2847,9 @@ pub(crate) struct Binding<'db> {
28192847
/// Return type of the call.
28202848
return_ty: Type<'db>,
28212849

2850+
/// The inferable typevars in this signature.
2851+
inferable_typevars: InferableTypeVars<'db, 'db>,
2852+
28222853
/// The specialization that was inferred from the argument types, if the callable is generic.
28232854
specialization: Option<Specialization<'db>>,
28242855

@@ -2845,6 +2876,7 @@ impl<'db> Binding<'db> {
28452876
callable_type: signature_type,
28462877
signature_type,
28472878
return_ty: Type::unknown(),
2879+
inferable_typevars: InferableTypeVars::None,
28482880
specialization: None,
28492881
argument_matches: Box::from([]),
28502882
variadic_argument_matched_to_variadic_parameter: false,
@@ -2916,7 +2948,7 @@ impl<'db> Binding<'db> {
29162948
checker.infer_specialization();
29172949

29182950
checker.check_argument_types();
2919-
self.specialization = checker.finish();
2951+
(self.inferable_typevars, self.specialization) = checker.finish();
29202952
if let Some(specialization) = self.specialization {
29212953
self.return_ty = self.return_ty.apply_specialization(db, specialization);
29222954
}
@@ -3010,6 +3042,7 @@ impl<'db> Binding<'db> {
30103042
fn snapshot(&self) -> BindingSnapshot<'db> {
30113043
BindingSnapshot {
30123044
return_ty: self.return_ty,
3045+
inferable_typevars: self.inferable_typevars,
30133046
specialization: self.specialization,
30143047
argument_matches: self.argument_matches.clone(),
30153048
parameter_tys: self.parameter_tys.clone(),
@@ -3020,13 +3053,15 @@ impl<'db> Binding<'db> {
30203053
fn restore(&mut self, snapshot: BindingSnapshot<'db>) {
30213054
let BindingSnapshot {
30223055
return_ty,
3056+
inferable_typevars,
30233057
specialization,
30243058
argument_matches,
30253059
parameter_tys,
30263060
errors,
30273061
} = snapshot;
30283062

30293063
self.return_ty = return_ty;
3064+
self.inferable_typevars = inferable_typevars;
30303065
self.specialization = specialization;
30313066
self.argument_matches = argument_matches;
30323067
self.parameter_tys = parameter_tys;
@@ -3046,6 +3081,7 @@ impl<'db> Binding<'db> {
30463081
/// Resets the state of this binding to its initial state.
30473082
fn reset(&mut self) {
30483083
self.return_ty = Type::unknown();
3084+
self.inferable_typevars = InferableTypeVars::None;
30493085
self.specialization = None;
30503086
self.argument_matches = Box::from([]);
30513087
self.parameter_tys = Box::from([]);
@@ -3056,6 +3092,7 @@ impl<'db> Binding<'db> {
30563092
#[derive(Clone, Debug)]
30573093
struct BindingSnapshot<'db> {
30583094
return_ty: Type<'db>,
3095+
inferable_typevars: InferableTypeVars<'db, 'db>,
30593096
specialization: Option<Specialization<'db>>,
30603097
argument_matches: Box<[MatchedArgument<'db>]>,
30613098
parameter_tys: Box<[Option<Type<'db>>]>,
@@ -3095,6 +3132,7 @@ impl<'db> CallableBindingSnapshot<'db> {
30953132

30963133
// ... and update the snapshot with the current state of the binding.
30973134
snapshot.return_ty = binding.return_ty;
3135+
snapshot.inferable_typevars = binding.inferable_typevars;
30983136
snapshot.specialization = binding.specialization;
30993137
snapshot
31003138
.argument_matches

crates/ty_python_semantic/src/types/class.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::types::diagnostic::INVALID_TYPE_ALIAS_TYPE;
2222
use crate::types::enums::enum_metadata;
2323
use crate::types::function::{DataclassTransformerParams, KnownFunction};
2424
use crate::types::generics::{
25-
GenericContext, Specialization, walk_generic_context, walk_specialization,
25+
GenericContext, InferableTypeVars, Specialization, walk_generic_context, walk_specialization,
2626
};
2727
use crate::types::infer::nearest_enclosing_class;
2828
use crate::types::member::{Member, class_member};
@@ -540,17 +540,20 @@ impl<'db> ClassType<'db> {
540540

541541
/// Return `true` if `other` is present in this class's MRO.
542542
pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: ClassType<'db>) -> bool {
543-
self.when_subclass_of(db, other).is_always_satisfied()
543+
self.when_subclass_of(db, other, InferableTypeVars::None)
544+
.is_always_satisfied()
544545
}
545546

546547
pub(super) fn when_subclass_of(
547548
self,
548549
db: &'db dyn Db,
549550
other: ClassType<'db>,
551+
inferable: InferableTypeVars<'_, 'db>,
550552
) -> ConstraintSet<'db> {
551553
self.has_relation_to_impl(
552554
db,
553555
other,
556+
inferable,
554557
TypeRelation::Subtyping,
555558
&HasRelationToVisitor::default(),
556559
&IsDisjointVisitor::default(),
@@ -561,6 +564,7 @@ impl<'db> ClassType<'db> {
561564
self,
562565
db: &'db dyn Db,
563566
other: Self,
567+
inferable: InferableTypeVars<'_, 'db>,
564568
relation: TypeRelation,
565569
relation_visitor: &HasRelationToVisitor<'db>,
566570
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -586,6 +590,7 @@ impl<'db> ClassType<'db> {
586590
base.specialization(db).has_relation_to_impl(
587591
db,
588592
other.specialization(db),
593+
inferable,
589594
relation,
590595
relation_visitor,
591596
disjointness_visitor,
@@ -610,6 +615,7 @@ impl<'db> ClassType<'db> {
610615
self,
611616
db: &'db dyn Db,
612617
other: ClassType<'db>,
618+
inferable: InferableTypeVars<'_, 'db>,
613619
visitor: &IsEquivalentVisitor<'db>,
614620
) -> ConstraintSet<'db> {
615621
if self == other {
@@ -628,6 +634,7 @@ impl<'db> ClassType<'db> {
628634
this.specialization(db).is_equivalent_to_impl(
629635
db,
630636
other.specialization(db),
637+
inferable,
631638
visitor,
632639
)
633640
})

0 commit comments

Comments
 (0)