Skip to content

Commit 8bf55be

Browse files
committed
fix diagnostic GenericConstraintMismatch
1 parent 084e30b commit 8bf55be

File tree

2 files changed

+203
-7
lines changed

2 files changed

+203
-7
lines changed

crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs

Lines changed: 143 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
use emmylua_parser::{LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaDocTagType};
1+
use std::ops::Deref;
2+
3+
use emmylua_parser::{LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaDocTagType, LuaExpr};
24
use rowan::TextRange;
35

46
use crate::diagnostic::checker::generic::infer_doc_type::infer_doc_type;
57
use crate::diagnostic::checker::param_type_check::get_call_source_type;
68
use crate::{
7-
humanize_type, DiagnosticCode, GenericTplId, LuaMemberOwner, LuaSemanticDeclId, LuaSignature,
8-
LuaStringTplType, LuaType, RenderLevel, SemanticDeclLevel, SemanticModel, TypeCheckFailReason,
9-
TypeCheckResult,
9+
humanize_type, DiagnosticCode, GenericTplId, LuaDeclExtra, LuaMemberOwner, LuaSemanticDeclId,
10+
LuaSignature, LuaStringTplType, LuaType, RenderLevel, SemanticDeclLevel, SemanticModel,
11+
TypeCheckFailReason, TypeCheckResult, TypeOps, VariadicType,
1012
};
1113

