diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala index 75f0a359a793b..ae91615da0f4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala @@ -42,13 +42,15 @@ trait JoinCodegenSupport extends CodegenSupport with BaseJoinExec { buildRow: Option[String] = None): (String, String, Seq[ExprCode]) = { val buildSideRow = buildRow.getOrElse(ctx.freshName("buildRow")) val buildVars = genOneSideJoinVars(ctx, buildSideRow, buildPlan, setDefaultValue = false) + val streamVars2 = streamVars.map(_.copy()) val checkCondition = if (condition.isDefined) { val expr = condition.get - // evaluate the variables from build side that used by condition - val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + // evaluate the variables that are used by the condition + val eval = evaluateRequiredVariables(streamPlan.output ++ buildPlan.output, + streamVars2 ++ buildVars, expr.references) // filter the output via condition - ctx.currentVars = streamVars ++ buildVars + ctx.currentVars = streamVars2 ++ buildVars val ev = BindReferences.bindReference(expr, streamPlan.output ++ buildPlan.output).genCode(ctx) val skipRow = s"${ev.isNull} || !${ev.value}" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 6dd34d41cf6c1..0e0ca54256075 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1455,4 +1455,39 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan checkAnswer(result1, result2) } } + + def dupStreamSideColTest(hint: String, check: SparkPlan => Unit): Unit = { + val query = + s"""select /*+ ${hint}(r) */ * + |from testData2 l + |full outer join testData3 r + |on l.a = r.a + |and l.b < (r.b + 1) + |and l.b < (r.a + 1)""".stripMargin + val df = sql(query) + val plan = df.queryExecution.executedPlan + check(plan) + val expected = Row(1, 1, null, null) :: + Row(1, 2, null, null) :: + Row(null, null, 1, null) :: + Row(2, 1, 2, 2) :: + Row(2, 2, 2, 2) :: + Row(3, 1, null, null) :: + Row(3, 2, null, null) :: Nil + checkAnswer(df, expected) + } + + test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SMJ)") { + def check(plan: SparkPlan): Unit = { + assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 1) + } + dupStreamSideColTest("MERGE", check) + } + + test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SHJ)") { + def check(plan: SparkPlan): Unit = { + assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1) + } + dupStreamSideColTest("SHUFFLE_HASH", check) + } }