Skip to content

Commit

Permalink
[SPARK-43165][SQL] Move canWrite to DataTypeUtils
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Move canWrite to DataTypeUtils.

### Why are the changes needed?

canWrite access SQLConf so we can move it out from DataType to make DataType as public simpiler API.

### Does this PR introduce _any_ user-facing change?

NO

### How was this patch tested?

UT

Closes apache#40825 from amaliujia/catalyst_datatype_refactor_7.

Authored-by: Rui Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
amaliujia authored and cloud-fan committed Apr 19, 2023
1 parent c1a02e7 commit 56f6af7
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 142 deletions.
2 changes: 2 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ object MimaExcludes {

// Exclude rules for 3.5.x from 3.4.0
lazy val v35excludes = defaultExcludes ++ Seq(
// [SPARK-43165][SQL] Move canWrite to DataTypeUtils
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.types.DataType.canWrite")
)

// Defulat exclude rules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ object TableOutputResolver {
colPath: Seq[String]): Boolean = {
conf.storeAssignmentPolicy match {
case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI =>
DataType.canWrite(
DataTypeUtils.canWrite(
valueType, expectedType, byName, conf.resolver, colPath.quoted,
conf.storeAssignmentPolicy, addError)
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
*/
package org.apache.spark.sql.catalyst.types

import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy.{ANSI, STRICT}
import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, MapType, NullType, StructType}

object DataTypeUtils {
/**
Expand Down Expand Up @@ -73,5 +77,140 @@ object DataTypeUtils {
case (fromDataType, toDataType) => fromDataType == toDataType
}
}

private val SparkGeneratedName = """col\d+""".r
private def isSparkGeneratedName(name: String): Boolean = name match {
case SparkGeneratedName(_*) => true
case _ => false
}

/**
* Returns true if the write data type can be read using the read data type.
*
* The write type is compatible with the read type if:
* - Both types are arrays, the array element types are compatible, and element nullability is
* compatible (read allows nulls or write does not contain nulls).
* - Both types are maps and the map key and value types are compatible, and value nullability
* is compatible (read allows nulls or write does not contain nulls).
* - Both types are structs and have the same number of fields. The type and nullability of each
* field from read/write is compatible. If byName is true, the name of each field from
* read/write needs to be the same.
* - Both types are atomic and the write type can be safely cast to the read type.
*
* Extra fields in write-side structs are not allowed to avoid accidentally writing data that
* the read schema will not read, and to ensure map key equality is not changed when data is read.
*
* @param write a write-side data type to validate against the read type
* @param read a read-side data type
* @return true if data written with the write type can be read using the read type
*/
def canWrite(
write: DataType,
read: DataType,
byName: Boolean,
resolver: Resolver,
context: String,
storeAssignmentPolicy: StoreAssignmentPolicy.Value,
addError: String => Unit): Boolean = {
(write, read) match {
case (wArr: ArrayType, rArr: ArrayType) =>
// run compatibility check first to produce all error messages
val typesCompatible = canWrite(
wArr.elementType, rArr.elementType, byName, resolver, context + ".element",
storeAssignmentPolicy, addError)

if (wArr.containsNull && !rArr.containsNull) {
addError(s"Cannot write nullable elements to array of non-nulls: '$context'")
false
} else {
typesCompatible
}

case (wMap: MapType, rMap: MapType) =>
// map keys cannot include data fields not in the read schema without changing equality when
// read. map keys can be missing fields as long as they are nullable in the read schema.

// run compatibility check first to produce all error messages
val keyCompatible = canWrite(
wMap.keyType, rMap.keyType, byName, resolver, context + ".key",
storeAssignmentPolicy, addError)
val valueCompatible = canWrite(
wMap.valueType, rMap.valueType, byName, resolver, context + ".value",
storeAssignmentPolicy, addError)

if (wMap.valueContainsNull && !rMap.valueContainsNull) {
addError(s"Cannot write nullable values to map of non-nulls: '$context'")
false
} else {
keyCompatible && valueCompatible
}

case (StructType(writeFields), StructType(readFields)) =>
var fieldCompatible = true
readFields.zip(writeFields).zipWithIndex.foreach {
case ((rField, wField), i) =>
val nameMatch = resolver(wField.name, rField.name) || isSparkGeneratedName(wField.name)
val fieldContext = s"$context.${rField.name}"
val typesCompatible = canWrite(
wField.dataType, rField.dataType, byName, resolver, fieldContext,
storeAssignmentPolicy, addError)

if (byName && !nameMatch) {
addError(s"Struct '$context' $i-th field name does not match " +
s"(may be out of order): expected '${rField.name}', found '${wField.name}'")
fieldCompatible = false
} else if (!rField.nullable && wField.nullable) {
addError(s"Cannot write nullable values to non-null field: '$fieldContext'")
fieldCompatible = false
} else if (!typesCompatible) {
// errors are added in the recursive call to canWrite above
fieldCompatible = false
}
}

if (readFields.size > writeFields.size) {
val missingFieldsStr = readFields.takeRight(readFields.size - writeFields.size)
.map(f => s"'${f.name}'").mkString(", ")
if (missingFieldsStr.nonEmpty) {
addError(s"Struct '$context' missing fields: $missingFieldsStr")
fieldCompatible = false
}

} else if (writeFields.size > readFields.size) {
val extraFieldsStr = writeFields.takeRight(writeFields.size - readFields.size)
.map(f => s"'${f.name}'").mkString(", ")
addError(s"Cannot write extra fields to struct '$context': $extraFieldsStr")
fieldCompatible = false
}

fieldCompatible

case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == STRICT =>
if (!Cast.canUpCast(w, r)) {
addError(s"Cannot safely cast '$context': ${w.catalogString} to ${r.catalogString}")
false
} else {
true
}

case (_: NullType, _) if storeAssignmentPolicy == ANSI => true

case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == ANSI =>
if (!Cast.canANSIStoreAssign(w, r)) {
addError(s"Cannot safely cast '$context': ${w.catalogString} to ${r.catalogString}")
false
} else {
true
}

case (w, r) if DataTypeUtils.sameType(w, r) && !w.isInstanceOf[NullType] =>
true

case (w, r) =>
addError(s"Cannot write '$context': " +
s"${w.catalogString} is incompatible with ${r.catalogString}")
false
}
}
}

139 changes: 1 addition & 138 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,12 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkThrowable
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer}
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy.{ANSI, STRICT}
import org.apache.spark.sql.types.DayTimeIntervalType._
import org.apache.spark.sql.types.YearMonthIntervalType._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -402,139 +400,4 @@ object DataType {
case _ => true
}
}

