diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8b3efe09834d6..62555c9a99cc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2418,7 +2418,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor */ private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) { - case s @ ScalarSubquery(sub, _, exprId, _, _) if !sub.resolved => + case s @ ScalarSubquery(sub, _, exprId, _, _, _) if !sub.resolved => resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId, _, _) if !sub.resolved => resolveSubQuery(e, outer)(Exists(_, _, exprId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c8466460e5823..49f0d438d0a23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -935,7 +935,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB checkOuterReference(plan, expr) expr match { - case ScalarSubquery(query, outerAttrs, _, _, _) => + case ScalarSubquery(query, outerAttrs, _, _, _, _) => // Scalar subquery must return one column as output. if (query.output.size != 1) { expr.failAnalysis( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index e7384dac2d53e..228bb4805c85f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -254,13 +254,20 @@ object SubExprUtils extends PredicateHelper { * scalar subquery during planning. * * Note: `exprId` is used to have a unique name in explain string output. + * + * `mayHaveCountBug` is whether it's possible for the subquery to evaluate to non-null on + * empty input (zero tuples). It is false if the subquery has a GROUP BY clause, because in that + * case the subquery yields no row at all on empty input to the GROUP BY, which evaluates to NULL. + * It is set in PullupCorrelatedPredicates to true/false, before it is set its value is None. + * See constructLeftJoins in RewriteCorrelatedScalarSubquery for more details. */ case class ScalarSubquery( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, - hint: Option[HintInfo] = None) + hint: Option[HintInfo] = None, + mayHaveCountBug: Option[Boolean] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { override def dataType: DataType = { assert(plan.schema.fields.nonEmpty, "Scalar subquery should have only one column") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index faafeecc316a6..83ff5e3973910 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, IN_SUBQ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /* * This file defines optimization rules related to subqueries. @@ -325,9 +327,22 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { - case ScalarSubquery(sub, children, exprId, conditions, hint) if children.nonEmpty => + case ScalarSubquery(sub, children, exprId, conditions, hint, mayHaveCountBugOld) + if children.nonEmpty => val (newPlan, newCond) = decorrelate(sub, plan) - ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint) + val mayHaveCountBug = if (mayHaveCountBugOld.isEmpty) { + // Check whether the pre-rewrite subquery had empty groupingExpressions. If yes, it may + // be subject to the COUNT bug. If it has non-empty groupingExpressions, there is + // no COUNT bug. + val (topPart, havingNode, aggNode) = splitSubquery(sub) + (aggNode.isDefined && aggNode.get.groupingExpressions.isEmpty) + } else { + // For idempotency, we must save this variable the first time this rule is run, because + // decorrelation introduces a GROUP BY is if one wasn't already present. + mayHaveCountBugOld.get + } + ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), + hint, Some(mayHaveCountBug)) case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan) Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint) @@ -519,7 +534,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe * (optional second part) and the Aggregate below the HAVING CLAUSE (optional third part). * When the third part is empty, it means the subquery is a non-aggregated single-row subquery. */ - private def splitSubquery( + def splitSubquery( plan: LogicalPlan): (Seq[LogicalPlan], Option[Filter], Option[Aggregate]) = { val topPart = ArrayBuffer.empty[LogicalPlan] var bottomPart: LogicalPlan = plan @@ -569,7 +584,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = { val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() val newChild = subqueries.foldLeft(child) { - case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint)) => + case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug)) => val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions) val origOutput = query.output.head // The subquery appears on the right side of the join, hence add its hint to the right @@ -581,8 +596,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe currentChild.output :+ origOutput, Join(currentChild, query, LeftOuter, conditions.reduceOption(And), joinHint)) + if (Utils.isTesting) { + assert(mayHaveCountBug.isDefined) + } if (resultWithZeroTups.isEmpty) { - // CASE 1: Subquery guaranteed not to have the COUNT bug + // CASE 1: Subquery guaranteed not to have the COUNT bug because it evaluates to NULL + // with zero tuples. + planWithoutCountBug + } else if (!mayHaveCountBug.getOrElse(true) && + !conf.getConf(SQLConf.DECORRELATE_SUBQUERY_LEGACY_INCORRECT_COUNT_HANDLING_ENABLED)) { + // Subquery guaranteed not to have the COUNT bug because it had non-empty GROUP BY clause planWithoutCountBug } else { val (topPart, havingNode, aggNode) = splitSubquery(query) @@ -800,7 +823,7 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(SCALAR_SUBQUERY)) { - case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _) + case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _) if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty => assert(p.projectList.size == 1) stripOuterReferences(p.projectList).head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4986dc3661c06..e5023498513b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3251,6 +3251,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DECORRELATE_SUBQUERY_LEGACY_INCORRECT_COUNT_HANDLING_ENABLED = + buildConf("spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled") + .internal() + .doc("If enabled, revert to legacy incorrect behavior for certain subqueries with COUNT or " + + "similar aggregates: see SPARK-43098.") + .version("3.5.0") + .booleanConf + .createWithDefault(false) + val OPTIMIZE_ONE_ROW_RELATION_SUBQUERY = buildConf("spark.sql.optimizer.optimizeOneRowRelationSubquery") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index cfc39617a0513..947a7314142fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -133,7 +133,7 @@ case class InsertAdaptiveSparkPlan( return subqueryMap.toMap } plan.foreach(_.expressions.filter(_.containsPattern(PLAN_EXPRESSION)).foreach(_.foreach { - case expressions.ScalarSubquery(p, _, exprId, _, _) + case expressions.ScalarSubquery(p, _, exprId, _, _, _) if !subqueryMap.contains(exprId.id) => val executedPlan = compileSubquery(p) verifyAdaptivePlan(executedPlan, p) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index f88d2ffd541d3..c3f4274058350 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -31,7 +31,7 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning( _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _, _) => + case expressions.ScalarSubquery(_, _, exprId, _, _, _) => execution.ScalarSubquery(subqueryMap(exprId.id), exprId) case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) => val expr = if (values.length == 1) { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out new file mode 100644 index 0000000000000..908bc579d7b3f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out @@ -0,0 +1,228 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +create temp view l (a, b) +as values + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0), + (3, 3.0), + (null, null), + (null, 5.0), + (6, null) +-- !query analysis +CreateViewCommand `l`, [(a,None), (b,None)], values + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0), + (3, 3.0), + (null, null), + (null, 5.0), + (6, null), false, false, LocalTempView, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +create temp view r (c, d) +as values + (2, 3.0), + (2, 3.0), + (3, 2.0), + (4, 1.0), + (null, null), + (null, 5.0), + (6, null) +-- !query analysis +CreateViewCommand `r`, [(c,None), (d,None)], values + (2, 3.0), + (2, 3.0), + (3, 2.0), + (4, 1.0), + (null, null), + (null, 5.0), + (6, null), false, false, LocalTempView, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c) from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Aggregate [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c group by c) from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Aggregate [c#x], [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c group by 'constant') from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Aggregate [constant], [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, ( + select (count(*)) is null + from r + where l.a = r.c) +from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#x] +: +- Aggregate [isnull(count(1)) AS (count(1) IS NULL)#x] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, ( + select (count(*)) is null + from r + where l.a = r.c + group by r.c) +from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#x] +: +- Aggregate [c#x], [isnull(count(1)) AS (count(1) IS NULL)#x] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c having count(*) <= 1) from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Filter (count(1)#xL <= cast(1 as bigint)) +: +- Aggregate [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c having count(*) >= 2) from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Filter (count(1)#xL >= cast(2 as bigint)) +: +- Aggregate [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +set spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled = true +-- !query analysis +SetCommand (spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled,Some(true)) + + +-- !query +select *, (select count(*) from r where l.a = r.c) from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Aggregate [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c group by c) from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Aggregate [c#x], [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from r where l.a = r.c group by 'constant') from l +-- !query analysis +Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL] +: +- Aggregate [constant], [count(1) AS count(1)#xL] +: +- Filter (outer(a#x) = c#x) +: +- SubqueryAlias r +: +- View (`r`, [c#x,d#x]) +: +- Project [cast(col1#x as int) AS c#x, cast(col2#x as decimal(2,1)) AS d#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias l + +- View (`l`, [a#x,b#x]) + +- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +reset spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled +-- !query analysis +ResetCommand spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-count-bug.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-count-bug.sql new file mode 100644 index 0000000000000..0ca2f07b301de --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-count-bug.sql @@ -0,0 +1,58 @@ +create temp view l (a, b) +as values + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0), + (3, 3.0), + (null, null), + (null, 5.0), + (6, null); + +create temp view r (c, d) +as values + (2, 3.0), + (2, 3.0), + (3, 2.0), + (4, 1.0), + (null, null), + (null, 5.0), + (6, null); + +-- count bug, empty groups should evaluate to 0 +select *, (select count(*) from r where l.a = r.c) from l; + +-- no count bug, empty groups should evaluate to null +select *, (select count(*) from r where l.a = r.c group by c) from l; +select *, (select count(*) from r where l.a = r.c group by 'constant') from l; + +-- count bug, empty groups should evaluate to false - but this case is wrong due to bug SPARK-43156 +select *, ( + select (count(*)) is null + from r + where l.a = r.c) +from l; + +-- no count bug, empty groups should evaluate to null +select *, ( + select (count(*)) is null + from r + where l.a = r.c + group by r.c) +from l; + +-- Empty groups should evaluate to 0, and groups filtered by HAVING should evaluate to NULL +select *, (select count(*) from r where l.a = r.c having count(*) <= 1) from l; + +-- Empty groups are filtered by HAVING and should evaluate to null +select *, (select count(*) from r where l.a = r.c having count(*) >= 2) from l; + + +set spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled = true; + +-- With legacy behavior flag set, both cases evaluate to 0 +select *, (select count(*) from r where l.a = r.c) from l; +select *, (select count(*) from r where l.a = r.c group by c) from l; +select *, (select count(*) from r where l.a = r.c group by 'constant') from l; + +reset spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out new file mode 100644 index 0000000000000..3012b67cf8ce8 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-count-bug.sql.out @@ -0,0 +1,207 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +create temp view l (a, b) +as values + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0), + (3, 3.0), + (null, null), + (null, 5.0), + (6, null) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temp view r (c, d) +as values + (2, 3.0), + (2, 3.0), + (3, 2.0), + (4, 1.0), + (null, null), + (null, 5.0), + (6, null) +-- !query schema +struct<> +-- !query output + + + +-- !query +select *, (select count(*) from r where l.a = r.c) from l +-- !query schema +struct +-- !query output +1 2.0 0 +1 2.0 0 +2 1.0 2 +2 1.0 2 +3 3.0 1 +6 NULL 1 +NULL 5.0 0 +NULL NULL 0 + + +-- !query +select *, (select count(*) from r where l.a = r.c group by c) from l +-- !query schema +struct +-- !query output +1 2.0 NULL +1 2.0 NULL +2 1.0 2 +2 1.0 2 +3 3.0 1 +6 NULL 1 +NULL 5.0 NULL +NULL NULL NULL + + +-- !query +select *, (select count(*) from r where l.a = r.c group by 'constant') from l +-- !query schema +struct +-- !query output +1 2.0 NULL +1 2.0 NULL +2 1.0 2 +2 1.0 2 +3 3.0 1 +6 NULL 1 +NULL 5.0 NULL +NULL NULL NULL + + +-- !query +select *, ( + select (count(*)) is null + from r + where l.a = r.c) +from l +-- !query schema +struct +-- !query output +1 2.0 NULL +1 2.0 NULL +2 1.0 false +2 1.0 false +3 3.0 false +6 NULL false +NULL 5.0 NULL +NULL NULL NULL + + +-- !query +select *, ( + select (count(*)) is null + from r + where l.a = r.c + group by r.c) +from l +-- !query schema +struct +-- !query output +1 2.0 NULL +1 2.0 NULL +2 1.0 false +2 1.0 false +3 3.0 false +6 NULL false +NULL 5.0 NULL +NULL NULL NULL + + +-- !query +select *, (select count(*) from r where l.a = r.c having count(*) <= 1) from l +-- !query schema +struct +-- !query output +1 2.0 0 +1 2.0 0 +2 1.0 NULL +2 1.0 NULL +3 3.0 1 +6 NULL 1 +NULL 5.0 0 +NULL NULL 0 + + +-- !query +select *, (select count(*) from r where l.a = r.c having count(*) >= 2) from l +-- !query schema +struct +-- !query output +1 2.0 NULL +1 2.0 NULL +2 1.0 2 +2 1.0 2 +3 3.0 NULL +6 NULL NULL +NULL 5.0 NULL +NULL NULL NULL + + +-- !query +set spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled = true +-- !query schema +struct +-- !query output +spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled true + + +-- !query +select *, (select count(*) from r where l.a = r.c) from l +-- !query schema +struct +-- !query output +1 2.0 0 +1 2.0 0 +2 1.0 2 +2 1.0 2 +3 3.0 1 +6 NULL 1 +NULL 5.0 0 +NULL NULL 0 + + +-- !query +select *, (select count(*) from r where l.a = r.c group by c) from l +-- !query schema +struct +-- !query output +1 2.0 0 +1 2.0 0 +2 1.0 2 +2 1.0 2 +3 3.0 1 +6 NULL 1 +NULL 5.0 0 +NULL NULL 0 + + +-- !query +select *, (select count(*) from r where l.a = r.c group by 'constant') from l +-- !query schema +struct +-- !query output +1 2.0 0 +1 2.0 0 +2 1.0 2 +2 1.0 2 +3 3.0 1 +6 NULL 1 +NULL 5.0 0 +NULL NULL 0 + + +-- !query +reset spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 9246d360becbc..32d913ca3b425 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -663,14 +663,26 @@ class SubquerySuite extends QueryTest checkAnswer( sql( """ - |select l.b, (select (r.c + count(*)) is null + |select l.b, (select (min(r.c) + count(*)) is null |from r - |where l.a = r.c group by r.c) from l + |where l.a = r.c) from l """.stripMargin), Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) } + test("SPARK-43098: no COUNT bug with group-by") { + checkAnswer( + sql( + """ + |select l.b, (select (r.c + count(*)) is null + |from r + |where l.a = r.c group by r.c) from l + """.stripMargin), + Row(1.0, false) :: Row(1.0, false) :: Row(2.0, null) :: Row(2.0, null) :: + Row(3.0, false) :: Row(5.0, null) :: Row(null, false) :: Row(null, null) :: Nil) + } + test("SPARK-16804: Correlated subqueries containing LIMIT - 1") { withTempView("onerow") { Seq(1).toDF("c1").createOrReplaceTempView("onerow") @@ -1844,8 +1856,8 @@ class SubquerySuite extends QueryTest | ) |FROM l """.stripMargin), - Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: - Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) + Row(1.0, false) :: Row(1.0, false) :: Row(2.0, null) :: Row(2.0, null) :: + Row(3.0, false) :: Row(5.0, null) :: Row(null, false) :: Row(null, null) :: Nil) } test("SPARK-28441: COUNT bug with non-foldable expression") {