Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ package com.stratio.crossdata.connector.cassandra
object CassandraAttributeRole extends Enumeration {
type CassandraAttributeRole = Value
val PartitionKey, ClusteringKey, Indexed, NonIndexed, Function, Unknown = Value
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package com.stratio.crossdata.connector.cassandra


import com.datastax.driver.core.ResultSet
import com.datastax.spark.connector.CassandraRowMetadata
import com.stratio.common.utils.components.logger.impl.SparkLoggerComponent
Expand All @@ -31,21 +30,25 @@ import org.apache.spark.sql.sources.CatalystToCrossdataAdapter._
import org.apache.spark.sql.sources.{CatalystToCrossdataAdapter, Filter => SourceFilter}
import org.apache.spark.sql.{Row, sources}

object CassandraQueryProcessor extends SQLLikeQueryProcessorUtils with SQLLikeUDFQueryProcessorUtils {
object CassandraQueryProcessor
extends SQLLikeQueryProcessorUtils
with SQLLikeUDFQueryProcessorUtils {

val DefaultLimit = 10000
type ColumnName = String

case class CassandraQueryProcessorContext(udfs: Map[String, NativeUDF]) extends SQLLikeUDFQueryProcessorUtils.ContextWithUDFs
case class CassandraQueryProcessorContext(udfs: Map[String, NativeUDF])
extends SQLLikeUDFQueryProcessorUtils.ContextWithUDFs
override type ProcessingContext = CassandraQueryProcessorContext

case class CassandraPlan(basePlan: BaseLogicalPlan, limit: Option[Int]){
case class CassandraPlan(basePlan: BaseLogicalPlan, limit: Option[Int]) {
def projects: Seq[NamedExpression] = basePlan.projects
def filters: Array[SourceFilter] = basePlan.filters
def udfsMap: Map[Attribute, NativeUDF] = basePlan.udfsMap
}

def apply(cassandraRelation: CassandraXDSourceRelation, logicalPlan: LogicalPlan) = new CassandraQueryProcessor(cassandraRelation, logicalPlan)
def apply(cassandraRelation: CassandraXDSourceRelation, logicalPlan: LogicalPlan) =
new CassandraQueryProcessor(cassandraRelation, logicalPlan)

def buildNativeQuery(tableQN: String,
requiredColumns: Seq[String],
Expand All @@ -57,17 +60,27 @@ object CassandraQueryProcessor extends SQLLikeQueryProcessorUtils with SQLLikeUD

def filterToCQL(filter: SourceFilter): String = filter match {

case sources.EqualTo(attribute, value) => s"${expandAttribute(attribute)} = ${quoteString(value)}"
case sources.In(attribute, values) => s"${expandAttribute(attribute)} IN ${values.map(quoteString).mkString("(", ",", ")")}"
case sources.LessThan(attribute, value) => s"${expandAttribute(attribute)} < ${quoteString(value)}"
case sources.GreaterThan(attribute, value) => s"${expandAttribute(attribute)} > ${quoteString(value)}"
case sources.LessThanOrEqual(attribute, value) => s"${expandAttribute(attribute)} <= ${quoteString(value)}"
case sources.GreaterThanOrEqual(attribute, value) => s"${expandAttribute(attribute)} >= ${quoteString(value)}"
case sources.And(leftFilter, rightFilter) => s"${filterToCQL(leftFilter)} AND ${filterToCQL(rightFilter)}"
case sources.EqualTo(attribute, value) =>
s"${expandAttribute(attribute)} = ${quoteString(value)}"
case sources.In(attribute, values) =>
s"${expandAttribute(attribute)} IN ${values.map(quoteString).mkString("(", ",", ")")}"
case sources.LessThan(attribute, value) =>
s"${expandAttribute(attribute)} < ${quoteString(value)}"
case sources.GreaterThan(attribute, value) =>
s"${expandAttribute(attribute)} > ${quoteString(value)}"
case sources.LessThanOrEqual(attribute, value) =>
s"${expandAttribute(attribute)} <= ${quoteString(value)}"
case sources.GreaterThanOrEqual(attribute, value) =>
s"${expandAttribute(attribute)} >= ${quoteString(value)}"
case sources.And(leftFilter, rightFilter) =>
s"${filterToCQL(leftFilter)} AND ${filterToCQL(rightFilter)}"

}

val filter = if (filters.nonEmpty) filters.map(filterToCQL).mkString("WHERE ", " AND ", "") else ""
val filter =
if (filters.nonEmpty)
filters.map(filterToCQL).mkString("WHERE ", " AND ", "")
else ""
val columns = requiredColumns.map(expandAttribute).mkString(", ")

s"SELECT $columns FROM $tableQN $filter LIMIT $limit ALLOW FILTERING"
Expand All @@ -76,21 +89,28 @@ object CassandraQueryProcessor extends SQLLikeQueryProcessorUtils with SQLLikeUD
}

// TODO logs, doc, tests
class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, logicalPlan: LogicalPlan) extends SparkLoggerComponent {
class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation,
logicalPlan: LogicalPlan)
extends SparkLoggerComponent {

import CassandraQueryProcessor._

def execute(): Option[Array[Row]] = {
def annotateRepeatedNames(names: Seq[String]): Seq[String] = {
val indexedNames = names zipWithIndex
val name2pos = indexedNames.groupBy(_._1).values.flatMap(_.zipWithIndex.map(x => x._1._2 -> x._2)).toMap
indexedNames map { case (name, index) => val c = name2pos(index); if (c > 0) s"$name$c" else name }
val name2pos =
indexedNames.groupBy(_._1).values.flatMap(_.zipWithIndex.map(x => x._1._2 -> x._2)).toMap
indexedNames map {
case (name, index) =>
val c = name2pos(index); if (c > 0) s"$name$c" else name
}
}

def buildAggregationExpression(names: Expression): String = {
names match {
case Alias(child, _) => buildAggregationExpression(child)
case Count(children) => s"count(${children.map(buildAggregationExpression).mkString(",")})"
case Count(children) =>
s"count(${children.map(buildAggregationExpression).mkString(",")})"
case Literal(1, _) => "*"
}
}
Expand All @@ -110,28 +130,32 @@ class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, logi
}

val cqlQuery = buildNativeQuery(
cassandraRelation.tableDef.name,
projectsString,
cassandraPlan.filters,
cassandraPlan.limit.getOrElse(CassandraQueryProcessor.DefaultLimit),
cassandraPlan.udfsMap map { case (k, v) => k.toString -> v }
cassandraRelation.tableDef.name,
projectsString,
cassandraPlan.filters,
cassandraPlan.limit.getOrElse(CassandraQueryProcessor.DefaultLimit),
cassandraPlan.udfsMap map { case (k, v) => k.toString -> v }
)
val resultSet = cassandraRelation.connector.withSessionDo { session =>
session.execute(cqlQuery)
}
sparkResultFromCassandra(annotateRepeatedNames(cassandraPlan.projects.map(_.name)).toArray, resultSet)
sparkResultFromCassandra(
annotateRepeatedNames(cassandraPlan.projects.map(_.name)).toArray,
resultSet)
}

}
} catch {
case exc: Exception => log.warn(s"Exception executing the native query $logicalPlan", exc.getMessage); None
case exc: Exception =>
log.warn(s"Exception executing the native query $logicalPlan", exc.getMessage); None
}

}


