Skip to content

Commit

Permalink
[SPARK-43190][SQL] ListQuery.childOutput should be consistent with ch…
Browse files Browse the repository at this point in the history
…ild output

### What changes were proposed in this pull request?

Update `ListQuery` to only store the number of columns of the original plan, instead of directly storing the original plan output attributes.

### Why are the changes needed?

Storing the plan output attributes is troublesome as we have to maintain them and keep them in sync with the plan. For example, `DeduplicateRelations` may change the plan output, and today we do not update `ListQuery.childOutputs` to keep sync.

`ListQuery.childOutputs` was added by apache#18968 . It's only used to track the original plan output attributes as subquery de-correlation may add more columns. We can do the same thing by storing the number of columns of the plan.

### Does this PR introduce _any_ user-facing change?

No, there is no user-facing bug exposed.

### How was this patch tested?

a new plan test

Closes apache#40851 from cloud-fan/list_query.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
cloud-fan committed Apr 20, 2023
1 parent 09a4353 commit 9e17731
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2425,7 +2425,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _))
if values.forall(_.resolved) && !l.resolved =>
val expr = resolveSubQuery(l, outer)((plan, exprs) => {
ListQuery(plan, exprs, exprId, plan.output)
ListQuery(plan, exprs, exprId, plan.output.length)
})
InSubquery(values, expr.asInstanceOf[ListQuery])
case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,11 @@ abstract class TypeCoercionBase {

// Handle type casting required between value expression and subquery output
// in IN subquery.
case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _, conditions, _))
if !i.resolved && lhs.length == sub.output.length =>
case i @ InSubquery(lhs, l: ListQuery)
if !i.resolved && lhs.length == l.plan.output.length =>
// LHS is the value expressions of IN subquery.
// RHS is the subquery output.
val rhs = sub.output
val rhs = l.plan.output

val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
findWiderTypeForTwo(l.dataType, r.dataType)
Expand All @@ -383,8 +383,7 @@ abstract class TypeCoercionBase {
case (e, _) => e
}

val newSub = Project(castedRhs, sub)
InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output, conditions))
InSubquery(newLhs, l.withNewPlan(Project(castedRhs, l.plan)))
} else {
i
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,12 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
final override val nodePatterns: Seq[TreePattern] = Seq(IN_SUBQUERY)

override def checkInputDataTypes(): TypeCheckResult = {
if (values.length != query.childOutputs.length) {
if (values.length != query.numCols) {
DataTypeMismatch(
errorSubClass = "IN_SUBQUERY_LENGTH_MISMATCH",
messageParameters = Map(
"leftLength" -> values.length.toString,
"rightLength" -> query.childOutputs.length.toString,
"rightLength" -> query.numCols.toString,
"leftColumns" -> values.map(toSQLExpr(_)).mkString(", "),
"rightColumns" -> query.childOutputs.map(toSQLExpr(_)).mkString(", ")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,16 +354,19 @@ case class ListQuery(
plan: LogicalPlan,
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
childOutputs: Seq[Attribute] = Seq.empty,
// The plan of list query may have more columns after de-correlation, and we need to track the
// number of the columns of the original plan, to report the data type properly.
numCols: Int = -1,
joinCond: Seq[Expression] = Seq.empty,
hint: Option[HintInfo] = None)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
override def dataType: DataType = if (childOutputs.length > 1) {
def childOutputs: Seq[Attribute] = plan.output.take(numCols)
override def dataType: DataType = if (numCols > 1) {
childOutputs.toStructType
} else {
childOutputs.head.dataType
plan.output.head.dataType
}
override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty
override lazy val resolved: Boolean = childrenResolved && plan.resolved && numCols != -1
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
override def withNewHint(hint: Option[HintInfo]): ListQuery = copy(hint = hint)
Expand All @@ -373,7 +376,7 @@ case class ListQuery(
plan.canonicalized,
outerAttrs.map(_.canonicalized),
ExprId(0),
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]),
numCols,
joinCond.map(_.canonicalized))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
return filterApplicationSidePlan
}
val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)),
ListQuery(aggregate, childOutputs = aggregate.output))
ListQuery(aggregate, numCols = aggregate.output.length))
Filter(filter, filterApplicationSidePlan)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,10 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint)
case ListQuery(sub, children, exprId, childOutputs, conditions, hint) if children.nonEmpty =>
case ListQuery(sub, children, exprId, numCols, conditions, hint) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
val joinCond = getJoinCondition(newCond, conditions)
ListQuery(newPlan, children, exprId, childOutputs, joinCond, hint)
ListQuery(newPlan, children, exprId, numCols, joinCond, hint)
case LateralSubquery(sub, children, exprId, conditions, hint) if children.nonEmpty =>
val (newPlan, newCond) = decorrelate(sub, plan, handleCountBug = true)
LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1500,4 +1500,28 @@ class AnalysisSuite extends AnalysisTest with Matchers {
assert(refs.map(_.output).distinct.length == 3)
}
}

test("SPARK-43190: ListQuery.childOutput should be consistent with child output") {
val listQuery1 = ListQuery(testRelation2.select($"a"))
val listQuery2 = ListQuery(testRelation2.select($"b"))
val plan = testRelation3.where($"f".in(listQuery1) && $"f".in(listQuery2)).analyze
val resolvedCondition = plan.expressions.head
val finalPlan = testRelation2.join(testRelation3).where(resolvedCondition).analyze
val resolvedListQueries = finalPlan.expressions.flatMap(_.collect {
case l: ListQuery => l
})
assert(resolvedListQueries.length == 2)

def collectLocalRelations(plan: LogicalPlan): Seq[LocalRelation] = plan.collect {
case l: LocalRelation => l
}
val localRelations = resolvedListQueries.flatMap(l => collectLocalRelations(l.plan))
assert(localRelations.length == 2)
// DeduplicateRelations should deduplicate plans in subquery expressions as well.
assert(localRelations.head.output != localRelations.last.output)

resolvedListQueries.foreach { l =>
assert(l.childOutputs == l.plan.output)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,7 @@ trait PlanTestBase extends PredicateHelper with SQLHelper with SQLConfHelper { s
case e: Exists =>
e.copy(plan = normalizeExprIds(e.plan), exprId = ExprId(0))
case l: ListQuery =>
l.copy(
plan = normalizeExprIds(l.plan),
exprId = ExprId(0),
childOutputs = l.childOutputs.map(_.withExprId(ExprId(0))))
l.copy(plan = normalizeExprIds(l.plan), exprId = ExprId(0))
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case OuterReference(a: AttributeReference) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp
val alias = Alias(buildKeys(broadcastKeyIndex), buildKeys(broadcastKeyIndex).toString)()
val aggregate = Aggregate(Seq(alias), Seq(alias), buildPlan)
DynamicPruningExpression(expressions.InSubquery(
Seq(value), ListQuery(aggregate, childOutputs = aggregate.output)))
Seq(value), ListQuery(aggregate, numCols = aggregate.output.length)))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,6 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla

val buildQuery = Aggregate(buildKeys, buildKeys, matchingRowsPlan)
DynamicPruningExpression(
InSubquery(pruningKeys, ListQuery(buildQuery, childOutputs = buildQuery.output)))
InSubquery(pruningKeys, ListQuery(buildQuery, numCols = buildQuery.output.length)))
}
}

0 comments on commit 9e17731

Please sign in to comment.