Skip to content

Commit 97cf8fe

Browse files
committed
Rust: Type inference for pattern matching
1 parent 0c82b6d commit 97cf8fe

File tree

6 files changed

+262
-21
lines changed

6 files changed

+262
-21
lines changed

rust/ql/lib/codeql/rust/elements/internal/StructPatImpl.qll

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,20 @@ module Impl {
3333
name = this.getStructPatFieldList().getAField().getFieldName()
3434
}
3535

36-
/** Gets the record field that matches the `name` pattern of this pattern. */
36+
/** Gets the struct field that matches the `name` pattern of this pattern. */
3737
pragma[nomagic]
3838
StructField getStructField(string name) {
3939
exists(PathResolution::ItemNode i | i = this.getResolvedPath(name) |
4040
result.isStructField(i, name) or
4141
result.isVariantField(i, name)
4242
)
4343
}
44+
45+
/** Gets the struct pattern for the field `name`. */
46+
pragma[nomagic]
47+
StructPatField getPatField(string name) {
48+
result = this.getStructPatFieldList().getAField() and
49+
name = result.getFieldName()
50+
}
4451
}
4552
}

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,20 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
276276
or
277277
n1 = n2.(MatchExpr).getAnArm().getExpr()
278278
or
279+
exists(LetExpr let |
280+
n1 = let.getScrutinee() and
281+
n2 = let.getPat()
282+
)
283+
or
284+
exists(MatchExpr me |
285+
n1 = me.getScrutinee() and
286+
n2 = me.getAnArm().getPat()
287+
)
288+
or
289+
n1 = n2.(OrPat).getAPat()
290+
or
291+
n1 = n2.(ParenPat).getPat()
292+
or
279293
exists(BreakExpr break |
280294
break.getExpr() = n1 and
281295
break.getTarget() = n2.(LoopExpr)
@@ -287,12 +301,18 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
287301
)
288302
or
289303
n1 = n2.(MacroExpr).getMacroCall().getMacroCallExpansion()
304+
or
305+
n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion()
290306
)
291307
or
292308
n1 = n2.(RefExpr).getExpr() and
293309
prefix1.isEmpty() and
294310
prefix2 = TypePath::singleton(TRefTypeParameter())
295311
or
312+
n1 = n2.(RefPat).getPat() and
313+
prefix1.isEmpty() and
314+
prefix2 = TypePath::singleton(TRefTypeParameter())
315+
or
296316
exists(BlockExpr be |
297317
n1 = be and
298318
n2 = be.getStmtList().getTailExpr() and
@@ -478,7 +498,7 @@ private module StructExprMatchingInput implements MatchingInputSig {
478498
Type getInferredType(AccessPosition apos, TypePath path) {
479499
result = inferType(this.getNodeAt(apos), path)
480500
or
481-
// The struct type is supplied explicitly as a type qualifier, e.g.
501+
// The struct/enum type is supplied explicitly as a type qualifier, e.g.
482502
// `Foo<Bar>::Variant { ... }`.
483503
apos.isStructPos() and
484504
exists(Path p, TypeMention tm |
@@ -576,7 +596,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
576596
}
577597
}
578598

579-
abstract private class TupleDeclaration extends Declaration {
599+
abstract additional class TupleDeclaration extends Declaration {
580600
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
581601
result = super.getDeclaredType(dpos, path)
582602
or
@@ -1178,6 +1198,120 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
11781198
)
11791199
}
11801200

1201+
/**
1202+
* A matching configuration for resolving types of struct patterns
1203+
* like `let Foo { bar } = ...`.
1204+
*/
1205+
private module StructPatMatchingInput implements MatchingInputSig {
1206+
class DeclarationPosition = StructExprMatchingInput::DeclarationPosition;
1207+
1208+
class Declaration = StructExprMatchingInput::Declaration;
1209+
1210+
class AccessPosition = DeclarationPosition;
1211+
1212+
class Access extends StructPat {
1213+
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }
1214+
1215+
AstNode getNodeAt(AccessPosition apos) {
1216+
result = this.getPatField(apos.asFieldPos()).getPat()
1217+
or
1218+
result = this and
1219+
apos.isStructPos()
1220+
}
1221+
1222+
Type getInferredType(AccessPosition apos, TypePath path) {
1223+
result = inferType(this.getNodeAt(apos), path)
1224+
or
1225+
// The struct/enum type is supplied explicitly as a type qualifier, e.g.
1226+
// `let Foo<Bar>::Variant { ... } = ...`.
1227+
apos.isStructPos() and
1228+
exists(Path p, TypeMention tm |
1229+
p = this.getPath() and
1230+
if resolvePath(p) instanceof Variant then tm = p.getQualifier() else tm = p
1231+
|
1232+
result = tm.resolveTypeAt(path)
1233+
)
1234+
}
1235+
1236+
Declaration getTarget() { result = resolvePath(this.getPath()) }
1237+
}
1238+
1239+
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
1240+
apos = dpos
1241+
}
1242+
}
1243+
1244+
private module StructPatMatching = Matching<StructPatMatchingInput>;
1245+
1246+
/**
1247+
* Gets the type of `n` at `path`, where `n` is either a struct pattern or
1248+
* a field pattern of a struct pattern.
1249+
*/
1250+
pragma[nomagic]
1251+
private Type inferStructPatType(AstNode n, TypePath path) {
1252+
exists(StructPatMatchingInput::Access a, StructPatMatchingInput::AccessPosition apos |
1253+
n = a.getNodeAt(apos) and
1254+
result = StructPatMatching::inferAccessType(a, apos, path)
1255+
)
1256+
}
1257+
1258+
/**
1259+
* A matching configuration for resolving types of tuple struct patterns
1260+
* like `let Some(x) = ...`.
1261+
*/
1262+
private module TupleStructPatMatchingInput implements MatchingInputSig {
1263+
class DeclarationPosition = CallExprBaseMatchingInput::DeclarationPosition;
1264+
1265+
class Declaration = CallExprBaseMatchingInput::TupleDeclaration;
1266+
1267+
class AccessPosition = DeclarationPosition;
1268+
1269+
class Access extends TupleStructPat {
1270+
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }
1271+
1272+
AstNode getNodeAt(AccessPosition apos) {
1273+
result = this.getField(apos.asPosition())
1274+
or
1275+
result = this and
1276+
apos.isSelf()
1277+
}
1278+
1279+
Type getInferredType(AccessPosition apos, TypePath path) {
1280+
result = inferType(this.getNodeAt(apos), path)
1281+
or
1282+
// The struct/enum type is supplied explicitly as a type qualifier, e.g.
1283+
// `let Option::<Foo>(x) = ...`.
1284+
apos.isSelf() and
1285+
exists(Path p, TypeMention tm |
1286+
p = this.getPath() and
1287+
if resolvePath(p) instanceof Variant then tm = p.getQualifier() else tm = p
1288+
|
1289+
result = tm.resolveTypeAt(path)
1290+
)
1291+
}
1292+
1293+
Declaration getTarget() { result = resolvePath(this.getPath()) }
1294+
}
1295+
1296+
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
1297+
apos = dpos
1298+
}
1299+
}
1300+
1301+
private module TupleStructPatMatching = Matching<TupleStructPatMatchingInput>;
1302+
1303+
/**
1304+
* Gets the type of `n` at `path`, where `n` is either a tuple struct pattern or
1305+
* a positional pattern of a tuple struct pattern.
1306+
*/
1307+
pragma[nomagic]
1308+
private Type inferTupleStructPatType(AstNode n, TypePath path) {
1309+
exists(TupleStructPatMatchingInput::Access a, TupleStructPatMatchingInput::AccessPosition apos |
1310+
n = a.getNodeAt(apos) and
1311+
result = TupleStructPatMatching::inferAccessType(a, apos, path)
1312+
)
1313+
}
1314+
11811315
final private class ForIterableExpr extends Expr {
11821316
ForIterableExpr() { this = any(ForExpr fe).getIterable() }
11831317

@@ -1836,6 +1970,10 @@ private module Cached {
18361970
result = inferForLoopExprType(n, path)
18371971
or
18381972
result = inferCastExprType(n, path)
1973+
or
1974+
result = inferStructPatType(n, path)
1975+
or
1976+
result = inferTupleStructPatType(n, path)
18391977
}
18401978
}
18411979

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2362,12 +2362,12 @@ pub mod pattern_matching {
23622362
pub fn f() -> Option<()> {
23632363
let value = Some(42);
23642364
if let Some(mesg) = value {
2365-
let mesg = mesg; // $ MISSING: type=mesg:i32
2365+
let mesg = mesg; // $ type=mesg:i32
23662366
println!("{mesg}");
23672367
}
23682368
match value {
23692369
Some(mesg) => {
2370-
let mesg = mesg; // $ MISSING: type=mesg:i32
2370+
let mesg = mesg; // $ type=mesg:i32
23712371
println!("{mesg}");
23722372
}
23732373
None => (),
@@ -2383,15 +2383,15 @@ pub mod pattern_matching {
23832383
value2: false,
23842384
};
23852385
if let MyRecordStruct { value1, value2 } = my_record_struct {
2386-
let x = value1; // $ MISSING: type=x:i32
2387-
let y = value2; // $ MISSING: type=y:bool
2386+
let x = value1; // $ type=x:i32
2387+
let y = value2; // $ type=y:bool
23882388
();
23892389
}
23902390

23912391
let my_tuple_struct = MyTupleStruct(42, false);
23922392
if let MyTupleStruct(value1, value2) = my_tuple_struct {
2393-
let x = value1; // $ MISSING: type=x:i32
2394-
let y = value2; // $ MISSING: type=y:bool
2393+
let x = value1; // $ type=x:i32
2394+
let y = value2; // $ type=y:bool
23952395
();
23962396
}
23972397

@@ -2401,13 +2401,13 @@ pub mod pattern_matching {
24012401
};
24022402
match my_enum1 {
24032403
MyEnum::Variant1 { value1, value2 } => {
2404-
let x = value1; // $ MISSING: type=x:i32
2405-
let y = value2; // $ MISSING: type=y:bool
2404+
let x = value1; // $ type=x:i32
2405+
let y = value2; // $ type=y:bool
24062406
();
24072407
}
24082408
MyEnum::Variant2(value1, value2) => {
2409-
let x = value1; // $ MISSING: type=x:bool
2410-
let y = value2; // $ MISSING: type=y:i32
2409+
let x = value1; // $ type=x:bool
2410+
let y = value2; // $ type=y:i32
24112411
();
24122412
}
24132413
}

0 commit comments

Comments
 (0)