Skip to content

Commit 665732f

Browse files
committed
Rust: Implement certain type information for annotation and simple calls
1 parent ddad971 commit 665732f

File tree

4 files changed

+158
-404
lines changed

4 files changed

+158
-404
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 151 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,14 @@ private module M2 = Make2<Input2>;
217217

218218
private import M2
219219

220-
module Consistency = M2::Consistency;
220+
module Consistency {
221+
import M2::Consistency
222+
223+
query predicate nonUniqueCertainType(AstNode n, TypePath path) {
224+
exists(CertainTypeInference::inferCertainType(n, path)) and
225+
strictcount(CertainTypeInference::inferCertainType(n, path)) > 1
226+
}
227+
}
221228

222229
/** Gets the type annotation that applies to `n`, if any. */
223230
private TypeMention getTypeAnnotation(AstNode n) {
@@ -245,6 +252,134 @@ private Type inferAnnotatedType(AstNode n, TypePath path) {
245252
result = getTypeAnnotation(n).resolveTypeAt(path)
246253
}
247254

255+
/** Module for inferring certain type information. */
256+
private module CertainTypeInference {
257+
/** Holds if the type mention does not contain any inferred types `_`. */
258+
predicate typeMentionIsComplete(TypeMention tm) {
259+
not exists(InferTypeRepr t | t.getParentNode*() = tm)
260+
}
261+
262+
/**
263+
* Holds if `ce` is a call where we can infer the type with certainty and if
264+
* `f` is the target of the call and `p` the path invoked by the call.
265+
*
266+
* Necessary conditions for this are:
267+
* - We are certain of the call target (i.e., the call target can not depend on type information).
268+
* - The declared type of the function does not contain any generics that we
269+
* need to infer.
270+
* - The call does not contain any arguments, as arguments in calls are coercion sites.
271+
*
272+
* The current requirements are made to allow for call to `new` functions such
273+
* as `Vec<Foo>::new()` but not much more.
274+
*/
275+
predicate certainCallExprTarget(CallExpr ce, Function f, Path p) {
276+
p = CallExprImpl::getFunctionPath(ce) and
277+
f = resolvePath(p) and
278+
// The function is not in a trait
279+
not any(TraitItemNode t).getAnAssocItem() = f and
280+
// The function is not in a trait implementation
281+
not any(ImplItemNode impl | impl.(Impl).hasTrait()).getAnAssocItem() = f and
282+
// The function does not have parameters.
283+
not f.getParamList().hasSelfParam() and
284+
f.getParamList().getNumberOfParams() = 0 and
285+
// The function is not async.
286+
not f.isAsync() and
287+
// For now, exclude functions in macro expansions.
288+
not ce.isInMacroExpansion() and
289+
// The function has no type parameters.
290+
not f.hasGenericParamList() and
291+
// The function does not have `impl` types among its parameters (these are type parameters).
292+
not any(ImplTraitTypeRepr itt | not itt.isInReturnPos()).getFunction() = f and
293+
(
294+
not exists(ImplItemNode impl | impl.getAnAssocItem() = f)
295+
or
296+
// If the function is in an impl then the impl block has no type
297+
// parameters or all the type parameters are given explicitly.
298+
exists(ImplItemNode impl | impl.getAnAssocItem() = f |
299+
not impl.(Impl).hasGenericParamList() or
300+
impl.(Impl).getGenericParamList().getNumberOfGenericParams() =
301+
p.getQualifier().getSegment().getGenericArgList().getNumberOfGenericArgs()
302+
)
303+
)
304+
}
305+
306+
private ImplItemNode getFunctionImpl(FunctionItemNode f) { result.getAnAssocItem() = f }
307+
308+
Type inferCertainCallExprType(CallExpr ce, TypePath path) {
309+
exists(Function f, Type ty, TypePath prefix, Path p |
310+
certainCallExprTarget(ce, f, p) and
311+
ty = f.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(prefix)
312+
|
313+
if ty.(TypeParamTypeParameter).getTypeParam() = getFunctionImpl(f).getTypeParam(_)
314+
then
315+
exists(TypePath pathToTp, TypePath suffix |
316+
// For type parameters of the `impl` block we must resolve their
317+
// instantiation from the path. For instance, for `impl<A> for Foo<A>`
318+
// and the path `Foo<i64>::bar` we must resolve `A` to `i64`.
319+
ty = getFunctionImpl(f).(Impl).getSelfTy().(TypeMention).resolveTypeAt(pathToTp) and
320+
result = p.getQualifier().(TypeMention).resolveTypeAt(pathToTp.appendInverse(suffix)) and
321+
path = prefix.append(suffix)
322+
)
323+
else (
324+
result = ty and path = prefix
325+
)
326+
)
327+
}
328+
329+
predicate certainTypeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
330+
prefix1.isEmpty() and
331+
prefix2.isEmpty() and
332+
(
333+
exists(Variable v | n1 = v.getAnAccess() |
334+
n2 = v.getPat().getName() or n2 = v.getParameter().(SelfParam)
335+
)
336+
or
337+
// A `let` statement with a type annotation is a coercion site and hence
338+
// is not a certain type equality.
339+
exists(LetStmt let | not let.hasTypeRepr() |
340+
let.getPat() = n1 and
341+
let.getInitializer() = n2
342+
)
343+
)
344+
or
345+
n1 =
346+
any(IdentPat ip |
347+
n2 = ip.getName() and
348+
prefix1.isEmpty() and
349+
if ip.isRef() then prefix2 = TypePath::singleton(TRefTypeParameter()) else prefix2.isEmpty()
350+
)
351+
}
352+
353+
pragma[nomagic]
354+
private Type inferCertainTypeEquality(AstNode n, TypePath path) {
355+
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
356+
result = inferCertainType(n2, prefix2.appendInverse(suffix)) and
357+
path = prefix1.append(suffix)
358+
|
359+
certainTypeEquality(n, prefix1, n2, prefix2)
360+
or
361+
certainTypeEquality(n2, prefix2, n, prefix1)
362+
)
363+
}
364+
365+
/**
366+
* Holds if `n` has complete and certain type information and if `n` has the
367+
* resulting type at `path`.
368+
*/
369+
pragma[nomagic]
370+
Type inferCertainType(AstNode n, TypePath path) {
371+
exists(TypeMention tm |
372+
tm = getTypeAnnotation(n) and
373+
typeMentionIsComplete(tm) and
374+
result = tm.resolveTypeAt(path)
375+
)
376+
or
377+
result = inferCertainCallExprType(n, path)
378+
or
379+
result = inferCertainTypeEquality(n, path)
380+
}
381+
}
382+
248383
private Type inferLogicalOperationType(AstNode n, TypePath path) {
249384
exists(Builtins::BuiltinType t, BinaryLogicalOperation be |
250385
n = [be, be.getLhs(), be.getRhs()] and
@@ -284,15 +419,11 @@ private Struct getRangeType(RangeExpr re) {
284419
* through the type equality.
285420
*/
286421
private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
422+
CertainTypeInference::certainTypeEquality(n1, prefix1, n2, prefix2)
423+
or
287424
prefix1.isEmpty() and
288425
prefix2.isEmpty() and
289426
(
290-
exists(Variable v | n1 = v.getAnAccess() |
291-
n2 = v.getPat().getName()
292-
or
293-
n2 = v.getParameter().(SelfParam)
294-
)
295-
or
296427
exists(LetStmt let |
297428
let.getPat() = n1 and
298429
let.getInitializer() = n2
@@ -335,13 +466,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
335466
n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion()
336467
)
337468
or
338-
n1 =
339-
any(IdentPat ip |
340-
n2 = ip.getName() and
341-
prefix1.isEmpty() and
342-
if ip.isRef() then prefix2 = TypePath::singleton(TRefTypeParameter()) else prefix2.isEmpty()
343-
)
344-
or
345469
(
346470
n1 = n2.(RefExpr).getExpr() or
347471
n1 = n2.(RefPat).getPat()
@@ -404,6 +528,9 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
404528

405529
pragma[nomagic]
406530
private Type inferTypeEquality(AstNode n, TypePath path) {
531+
// Don't propagate type information into a node for which we already have
532+
// certain type information.
533+
not exists(CertainTypeInference::inferCertainType(n, _)) and
407534
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
408535
result = inferType(n2, prefix2.appendInverse(suffix)) and
409536
path = prefix1.append(suffix)
@@ -814,6 +941,8 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
814941
}
815942

816943
final class Access extends Call {
944+
Access() { not CertainTypeInference::certainCallExprTarget(this, _, _) }
945+
817946
pragma[nomagic]
818947
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
819948
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
@@ -2146,6 +2275,8 @@ private module Cached {
21462275
cached
21472276
Type inferType(AstNode n, TypePath path) {
21482277
Stages::TypeInferenceStage::ref() and
2278+
result = CertainTypeInference::inferCertainType(n, path)
2279+
or
21492280
result = inferAnnotatedType(n, path)
21502281
or
21512282
result = inferLogicalOperationType(n, path)
@@ -2291,4 +2422,10 @@ private module Debug {
22912422
c = countTypePaths(n, path, t) and
22922423
c = max(countTypePaths(_, _, _))
22932424
}
2425+
2426+
Type debugInferCertainNonUniqueType(AstNode n, TypePath path) {
2427+
n = getRelevantLocatable() and
2428+
Consistency::nonUniqueCertainType(n, path) and
2429+
result = CertainTypeInference::inferCertainType(n, path)
2430+
}
22942431
}

rust/ql/lib/codeql/rust/internal/TypeInferenceConsistency.qll

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
* Provides classes for recognizing type inference inconsistencies.
33
*/
44

5+
private import rust
56
private import Type
67
private import TypeMention
7-
private import TypeInference::Consistency as Consistency
8-
import TypeInference::Consistency
8+
private import TypeInference
9+
private import Consistency
910

1011
query predicate illFormedTypeMention(TypeMention tm) {
1112
Consistency::illFormedTypeMention(tm) and
@@ -27,4 +28,7 @@ int getTypeInferenceInconsistencyCounts(string type) {
2728
or
2829
type = "Ill-formed type mention" and
2930
result = count(TypeMention tm | illFormedTypeMention(tm) | tm)
31+
or
32+
type = "Non-unique certain type information" and
33+
result = count(AstNode n, TypePath path | nonUniqueCertainType(n, path) | n)
3034
}

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2352,7 +2352,7 @@ mod loops {
23522352
#[rustfmt::skip]
23532353
let _ = while a < 10 // $ target=lt type=a:i64
23542354
{
2355-
a += 1; // $ type=a:i64 target=add_assign
2355+
a += 1; // $ type=a:i64 MISSING: target=add_assign
23562356
};
23572357
}
23582358
}

0 commit comments

Comments
 (0)