1214
use crate::diagnostic::checker::Checker;
@@ -83,9 +85,8 @@ fn check_call_expr(
8385
.get_signature_index()
8486
.get(&signature_id)?;
8587
let mut params = signature.get_type_params();
88+
let mut arg_infos = get_arg_infos(semantic_model, &call_expr)?;
8689

87-
let arg_exprs = call_expr.get_args_list()?.get_args().collect::<Vec<_>>();
88-
let mut arg_infos = semantic_model.infer_expr_list_types(&arg_exprs, None);
8990
match (call_expr.is_colon_call(), signature.is_colon_define) {
9091
(true, true) | (false, false) => {}
9192
(false, true) => {
@@ -102,7 +103,6 @@ fn check_call_expr(
102103
);
103104
}
104105
}
105-
106106
for (i, (_, param_type)) in params.iter().enumerate() {
107107
let param_type = if let Some(param_type) = param_type {
108108
param_type
@@ -344,3 +344,139 @@ fn add_type_check_diagnostic(
344344
}
345345
}
346346
}
347+
348+
fn get_arg_infos(
349+
semantic_model: &SemanticModel,
350+
call_expr: &LuaCallExpr,
351+
) -> Option<Vec<(LuaType, TextRange)>> {
352+
let arg_exprs = call_expr.get_args_list()?.get_args().collect::<Vec<_>>();
353+
let mut arg_infos = infer_expr_list_types(semantic_model, &arg_exprs);
354+
for (arg_type, arg_expr) in arg_infos.iter_mut() {
355+
let extend_type = try_instantiate_arg_type(semantic_model, arg_type, arg_expr, 0);
356+
if let Some(extend_type) = extend_type {
357+
*arg_type = extend_type;
358+
}
359+
}
360+
361+
let arg_infos = arg_infos
362+
.into_iter()
363+
.map(|(arg_type, arg_expr)| (arg_type, arg_expr.get_range()))
364+
.collect();
365+
366+
Some(arg_infos)
367+
}
368+
369+
fn try_instantiate_arg_type(
370+
semantic_model: &SemanticModel,
371+
arg_type: &LuaType,
372+
arg_expr: &LuaExpr,
373+
depth: usize,
374+
) -> Option<LuaType> {
375+
match arg_type {
376+
LuaType::TplRef(tpl_ref) => {
377+
let node_or_token = arg_expr.syntax().clone().into();
378+
let semantic_decl =
379+
semantic_model.find_decl(node_or_token, SemanticDeclLevel::default())?;
380+
match tpl_ref.get_tpl_id() {
381+
GenericTplId::Func(tpl_id) => {
382+
if let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl {
383+
let decl = semantic_model
384+
.get_db()
385+
.get_decl_index()
386+
.get_decl(&decl_id)?;
387+
match decl.extra {
388+
LuaDeclExtra::Param { signature_id, .. } => {
389+
let signature = semantic_model
390+
.get_db()
391+
.get_signature_index()
392+
.get(&signature_id)?;
393+
if let Some((_, param_type)) =
394+
signature.generic_params.get(tpl_id as usize)
395+
{
396+
return param_type.clone();
397+
}
398+
}
399+
_ => return None,
400+
}
401+
}
402+
None
403+
}
404+
GenericTplId::Type(tpl_id) => {
405+
if let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl {
406+
let decl = semantic_model
407+
.get_db()
408+
.get_decl_index()
409+
.get_decl(&decl_id)?;
410+
match decl.extra {
411+
LuaDeclExtra::Param {
412+
owner_member_id, ..
413+
} => {
414+
let owner_member_id = owner_member_id?;
415+
let parent_owner = semantic_model
416+
.get_db()
417+
.get_member_index()
418+
.get_current_owner(&owner_member_id)?;
419+
match parent_owner {
420+
LuaMemberOwner::Type(type_id) => {
421+
let generic_params = semantic_model
422+
.get_db()
423+
.get_type_index()
424+
.get_generic_params(&type_id)?;
425+
return generic_params.get(tpl_id as usize)?.1.clone();
426+
}
427+
_ => return None,
428+
}
429+
}
430+
_ => return None,
431+
}
432+
}
433+
None
434+
}
435+
}
436+
}
437+
LuaType::Union(union_type) => {
438+
if depth > 1 {
439+
return None;
440+
}
441+
let mut result = LuaType::Unknown;
442+
for union_member_type in union_type.into_vec().iter() {
443+
let extend_type = try_instantiate_arg_type(
444+
semantic_model,
445+
union_member_type,
446+
arg_expr,
447+
depth + 1,
448+
)
449+
.unwrap_or(union_member_type.clone());
450+
result = TypeOps::Union.apply(semantic_model.get_db(), &result, &extend_type);
451+
}
452+
Some(result)
453+
}
454+
_ => None,
455+
}
456+
}
457+
458+
fn infer_expr_list_types(
459+
semantic_model: &SemanticModel,
460+
exprs: &[LuaExpr],
461+
) -> Vec<(LuaType, LuaExpr)> {
462+
let mut value_types = Vec::new();
463+
for expr in exprs.iter() {
464+
let expr_type = semantic_model
465+
.infer_expr(expr.clone())
466+
.unwrap_or(LuaType::Unknown);
467+
match expr_type {
468+
LuaType::Variadic(variadic) => match variadic.deref() {
469+
VariadicType::Base(base) => {
470+
value_types.push((base.clone(), expr.clone()));
471+
}
472+
VariadicType::Multi(vecs) => {
473+
for typ in vecs {
474+
value_types.push((typ.clone(), expr.clone()));
475+
}
476+
}
477+
},
478+
_ => value_types.push((expr_type.clone(), expr.clone())),
479+
}
480+
}
481+
value_types
482+
}

crates/emmylua_code_analysis/src/diagnostic/test/generic_constraint_mismatch_test.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,4 +239,64 @@ mod test {
239239
"#
240240
));
241241
}
242+
243+
#[test]
244+
fn test_union_2() {
245+
let mut ws = VirtualWorkspace::new();
246+
ws.def(
247+
r#"
248+
---@generic T: table
249+
---@param obj T
250+
function add(obj)
251+
end
252+
253+
---@class GCNode
254+
"#,
255+
);
256+
assert!(ws.check_code_for(
257+
DiagnosticCode::GenericConstraintMismatch,
258+
r#"
259+
---@generic T: table
260+
---@param obj T | string
261+
---@return T?
262+
function bindGC(obj)
263+
if type(obj) == "string" then
264+
---@type GCNode
265+
obj = {}
266+
end
267+
268+
return add(obj)
269+
end
270+
"#
271+
));
272+
}
273+
274+
#[test]
275+
fn test_union_3() {
276+
let mut ws = VirtualWorkspace::new();
277+
ws.def(
278+
r#"
279+
---@generic T: table
280+
---@param obj T
281+
function add(obj)
282+
end
283+
284+
285+
"#,
286+
);
287+
assert!(ws.check_code_for(
288+
DiagnosticCode::GenericConstraintMismatch,
289+
r#"
290+
291+
---@class GCNode<T: table>
292+
GCNode = {}
293+
294+
---@param obj T
295+
---@return T?
296+
function GCNode:bindGC(obj)
297+
return add(obj)
298+
end
299+
"#
300+
));
301+
}
242302
}

0 commit comments

Comments
 (0)