Skip to content

Commit

Permalink
[SPARK-43098][SQL] Fix correctness COUNT bug when scalar subquery has…
Browse files Browse the repository at this point in the history
… group by clause

### What changes were proposed in this pull request?
Fix a correctness bug for scalar subqueries with COUNT and a GROUP BY clause, for example:
```
create view t1(c1, c2) as values (0, 1), (1, 2);
create view t2(c1, c2) as values (0, 2), (0, 3);

select c1, c2, (select count(*) from t2 where t1.c1 = t2.c1 group by c1) from t1;

-- Correct answer: [(0, 1, 2), (1, 2, null)]
+---+---+------------------+
|c1 |c2 |scalarsubquery(c1)|
+---+---+------------------+
|0  |1  |2                 |
|1  |2  |0                 |
+---+---+------------------+
```

This is due to a bug in our "COUNT bug" handling for scalar subqueries. For a subquery with COUNT aggregate but no GROUP BY clause, 0 is the correct output on empty inputs, and we use the COUNT bug handling to construct the plan that  yields 0 when there were no matched rows.

But when there is a GROUP BY clause then NULL is the correct output (i.e. there is no COUNT bug), but we still incorrectly use the COUNT bug handling and therefore incorrectly output 0. Instead, we need to only apply the COUNT bug handling when the scalar subquery had no GROUP BY clause.

To fix this, we need to track whether the scalar subquery has a GROUP BY, i.e. a non-empty groupingExpressions for the Aggregate node. This need to be checked before subquery decorrelation, because that adds the correlated outer refs to the group-by list so after that the group-by is always non-empty. We save it in a boolean in the ScalarSubquery node until later when we rewrite the subquery into a join in constructLeftJoins.

This is a long-standing bug. This bug affected both the current DecorrelateInnerQuery framework and the old code (with spark.sql.optimizer.decorrelateInnerQuery.enabled = false), and this PR fixes both.

### Why are the changes needed?
Fix a correctness bug.

### Does this PR introduce _any_ user-facing change?
Yes, fix incorrect query results.

### How was this patch tested?
Add SQL tests and unit tests. (Note that there were 2 existing unit tests for queries of this shape, which had the incorrect results as golden results.)

Closes apache#40811 from jchen5/count-bug.