def validatedNativePlan: Option[CassandraPlan] = {
lazy val limit: Option[Int] = logicalPlan.collectFirst { case Limit(Literal(num: Int, _), _) => num }
lazy val limit: Option[Int] = logicalPlan.collectFirst {
case Limit(Literal(num: Int, _), _) => num
}

def findBasePlan(lplan: LogicalPlan): Option[BaseLogicalPlan] = {
lplan match {
Expand All @@ -143,15 +167,20 @@ class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, logi
findBasePlan(child)

case ExtendedPhysicalOperation(projectList, filterList, _) =>
CatalystToCrossdataAdapter.getConnectorLogicalPlan(logicalPlan, projectList, filterList) match {
case (_, ProjectReport(exprIgnored), FilterReport(filtersIgnored, _)) if filtersIgnored.nonEmpty || exprIgnored.nonEmpty =>
CatalystToCrossdataAdapter
.getConnectorLogicalPlan(logicalPlan, projectList, filterList) match {
case (_, ProjectReport(exprIgnored), FilterReport(filtersIgnored, _))
if filtersIgnored.nonEmpty || exprIgnored.nonEmpty =>
None
case (basePlan, _, _) =>
Some(basePlan)
}
}
}
findBasePlan(logicalPlan).collect{ case bp if checkNativeFilters(bp.filters, bp.udfsMap) => CassandraPlan(bp, limit)}
findBasePlan(logicalPlan).collect {
case bp if checkNativeFilters(bp.filters, bp.udfsMap) =>
CassandraPlan(bp, limit)
}
}

private[this] def checkNativeFilters(filters: Array[SourceFilter],
Expand All @@ -163,16 +192,20 @@ class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, logi
case sources.EqualTo(attribute, _) => attributeRole(attribute, udfNames)
case sources.In(attribute, _) => attributeRole(attribute, udfNames)
case sources.LessThan(attribute, _) => attributeRole(attribute, udfNames)
case sources.GreaterThan(attribute, _) => attributeRole(attribute, udfNames)
case sources.LessThanOrEqual(attribute, _) => attributeRole(attribute, udfNames)
case sources.GreaterThanOrEqual(attribute, _) => attributeRole(attribute, udfNames)
case sources.GreaterThan(attribute, _) =>
attributeRole(attribute, udfNames)
case sources.LessThanOrEqual(attribute, _) =>
attributeRole(attribute, udfNames)
case sources.GreaterThanOrEqual(attribute, _) =>
attributeRole(attribute, udfNames)
case _ => Unknown
}

def checksClusteringKeyFilters: Boolean =
!groupedFilters.contains(ClusteringKey) || {
// if there is a CK filter then all CKs should be included. Accept any kind of filter
val clusteringColsInFilter = groupedFilters.get(ClusteringKey).get.flatMap(columnNameFromFilter)
val clusteringColsInFilter =
groupedFilters.get(ClusteringKey).get.flatMap(columnNameFromFilter)
cassandraRelation.tableDef.clusteringColumns.forall { column =>
clusteringColsInFilter.contains(column.columnName)
}
Expand All @@ -187,10 +220,10 @@ class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, logi
}
}


