Skip to content

Commit 07fbea4

Browse files
authored
Merge pull request #662 from EmmyLuaLs/generic
refactor generic infer
2 parents b184288 + 473f478 commit 07fbea4

File tree

13 files changed

+632
-380
lines changed

13 files changed

+632
-380
lines changed

crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ use emmylua_parser::{LuaAstNode, LuaAstToken, LuaExpr, LuaForRangeStat};
33
use crate::{
44
compilation::analyzer::unresolve::UnResolveIterVar, infer_expr, instantiate_doc_function,
55
tpl_pattern_match_args, DbIndex, InferFailReason, LuaDeclId, LuaInferCache,
6-
LuaOperatorMetaMethod, LuaType, LuaTypeCache, TypeOps, TypeSubstitutor, VariadicType,
6+
LuaOperatorMetaMethod, LuaType, LuaTypeCache, TplContext, TypeOps, TypeSubstitutor,
7+
VariadicType,
78
};
89

910
use super::LuaAnalyzer;
@@ -149,19 +150,19 @@ pub fn infer_for_range_iter_expr_func(
149150
return Ok(doc_function.get_variadic_ret());
150151
}
151152
let mut substitutor = TypeSubstitutor::new();
153+
let mut context = TplContext {
154+
db,
155+
cache,
156+
substitutor: &mut substitutor,
157+
root: root,
158+
};
152159
let params = doc_function
153160
.get_params()
154161
.iter()
155162
.map(|(_, opt_ty)| opt_ty.clone().unwrap_or(LuaType::Any))
156163
.collect::<Vec<_>>();
157-
tpl_pattern_match_args(
158-
db,
159-
cache,
160-
&params,
161-
&vec![status_param.clone().unwrap()],
162-
&root,
163-
&mut substitutor,
164-
)?;
164+
165+
tpl_pattern_match_args(&mut context, &params, &vec![status_param.clone().unwrap()])?;
165166

166167
let instantiate_func = if let LuaType::DocFunction(f) =
167168
instantiate_doc_function(db, &doc_function, &substitutor)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#[cfg(test)]
2+
mod test {
3+
use crate::VirtualWorkspace;
4+
5+
#[test]
6+
fn test_issue_586() {
7+
let mut ws = VirtualWorkspace::new_with_init_std_lib();
8+
ws.def(
9+
r#"
10+
--- @generic T
11+
--- @param cb fun(...: T...)
12+
--- @param ... T...
13+
function invoke1(cb, ...)
14+
cb(...)
15+
end
16+
17+
invoke1(
18+
function(a, b, c)
19+
_a = a
20+
_b = b
21+
_c = c
22+
end,
23+
1, "2", "3"
24+
)
25+
"#,
26+
);
27+
28+
let a_ty = ws.expr_ty("_a");
29+
let b_ty = ws.expr_ty("_b");
30+
let c_ty = ws.expr_ty("_c");
31+
32+
assert_eq!(a_ty, ws.ty("integer"));
33+
assert_eq!(b_ty, ws.ty("string"));
34+
assert_eq!(c_ty, ws.ty("string"));
35+
}
36+
37+
#[test]
38+
fn test_issue_658() {
39+
let mut ws = VirtualWorkspace::new_with_init_std_lib();
40+
ws.def(
41+
r#"
42+
--- @generic T1, T2, R
43+
--- @param fn fun(_:T1..., _:T2...): R...
44+
--- @param ... T1...
45+
--- @return fun(_:T2...): R...
46+
local function curry(fn, ...)
47+
local nargs, args = select('#', ...), { ... }
48+
return function(...)
49+
local nargs2 = select('#', ...)
50+
for i = 1, nargs2 do
51+
args[nargs + i] = select(i, ...)
52+
end
53+
return fn(unpack(args, 1, nargs + nargs2))
54+
end
55+
end
56+
57+
--- @param a string
58+
--- @param b string
59+
--- @param c table
60+
local function foo(a, b, c) end
61+
62+
bar = curry(foo, 'a')
63+
"#,
64+
);
65+
66+
let bar_ty = ws.expr_ty("bar");
67+
let expected = ws.ty("fun(b:string, c:table)");
68+
assert_eq!(bar_ty, expected);
69+
}
70+
}

crates/emmylua_code_analysis/src/compilation/test/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ mod diagnostic_disable_test;
99
mod export_test;
1010
mod flow;
1111
mod for_range_var_infer_test;
12+
mod generic_test;
1213
mod infer_str_tpl_test;
1314
mod inherit_type;
1415
mod mathlib_test;

crates/emmylua_code_analysis/src/db_index/type/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ mod test;
33
mod type_decl;
44
mod type_ops;
55
mod type_owner;
6+
mod type_visit_trait;
67
mod types;
78

89
use super::traits::LuaIndex;
@@ -14,6 +15,7 @@ pub use type_decl::{
1415
};
1516
pub use type_ops::TypeOps;
1617
pub use type_owner::{LuaTypeCache, LuaTypeOwner};
18+
pub use type_visit_trait::TypeVisitTrait;
1719
pub use types::*;
1820

1921
#[derive(Debug)]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
use crate::LuaType;
2+
3+
pub trait TypeVisitTrait {
4+
fn visit_type<F>(&self, f: &mut F)
5+
where
6+
F: FnMut(&LuaType);
7+
}

0 commit comments

Comments
 (0)