Authored-by: Jack Chen <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
jchen5 authored and cloud-fan committed Apr 19, 2023
1 parent 780aeec commit c1a02e7
Show file tree
Hide file tree
Showing 11 changed files with 559 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2418,7 +2418,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
*/
private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
case s @ ScalarSubquery(sub, _, exprId, _, _) if !sub.resolved =>
case s @ ScalarSubquery(sub, _, exprId, _, _, _) if !sub.resolved =>
resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId, _, _) if !sub.resolved =>
resolveSubQuery(e, outer)(Exists(_, _, exprId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
checkOuterReference(plan, expr)

expr match {
case ScalarSubquery(query, outerAttrs, _, _, _) =>
case ScalarSubquery(query, outerAttrs, _, _, _, _) =>
// Scalar subquery must return one column as output.
if (query.output.size != 1) {
expr.failAnalysis(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,20 @@ object SubExprUtils extends PredicateHelper {
* scalar subquery during planning.
*
* Note: `exprId` is used to have a unique name in explain string output.
*
* `mayHaveCountBug` is whether it's possible for the subquery to evaluate to non-null on
* empty input (zero tuples). It is false if the subquery has a GROUP BY clause, because in that
* case the subquery yields no row at all on empty input to the GROUP BY, which evaluates to NULL.
* It is set in PullupCorrelatedPredicates to true/false, before it is set its value is None.
* See constructLeftJoins in RewriteCorrelatedScalarSubquery for more details.
*/
case class ScalarSubquery(
plan: LogicalPlan,
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
joinCond: Seq[Expression] = Seq.empty,
hint: Option[HintInfo] = None)
hint: Option[HintInfo] = None,
mayHaveCountBug: Option[Boolean] = None)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
override def dataType: DataType = {
assert(plan.schema.fields.nonEmpty, "Scalar subquery should have only one column")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, IN_SUBQUERY, LATERAL_JOIN, LIST_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

/*
* This file defines optimization rules related to subqueries.
Expand Down Expand Up @@ -325,9 +327,22 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
}

plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
case ScalarSubquery(sub, children, exprId, conditions, hint) if children.nonEmpty =>
case ScalarSubquery(sub, children, exprId, conditions, hint, mayHaveCountBugOld)
if children.nonEmpty =>
val (newPlan, newCond) = decorrelate(sub, plan)
ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint)
val mayHaveCountBug = if (mayHaveCountBugOld.isEmpty) {
// Check whether the pre-rewrite subquery had empty groupingExpressions. If yes, it may
// be subject to the COUNT bug. If it has non-empty groupingExpressions, there is
// no COUNT bug.
val (topPart, havingNode, aggNode) = splitSubquery(sub)
(aggNode.isDefined && aggNode.get.groupingExpressions.isEmpty)
} else {
// For idempotency, we must save this variable the first time this rule is run, because
// decorrelation introduces a GROUP BY is if one wasn't already present.
mayHaveCountBugOld.get
}
ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions),
hint, Some(mayHaveCountBug))
case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint)
Expand Down Expand Up @@ -519,7 +534,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
* (optional second part) and the Aggregate below the HAVING CLAUSE (optional third part).
* When the third part is empty, it means the subquery is a non-aggregated single-row subquery.
*/
private def splitSubquery(
def splitSubquery(
plan: LogicalPlan): (Seq[LogicalPlan], Option[Filter], Option[Aggregate]) = {
val topPart = ArrayBuffer.empty[LogicalPlan]
var bottomPart: LogicalPlan = plan
Expand Down Expand Up @@ -569,7 +584,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = {
val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
val newChild = subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint)) =>
case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug)) =>
val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions)
val origOutput = query.output.head
// The subquery appears on the right side of the join, hence add its hint to the right
Expand All @@ -581,8 +596,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
currentChild.output :+ origOutput,
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), joinHint))

if (Utils.isTesting) {
assert(mayHaveCountBug.isDefined)
}
if (resultWithZeroTups.isEmpty) {
// CASE 1: Subquery guaranteed not to have the COUNT bug
// CASE 1: Subquery guaranteed not to have the COUNT bug because it evaluates to NULL
// with zero tuples.
planWithoutCountBug
} else if (!mayHaveCountBug.getOrElse(true) &&
!conf.getConf(SQLConf.DECORRELATE_SUBQUERY_LEGACY_INCORRECT_COUNT_HANDLING_ENABLED)) {
// Subquery guaranteed not to have the COUNT bug because it had non-empty GROUP BY clause
planWithoutCountBug
} else {
val (topPart, havingNode, aggNode) = splitSubquery(query)
Expand Down Expand Up @@ -800,7 +823,7 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] {

case p: LogicalPlan => p.transformExpressionsUpWithPruning(
_.containsPattern(SCALAR_SUBQUERY)) {
case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _)
case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _)
if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty =>
assert(p.projectList.size == 1)
stripOuterReferences(p.projectList).head
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3251,6 +3251,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val DECORRELATE_SUBQUERY_LEGACY_INCORRECT_COUNT_HANDLING_ENABLED =
buildConf("spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled")
.internal()
.doc("If enabled, revert to legacy incorrect behavior for certain subqueries with COUNT or " +
"similar aggregates: see SPARK-43098.")
.version("3.5.0")
.booleanConf
.createWithDefault(false)

val OPTIMIZE_ONE_ROW_RELATION_SUBQUERY =
buildConf("spark.sql.optimizer.optimizeOneRowRelationSubquery")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ case class InsertAdaptiveSparkPlan(
return subqueryMap.toMap
}
plan.foreach(_.expressions.filter(_.containsPattern(PLAN_EXPRESSION)).foreach(_.foreach {
case expressions.ScalarSubquery(p, _, exprId, _, _)
case expressions.ScalarSubquery(p, _, exprId, _, _, _)
if !subqueryMap.contains(exprId.id) =>
val executedPlan = compileSubquery(p)
verifyAdaptivePlan(executedPlan, p)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ case class PlanAdaptiveSubqueries(
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressionsWithPruning(
_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
case expressions.ScalarSubquery(_, _, exprId, _, _) =>
case expressions.ScalarSubquery(_, _, exprId, _, _, _) =>
execution.ScalarSubquery(subqueryMap(exprId.id), exprId)
case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) =>
val expr = if (values.length == 1) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
create temp view l (a, b)
as values
(1, 2.0),
(1, 2.0),
(2, 1.0),
(2, 1.0),
(3, 3.0),
(null, null),
(null, 5.0),
(6, null)
-- !query analysis
CreateViewCommand `l`, [(a,None), (b,None)], values
(1, 2.0),
(1, 2.0),
(2, 1.0),
(2, 1.0),
(3, 3.0),
(null, null),
(null, 5.0),
(6, null), false, false, LocalTempView, true
+- LocalRelation [col1#x, col2#x]


-- !query
create temp view r (c, d)
as values
(2, 3.0),
(2, 3.0),
(3, 2.0),
(4, 1.0),
(null, null),
(null, 5.0),
(6, null)
-- !query analysis
CreateViewCommand `r`, [(c,None), (d,None)], values
(2, 3.0),
(2, 3.0),
(3, 2.0),
(4, 1.0),
(null, null),
(null, 5.0),
(6, null), false, false, LocalTempView, true
+- LocalRelation [col1#x, col2#x]


-- !query
select *, (select count(*) from r where l.a = r.c) from l
-- !query analysis
Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL]
: +- Aggregate [count(1) AS count(1)#xL]
: +- Filter (outer(a#x) = c#x)
: +- SubqueryAlias r
: +- View (`r`, [c#x,d#x])
: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias l
+- View (`l`, [a#x,b#x])
+- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select *, (select count(*) from r where l.a = r.c group by c) from l
-- !query analysis
Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL]
: +- Aggregate [c#x], [count(1) AS count(1)#xL]
: +- Filter (outer(a#x) = c#x)
: +- SubqueryAlias r
: +- View (`r`, [c#x,d#x])
: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias l
+- View (`l`, [a#x,b#x])
+- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select *, (select count(*) from r where l.a = r.c group by 'constant') from l
-- !query analysis
Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL]
: +- Aggregate [constant], [count(1) AS count(1)#xL]
: +- Filter (outer(a#x) = c#x)
: +- SubqueryAlias r
: +- View (`r`, [c#x,d#x])
: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias l
+- View (`l`, [a#x,b#x])
+- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select *, (
select (count(*)) is null
from r
where l.a = r.c)
from l
-- !query analysis
Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#x]
: +- Aggregate [isnull(count(1)) AS (count(1) IS NULL)#x]
: +- Filter (outer(a#x) = c#x)
: +- SubqueryAlias r
: +- View (`r`, [c#x,d#x])
: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias l
+- View (`l`, [a#x,b#x])
+- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select *, (
select (count(*)) is null
from r
where l.a = r.c
group by r.c)
from l
-- !query analysis
Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#x]
: +- Aggregate [c#x], [isnull(count(1)) AS (count(1) IS NULL)#x]
: +- Filter (outer(a#x) = c#x)
: +- SubqueryAlias r
: +- View (`r`, [c#x,d#x])
: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias l
+- View (`l`, [a#x,b#x])
+- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select *, (select count(*) from r where l.a = r.c having count(*) <= 1) from l
-- !query analysis
Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL]
: +- Filter (count(1)#xL <= cast(1 as bigint))
: +- Aggregate [count(1) AS count(1)#xL]
: +- Filter (outer(a#x) = c#x)
: +- SubqueryAlias r
: +- View (`r`, [c#x,d#x])
: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias l
+- View (`l`, [a#x,b#x])
+- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select *, (select count(*) from r where l.a = r.c having count(*) >= 2) from l
-- !query analysis
Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL]
: +- Filter (count(1)#xL >= cast(2 as bigint))
: +- Aggregate [count(1) AS count(1)#xL]
: +- Filter (outer(a#x) = c#x)
: +- SubqueryAlias r
: +- View (`r`, [c#x,d#x])
: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias l
+- View (`l`, [a#x,b#x])
+- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x]
+- LocalRelation [col1#x, col2#x]


-- !query
set spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled = true
-- !query analysis
SetCommand (spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled,Some(true))


-- !query
select *, (select count(*) from r where l.a = r.c) from l
-- !query analysis
Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL]
: +- Aggregate [count(1) AS count(1)#xL]
: +- Filter (outer(a#x) = c#x)
: +- SubqueryAlias r
: +- View (`r`, [c#x,d#x])
: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias l
+- View (`l`, [a#x,b#x])
+- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select *, (select count(*) from r where l.a = r.c group by c) from l
-- !query analysis
Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL]
: +- Aggregate [c#x], [count(1) AS count(1)#xL]
: +- Filter (outer(a#x) = c#x)
: +- SubqueryAlias r
: +- View (`r`, [c#x,d#x])
: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias l
+- View (`l`, [a#x,b#x])
+- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select *, (select count(*) from r where l.a = r.c group by 'constant') from l
-- !query analysis
Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL]
: +- Aggregate [constant], [count(1) AS count(1)#xL]
: +- Filter (outer(a#x) = c#x)
: +- SubqueryAlias r
: +- View (`r`, [c#x,d#x])
: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias l
+- View (`l`, [a#x,b#x])
+- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x]
+- LocalRelation [col1#x, col2#x]


-- !query
reset spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled
-- !query analysis
ResetCommand spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled
Loading

0 comments on commit c1a02e7

Please sign in to comment.