private val SparkGeneratedName = """col\d+""".r
private def isSparkGeneratedName(name: String): Boolean = name match {
case SparkGeneratedName(_*) => true
case _ => false
}

/**
* Returns true if the write data type can be read using the read data type.
*
* The write type is compatible with the read type if:
* - Both types are arrays, the array element types are compatible, and element nullability is
* compatible (read allows nulls or write does not contain nulls).
* - Both types are maps and the map key and value types are compatible, and value nullability
* is compatible (read allows nulls or write does not contain nulls).
* - Both types are structs and have the same number of fields. The type and nullability of each
* field from read/write is compatible. If byName is true, the name of each field from
* read/write needs to be the same.
* - Both types are atomic and the write type can be safely cast to the read type.
*
* Extra fields in write-side structs are not allowed to avoid accidentally writing data that
* the read schema will not read, and to ensure map key equality is not changed when data is read.
*
* @param write a write-side data type to validate against the read type
* @param read a read-side data type
* @return true if data written with the write type can be read using the read type
*/
def canWrite(
write: DataType,
read: DataType,
byName: Boolean,
resolver: Resolver,
context: String,
storeAssignmentPolicy: StoreAssignmentPolicy.Value,
addError: String => Unit): Boolean = {
(write, read) match {
case (wArr: ArrayType, rArr: ArrayType) =>
// run compatibility check first to produce all error messages
val typesCompatible = canWrite(
wArr.elementType, rArr.elementType, byName, resolver, context + ".element",
storeAssignmentPolicy, addError)

if (wArr.containsNull && !rArr.containsNull) {
addError(s"Cannot write nullable elements to array of non-nulls: '$context'")
false
} else {
typesCompatible
}

case (wMap: MapType, rMap: MapType) =>
// map keys cannot include data fields not in the read schema without changing equality when
// read. map keys can be missing fields as long as they are nullable in the read schema.

// run compatibility check first to produce all error messages
val keyCompatible = canWrite(
wMap.keyType, rMap.keyType, byName, resolver, context + ".key",
storeAssignmentPolicy, addError)
val valueCompatible = canWrite(
wMap.valueType, rMap.valueType, byName, resolver, context + ".value",
storeAssignmentPolicy, addError)

if (wMap.valueContainsNull && !rMap.valueContainsNull) {
addError(s"Cannot write nullable values to map of non-nulls: '$context'")
false
} else {
keyCompatible && valueCompatible
}

case (StructType(writeFields), StructType(readFields)) =>
var fieldCompatible = true
readFields.zip(writeFields).zipWithIndex.foreach {
case ((rField, wField), i) =>
val nameMatch = resolver(wField.name, rField.name) || isSparkGeneratedName(wField.name)
val fieldContext = s"$context.${rField.name}"
val typesCompatible = canWrite(
wField.dataType, rField.dataType, byName, resolver, fieldContext,
storeAssignmentPolicy, addError)

if (byName && !nameMatch) {
addError(s"Struct '$context' $i-th field name does not match " +
s"(may be out of order): expected '${rField.name}', found '${wField.name}'")
fieldCompatible = false
} else if (!rField.nullable && wField.nullable) {
addError(s"Cannot write nullable values to non-null field: '$fieldContext'")
fieldCompatible = false
} else if (!typesCompatible) {
// errors are added in the recursive call to canWrite above
fieldCompatible = false
}
}

if (readFields.size > writeFields.size) {
val missingFieldsStr = readFields.takeRight(readFields.size - writeFields.size)
.map(f => s"'${f.name}'").mkString(", ")
if (missingFieldsStr.nonEmpty) {
addError(s"Struct '$context' missing fields: $missingFieldsStr")
fieldCompatible = false
}

} else if (writeFields.size > readFields.size) {
val extraFieldsStr = writeFields.takeRight(writeFields.size - readFields.size)
.map(f => s"'${f.name}'").mkString(", ")
addError(s"Cannot write extra fields to struct '$context': $extraFieldsStr")
fieldCompatible = false
}

fieldCompatible

case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == STRICT =>
if (!Cast.canUpCast(w, r)) {
addError(s"Cannot safely cast '$context': ${w.catalogString} to ${r.catalogString}")
false
} else {
true
}

case (_: NullType, _) if storeAssignmentPolicy == ANSI => true

case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == ANSI =>
if (!Cast.canANSIStoreAssign(w, r)) {
addError(s"Cannot safely cast '$context': ${w.catalogString} to ${r.catalogString}")
false
} else {
true
}

case (w, r) if DataTypeUtils.sameType(w, r) && !w.isInstanceOf[NullType] =>
true

case (w, r) =>
addError(s"Cannot write '$context': " +
s"${w.catalogString} is incompatible with ${r.catalogString}")
false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy

Expand Down Expand Up @@ -482,7 +483,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite {
desc: String,
byName: Boolean = true): Unit = {
assert(
DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name,
DataTypeUtils.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name,
storeAssignmentPolicy,
errMsg => fail(s"Should not produce errors but was called with: $errMsg")), desc)
}
Expand All @@ -508,7 +509,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite {
(checkErrors: Seq[String] => Unit): Unit = {
val errs = new mutable.ArrayBuffer[String]()
assert(
DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name,
DataTypeUtils.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name,
storeAssignmentPolicy, errMsg => errs += errMsg) === false, desc)
assert(errs.size === numErrs, s"Should produce $numErrs error messages")
checkErrors(errs.toSeq)
Expand Down

0 comments on commit 56f6af7

Please sign in to comment.