Skip to content

Commit

Permalink
[SPARK-42151][SQL] Align UPDATE assignments with table attributes
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds a rule to align UPDATE assignments with table attributes.

### Why are the changes needed?

These changes are needed so that we can rewrite UPDATE statements into executable plans for tables that support row-level operations. In particular, our row-level mutation framework assumes Spark is responsible for building an updated version of each affected row and that row is passed back to the data source.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

This PR comes with tests.

Closes apache#40308 from aokolnychyi/spark-42151-v2.

Authored-by: aokolnychyi <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
aokolnychyi authored and cloud-fan committed Apr 18, 2023
1 parent 66392c4 commit 1c057f5
Show file tree
Hide file tree
Showing 10 changed files with 1,182 additions and 55 deletions.
5 changes: 5 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,11 @@
"The <functionName> does not support ordering on type <dataType>."
]
},
"INVALID_ROW_LEVEL_OPERATION_ASSIGNMENTS" : {
"message" : [
"<errors>"
]
},
"IN_SUBQUERY_DATA_TYPE_MISMATCH" : {
"message" : [
"The data type of one or more elements in the left hand side of an IN subquery is not compatible with the data type of the output of the subquery. Mismatched columns: [<mismatchedColumns>], left side: [<leftType>], right side: [<rightType>]."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveRandomSeed ::
ResolveBinaryArithmetic ::
ResolveUnion ::
ResolveRowLevelCommandAssignments ::
RewriteDeleteFromTable ::
typeCoercionRules ++
Seq(
Expand Down Expand Up @@ -3329,43 +3330,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
} else {
v2Write
}

case u: UpdateTable if !u.skipSchemaResolution && u.resolved =>
resolveAssignments(u)

case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved =>
resolveAssignments(m)
}

private def resolveAssignments(p: LogicalPlan): LogicalPlan = {
p.transformExpressions {
case assignment: Assignment =>
val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
AssertNotNull(assignment.value)
} else {
assignment.value
}
val casted = if (assignment.key.dataType != nullHandled.dataType) {
val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true)
cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
cast
} else {
nullHandled
}
val rawKeyType = assignment.key.transform {
case a: AttributeReference =>
CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
}.dataType
val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
} else {
casted
}
val cleanedKey = assignment.key.transform {
case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
}
Assignment(cleanedKey, finalValue)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal}
import org.apache.spark.sql.catalyst.plans.logical.Assignment
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{DataType, StructType}

object AssignmentUtils extends SQLConfHelper with CastSupport {

/**
* Aligns assignments to match table columns.
* <p>
* This method processes and reorders given assignments so that each target column gets
* an expression it should be set to. If a column does not have a matching assignment,
* it will be set to its current value. For example, if one passes table attributes c1, c2
* and an assignment c2 = 1, this method will return c1 = c1, c2 = 1. This allows Spark to
* construct an updated version of a row.
* <p>
* This method also handles updates to nested columns. If there is an assignment to a particular
* nested field, this method will construct a new struct with one field updated preserving other
* fields that have not been modified. For example, if one passes table attributes c1, c2
* where c2 is a struct with fields n1 and n2 and an assignment c2.n2 = 1, this method will
* return c1 = c1, c2 = struct(c2.n1, 1).
*
* @param attrs table attributes
* @param assignments assignments to align
* @return aligned assignments that match table attributes
*/
def alignAssignments(
attrs: Seq[Attribute],
assignments: Seq[Assignment]): Seq[Assignment] = {

val errors = new mutable.ArrayBuffer[String]()

val output = attrs.map { attr =>
applyAssignments(
col = restoreActualType(attr),
colExpr = attr,
assignments,
addError = err => errors += err,
colPath = Seq(attr.name))
}

if (errors.nonEmpty) {
throw QueryCompilationErrors.invalidRowLevelOperationAssignments(assignments, errors.toSeq)
}

attrs.zip(output).map { case (attr, expr) => Assignment(attr, expr) }
}

private def restoreActualType(attr: Attribute): Attribute = {
attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType))
}

