Skip to content

Rust: Type inference for pattern matching #20020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
category: minorAnalysis
---
* Type inference has been extended to support pattern matching.
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,20 @@ module Impl {
name = this.getStructPatFieldList().getAField().getFieldName()
}

/** Gets the record field that matches the `name` pattern of this pattern. */
/** Gets the struct field that matches the `name` pattern of this pattern. */
pragma[nomagic]
StructField getStructField(string name) {
exists(PathResolution::ItemNode i | i = this.getResolvedPath(name) |
result.isStructField(i, name) or
result.isVariantField(i, name)
)
}

/** Gets the struct pattern for the field `name`. */
pragma[nomagic]
StructPatField getPatField(string name) {
result = this.getStructPatFieldList().getAField() and
name = result.getFieldName()
}
}
}
166 changes: 153 additions & 13 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
prefix2.isEmpty() and
(
exists(Variable v | n1 = v.getAnAccess() |
n2 = v.getPat()
n2 = v.getPat().getName()
or
n2 = v.getParameter().(SelfParam)
)
Expand All @@ -276,6 +276,22 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
or
n1 = n2.(MatchExpr).getAnArm().getExpr()
or
exists(LetExpr let |
n1 = let.getScrutinee() and
n2 = let.getPat()
)
or
exists(MatchExpr me |
n1 = me.getScrutinee() and
n2 = me.getAnArm().getPat()
)
or
n1 = n2.(OrPat).getAPat()
or
n1 = n2.(ParenPat).getPat()
or
n1 = n2.(LiteralPat).getLiteral()
or
exists(BreakExpr break |
break.getExpr() = n1 and
break.getTarget() = n2.(LoopExpr)
Expand All @@ -287,9 +303,21 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
)
or
n1 = n2.(MacroExpr).getMacroCall().getMacroCallExpansion()
or
n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion()
)
or
n1 = n2.(RefExpr).getExpr() and
n1 =
any(IdentPat ip |
n2 = ip.getName() and
prefix1.isEmpty() and
if ip.isRef() then prefix2 = TypePath::singleton(TRefTypeParameter()) else prefix2.isEmpty()
)
or
(
n1 = n2.(RefExpr).getExpr() or
n1 = n2.(RefPat).getPat()
) and
prefix1.isEmpty() and
prefix2 = TypePath::singleton(TRefTypeParameter())
or
Expand Down Expand Up @@ -478,15 +506,10 @@ private module StructExprMatchingInput implements MatchingInputSig {
Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
or
// The struct type is supplied explicitly as a type qualifier, e.g.
// The struct/enum type is supplied explicitly as a type qualifier, e.g.
// `Foo<Bar>::Variant { ... }`.
apos.isStructPos() and
exists(Path p, TypeMention tm |
p = this.getPath() and
if resolvePath(p) instanceof Variant then tm = p.getQualifier() else tm = p
|
result = tm.resolveTypeAt(path)
)
result = this.getPath().(TypeMention).resolveTypeAt(path)
}

Declaration getTarget() { result = resolvePath(this.getPath()) }
Expand Down Expand Up @@ -576,7 +599,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
}

abstract private class TupleDeclaration extends Declaration {
abstract additional class TupleDeclaration extends Declaration {
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
result = super.getDeclaredType(dpos, path)
or
Expand Down Expand Up @@ -1032,9 +1055,18 @@ private Type inferFieldExprType(AstNode n, TypePath path) {
)
}

/** Gets the root type of the reference expression `re`. */
/** Gets the root type of the reference node `ref`. */
pragma[nomagic]
private Type inferRefExprType(RefExpr re) { exists(re) and result = TRefType() }
private Type inferRefNodeType(AstNode ref) {
(
ref = any(IdentPat ip | ip.isRef()).getName()
or
ref instanceof RefExpr
or
ref instanceof RefPat
) and
result = TRefType()
}

pragma[nomagic]
private Type inferTryExprType(TryExpr te, TypePath path) {
Expand Down Expand Up @@ -1178,6 +1210,110 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
)
}

/**
* A matching configuration for resolving types of struct patterns
* like `let Foo { bar } = ...`.
*/
private module StructPatMatchingInput implements MatchingInputSig {
class DeclarationPosition = StructExprMatchingInput::DeclarationPosition;

class Declaration = StructExprMatchingInput::Declaration;

class AccessPosition = DeclarationPosition;

class Access extends StructPat {
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this none()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since #19847, getTypeArgument is only used for type arguments supplied to functions, not to their defining type, i.e. it is used for foo::<i32>(...) but not for Bar::<i32>::foo(...).

Since explicit type arguments in pattern matching corresponds to the latter, it is none() here, and instead handled in getInferredType below. I'll add some tests for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding test, of course, revealed some corner cases that were not handled, I'm gonna do that on this PR as well.


AstNode getNodeAt(AccessPosition apos) {
result = this.getPatField(apos.asFieldPos()).getPat()
or
result = this and
apos.isStructPos()
}

Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
or
// The struct/enum type is supplied explicitly as a type qualifier, e.g.
// `let Foo<Bar>::Variant { ... } = ...`.
apos.isStructPos() and
result = this.getPath().(TypeMention).resolveTypeAt(path)
}

Declaration getTarget() { result = resolvePath(this.getPath()) }
}

predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
apos = dpos
}
}

private module StructPatMatching = Matching<StructPatMatchingInput>;

