Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
840ef5c
Rust: Add test cases for type inference in loops.
geoffw0 Jun 12, 2025
f76b562
Rust: Implement type inference for 'for' loops on arrays.
geoffw0 Jun 12, 2025
51343a5
Rust: Implement type inference for ArrayListExprs.
geoffw0 Jun 13, 2025
b89d6d3
Rust: Implement type inference for ArrayRepeatExprs.
geoffw0 Jun 13, 2025
62e3cc5
Merge branch 'main' into typeinfer
geoffw0 Jun 13, 2025
6194676
Rust: Accept consistency failures (for now).
geoffw0 Jun 13, 2025
69da4e7
Rust: Move inferArrayExprType logic into typeEquality predicate.
geoffw0 Jun 17, 2025
66d6770
Rust: If we're inferring both ways, it should really be to any element.
geoffw0 Jun 17, 2025
4292b03
Rust: Add logic for Vecs and slices.
geoffw0 Jun 17, 2025
dec0deb
Rust: Add some more test cases for type inference on Vecs.
geoffw0 Jun 17, 2025
639f85a
Merge branch 'main' into typeinfer
geoffw0 Jun 19, 2025
1622d08
Rust: Add inferArrayExprType.
geoffw0 Jun 19, 2025
f670fcb
Rust: Add a Vec test case that we actually get (explicit type).
geoffw0 Jun 19, 2025
7170e97
Rust: Update test expectations format (type=...).
geoffw0 Jun 19, 2025
d55e8b7
Rust: Add another test case for ranges.
geoffw0 Jun 19, 2025
26e7b2d
Rust: Accept path resolution consistency changes.
geoffw0 Jun 19, 2025
7a25596
Merge branch 'main' into typeinfer
geoffw0 Jun 19, 2025
bfaabab
Rust: Update more expectations.
geoffw0 Jun 23, 2025
34cd976
Rust: Run rustfmt --edition 2024 on the test.
geoffw0 Jun 23, 2025
d02a728
Update rust/ql/lib/codeql/rust/internal/TypeInference.qll
geoffw0 Jun 23, 2025
8c848ac
Rust: Effects of rustfmt on .expected.
geoffw0 Jun 23, 2025
4530e85
Rust: Repair the test annotations.
geoffw0 Jun 23, 2025
530ded1
Merge branch 'main' into typeinfer
geoffw0 Jun 23, 2025
21bea7e
Merge branch 'main' into typeinfer
geoffw0 Jun 24, 2025
96dcdf9
Rust: Change note.
geoffw0 Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,16 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
prefix2.isEmpty()
)
)
or
// an array list expression (`[1, 2, 3]`) has the type of the first (any) element
n1.(ArrayListExpr).getExpr(_) = n2 and
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
prefix2.isEmpty()
or
// an array repeat expression (`[1; 3]`) has the type of the repeat operand
n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
prefix2.isEmpty()
}

pragma[nomagic]
Expand Down Expand Up @@ -1124,6 +1134,27 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
)
}

pragma[nomagic]
private Type inferForLoopExprType(AstNode n, TypePath path) {
// type of iterable -> type of pattern (loop variable)
exists(ForExpr fe, Type iterableType, TypePath iterablePath |
n = fe.getPat() and
iterableType = inferType(fe.getIterable(), iterablePath) and
result = iterableType and
(
iterablePath.isCons(any(Vec v).getElementTypeParameter(), path)
or
iterablePath.isCons(any(ArrayTypeParameter tp), path)
or
exists(TypePath path0 |
iterablePath.isCons(any(RefTypeParameter tp), path0) and
path0.isCons(any(SliceTypeParameter tp), path)
)
// TODO: iterables (general case for containers, ranges etc)
)
)
}

