diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 1edf625fdc308..11b280efad893 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -353,6 +353,11 @@ "The does not support ordering on type ." ] }, + "INVALID_ROW_LEVEL_OPERATION_ASSIGNMENTS" : { + "message" : [ + "" + ] + }, "IN_SUBQUERY_DATA_TYPE_MISMATCH" : { "message" : [ "The data type of one or more elements in the left hand side of an IN subquery is not compatible with the data type of the output of the subquery. Mismatched columns: [], left side: [], right side: []." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 402bd532ab90c..8b3efe09834d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -320,6 +320,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveRandomSeed :: ResolveBinaryArithmetic :: ResolveUnion :: + ResolveRowLevelCommandAssignments :: RewriteDeleteFromTable :: typeCoercionRules ++ Seq( @@ -3329,43 +3330,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } else { v2Write } - - case u: UpdateTable if !u.skipSchemaResolution && u.resolved => - resolveAssignments(u) - - case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved => - resolveAssignments(m) - } - - private def resolveAssignments(p: LogicalPlan): LogicalPlan = { - p.transformExpressions { - case assignment: Assignment => - val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) { - AssertNotNull(assignment.value) - } else { - assignment.value - } - val casted = if (assignment.key.dataType != nullHandled.dataType) { - val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true) - cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) - cast - } else { - nullHandled - } - val rawKeyType = assignment.key.transform { - case a: AttributeReference => - CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a) - }.dataType - val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) { - CharVarcharUtils.stringLengthCheck(casted, rawKeyType) - } else { - casted - } - val cleanedKey = assignment.key.transform { - case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a) - } - Assignment(cleanedKey, finalValue) - } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala new file mode 100644 index 0000000000000..265909d3a7e7d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal} +import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.{DataType, StructType} + +object AssignmentUtils extends SQLConfHelper with CastSupport { + + /** + * Aligns assignments to match table columns. + *

+ * This method processes and reorders given assignments so that each target column gets + * an expression it should be set to. If a column does not have a matching assignment, + * it will be set to its current value. For example, if one passes table attributes c1, c2 + * and an assignment c2 = 1, this method will return c1 = c1, c2 = 1. This allows Spark to + * construct an updated version of a row. + *

+ * This method also handles updates to nested columns. If there is an assignment to a particular + * nested field, this method will construct a new struct with one field updated preserving other + * fields that have not been modified. For example, if one passes table attributes c1, c2 + * where c2 is a struct with fields n1 and n2 and an assignment c2.n2 = 1, this method will + * return c1 = c1, c2 = struct(c2.n1, 1). + * + * @param attrs table attributes + * @param assignments assignments to align + * @return aligned assignments that match table attributes + */ + def alignAssignments( + attrs: Seq[Attribute], + assignments: Seq[Assignment]): Seq[Assignment] = { + + val errors = new mutable.ArrayBuffer[String]() + + val output = attrs.map { attr => + applyAssignments( + col = restoreActualType(attr), + colExpr = attr, + assignments, + addError = err => errors += err, + colPath = Seq(attr.name)) + } + + if (errors.nonEmpty) { + throw QueryCompilationErrors.invalidRowLevelOperationAssignments(assignments, errors.toSeq) + } + + attrs.zip(output).map { case (attr, expr) => Assignment(attr, expr) } + } + + private def restoreActualType(attr: Attribute): Attribute = { + attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)) + } + + private def applyAssignments( + col: Attribute, + colExpr: Expression, + assignments: Seq[Assignment], + addError: String => Unit, + colPath: Seq[String]): Expression = { + + val (exactAssignments, otherAssignments) = assignments.partition { assignment => + assignment.key.semanticEquals(colExpr) + } + + val fieldAssignments = otherAssignments.filter { assignment => + assignment.key.exists(_.semanticEquals(colExpr)) + } + + if (exactAssignments.size > 1) { + val conflictingValuesStr = exactAssignments.map(_.value.sql).mkString(", ") + addError(s"Multiple assignments for '${colPath.quoted}': $conflictingValuesStr") + colExpr + } else if (exactAssignments.nonEmpty && fieldAssignments.nonEmpty) { + val conflictingAssignments = exactAssignments ++ fieldAssignments + val conflictingAssignmentsStr = conflictingAssignments.map(_.sql).mkString(", ") + addError(s"Conflicting assignments for '${colPath.quoted}': $conflictingAssignmentsStr") + colExpr + } else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) { + TableOutputResolver.checkNullability(colExpr, col, conf, colPath) + } else if (exactAssignments.nonEmpty) { + val value = exactAssignments.head.value + TableOutputResolver.resolveUpdate(value, col, conf, addError, colPath) + } else { + applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath) + } + } + + private def applyFieldAssignments( + col: Attribute, + colExpr: Expression, + assignments: Seq[Assignment], + addError: String => Unit, + colPath: Seq[String]): Expression = { + + col.dataType match { + case structType: StructType => + val fieldAttrs = structType.toAttributes + val fieldExprs = structType.fields.zipWithIndex.map { case (field, ordinal) => + GetStructField(colExpr, ordinal, Some(field.name)) + } + val updatedFieldExprs = fieldAttrs.zip(fieldExprs).map { case (fieldAttr, fieldExpr) => + applyAssignments(fieldAttr, fieldExpr, assignments, addError, colPath :+ fieldAttr.name) + } + toNamedStruct(structType, updatedFieldExprs) + + case otherType => + addError( + "Updating nested fields is only supported for StructType but " + + s"'${colPath.quoted}' is of type $otherType") + colExpr + } + } + + private def toNamedStruct(structType: StructType, fieldExprs: Seq[Expression]): Expression = { + val namedStructExprs = structType.fields.zip(fieldExprs).flatMap { case (field, expr) => + Seq(Literal(field.name), expr) + } + CreateNamedStruct(namedStructExprs) + } + + /** + * Checks whether assignments are aligned and compatible with table columns. + * + * @param attrs table attributes + * @param assignments assignments to check + * @return true if the assignments are aligned + */ + def aligned(attrs: Seq[Attribute], assignments: Seq[Assignment]): Boolean = { + if (attrs.size != assignments.size) { + return false + } + + attrs.zip(assignments).forall { case (attr, assignment) => + val attrType = CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType) + val isMatchingAssignment = assignment.key match { + case key: Attribute if conf.resolver(key.name, attr.name) => true + case _ => false + } + isMatchingAssignment && + DataType.equalsIgnoreCompatibleNullability(assignment.value.dataType, attrType) && + (attr.nullable || !assignment.value.nullable) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala new file mode 100644 index 0000000000000..596dc00b9176b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast} +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan, MergeIntoTable, UpdateTable} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy + +/** + * A rule that resolves assignments in row-level commands. + * + * Note that this rule must be run before rewriting row-level commands into executable plans. + * This rule does not apply to tables that accept any schema. Such tables must inject their own + * rules to resolve assignments. + */ +object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(COMMAND), ruleId) { + case u: UpdateTable if !u.skipSchemaResolution && u.resolved && + supportsRowLevelOperations(u.table) && !u.aligned => + validateStoreAssignmentPolicy() + val newTable = u.table.transform { + case r: DataSourceV2Relation => + r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata)) + } + val newAssignments = AssignmentUtils.alignAssignments(u.table.output, u.assignments) + u.copy(table = newTable, assignments = newAssignments) + + case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned => + resolveAssignments(u) + + case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved => + resolveAssignments(m) + } + + private def validateStoreAssignmentPolicy(): Unit = { + // SPARK-28730: LEGACY store assignment policy is disallowed in data source v2 + if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) { + throw QueryCompilationErrors.legacyStoreAssignmentPolicyError() + } + } + + private def supportsRowLevelOperations(table: LogicalPlan): Boolean = { + EliminateSubqueryAliases(table) match { + case DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _) => true + case _ => false + } + } + + private def resolveAssignments(p: LogicalPlan): LogicalPlan = { + p.transformExpressions { + case assignment: Assignment => + val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) { + AssertNotNull(assignment.value) + } else { + assignment.value + } + val casted = if (assignment.key.dataType != nullHandled.dataType) { + val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true) + cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) + cast + } else { + nullHandled + } + val rawKeyType = assignment.key.transform { + case a: AttributeReference => + CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a) + }.dataType + val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) { + CharVarcharUtils.stringLengthCheck(casted, rawKeyType) + } else { + casted + } + val cleanedKey = assignment.key.transform { + case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a) + } + Assignment(cleanedKey, finalValue) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 45664d3657c06..6aa800e1d2fed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -70,6 +70,87 @@ object TableOutputResolver { } } + def resolveUpdate( + value: Expression, + col: Attribute, + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String]): Expression = { + + (value.dataType, col.dataType) match { + // no need to reorder inner fields or cast if types are already compatible + case (valueType, colType) if DataType.equalsIgnoreCompatibleNullability(valueType, colType) => + val canWriteExpr = canWrite(valueType, colType, byName = true, conf, addError, colPath) + if (canWriteExpr) checkNullability(value, col, conf, colPath) else value + case (valueType: StructType, colType: StructType) => + val resolvedValue = resolveStructType( + value, valueType, col, colType, + byName = true, conf, addError, colPath) + resolvedValue.getOrElse(value) + case (valueType: ArrayType, colType: ArrayType) => + val resolvedValue = resolveArrayType( + value, valueType, col, colType, + byName = true, conf, addError, colPath) + resolvedValue.getOrElse(value) + case (valueType: MapType, colType: MapType) => + val resolvedValue = resolveMapType( + value, valueType, col, colType, + byName = true, conf, addError, colPath) + resolvedValue.getOrElse(value) + case _ => + checkUpdate(value, col, conf, addError, colPath) + } + } + + private def checkUpdate( + value: Expression, + attr: Attribute, + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String]): Expression = { + + val attrTypeHasCharVarchar = CharVarcharUtils.hasCharVarchar(attr.dataType) + val attrTypeWithoutCharVarchar = if (attrTypeHasCharVarchar) { + CharVarcharUtils.replaceCharVarcharWithString(attr.dataType) + } else { + attr.dataType + } + + val canWriteValue = canWrite( + value.dataType, attrTypeWithoutCharVarchar, + byName = true, conf, addError, colPath) + + if (canWriteValue) { + val nullCheckedValue = checkNullability(value, attr, conf, colPath) + val casted = cast(nullCheckedValue, attrTypeWithoutCharVarchar, conf, colPath.quoted) + val exprWithStrLenCheck = if (conf.charVarcharAsString || !attrTypeHasCharVarchar) { + casted + } else { + CharVarcharUtils.stringLengthCheck(casted, attr.dataType) + } + Alias(exprWithStrLenCheck, attr.name)(explicitMetadata = Some(attr.metadata)) + } else { + value + } + } + + private def canWrite( + valueType: DataType, + expectedType: DataType, + byName: Boolean, + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String]): Boolean = { + conf.storeAssignmentPolicy match { + case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI => + DataType.canWrite( + valueType, expectedType, byName, conf.resolver, colPath.quoted, + conf.storeAssignmentPolicy, addError) + case _ => + true + } + } + private def reorderColumnsByName( inputCols: Seq[NamedExpression], expectedCols: Seq[Attribute], @@ -170,7 +251,7 @@ object TableOutputResolver { } } - private def checkNullability( + private[sql] def checkNullability( input: Expression, expected: Attribute, conf: SQLConf, @@ -190,7 +271,7 @@ object TableOutputResolver { } private def resolveStructType( - input: NamedExpression, + input: Expression, inputType: StructType, expected: Attribute, expectedType: StructType, @@ -221,7 +302,7 @@ object TableOutputResolver { } private def resolveArrayType( - input: NamedExpression, + input: Expression, inputType: ArrayType, expected: Attribute, expectedType: ArrayType, @@ -247,7 +328,7 @@ object TableOutputResolver { } private def resolveMapType( - input: NamedExpression, + input: Expression, inputType: MapType, expected: Attribute, expectedType: MapType, @@ -332,7 +413,6 @@ object TableOutputResolver { } else { tableAttr.dataType } - val storeAssignmentPolicy = conf.storeAssignmentPolicy lazy val outputField = if (isCompatible(tableAttr, queryExpr)) { if (requiresNullChecks(queryExpr, tableAttr, conf)) { val assert = AssertNotNull(queryExpr, colPath) @@ -354,18 +434,11 @@ object TableOutputResolver { Some(Alias(exprWithStrLenCheck, tableAttr.name)(explicitMetadata = Some(tableAttr.metadata))) } - storeAssignmentPolicy match { - case StoreAssignmentPolicy.LEGACY => - outputField + val canWriteExpr = canWrite( + queryExpr.dataType, attrTypeWithoutCharVarchar, + byName, conf, addError, colPath) - case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI => - // run the type check first to ensure type errors are present - val canWrite = DataType.canWrite( - queryExpr.dataType, attrTypeWithoutCharVarchar, byName, conf.resolver, colPath.quoted, - storeAssignmentPolicy, addError) - - if (canWrite) outputField else None - } + if (canWriteExpr) outputField else None } private def cast( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index dd41cefb724bb..91925cb0e7c01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -257,6 +257,16 @@ object NonNullLiteral { } } +/** + * Extractor for retrieving Boolean literals. + */ +object BooleanLiteral { + def unapply(a: Any): Option[Boolean] = a match { + case Literal(a: Boolean, BooleanType) => Some(a) + case _ => None + } +} + /** * Extractor for retrieving Float literals. */ @@ -287,6 +297,16 @@ object IntegerLiteral { } } +/** + * Extractor for retrieving Long literals. + */ +object LongLiteral { + def unapply(a: Any): Option[Long] = a match { + case Literal(a: Long, LongType) => Some(a) + case _ => None + } +} + /** * Extractor for retrieving String literals. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 6f5df8d48c485..68943c918b12f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, UnresolvedException} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, UnresolvedException} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.FunctionResource import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, NamedExpression, Unevaluable, V2ExpressionUtils} @@ -684,6 +684,9 @@ case class UpdateTable( table: LogicalPlan, assignments: Seq[Assignment], condition: Option[Expression]) extends UnaryCommand with SupportsSubquery { + + lazy val aligned: Boolean = AssignmentUtils.aligned(table.output, assignments) + override def child: LogicalPlan = table override protected def withNewChildInternal(newChild: LogicalPlan): UpdateTable = copy(table = newChild) @@ -778,6 +781,7 @@ case class Assignment(key: Expression, value: Expression) extends Expression override def dataType: DataType = throw new UnresolvedException("nullable") override def left: Expression = key override def right: Expression = value + override def sql: String = s"${key.sql} = ${value.sql}" override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Assignment = copy(key = newLeft, value = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index c1d06a457f59d..7fa048c5dc378 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -91,6 +91,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.ResolveLambdaVariables" :: "org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAliasReference" :: "org.apache.spark.sql.catalyst.analysis.ResolveOrderByAll" :: + "org.apache.spark.sql.catalyst.analysis.ResolveRowLevelCommandAssignments" :: "org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnion" :: "org.apache.spark.sql.catalyst.analysis.ResolveWindowTime" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 0ae17ab823a8a..e1a09a4f84389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, CreateStruct, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} import org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, LogicalPlan, SerdeInfo, Window} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, InsertIntoStatement, Join, LogicalPlan, SerdeInfo, Window} import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.connector.catalog._ @@ -2084,6 +2084,17 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { "errors" -> errors.mkString("\n- "))) } + def invalidRowLevelOperationAssignments( + assignments: Seq[Assignment], + errors: Seq[String]): Throwable = { + + new AnalysisException( + errorClass = "DATATYPE_MISMATCH.INVALID_ROW_LEVEL_OPERATION_ASSIGNMENTS", + messageParameters = Map( + "sqlExpr" -> assignments.map(toSQLExpr).mkString(", "), + "errors" -> errors.mkString("\n- ", "\n- ", ""))) + } + def secondArgumentOfFunctionIsNotIntegerError( function: String, e: NumberFormatException): Throwable = { // The second argument of {function} function needs to be an integer diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala new file mode 100644 index 0000000000000..a173106db99e9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala @@ -0,0 +1,779 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import java.util.Collections + +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, FunctionRegistry, NoSuchTableException, ResolveSessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, AttributeReference, BooleanLiteral, Cast, CheckOverflowInTableInsert, CreateNamedStruct, EvalMode, GetStructField, IntegerLiteral, LambdaFunction, LongLiteral, MapFromArrays, StringLiteral} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, StaticInvoke} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan, UpdateTable} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, CatalogV2Util, Column, ColumnDefaultValue, Identifier, SupportsRowLevelOperations, TableCapability, TableCatalog} +import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform} +import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy +import org.apache.spark.sql.types.{BooleanType, IntegerType, StructType} + +class AlignUpdateAssignmentsSuite extends AnalysisTest { + + private val primitiveTable = { + val t = mock(classOf[SupportsRowLevelOperations]) + val schema = new StructType() + .add("i", "INT", nullable = false) + .add("l", "LONG") + .add("txt", "STRING") + when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + + private val nestedStructTable = { + val t = mock(classOf[SupportsRowLevelOperations]) + val schema = new StructType() + .add("i", "INT") + .add( + "s", + "STRUCT>", + nullable = false) + .add("txt", "STRING") + when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + + private val mapArrayTable = { + val t = mock(classOf[SupportsRowLevelOperations]) + val schema = new StructType() + .add("i", "INT") + .add("a", "ARRAY>") + .add("m", "MAP") + .add("txt", "STRING") + when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + + private val charVarcharTable = { + val t = mock(classOf[SupportsRowLevelOperations]) + val schema = new StructType() + .add("c", "CHAR(5)") + .add( + "s", + "STRUCT", + nullable = false) + .add( + "a", + "ARRAY>", + nullable = false) + .add( + "mk", + "MAP, STRING>", + nullable = false) + .add( + "mv", + "MAP>", + nullable = false) + when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + + private val acceptsAnySchemaTable = { + val t = mock(classOf[SupportsRowLevelOperations]) + val schema = new StructType() + .add("i", "INT", nullable = false) + .add("l", "LONG") + .add("txt", "STRING") + when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + when(t.capabilities()).thenReturn(Collections.singleton(TableCapability.ACCEPT_ANY_SCHEMA)) + t + } + + private val defaultValuesTable = { + val t = mock(classOf[SupportsRowLevelOperations]) + val iDefault = new ColumnDefaultValue("42", LiteralValue(42, IntegerType)) + when(t.columns()).thenReturn(Array( + Column.create("b", BooleanType, true, null, null), + Column.create("i", IntegerType, true, null, iDefault, null))) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + + private val v2Catalog = { + val newCatalog = mock(classOf[TableCatalog]) + when(newCatalog.loadTable(any())).thenAnswer((invocation: InvocationOnMock) => { + val ident = invocation.getArgument[Identifier](0) + ident.name match { + case "primitive_table" => primitiveTable + case "nested_struct_table" => nestedStructTable + case "map_array_table" => mapArrayTable + case "char_varchar_table" => charVarcharTable + case "accepts_any_schema_table" => acceptsAnySchemaTable + case "default_values_table" => defaultValuesTable + case name => throw new NoSuchTableException(Seq(name)) + } + }) + when(newCatalog.name()).thenReturn("cat") + newCatalog + } + + private val v1SessionCatalog = + new SessionCatalog(new InMemoryCatalog(), FunctionRegistry.builtin, new SQLConf()) + + private val v2SessionCatalog = new V2SessionCatalog(v1SessionCatalog) + + private val catalogManager = { + val manager = mock(classOf[CatalogManager]) + when(manager.catalog(any())).thenAnswer((invocation: InvocationOnMock) => { + invocation.getArgument[String](0) match { + case "testcat" => v2Catalog + case CatalogManager.SESSION_CATALOG_NAME => v2SessionCatalog + case name => throw new CatalogNotFoundException(s"No such catalog: $name") + } + }) + when(manager.currentCatalog).thenReturn(v2Catalog) + when(manager.currentNamespace).thenReturn(Array.empty[String]) + when(manager.v1SessionCatalog).thenReturn(v1SessionCatalog) + when(manager.v2SessionCatalog).thenReturn(v2SessionCatalog) + manager + } + + test("align assignments (primitive types)") { + val sql1 = "UPDATE primitive_table AS t SET t.txt = 'new', t.i = 1" + parseAndAlignAssignments(sql1) match { + case Seq( + Assignment(i: AttributeReference, IntegerLiteral(1)), + Assignment(l: AttributeReference, lValue: AttributeReference), + Assignment(txt: AttributeReference, StringLiteral("new"))) => + + assert(i.name == "i") + assert(l.name == "l" && l == lValue) + assert(txt.name == "txt") + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + val sql2 = "UPDATE primitive_table SET l = 10L" + parseAndAlignAssignments(sql2) match { + case Seq( + Assignment(i: AttributeReference, iValue: AttributeReference), + Assignment(l: AttributeReference, LongLiteral(10L)), + Assignment(txt: AttributeReference, txtValue: AttributeReference)) => + + assert(i.name == "i" && i == iValue) + assert(l.name == "l") + assert(txt.name == "txt" && txt == txtValue) + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + val sql3 = "UPDATE primitive_table AS t SET t.txt = 'new', t.l = 10L, t.i = -1" + parseAndAlignAssignments(sql3) match { + case Seq( + Assignment(i: AttributeReference, IntegerLiteral(-1)), + Assignment(l: AttributeReference, LongLiteral(10L)), + Assignment(txt: AttributeReference, StringLiteral("new"))) => + + assert(i.name == "i") + assert(l.name == "l") + assert(txt.name == "txt") + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + } + + test("align assignments (structs)") { + val sql1 = + "UPDATE nested_struct_table " + + "SET s = named_struct('n_s', named_struct('dn_i', 1, 'dn_l', 100L), 'n_i', 1)" + parseAndAlignAssignments(sql1) match { + case Seq( + Assignment(i: AttributeReference, iValue: AttributeReference), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(txt: AttributeReference, txtValue: AttributeReference)) => + + assert(i.name == "i" && i == iValue) + + assert(s.name == "s") + sValue.children match { + case Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_s"), nsValue: CreateNamedStruct) => + + nsValue.children match { + case Seq( + StringLiteral("dn_i"), GetStructField(_, _, Some("dn_i")), + StringLiteral("dn_l"), GetStructField(_, _, Some("dn_l"))) => + // OK + + case nsValueChildren => + fail(s"Unexpected children for 's.n_s': $nsValueChildren") + } + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + + assert(txt.name == "txt" && txt == txtValue) + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + val sql2 = "UPDATE nested_struct_table SET s.n_s = named_struct('dn_i', 1, 'dn_l', 1L)" + parseAndAlignAssignments(sql2) match { + case Seq( + Assignment(i: AttributeReference, iValue: AttributeReference), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(txt: AttributeReference, txtValue: AttributeReference)) => + + assert(i.name == "i" && i == iValue) + + assert(s.name == "s") + sValue.children match { + case Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_s"), nsValue: CreateNamedStruct) => + + nsValue.children match { + case Seq( + StringLiteral("dn_i"), IntegerLiteral(1), + StringLiteral("dn_l"), LongLiteral(1L)) => + // OK + + case nsValueChildren => + fail(s"Unexpected children for 's.n_s': $nsValueChildren") + } + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + + assert(txt.name == "txt" && txt == txtValue) + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + val sql3 = "UPDATE nested_struct_table SET s.n_s = named_struct('dn_l', 1L, 'dn_i', 1)" + parseAndAlignAssignments(sql3) match { + case Seq( + Assignment(i: AttributeReference, iValue: AttributeReference), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(txt: AttributeReference, txtValue: AttributeReference)) => + + assert(i.name == "i" && i == iValue) + + assert(s.name == "s") + sValue.children match { + case Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_s"), nsValue: CreateNamedStruct) => + + nsValue.children match { + case Seq( + StringLiteral("dn_i"), GetStructField(_, _, Some("dn_i")), + StringLiteral("dn_l"), GetStructField(_, _, Some("dn_l"))) => + // OK + + case nsValueChildren => + fail(s"Unexpected children for 's.n_s': $nsValueChildren") + } + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + + assert(txt.name == "txt" && txt == txtValue) + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + val sql4 = "UPDATE nested_struct_table SET s.n_i = 1" + parseAndAlignAssignments(sql4) match { + case Seq( + Assignment(i: AttributeReference, iValue: AttributeReference), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(txt: AttributeReference, txtValue: AttributeReference)) => + + assert(i.name == "i" && i == iValue) + + assert(s.name == "s") + sValue.children match { + case Seq( + StringLiteral("n_i"), IntegerLiteral(1), + StringLiteral("n_s"), GetStructField(_, _, Some("n_s"))) => + // OK + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + + assert(txt.name == "txt" && txt == txtValue) + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + } + + test("align assignments (char and varchar types)") { + val sql1 = "UPDATE char_varchar_table SET c = 'a'" + parseAndAlignAssignments(sql1) match { + case Seq( + Assignment(c: AttributeReference, cValue: StaticInvoke), + Assignment(s: AttributeReference, sValue: AttributeReference), + Assignment(a: AttributeReference, aValue: AttributeReference), + Assignment(mk: AttributeReference, mkValue: AttributeReference), + Assignment(mv: AttributeReference, mvValue: AttributeReference)) => + + assert(c.name == "c") + assert(cValue.arguments.length == 2) + assert(cValue.functionName == "charTypeWriteSideCheck") + assert(s.name == "s" && s == sValue) + assert(a.name == "a" && a == aValue) + assert(mk.name == "mk" && mk == mkValue) + assert(mv.name == "mv" && mv == mvValue) + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + val sql2 = "UPDATE char_varchar_table SET s = named_struct('n_i', 1, 'n_vc', 'a')" + parseAndAlignAssignments(sql2) match { + case Seq( + Assignment(c: AttributeReference, cValue: AttributeReference), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(a: AttributeReference, aValue: AttributeReference), + Assignment(mk: AttributeReference, mkValue: AttributeReference), + Assignment(mv: AttributeReference, mvValue: AttributeReference)) => + + assert(c.name == "c" && c == cValue) + + assert(s.name == "s") + sValue.children match { + case Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_vc"), invoke: StaticInvoke) => + + assert(invoke.arguments.length == 2) + assert(invoke.functionName == "varcharTypeWriteSideCheck") + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + + assert(a.name == "a" && a == aValue) + assert(mk.name == "mk" && mk == mkValue) + assert(mv.name == "mv" && mv == mvValue) + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + val sql3 = "UPDATE char_varchar_table SET s.n_vc = 'a'" + parseAndAlignAssignments(sql3) match { + case Seq( + Assignment(c: AttributeReference, cValue: AttributeReference), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(a: AttributeReference, aValue: AttributeReference), + Assignment(mk: AttributeReference, mkValue: AttributeReference), + Assignment(mv: AttributeReference, mvValue: AttributeReference)) => + + assert(c.name == "c" && c == cValue) + + assert(s.name == "s") + sValue.children match { + case Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_vc"), invoke: StaticInvoke) => + + assert(invoke.arguments.length == 2) + assert(invoke.functionName == "varcharTypeWriteSideCheck") + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + + assert(a.name == "a" && a == aValue) + assert(mk.name == "mk" && mk == mkValue) + assert(mv.name == "mv" && mv == mvValue) + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + val sql4 = "UPDATE char_varchar_table SET s = named_struct('n_vc', 3, 'n_i', 1)" + parseAndAlignAssignments(sql4) match { + case Seq( + Assignment(c: AttributeReference, cValue: AttributeReference), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(a: AttributeReference, aValue: AttributeReference), + Assignment(mk: AttributeReference, mkValue: AttributeReference), + Assignment(mv: AttributeReference, mvValue: AttributeReference)) => + + assert(c.name == "c" && c == cValue) + + assert(s.name == "s") + sValue.children match { + case Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_vc"), invoke: StaticInvoke) => + + assert(invoke.arguments.length == 2) + assert(invoke.functionName == "varcharTypeWriteSideCheck") + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + + assert(a.name == "a" && a == aValue) + assert(mk.name == "mk" && mk == mkValue) + assert(mv.name == "mv" && mv == mvValue) + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + val sql5 = "UPDATE char_varchar_table SET a = array(named_struct('n_i', 1, 'n_vc', 3))" + parseAndAlignAssignments(sql5) match { + case Seq( + Assignment(c: AttributeReference, cValue: AttributeReference), + Assignment(s: AttributeReference, sValue: AttributeReference), + Assignment(a: AttributeReference, aValue: ArrayTransform), + Assignment(mk: AttributeReference, mkValue: AttributeReference), + Assignment(mv: AttributeReference, mvValue: AttributeReference)) => + + assert(c.name == "c" && c == cValue) + assert(s.name == "s" && s == sValue) + + assert(a.name == "a") + val lambda = aValue.function.asInstanceOf[LambdaFunction] + lambda.function match { + case CreateNamedStruct(Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_vc"), invoke: StaticInvoke)) => + + assert(invoke.arguments.length == 2) + assert(invoke.functionName == "varcharTypeWriteSideCheck") + + case func => + fail(s"Unexpected lambda function: $func") + } + + assert(mk.name == "mk" && mk == mkValue) + assert(mv.name == "mv" && mv == mvValue) + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + val sql6 = "UPDATE char_varchar_table SET a = array(named_struct('n_vc', 3, 'n_i', 1))" + parseAndAlignAssignments(sql6) match { + case Seq( + Assignment(c: AttributeReference, cValue: AttributeReference), + Assignment(s: AttributeReference, sValue: AttributeReference), + Assignment(a: AttributeReference, aValue: ArrayTransform), + Assignment(mk: AttributeReference, mkValue: AttributeReference), + Assignment(mv: AttributeReference, mvValue: AttributeReference)) => + + assert(c.name == "c" && c == cValue) + assert(s.name == "s" && s == sValue) + + assert(a.name == "a") + val lambda = aValue.function.asInstanceOf[LambdaFunction] + lambda.function match { + case CreateNamedStruct(Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_vc"), invoke: StaticInvoke)) => + + assert(invoke.arguments.length == 2) + assert(invoke.functionName == "varcharTypeWriteSideCheck") + + case func => + fail(s"Unexpected lambda function: $func") + } + + assert(mk.name == "mk" && mk == mkValue) + assert(mv.name == "mv" && mv == mvValue) + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + val sql7 = "UPDATE char_varchar_table SET mk = map(named_struct('n_vc', 'a', 'n_i', 1), 'v')" + parseAndAlignAssignments(sql7) match { + case Seq( + Assignment(c: AttributeReference, cValue: AttributeReference), + Assignment(s: AttributeReference, sValue: AttributeReference), + Assignment(a: AttributeReference, aValue: AttributeReference), + Assignment(mk: AttributeReference, mkValue: MapFromArrays), + Assignment(mv: AttributeReference, mvValue: AttributeReference)) => + + assert(c.name == "c" && c == cValue) + assert(s.name == "s" && s == sValue) + assert(a.name == "a" && a == aValue) + + assert(mk.name == "mk") + val keyTransform = mkValue.left.asInstanceOf[ArrayTransform] + val keyLambda = keyTransform.function.asInstanceOf[LambdaFunction] + keyLambda.function match { + case CreateNamedStruct(Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_vc"), invoke: StaticInvoke)) => + + assert(invoke.arguments.length == 2) + assert(invoke.functionName == "varcharTypeWriteSideCheck") + + case func => + fail(s"Unexpected key lambda function: $func") + } + + assert(mv.name == "mv" && mv == mvValue) + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + val sql8 = "UPDATE char_varchar_table SET mv = map('v', named_struct('n_vc', 'a', 'n_i', 1))" + parseAndAlignAssignments(sql8) match { + case Seq( + Assignment(c: AttributeReference, cValue: AttributeReference), + Assignment(s: AttributeReference, sValue: AttributeReference), + Assignment(a: AttributeReference, aValue: AttributeReference), + Assignment(mk: AttributeReference, mkValue: AttributeReference), + Assignment(mv: AttributeReference, mvValue: MapFromArrays)) => + + assert(c.name == "c" && c == cValue) + assert(s.name == "s" && s == sValue) + assert(a.name == "a" && a == aValue) + assert(mk.name == "mk" && mk == mkValue) + + assert(mv.name == "mv") + val valueTransform = mvValue.right.asInstanceOf[ArrayTransform] + val valueLambda = valueTransform.function.asInstanceOf[LambdaFunction] + valueLambda.function match { + case CreateNamedStruct(Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_vc"), invoke: StaticInvoke)) => + + assert(invoke.arguments.length == 2) + assert(invoke.functionName == "varcharTypeWriteSideCheck") + + case func => + fail(s"Unexpected key lambda function: $func") + } + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + } + + test("conflicting assignments") { + Seq(StoreAssignmentPolicy.ANSI, StoreAssignmentPolicy.STRICT).foreach { policy => + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) { + // two updates to a top-level column + assertAnalysisException( + "UPDATE primitive_table SET i = 1, l = 1L, i = -1", + "Multiple assignments for 'i': 1, -1") + + // two updates to a nested column + assertAnalysisException( + "UPDATE nested_struct_table SET s.n_i = 1, s.n_s = null, s.n_i = -1", + "Multiple assignments for 's.n_i': 1, -1") + + // conflicting updates to a nested struct and its fields + assertAnalysisException( + "UPDATE nested_struct_table " + + "SET s.n_s.dn_i = 1, s.n_s = named_struct('dn_i', 1, 'dn_l', 1L)", + "Conflicting assignments for 's.n_s'", + "cat.nested_struct_table.s.`n_s` = named_struct('dn_i', 1, 'dn_l', 1L)", + "cat.nested_struct_table.s.`n_s`.`dn_i` = 1") + } + } + } + + test("updates to nested structs in arrays") { + Seq(StoreAssignmentPolicy.ANSI, StoreAssignmentPolicy.STRICT).foreach { policy => + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) { + assertAnalysisException( + "UPDATE map_array_table SET a.i1 = 1", + "Updating nested fields is only supported for StructType but 'a' is of type ArrayType") + } + } + } + + test("ANSI mode assignments") { + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.ANSI.toString) { + val plan1 = parseAndResolve("UPDATE primitive_table SET i = NULL") + assertNullCheckExists(plan1, Seq("i")) + + val plan2 = parseAndResolve("UPDATE nested_struct_table SET s.n_i = NULL") + assertNullCheckExists(plan2, Seq("s", "n_i")) + + val plan3 = parseAndResolve("UPDATE nested_struct_table SET s.n_s.dn_i = NULL") + assertNullCheckExists(plan3, Seq("s", "n_s", "dn_i")) + + val plan4 = parseAndResolve( + "UPDATE nested_struct_table SET s.n_s = named_struct('dn_i', NULL, 'dn_l', 1L)") + assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i")) + + assertAnalysisException( + "UPDATE nested_struct_table SET s.n_s = named_struct('dn_i', 1)", + "Cannot find data for output column 's.n_s.dn_l'") + + // ANSI mode does NOT allow string to int casts + assertAnalysisException( + "UPDATE nested_struct_table SET s.n_s = named_struct('dn_i', 'string-value', 'dn_l', 1L)", + "Cannot safely cast") + + // ANSI mode allows long to int casts + val validSql1 = "UPDATE primitive_table SET i = 1L, txt = 'new', l = 10L" + parseAndAlignAssignments(validSql1) match { + case Seq( + Assignment( + i: AttributeReference, + CheckOverflowInTableInsert( + Cast(LongLiteral(1L), IntegerType, _, EvalMode.ANSI), _)), + Assignment(l: AttributeReference, LongLiteral(10L)), + Assignment(txt: AttributeReference, StringLiteral("new"))) => + + assert(i.name == "i") + assert(l.name == "l") + assert(txt.name == "txt") + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + } + } + + test("strict mode assignments") { + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.STRICT.toString) { + val plan1 = parseAndResolve("UPDATE primitive_table SET i = CAST(NULL AS INT)") + assertNullCheckExists(plan1, Seq("i")) + + val plan2 = parseAndResolve("UPDATE nested_struct_table SET s.n_i = CAST(NULL AS INT)") + assertNullCheckExists(plan2, Seq("s", "n_i")) + + val plan3 = parseAndResolve("UPDATE nested_struct_table SET s.n_s.dn_i = CAST(NULL AS INT)") + assertNullCheckExists(plan3, Seq("s", "n_s", "dn_i")) + + val plan4 = parseAndResolve( + """UPDATE nested_struct_table + |SET s.n_s = named_struct('dn_i', CAST (NULL AS INT), 'dn_l', 1L)""".stripMargin) + assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i")) + + assertAnalysisException( + "UPDATE nested_struct_table SET s.n_s = named_struct('dn_i', 1)", + "Cannot find data for output column 's.n_s.dn_l'") + + // strict mode does NOT allow string to int casts + assertAnalysisException( + "UPDATE nested_struct_table SET s.n_s = named_struct('dn_i', 'string-value', 'dn_l', 1L)", + "Cannot safely cast") + + // strict mode does not allow long to int casts + assertAnalysisException( + "UPDATE primitive_table SET i = 1L", + "Cannot safely cast") + } + } + + test("legacy mode assignments") { + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.LEGACY.toString) { + assertAnalysisException( + "UPDATE nested_struct_table SET s.n_s = named_struct('dn_i', 1)", + "LEGACY store assignment policy is disallowed in Spark data source V2") + } + } + + test("skip alignment for tables that accept any schema") { + val sql = "UPDATE accepts_any_schema_table SET txt = 'new', i = 1" + parseAndAlignAssignments(sql) match { + case Seq( + Assignment(txt: AttributeReference, StringLiteral("new")), + Assignment(i: AttributeReference, IntegerLiteral(1))) => + + assert(i.name == "i") + assert(txt.name == "txt") + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + } + + test("align assignments with default values") { + val sql = "UPDATE default_values_table SET i = DEFAULT, b = false" + parseAndAlignAssignments(sql) match { + case Seq( + Assignment(b: AttributeReference, BooleanLiteral(false)), + Assignment(i: AttributeReference, IntegerLiteral(42))) => + + assert(b.name == "b") + assert(i.name == "i") + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + } + + private def parseAndAlignAssignments(query: String): Seq[Assignment] = { + parseAndResolve(query) match { + case UpdateTable(_, assignments, _) => assignments + case plan => fail("Expected UpdateTable, but got:\n" + plan.treeString) + } + } + + private def parseAndResolve(query: String): LogicalPlan = { + val analyzer = new Analyzer(catalogManager) { + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq( + new ResolveSessionCatalog(catalogManager)) + } + val analyzed = analyzer.execute(CatalystSqlParser.parsePlan(query)) + analyzer.checkAnalysis(analyzed) + analyzed + } + + private def assertAnalysisException(query: String, messages: String*): Unit = { + val exception = intercept[AnalysisException] { + parseAndResolve(query) + } + messages.foreach(message => assert(exception.message.contains(message))) + } + + private def assertNullCheckExists(plan: LogicalPlan, colPath: Seq[String]): Unit = { + val asserts = plan.expressions.flatMap(e => e.collect { + case assert: AssertNotNull if assert.walkedTypePath == colPath => assert + }) + assert(asserts.nonEmpty, s"Must have NOT NULL checks for col $colPath") + } +}