Skip to content

Commit

Permalink
[SPARK-37829][SQL] Dataframe.joinWith outer-join should return a null…
Browse files Browse the repository at this point in the history
… value for unmatched row

### What changes were proposed in this pull request?

When doing an outer join with joinWith on DataFrames, unmatched rows return Row objects with null fields instead of a single null value. This is not a expected behavior, and it's a regression introduced in [this commit](apache@cd92f25).
This pull request aims to fix the regression, note this is not a full rollback of the commit, do not add back "schema" variable.

```
case class ClassData(a: String, b: Int)
val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF
val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF

left.joinWith(right, left("b") === right("b"), "left_outer").collect
```

```
Wrong results (current behavior):    Array(([a,1],[null,null]), ([b,2],[x,2]))
Correct results:                     Array(([a,1],null), ([b,2],[x,2]))
```

### Why are the changes needed?

We need to address the regression mentioned above. It results in unexpected behavior changes in the Dataframe joinWith API between versions 2.4.8 and 3.0.0+. This could potentially cause data correctness issues for users who expect the old behavior when using Spark 3.0.0+.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added unit test (use the same test in previous [closed pull request](apache#35140), credit to Clément de Groc)
Run sql-core and sql-catalyst submodules locally with ./build/mvn clean package -pl sql/core,sql/catalyst

Closes apache#40755 from kings129/encoder_bug_fix.

Authored-by: --global <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
kings129 authored and cloud-fan committed Apr 19, 2023
1 parent cac6f58 commit 74ce620
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,29 @@ object ExpressionEncoder {
}
val newSerializer = CreateStruct(serializers)

def nullSafe(input: Expression, result: Expression): Expression = {
If(IsNull(input), Literal.create(null, result.dataType), result)
}

val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType)
val deserializers = encoders.zipWithIndex.map { case (enc, index) =>
val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct
assert(getColExprs.size == 1, "object deserializer should have only one " +
s"`GetColumnByOrdinal`, but there are ${getColExprs.size}")

val input = GetStructField(newDeserializerInput, index)
enc.objDeserializer.transformUp {
val childDeserializer = enc.objDeserializer.transformUp {
case GetColumnByOrdinal(0, _) => input
}
}
val newDeserializer = NewInstance(cls, deserializers, ObjectType(cls), propagateNull = false)

def nullSafe(input: Expression, result: Expression): Expression = {
If(IsNull(input), Literal.create(null, result.dataType), result)
if (enc.objSerializer.nullable) {
nullSafe(input, childDeserializer)
} else {
childDeserializer
}
}
val newDeserializer =
NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false)

new ExpressionEncoder[Any](
nullSafe(newSerializerInput, newSerializer),
Expand Down
45 changes: 45 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.internal.config.MAX_RESULT_SIZE
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample}
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
Expand Down Expand Up @@ -2429,6 +2430,50 @@ class DatasetSuite extends QueryTest
assert(parquetFiles.size === 10)
}
}

test("SPARK-37829: DataFrame outer join") {
// Same as "SPARK-15441: Dataset outer join" but using DataFrames instead of Datasets
val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF().as("left")
val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF().as("right")
val joined = left.joinWith(right, $"left.b" === $"right.b", "left")

val leftFieldSchema = StructType(
Seq(
StructField("a", StringType),
StructField("b", IntegerType, nullable = false)
)
)
val rightFieldSchema = StructType(
Seq(
StructField("a", StringType),
StructField("b", IntegerType, nullable = false)
)
)
val expectedSchema = StructType(
Seq(
StructField(
"_1",
leftFieldSchema,
nullable = false
),
// This is a left join, so the right output is nullable:
StructField(
"_2",
rightFieldSchema
)
)
)
assert(joined.schema === expectedSchema)

val result = joined.collect().toSet
val expected = Set(
new GenericRowWithSchema(Array("a", 1), leftFieldSchema) ->
null,
new GenericRowWithSchema(Array("b", 2), leftFieldSchema) ->
new GenericRowWithSchema(Array("x", 2), rightFieldSchema)
)
assert(result == expected)
}
}

class DatasetLargeResultCollectingSuite extends QueryTest
Expand Down

0 comments on commit 74ce620

Please sign in to comment.