Skip to content

Commit aaf1fed

Browse files
committed
Rust: Implement overloaded index expression in type inference
1 parent b234d77 commit aaf1fed

File tree

5 files changed

+71
-113
lines changed

5 files changed

+71
-113
lines changed

rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,11 @@ private module Cached {
959959

960960
cached
961961
newtype TDataFlowCall =
962-
TCall(CallCfgNode c) { Stages::DataFlowStage::ref() } or
962+
TCall(CallCfgNode c) {
963+
Stages::DataFlowStage::ref() and
964+
// TODO: Handle index expressions as calls in data flow.
965+
not c.getCall() instanceof IndexExpr
966+
} or
963967
TSummaryCall(
964968
FlowSummaryImpl::Public::SummarizedCallable c, FlowSummaryImpl::Private::SummaryNode receiver
965969
) {

rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,20 @@ module Impl {
160160
pos.asPosition() = 0 and result = super.getOperand(1)
161161
}
162162
}
163+
164+
private class IndexCall extends Call instanceof IndexExpr {
165+
override string getMethodName() { result = "index" }
166+
167+
override Trait getTrait() { result.getCanonicalPath() = "core::ops::index::Index" }
168+
169+
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) {
170+
pos.isSelf() and certain = true
171+
}
172+
173+
override Expr getArgument(ArgumentPosition pos) {
174+
pos.isSelf() and result = super.getBase()
175+
or
176+
pos.asPosition() = 0 and result = super.getIndex()
177+
}
178+
}
163179
}

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -772,45 +772,48 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) {
772772
n = a.getNodeAt(apos) and
773773
result = CallExprBaseMatching::inferAccessType(a, apos, path0)
774774
|
775-
(
775+
if
776776
apos.isBorrowed(true)
777777
or
778-
// The desugaring of the unary `*e` is `*Deref::deref(&e)`. To handle the
779-
// deref expression after the call we must strip a `&` from the type at
780-
// the return position.
781-
apos.isReturn() and a instanceof DerefExpr
782-
) and
783-
path0.isCons(TRefTypeParameter(), path)
784-
or
785-
apos.isBorrowed(false) and
786-
exists(Type argType | argType = inferType(n) |
787-
if argType = TRefType()
778+
// The desugaring of the unary `*e` is `*Deref::deref(&e)` and the
779+
// desugaring of `a[b]` is `*Index::index(&a, b)`. To handle the deref
780+
// expression after the call we must strip a `&` from the type at the
781+
// return position.
782+
apos.isReturn() and
783+
(a instanceof DerefExpr or a instanceof IndexExpr)
784+
then path0.isCons(TRefTypeParameter(), path)
785+
else
786+
if apos.isBorrowed(false)
788787
then
789-
path = path0 and
790-
path0.isCons(TRefTypeParameter(), _)
791-
or
792-
// adjust for implicit deref
793-
not path0.isCons(TRefTypeParameter(), _) and
794-
not (path0.isEmpty() and result = TRefType()) and
795-
path = TypePath::cons(TRefTypeParameter(), path0)
796-
else (
797-
not (
798-
argType.(StructType).asItemNode() instanceof StringStruct and
799-
result.(StructType).asItemNode() instanceof Builtins::Str
800-
) and
801-
(
802-
not path0.isCons(TRefTypeParameter(), _) and
803-
not (path0.isEmpty() and result = TRefType()) and
804-
path = path0
805-
or
806-
// adjust for implicit borrow
807-
path0.isCons(TRefTypeParameter(), path)
788+
exists(Type argType | argType = inferType(n) |
789+
if argType = TRefType()
790+
then
791+
path = path0 and
792+
path0.isCons(TRefTypeParameter(), _)
793+
or
794+
// adjust for implicit deref
795+
not path0.isCons(TRefTypeParameter(), _) and
796+
not (path0.isEmpty() and result = TRefType()) and
797+
path = TypePath::cons(TRefTypeParameter(), path0)
798+
else (
799+
not (
800+
argType.(StructType).asItemNode() instanceof StringStruct and
801+
result.(StructType).asItemNode() instanceof Builtins::Str
802+
) and
803+
(
804+
not path0.isCons(TRefTypeParameter(), _) and
805+
not (path0.isEmpty() and result = TRefType()) and
806+
path = path0
807+
or
808+
// adjust for implicit borrow
809+
path0.isCons(TRefTypeParameter(), path)
810+
)
811+
)
808812
)
813+
else (
814+
not apos.isBorrowed(_) and
815+
path = path0
809816
)
810-
)
811-
or
812-
not apos.isBorrowed(_) and
813-
path = path0
814817
)
815818
}
816819

@@ -1116,8 +1119,8 @@ private class Vec extends Struct {
11161119
*/
11171120
pragma[nomagic]
11181121
private Type inferIndexExprType(IndexExpr ie, TypePath path) {
1119-
// TODO: Should be implemented as method resolution, using the special
1120-
// `std::ops::Index` trait.
1122+
// TODO: Method resolution to the `std::ops::Index` trait can handle the
1123+
// `Index` instances for slices and arrays.
11211124
exists(TypePath exprPath, Builtins::BuiltinType t |
11221125
TStruct(t) = inferType(ie.getIndex()) and
11231126
(
@@ -1129,8 +1132,6 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
11291132
) and
11301133
result = inferType(ie.getBase(), exprPath)
11311134
|
1132-
exprPath.isCons(any(Vec v).getElementTypeParameter(), path)
1133-
or
11341135
exprPath.isCons(any(ArrayTypeParameter tp), path)
11351136
or
11361137
exists(TypePath path0 |
@@ -1601,8 +1602,8 @@ private module Debug {
16011602
private Locatable getRelevantLocatable() {
16021603
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
16031604
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
1604-
filepath.matches("%/sqlx.rs") and
1605-
startline = [56 .. 60]
1605+
filepath.matches("%/main.rs") and
1606+
startline = [1854 .. 1880]
16061607
)
16071608
}
16081609

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,21 +1856,21 @@ mod indexers {
18561856

18571857
// MyVec::index
18581858
fn index(&self, index: usize) -> &Self::Output {
1859-
&self.data[index] // $ fieldof=MyVec
1859+
&self.data[index] // $ fieldof=MyVec method=index
18601860
}
18611861
}
18621862

18631863
fn analyze_slice(slice: &[S]) {
1864-
let x = slice[0].foo(); // $ method=foo type=x:S
1864+
let x = slice[0].foo(); // $ method=foo type=x:S method=index
18651865
}
18661866

18671867
pub fn f() {
18681868
let mut vec = MyVec::new(); // $ type=vec:T.S
18691869
vec.push(S); // $ method=push
1870-
vec[0].foo(); // $ MISSING: method=foo -- type inference does not support the `Index` trait yet
1870+
vec[0].foo(); // $ method=MyVec::index method=foo
18711871

18721872
let xs: [S; 1] = [S];
1873-
let x = xs[0].foo(); // $ method=foo type=x:S
1873+
let x = xs[0].foo(); // $ method=foo type=x:S method=index
18741874

18751875
analyze_slice(&xs);
18761876
}

0 commit comments

Comments
 (0)