/**
* Gets the type of `n` at `path`, where `n` is either a struct pattern or
* a field pattern of a struct pattern.
*/
pragma[nomagic]
private Type inferStructPatType(AstNode n, TypePath path) {
exists(StructPatMatchingInput::Access a, StructPatMatchingInput::AccessPosition apos |
n = a.getNodeAt(apos) and
result = StructPatMatching::inferAccessType(a, apos, path)
)
}

/**
* A matching configuration for resolving types of tuple struct patterns
* like `let Some(x) = ...`.
*/
private module TupleStructPatMatchingInput implements MatchingInputSig {
class DeclarationPosition = CallExprBaseMatchingInput::DeclarationPosition;

class Declaration = CallExprBaseMatchingInput::TupleDeclaration;

class AccessPosition = DeclarationPosition;

class Access extends TupleStructPat {
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }

AstNode getNodeAt(AccessPosition apos) {
result = this.getField(apos.asPosition())
or
result = this and
apos.isSelf()
}

Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
or
// The struct/enum type is supplied explicitly as a type qualifier, e.g.
// `let Option::<Foo>::Some(x) = ...`.
apos.isSelf() and
result = this.getPath().(TypeMention).resolveTypeAt(path)
}

Declaration getTarget() { result = resolvePath(this.getPath()) }
}

predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
apos = dpos
}
}

private module TupleStructPatMatching = Matching<TupleStructPatMatchingInput>;

/**
* Gets the type of `n` at `path`, where `n` is either a tuple struct pattern or
* a positional pattern of a tuple struct pattern.
*/
pragma[nomagic]
private Type inferTupleStructPatType(AstNode n, TypePath path) {
exists(TupleStructPatMatchingInput::Access a, TupleStructPatMatchingInput::AccessPosition apos |
n = a.getNodeAt(apos) and
result = TupleStructPatMatching::inferAccessType(a, apos, path)
)
}

final private class ForIterableExpr extends Expr {
ForIterableExpr() { this = any(ForExpr fe).getIterable() }

Expand Down Expand Up @@ -1813,7 +1949,7 @@ private module Cached {
or
result = inferFieldExprType(n, path)
or
result = inferRefExprType(n) and
result = inferRefNodeType(n) and
path.isEmpty()
or
result = inferTryExprType(n, path)
Expand All @@ -1836,6 +1972,10 @@ private module Cached {
result = inferForLoopExprType(n, path)
or
result = inferCastExprType(n, path)
or
result = inferStructPatType(n, path)
or
result = inferTupleStructPatType(n, path)
}
}

Expand Down
12 changes: 10 additions & 2 deletions rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ class SliceTypeReprMention extends TypeMention instanceof SliceTypeRepr {
class PathTypeMention extends TypeMention, Path {
TypeItemNode resolved;

PathTypeMention() { resolved = resolvePath(this) }
PathTypeMention() {
resolved = resolvePath(this)
or
resolved = resolvePath(this).(Variant).getEnum()
}

ItemNode getResolved() { result = resolved }
TypeItemNode getResolved() { result = resolved }

pragma[nomagic]
private TypeAlias getResolvedTraitAlias(string name) {
Expand Down Expand Up @@ -99,6 +103,10 @@ class PathTypeMention extends TypeMention, Path {
this = node.getASelfPath() and
result = node.(ImplItemNode).getSelfPath().getSegment().getGenericArgList().getTypeArg(i)
)
or
// `Option::<i32>::Some` is valid in addition to `Option::Some::<i32>`
resolvePath(this) instanceof Variant and
result = this.getQualifier().getSegment().getGenericArgList().getTypeArg(i)
}

private TypeMention getPositionalTypeArgument(int i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,11 @@ multipleCallTargets
| test_futures_io.rs:93:26:93:63 | pinned.poll_read(...) |
| test_futures_io.rs:116:22:116:50 | pinned.poll_fill_buf(...) |
| test_futures_io.rs:145:26:145:49 | ...::with_capacity(...) |
| web_frameworks.rs:13:14:13:22 | a.as_str() |
| web_frameworks.rs:13:14:13:23 | a.as_str() |
| web_frameworks.rs:14:14:14:24 | a.as_bytes() |
| web_frameworks.rs:14:14:14:25 | a.as_bytes() |
| web_frameworks.rs:101:14:101:23 | a.as_str() |
| web_frameworks.rs:102:14:102:25 | a.as_bytes() |
| web_frameworks.rs:158:14:158:23 | a.as_str() |
| web_frameworks.rs:159:14:159:25 | a.as_bytes() |
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ mod poem_test {
#[handler]
fn my_poem_handler_1(Path(a): Path<String>, // $ Alert[rust/summary/taint-sources]
) -> String {
sink(a.as_str()); // $ MISSING: hasTaintFlow -- no type inference for patterns
sink(a.as_bytes()); // $ MISSING: hasTaintFlow -- no type inference for patterns
sink(a.as_str()); // $ hasTaintFlow
sink(a.as_bytes()); // $ hasTaintFlow
sink(a); // $ hasTaintFlow

"".to_string()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ multipleCallTargets
| test.rs:302:7:302:48 | ... .as_str() |
| test.rs:303:7:303:35 | ... .as_str() |
| test.rs:304:7:304:35 | ... .as_str() |
| test.rs:313:8:313:19 | num.as_str() |
| test.rs:324:8:324:19 | num.as_str() |
| test.rs:343:7:343:39 | ... .as_str() |
Loading