@@ -5,7 +5,7 @@ import com.datastax.spark.connector.util.Logging
55import org .apache .spark .sql .{SparkSession , Strategy }
66import org .apache .spark .sql .cassandra .{AlwaysOff , AlwaysOn , Automatic , CassandraSourceRelation }
77import org .apache .spark .sql .cassandra .CassandraSourceRelation ._
8- import org .apache .spark .sql .catalyst .expressions .{Alias , AttributeReference , ExprId , Expression , NamedExpression }
8+ import org .apache .spark .sql .catalyst .expressions .{Alias , Attribute , AttributeReference , ExprId , Expression , NamedExpression }
99import org .apache .spark .sql .catalyst .planning .{ExtractEquiJoinKeys , PhysicalOperation }
1010import org .apache .spark .sql .catalyst .plans .logical ._
1111import org .apache .spark .sql .catalyst .plans ._
@@ -59,7 +59,7 @@ case class CassandraDirectJoinStrategy(spark: SparkSession) extends Strategy wit
5959 cassandraScanExec
6060 )
6161
62- val newPlan = reorderPlan(dataSourceOptimizedPlan, directJoin) :: Nil
62+ val newPlan = reorderPlan(dataSourceOptimizedPlan, directJoin, plan.output ) :: Nil
6363 val newOutput = (newPlan.head.outputSet, newPlan.head.output.map(_.name))
6464 val oldOutput = (plan.outputSet, plan.output.map(_.name))
6565 val noMissingOutput = oldOutput._1.subsetOf(newPlan.head.outputSet)
@@ -232,7 +232,10 @@ object CassandraDirectJoinStrategy extends Logging {
232232 *
233233 * This should only be called on optimized Physical Plans
234234 */
235- def reorderPlan (plan : SparkPlan , directJoin : CassandraDirectJoinExec ): SparkPlan = {
235+ def reorderPlan (
236+ plan : SparkPlan ,
237+ directJoin : CassandraDirectJoinExec ,
238+ originalOutput : Seq [Attribute ]): SparkPlan = {
236239 val reordered = plan match {
237240 // This may be the only node in the Plan
238241 case BatchScanExec (_, _ : CassandraScan , _) => directJoin
@@ -252,19 +255,25 @@ object CassandraDirectJoinStrategy extends Logging {
252255 */
253256 reordered.transform {
254257 case ProjectExec (projectList, child) =>
258+ val attrMap = directJoin.output.map {
259+ case attr => attr.exprId -> attr
260+ }.toMap
261+
255262 val aliases = projectList.collect {
256- case a @ Alias (child : AttributeReference , _) => (child.toAttribute.exprId, a)
263+ case a @ Alias (child, _) =>
264+ val newAliasChild = child.transform {
265+ case attr : Attribute => attrMap.getOrElse(attr.exprId, attr)
266+ }
267+ (a.exprId, a.withNewChildren(newAliasChild :: Nil ).asInstanceOf [Alias ])
257268 }.toMap
258269
259- val aliasedOutput = directJoin.output.map {
260- case attr if aliases.contains(attr.exprId) =>
261- val oldAlias = aliases(attr.exprId)
262- oldAlias.copy(child = attr)(oldAlias.exprId, oldAlias.qualifier,
263- oldAlias.explicitMetadata, oldAlias.nonInheritableMetadataKeys)
270+ // The original output of Join
271+ val reorderedOutput = originalOutput.map {
272+ case attr if aliases.contains(attr.exprId) => aliases(attr.exprId)
264273 case other => other
265274 }
266275
267- ProjectExec (aliasedOutput , child)
276+ ProjectExec (reorderedOutput , child)
268277 }
269278 }
270279
@@ -310,13 +319,21 @@ object CassandraDirectJoinStrategy extends Logging {
310319 case _ => false
311320 }
312321
322+ def getAlias (expr : NamedExpression ): (String , ExprId ) = expr match {
323+ case a @ Alias (child : AttributeReference , _) => child.name -> a.exprId
324+ case a @ Alias (child, _) =>
325+ val attrs = child.collect {
326+ case attr : AttributeReference => attr
327+ }
328+ assert(attrs.length == 1 )
329+ attrs(0 ).name -> attrs(0 ).exprId
330+ case attributeReference : AttributeReference => attributeReference.name -> attributeReference.exprId
331+ }
332+
313333 /**
314334 * Map Source Cassandra Column Names to ExpressionIds referring to them
315335 */
316- def aliasMap (aliases : Seq [NamedExpression ]): Map [String , ExprId ] = aliases.map {
317- case a @ Alias (child : AttributeReference , _) => child.name -> a.exprId
318- case attributeReference : AttributeReference => attributeReference.name -> attributeReference.exprId
319- }.toMap
336+ def aliasMap (aliases : Seq [NamedExpression ]): Map [String , ExprId ] = aliases.map(getAlias).toMap
320337
321338 /**
322339 * Checks whether a logical plan contains only Filters, Aliases
0 commit comments