forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-42151][SQL] Align UPDATE assignments with table attributes
### What changes were proposed in this pull request? This PR adds a rule to align UPDATE assignments with table attributes. ### Why are the changes needed? These changes are needed so that we can rewrite UPDATE statements into executable plans for tables that support row-level operations. In particular, our row-level mutation framework assumes Spark is responsible for building an updated version of each affected row and that row is passed back to the data source. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR comes with tests. Closes apache#40308 from aokolnychyi/spark-42151-v2. Authored-by: aokolnychyi <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
- Loading branch information
1 parent
66392c4
commit 1c057f5
Showing
10 changed files
with
1,182 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
167 changes: 167 additions & 0 deletions
167
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.analysis | ||
|
||
import scala.collection.mutable | ||
|
||
import org.apache.spark.sql.catalyst.SQLConfHelper | ||
import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal} | ||
import org.apache.spark.sql.catalyst.plans.logical.Assignment | ||
import org.apache.spark.sql.catalyst.util.CharVarcharUtils | ||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ | ||
import org.apache.spark.sql.errors.QueryCompilationErrors | ||
import org.apache.spark.sql.types.{DataType, StructType} | ||
|
||
object AssignmentUtils extends SQLConfHelper with CastSupport { | ||
|
||
/** | ||
* Aligns assignments to match table columns. | ||
* <p> | ||
* This method processes and reorders given assignments so that each target column gets | ||
* an expression it should be set to. If a column does not have a matching assignment, | ||
* it will be set to its current value. For example, if one passes table attributes c1, c2 | ||
* and an assignment c2 = 1, this method will return c1 = c1, c2 = 1. This allows Spark to | ||
* construct an updated version of a row. | ||
* <p> | ||
* This method also handles updates to nested columns. If there is an assignment to a particular | ||
* nested field, this method will construct a new struct with one field updated preserving other | ||
* fields that have not been modified. For example, if one passes table attributes c1, c2 | ||
* where c2 is a struct with fields n1 and n2 and an assignment c2.n2 = 1, this method will | ||
* return c1 = c1, c2 = struct(c2.n1, 1). | ||
* | ||
* @param attrs table attributes | ||
* @param assignments assignments to align | ||
* @return aligned assignments that match table attributes | ||
*/ | ||
def alignAssignments( | ||
attrs: Seq[Attribute], | ||
assignments: Seq[Assignment]): Seq[Assignment] = { | ||
|
||
val errors = new mutable.ArrayBuffer[String]() | ||
|
||
val output = attrs.map { attr => | ||
applyAssignments( | ||
col = restoreActualType(attr), | ||
colExpr = attr, | ||
assignments, | ||
addError = err => errors += err, | ||
colPath = Seq(attr.name)) | ||
} | ||
|
||
if (errors.nonEmpty) { | ||
throw QueryCompilationErrors.invalidRowLevelOperationAssignments(assignments, errors.toSeq) | ||
} | ||
|
||
attrs.zip(output).map { case (attr, expr) => Assignment(attr, expr) } | ||
} | ||
|
||
private def restoreActualType(attr: Attribute): Attribute = { | ||
attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)) | ||
} | ||
|
||
private def applyAssignments( | ||
col: Attribute, | ||
colExpr: Expression, | ||
assignments: Seq[Assignment], | ||
addError: String => Unit, | ||
colPath: Seq[String]): Expression = { | ||
|
||
val (exactAssignments, otherAssignments) = assignments.partition { assignment => | ||
assignment.key.semanticEquals(colExpr) | ||
} | ||
|
||
val fieldAssignments = otherAssignments.filter { assignment => | ||
assignment.key.exists(_.semanticEquals(colExpr)) | ||
} | ||
|
||
if (exactAssignments.size > 1) { | ||
val conflictingValuesStr = exactAssignments.map(_.value.sql).mkString(", ") | ||
addError(s"Multiple assignments for '${colPath.quoted}': $conflictingValuesStr") | ||
colExpr | ||
} else if (exactAssignments.nonEmpty && fieldAssignments.nonEmpty) { | ||
val conflictingAssignments = exactAssignments ++ fieldAssignments | ||
val conflictingAssignmentsStr = conflictingAssignments.map(_.sql).mkString(", ") | ||
addError(s"Conflicting assignments for '${colPath.quoted}': $conflictingAssignmentsStr") | ||
colExpr | ||
} else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) { | ||
TableOutputResolver.checkNullability(colExpr, col, conf, colPath) | ||
} else if (exactAssignments.nonEmpty) { | ||
val value = exactAssignments.head.value | ||
TableOutputResolver.resolveUpdate(value, col, conf, addError, colPath) | ||
} else { | ||
applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath) | ||
} | ||
} | ||
|
||
private def applyFieldAssignments( | ||
col: Attribute, | ||
colExpr: Expression, | ||
assignments: Seq[Assignment], | ||
addError: String => Unit, | ||
colPath: Seq[String]): Expression = { | ||
|
||
col.dataType match { | ||
case structType: StructType => | ||
val fieldAttrs = structType.toAttributes | ||
val fieldExprs = structType.fields.zipWithIndex.map { case (field, ordinal) => | ||
GetStructField(colExpr, ordinal, Some(field.name)) | ||
} | ||
val updatedFieldExprs = fieldAttrs.zip(fieldExprs).map { case (fieldAttr, fieldExpr) => | ||
applyAssignments(fieldAttr, fieldExpr, assignments, addError, colPath :+ fieldAttr.name) | ||
} | ||
toNamedStruct(structType, updatedFieldExprs) | ||
|
||
case otherType => | ||
addError( | ||
"Updating nested fields is only supported for StructType but " + | ||
s"'${colPath.quoted}' is of type $otherType") | ||
colExpr | ||
} | ||
} | ||
|
||
private def toNamedStruct(structType: StructType, fieldExprs: Seq[Expression]): Expression = { | ||
val namedStructExprs = structType.fields.zip(fieldExprs).flatMap { case (field, expr) => | ||
Seq(Literal(field.name), expr) | ||
} | ||
CreateNamedStruct(namedStructExprs) | ||
} | ||
|
||
/** | ||
* Checks whether assignments are aligned and compatible with table columns. | ||
* | ||
* @param attrs table attributes | ||
* @param assignments assignments to check | ||
* @return true if the assignments are aligned | ||
*/ | ||
def aligned(attrs: Seq[Attribute], assignments: Seq[Assignment]): Boolean = { | ||
if (attrs.size != assignments.size) { | ||
return false | ||
} | ||
|
||
attrs.zip(assignments).forall { case (attr, assignment) => | ||
val attrType = CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType) | ||
val isMatchingAssignment = assignment.key match { | ||
case key: Attribute if conf.resolver(key.name, attr.name) => true | ||
case _ => false | ||
} | ||
isMatchingAssignment && | ||
DataType.equalsIgnoreCompatibleNullability(assignment.value.dataType, attrType) && | ||
(attr.nullable || !assignment.value.nullable) | ||
} | ||
} | ||
} |
103 changes: 103 additions & 0 deletions
103
...main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.analysis | ||
|
||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast} | ||
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull | ||
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan, MergeIntoTable, UpdateTable} | ||
import org.apache.spark.sql.catalyst.rules.Rule | ||
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND | ||
import org.apache.spark.sql.catalyst.util.CharVarcharUtils | ||
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations | ||
import org.apache.spark.sql.errors.QueryCompilationErrors | ||
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation | ||
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy | ||
|
||
/** | ||
* A rule that resolves assignments in row-level commands. | ||
* | ||
* Note that this rule must be run before rewriting row-level commands into executable plans. | ||
* This rule does not apply to tables that accept any schema. Such tables must inject their own | ||
* rules to resolve assignments. | ||
*/ | ||
object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { | ||
|
||
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( | ||
_.containsPattern(COMMAND), ruleId) { | ||
case u: UpdateTable if !u.skipSchemaResolution && u.resolved && | ||
supportsRowLevelOperations(u.table) && !u.aligned => | ||
validateStoreAssignmentPolicy() | ||
val newTable = u.table.transform { | ||
case r: DataSourceV2Relation => | ||
r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata)) | ||
} | ||
val newAssignments = AssignmentUtils.alignAssignments(u.table.output, u.assignments) | ||
u.copy(table = newTable, assignments = newAssignments) | ||
|
||
case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned => | ||
resolveAssignments(u) | ||
|
||
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved => | ||
resolveAssignments(m) | ||
} | ||
|
||
private def validateStoreAssignmentPolicy(): Unit = { | ||
// SPARK-28730: LEGACY store assignment policy is disallowed in data source v2 | ||
if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) { | ||
throw QueryCompilationErrors.legacyStoreAssignmentPolicyError() | ||
} | ||
} | ||
|
||
private def supportsRowLevelOperations(table: LogicalPlan): Boolean = { | ||
EliminateSubqueryAliases(table) match { | ||
case DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _) => true | ||
case _ => false | ||
} | ||
} | ||
|
||
private def resolveAssignments(p: LogicalPlan): LogicalPlan = { | ||
p.transformExpressions { | ||
case assignment: Assignment => | ||
val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) { | ||
AssertNotNull(assignment.value) | ||
} else { | ||
assignment.value | ||
} | ||
val casted = if (assignment.key.dataType != nullHandled.dataType) { | ||
val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true) | ||
cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) | ||
cast | ||
} else { | ||
nullHandled | ||
} | ||
val rawKeyType = assignment.key.transform { | ||
case a: AttributeReference => | ||
CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a) | ||
}.dataType | ||
val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) { | ||
CharVarcharUtils.stringLengthCheck(casted, rawKeyType) | ||
} else { | ||
casted | ||
} | ||
val cleanedKey = assignment.key.transform { | ||
case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a) | ||
} | ||
Assignment(cleanedKey, finalValue) | ||
} | ||
} | ||
} |
Oops, something went wrong.