def checksPartitionKeyFilters: Boolean =
!groupedFilters.contains(PartitionKey) || {
val partitionColsInFilter = groupedFilters.get(PartitionKey).get.flatMap(columnNameFromFilter)
val partitionColsInFilter =
groupedFilters.get(PartitionKey).get.flatMap(columnNameFromFilter)

// all PKs must be present
cassandraRelation.tableDef.partitionKey.forall { column =>
Expand All @@ -199,43 +232,46 @@ class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, logi
// filters condition must be = or IN with restrictions
groupedFilters.get(PartitionKey).get.forall {
case sources.EqualTo(_, _) => true
case sources.In(colName, _) if cassandraRelation.tableDef.partitionKey.last.columnName.equals(colName) => true
case sources.In(colName, _)
if cassandraRelation.tableDef.partitionKey.last.columnName.equals(colName) =>
true
case _ => false
}
}

{
!groupedFilters.contains(Unknown) && !groupedFilters.contains(NonIndexed) &&
checksPartitionKeyFilters && checksClusteringKeyFilters && checksSecondaryIndexesFilters
checksPartitionKeyFilters && checksClusteringKeyFilters && checksSecondaryIndexesFilters
}

}

private[this] def columnNameFromFilter(sourceFilter: SourceFilter): Option[ColumnName] = sourceFilter match {
case sources.EqualTo(attribute, _) => Some(attribute)
case sources.In(attribute, _) => Some(attribute)
case sources.LessThan(attribute, _) => Some(attribute)
case sources.GreaterThan(attribute, _) => Some(attribute)
case sources.LessThanOrEqual(attribute, _) => Some(attribute)
case sources.GreaterThanOrEqual(attribute, _) => Some(attribute)
case _ => None
}
private[this] def columnNameFromFilter(sourceFilter: SourceFilter): Option[ColumnName] =
sourceFilter match {
case sources.EqualTo(attribute, _) => Some(attribute)
case sources.In(attribute, _) => Some(attribute)
case sources.LessThan(attribute, _) => Some(attribute)
case sources.GreaterThan(attribute, _) => Some(attribute)
case sources.LessThanOrEqual(attribute, _) => Some(attribute)
case sources.GreaterThanOrEqual(attribute, _) => Some(attribute)
case _ => None
}

private[this] def attributeRole(columnName: String, udfs: Set[String]): CassandraAttributeRole =
if (udfs contains columnName) Function
else cassandraRelation.tableDef.columnByName(columnName) match {
case x if x.isPartitionKeyColumn => PartitionKey
case x if x.isClusteringColumn => ClusteringKey
case x if cassandraRelation.tableDef.isIndexed(x) => Indexed
case _ => NonIndexed
}
else
cassandraRelation.tableDef.columnByName(columnName) match {
case x if x.isPartitionKeyColumn => PartitionKey
case x if x.isClusteringColumn => ClusteringKey
case x if cassandraRelation.tableDef.isIndexed(x) => Indexed
case _ => NonIndexed
}

private[this] def sparkResultFromCassandra(requiredColumns: Array[ColumnName], resultSet: ResultSet): Array[Row] = {
private[this] def sparkResultFromCassandra(requiredColumns: Array[ColumnName],
resultSet: ResultSet): Array[Row] = {
import scala.collection.JavaConversions._
val cassandraRowMetadata = CassandraRowMetadata.fromColumnNames(requiredColumns)
resultSet.all().map(CassandraSQLRow.fromJavaDriverRow(_, cassandraRowMetadata)).toArray
}

}


Loading