Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions rust/ql/lib/codeql/rust/frameworks/stdlib/Builtins.qll
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,18 @@ class RefMutType extends BuiltinType {
override string getDisplayName() { result = "&mut" }
}

/** The builtin pointer type `*const T`. */
class PtrType extends BuiltinType {
PtrType() { this.getName() = "Ptr" }
/** A builtin raw pointer type `*const T` or `*mut T`. */
abstract class PtrType extends BuiltinType { }

/** The builtin raw pointer type `*const T`. */
class PtrConstType extends PtrType {
PtrConstType() { this.getName() = "PtrConst" }

override string getDisplayName() { result = "*const" }
}

/** The builtin pointer type `*mut T`. */
class PtrMutType extends BuiltinType {
/** The builtin raw pointer type `*mut T`. */
class PtrMutType extends PtrType {
PtrMutType() { this.getName() = "PtrMut" }

override string getDisplayName() { result = "*mut" }
Expand Down
28 changes: 20 additions & 8 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,27 @@ class NeverType extends Type, TNeverType {
override Location getLocation() { result instanceof EmptyLocation }
}

class PtrType extends StructType {
PtrType() { this.getStruct() instanceof Builtins::PtrType }
abstract class PtrType extends StructType {
override Location getLocation() { result instanceof EmptyLocation }
}

pragma[nomagic]
TypeParamTypeParameter getPtrTypeParameter() {
result = any(PtrType t).getPositionalTypeParameter(0)
}

class PtrMutType extends PtrType {
PtrMutType() { this.getStruct() instanceof Builtins::PtrMutType }

override string toString() { result = "*mut" }

override string toString() { result = "*" }
override Location getLocation() { result instanceof EmptyLocation }
}

class PtrConstType extends PtrType {
PtrConstType() { this.getStruct() instanceof Builtins::PtrConstType }

override string toString() { result = "*const" }

override Location getLocation() { result instanceof EmptyLocation }
}
Expand Down Expand Up @@ -377,11 +394,6 @@ class UnknownType extends Type, TUnknownType {
override Location getLocation() { result instanceof EmptyLocation }
}

pragma[nomagic]
TypeParamTypeParameter getPtrTypeParameter() {
result = any(PtrType t).getPositionalTypeParameter(0)
}

/** A type parameter. */
abstract class TypeParameter extends Type {
override TypeParameter getPositionalTypeParameter(int i) { none() }
Expand Down
68 changes: 51 additions & 17 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,10 @@ module CertainTypeInference {
or
result = inferLiteralType(n, path, true)
or
result = inferRefNodeType(n) and
result = inferRefPatType(n) and
path.isEmpty()
or
result = inferRefExprType(n) and
path.isEmpty()
or
result = inferLogicalOperationType(n, path)
Expand Down Expand Up @@ -606,10 +609,14 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
strictcount(Expr e | bodyReturns(n1, e)) = 1
)
or
(
n1 = n2.(RefExpr).getExpr() or
n1 = n2.(RefPat).getPat()
) and
exists(RefExpr re |
n2 = re and
n1 = re.getExpr() and
prefix1.isEmpty() and
prefix2 = TypePath::singleton(inferRefExprType(re).getPositionalTypeParameter(0))
)
or
n1 = n2.(RefPat).getPat() and
prefix1.isEmpty() and
prefix2 = TypePath::singleton(getRefTypeParameter())
or
Expand Down Expand Up @@ -709,9 +716,7 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
* `n2`.
*/
private predicate typeEqualityNonSymmetric(
AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2
) {
private predicate typeEqualityAsymmetric(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
lubCoercion(n2, n1, prefix2) and
prefix1.isEmpty()
or
Expand All @@ -723,6 +728,13 @@ private predicate typeEqualityNonSymmetric(
not lubCoercion(mid, n1, _) and
prefix1 = prefixMid.append(suffix)
)
or
// When `n2` is `*n1` propagate type information from a raw pointer type
// parameter at `n1`. The other direction is handled in
// `inferDereferencedExprPtrType`.
n1 = n2.(DerefExpr).getExpr() and
prefix1 = TypePath::singleton(getPtrTypeParameter()) and
prefix2.isEmpty()
}

pragma[nomagic]
Expand All @@ -735,7 +747,7 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
or
typeEquality(n2, prefix2, n, prefix1)
or
typeEqualityNonSymmetric(n2, prefix2, n, prefix1)
typeEqualityAsymmetric(n2, prefix2, n, prefix1)
)
}

Expand Down Expand Up @@ -2952,16 +2964,21 @@ private Type inferFieldExprType(AstNode n, TypePath path) {
)
}

/** Gets the root type of the reference node `ref`. */
/** Gets the root type of the reference expression `ref`. */
pragma[nomagic]
private Type inferRefNodeType(AstNode ref) {
(
ref = any(IdentPat ip | ip.isRef()).getName()
or
ref instanceof RefExpr
private Type inferRefExprType(RefExpr ref) {
if ref.isRaw()
then
ref.isMut() and result instanceof PtrMutType
or
ref instanceof RefPat
) and
ref.isConst() and result instanceof PtrConstType
else result instanceof RefType
}

/** Gets the root type of the reference node `ref`. */
pragma[nomagic]
private Type inferRefPatType(AstNode ref) {
(ref = any(IdentPat ip | ip.isRef()).getName() or ref instanceof RefPat) and
result instanceof RefType
}

Expand Down Expand Up @@ -3145,6 +3162,21 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
)
}

/**
* Gets the inferred type of `n` at `path` when `n` occurs in a dereference
* expression `*n` and when `n` is known to have a raw pointer type.
*
* The other direction is handled in `typeEqualityAsymmetric`.
*/
private Type inferDereferencedExprPtrType(AstNode n, TypePath path) {
exists(DerefExpr de, PtrType type, TypePath suffix |
de.getExpr() = n and
type = inferType(de.getExpr()) and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is equivalent, assuming a DerefExpr only has one getExpr():

Suggested change
type = inferType(de.getExpr()) and
type = inferType(n) and

result = inferType(de, suffix) and
path = TypePath::cons(type.getPositionalTypeParameter(0), suffix)
)
}

/**
* A matching configuration for resolving types of struct patterns
* like `let Foo { bar } = ...`.
Expand Down Expand Up @@ -3544,6 +3576,8 @@ private module Cached {
or
result = inferIndexExprType(n, path)
or
result = inferDereferencedExprPtrType(n, path)
or
result = inferForLoopExprType(n, path)
or
result = inferDynamicCallExprType(n, path)
Expand Down
11 changes: 8 additions & 3 deletions rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,18 @@ class NeverTypeReprMention extends TypeMention, NeverTypeRepr {
}

class PtrTypeReprMention extends TypeMention instanceof PtrTypeRepr {
private PtrType resolveRootType() {
super.isConst() and result instanceof PtrConstType
or
super.isMut() and result instanceof PtrMutType
}

override Type resolveTypeAt(TypePath path) {
path.isEmpty() and
result instanceof PtrType
path.isEmpty() and result = this.resolveRootType()
or
exists(TypePath suffix |
result = super.getTypeRepr().(TypeMention).resolveTypeAt(suffix) and
path = TypePath::cons(getPtrTypeParameter(), suffix)
path = TypePath::cons(this.resolveRootType().getPositionalTypeParameter(0), suffix)
)
}
}
18 changes: 9 additions & 9 deletions rust/ql/test/library-tests/type-inference/raw_pointer.rs
Original file line number Diff line number Diff line change
@@ -1,47 +1,47 @@
use std::ptr::null_mut;

fn raw_pointer_const_deref(x: *const i32) -> i32 {
let _y = unsafe { *x }; // $ MISSING: type=_y:i32
let _y = unsafe { *x }; // $ type=_y:i32
0
}

fn raw_pointer_mut_deref(x: *mut bool) -> i32 {
let _y = unsafe { *x }; // $ MISSING: type=_y:bool
let _y = unsafe { *x }; // $ type=_y:bool
0
}

fn raw_const_borrow() {
let a: i64 = 10;
let x = &raw const a; // $ MISSING: type=x:TPtrConst.i64
let x = &raw const a; // $ type=x:TPtrConst.i64
unsafe {
let _y = *x; // $ type=_y:i64 SPURIOUS: target=deref
let _y = *x; // $ type=_y:i64
}
}

fn raw_mut_borrow() {
let mut a = 10i32;
let x = &raw mut a; // $ MISSING: type=x:TPtrMut.i32
let x = &raw mut a; // $ type=x:TPtrMut.i32
unsafe {
let _y = *x; // $ type=_y:i32 SPURIOUS: target=deref
let _y = *x; // $ type=_y:i32
}
}

fn raw_mut_write(cond: bool) {
let a = 10i32;
// The type of `x` must be inferred from the write below.
let ptr_written = null_mut(); // $ target=null_mut MISSING: type=ptr_written:TPtrMut.i32
let ptr_written = null_mut(); // $ target=null_mut type=ptr_written:TPtrMut.i32
if cond {
unsafe {
// NOTE: This write is undefined behavior because `x` is a null pointer.
*ptr_written = a;
let _y = *ptr_written; // $ MISSING: type=_y:i32
let _y = *ptr_written; // $ type=_y:i32
}
}
}

fn raw_type_from_deref(cond: bool) {
// The type of `x` must be inferred from the read below.
let ptr_read = null_mut(); // $ target=null_mut MISSING: type=ptr_read:TPtrMut.i64
let ptr_read = null_mut(); // $ target=null_mut type=ptr_read:TPtrMut.i64
if cond {
unsafe {
// NOTE: This read is undefined behavior because `x` is a null pointer.
Expand Down
Loading