final class MethodCall extends Call {
MethodCall() {
exists(this.getReceiver()) and
Expand Down Expand Up @@ -1438,6 +1469,8 @@ private module Cached {
result = inferAwaitExprType(n, path)
or
result = inferIndexExprType(n, path)
or
result = inferForLoopExprType(n, path)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
multiplePathResolutions
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
99 changes: 99 additions & 0 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,104 @@ mod indexers {
}
}

mod loops {
struct MyCallable {
}

impl MyCallable {
fn new() -> Self {
MyCallable {}
}

fn call(&self) -> i64 {
1
}
}

pub fn f() {
// for loops with arrays

for i in [1, 2, 3] { } // $ type=i:i32
for i in [1, 2, 3].map(|x| x + 1) { } // $ MISSING: type=i:i32
for i in [1, 2, 3].into_iter() { } // $ MISSING: type=i:i32

let vals1 = [1u8, 2, 3]; // $ MISSING: type=vals1:[u8; 3]
for u in vals1 { } // $ type=u:u8

let vals2 = [1u16; 3]; // $ MISSING: type=vals2:[u16; 3]
for u in vals2 { } // $ type=u:u16

let vals3: [u32; 3] = [1, 2, 3]; // $ MISSING: type=vals3:[u32; 3]
for u in vals3 { } // $ type=u:u32

let vals4: [u64; 3] = [1; 3]; // $ MISSING: type=vals4:[u64; 3]
for u in vals4 { } // $ type=u:u64

let mut strings1 = ["foo", "bar", "baz"]; // $ MISSING: type=strings1:[&str; 3]
for s in &strings1 { } // $ MISSING: type=s:&str
for s in &mut strings1 { } // $ MISSING: type=s:&str
for s in strings1 { } // $ type=s:str

let strings2 = [String::from("foo"), String::from("bar"), String::from("baz")]; // $ MISSING: type=strings2:[String; 3]
for s in strings2 { } // $ type=s:String

let strings3 = &[String::from("foo"), String::from("bar"), String::from("baz")]; // $ MISSING: type=strings3:&[String; 3]
for s in strings3 { } // $ MISSING: type=s:String

let callables = [MyCallable::new(), MyCallable::new(), MyCallable::new()]; // $ MISSING: type=callables:[MyCallable; 3]
for c in callables { // $ type=c:MyCallable
let result = c.call(); // $ type=result:i64 method=call
}

// for loops with ranges

for i in 0..10 { } // $ MISSING: type=i:i32
for u in [0u8 .. 10] { } // $ MISSING: type=u:u8

let range1 = std::ops::Range { start: 0u16, end: 10u16 }; // $ MISSING: type=range:std::ops::Range<u16>
for u in range1 { } // $ MISSING: type=i:u16

// for loops with containers

let vals3 = vec![1, 2, 3]; // $ MISSING: type=vals3:Vec<i32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not how to make expected type annotations for constructed types, instead you should use the format type=<element>:<path>.<type> (see

else value = element + ":" + path.toString() + "." + t.toString()
). So in this case it should instead by type=vals3:T.i32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think I was relying on the test runner to start suggesting correct syntax, which it doesn't really do for optional results. I've updated the expectations now as best I can - there may still be add mistakes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one still needs to be updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. 😆

Updated.

for i in vals3 { } // $ MISSING: type=i:i32

let vals4 = [1u16, 2, 3].to_vec(); // $ MISSING: type=vals4:Vec<u16>
for u in vals4 { } // $ MISSING: type=u:u16

let vals5 = Vec::from([1u32, 2, 3]); // $ MISSING: type=vals5:Vec<u32>
for u in vals5 { } // $ MISSING: type=u:u32

let vals6 : Vec<&u64> = [1u64, 2, 3].iter().collect(); // $ MISSING: type=vals6:Vec<&u64>
for u in vals6 { } // $ MISSING: type=u:&u64

let mut vals7 = Vec::new(); // $ MISSING: type=vals7:Vec<u8>
vals7.push(1u8); // $ method=push
for u in vals7 { } // $ MISSING: type=u:u8

let matrix1 = vec![vec![1, 2], vec![3, 4]]; // $ MISSING: type=vals5:Vec<Vec<i32>>
for row in matrix1 { // $ MISSING: type=row:Vec<i32>
for cell in row { // $ MISSING: type=cell:i32
}
}

let mut map1 = std::collections::HashMap::new(); // $ MISSING: type=map1:std::collections::HashMap<_, _>
map1.insert(1, Box::new("one")); // $ method=insert
map1.insert(2, Box::new("two")); // $ method=insert
for key in map1.keys() { } // $ method=keys MISSING: type=key:i32
for value in map1.values() { } // $ method=values MISSING: type=value:Box<&str>
for (key, value) in map1.iter() { } // $ method=iter MISSING: type=key:i32 type=value:Box<&str>
for (key, value) in &map1 { } // $ MISSING: type=key:i32 type=value:Box<&str>

// while loops

let mut a: i64 = 0; // $ type=a:i64
while a < 10 { // $ method=lt MISSING: type=a:i64m
a += 1; // $ type=a:i64 method=add_assign
}
}
}

fn main() {
field_access::f();
method_impl::f();
Expand All @@ -1832,4 +1930,5 @@ fn main() {
async_::f();
impl_trait::f();
indexers::f();
loops::f();
}
Loading