diff --git a/java/ql/lib/semmle/code/java/security/CsrfUnprotectedRequestTypeQuery.qll b/java/ql/lib/semmle/code/java/security/CsrfUnprotectedRequestTypeQuery.qll index 42d6db246c0e..80d100d3d9e1 100644 --- a/java/ql/lib/semmle/code/java/security/CsrfUnprotectedRequestTypeQuery.qll +++ b/java/ql/lib/semmle/code/java/security/CsrfUnprotectedRequestTypeQuery.qll @@ -237,12 +237,35 @@ private predicate sink(CallPathNode sinkMethodCall) { ) } +private predicate fwdFlow(CallPathNode n) { + source(n) + or + exists(CallPathNode mid | fwdFlow(mid) and CallGraph::edges(mid, n)) +} + +private predicate revFlow(CallPathNode n) { + fwdFlow(n) and + ( + sink(n) + or + exists(CallPathNode mid | revFlow(mid) and CallGraph::edges(n, mid)) + ) +} + +/** + * Holds if `pred` has a successor node `succ` and this edge is in an + * `unprotectedStateChange` path. + */ +predicate relevantEdge(CallPathNode pred, CallPathNode succ) { + CallGraph::edges(pred, succ) and revFlow(pred) and revFlow(succ) +} + /** * Holds if `sourceMethod` is an unprotected request handler that reaches a * `sinkMethodCall` that updates a database. */ private predicate unprotectedDatabaseUpdate(CallPathNode sourceMethod, CallPathNode sinkMethodCall) = - doublyBoundedFastTC(CallGraph::edges/2, source/1, sink/1)(sourceMethod, sinkMethodCall) + doublyBoundedFastTC(relevantEdge/2, source/1, sink/1)(sourceMethod, sinkMethodCall) /** * Holds if `sourceMethod` is an unprotected request handler that appears to diff --git a/java/ql/src/Security/CWE/CWE-352/CsrfUnprotectedRequestType.ql b/java/ql/src/Security/CWE/CWE-352/CsrfUnprotectedRequestType.ql index e338cb84c005..cf5c0b385ccf 100644 --- a/java/ql/src/Security/CWE/CWE-352/CsrfUnprotectedRequestType.ql +++ b/java/ql/src/Security/CWE/CWE-352/CsrfUnprotectedRequestType.ql @@ -15,7 +15,7 @@ import java import semmle.code.java.security.CsrfUnprotectedRequestTypeQuery -query predicate edges(CallPathNode pred, CallPathNode succ) { CallGraph::edges(pred, succ) } +query predicate edges(CallPathNode pred, CallPathNode succ) { relevantEdge(pred, succ) } from CallPathNode source, CallPathNode sink where unprotectedStateChange(source, sink)