private def applyAssignments(
col: Attribute,
colExpr: Expression,
assignments: Seq[Assignment],
addError: String => Unit,
colPath: Seq[String]): Expression = {

val (exactAssignments, otherAssignments) = assignments.partition { assignment =>
assignment.key.semanticEquals(colExpr)
}

val fieldAssignments = otherAssignments.filter { assignment =>
assignment.key.exists(_.semanticEquals(colExpr))
}

if (exactAssignments.size > 1) {
val conflictingValuesStr = exactAssignments.map(_.value.sql).mkString(", ")
addError(s"Multiple assignments for '${colPath.quoted}': $conflictingValuesStr")
colExpr
} else if (exactAssignments.nonEmpty && fieldAssignments.nonEmpty) {
val conflictingAssignments = exactAssignments ++ fieldAssignments
val conflictingAssignmentsStr = conflictingAssignments.map(_.sql).mkString(", ")
addError(s"Conflicting assignments for '${colPath.quoted}': $conflictingAssignmentsStr")
colExpr
} else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) {
TableOutputResolver.checkNullability(colExpr, col, conf, colPath)
} else if (exactAssignments.nonEmpty) {
val value = exactAssignments.head.value
TableOutputResolver.resolveUpdate(value, col, conf, addError, colPath)
} else {
applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath)
}
}

private def applyFieldAssignments(
col: Attribute,
colExpr: Expression,
assignments: Seq[Assignment],
addError: String => Unit,
colPath: Seq[String]): Expression = {

col.dataType match {
case structType: StructType =>
val fieldAttrs = structType.toAttributes
val fieldExprs = structType.fields.zipWithIndex.map { case (field, ordinal) =>
GetStructField(colExpr, ordinal, Some(field.name))
}
val updatedFieldExprs = fieldAttrs.zip(fieldExprs).map { case (fieldAttr, fieldExpr) =>
applyAssignments(fieldAttr, fieldExpr, assignments, addError, colPath :+ fieldAttr.name)
}
toNamedStruct(structType, updatedFieldExprs)

case otherType =>
addError(
"Updating nested fields is only supported for StructType but " +
s"'${colPath.quoted}' is of type $otherType")
colExpr
}
}

private def toNamedStruct(structType: StructType, fieldExprs: Seq[Expression]): Expression = {
val namedStructExprs = structType.fields.zip(fieldExprs).flatMap { case (field, expr) =>
Seq(Literal(field.name), expr)
}
CreateNamedStruct(namedStructExprs)
}

/**
* Checks whether assignments are aligned and compatible with table columns.
*
* @param attrs table attributes
* @param assignments assignments to check
* @return true if the assignments are aligned
*/
def aligned(attrs: Seq[Attribute], assignments: Seq[Assignment]): Boolean = {
if (attrs.size != assignments.size) {
return false
}

attrs.zip(assignments).forall { case (attr, assignment) =>
val attrType = CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)
val isMatchingAssignment = assignment.key match {
case key: Attribute if conf.resolver(key.name, attr.name) => true
case _ => false
}
isMatchingAssignment &&
DataType.equalsIgnoreCompatibleNullability(assignment.value.dataType, attrType) &&
(attr.nullable || !assignment.value.nullable)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast}
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan, MergeIntoTable, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy

/**
* A rule that resolves assignments in row-level commands.
*
* Note that this rule must be run before rewriting row-level commands into executable plans.
* This rule does not apply to tables that accept any schema. Such tables must inject their own
* rules to resolve assignments.
*/
object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
_.containsPattern(COMMAND), ruleId) {
case u: UpdateTable if !u.skipSchemaResolution && u.resolved &&
supportsRowLevelOperations(u.table) && !u.aligned =>
validateStoreAssignmentPolicy()
val newTable = u.table.transform {
case r: DataSourceV2Relation =>
r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))
}
val newAssignments = AssignmentUtils.alignAssignments(u.table.output, u.assignments)
u.copy(table = newTable, assignments = newAssignments)

case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned =>
resolveAssignments(u)

case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved =>
resolveAssignments(m)
}

private def validateStoreAssignmentPolicy(): Unit = {
// SPARK-28730: LEGACY store assignment policy is disallowed in data source v2
if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) {
throw QueryCompilationErrors.legacyStoreAssignmentPolicyError()
}
}

private def supportsRowLevelOperations(table: LogicalPlan): Boolean = {
EliminateSubqueryAliases(table) match {
case DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _) => true
case _ => false
}
}

private def resolveAssignments(p: LogicalPlan): LogicalPlan = {
p.transformExpressions {
case assignment: Assignment =>
val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
AssertNotNull(assignment.value)
} else {
assignment.value
}
val casted = if (assignment.key.dataType != nullHandled.dataType) {
val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true)
cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
cast
} else {
nullHandled
}
val rawKeyType = assignment.key.transform {
case a: AttributeReference =>
CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
}.dataType
val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
} else {
casted
}
val cleanedKey = assignment.key.transform {
case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
}
Assignment(cleanedKey, finalValue)
}
}
}
Loading

0 comments on commit 1c057f5

Please sign in to comment.