From c1a02e7e304934a7e671c0ae2f5a1fcedd6806c0 Mon Sep 17 00:00:00 2001 From: Jack Chen Date: Wed, 19 Apr 2023 09:42:06 +0800 Subject: [PATCH] [SPARK-43098][SQL] Fix correctness COUNT bug when scalar subquery has group by clause ### What changes were proposed in this pull request? Fix a correctness bug for scalar subqueries with COUNT and a GROUP BY clause, for example: ``` create view t1(c1, c2) as values (0, 1), (1, 2); create view t2(c1, c2) as values (0, 2), (0, 3); select c1, c2, (select count(*) from t2 where t1.c1 = t2.c1 group by c1) from t1; -- Correct answer: [(0, 1, 2), (1, 2, null)] +---+---+------------------+ |c1 |c2 |scalarsubquery(c1)| +---+---+------------------+ |0 |1 |2 | |1 |2 |0 | +---+---+------------------+ ``` This is due to a bug in our "COUNT bug" handling for scalar subqueries. For a subquery with COUNT aggregate but no GROUP BY clause, 0 is the correct output on empty inputs, and we use the COUNT bug handling to construct the plan that yields 0 when there were no matched rows. But when there is a GROUP BY clause then NULL is the correct output (i.e. there is no COUNT bug), but we still incorrectly use the COUNT bug handling and therefore incorrectly output 0. Instead, we need to only apply the COUNT bug handling when the scalar subquery had no GROUP BY clause. To fix this, we need to track whether the scalar subquery has a GROUP BY, i.e. a non-empty groupingExpressions for the Aggregate node. This need to be checked before subquery decorrelation, because that adds the correlated outer refs to the group-by list so after that the group-by is always non-empty. We save it in a boolean in the ScalarSubquery node until later when we rewrite the subquery into a join in constructLeftJoins. This is a long-standing bug. This bug affected both the current DecorrelateInnerQuery framework and the old code (with spark.sql.optimizer.decorrelateInnerQuery.enabled = false), and this PR fixes both. ### Why are the changes needed? Fix a correctness bug. ### Does this PR introduce _any_ user-facing change? Yes, fix incorrect query results. ### How was this patch tested? Add SQL tests and unit tests. (Note that there were 2 existing unit tests for queries of this shape, which had the incorrect results as golden results.) Closes #40811 from jchen5/count-bug. Authored-by: Jack Chen Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../sql/catalyst/expressions/subquery.scala | 9 +- .../sql/catalyst/optimizer/subquery.scala | 35 ++- .../apache/spark/sql/internal/SQLConf.scala | 9 + .../adaptive/InsertAdaptiveSparkPlan.scala | 2 +- .../adaptive/PlanAdaptiveSubqueries.scala | 2 +- .../scalar-subquery-count-bug.sql.out | 228 ++++++++++++++++++ .../scalar-subquery-count-bug.sql | 58 +++++ .../scalar-subquery-count-bug.sql.out | 207 ++++++++++++++++ .../org/apache/spark/sql/SubquerySuite.scala | 20 +- 11 files changed, 559 insertions(+), 15 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-count-bug.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8b3efe09834d6..62555c9a99cc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2418,7 +2418,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor */ private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) { - case s @ ScalarSubquery(sub, _, exprId, _, _) if !sub.resolved => + case s @ ScalarSubquery(sub, _, exprId, _, _, _) if !sub.resolved => resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId, _, _) if !sub.resolved => resolveSubQuery(e, outer)(Exists(_, _, exprId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c8466460e5823..49f0d438d0a23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -935,7 +935,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB checkOuterReference(plan, expr) expr match { - case ScalarSubquery(query, outerAttrs, _, _, _) => + case ScalarSubquery(query, outerAttrs, _, _, _, _) => // Scalar subquery must return one column as output. if (query.output.size != 1) { expr.failAnalysis( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index e7384dac2d53e..228bb4805c85f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -254,13 +254,20 @@ object SubExprUtils extends PredicateHelper { * scalar subquery during planning. * * Note: `exprId` is used to have a unique name in explain string output. + * + * `mayHaveCountBug` is whether it's possible for the subquery to evaluate to non-null on + * empty input (zero tuples). It is false if the subquery has a GROUP BY clause, because in that + * case the subquery yields no row at all on empty input to the GROUP BY, which evaluates to NULL. + * It is set in PullupCorrelatedPredicates to true/false, before it is set its value is None. + * See constructLeftJoins in RewriteCorrelatedScalarSubquery for more details. */ case class ScalarSubquery( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, - hint: Option[HintInfo] = None) + hint: Option[HintInfo] = None, + mayHaveCountBug: Option[Boolean] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { override def dataType: DataType = { assert(plan.schema.fields.nonEmpty, "Scalar subquery should have only one column") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index faafeecc316a6..83ff5e3973910 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, IN_SUBQ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /* * This file defines optimization rules related to subqueries. @@ -325,9 +327,22 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { - case ScalarSubquery(sub, children, exprId, conditions, hint) if children.nonEmpty => + case ScalarSubquery(sub, children, exprId, conditions, hint, mayHaveCountBugOld) + if children.nonEmpty => val (newPlan, newCond) = decorrelate(sub, plan) - ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint) + val mayHaveCountBug = if (mayHaveCountBugOld.isEmpty) { + // Check whether the pre-rewrite subquery had empty groupingExpressions. If yes, it may + // be subject to the COUNT bug. If it has non-empty groupingExpressions, there is + // no COUNT bug. + val (topPart, havingNode, aggNode) = splitSubquery(sub) + (aggNode.isDefined && aggNode.get.groupingExpressions.isEmpty) + } else { + // For idempotency, we must save this variable the first time this rule is run, because + // decorrelation introduces a GROUP BY is if one wasn't already present. + mayHaveCountBugOld.get + } + ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), + hint, Some(mayHaveCountBug)) case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan) Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint) @@ -519,7 +534,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe * (optional second part) and the Aggregate below the HAVING CLAUSE (optional third part). * When the third part is empty, it means the subquery is a non-aggregated single-row subquery. */ - private def splitSubquery( + def splitSubquery( plan: LogicalPlan): (Seq[LogicalPlan], Option[Filter], Option[Aggregate]) = { val topPart = ArrayBuffer.empty[LogicalPlan] var bottomPart: LogicalPlan = plan @@ -569,7 +584,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = { val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() val newChild = subqueries.foldLeft(child) { - case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint)) => + case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug)) => val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions) val origOutput = query.output.head // The subquery appears on the right side of the join, hence add its hint to the right @@ -581,8 +596,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe currentChild.output :+ origOutput, Join(currentChild, query, LeftOuter, conditions.reduceOption(And), joinHint)) + if (Utils.isTesting) { + assert(mayHaveCountBug.isDefined) + } if (resultWithZeroTups.isEmpty) { - // CASE 1: Subquery guaranteed not to have the COUNT bug + // CASE 1: Subquery guaranteed not to have the COUNT bug because it evaluates to NULL + // with zero tuples. + planWithoutCountBug + } else if (!mayHaveCountBug.getOrElse(true) && + !conf.getConf(SQLConf.DECORRELATE_SUBQUERY_LEGACY_INCORRECT_COUNT_HANDLING_ENABLED)) { + // Subquery guaranteed not to have the COUNT bug because it had non-empty GROUP BY clause planWithoutCountBug } else { val (topPart, havingNode, aggNode) = splitSubquery(query) @@ -800,7 +823,7 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(SCALAR_SUBQUERY)) { - case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _) + case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _) if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty => assert(p.projectList.size == 1) stripOuterReferences(p.projectList).head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4986dc3661c06..e5023498513b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3251,6 +3251,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DECORRELATE_SUBQUERY_LEGACY_INCORRECT_COUNT_HANDLING_ENABLED = + buildConf("spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled") + .internal() + .doc("If enabled, revert to legacy incorrect behavior for certain subqueries with COUNT or " + + "similar aggregates: see SPARK-43098.") + .version("3.5.0") + .booleanConf + .createWithDefault(false) + val OPTIMIZE_ONE_ROW_RELATION_SUBQUERY = buildConf("spark.sql.optimizer.optimizeOneRowRelationSubquery") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index cfc39617a0513..947a7314142fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -133,7 +133,7 @@ case class InsertAdaptiveSparkPlan( return subqueryMap.toMap } plan.foreach(_.expressions.filter(_.containsPattern(PLAN_EXPRESSION)).foreach(_.foreach { - case expressions.ScalarSubquery(p, _, exprId, _, _) + case expressions.ScalarSubquery(p, _, exprId, _, _, _) if !subqueryMap.contains(exprId.id) => val executedPlan = compileSubquery(p) verifyAdaptivePlan(executedPlan, p) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index f88d2ffd541d3..c3f4274058350 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -31,7 +31,7 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning( _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _, _) => + case expressions.ScalarSubquery(_, _, exprId, _, _, _) => execution.ScalarSubquery(subqueryMap(exprId.id), exprId) case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) => val expr = if (values.length == 1) { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out new file mode 100644 index 0000000000000..908bc579d7b3f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out @@ -0,0 +1,228 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +create temp view l (a, b) +as values + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0), + (3, 3.0), + (null, null), + (null, 5.0), + (6, null) +-- !query analysis +CreateViewCommand `l`, [(a,None), (b,None)], values + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0), + (3, 3.0), + (null, null), + (null, 5.0), + (6, null), false, false, LocalTempView, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +create temp view r (c, d) +as values + (2, 3.0), + (2, 3.0), + (3, 2.0), + (4, 1.0), + (null, null), + (null, 5.0), + (6, null) +-- !query analysis +CreateViewCommand `r`, [(c,None), (d,None)], values + (2, 3.0), + (2, 3.0), + (3, 2.0), + (4, 1.0), + (null, null), + (null, 5.0), + (6, null), false, false, LocalTempView, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c) from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Aggregate [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c group by c) from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Aggregate [c#x], [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c group by 'constant') from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Aggregate [constant], [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, ( + select (count(*)) is null + from r + where l.a = r.c) +from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#x] +: +- Aggregate [isnull(count(1)) AS (count(1) IS NULL)#x] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, ( + select (count(*)) is null + from r + where l.a = r.c + group by r.c) +from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#x] +: +- Aggregate [c#x], [isnull(count(1)) AS (count(1) IS NULL)#x] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c having count(*) <= 1) from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Filter (count(1)#xL <= cast(1 as bigint)) +: +- Aggregate [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c having count(*) >= 2) from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Filter (count(1)#xL >= cast(2 as bigint)) +: +- Aggregate [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +set spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled = true +-- !query analysis +SetCommand (spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled,Some(true)) + + +-- !query +select *, (select count(*) from r where l.a = r.c) from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Aggregate [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c group by c) from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Aggregate [c#x], [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c group by 'constant') from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Aggregate [constant], [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +reset spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled +-- !query analysis +ResetCommand spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-count-bug.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-count-bug.sql new file mode 100644 index 0000000000000..0ca2f07b301de --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-count-bug.sql @@ -0,0 +1,58 @@ +create temp view l (a, b) +as values + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0), + (3, 3.0), + (null, null), + (null, 5.0), + (6, null); + +create temp view r (c, d) +as values + (2, 3.0), + (2, 3.0), + (3, 2.0), + (4, 1.0), + (null, null), + (null, 5.0), + (6, null); + +-- count bug, empty groups should evaluate to 0 +select *, (select count(*) from r where l.a = r.c) from l; + +-- no count bug, empty groups should evaluate to null +select *, (select count(*) from r where l.a = r.c group by c) from l; +select *, (select count(*) from r where l.a = r.c group by 'constant') from l; + +-- count bug, empty groups should evaluate to false - but this case is wrong due to bug SPARK-43156 +select *, ( + select (count(*)) is null + from r + where l.a = r.c) +from l; + +-- no count bug, empty groups should evaluate to null +select *, ( + select (count(*)) is null + from r + where l.a = r.c + group by r.c) +from l; + +-- Empty groups should evaluate to 0, and groups filtered by HAVING should evaluate to NULL +select *, (select count(*) from r where l.a = r.c having count(*) <= 1) from l; + +-- Empty groups are filtered by HAVING and should evaluate to null +select *, (select count(*) from r where l.a = r.c having count(*) >= 2) from l; + + +set spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled = true; + +-- With legacy behavior flag set, both cases evaluate to 0 +select *, (select count(*) from r where l.a = r.c) from l; +select *, (select count(*) from r where l.a = r.c group by c) from l; +select *, (select count(*) from r where l.a = r.c group by 'constant') from l; + +reset spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out new file mode 100644 index 0000000000000..3012b67cf8ce8 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out @@ -0,0 +1,207 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +create temp view l (a, b) +as values + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0), + (3, 3.0), + (null, null), + (null, 5.0), + (6, null) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temp view r (c, d) +as values + (2, 3.0), + (2, 3.0), + (3, 2.0), + (4, 1.0), + (null, null), + (null, 5.0), + (6, null) +-- !query schema +struct<> +-- !query output + + + +-- !query +select *, (select count(*) from r where l.a = r.c) from l +-- !query schema +struct +-- !query output +1 2.0 0 +1 2.0 0 +2 1.0 2 +2 1.0 2 +3 3.0 1 +6 NULL 1 +NULL 5.0 0 +NULL NULL 0 + + +-- !query +select *, (select count(*) from r where l.a = r.c group by c) from l +-- !query schema +struct +-- !query output +1 2.0 NULL +1 2.0 NULL +2 1.0 2 +2 1.0 2 +3 3.0 1 +6 NULL 1 +NULL 5.0 NULL +NULL NULL NULL + + +-- !query +select *, (select count(*) from r where l.a = r.c group by 'constant') from l +-- !query schema +struct +-- !query output +1 2.0 NULL +1 2.0 NULL +2 1.0 2 +2 1.0 2 +3 3.0 1 +6 NULL 1 +NULL 5.0 NULL +NULL NULL NULL + + +-- !query +select *, ( + select (count(*)) is null + from r + where l.a = r.c) +from l +-- !query schema +struct +-- !query output +1 2.0 NULL +1 2.0 NULL +2 1.0 false +2 1.0 false +3 3.0 false +6 NULL false +NULL 5.0 NULL +NULL NULL NULL + + +-- !query +select *, ( + select (count(*)) is null + from r + where l.a = r.c + group by r.c) +from l +-- !query schema +struct +-- !query output +1 2.0 NULL +1 2.0 NULL +2 1.0 false +2 1.0 false +3 3.0 false +6 NULL false +NULL 5.0 NULL +NULL NULL NULL + + +-- !query +select *, (select count(*) from r where l.a = r.c having count(*) <= 1) from l +-- !query schema +struct +-- !query output +1 2.0 0 +1 2.0 0 +2 1.0 NULL +2 1.0 NULL +3 3.0 1 +6 NULL 1 +NULL 5.0 0 +NULL NULL 0 + + +-- !query +select *, (select count(*) from r where l.a = r.c having count(*) >= 2) from l +-- !query schema +struct +-- !query output +1 2.0 NULL +1 2.0 NULL +2 1.0 2 +2 1.0 2 +3 3.0 NULL +6 NULL NULL +NULL 5.0 NULL +NULL NULL NULL + + +-- !query +set spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled = true +-- !query schema +struct +-- !query output +spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled true + + +-- !query +select *, (select count(*) from r where l.a = r.c) from l +-- !query schema +struct +-- !query output +1 2.0 0 +1 2.0 0 +2 1.0 2 +2 1.0 2 +3 3.0 1 +6 NULL 1 +NULL 5.0 0 +NULL NULL 0 + + +-- !query +select *, (select count(*) from r where l.a = r.c group by c) from l +-- !query schema +struct +-- !query output +1 2.0 0 +1 2.0 0 +2 1.0 2 +2 1.0 2 +3 3.0 1 +6 NULL 1 +NULL 5.0 0 +NULL NULL 0 + + +-- !query +select *, (select count(*) from r where l.a = r.c group by 'constant') from l +-- !query schema +struct +-- !query output +1 2.0 0 +1 2.0 0 +2 1.0 2 +2 1.0 2 +3 3.0 1 +6 NULL 1 +NULL 5.0 0 +NULL NULL 0 + + +-- !query +reset spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 9246d360becbc..32d913ca3b425 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -663,14 +663,26 @@ class SubquerySuite extends QueryTest checkAnswer( sql( """ - |select l.b, (select (r.c + count(*)) is null + |select l.b, (select (min(r.c) + count(*)) is null |from r - |where l.a = r.c group by r.c) from l + |where l.a = r.c) from l """.stripMargin), Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) } + test("SPARK-43098: no COUNT bug with group-by") { + checkAnswer( + sql( + """ + |select l.b, (select (r.c + count(*)) is null + |from r + |where l.a = r.c group by r.c) from l + """.stripMargin), + Row(1.0, false) :: Row(1.0, false) :: Row(2.0, null) :: Row(2.0, null) :: + Row(3.0, false) :: Row(5.0, null) :: Row(null, false) :: Row(null, null) :: Nil) + } + test("SPARK-16804: Correlated subqueries containing LIMIT - 1") { withTempView("onerow") { Seq(1).toDF("c1").createOrReplaceTempView("onerow") @@ -1844,8 +1856,8 @@ class SubquerySuite extends QueryTest | ) |FROM l """.stripMargin), - Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: - Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) + Row(1.0, false) :: Row(1.0, false) :: Row(2.0, null) :: Row(2.0, null) :: + Row(3.0, false) :: Row(5.0, null) :: Row(null, false) :: Row(null, null) :: Nil) } test("SPARK-28441: COUNT bug with non-foldable expression") {