Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 51 additions & 67 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}

Declaration getTarget() {
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
result = resolveMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
or
result = CallExprImpl::getResolvedFunction(this)
}
Expand Down Expand Up @@ -1178,14 +1178,14 @@ private predicate methodCandidateTrait(Type type, Trait trait, string name, int
methodCandidate(type, name, arity, impl)
}

private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
pragma[nomagic]
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
rootType = mc.getTypeAt(TypePath::nil()) and
name = mc.getMethodName() and
arity = mc.getNumberOfArguments()
}
pragma[nomagic]
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
rootType = mc.getTypeAt(TypePath::nil()) and
name = mc.getMethodName() and
arity = mc.getNumberOfArguments()
}

private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
pragma[nomagic]
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
exists(Type rootType, string name, int arity |
Expand Down Expand Up @@ -1334,17 +1334,46 @@ private predicate methodResolutionDependsOnArgument(
)
}

/**
* Holds if the method call `mc` has no inherent target, i.e., it does not
* resolve to a method in an `impl` block for the type of the receiver.
*/
pragma[nomagic]
private predicate methodCallHasNoInherentTarget(MethodCall mc) {
exists(Type rootType, string name, int arity |
isMethodCall(mc, rootType, name, arity) and
forall(Impl impl |
methodCandidate(rootType, name, arity, impl) and
not impl.hasTrait()
|
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isNotInstantiationOf(mc, impl, _)
)
)
}

pragma[nomagic]
private predicate methodCallHasImplCandidate(MethodCall mc, Impl impl) {
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
if impl.hasTrait() and not exists(mc.getTrait())
then
// inherent methods take precedence over trait methods, so only allow
// trait methods when there are no matching inherent methods
methodCallHasNoInherentTarget(mc)
else any()
}

/** Gets a method from an `impl` block that matches the method call `mc`. */
pragma[nomagic]
private Function getMethodFromImpl(MethodCall mc) {
exists(Impl impl |
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
result = getMethodSuccessor(impl, mc.getMethodName())
exists(Impl impl, string name |
methodCallHasImplCandidate(mc, impl) and
name = mc.getMethodName() and
result = getMethodSuccessor(impl, name)
|
not methodResolutionDependsOnArgument(impl, _, _, _, _, _) and
result = getMethodSuccessor(impl, mc.getMethodName())
not methodResolutionDependsOnArgument(impl, _, _, _, _, _)
or
exists(int pos, TypePath path, Type type |
methodResolutionDependsOnArgument(impl, mc.getMethodName(), result, pos, path, type) and
methodResolutionDependsOnArgument(impl, name, result, pos, path, type) and
inferType(mc.getPositionalArgument(pos), path) = type
)
)
Expand All @@ -1356,22 +1385,6 @@ private Function getTraitMethod(ImplTraitReturnType trait, string name) {
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
}

/**
* Gets a method that the method call `mc` resolves to based on type inference,
* if any.
*/
private Function inferMethodCallTarget(MethodCall mc) {
// The method comes from an `impl` block targeting the type of the receiver.
result = getMethodFromImpl(mc)
or
// The type of the receiver is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
// The type of the receiver is an `impl Trait` type.
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
}

cached
private module Cached {
private import codeql.rust.internal.CachedStages
Expand Down Expand Up @@ -1400,47 +1413,18 @@ private module Cached {
)
}

private predicate isInherentImplFunction(Function f) {
f = any(Impl impl | not impl.hasTrait()).(ImplItemNode).getAnAssocItem()
}

private predicate isTraitImplFunction(Function f) {
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
}

private Function resolveMethodCallTargetFrom(MethodCall mc, boolean fromSource) {
result = inferMethodCallTarget(mc) and
(if result.fromSource() then fromSource = true else fromSource = false) and
(
// prioritize inherent implementation methods first
isInherentImplFunction(result)
or
not isInherentImplFunction(inferMethodCallTarget(mc)) and
(
// then trait implementation methods
isTraitImplFunction(result)
or
not isTraitImplFunction(inferMethodCallTarget(mc)) and
(
// then trait methods with default implementations
result.hasBody()
or
// and finally trait methods without default implementations
not inferMethodCallTarget(mc).hasBody()
)
)
)
}

/** Gets a method that the method call `mc` resolves to, if any. */
cached
Function resolveMethodCallTarget(MethodCall mc) {
// Functions in source code also gets extracted as library code, due to
// this duplication we prioritize functions from source code.
result = resolveMethodCallTargetFrom(mc, true)
// The method comes from an `impl` block targeting the type of the receiver.
result = getMethodFromImpl(mc)
or
// The type of the receiver is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
not exists(resolveMethodCallTargetFrom(mc, true)) and
result = resolveMethodCallTargetFrom(mc, false)
// The type of the receiver is an `impl Trait` type.
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
}

pragma[inline]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ multipleCallTargets
| test.rs:168:26:168:111 | ...::_print(...) |
| test.rs:178:30:178:68 | ...::_print(...) |
| test.rs:187:26:187:105 | ...::_print(...) |
| test.rs:228:22:228:72 | ... .read_to_string(...) |
| test.rs:482:22:482:50 | file.read_to_end(...) |
| test.rs:488:22:488:53 | file.read_to_string(...) |
| test.rs:609:18:609:38 | ...::_print(...) |
| test.rs:614:18:614:45 | ...::_print(...) |
| test.rs:618:25:618:49 | address.to_socket_addrs() |
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
multipleCallTargets
| dereference.rs:61:15:61:24 | e1.deref() |
| main.rs:1963:13:1963:31 | ...::from(...) |
| main.rs:1964:13:1964:31 | ...::from(...) |
| main.rs:1965:13:1965:31 | ...::from(...) |
| main.rs:1970:13:1970:31 | ...::from(...) |
| main.rs:1971:13:1971:31 | ...::from(...) |
| main.rs:1972:13:1972:31 | ...::from(...) |
| main.rs:2006:21:2006:43 | ...::from(...) |
| main.rs:2032:13:2032:31 | ...::from(...) |
| main.rs:2033:13:2033:31 | ...::from(...) |
| main.rs:2034:13:2034:31 | ...::from(...) |
| main.rs:2040:13:2040:31 | ...::from(...) |
| main.rs:2041:13:2041:31 | ...::from(...) |
| main.rs:2042:13:2042:31 | ...::from(...) |
| main.rs:2078:21:2078:43 | ...::from(...) |
Loading