diff --git a/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/CassandraAttributeRole.scala b/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/CassandraAttributeRole.scala index fb6eb8861..ed9537b3e 100644 --- a/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/CassandraAttributeRole.scala +++ b/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/CassandraAttributeRole.scala @@ -18,4 +18,4 @@ package com.stratio.crossdata.connector.cassandra object CassandraAttributeRole extends Enumeration { type CassandraAttributeRole = Value val PartitionKey, ClusteringKey, Indexed, NonIndexed, Function, Unknown = Value -} \ No newline at end of file +} diff --git a/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/CassandraQueryProcessor.scala b/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/CassandraQueryProcessor.scala index f205da096..b1e0855ac 100644 --- a/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/CassandraQueryProcessor.scala +++ b/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/CassandraQueryProcessor.scala @@ -15,7 +15,6 @@ */ package com.stratio.crossdata.connector.cassandra - import com.datastax.driver.core.ResultSet import com.datastax.spark.connector.CassandraRowMetadata import com.stratio.common.utils.components.logger.impl.SparkLoggerComponent @@ -31,21 +30,25 @@ import org.apache.spark.sql.sources.CatalystToCrossdataAdapter._ import org.apache.spark.sql.sources.{CatalystToCrossdataAdapter, Filter => SourceFilter} import org.apache.spark.sql.{Row, sources} -object CassandraQueryProcessor extends SQLLikeQueryProcessorUtils with SQLLikeUDFQueryProcessorUtils { +object CassandraQueryProcessor + extends SQLLikeQueryProcessorUtils + with SQLLikeUDFQueryProcessorUtils { val DefaultLimit = 10000 type ColumnName = String - case class CassandraQueryProcessorContext(udfs: Map[String, NativeUDF]) extends SQLLikeUDFQueryProcessorUtils.ContextWithUDFs + case class CassandraQueryProcessorContext(udfs: Map[String, NativeUDF]) + extends SQLLikeUDFQueryProcessorUtils.ContextWithUDFs override type ProcessingContext = CassandraQueryProcessorContext - case class CassandraPlan(basePlan: BaseLogicalPlan, limit: Option[Int]){ + case class CassandraPlan(basePlan: BaseLogicalPlan, limit: Option[Int]) { def projects: Seq[NamedExpression] = basePlan.projects def filters: Array[SourceFilter] = basePlan.filters def udfsMap: Map[Attribute, NativeUDF] = basePlan.udfsMap } - def apply(cassandraRelation: CassandraXDSourceRelation, logicalPlan: LogicalPlan) = new CassandraQueryProcessor(cassandraRelation, logicalPlan) + def apply(cassandraRelation: CassandraXDSourceRelation, logicalPlan: LogicalPlan) = + new CassandraQueryProcessor(cassandraRelation, logicalPlan) def buildNativeQuery(tableQN: String, requiredColumns: Seq[String], @@ -57,17 +60,27 @@ object CassandraQueryProcessor extends SQLLikeQueryProcessorUtils with SQLLikeUD def filterToCQL(filter: SourceFilter): String = filter match { - case sources.EqualTo(attribute, value) => s"${expandAttribute(attribute)} = ${quoteString(value)}" - case sources.In(attribute, values) => s"${expandAttribute(attribute)} IN ${values.map(quoteString).mkString("(", ",", ")")}" - case sources.LessThan(attribute, value) => s"${expandAttribute(attribute)} < ${quoteString(value)}" - case sources.GreaterThan(attribute, value) => s"${expandAttribute(attribute)} > ${quoteString(value)}" - case sources.LessThanOrEqual(attribute, value) => s"${expandAttribute(attribute)} <= ${quoteString(value)}" - case sources.GreaterThanOrEqual(attribute, value) => s"${expandAttribute(attribute)} >= ${quoteString(value)}" - case sources.And(leftFilter, rightFilter) => s"${filterToCQL(leftFilter)} AND ${filterToCQL(rightFilter)}" + case sources.EqualTo(attribute, value) => + s"${expandAttribute(attribute)} = ${quoteString(value)}" + case sources.In(attribute, values) => + s"${expandAttribute(attribute)} IN ${values.map(quoteString).mkString("(", ",", ")")}" + case sources.LessThan(attribute, value) => + s"${expandAttribute(attribute)} < ${quoteString(value)}" + case sources.GreaterThan(attribute, value) => + s"${expandAttribute(attribute)} > ${quoteString(value)}" + case sources.LessThanOrEqual(attribute, value) => + s"${expandAttribute(attribute)} <= ${quoteString(value)}" + case sources.GreaterThanOrEqual(attribute, value) => + s"${expandAttribute(attribute)} >= ${quoteString(value)}" + case sources.And(leftFilter, rightFilter) => + s"${filterToCQL(leftFilter)} AND ${filterToCQL(rightFilter)}" } - val filter = if (filters.nonEmpty) filters.map(filterToCQL).mkString("WHERE ", " AND ", "") else "" + val filter = + if (filters.nonEmpty) + filters.map(filterToCQL).mkString("WHERE ", " AND ", "") + else "" val columns = requiredColumns.map(expandAttribute).mkString(", ") s"SELECT $columns FROM $tableQN $filter LIMIT $limit ALLOW FILTERING" @@ -76,21 +89,28 @@ object CassandraQueryProcessor extends SQLLikeQueryProcessorUtils with SQLLikeUD } // TODO logs, doc, tests -class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, logicalPlan: LogicalPlan) extends SparkLoggerComponent { +class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, + logicalPlan: LogicalPlan) + extends SparkLoggerComponent { import CassandraQueryProcessor._ def execute(): Option[Array[Row]] = { def annotateRepeatedNames(names: Seq[String]): Seq[String] = { val indexedNames = names zipWithIndex - val name2pos = indexedNames.groupBy(_._1).values.flatMap(_.zipWithIndex.map(x => x._1._2 -> x._2)).toMap - indexedNames map { case (name, index) => val c = name2pos(index); if (c > 0) s"$name$c" else name } + val name2pos = + indexedNames.groupBy(_._1).values.flatMap(_.zipWithIndex.map(x => x._1._2 -> x._2)).toMap + indexedNames map { + case (name, index) => + val c = name2pos(index); if (c > 0) s"$name$c" else name + } } def buildAggregationExpression(names: Expression): String = { names match { case Alias(child, _) => buildAggregationExpression(child) - case Count(children) => s"count(${children.map(buildAggregationExpression).mkString(",")})" + case Count(children) => + s"count(${children.map(buildAggregationExpression).mkString(",")})" case Literal(1, _) => "*" } } @@ -110,28 +130,32 @@ class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, logi } val cqlQuery = buildNativeQuery( - cassandraRelation.tableDef.name, - projectsString, - cassandraPlan.filters, - cassandraPlan.limit.getOrElse(CassandraQueryProcessor.DefaultLimit), - cassandraPlan.udfsMap map { case (k, v) => k.toString -> v } + cassandraRelation.tableDef.name, + projectsString, + cassandraPlan.filters, + cassandraPlan.limit.getOrElse(CassandraQueryProcessor.DefaultLimit), + cassandraPlan.udfsMap map { case (k, v) => k.toString -> v } ) val resultSet = cassandraRelation.connector.withSessionDo { session => session.execute(cqlQuery) } - sparkResultFromCassandra(annotateRepeatedNames(cassandraPlan.projects.map(_.name)).toArray, resultSet) + sparkResultFromCassandra( + annotateRepeatedNames(cassandraPlan.projects.map(_.name)).toArray, + resultSet) } } } catch { - case exc: Exception => log.warn(s"Exception executing the native query $logicalPlan", exc.getMessage); None + case exc: Exception => + log.warn(s"Exception executing the native query $logicalPlan", exc.getMessage); None } } - def validatedNativePlan: Option[CassandraPlan] = { - lazy val limit: Option[Int] = logicalPlan.collectFirst { case Limit(Literal(num: Int, _), _) => num } + lazy val limit: Option[Int] = logicalPlan.collectFirst { + case Limit(Literal(num: Int, _), _) => num + } def findBasePlan(lplan: LogicalPlan): Option[BaseLogicalPlan] = { lplan match { @@ -143,15 +167,20 @@ class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, logi findBasePlan(child) case ExtendedPhysicalOperation(projectList, filterList, _) => - CatalystToCrossdataAdapter.getConnectorLogicalPlan(logicalPlan, projectList, filterList) match { - case (_, ProjectReport(exprIgnored), FilterReport(filtersIgnored, _)) if filtersIgnored.nonEmpty || exprIgnored.nonEmpty => + CatalystToCrossdataAdapter + .getConnectorLogicalPlan(logicalPlan, projectList, filterList) match { + case (_, ProjectReport(exprIgnored), FilterReport(filtersIgnored, _)) + if filtersIgnored.nonEmpty || exprIgnored.nonEmpty => None case (basePlan, _, _) => Some(basePlan) } } } - findBasePlan(logicalPlan).collect{ case bp if checkNativeFilters(bp.filters, bp.udfsMap) => CassandraPlan(bp, limit)} + findBasePlan(logicalPlan).collect { + case bp if checkNativeFilters(bp.filters, bp.udfsMap) => + CassandraPlan(bp, limit) + } } private[this] def checkNativeFilters(filters: Array[SourceFilter], @@ -163,16 +192,20 @@ class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, logi case sources.EqualTo(attribute, _) => attributeRole(attribute, udfNames) case sources.In(attribute, _) => attributeRole(attribute, udfNames) case sources.LessThan(attribute, _) => attributeRole(attribute, udfNames) - case sources.GreaterThan(attribute, _) => attributeRole(attribute, udfNames) - case sources.LessThanOrEqual(attribute, _) => attributeRole(attribute, udfNames) - case sources.GreaterThanOrEqual(attribute, _) => attributeRole(attribute, udfNames) + case sources.GreaterThan(attribute, _) => + attributeRole(attribute, udfNames) + case sources.LessThanOrEqual(attribute, _) => + attributeRole(attribute, udfNames) + case sources.GreaterThanOrEqual(attribute, _) => + attributeRole(attribute, udfNames) case _ => Unknown } def checksClusteringKeyFilters: Boolean = !groupedFilters.contains(ClusteringKey) || { // if there is a CK filter then all CKs should be included. Accept any kind of filter - val clusteringColsInFilter = groupedFilters.get(ClusteringKey).get.flatMap(columnNameFromFilter) + val clusteringColsInFilter = + groupedFilters.get(ClusteringKey).get.flatMap(columnNameFromFilter) cassandraRelation.tableDef.clusteringColumns.forall { column => clusteringColsInFilter.contains(column.columnName) } @@ -187,10 +220,10 @@ class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, logi } } - def checksPartitionKeyFilters: Boolean = !groupedFilters.contains(PartitionKey) || { - val partitionColsInFilter = groupedFilters.get(PartitionKey).get.flatMap(columnNameFromFilter) + val partitionColsInFilter = + groupedFilters.get(PartitionKey).get.flatMap(columnNameFromFilter) // all PKs must be present cassandraRelation.tableDef.partitionKey.forall { column => @@ -199,43 +232,46 @@ class CassandraQueryProcessor(cassandraRelation: CassandraXDSourceRelation, logi // filters condition must be = or IN with restrictions groupedFilters.get(PartitionKey).get.forall { case sources.EqualTo(_, _) => true - case sources.In(colName, _) if cassandraRelation.tableDef.partitionKey.last.columnName.equals(colName) => true + case sources.In(colName, _) + if cassandraRelation.tableDef.partitionKey.last.columnName.equals(colName) => + true case _ => false } } { !groupedFilters.contains(Unknown) && !groupedFilters.contains(NonIndexed) && - checksPartitionKeyFilters && checksClusteringKeyFilters && checksSecondaryIndexesFilters + checksPartitionKeyFilters && checksClusteringKeyFilters && checksSecondaryIndexesFilters } } - private[this] def columnNameFromFilter(sourceFilter: SourceFilter): Option[ColumnName] = sourceFilter match { - case sources.EqualTo(attribute, _) => Some(attribute) - case sources.In(attribute, _) => Some(attribute) - case sources.LessThan(attribute, _) => Some(attribute) - case sources.GreaterThan(attribute, _) => Some(attribute) - case sources.LessThanOrEqual(attribute, _) => Some(attribute) - case sources.GreaterThanOrEqual(attribute, _) => Some(attribute) - case _ => None - } + private[this] def columnNameFromFilter(sourceFilter: SourceFilter): Option[ColumnName] = + sourceFilter match { + case sources.EqualTo(attribute, _) => Some(attribute) + case sources.In(attribute, _) => Some(attribute) + case sources.LessThan(attribute, _) => Some(attribute) + case sources.GreaterThan(attribute, _) => Some(attribute) + case sources.LessThanOrEqual(attribute, _) => Some(attribute) + case sources.GreaterThanOrEqual(attribute, _) => Some(attribute) + case _ => None + } private[this] def attributeRole(columnName: String, udfs: Set[String]): CassandraAttributeRole = if (udfs contains columnName) Function - else cassandraRelation.tableDef.columnByName(columnName) match { - case x if x.isPartitionKeyColumn => PartitionKey - case x if x.isClusteringColumn => ClusteringKey - case x if cassandraRelation.tableDef.isIndexed(x) => Indexed - case _ => NonIndexed - } + else + cassandraRelation.tableDef.columnByName(columnName) match { + case x if x.isPartitionKeyColumn => PartitionKey + case x if x.isClusteringColumn => ClusteringKey + case x if cassandraRelation.tableDef.isIndexed(x) => Indexed + case _ => NonIndexed + } - private[this] def sparkResultFromCassandra(requiredColumns: Array[ColumnName], resultSet: ResultSet): Array[Row] = { + private[this] def sparkResultFromCassandra(requiredColumns: Array[ColumnName], + resultSet: ResultSet): Array[Row] = { import scala.collection.JavaConversions._ val cassandraRowMetadata = CassandraRowMetadata.fromColumnNames(requiredColumns) resultSet.all().map(CassandraSQLRow.fromJavaDriverRow(_, cassandraRowMetadata)).toArray } } - - diff --git a/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/DefaultSource.scala b/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/DefaultSource.scala index a4392d774..807d41fb3 100644 --- a/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/DefaultSource.scala +++ b/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/DefaultSource.scala @@ -17,7 +17,6 @@ package com.stratio.crossdata.connector.cassandra - import java.util.Collection import com.datastax.driver.core.{KeyspaceMetadata, TableMetadata} @@ -36,48 +35,51 @@ import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} import scala.util.Try - - /** - * Cassandra data source extends [[org.apache.spark.sql.sources.RelationProvider]], [[org.apache.spark.sql.sources.SchemaRelationProvider]] - * and [[org.apache.spark.sql.sources.CreatableRelationProvider]]. - * - * It's used internally by Spark SQL to create Relation for a table which specifies the Cassandra data source - * e.g. - * - * CREATE TEMPORARY TABLE tmpTable - * USING org.apache.spark.sql.cassandra - * OPTIONS ( - * table "table", - * keyspace "keyspace", - * cluster "test_cluster", - * pushdown "true", - * spark_cassandra_input_page_row_size "10", - * spark_cassandra_output_consistency_level "ONE", - * spark_cassandra_connection_timeout_ms "1000" - * ) - */ -class DefaultSource extends CassandraConnectorDS with TableInventory with FunctionInventory with DataSourceRegister with TableManipulation { + * Cassandra data source extends [[org.apache.spark.sql.sources.RelationProvider]], [[org.apache.spark.sql.sources.SchemaRelationProvider]] + * and [[org.apache.spark.sql.sources.CreatableRelationProvider]]. + * + * It's used internally by Spark SQL to create Relation for a table which specifies the Cassandra data source + * e.g. + * + * CREATE TEMPORARY TABLE tmpTable + * USING org.apache.spark.sql.cassandra + * OPTIONS ( + * table "table", + * keyspace "keyspace", + * cluster "test_cluster", + * pushdown "true", + * spark_cassandra_input_page_row_size "10", + * spark_cassandra_output_consistency_level "ONE", + * spark_cassandra_connection_timeout_ms "1000" + * ) + */ +class DefaultSource + extends CassandraConnectorDS + with TableInventory + with FunctionInventory + with DataSourceRegister + with TableManipulation { import CassandraConnectorDS._ override def shortName(): String = "cassandra" /** - * Creates a new relation for a cassandra table. - * The parameters map stores table level data. User can specify vale for following keys - * - * table -- table name, required - * keyspace -- keyspace name, required - * cluster -- cluster name, optional, default name is "default" - * pushdown -- true/false, optional, default is true - * Cassandra connection settings -- optional, e.g. spark_cassandra_connection_timeout_ms - * Cassandra Read Settings -- optional, e.g. spark_cassandra_input_page_row_size - * Cassandra Write settings -- optional, e.g. spark_cassandra_output_consistency_level - * - * When push_down is true, some filters are pushed down to CQL. - * - */ + * Creates a new relation for a cassandra table. + * The parameters map stores table level data. User can specify vale for following keys + * + * table -- table name, required + * keyspace -- keyspace name, required + * cluster -- cluster name, optional, default name is "default" + * pushdown -- true/false, optional, default is true + * Cassandra connection settings -- optional, e.g. spark_cassandra_connection_timeout_ms + * Cassandra Read Settings -- optional, e.g. spark_cassandra_input_page_row_size + * Cassandra Write settings -- optional, e.g. spark_cassandra_output_consistency_level + * + * When push_down is true, some filters are pushed down to CQL. + * + */ override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { @@ -86,9 +88,9 @@ class DefaultSource extends CassandraConnectorDS with TableInventory with Functi } /** - * Creates a new relation for a cassandra table given table, keyspace, cluster and push_down - * as parameters and explicitly pass schema [[StructType]] as a parameter - */ + * Creates a new relation for a cassandra table given table, keyspace, cluster and push_down + * as parameters and explicitly pass schema [[StructType]] as a parameter + */ override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = { @@ -98,9 +100,9 @@ class DefaultSource extends CassandraConnectorDS with TableInventory with Functi } /** - * Creates a new relation for a cassandra table given table, keyspace, cluster, push_down and schema - * as parameters. It saves the data to the Cassandra table depends on [[SaveMode]] - */ + * Creates a new relation for a cassandra table given table, keyspace, cluster, push_down and schema + * as parameters. It saves the data to the Cassandra table depends on [[SaveMode]] + */ override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], @@ -116,8 +118,7 @@ class DefaultSource extends CassandraConnectorDS with TableInventory with Functi if (table.buildScan().isEmpty()) { table.insert(data, overwrite = false) } else { - throw new UnsupportedOperationException( - s"""'SaveMode is set to ErrorIfExists and Table + throw new UnsupportedOperationException(s"""'SaveMode is set to ErrorIfExists and Table |${tableRef.keyspace + "." + tableRef.table} already exists and contains data. |Perhaps you meant to set the DataFrame write mode to Append? |Example: df.write.format.options.mode(SaveMode.Append).save()" '""".stripMargin) @@ -131,14 +132,18 @@ class DefaultSource extends CassandraConnectorDS with TableInventory with Functi CassandraXDSourceRelation(tableRef, sqlContext, options) } - - override def nativeBuiltinFunctions: Seq[UDF] = { //TODO: Complete the built-in function inventory Seq( - UDF("now", None, StructType(Nil), StringType), - UDF("dateOf", None, StructType(StructField("date", StringType, false)::Nil), DataTypes.TimestampType), - UDF("unixTimestampOf", None, StructType(StructField("date", StringType, false)::Nil), DataTypes.LongType) + UDF("now", None, StructType(Nil), StringType), + UDF("dateOf", + None, + StructType(StructField("date", StringType, false) :: Nil), + DataTypes.TimestampType), + UDF("unixTimestampOf", + None, + StructType(StructField("date", StringType, false) :: Nil), + DataTypes.LongType) ) } @@ -148,8 +153,11 @@ class DefaultSource extends CassandraConnectorDS with TableInventory with Functi databaseName: Option[String], schema: StructType, options: Map[String, String]): Option[Table] = { - val keyspace: String = options.get(CassandraDataSourceKeyspaceNameProperty).orElse(databaseName). - getOrElse(throw new RuntimeException(s"$CassandraDataSourceKeyspaceNameProperty required when use CREATE EXTERNAL TABLE command")) + val keyspace: String = options + .get(CassandraDataSourceKeyspaceNameProperty) + .orElse(databaseName) + .getOrElse(throw new RuntimeException( + s"$CassandraDataSourceKeyspaceNameProperty required when use CREATE EXTERNAL TABLE command")) val table: String = options.getOrElse(CassandraDataSourceTableNameProperty, tableName) @@ -172,8 +180,7 @@ class DefaultSource extends CassandraConnectorDS with TableInventory with Functi } } - override def dropExternalTable(context: SQLContext, - options: Map[String, String]): Try[Unit] = { + override def dropExternalTable(context: SQLContext, options: Map[String, String]): Try[Unit] = { val keyspace: String = options.get(CassandraDataSourceKeyspaceNameProperty).get val table: String = options.get(CassandraDataSourceTableNameProperty).get @@ -187,53 +194,60 @@ class DefaultSource extends CassandraConnectorDS with TableInventory with Functi //-----------MetadataInventory----------------- - import collection.JavaConversions._ override def listTables(context: SQLContext, options: Map[String, String]): Seq[Table] = { if (options.contains(CassandraDataSourceTableNameProperty)) - require(options.contains(CassandraDataSourceKeyspaceNameProperty), s"$CassandraDataSourceKeyspaceNameProperty required when use $CassandraDataSourceTableNameProperty") + require( + options.contains(CassandraDataSourceKeyspaceNameProperty), + s"$CassandraDataSourceKeyspaceNameProperty required when use $CassandraDataSourceTableNameProperty") buildCassandraConnector(context, options).withSessionDo { s => - val keyspaces = options.get(CassandraDataSourceKeyspaceNameProperty).fold(s.getCluster.getMetadata.getKeyspaces){ - keySpaceName => s.getCluster.getMetadata.getKeyspace(keySpaceName) :: Nil - } + val keyspaces = options + .get(CassandraDataSourceKeyspaceNameProperty) + .fold(s.getCluster.getMetadata.getKeyspaces) { keySpaceName => + s.getCluster.getMetadata.getKeyspace(keySpaceName) :: Nil + } - val tablesIt: Iterable[Table] = for( - ksMeta: KeyspaceMetadata <- keyspaces; - tMeta: TableMetadata <- pickTables(ksMeta, options)) yield tableMeta2Table(tMeta) + val tablesIt: Iterable[Table] = for (ksMeta: KeyspaceMetadata <- keyspaces; + tMeta: TableMetadata <- pickTables(ksMeta, options)) + yield tableMeta2Table(tMeta) tablesIt.toSeq } } + private def buildCassandraConnector(context: SQLContext, + options: Map[String, String]): CassandraConnector = { - private def buildCassandraConnector(context: SQLContext, options: Map[String, String]): CassandraConnector = { - - val conParams = (CassandraDataSourceClusterNameProperty::CassandraConnectionHostProperty::Nil) map { opName => - if(!options.contains(opName)) sys.error(s"""Option "$opName" is mandatory for IMPORT CATALOG""") - else options(opName) + val conParams = (CassandraDataSourceClusterNameProperty :: CassandraConnectionHostProperty :: Nil) map { + opName => + if (!options.contains(opName)) + sys.error(s"""Option "$opName" is mandatory for IMPORT CATALOG""") + else options(opName) } val (clusterName, host) = (conParams zip conParams.tail) head val cfg: SparkConf = context.sparkContext.getConf.clone() for (prop <- CassandraConnectorDS.confProperties; - clusterLevelValue <- context.getAllConfs.get(s"$clusterName/$prop")) cfg.set(prop, clusterLevelValue) + clusterLevelValue <- context.getAllConfs.get(s"$clusterName/$prop")) + cfg.set(prop, clusterLevelValue) cfg.set("spark.cassandra.connection.host", host) CassandraConnector(cfg) } - private def pickTables(ksMeta: KeyspaceMetadata, options: Map[String, String]): Collection[TableMetadata] = { + private def pickTables(ksMeta: KeyspaceMetadata, + options: Map[String, String]): Collection[TableMetadata] = { options.get(CassandraDataSourceTableNameProperty).fold(ksMeta.getTables) { tableName => ksMeta.getTable(tableName) :: Nil } } /** - * @param tMeta C* Metadata for a given table - * @return A table description obtained after translate its C* meta data. - */ + * @param tMeta C* Metadata for a given table + * @return A table description obtained after translate its C* meta data. + */ private def tableMeta2Table(tMeta: TableMetadata): Table = Table(tMeta.getName, Some(tMeta.getKeyspace.getName)) @@ -241,13 +255,15 @@ class DefaultSource extends CassandraConnectorDS with TableInventory with Functi //Avoids importing system tables override def exclusionFilter(t: TableInventory.Table): Boolean = - t.database.exists( dbName => systemTableRegex.findFirstIn(dbName).isEmpty) - + t.database.exists(dbName => systemTableRegex.findFirstIn(dbName).isEmpty) - override def generateConnectorOpts(item: Table, opts: Map[String, String] = Map.empty): Map[String, String] = Map( - CassandraDataSourceTableNameProperty -> item.tableName, - CassandraDataSourceKeyspaceNameProperty -> item.database.get - ) ++ opts.filterKeys(Set(CassandraConnectionHostProperty, CassandraDataSourceClusterNameProperty).contains(_)) + override def generateConnectorOpts(item: Table, + opts: Map[String, String] = Map.empty): Map[String, String] = + Map( + CassandraDataSourceTableNameProperty -> item.tableName, + CassandraDataSourceKeyspaceNameProperty -> item.database.get + ) ++ opts.filterKeys( + Set(CassandraConnectionHostProperty, CassandraDataSourceClusterNameProperty).contains(_)) //------------MetadataInventory----------------- } @@ -255,8 +271,7 @@ class DefaultSource extends CassandraConnectorDS with TableInventory with Functi object DefaultSource { val CassandraConnectionHostProperty = "spark_cassandra_connection_host" - val CassandraDataSourcePrimaryKeyStringProperty ="primary_key_string" - val CassandraDataSourceKeyspaceReplicationStringProperty ="with_replication" + val CassandraDataSourcePrimaryKeyStringProperty = "primary_key_string" + val CassandraDataSourceKeyspaceReplicationStringProperty = "with_replication" } - diff --git a/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/statements/CassandraUtils.scala b/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/statements/CassandraUtils.scala index d2c2e9ee0..7450385e9 100644 --- a/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/statements/CassandraUtils.scala +++ b/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/statements/CassandraUtils.scala @@ -33,11 +33,11 @@ import org.apache.spark.sql.types.{TimestampType => SparkSqlTimestampType} object CassandraUtils { - /** Returns natural Cassandra type for representing data of the given Spark SQL type */ def fromSparkSqlType(dataType: SparkSqlDataType): ColumnType[_] = { - def unsupportedType() = throw new IllegalArgumentException(s"Unsupported type: $dataType") + def unsupportedType() = + throw new IllegalArgumentException(s"Unsupported type: $dataType") dataType match { case ByteType => IntType diff --git a/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/statements/CreateKeyspaceStatement.scala b/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/statements/CreateKeyspaceStatement.scala index 3ab1cc8c9..1fa2e458f 100644 --- a/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/statements/CreateKeyspaceStatement.scala +++ b/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/statements/CreateKeyspaceStatement.scala @@ -15,12 +15,10 @@ */ package com.stratio.crossdata.connector.cassandra.statements - import com.stratio.crossdata.connector.cassandra.DefaultSource.CassandraDataSourceKeyspaceReplicationStringProperty case class CreateKeyspaceStatement(options: Map[String, String]) { - override def toString(): String = { val cqlCommand = StringBuilder.newBuilder cqlCommand.append(s"CREATE KEYSPACE $keyspace WITH REPLICATION = $replication") @@ -32,10 +30,10 @@ case class CreateKeyspaceStatement(options: Map[String, String]) { options.get("keyspace").get } - lazy val replication: String = { - require(options.contains(CassandraDataSourceKeyspaceReplicationStringProperty), - s"$CassandraDataSourceKeyspaceReplicationStringProperty required when use CREATE EXTERNAL TABLE command") + require( + options.contains(CassandraDataSourceKeyspaceReplicationStringProperty), + s"$CassandraDataSourceKeyspaceReplicationStringProperty required when use CREATE EXTERNAL TABLE command") options.get(CassandraDataSourceKeyspaceReplicationStringProperty).get } diff --git a/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/statements/CreateTableStatement.scala b/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/statements/CreateTableStatement.scala index 65ae5a15a..a3ac17a27 100644 --- a/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/statements/CreateTableStatement.scala +++ b/cassandra/src/main/scala/com/stratio/crossdata/connector/cassandra/statements/CreateTableStatement.scala @@ -21,20 +21,19 @@ import org.apache.spark.sql.types.StructType case class CreateTableStatement(tableName: String, schema: StructType, - options: Map[String, String] - ) { + options: Map[String, String]) { override def toString(): String = { - s"CREATE TABLE ${if(ifNotExists) "IF NOT EXISTS " else ""}$keyspace.$tableName (" + schema.fields.foldLeft("") { - case (prev: String, next: StructField) => - val cassandraDataType = CassandraUtils.fromSparkSqlType(next.dataType) - prev + s"${next.name} ${cassandraDataType.cqlTypeName}, " - } + s"PRIMARY KEY ($primaryKeyString))" + s"CREATE TABLE ${if (ifNotExists) "IF NOT EXISTS " else ""}$keyspace.$tableName (" + schema.fields + .foldLeft("") { + case (prev: String, next: StructField) => + val cassandraDataType = CassandraUtils.fromSparkSqlType(next.dataType) + prev + s"${next.name} ${cassandraDataType.cqlTypeName}, " + } + s"PRIMARY KEY ($primaryKeyString))" } - lazy val ifNotExists: Boolean = { options.contains("ifNotExist") } @@ -43,12 +42,11 @@ case class CreateTableStatement(tableName: String, options.get("keyspace").get } - lazy val primaryKeyString:String = { - require(options.contains(CassandraDataSourcePrimaryKeyStringProperty), - s"$CassandraDataSourcePrimaryKeyStringProperty required when use CREATE EXTERNAL TABLE command") + lazy val primaryKeyString: String = { + require( + options.contains(CassandraDataSourcePrimaryKeyStringProperty), + s"$CassandraDataSourcePrimaryKeyStringProperty required when use CREATE EXTERNAL TABLE command") options.get(CassandraDataSourcePrimaryKeyStringProperty).get } - - } diff --git a/cassandra/src/main/scala/org/apache/spark/sql/cassandra/CassandraXDSourceRelation.scala b/cassandra/src/main/scala/org/apache/spark/sql/cassandra/CassandraXDSourceRelation.scala index 5d604ca99..f560db65f 100644 --- a/cassandra/src/main/scala/org/apache/spark/sql/cassandra/CassandraXDSourceRelation.scala +++ b/cassandra/src/main/scala/org/apache/spark/sql/cassandra/CassandraXDSourceRelation.scala @@ -46,11 +46,11 @@ import org.apache.spark.sql.{DataFrame, Row, SQLContext, sources} import org.apache.spark.unsafe.types.UTF8String /** - * Implements [[org.apache.spark.sql.sources.BaseRelation]]]], [[org.apache.spark.sql.sources.InsertableRelation]]]] - * and [[org.apache.spark.sql.sources.PrunedFilteredScan]]]] - * It inserts data to and scans Cassandra table. If filterPushdown is true, it pushs down - * some filters to CQL - */ + * Implements [[org.apache.spark.sql.sources.BaseRelation]]]], [[org.apache.spark.sql.sources.InsertableRelation]]]] + * and [[org.apache.spark.sql.sources.PrunedFilteredScan]]]] + * It inserts data to and scans Cassandra table. If filterPushdown is true, it pushs down + * some filters to CQL + */ class CassandraXDSourceRelation(tableRef: TableRef, userSpecifiedSchema: Option[StructType], filterPushdown: Boolean, @@ -59,11 +59,12 @@ class CassandraXDSourceRelation(tableRef: TableRef, readConf: ReadConf, writeConf: WriteConf, @transient override val sqlContext: SQLContext) - extends BaseRelation - with InsertableRelation - with PrunedFilteredScan - with NativeFunctionExecutor - with NativeScan with SparkLoggerComponent { + extends BaseRelation + with InsertableRelation + with PrunedFilteredScan + with NativeFunctionExecutor + with NativeScan + with SparkLoggerComponent { // NativeScan implementation ~~ @@ -83,24 +84,32 @@ class CassandraXDSourceRelation(tableRef: TableRef, } - override def isSupported(logicalStep: LogicalPlan, wholeLogicalPlan: LogicalPlan): Boolean = logicalStep match { - case ln: LeafNode => true // TODO leafNode == LogicalRelation(xdSourceRelation) - case un: UnaryNode => un match { - case Limit(_, _) | Project(_, _) | Filter(_, _) | EvaluateNativeUDF(_, _, _) => true - case aggregatePlan: Aggregate => isAggregateSupported(aggregatePlan) - case _ => false + override def isSupported(logicalStep: LogicalPlan, wholeLogicalPlan: LogicalPlan): Boolean = + logicalStep match { + case ln: LeafNode => + true // TODO leafNode == LogicalRelation(xdSourceRelation) + case un: UnaryNode => + un match { + case Limit(_, _) | Project(_, _) | Filter(_, _) | EvaluateNativeUDF(_, _, _) => + true + case aggregatePlan: Aggregate => isAggregateSupported(aggregatePlan) + case _ => false + } + case unsupportedLogicalPlan => + log.debug(s"LogicalPlan $unsupportedLogicalPlan cannot be executed natively"); + false } - case unsupportedLogicalPlan => log.debug(s"LogicalPlan $unsupportedLogicalPlan cannot be executed natively"); false - } - def isAggregateSupported(aggregateLogicalPlan: Aggregate): Boolean = aggregateLogicalPlan match { - case Aggregate(Nil, aggregateExpressions, _) if aggregateExpressions.length == 1 => - aggregateExpressions.head match { - case Alias(Count(Literal(1, _) :: Nil), _) => false // TODO Keep it unless Cassandra implement the count efficiently - case _ => false - } - case _ => false - } + def isAggregateSupported(aggregateLogicalPlan: Aggregate): Boolean = + aggregateLogicalPlan match { + case Aggregate(Nil, aggregateExpressions, _) if aggregateExpressions.length == 1 => + aggregateExpressions.head match { + case Alias(Count(Literal(1, _) :: Nil), _) => + false // TODO Keep it unless Cassandra implement the count efficiently + case _ => false + } + case _ => false + } // ~~ NativeScan implementation @@ -115,7 +124,8 @@ class CassandraXDSourceRelation(tableRef: TableRef, connector.withSessionDo { val keyspace = quote(tableRef.keyspace) val table = quote(tableRef.table) - session => session.execute(s"TRUNCATE $keyspace.$table") + session => + session.execute(s"TRUNCATE $keyspace.$table") } } @@ -137,10 +147,11 @@ class CassandraXDSourceRelation(tableRef: TableRef, def buildScan(): RDD[Row] = baseRdd.asInstanceOf[RDD[Row]] - override def unhandledFilters(filters: Array[Filter]): Array[Filter] = filterPushdown match { - case true => predicatePushDown(filters).handledBySpark.toArray - case false => filters - } + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = + filterPushdown match { + case true => predicatePushDown(filters).handledBySpark.toArray + case false => filters + } lazy val additionalRules: Seq[CassandraPredicateRules] = { import CassandraSourceRelation.AdditionalCassandraPushDownRulesParam @@ -148,14 +159,13 @@ class CassandraXDSourceRelation(tableRef: TableRef, /* So we can set this in testing to different values without making a new context check local property as well */ - val userClasses: Option[String] = - sc.getConf.getOption(AdditionalCassandraPushDownRulesParam.name) - .orElse(Option(sc.getLocalProperty(AdditionalCassandraPushDownRulesParam.name))) + val userClasses: Option[String] = sc.getConf + .getOption(AdditionalCassandraPushDownRulesParam.name) + .orElse(Option(sc.getLocalProperty(AdditionalCassandraPushDownRulesParam.name))) userClasses match { case Some(classes) => - classes - .trim + classes.trim .split("""\s*,\s*""") .map(ReflectionUtil.findGlobalObject[CassandraPredicateRules]) .reverse @@ -172,19 +182,18 @@ class CassandraXDSourceRelation(tableRef: TableRef, logDebug(s"Basic Rules Applied:\n$basicPushdown") /** Apply any user defined rules **/ - val finalPushdown = additionalRules.foldRight(basicPushdown)( - (rules, pushdowns) => { - val pd = rules(pushdowns, tableDef) - logDebug(s"Applied ${rules.getClass.getSimpleName} Pushdown Filters:\n$pd") - pd - } + val finalPushdown = additionalRules.foldRight(basicPushdown)( + (rules, pushdowns) => { + val pd = rules(pushdowns, tableDef) + logDebug(s"Applied ${rules.getClass.getSimpleName} Pushdown Filters:\n$pd") + pd + } ) logDebug(s"Final Pushdown filters:\n$finalPushdown") finalPushdown } - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { buildScan(requiredColumns, filters, Map.empty) } @@ -193,7 +202,6 @@ class CassandraXDSourceRelation(tableRef: TableRef, filters: Array[Filter], udfs: Map[String, NativeUDF]): RDD[Row] = { - val prunedRdd = maybeSelect(baseRdd, requiredColumns, udfs) logInfo(s"filters: ${filters.mkString(", ")}") val prunedFilteredRdd = { @@ -209,10 +217,12 @@ class CassandraXDSourceRelation(tableRef: TableRef, prunedFilteredRdd.asInstanceOf[RDD[Row]] } - private def resolveUDFsReferences(strId: String, udfs: Map[String, NativeUDF]): Option[FunctionCallRef] = + private def resolveUDFsReferences(strId: String, + udfs: Map[String, NativeUDF]): Option[FunctionCallRef] = udfs.get(strId) map { udf => val actualParams = udf.children.collect { - case at: AttributeReference if udfs contains at.toString => Left(resolveUDFsReferences(at.toString(), udfs).get) + case at: AttributeReference if udfs contains at.toString => + Left(resolveUDFsReferences(at.toString(), udfs).get) case at: AttributeReference => Left(ColumnName(at.name)) case lit: Literal => Right(lit.toString()) } @@ -223,12 +233,12 @@ class CassandraXDSourceRelation(tableRef: TableRef, private type RDDType = CassandraRDD[CassandraSQLRow] /** Transfer selection to limit to columns specified */ - private def maybeSelect( - rdd: RDDType, - requiredColumns: Array[String], - udfs: Map[String, NativeUDF] = Map.empty): RDDType = { + private def maybeSelect(rdd: RDDType, + requiredColumns: Array[String], + udfs: Map[String, NativeUDF] = Map.empty): RDDType = { if (requiredColumns.nonEmpty) { - val cols = requiredColumns.map(column => resolveUDFsReferences(column, udfs).getOrElse(column: ColumnRef)) + val cols = requiredColumns.map(column => + resolveUDFsReferences(column, udfs).getOrElse(column: ColumnRef)) rdd.select(cols: _*) } else { rdd @@ -252,7 +262,8 @@ class CassandraXDSourceRelation(tableRef: TableRef, udfs: Map[String, NativeUDF] = Map.empty): (String, Seq[Any]) = { def udfvalcmp(attribute: String, cmpOp: String, f: AttributeReference): (String, Seq[Any]) = - (s"${quote(attribute)} $cmpOp ${resolveUDFsReferences(f.toString(), udfs).get.cql}", Seq.empty) + (s"${quote(attribute)} $cmpOp ${resolveUDFsReferences(f.toString(), udfs).get.cql}", + Seq.empty) filter match { case sources.EqualTo(attribute, f: AttributeReference) if udfs contains f.toString => @@ -275,17 +286,19 @@ class CassandraXDSourceRelation(tableRef: TableRef, case sources.GreaterThan(attribute, value) => (s"${quote(attribute)} > ?", Seq(toCqlValue(attribute, value))) - case sources.GreaterThanOrEqual(attribute, f: AttributeReference) if udfs contains f.toString => + case sources.GreaterThanOrEqual(attribute, f: AttributeReference) + if udfs contains f.toString => udfvalcmp(attribute, ">=", f) case sources.GreaterThanOrEqual(attribute, value) => (s"${quote(attribute)} >= ?", Seq(toCqlValue(attribute, value))) - case sources.In(attribute, values) => - (quote(attribute) + " IN " + values.map(_ => "?").mkString("(", ", ", ")"), toCqlValues(attribute, values)) + case sources.In(attribute, values) => + (quote(attribute) + " IN " + values.map(_ => "?").mkString("(", ", ", ")"), + toCqlValues(attribute, values)) case _ => throw new UnsupportedOperationException( - s"It's not a valid filter $filter to be pushed down, only >, <, >=, <= and In are allowed.") + s"It's not a valid filter $filter to be pushed down, only >, <, >=, <= and In are allowed.") } } @@ -303,7 +316,7 @@ class CassandraXDSourceRelation(tableRef: TableRef, val columnType = tableDef.columnByName(columnName).columnType if (columnType == InetType) { InetAddress.getByName(utf8String.toString) - } else if(columnType == UUIDType) { + } else if (columnType == UUIDType) { UUID.fromString(utf8String.toString) } else { utf8String @@ -313,7 +326,8 @@ class CassandraXDSourceRelation(tableRef: TableRef, } /** Construct where clause from pushdown filters */ - private def whereClause(pushdownFilters: Seq[Any], udfs: Map[String, NativeUDF] = Map.empty): (String, Seq[Any]) = { + private def whereClause(pushdownFilters: Seq[Any], + udfs: Map[String, NativeUDF] = Map.empty): (String, Seq[Any]) = { val cqlValue = pushdownFilters.map(filterToCqlAndValue(_, udfs)) val cql = cqlValue.map(_._1).mkString(" AND ") val args = cqlValue.flatMap(_._2) @@ -322,7 +336,6 @@ class CassandraXDSourceRelation(tableRef: TableRef, } - object CassandraXDSourceRelation { import CassandraSourceRelation._ @@ -334,20 +347,16 @@ object CassandraXDSourceRelation { val sparkConf = sqlContext.sparkContext.getConf val sqlConf = sqlContext.getAllConfs - val conf = - consolidateConfs(sparkConf, sqlConf, tableRef, options.cassandraConfs) + val conf = consolidateConfs(sparkConf, sqlConf, tableRef, options.cassandraConfs) val tableSizeInBytesString = conf.getOption(CassandraSourceRelation.TableSizeInBytesParam.name) - val cassandraConnector = - new CassandraConnector(CassandraConnectorConf(conf)) + val cassandraConnector = new CassandraConnector(CassandraConnectorConf(conf)) val tableSizeInBytes = tableSizeInBytesString match { case Some(size) => Option(size.toLong) case None => val tokenFactory = CassandraPartitionGenerator.getTokenFactory(cassandraConnector) val dataSizeInBytes = - new DataSizeEstimates( - cassandraConnector, - tableRef.keyspace, - tableRef.table)(tokenFactory).totalDataSizeInBytes + new DataSizeEstimates(cassandraConnector, tableRef.keyspace, tableRef.table)( + tokenFactory).totalDataSizeInBytes if (dataSizeInBytes <= 0L) { None } else { @@ -357,15 +366,14 @@ object CassandraXDSourceRelation { val readConf = ReadConf.fromSparkConf(conf) val writeConf = WriteConf.fromSparkConf(conf) - new CassandraXDSourceRelation( - tableRef = tableRef, - userSpecifiedSchema = schema, - filterPushdown = options.pushdown, - tableSizeInBytes = tableSizeInBytes, - connector = cassandraConnector, - readConf = readConf, - writeConf = writeConf, - sqlContext = sqlContext) + new CassandraXDSourceRelation(tableRef = tableRef, + userSpecifiedSchema = schema, + filterPushdown = options.pushdown, + tableSizeInBytes = tableSizeInBytes, + connector = cassandraConnector, + readConf = readConf, + writeConf = writeConf, + sqlContext = sqlContext) } -} \ No newline at end of file +} diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraAggregationIT.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraAggregationIT.scala index 8f74deb4a..18d1fbadf 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraAggregationIT.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraAggregationIT.scala @@ -23,7 +23,7 @@ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class CassandraAggregationIT extends CassandraWithSharedContext { - val nativeErrorMessage = "The operation cannot be executed without Spark" + val nativeErrorMessage = "The operation cannot be executed without Spark" // PRIMARY KEY id // CLUSTERING KEY age, comment // DEFAULT enrolled @@ -45,7 +45,6 @@ class CassandraAggregationIT extends CassandraWithSharedContext { } should have message nativeErrorMessage } - it should "not execute natively a (SELECT count(*) FROM _ WHERE _)" in { assumeEnvironmentIsUpAndRunning diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraConnectorIT.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraConnectorIT.scala index f80dda9cb..366f65a7c 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraConnectorIT.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraConnectorIT.scala @@ -34,12 +34,11 @@ class CassandraConnectorIT extends CassandraWithSharedContext { val dataframe = sql(s"SELECT * FROM $Table ") val schema = dataframe.schema val result = dataframe.collect(executionType) - schema.fieldNames should equal (Seq("id", "age", "comment", "enrolled", "name")) + schema.fieldNames should equal(Seq("id", "age", "comment", "enrolled", "name")) result should have length 10 result(0) should have length 5 } - it should s"support a query with limit 0 for $executionType execution" in { assumeEnvironmentIsUpAndRunning @@ -73,25 +72,26 @@ class CassandraConnectorIT extends CassandraWithSharedContext { it should s"support a (SELECT * ... WHERE CK._1 = _ AND CK._2 = _) for $executionType execution" in { assumeEnvironmentIsUpAndRunning - val result = sql(s"SELECT * FROM $Table WHERE age = 13 AND comment = 'Comment 3' ").collect(executionType) + val result = sql(s"SELECT * FROM $Table WHERE age = 13 AND comment = 'Comment 3' ") + .collect(executionType) result should have length 1 } it should s"support a (SELECT * ... WHERE PK = _ AND CK._1 = _ AND CK._2 = _) for $executionType execution" in { assumeEnvironmentIsUpAndRunning - val result = sql(s"SELECT * FROM $Table WHERE id = 3 AND age = 13 AND comment = 'Comment 3' ").collect(executionType) + val result = + sql(s"SELECT * FROM $Table WHERE id = 3 AND age = 13 AND comment = 'Comment 3' ") + .collect(executionType) result should have length 1 } } - "Cassandra connector" should "execute natively a (SELECT * ... WHERE LUCENE_SEC_INDEX = _ )" in { assumeEnvironmentIsUpAndRunning - val result = sql( - s""" + val result = sql(s""" |SELECT * FROM $Table |WHERE name = |'{ filter : @@ -101,13 +101,12 @@ class CassandraConnectorIT extends CassandraWithSharedContext { result should have length 10 } - // NOT SUPPORTED FILTERS it should "not execute natively a (SELECT * ... WHERE LUCENE_SEC_INDEX < _ )" in { assumeEnvironmentIsUpAndRunning - the [CrossdataException] thrownBy { + the[CrossdataException] thrownBy { sql(s""" |SELECT * FROM $Table |WHERE name > @@ -145,8 +144,3 @@ class CassandraConnectorIT extends CassandraWithSharedContext { // TODO test filter on PKs (=) and CKs(any) (right -> left) } - - - - - diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraCreateExternalTableIT.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraCreateExternalTableIT.scala index d2225a627..b48bb962c 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraCreateExternalTableIT.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraCreateExternalTableIT.scala @@ -21,13 +21,11 @@ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class CassandraCreateExternalTableIT extends CassandraWithSharedContext { - "The Cassandra connector" should "execute natively create a External Table" in { val tableName = "newtable" - val createTableQueryString = - s"""|CREATE EXTERNAL TABLE $tableName ( + val createTableQueryString = s"""|CREATE EXTERNAL TABLE $tableName ( |id Integer, |name String, |booleanFile boolean, @@ -53,14 +51,14 @@ class CassandraCreateExternalTableIT extends CassandraWithSharedContext { //Expectations val table = xdContext.table(tableName) table should not be null - table.schema.fieldNames should contain ("name") + table.schema.fieldNames should contain("name") // In case that the table didn't exist, then this operation would throw a InvalidQueryException val resultSet = client.get._2.execute(s"SELECT * FROM $Catalog.$tableName") import scala.collection.JavaConversions._ - resultSet.getColumnDefinitions.asList.map(cd => cd.getName) should contain ("name") + resultSet.getColumnDefinitions.asList.map(cd => cd.getName) should contain("name") } it should "execute natively create a External Table with no existing Keyspace" in { @@ -85,7 +83,7 @@ class CassandraCreateExternalTableIT extends CassandraWithSharedContext { val table = xdContext.table(s"newkeyspace.othertable") table should not be null table.schema.fieldNames should contain("name") - }finally { + } finally { //AFTER client.get._2.execute(s"DROP KEYSPACE newkeyspace") } @@ -105,11 +103,10 @@ class CassandraCreateExternalTableIT extends CassandraWithSharedContext { """.stripMargin.replaceAll("\n", " ") //Experimentation - the [IllegalArgumentException] thrownBy { + the[IllegalArgumentException] thrownBy { sql(createTableQueryString).collect() - } should have message "requirement failed: with_replication required when use CREATE EXTERNAL TABLE command" + } should have message "requirement failed: with_replication required when use CREATE EXTERNAL TABLE command" } - } diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraDropExternalTableIT.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraDropExternalTableIT.scala index bff815ecd..540e5af32 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraDropExternalTableIT.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraDropExternalTableIT.scala @@ -26,8 +26,7 @@ class CassandraDropExternalTableIT extends CassandraWithSharedContext { protected override def beforeAll(): Unit = { super.beforeAll() - val createTableQueryString1 = - s"""|CREATE EXTERNAL TABLE $Catalog.dropTable1 ( + val createTableQueryString1 = s"""|CREATE EXTERNAL TABLE $Catalog.dropTable1 ( |id Integer, |name String, |booleanFile boolean, @@ -49,8 +48,7 @@ class CassandraDropExternalTableIT extends CassandraWithSharedContext { """.stripMargin.replaceAll("\n", " ") sql(createTableQueryString1).collect() - val createTableQueryString2 = - s"""|CREATE EXTERNAL TABLE dropTable2 ( + val createTableQueryString2 = s"""|CREATE EXTERNAL TABLE dropTable2 ( |id Integer, |name String, |booleanFile boolean, @@ -83,7 +81,7 @@ class CassandraDropExternalTableIT extends CassandraWithSharedContext { //DROP val dropExternalTableQuery = s"DROP EXTERNAL TABLE $Catalog.dropTable1" - sql(dropExternalTableQuery).collect() should be (Seq.empty) + sql(dropExternalTableQuery).collect() should be(Seq.empty) //Expectations an[Exception] shouldBe thrownBy(xdContext.table(s"$Catalog.dropTable1")) @@ -100,12 +98,12 @@ class CassandraDropExternalTableIT extends CassandraWithSharedContext { //DROP val dropExternalTableQuery = "DROP EXTERNAL TABLE dropTable2" - sql(dropExternalTableQuery).collect() should be (Seq.empty) + sql(dropExternalTableQuery).collect() should be(Seq.empty) //Expectations an[Exception] shouldBe thrownBy(xdContext.table("dropTable2")) client.get._1.getMetadata.getKeyspace(Catalog).getTable(cassandraTableName) shouldBe null } - + } diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraFunctionIT.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraFunctionIT.scala index 4c79f561d..77d61aada 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraFunctionIT.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraFunctionIT.scala @@ -22,14 +22,14 @@ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class CassandraFunctionIT extends CassandraWithSharedContext { - val execTypes: List[ExecutionType] = Native::Spark::Nil + val execTypes: List[ExecutionType] = Native :: Spark :: Nil execTypes.foreach { exec => - "The Cassandra connector" should s"be able to ${exec.toString}ly select the built-in functions `now`, `dateOf` and `unixTimeStampOf`" in { assumeEnvironmentIsUpAndRunning - val query = s"SELECT cassandra_now() as t, cassandra_now() as a, cassandra_dateOf(cassandra_now()) as dt, cassandra_unixTimestampOf(cassandra_now()) as ut FROM $Table" + val query = + s"SELECT cassandra_now() as t, cassandra_now() as a, cassandra_dateOf(cassandra_now()) as dt, cassandra_unixTimestampOf(cassandra_now()) as ut FROM $Table" sql(query).collect(exec) should have length 10 } } @@ -41,5 +41,4 @@ class CassandraFunctionIT extends CassandraWithSharedContext { sql(query).collect(Native) should have length 10 } - } diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraImportTablesIT.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraImportTablesIT.scala index 544924ed6..e549d4f35 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraImportTablesIT.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraImportTablesIT.scala @@ -28,8 +28,7 @@ class CassandraImportTablesIT extends CassandraWithSharedContext { def tableCountInHighschool: Long = xdContext.sql("SHOW TABLES").count val initialLength = tableCountInHighschool - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -51,8 +50,7 @@ class CassandraImportTablesIT extends CassandraWithSharedContext { assumeEnvironmentIsUpAndRunning val (cluster, session) = createOtherTables - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -78,8 +76,7 @@ class CassandraImportTablesIT extends CassandraWithSharedContext { val (cluster, session) = createOtherTables - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -93,11 +90,11 @@ class CassandraImportTablesIT extends CassandraWithSharedContext { xdContext.dropAllTables() val importedTables = sql(importQuery) // imported tables shouldn't be ignored (schema is (tableName, ignored) - importedTables.collect().forall( row => row.getBoolean(1)) shouldBe false + importedTables.collect().forall(row => row.getBoolean(1)) shouldBe false // imported tables should be ignored after importing twice - sql(importQuery).collect().forall( row => row.getBoolean(1)) shouldBe true - xdContext.tableNames() should contain (s"$Catalog.$Table") - xdContext.tableNames() should not contain "NewKeyspace.NewTable" + sql(importQuery).collect().forall(row => row.getBoolean(1)) shouldBe true + xdContext.tableNames() should contain(s"$Catalog.$Table") + xdContext.tableNames() should not contain "NewKeyspace.NewTable" } finally { cleanOtherTables(cluster, session) } @@ -107,8 +104,7 @@ class CassandraImportTablesIT extends CassandraWithSharedContext { assumeEnvironmentIsUpAndRunning xdContext.dropAllTables() - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -131,8 +127,7 @@ class CassandraImportTablesIT extends CassandraWithSharedContext { assumeEnvironmentIsUpAndRunning xdContext.dropAllTables() - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -143,18 +138,18 @@ class CassandraImportTablesIT extends CassandraWithSharedContext { """.stripMargin //Experimentation - an [IllegalArgumentException] should be thrownBy sql(importQuery) + an[IllegalArgumentException] should be thrownBy sql(importQuery) } val wrongImportTablesSentences = List( - s""" + s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( | cluster "$ClusterName" |) """.stripMargin, - s""" + s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -170,33 +165,32 @@ class CassandraImportTablesIT extends CassandraWithSharedContext { } } + def createOtherTables(): (Cluster, Session) = { + val (cluster, session) = prepareClient.get - def createOtherTables(): (Cluster, Session) ={ - val (cluster, session) = prepareClient.get - - session.execute(s"CREATE KEYSPACE NewKeyspace WITH replication = {'class':'SimpleStrategy', 'replication_factor':1} AND durable_writes = true;") - session.execute(s"CREATE TABLE NewKeyspace.NewTable (id int, coolstuff text, PRIMARY KEY (id))") + session.execute( + s"CREATE KEYSPACE NewKeyspace WITH replication = {'class':'SimpleStrategy', 'replication_factor':1} AND durable_writes = true;") + session.execute( + s"CREATE TABLE NewKeyspace.NewTable (id int, coolstuff text, PRIMARY KEY (id))") (cluster, session) } - def cleanOtherTables(cluster:Cluster, session:Session): Unit ={ + def cleanOtherTables(cluster: Cluster, session: Session): Unit = { session.execute(s"DROP KEYSPACE NewKeyspace") session.close() cluster.close() } - it should "infer schema after import One table from a keyspace using API" in { assumeEnvironmentIsUpAndRunning xdContext.dropAllTables() - val options = Map( - "cluster" -> ClusterName, - "keyspace" -> Catalog, - "table" -> Table, - "spark_cassandra_connection_host" -> CassandraHost) + val options = Map("cluster" -> ClusterName, + "keyspace" -> Catalog, + "table" -> Table, + "spark_cassandra_connection_host" -> CassandraHost) //Experimentation xdContext.importTables(SourceProvider, options) @@ -206,9 +200,3 @@ class CassandraImportTablesIT extends CassandraWithSharedContext { xdContext.tableNames() should not contain "highschool.teachers" } } - - - - - - diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraInsertCollection.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraInsertCollection.scala index 39f834f84..fbb1a4fba 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraInsertCollection.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraInsertCollection.scala @@ -24,27 +24,35 @@ trait CassandraInsertCollection extends CassandraWithSharedContext { override val Table = "studentsinserttest" - override val schema = ListMap("id" -> "int", "age" -> "int", "comment" -> "text", - "enrolled" -> "boolean", "name" -> "text", "array_test" ->"list", "map_test" -> "map", - "array_map" -> "list>>", "map_array" -> "map>>") - - override val pk = "(id)" :: "age" :: Nil + override val schema = ListMap("id" -> "int", + "age" -> "int", + "comment" -> "text", + "enrolled" -> "boolean", + "name" -> "text", + "array_test" -> "list", + "map_test" -> "map", + "array_map" -> "list>>", + "map_array" -> "map>>") + + override val pk = "(id)" :: "age" :: Nil override val testData = (for (a <- 1 to 10) yield { a :: (10 + a) :: - s"Comment $a" :: - (a % 2 == 0) :: - s"Name $a" :: - List(a.toString, (a+1).toString, (a+2).toString) :: - Map("x" -> (a + 1).toString, "y" -> (a + 2).toString) :: - List(Map("x" -> (a + 1).toString) ,Map("y" -> (a + 2).toString)) :: - Map("x" -> List((a + 1).toString, (a + 2).toString), "y" -> List((a + 2).toString)) ::Nil + s"Comment $a" :: + (a % 2 == 0) :: + s"Name $a" :: + List(a.toString, (a + 1).toString, (a + 2).toString) :: + Map("x" -> (a + 1).toString, "y" -> (a + 2).toString) :: + List(Map("x" -> (a + 1).toString), Map("y" -> (a + 2).toString)) :: + Map("x" -> List((a + 1).toString, (a + 2).toString), + "y" -> List((a + 2).toString)) :: Nil }).toList - abstract override def sparkRegisterTableSQL: Seq[SparkTable] = super.sparkRegisterTableSQL :+ - str2sparkTableDesc(s"CREATE TEMPORARY TABLE $Table") + abstract override def sparkRegisterTableSQL: Seq[SparkTable] = + super.sparkRegisterTableSQL :+ + str2sparkTableDesc(s"CREATE TEMPORARY TABLE $Table") override def defaultOptions = super.defaultOptions + ("table" -> Table) -} \ No newline at end of file +} diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraInsertTableIT.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraInsertTableIT.scala index 1377e91a7..8c8f25ae6 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraInsertTableIT.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraInsertTableIT.scala @@ -23,20 +23,30 @@ import org.scalatest.junit.JUnitRunner class CassandraInsertTableIT extends CassandraInsertCollection { it should "insert a row using INSERT INTO table VALUES in Cassandra" in { - _xdContext.sql(s"INSERT INTO $Table VALUES (20, 25, [(x -> 3)], ['proof'], 'proof description', true, (x->[1]), (a->2), 'Eve' )").collect() should be(Row(1) :: Nil) + _xdContext + .sql(s"INSERT INTO $Table VALUES (20, 25, [(x -> 3)], ['proof'], 'proof description', true, (x->[1]), (a->2), 'Eve' )") + .collect() should be(Row(1) :: Nil) //EXPECTATION val results = sql(s"select * from $Table where id=20").collect() results should have length 1 results should contain - Row(20, 25, "proof description", true, "Eve", - Seq("proof"), Map("a" -> "2"), List(Map("x" -> "1", "y" -> "1"), - Map("z" -> "1")), Map("x" -> List("1", "2"), "y" -> List("3", "4"))) + Row(20, + 25, + "proof description", + true, + "Eve", + Seq("proof"), + Map("a" -> "2"), + List(Map("x" -> "1", "y" -> "1"), Map("z" -> "1")), + Map("x" -> List("1", "2"), "y" -> List("3", "4"))) } it should "insert a row using INSERT INTO table(schema) VALUES in Cassandra" in { - _xdContext.sql(s"INSERT INTO $Table(id, age, name) VALUES (21, 25, 'Peter')").collect() should be(Row(1) :: Nil) + _xdContext + .sql(s"INSERT INTO $Table(id, age, name) VALUES (21, 25, 'Peter')") + .collect() should be(Row(1) :: Nil) //EXPECTATION val results = sql(s"select id, age, enrolled from $Table where id=21").collect() @@ -47,8 +57,7 @@ class CassandraInsertTableIT extends CassandraInsertCollection { } it should "insert multiple rows using INSERT INTO table VALUES in Cassandra" in { - val query = - s"""|INSERT INTO $Table VALUES + val query = s"""|INSERT INTO $Table VALUES |(22, 25, [(x -> 1)], [4,5], 'proof description', true, (x->[1,5]), (x -> 1), 'John' ), |(23, 1, [(x -> 7, y->8)], [1,2,3], 'other description', false, (x->[1]), (key -> value), 'James' ), |(24, 33, [(x -> 3)], [true,true], 'other fun description', false, (x->[1,9]), (z->1, a-> 2), 'July' ) @@ -57,106 +66,128 @@ class CassandraInsertTableIT extends CassandraInsertCollection { rows should be(Row(3) :: Nil) //EXPECTATION - val results = sql(s"select id,age,comment,enrolled,name,array_test,map_test,array_map,map_array from $Table where id=22 or id=23 or id=24").collect() + val results = sql( + s"select id,age,comment,enrolled,name,array_test,map_test,array_map,map_array from $Table where id=22 or id=23 or id=24") + .collect() results should have length 3 - results should contain allOf( - - Row(22, 25, "proof description", true, "John", Seq("4", "5"), - Map("x" -> "1"), Seq(Map("x" -> "1")), Map("x" -> Seq("1", "5"))), - - Row(23, 1, "other description", false, "James", Seq("1", "2", "3"), - Map("key" -> "value"), Seq(Map("x" -> "7", "y" -> "8")), Map("x" -> Seq("1"))), - - Row(24, 33, "other fun description", false, "July", Seq("true", "true"), - Map("z" -> "1", "a" -> "2"), Seq(Map("x" -> "3")), Map("x" -> Seq("1", "9"))) - ) + results should contain allOf ( + Row(22, + 25, + "proof description", + true, + "John", + Seq("4", "5"), + Map("x" -> "1"), + Seq(Map("x" -> "1")), + Map("x" -> Seq("1", "5"))), + Row(23, + 1, + "other description", + false, + "James", + Seq("1", "2", "3"), + Map("key" -> "value"), + Seq(Map("x" -> "7", "y" -> "8")), + Map("x" -> Seq("1"))), + Row(24, + 33, + "other fun description", + false, + "July", + Seq("true", "true"), + Map("z" -> "1", "a" -> "2"), + Seq(Map("x" -> "3")), + Map("x" -> Seq("1", "9"))) + ) } it should "insert multiple rows using INSERT INTO table(schema) VALUES in Cassandra" in { - _xdContext.sql(s"INSERT INTO $Table(id, age, name, enrolled) VALUES ( 25, 50, 'Samantha', true),( 26, 1, 'Charlie', false)").collect() should be(Row(2) :: Nil) + _xdContext + .sql(s"INSERT INTO $Table(id, age, name, enrolled) VALUES ( 25, 50, 'Samantha', true),( 26, 1, 'Charlie', false)") + .collect() should be(Row(2) :: Nil) //EXPECTATION val results = sql(s"select id, age, name, enrolled from $Table where id=25 or id=26").collect() results should have length 2 - results should contain allOf( - Row(25, 50, "Samantha", true), - Row(26, 1, "Charlie", false) - ) + results should contain allOf ( + Row(25, 50, "Samantha", true), + Row(26, 1, "Charlie", false) + ) } it should "insert rows using INSERT INTO table(schema) VALUES with Arrays in Cassandra" in { - val query = - s"""|INSERT INTO $Table (id, age, name, enrolled, array_test) VALUES + val query = s"""|INSERT INTO $Table (id, age, name, enrolled, array_test) VALUES |(27, 55, 'Jules', true, [true, false]), |(28, 12, 'Martha', false, ['test1,t', 'test2']) """.stripMargin _xdContext.sql(query).collect() should be(Row(2) :: Nil) //EXPECTATION - val results = sql(s"select id, age, name, enrolled, array_test from $Table where id=27 or id=28").collect() + val results = + sql(s"select id, age, name, enrolled, array_test from $Table where id=27 or id=28").collect() results should have length 2 - results should contain allOf( - Row(27, 55, "Jules", true, Seq("true", "false")), - Row(28, 12, "Martha", false, Seq("test1,t", "test2")) - ) + results should contain allOf ( + Row(27, 55, "Jules", true, Seq("true", "false")), + Row(28, 12, "Martha", false, Seq("test1,t", "test2")) + ) } it should "insert rows using INSERT INTO table(schema) VALUES with Map in Cassandra" in { - val query = - s"""|INSERT INTO $Table (id, age, name, enrolled, map_test) VALUES + val query = s"""|INSERT INTO $Table (id, age, name, enrolled, map_test) VALUES |( 29, 12, 'Albert', true, (x->1, y->2, z->3) ), |( 30, 20, 'Alfred', false, (xa->1, ya->2, za->3,d -> 5) ) """.stripMargin _xdContext.sql(query).collect() should be(Row(2) :: Nil) //EXPECTATION - val results = sql(s"select id, age, name, enrolled, map_test from $Table where id=29 or id=30").collect() + val results = + sql(s"select id, age, name, enrolled, map_test from $Table where id=29 or id=30").collect() results should have length 2 - results should contain allOf( - Row(29, 12, "Albert", true, Map("x" -> "1", "y" -> "2", "z" -> "3")), - Row(30, 20, "Alfred", false, Map("xa" -> "1", "ya" -> "2", "za" -> "3", "d" -> "5")) - ) + results should contain allOf ( + Row(29, 12, "Albert", true, Map("x" -> "1", "y" -> "2", "z" -> "3")), + Row(30, 20, "Alfred", false, Map("xa" -> "1", "ya" -> "2", "za" -> "3", "d" -> "5")) + ) } it should "insert rows using INSERT INTO table(schema) VALUES with Array of Maps in Cassandra" in { - val query = - s"""|INSERT INTO $Table (id, age,name, enrolled, array_map) VALUES + val query = s"""|INSERT INTO $Table (id, age,name, enrolled, array_map) VALUES |(31, 1, 'Nikolai', true, [(x -> 3), (z -> 1)] ), |(32, 14, 'Ludwig', false, [(x -> 1, y-> 1), (z -> 1)] ) """.stripMargin _xdContext.sql(query).collect() should be(Row(2) :: Nil) //EXPECTATION - val results = sql(s"select id, age,name, enrolled, array_map from $Table where id=31 or id=32").collect() + val results = + sql(s"select id, age,name, enrolled, array_map from $Table where id=31 or id=32").collect() results should have length 2 - results should contain allOf( - Row(31, 1, "Nikolai", true, Seq(Map("x" -> "3"), Map("z" -> "1"))), - Row(32, 14, "Ludwig", false, Seq(Map("x" -> "1", "y" -> "1"), Map("z" -> "1"))) - ) + results should contain allOf ( + Row(31, 1, "Nikolai", true, Seq(Map("x" -> "3"), Map("z" -> "1"))), + Row(32, 14, "Ludwig", false, Seq(Map("x" -> "1", "y" -> "1"), Map("z" -> "1"))) + ) } it should "insert rows using INSERT INTO table(schema) VALUES with Map of Array in Cassandra" in { - val query = - s"""|INSERT INTO $Table (id, age,name, enrolled, map_array) VALUES + val query = s"""|INSERT INTO $Table (id, age,name, enrolled, map_array) VALUES |(33, 13, 'Svletana', true, ( x->[1], y-> [3,4] ) ), |(34, 17, 'Wolfang', false, ( x->[1,2], y-> [3] ) ) """.stripMargin _xdContext.sql(query).collect() should be(Row(2) :: Nil) //EXPECTATION - val results = sql(s"select id, age,name, enrolled, map_array from $Table where id=33 or id=34").collect() + val results = + sql(s"select id, age,name, enrolled, map_array from $Table where id=33 or id=34").collect() results should have length 2 - results should contain allOf( - Row(33, 13, "Svletana", true, Map("x" -> Seq("1"), "y" -> Seq("3", "4"))), - Row(34, 17, "Wolfang", false, Map("x" -> Seq("1", "2"), "y" -> Seq("3"))) - ) + results should contain allOf ( + Row(33, 13, "Svletana", true, Map("x" -> Seq("1"), "y" -> Seq("3", "4"))), + Row(34, 17, "Wolfang", false, Map("x" -> Seq("1", "2"), "y" -> Seq("3"))) + ) } } diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraPKFiltersIT.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraPKFiltersIT.scala index b275602f4..f885ff58f 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraPKFiltersIT.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraPKFiltersIT.scala @@ -39,11 +39,11 @@ class CassandraPKFiltersIT extends CassandraWithSharedContext { override val testData = List(List(FixedDate)) override val defaultOptions = Map( - "table" -> Table, - "keyspace" -> Catalog, - "cluster" -> ClusterName, - "pushdown" -> "true", - "spark_cassandra_connection_host" -> CassandraHost + "table" -> Table, + "keyspace" -> Catalog, + "cluster" -> ClusterName, + "pushdown" -> "true", + "spark_cassandra_connection_host" -> CassandraHost ) // PRIMARY KEY date @@ -54,14 +54,9 @@ class CassandraPKFiltersIT extends CassandraWithSharedContext { val optimizedPlan = dataframe.queryExecution.optimizedPlan val schema = dataframe.schema val result = dataframe.collect(Native) - schema.fieldNames should equal (Seq(pk(0))) + schema.fieldNames should equal(Seq(pk(0))) result should have length 1 result(0) should have length 1 } } - - - - - diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraQueryProcessorSpec.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraQueryProcessorSpec.scala index ec16ff636..056b3f681 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraQueryProcessorSpec.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraQueryProcessorSpec.scala @@ -35,98 +35,138 @@ class CassandraQueryProcessorSpec extends BaseXDTest { val ValueAge = 25 val ValueAge2 = 30 val ValueId = "00123" - + val Function01 = "F#01" val Function02 = "F#02" - + val udfs = Map( - Function01 -> - NativeUDF(getFunctionName(Function01), DataTypes.StringType, AttributeReference("id", DataTypes.StringType) - ()::Nil), - Function02 -> - NativeUDF(getFunctionName(Function02), DataTypes.StringType, Literal(42)::Nil) + Function01 -> + NativeUDF(getFunctionName(Function01), + DataTypes.StringType, + AttributeReference("id", DataTypes.StringType)() :: Nil), + Function02 -> + NativeUDF(getFunctionName(Function02), DataTypes.StringType, Literal(42) :: Nil) ) protected def getFunctionName(fid: String): String = fid.split("#").head.trim "A CassandraQueryProcessor" should "build a query requiring some columns" in { - val query = CassandraQueryProcessor.buildNativeQuery(TableQN, Array(ColumnId, ColumnAge), Array(), Limit) + val query = + CassandraQueryProcessor.buildNativeQuery(TableQN, Array(ColumnId, ColumnAge), Array(), Limit) query should be(s"SELECT $ColumnId, $ColumnAge FROM $TableQN LIMIT $Limit ALLOW FILTERING") } it should "build a query with two equal filters" in { val query = CassandraQueryProcessor.buildNativeQuery( - TableQN, Array(ColumnId), Array(sources.EqualTo(ColumnAge, ValueAge), sources.EqualTo(ColumnId, ValueId)), Limit) + TableQN, + Array(ColumnId), + Array(sources.EqualTo(ColumnAge, ValueAge), sources.EqualTo(ColumnId, ValueId)), + Limit) - query should be(s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge = $ValueAge AND $ColumnId = '$ValueId' LIMIT $Limit ALLOW FILTERING") + query should be( + s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge = $ValueAge AND $ColumnId = '$ValueId' LIMIT $Limit ALLOW FILTERING") } it should "build a query with a IN clause" in { val query = CassandraQueryProcessor.buildNativeQuery( - TableQN, Array(ColumnId), Array(sources.In(ColumnAge, Array(ValueAge, ValueAge2))), Limit) + TableQN, + Array(ColumnId), + Array(sources.In(ColumnAge, Array(ValueAge, ValueAge2))), + Limit) - query should be(s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge IN ($ValueAge,$ValueAge2) LIMIT $Limit ALLOW FILTERING") + query should be( + s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge IN ($ValueAge,$ValueAge2) LIMIT $Limit ALLOW FILTERING") } it should "build a query with a IN clause and a single value" in { - val query = CassandraQueryProcessor.buildNativeQuery( - TableQN, Array(ColumnId), Array(sources.In(ColumnAge, Array(ValueAge))), Limit) - - query should be(s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge IN ($ValueAge) LIMIT $Limit ALLOW FILTERING") + val query = + CassandraQueryProcessor.buildNativeQuery(TableQN, + Array(ColumnId), + Array(sources.In(ColumnAge, Array(ValueAge))), + Limit) + + query should be( + s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge IN ($ValueAge) LIMIT $Limit ALLOW FILTERING") } it should "build a query with a LT clause " in { - val query = CassandraQueryProcessor.buildNativeQuery( - TableQN, Array(ColumnId), Array(sources.LessThan(ColumnAge, ValueAge)), Limit) - - query should be(s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge < $ValueAge LIMIT $Limit ALLOW FILTERING") + val query = + CassandraQueryProcessor.buildNativeQuery(TableQN, + Array(ColumnId), + Array(sources.LessThan(ColumnAge, ValueAge)), + Limit) + + query should be( + s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge < $ValueAge LIMIT $Limit ALLOW FILTERING") } it should "build a query with a LTE clause " in { - val query = CassandraQueryProcessor.buildNativeQuery( - TableQN, Array(ColumnId), Array(sources.LessThanOrEqual(ColumnAge, ValueAge)), Limit) - - query should be(s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge <= $ValueAge LIMIT $Limit ALLOW FILTERING") + val query = + CassandraQueryProcessor.buildNativeQuery(TableQN, + Array(ColumnId), + Array(sources.LessThanOrEqual(ColumnAge, ValueAge)), + Limit) + + query should be( + s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge <= $ValueAge LIMIT $Limit ALLOW FILTERING") } it should "build a query with a GT clause " in { - val query = CassandraQueryProcessor.buildNativeQuery( - TableQN, Array(ColumnId), Array(sources.GreaterThan(ColumnAge, ValueAge)), Limit) - - query should be(s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge > $ValueAge LIMIT $Limit ALLOW FILTERING") + val query = + CassandraQueryProcessor.buildNativeQuery(TableQN, + Array(ColumnId), + Array(sources.GreaterThan(ColumnAge, ValueAge)), + Limit) + + query should be( + s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge > $ValueAge LIMIT $Limit ALLOW FILTERING") } it should "build a query with a GTE clause " in { val query = CassandraQueryProcessor.buildNativeQuery( - TableQN, Array(ColumnId), Array(sources.GreaterThanOrEqual(ColumnAge, ValueAge)), Limit) + TableQN, + Array(ColumnId), + Array(sources.GreaterThanOrEqual(ColumnAge, ValueAge)), + Limit) - query should be(s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge >= $ValueAge LIMIT $Limit ALLOW FILTERING") + query should be( + s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge >= $ValueAge LIMIT $Limit ALLOW FILTERING") } it should "build a query with an AND clause " in { val query = CassandraQueryProcessor.buildNativeQuery( - TableQN, Array(ColumnId), Array(sources.And(sources.GreaterThan(ColumnAge, ValueAge), sources.LessThan(ColumnAge, ValueAge2))), Limit) - - query should be(s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge > $ValueAge AND $ColumnAge < $ValueAge2 LIMIT $Limit ALLOW FILTERING") + TableQN, + Array(ColumnId), + Array( + sources.And(sources.GreaterThan(ColumnAge, ValueAge), + sources.LessThan(ColumnAge, ValueAge2))), + Limit) + + query should be( + s"SELECT $ColumnId FROM $TableQN WHERE $ColumnAge > $ValueAge AND $ColumnAge < $ValueAge2 LIMIT $Limit ALLOW FILTERING") } it should "built a query with filters calling a pushed-down function" in { val predicate2expectationOp = List( - sources.EqualTo(Function01, ValueId) -> "=", - sources.GreaterThan(Function01, ValueId) -> ">", - sources.LessThan(Function01, ValueId) -> "<", - sources.GreaterThanOrEqual(Function01, ValueId) -> ">=", - sources.LessThanOrEqual(Function01, ValueId) -> "<=" + sources.EqualTo(Function01, ValueId) -> "=", + sources.GreaterThan(Function01, ValueId) -> ">", + sources.LessThan(Function01, ValueId) -> "<", + sources.GreaterThanOrEqual(Function01, ValueId) -> ">=", + sources.LessThanOrEqual(Function01, ValueId) -> "<=" ) - for((predicate, operatorStr) <- predicate2expectationOp) { + for ((predicate, operatorStr) <- predicate2expectationOp) { val query = CassandraQueryProcessor.buildNativeQuery( - TableQN, Array(ColumnId), Array(predicate), Limit, udfs + TableQN, + Array(ColumnId), + Array(predicate), + Limit, + udfs ) query should be( - s"SELECT $ColumnId FROM $TableQN WHERE ${getFunctionName(Function01)}($ColumnId) ${operatorStr} '$ValueId' LIMIT $Limit ALLOW FILTERING" + s"SELECT $ColumnId FROM $TableQN WHERE ${getFunctionName(Function01)}($ColumnId) ${operatorStr} '$ValueId' LIMIT $Limit ALLOW FILTERING" ) } } @@ -134,19 +174,23 @@ class CassandraQueryProcessorSpec extends BaseXDTest { it should "build a query selecting a pushed-down function call" in { val function2expectedSel = List( - Function01 -> s"${getFunctionName(Function01)}(id)", - Function02 -> s"${getFunctionName(Function02)}(42)" + Function01 -> s"${getFunctionName(Function01)}(id)", + Function02 -> s"${getFunctionName(Function02)}(42)" ) - for((f, sel) <- function2expectedSel) { + for ((f, sel) <- function2expectedSel) { val query = CassandraQueryProcessor.buildNativeQuery( - TableQN, Array(f.toString), Array.empty, Limit, udfs + TableQN, + Array(f.toString), + Array.empty, + Limit, + udfs ) query should be(s"SELECT $sel FROM $TableQN LIMIT $Limit ALLOW FILTERING") } } - + /* "A CassandraXDSourceRelation" should "support natively a table scan" in { diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraTypesIT.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraTypesIT.scala index a1b5df776..000e9839d 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraTypesIT.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraTypesIT.scala @@ -33,13 +33,13 @@ class CassandraTypesIT extends CassandraWithSharedContext with SharedXDContextTy val session = client.get._2 val tableDDL: Seq[String] = - s"CREATE TYPE $Catalog.STRUCT (field1 INT, field2 INT)":: - s"CREATE TYPE $Catalog.STRUCT1 (structField1 VARCHAR, structField2 INT)":: - s"CREATE TYPE $Catalog.STRUCT_DATE (field1 TIMESTAMP, field2 INT)":: - s"CREATE TYPE $Catalog.STRUCT_STRUCT (field1 TIMESTAMP, field2 INT, struct1 frozen)":: - s"CREATE TYPE $Catalog.STRUCT_DATE1 (structField1 TIMESTAMP, structField2 INT)":: - s"CREATE TYPE $Catalog.STRUCT_ARRAY_STRUCT (stringfield VARCHAR, arrayfield LIST>)":: - s""" + s"CREATE TYPE $Catalog.STRUCT (field1 INT, field2 INT)" :: + s"CREATE TYPE $Catalog.STRUCT1 (structField1 VARCHAR, structField2 INT)" :: + s"CREATE TYPE $Catalog.STRUCT_DATE (field1 TIMESTAMP, field2 INT)" :: + s"CREATE TYPE $Catalog.STRUCT_STRUCT (field1 TIMESTAMP, field2 INT, struct1 frozen)" :: + s"CREATE TYPE $Catalog.STRUCT_DATE1 (structField1 TIMESTAMP, structField2 INT)" :: + s"CREATE TYPE $Catalog.STRUCT_ARRAY_STRUCT (stringfield VARCHAR, arrayfield LIST>)" :: + s""" |CREATE TABLE $Catalog.$TypesTable |( | id INT, @@ -71,12 +71,11 @@ class CassandraTypesIT extends CassandraWithSharedContext with SharedXDContextTy | mapstruct MAP>, | arraystructarraystruct LIST> |) - """.stripMargin::Nil + """.stripMargin :: Nil tableDDL.foreach(session.execute) - val dataQuery = - s"""| + val dataQuery = s"""| |INSERT INTO $Catalog.$TypesTable ( | id, int, bigint, long, string, boolean, double, float, decimalInt, decimalLong, | decimalDouble, decimalFloat, date, timestamp, tinyint, smallint, binary, @@ -98,23 +97,26 @@ class CassandraTypesIT extends CassandraWithSharedContext with SharedXDContextTy } - override protected def typesSet: Seq[SparkSQLColDef] = super.typesSet flatMap { - case SparkSQLColDef(_, "TINYINT", _) | SparkSQLColDef(_, "SMALLINT", _) => Nil - case SparkSQLColDef(name, "DATE", typeChecker) => - SparkSQLColDef(name, "TIMESTAMP", _.isInstanceOf[java.sql.Timestamp])::Nil - case SparkSQLColDef(name, sqlClause, typeChecker) if name contains "struct" => - SparkSQLColDef(name, sqlClause.replace("DATE", "TIMESTAMP"), typeChecker)::Nil - case other => - other::Nil - } + override protected def typesSet: Seq[SparkSQLColDef] = + super.typesSet flatMap { + case SparkSQLColDef(_, "TINYINT", _) | SparkSQLColDef(_, "SMALLINT", _) => + Nil + case SparkSQLColDef(name, "DATE", typeChecker) => + SparkSQLColDef(name, "TIMESTAMP", _.isInstanceOf[java.sql.Timestamp]) :: Nil + case SparkSQLColDef(name, sqlClause, typeChecker) if name contains "struct" => + SparkSQLColDef(name, sqlClause.replace("DATE", "TIMESTAMP"), typeChecker) :: Nil + case other => + other :: Nil + } - override def sparkAdditionalKeyColumns: Seq[SparkSQLColDef] = Seq(SparkSQLColDef("id", "INT")) + override def sparkAdditionalKeyColumns: Seq[SparkSQLColDef] = + Seq(SparkSQLColDef("id", "INT")) override def dataTypesSparkOptions: Map[String, String] = Map( - "table" -> TypesTable, - "keyspace" -> Catalog, - "cluster" -> ClusterName, - "pushdown" -> "true", - "spark_cassandra_connection_host" -> CassandraHost + "table" -> TypesTable, + "keyspace" -> Catalog, + "cluster" -> ClusterName, + "pushdown" -> "true", + "spark_cassandra_connection_host" -> CassandraHost ) //Perform test diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraWithSharedContext.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraWithSharedContext.scala index c3a8e41d8..ce6fe8255 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraWithSharedContext.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/CassandraWithSharedContext.scala @@ -25,35 +25,36 @@ import org.scalatest.Suite import scala.collection.immutable.ListMap import scala.util.Try -trait CassandraWithSharedContext extends SharedXDContextWithDataTest - with CassandraDefaultTestConstants - with SparkLoggerComponent { +trait CassandraWithSharedContext + extends SharedXDContextWithDataTest + with CassandraDefaultTestConstants + with SparkLoggerComponent { this: Suite => override type ClientParams = (Cluster, Session) override val provider: String = SourceProvider override def defaultOptions = Map( - "table" -> Table, - "keyspace" -> Catalog, - "cluster" -> ClusterName, - "pushdown" -> "false", // TODO replace with pushdown true when c* fix some issues - "spark_cassandra_connection_host" -> CassandraHost + "table" -> Table, + "keyspace" -> Catalog, + "cluster" -> ClusterName, + "pushdown" -> "false", // TODO replace with pushdown true when c* fix some issues + "spark_cassandra_connection_host" -> CassandraHost ) abstract override def saveTestData: Unit = { val session = client.get._2 - def stringifySchema(schema: Map[String, String]): String = schema.map(p => s"${p._1} ${p._2}").mkString(", ") + def stringifySchema(schema: Map[String, String]): String = + schema.map(p => s"${p._1} ${p._2}").mkString(", ") session.execute( - s"""CREATE KEYSPACE $Catalog WITH replication = {'class':'SimpleStrategy', 'replication_factor':1} + s"""CREATE KEYSPACE $Catalog WITH replication = {'class':'SimpleStrategy', 'replication_factor':1} |AND durable_writes = true;""".stripMargin.replaceAll("\n", " ")) - session.execute( - s"""CREATE TABLE $Catalog.$Table (${stringifySchema(schema)}, + session.execute(s"""CREATE TABLE $Catalog.$Table (${stringifySchema(schema)}, |PRIMARY KEY (${pk.mkString(", ")}))""".stripMargin.replaceAll("\n", " ")) - if(indexedColumn.nonEmpty){ + if (indexedColumn.nonEmpty) { session.execute(s""" |CREATE CUSTOM INDEX student_index ON $Catalog.$Table (name) |USING 'com.stratio.cassandra.lucene.Index' @@ -70,19 +71,27 @@ trait CassandraWithSharedContext extends SharedXDContextWithDataTest }*/ def insertRow(row: List[Any]): Unit = { - session.execute( - s"""INSERT INTO $Catalog.$Table(${schema.map(p => p._1).mkString(", ")}) + session.execute(s"""INSERT INTO $Catalog.$Table(${schema.map(p => p._1).mkString(", ")}) | VALUES (${parseRow(row)})""".stripMargin.replaceAll("\n", "")) } def parseRow(row: List[Any]): String = { - row map {col => parseElement(col)} mkString ", " + row map { col => + parseElement(col) + } mkString ", " } def parseElement(element: Any): String = { element match { - case map : Map[_,_] => map map { case (key,value) => s"${parseElement(key)} : ${parseElement(value)}" } mkString ("{", ", ", "}") - case list : Seq[_] => list map {listElement => parseElement(listElement)} mkString ("[", ", ", "]") + case map: Map[_, _] => + map map { + case (key, value) => + s"${parseElement(key)} : ${parseElement(value)}" + } mkString ("{", ", ", "}") + case list: Seq[_] => + list map { listElement => + parseElement(listElement) + } mkString ("[", ", ", "]") case string: String => s"'$string'" case other => other.toString } @@ -91,9 +100,8 @@ trait CassandraWithSharedContext extends SharedXDContextWithDataTest testData.foreach(insertRow(_)) //This creates a new table in the keyspace which will not be initially registered at the Spark - if(UnregisteredTable.nonEmpty){ - session.execute( - s"""CREATE TABLE $Catalog.$UnregisteredTable (${stringifySchema(schema)}, + if (UnregisteredTable.nonEmpty) { + session.execute(s"""CREATE TABLE $Catalog.$UnregisteredTable (${stringifySchema(schema)}, |PRIMARY KEY (${pk.mkString(", ")}))""".stripMargin.replaceAll("\n", " ")) } @@ -107,15 +115,18 @@ trait CassandraWithSharedContext extends SharedXDContextWithDataTest cluster.close() } - override protected def cleanTestData: Unit = client.get._2.execute(s"DROP KEYSPACE $Catalog") + override protected def cleanTestData: Unit = + client.get._2.execute(s"DROP KEYSPACE $Catalog") - override protected def prepareClient: Option[ClientParams] = Try { - val cluster = Cluster.builder().addContactPoint(CassandraHost).build() - (cluster, cluster.connect()) - } toOption + override protected def prepareClient: Option[ClientParams] = + Try { + val cluster = Cluster.builder().addContactPoint(CassandraHost).build() + (cluster, cluster.connect()) + } toOption - abstract override def sparkRegisterTableSQL: Seq[SparkTable] = super.sparkRegisterTableSQL :+ - str2sparkTableDesc(s"CREATE TEMPORARY TABLE $Table") + abstract override def sparkRegisterTableSQL: Seq[SparkTable] = + super.sparkRegisterTableSQL :+ + str2sparkTableDesc(s"CREATE TEMPORARY TABLE $Table") override val runningError: String = "Cassandra and Spark must be up and running" @@ -127,7 +138,11 @@ sealed trait CassandraDefaultTestConstants { val Table = "students" val TypesTable = "datatypestablename" val UnregisteredTable = "teachers" - val schema = ListMap("id" -> "int", "age" -> "int", "comment" -> "text", "enrolled" -> "boolean", "name" -> "text") + val schema = ListMap("id" -> "int", + "age" -> "int", + "comment" -> "text", + "enrolled" -> "boolean", + "name" -> "text") val pk = "(id)" :: "age" :: "comment" :: Nil val indexedColumn = "name" @@ -139,4 +154,4 @@ sealed trait CassandraDefaultTestConstants { Try(ConfigFactory.load().getStringList("cassandra.hosts")).map(_.get(0)).getOrElse("127.0.0.1") } val SourceProvider = "com.stratio.crossdata.connector.cassandra" -} \ No newline at end of file +} diff --git a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/statements/CreateTableStatementSpec.scala b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/statements/CreateTableStatementSpec.scala index 2ff91ec6b..499607a9a 100644 --- a/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/statements/CreateTableStatementSpec.scala +++ b/cassandra/src/test/scala/com/stratio/crossdata/connector/cassandra/statements/CreateTableStatementSpec.scala @@ -24,7 +24,7 @@ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) -class CreateTableStatementSpec extends BaseXDTest { +class CreateTableStatementSpec extends BaseXDTest { val Keyspace = "testKeyspace" val Table = "testTable" @@ -33,9 +33,8 @@ class CreateTableStatementSpec extends BaseXDTest { "A CreateTableStatementSpec" should "Build a simple CreateExternalTableStatement " in { - val schema: StructType = StructType(Seq(idField, nameField)) - val options: Map[String, String] = Map("keyspace" -> Keyspace, "primary_key_string" ->"id") + val options: Map[String, String] = Map("keyspace" -> Keyspace, "primary_key_string" -> "id") val stm = new CreateTableStatement(Table, schema, options) //Experimentation @@ -48,9 +47,9 @@ class CreateTableStatementSpec extends BaseXDTest { it should "Build a CreateExternalTableStatement with a Composed PrimKey" in { - val schema: StructType = StructType(Seq(idField, nameField)) - val options: Map[String, String] = Map("keyspace" -> Keyspace, "primary_key_string" ->"id, name") + val options: Map[String, String] = + Map("keyspace" -> Keyspace, "primary_key_string" -> "id, name") val stm = new CreateTableStatement(Table, schema, options) //Experimentation @@ -58,6 +57,7 @@ class CreateTableStatementSpec extends BaseXDTest { //Expectations print(query) - query should be(s"CREATE TABLE $Keyspace.$Table (id int, name varchar, PRIMARY KEY (id, name))") + query should be( + s"CREATE TABLE $Keyspace.$Table (id int, name varchar, PRIMARY KEY (id, name))") } } diff --git a/common/src/main/scala/com/stratio/crossdata/common/messages.scala b/common/src/main/scala/com/stratio/crossdata/common/messages.scala index 0f4193114..244631fa2 100644 --- a/common/src/main/scala/com/stratio/crossdata/common/messages.scala +++ b/common/src/main/scala/com/stratio/crossdata/common/messages.scala @@ -29,61 +29,58 @@ private[crossdata] trait Command { private[crossdata] val requestId = UUID.randomUUID() } -private[crossdata] case class SQLCommand private(sql: String, - queryId: UUID = UUID.randomUUID(), - flattenResults: Boolean = false, - timeout: Option[FiniteDuration] = None - ) extends Command { +private[crossdata] case class SQLCommand private (sql: String, + queryId: UUID = UUID.randomUUID(), + flattenResults: Boolean = false, + timeout: Option[FiniteDuration] = None) + extends Command { - def this(query: String, - retrieveColNames: Boolean, - timeoutDuration: FiniteDuration - ) = this(sql = query, flattenResults = retrieveColNames, timeout = Option(timeoutDuration)) + def this(query: String, retrieveColNames: Boolean, timeoutDuration: FiniteDuration) = + this(sql = query, flattenResults = retrieveColNames, timeout = Option(timeoutDuration)) - def this(query: String, - retrieveColNames: Boolean - ) = this(sql = query, flattenResults = retrieveColNames, timeout = None) + def this(query: String, retrieveColNames: Boolean) = + this(sql = query, flattenResults = retrieveColNames, timeout = None) } - -case class AddJARCommand(path: String, hdfsConfig: Option[Config] = None, - timeout: Option[FiniteDuration] = None, toClassPath:Option[Boolean]= None - ) extends Command { +case class AddJARCommand(path: String, + hdfsConfig: Option[Config] = None, + timeout: Option[FiniteDuration] = None, + toClassPath: Option[Boolean] = None) + extends Command { def this( - jarpath: String, - timeout: FiniteDuration - ) = this(path = jarpath, timeout = Option(timeout)) - + jarpath: String, + timeout: FiniteDuration + ) = this(path = jarpath, timeout = Option(timeout)) def this(jarpath: String) = this(path = jarpath) def this( - jarpath: String, - hdfsConf: Config - ) = this(path = jarpath, hdfsConfig = Option(hdfsConf)) + jarpath: String, + hdfsConf: Config + ) = this(path = jarpath, hdfsConfig = Option(hdfsConf)) - def this (jarpath: String, - toClassPath: Boolean - ) = this(path = jarpath, toClassPath = Option(toClassPath)) + def this(jarpath: String, toClassPath: Boolean) = + this(path = jarpath, toClassPath = Option(toClassPath)) } -case class AddAppCommand(path: String,alias:String,clss:String, - timeout: Option[FiniteDuration] = None - ) extends Command { +case class AddAppCommand(path: String, + alias: String, + clss: String, + timeout: Option[FiniteDuration] = None) + extends Command { def this( - jarpath: String, - alias:String, - clss:String, - timeout: FiniteDuration - ) = this(path = jarpath, alias,clss, timeout = Option(timeout)) + jarpath: String, + alias: String, + clss: String, + timeout: FiniteDuration + ) = this(path = jarpath, alias, clss, timeout = Option(timeout)) def this( - jarpath: String, - alias:String, - clss:String - )= this(jarpath, alias,clss,None) - + jarpath: String, + alias: String, + clss: String + ) = this(jarpath, alias, clss, None) } case class ClusterStateCommand() extends Command @@ -100,7 +97,6 @@ private[crossdata] case class CancelQueryExecution(queryId: UUID) extends Contro private[crossdata] case class CommandEnvelope(cmd: Command, session: Session) - // Server -> Driver messages private[crossdata] trait ServerReply { def requestId: UUID @@ -110,9 +106,11 @@ private[crossdata] case class QueryCancelledReply(requestId: UUID) extends Serve private[crossdata] case class SQLReply(requestId: UUID, sqlResult: SQLResult) extends ServerReply -private[crossdata] case class ClusterStateReply(requestId: UUID, clusterState: CurrentClusterState) extends ServerReply - -private[crossdata] case class OpenSessionReply(requestId: UUID, isOpen: Boolean) extends ServerReply +private[crossdata] case class ClusterStateReply(requestId: UUID, clusterState: CurrentClusterState) + extends ServerReply -private[crossdata] case class AddHdfsFileReply(requestId: UUID, hdfsRoute: String) extends ServerReply +private[crossdata] case class OpenSessionReply(requestId: UUID, isOpen: Boolean) + extends ServerReply +private[crossdata] case class AddHdfsFileReply(requestId: UUID, hdfsRoute: String) + extends ServerReply diff --git a/common/src/main/scala/com/stratio/crossdata/common/package.scala b/common/src/main/scala/com/stratio/crossdata/common/package.scala index 54fa4329b..9916ab9c8 100644 --- a/common/src/main/scala/com/stratio/crossdata/common/package.scala +++ b/common/src/main/scala/com/stratio/crossdata/common/package.scala @@ -19,6 +19,7 @@ import scala.io.Source package object crossdata { - lazy val CrossdataVersion = Source.fromInputStream(getClass.getResourceAsStream("/crossdata.version")).mkString + lazy val CrossdataVersion = + Source.fromInputStream(getClass.getResourceAsStream("/crossdata.version")).mkString } diff --git a/common/src/main/scala/com/stratio/crossdata/common/result/responses.scala b/common/src/main/scala/com/stratio/crossdata/common/result/responses.scala index fbeeb3bb3..b758302c0 100644 --- a/common/src/main/scala/com/stratio/crossdata/common/result/responses.scala +++ b/common/src/main/scala/com/stratio/crossdata/common/result/responses.scala @@ -32,11 +32,13 @@ case class SQLResponse(id: UUID, sqlResult: Future[SQLResult]) extends Response } getOrElse ErrorSQLResult(s"Not found answer to query $id. Timeout was exceed.") } - def cancelCommand(): Unit = throw new RuntimeException("The query cannot be cancelled. Use sql(query).cancelCommand") + def cancelCommand(): Unit = + throw new RuntimeException("The query cannot be cancelled. Use sql(query).cancelCommand") } case class QueryCancelledResponse(id: UUID) extends Response object SQLResponse { - implicit def sqlResponseToSQLResult(response: SQLResponse): SQLResult = response.waitForResult() + implicit def sqlResponseToSQLResult(response: SQLResponse): SQLResult = + response.waitForResult() } diff --git a/common/src/main/scala/com/stratio/crossdata/common/result/results.scala b/common/src/main/scala/com/stratio/crossdata/common/result/results.scala index a9918493b..b39486ec6 100644 --- a/common/src/main/scala/com/stratio/crossdata/common/result/results.scala +++ b/common/src/main/scala/com/stratio/crossdata/common/result/results.scala @@ -19,7 +19,6 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.Row import org.apache.spark.sql.types.StructType - trait Result { def hasError: Boolean } @@ -32,10 +31,10 @@ trait SQLResult extends Result { def schema: StructType /** - * NOTE: This method is based on the method org.apache.spark.sql.DataFrame#showString from Apache Spark. - * For more information, go to http://spark.apache.org. - * Compose the string representing rows for output - */ + * NOTE: This method is based on the method org.apache.spark.sql.DataFrame#showString from Apache Spark. + * For more information, go to http://spark.apache.org. + * Compose the string representing rows for output + */ def prettyResult: Array[String] = { val sb = new StringBuilder @@ -45,10 +44,10 @@ trait SQLResult extends Result { // For cells that are beyond 20 characters, replace it with the first 17 and "..." val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: resultSet.map { row => row.toSeq.map { - case null => "null" - case array: Array[_] => array.mkString("[", ", ", "]") - case seq: Seq[_] => seq.mkString("[", ", ", "]") - case cell => cell.toString + case null => "null" + case array: Array[_] => array.mkString("[", ", ", "]") + case seq: Seq[_] => seq.mkString("[", ", ", "]") + case cell => cell.toString }: Seq[String] } @@ -66,16 +65,18 @@ trait SQLResult extends Result { val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() // column names - rows.head.zipWithIndex.map { case (cell, i) => - StringUtils.rightPad(cell, colWidths(i)) + rows.head.zipWithIndex.map { + case (cell, i) => + StringUtils.rightPad(cell, colWidths(i)) }.addString(sb, "|", "|", "|\n") sb.append(sep) // data rows.tail.map { - _.zipWithIndex.map { case (cell, i) => - StringUtils.rightPad(cell.toString, colWidths(i)) + _.zipWithIndex.map { + case (cell, i) => + StringUtils.rightPad(cell.toString, colWidths(i)) }.addString(sb, "|", "|", "|\n") } @@ -93,12 +94,7 @@ case class ErrorSQLResult(message: String, cause: Option[Throwable] = None) exte override lazy val schema = throw mkException private def mkException: Exception = - cause.map(throwable => new RuntimeException(message, throwable)).getOrElse(new RuntimeException(message)) + cause + .map(throwable => new RuntimeException(message, throwable)) + .getOrElse(new RuntimeException(message)) } - - - - - - - diff --git a/common/src/main/scala/com/stratio/crossdata/common/security/Session.scala b/common/src/main/scala/com/stratio/crossdata/common/security/Session.scala index e75615585..217e4117e 100644 --- a/common/src/main/scala/com/stratio/crossdata/common/security/Session.scala +++ b/common/src/main/scala/com/stratio/crossdata/common/security/Session.scala @@ -19,4 +19,4 @@ import java.util.UUID import akka.actor.ActorRef -private [crossdata] case class Session(id: UUID, clientRef: ActorRef) extends Serializable +private[crossdata] case class Session(id: UUID, clientRef: ActorRef) extends Serializable diff --git a/common/src/main/scala/com/stratio/crossdata/common/util/akka/keepalive/KeepAliveMaster.scala b/common/src/main/scala/com/stratio/crossdata/common/util/akka/keepalive/KeepAliveMaster.scala index 7711ad427..dbbde7a59 100644 --- a/common/src/main/scala/com/stratio/crossdata/common/util/akka/keepalive/KeepAliveMaster.scala +++ b/common/src/main/scala/com/stratio/crossdata/common/util/akka/keepalive/KeepAliveMaster.scala @@ -43,7 +43,8 @@ object KeepAliveMaster { */ case class HeartbeatLost[T](id: T) - def props[ID](client: ActorRef): Props = Props(new KeepAliveMaster[ID](client)) + def props[ID](client: ActorRef): Props = + Props(new KeepAliveMaster[ID](client)) } @@ -56,22 +57,21 @@ class KeepAliveMaster[ID](client: ActorRef) extends Actor { def receive(pending: Set[ID]): Receive = { - case HeartBeat(id: ID @ unchecked) => + case HeartBeat(id: ID @unchecked) => context.become(receive(pending - id)) - case m @ DoCheck(id: ID @ unchecked, period, continue) => + case m @ DoCheck(id: ID @unchecked, period, continue) => import context.dispatcher val missing = pending contains id - if(missing) client ! HeartbeatLost(id) + if (missing) client ! HeartbeatLost(id) - if(!missing || continue) { + if (!missing || continue) { context.system.scheduler.scheduleOnce(period, self, m) context.become(receive(pending + id)) } } - -} \ No newline at end of file +} diff --git a/common/src/main/scala/org/apache/spark/sql/crossdata/metadata/DataTypesUtils.scala b/common/src/main/scala/org/apache/spark/sql/crossdata/metadata/DataTypesUtils.scala index 3b00a7861..7881887f1 100644 --- a/common/src/main/scala/org/apache/spark/sql/crossdata/metadata/DataTypesUtils.scala +++ b/common/src/main/scala/org/apache/spark/sql/crossdata/metadata/DataTypesUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.crossdata.metadata import org.apache.spark.sql.types.{DataType, DataTypeParser} - object DataTypesUtils { - def toDataType(stringType: String): DataType = DataTypeParser.parse(stringType) + def toDataType(stringType: String): DataType = + DataTypeParser.parse(stringType) } diff --git a/common/src/test/scala/com/stratio/crossdata/common/SQLResultSpec.scala b/common/src/test/scala/com/stratio/crossdata/common/SQLResultSpec.scala index f38ee455a..7213e2773 100644 --- a/common/src/test/scala/com/stratio/crossdata/common/SQLResultSpec.scala +++ b/common/src/test/scala/com/stratio/crossdata/common/SQLResultSpec.scala @@ -28,12 +28,12 @@ import org.scalatest.junit.JUnitRunner import org.scalatest.mock.MockitoSugar @RunWith(classOf[JUnitRunner]) -class SQLResultSpec extends BaseXDTest with MockitoSugar{ +class SQLResultSpec extends BaseXDTest with MockitoSugar { "An error result" should "have an empty result" in { val error = ErrorSQLResult("message") - error.hasError should be (true) - a [RuntimeException] should be thrownBy error.resultSet + error.hasError should be(true) + a[RuntimeException] should be thrownBy error.resultSet } "An SuccessfulQueryResult " should "have a resultSet" in { @@ -51,7 +51,7 @@ class SQLResultSpec extends BaseXDTest with MockitoSugar{ res should not be (null) res should be equals Array(row) - hasError should be (false) + hasError should be(false) } } diff --git a/common/src/test/scala/com/stratio/crossdata/common/util/akka/keepalive/KeepAliveSpec.scala b/common/src/test/scala/com/stratio/crossdata/common/util/akka/keepalive/KeepAliveSpec.scala index 0c7e703cc..89b26139a 100644 --- a/common/src/test/scala/com/stratio/crossdata/common/util/akka/keepalive/KeepAliveSpec.scala +++ b/common/src/test/scala/com/stratio/crossdata/common/util/akka/keepalive/KeepAliveSpec.scala @@ -23,17 +23,15 @@ import org.scalatest.{FlatSpecLike, Matchers} import scala.concurrent.duration._ +class KeepAliveSpec extends TestKit(ActorSystem("KeepAliveSpec")) with FlatSpecLike with Matchers { -class KeepAliveSpec extends TestKit(ActorSystem("KeepAliveSpec")) - with FlatSpecLike with Matchers { - - class MonitoredActor(override val keepAliveId: Int, override val master: ActorRef) extends LiveMan[Int] { + class MonitoredActor(override val keepAliveId: Int, override val master: ActorRef) + extends LiveMan[Int] { override val period: FiniteDuration = 100 milliseconds override def receive: Receive = PartialFunction.empty } - "A LiveMan Actor" should "periodically send HearBeat message providing its id" in { val kaId = 1 @@ -44,8 +42,6 @@ class KeepAliveSpec extends TestKit(ActorSystem("KeepAliveSpec")) system.stop(liveMan) } - - "A Master Actor" should "detect when a LiveManActor stops beating" in { val master: ActorRef = system.actorOf(KeepAliveMaster.props[Int](testActor)) @@ -74,7 +70,7 @@ class KeepAliveSpec extends TestKit(ActorSystem("KeepAliveSpec")) system.stop(lastActor) expectMsg(500 milliseconds, HeartbeatLost(lastId)) - + liveMen foreach { case (_, monitoredActor) => system.stop(monitoredActor) } @@ -83,5 +79,4 @@ class KeepAliveSpec extends TestKit(ActorSystem("KeepAliveSpec")) } - } diff --git a/core/src/main/scala/com/stratio/crossdata/connector/SQLLikeUDFQueryProcessorUtils.scala b/core/src/main/scala/com/stratio/crossdata/connector/SQLLikeUDFQueryProcessorUtils.scala index c1da46530..bfa982edc 100644 --- a/core/src/main/scala/com/stratio/crossdata/connector/SQLLikeUDFQueryProcessorUtils.scala +++ b/core/src/main/scala/com/stratio/crossdata/connector/SQLLikeUDFQueryProcessorUtils.scala @@ -29,16 +29,14 @@ object SQLLikeUDFQueryProcessorUtils { } } -trait SQLLikeUDFQueryProcessorUtils { - self: SQLLikeQueryProcessorUtils => - +trait SQLLikeUDFQueryProcessorUtils { self: SQLLikeQueryProcessorUtils => import SQLLikeUDFQueryProcessorUtils.ContextWithUDFs override type ProcessingContext <: ContextWithUDFs override def quoteString(in: Any)(implicit context: ProcessingContext): String = in match { - case s @ (_:String | _: Timestamp) => s"'$s'" + case s @ (_: String | _: Timestamp) => s"'$s'" case a: Attribute => expandAttribute(a.toString) case other => other.toString } @@ -46,11 +44,13 @@ trait SQLLikeUDFQueryProcessorUtils { // UDFs are string references in both filters and projects => lookup in udfsMap def expandAttribute(att: String)(implicit context: ProcessingContext): String = { implicit val udfs = context.asInstanceOf[SQLLikeUDFQueryProcessorUtils#ProcessingContext].udfs - udfs get(att) map { udf => + udfs get (att) map { udf => val actualParams = udf.children.collect { //TODO: Add type checker (maybe not here) - case at: AttributeReference if(udfs contains at.toString) => expandAttribute(at.toString) + case at: AttributeReference if (udfs contains at.toString) => + expandAttribute(at.toString) case at: AttributeReference => at.name - case lit @ Literal(_, DataTypes.StringType) => quoteString(lit.toString) + case lit @ Literal(_, DataTypes.StringType) => + quoteString(lit.toString) case lit: Literal => lit.toString } mkString "," s"${udf.name}($actualParams)" diff --git a/core/src/main/scala/com/stratio/crossdata/connector/interfaces.scala b/core/src/main/scala/com/stratio/crossdata/connector/interfaces.scala index c7ad1a9be..3d5ad4fb0 100644 --- a/core/src/main/scala/com/stratio/crossdata/connector/interfaces.scala +++ b/core/src/main/scala/com/stratio/crossdata/connector/interfaces.scala @@ -25,81 +25,84 @@ import org.apache.spark.sql.types.{DataType, StructType} import scala.util.Try - /** - * A BaseRelation that can execute the whole logical plan without running the query - * on the Spark cluster. If a specific logical plan cannot be resolved by the datasource - * a None should be returned and the process will be executed on Spark. - */ + * A BaseRelation that can execute the whole logical plan without running the query + * on the Spark cluster. If a specific logical plan cannot be resolved by the datasource + * a None should be returned and the process will be executed on Spark. + */ @DeveloperApi trait NativeScan extends PushDownable { def buildScan(optimizedLogicalPlan: LogicalPlan): Option[Array[Row]] } /** - * Interface for asking whether the datasource is able to push down an isolated logical plan. - */ + * Interface for asking whether the datasource is able to push down an isolated logical plan. + */ @DeveloperApi sealed trait PushDownable { + /** - * Checks the ability to execute a [[LogicalPlan]]. - * - * @param logicalStep isolated plan - * @param wholeLogicalPlan the whole DataFrame tree - * @return whether the logical step within the entire logical plan is supported - */ + * Checks the ability to execute a [[LogicalPlan]]. + * + * @param logicalStep isolated plan + * @param wholeLogicalPlan the whole DataFrame tree + * @return whether the logical step within the entire logical plan is supported + */ def isSupported(logicalStep: LogicalPlan, wholeLogicalPlan: LogicalPlan): Boolean } - sealed trait GenerateConnectorOptions { import TableInventory.Table /** - * - * @param item Table description case class instance - * @param userOpts Options provided by the parsed sentence - * @return A concrete (for a given connector) translation of the high level table description - * to a low-level option map. - */ - def generateConnectorOpts(item: Table, userOpts: Map[String, String] = Map.empty): Map[String, String] + * + * @param item Table description case class instance + * @param userOpts Options provided by the parsed sentence + * @return A concrete (for a given connector) translation of the high level table description + * to a low-level option map. + */ + def generateConnectorOpts(item: Table, + userOpts: Map[String, String] = Map.empty): Map[String, String] } + /** - * Interface including data source operations for listing and describing tables - * at a data source. - * - */ + * Interface including data source operations for listing and describing tables + * at a data source. + * + */ @DeveloperApi -trait TableInventory extends GenerateConnectorOptions{ +trait TableInventory extends GenerateConnectorOptions { import TableInventory.Table /** - * Overriding this function allows tables import filtering. e.g: Avoiding system tables. - * - * @param table Table description case class instance - * @return `true` if the table shall be imported, `false` otherwise - */ + * Overriding this function allows tables import filtering. e.g: Avoiding system tables. + * + * @param table Table description case class instance + * @return `true` if the table shall be imported, `false` otherwise + */ def exclusionFilter(table: TableInventory.Table): Boolean = true /** - * - * @param context SQLContext at which the command will be executed. - * @param options SQL Sentence user options - * @return A list of tables descriptions extracted from the datasource using a connector.0 - */ + * + * @param context SQLContext at which the command will be executed. + * @param options SQL Sentence user options + * @return A list of tables descriptions extracted from the datasource using a connector.0 + */ def listTables(context: SQLContext, options: Map[String, String]): Seq[Table] } object TableInventory { //Table description - case class Table(tableName: String, database: Option[String] = None, schema: Option[StructType] = None) + case class Table(tableName: String, + database: Option[String] = None, + schema: Option[StructType] = None) } /* Interface for providing lists and UDF discovery services */ -trait FunctionInventory extends DataSourceRegister{ +trait FunctionInventory extends DataSourceRegister { import FunctionInventory.UDF //Get builtin functions manifest @@ -108,16 +111,22 @@ trait FunctionInventory extends DataSourceRegister{ object FunctionInventory { //Native function (either built-in or user defined) description. - case class UDF(name: String, database: Option[String] = None, formalParameters: StructType, returnType: DataType) + case class UDF(name: String, + database: Option[String] = None, + formalParameters: StructType, + returnType: DataType) - def qualifyUDF(datasourceName: String, udfName: String) = s"${datasourceName}_$udfName" + def qualifyUDF(datasourceName: String, udfName: String) = + s"${datasourceName}_$udfName" } /** Interface for data sources which are able to execute functions (native or user defined) natively - */ + */ trait NativeFunctionExecutor { - def buildScan(requiredColumns: Array[String], filters: Array[Filter], udfs: Map[String, NativeUDF]): RDD[Row] + def buildScan(requiredColumns: Array[String], + filters: Array[Filter], + udfs: Map[String, NativeUDF]): RDD[Row] } /** @@ -125,8 +134,7 @@ trait NativeFunctionExecutor { * CREATE/DROP EXTERNAL TABLE * */ -trait TableManipulation extends GenerateConnectorOptions{ - +trait TableManipulation extends GenerateConnectorOptions { def createExternalTable(context: SQLContext, tableName: String, @@ -134,6 +142,5 @@ trait TableManipulation extends GenerateConnectorOptions{ schema: StructType, options: Map[String, String]): Option[TableInventory.Table] - def dropExternalTable(context: SQLContext, - options: Map[String, String]): Try[Unit] -} \ No newline at end of file + def dropExternalTable(context: SQLContext, options: Map[String, String]): Try[Unit] +} diff --git a/core/src/main/scala/com/stratio/crossdata/util/HdfsUtils.scala b/core/src/main/scala/com/stratio/crossdata/util/HdfsUtils.scala index 83d8ce37d..2ddb7a1cc 100644 --- a/core/src/main/scala/com/stratio/crossdata/util/HdfsUtils.scala +++ b/core/src/main/scala/com/stratio/crossdata/util/HdfsUtils.scala @@ -27,11 +27,12 @@ import scala.util.Try case class HdfsUtils(dfs: FileSystem, userName: String) { - def getFiles(path: String): Array[FileStatus] = dfs.listStatus(new Path(path)) + def getFiles(path: String): Array[FileStatus] = + dfs.listStatus(new Path(path)) def getFile(filename: String): InputStream = dfs.open(new Path(filename)) - def fileExist(fileName:String): Boolean = dfs.exists(new Path(fileName)) + def fileExist(fileName: String): Boolean = dfs.exists(new Path(fileName)) def delete(path: String): Unit = { dfs.delete(new Path(path), true) @@ -53,7 +54,7 @@ object HdfsUtils extends SLF4JLogging { private final val DefaultFSProperty = "fs.defaultFS" private final val HdfsDefaultPort = 9000 - def apply(user: String, namenode:String): HdfsUtils = { + def apply(user: String, namenode: String): HdfsUtils = { val conf = new Configuration() conf.set(DefaultFSProperty, namenode) log.debug(s"Configuring HDFS with master: ${conf.get(DefaultFSProperty)} and user: $user") @@ -62,8 +63,8 @@ object HdfsUtils extends SLF4JLogging { } def apply(config: Config): HdfsUtils = { - val namenode=config.getString("namenode") + val namenode = config.getString("namenode") val user = config.getString("user") apply(user, namenode) } -} \ No newline at end of file +} diff --git a/core/src/main/scala/com/stratio/crossdata/util/utils.scala b/core/src/main/scala/com/stratio/crossdata/util/utils.scala index c210b11c6..5f9dba911 100644 --- a/core/src/main/scala/com/stratio/crossdata/util/utils.scala +++ b/core/src/main/scala/com/stratio/crossdata/util/utils.scala @@ -17,18 +17,15 @@ package com.stratio.crossdata.util import scala.util.Try - object using { - type AutoClosable = {def close(): Unit} + type AutoClosable = { def close(): Unit } def apply[A <: AutoClosable, B](resource: A)(code: A => B): B = try { code(resource) - } - finally { + } finally { Try(resource.close()) } } - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/ExecutionType.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/ExecutionType.scala index 1e28de98c..a10d2dae3 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/ExecutionType.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/ExecutionType.scala @@ -15,7 +15,7 @@ */ package org.apache.spark.sql.crossdata -object ExecutionType extends Enumeration{ +object ExecutionType extends Enumeration { type ExecutionType = Value val Default, Spark, Native = Value -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/XDContext.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/XDContext.scala index 009d52550..9e28f8d69 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/XDContext.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/XDContext.scala @@ -64,8 +64,10 @@ import scala.util.{Failure, Success, Try} * @param sc A [[SparkContext]]. */ class XDContext protected (@transient val sc: SparkContext, - userConfig: Option[Config] = None, - credentials: Credentials = Credentials()) extends SQLContext(sc) with Logging { + userConfig: Option[Config] = None, + credentials: Credentials = Credentials()) + extends SQLContext(sc) + with Logging { self => def this(sc: SparkContext) = @@ -84,43 +86,43 @@ class XDContext protected (@transient val sc: SparkContext, Config should be changed by a map and implicitly converted into `Config` whenever one of its methods is called. - */ + */ xdConfig = userConfig.fold(config) { userConf => userConf.withFallback(config) } - catalogConfig = Try(xdConfig.getConfig(CoreConfig.CatalogConfigKey)).getOrElse(ConfigFactory.empty()) - - - override protected[sql] lazy val conf: SQLConf = - userConfig.map{ coreConfig => - configToSparkSQL(coreConfig, new SQLConf) - }.getOrElse(new SQLConf) + catalogConfig = + Try(xdConfig.getConfig(CoreConfig.CatalogConfigKey)).getOrElse(ConfigFactory.empty()) + override protected[sql] lazy val conf: SQLConf = userConfig.map { coreConfig => + configToSparkSQL(coreConfig, new SQLConf) + }.getOrElse(new SQLConf) @transient override protected[sql] lazy val catalog: XDCatalog = { - val catalogs: List[XDCatalogCommon] = temporaryCatalog :: externalCatalog :: streamingCatalog.toList - CatalogChain(catalogs:_*)(self) + val catalogs: List[XDCatalogCommon] = temporaryCatalog :: externalCatalog :: streamingCatalog.toList + CatalogChain(catalogs: _*)(self) } @transient protected lazy val temporaryCatalog: XDTemporaryCatalog = new HashmapCatalog(conf) @transient - protected lazy val externalCatalog: XDPersistentCatalog = CatalogUtils.externalCatalog(conf, catalogConfig) + protected lazy val externalCatalog: XDPersistentCatalog = + CatalogUtils.externalCatalog(conf, catalogConfig) @transient - protected lazy val streamingCatalog: Option[XDStreamingCatalog] = CatalogUtils.streamingCatalog(conf, xdConfig) - + protected lazy val streamingCatalog: Option[XDStreamingCatalog] = + CatalogUtils.streamingCatalog(conf, xdConfig) @transient protected[crossdata] lazy val securityManager = { import CoreConfig._ - val securityClass = Try(xdConfig.getString(SecurityClassConfigKey)).getOrElse(DefaultSecurityManager) + val securityClass = + Try(xdConfig.getString(SecurityClassConfigKey)).getOrElse(DefaultSecurityManager) val audit: java.lang.Boolean = { if (xdConfig.hasPath(SecurityAuditConfigKey)) @@ -136,62 +138,56 @@ class XDContext protected (@transient val sc: SparkContext, val securityManagerClass = Class.forName(securityClass) val fallbackCredentials = Credentials( - user = credentials.user.orElse(userConfig), - password = credentials.password.orElse(passwordConfig), - sessionId = credentials.sessionId.orElse(sessionIdConfig) + user = credentials.user.orElse(userConfig), + password = credentials.password.orElse(passwordConfig), + sessionId = credentials.sessionId.orElse(sessionIdConfig) ) - val constr: Constructor[_] = securityManagerClass.getConstructor(classOf[Credentials], classOf[Boolean]) + val constr: Constructor[_] = + securityManagerClass.getConstructor(classOf[Credentials], classOf[Boolean]) constr.newInstance(fallbackCredentials, audit).asInstanceOf[SecurityManager] } - @transient override protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = ResolveAggregateAlias :: ExtractPythonUDFs :: - ExtractNativeUDFs :: - PreInsertCastAndRename :: - Nil + ExtractNativeUDFs :: + PreInsertCastAndRename :: + Nil override val extendedCheckRules = Seq( - PreWriteCheck(catalog) + PreWriteCheck(catalog) ) - val preparationRules = Seq(PrepareAggregateAlias) override lazy val batches: Seq[Batch] = Seq( - Batch("Substitution", fixedPoint, - CTESubstitution, - WindowsSubstitution), - Batch("Preparation", fixedPoint, preparationRules : _*), - Batch("Resolution", fixedPoint, - WrapRelationWithGlobalIndex(catalog) :: - ResolveRelations :: - ResolveReferences :: - ResolveGroupingAnalytics :: - ResolvePivot :: - ResolveUpCast :: - ResolveSortReferences :: - ResolveGenerate :: - ResolveFunctions :: - ResolveAliases :: - ExtractWindowExpressions :: - GlobalAggregates :: - ResolveAggregateFunctions :: - DistinctAggregationRewriter(conf) :: - HiveTypeCoercion.typeCoercionRules ++ - extendedResolutionRules : _*), - Batch("Nondeterministic", Once, - PullOutNondeterministic, - ComputeCurrentTime), - Batch("UDF", Once, - HandleNullInputsForUDF), - Batch("Cleanup", fixedPoint, - CleanupAliases) + Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution), + Batch("Preparation", fixedPoint, preparationRules: _*), + Batch("Resolution", + fixedPoint, + WrapRelationWithGlobalIndex(catalog) :: + ResolveRelations :: + ResolveReferences :: + ResolveGroupingAnalytics :: + ResolvePivot :: + ResolveUpCast :: + ResolveSortReferences :: + ResolveGenerate :: + ResolveFunctions :: + ResolveAliases :: + ExtractWindowExpressions :: + GlobalAggregates :: + ResolveAggregateFunctions :: + DistinctAggregationRewriter(conf) :: + HiveTypeCoercion.typeCoercionRules ++ + extendedResolutionRules: _*), + Batch("Nondeterministic", Once, PullOutNondeterministic, ComputeCurrentTime), + Batch("UDF", Once, HandleNullInputsForUDF), + Batch("Cleanup", fixedPoint, CleanupAliases) ) } @@ -200,7 +196,8 @@ class XDContext protected (@transient val sc: SparkContext, @transient class XDPlanner extends sparkexecution.SparkPlanner(this) with XDStrategies { - override def strategies: Seq[Strategy] = Seq(XDDDLStrategy, ExtendedDataSourceStrategy) ++ super.strategies + override def strategies: Seq[Strategy] = + Seq(XDDDLStrategy, ExtendedDataSourceStrategy) ++ super.strategies } @transient @@ -227,11 +224,12 @@ class XDContext protected (@transient val sc: SparkContext, { //Register built-in UDFs for each provider available. import FunctionInventory.qualifyUDF - for {srv <- functionInventoryServices - datasourceName = srv.shortName() - udf <- srv.nativeBuiltinFunctions - } functionRegistry - .registerFunction(qualifyUDF(datasourceName, udf.name), e => NativeUDF(udf.name, udf.returnType, e)) + for { + srv <- functionInventoryServices + datasourceName = srv.shortName() + udf <- srv.nativeBuiltinFunctions + } functionRegistry.registerFunction(qualifyUDF(datasourceName, udf.name), + e => NativeUDF(udf.name, udf.returnType, e)) val gc = new GroupConcat(", ") udf.register("group_concat", gc) @@ -248,20 +246,22 @@ class XDContext protected (@transient val sc: SparkContext, */ def addJar(path: String, toClasspath: Option[Boolean] = None) = { super.addJar(path) - if ((path.toLowerCase.startsWith("hdfs://")) && (toClasspath.getOrElse(true))){ + if ((path.toLowerCase.startsWith("hdfs://")) && (toClasspath.getOrElse(true))) { val hdfsIS: InputStream = HdfsUtils(xdConfig.getConfig(CoreConfig.HdfsKey)).getFile(path) - val file: java.io.File = createFile(hdfsIS, s"${xdConfig.getConfig(CoreConfig.JarsRepo).getString("externalJars")}/${path.split("/").last}") + val file: java.io.File = createFile( + hdfsIS, + s"${xdConfig.getConfig(CoreConfig.JarsRepo).getString("externalJars")}/${path.split("/").last}") addToClasspath(file) - }else if (scala.reflect.io.File(path).exists){ - val file=new java.io.File(path) + } else if (scala.reflect.io.File(path).exists) { + val file = new java.io.File(path) addToClasspath(file) - }else{ + } else { sys.error("File doesn't exist or is not a hdfs file") } } - private def addToClasspath(file:java.io.File): Unit = { + private def addToClasspath(file: java.io.File): Unit = { if (file.exists) { val method: Method = classOf[URLClassLoader].getDeclaredMethod("addURL", classOf[URL]) method.setAccessible(true) @@ -284,11 +284,20 @@ class XDContext protected (@transient val sc: SparkContext, catalog.lookupApp(alias) } - def executeApp(appName: String, arguments: Seq[String], submitOptions: Option[Map[String, String]] = None): Seq[Row] = { + def executeApp(appName: String, + arguments: Seq[String], + submitOptions: Option[Map[String, String]] = None): Seq[Row] = { import scala.concurrent.ExecutionContext.Implicits.global - val crossdataApp = catalog.lookupApp(appName).getOrElse(sys.error(s"There is not any app called $appName")) + val crossdataApp = + catalog.lookupApp(appName).getOrElse(sys.error(s"There is not any app called $appName")) val launcherConfig = xdConfig.getConfig(CoreConfig.LauncherKey) - SparkJobLauncher.getSparkJob(launcherConfig, this.sparkContext.master, crossdataApp.appClass, arguments, crossdataApp.jar, crossdataApp.appAlias, submitOptions) match { + SparkJobLauncher.getSparkJob(launcherConfig, + this.sparkContext.master, + crossdataApp.appClass, + arguments, + crossdataApp.jar, + crossdataApp.appAlias, + submitOptions) match { case Failure(exception) => logError(exception.getMessage, exception) sys.error("Validation error: " + exception.getMessage) @@ -332,7 +341,6 @@ class XDContext protected (@transient val sc: SparkContext, def dropGlobalIndex(indexIdentifier: IndexIdentifier): Unit = catalog.dropIndex(indexIdentifier) - /** * Imports tables from a DataSource in the persistent catalog. * @@ -342,7 +350,6 @@ class XDContext protected (@transient val sc: SparkContext, def importTables(datasource: String, opts: Map[String, String]): Unit = ImportTablesUsingWithOptions(datasource, opts).run(this) - /** * Check if there is Connection with the catalog * @@ -351,11 +358,9 @@ class XDContext protected (@transient val sc: SparkContext, def checkCatalogConnection: Boolean = catalog.checkConnectivity - def createDataFrame(rows: Seq[Row], schema: StructType): DataFrame = DataFrame(self, LocalRelation.fromExternalRows(schema.toAttributes, rows)) - XDContext.setLastInstantiatedContext(self) } @@ -376,7 +381,6 @@ object XDContext extends CoreConfig { //This is definitely NOT right and will only work as long a single instance of XDContext exits var catalogConfig: Config = _ //This is definitely NOT right and will only work as long a single instance of XDContext exits - @transient private val INSTANTIATION_LOCK = new Object() /** @@ -391,9 +395,11 @@ object XDContext extends CoreConfig { */ def getOrCreate(sparkContext: SparkContext, userConfig: Option[Config] = None): XDContext = { INSTANTIATION_LOCK.synchronized { - Option(lastInstantiatedContext.get()).filter( - _.getClass == classOf[XDContext] - ).getOrElse(new XDContext(sparkContext, userConfig)) + Option(lastInstantiatedContext.get()) + .filter( + _.getClass == classOf[XDContext] + ) + .getOrElse(new XDContext(sparkContext, userConfig)) } lastInstantiatedContext.get() } @@ -413,4 +419,3 @@ object XDContext extends CoreConfig { } } - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/XDDataFrame.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/XDDataFrame.scala index 98a0277d2..ae74e7cc8 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/XDDataFrame.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/XDDataFrame.scala @@ -56,14 +56,14 @@ private[sql] object XDDataFrame { } /** - * Finds a [[org.apache.spark.sql.sources.BaseRelation]] mixing-in [[NativeScan]] supporting native execution. - * - * The logical plan must involve only base relation from the same datasource implementation. For example, - * if there is a join with a [[org.apache.spark.rdd.RDD]] the logical plan cannot be executed natively. - * - * @param optimizedLogicalPlan the logical plan once it has been processed by the parser, analyzer and optimizer. - * @return - */ + * Finds a [[org.apache.spark.sql.sources.BaseRelation]] mixing-in [[NativeScan]] supporting native execution. + * + * The logical plan must involve only base relation from the same datasource implementation. For example, + * if there is a join with a [[org.apache.spark.rdd.RDD]] the logical plan cannot be executed natively. + * + * @param optimizedLogicalPlan the logical plan once it has been processed by the parser, analyzer and optimizer. + * @return + */ def findNativeQueryExecutor(optimizedLogicalPlan: LogicalPlan): Option[NativeScan] = { def allLeafsAreNative(leafs: Seq[LeafNode]): Boolean = { @@ -73,12 +73,16 @@ private[sql] object XDDataFrame { } } - val leafs = optimizedLogicalPlan.collect { case leafNode: LeafNode => leafNode } + val leafs = optimizedLogicalPlan.collect { + case leafNode: LeafNode => leafNode + } if (!allLeafsAreNative(leafs)) { None } else { - val nativeExecutors: Seq[NativeScan] = leafs.map { case LogicalRelation(ns: NativeScan, _) => ns } + val nativeExecutors: Seq[NativeScan] = leafs.map { + case LogicalRelation(ns: NativeScan, _) => ns + } nativeExecutors match { case seq if seq.length == 1 => @@ -86,8 +90,8 @@ private[sql] object XDDataFrame { case _ => if (nativeExecutors.sliding(2).forall { tuple => - tuple.head.getClass == tuple.head.getClass - }) { + tuple.head.getClass == tuple.head.getClass + }) { nativeExecutors.headOption } else { None @@ -99,11 +103,12 @@ private[sql] object XDDataFrame { } /** - * Extends a [[DataFrame]] to provide native access to datasources when performing Spark actions. - */ -class XDDataFrame private[sql](@transient override val sqlContext: SQLContext, - @transient override val queryExecution: QueryExecution) - extends DataFrame(sqlContext, queryExecution) with SparkLoggerComponent { + * Extends a [[DataFrame]] to provide native access to datasources when performing Spark actions. + */ +class XDDataFrame private[sql] (@transient override val sqlContext: SQLContext, + @transient override val queryExecution: QueryExecution) + extends DataFrame(sqlContext, queryExecution) + with SparkLoggerComponent { def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = { this(sqlContext, { @@ -112,21 +117,21 @@ class XDDataFrame private[sql](@transient override val sqlContext: SQLContext, qe.assertAnalyzed() // This should force analysis and throw errors if there are any } qe - } - ) + }) } /** - * @inheritdoc - */ + * @inheritdoc + */ override def collect(): Array[Row] = { sqlContext.asInstanceOf[XDContext].securityManager.authorize(logicalPlan) // If cache doesn't go through native if (sqlContext.cacheManager.lookupCachedData(this).nonEmpty) { super.collect() } else { - val nativeQueryExecutor: Option[NativeScan] = findNativeQueryExecutor(queryExecution.optimizedPlan) - if(nativeQueryExecutor.isEmpty){ + val nativeQueryExecutor: Option[NativeScan] = findNativeQueryExecutor( + queryExecution.optimizedPlan) + if (nativeQueryExecutor.isEmpty) { logInfo(s"Spark Query: ${queryExecution.simpleString}") } else { logInfo(s"Native query: ${queryExecution.simpleString}") @@ -137,16 +142,17 @@ class XDDataFrame private[sql](@transient override val sqlContext: SQLContext, def flattenedCollect(): Array[Row] = { - def flattenProjectedColumns(exp: Expression, prev: List[String] = Nil): (List[String], Boolean) = exp match { - case GetStructField(child, _, Some(fieldName)) => + def flattenProjectedColumns(exp: Expression, + prev: List[String] = Nil): (List[String], Boolean) = exp match { + case GetStructField(child, _, Some(fieldName)) => flattenProjectedColumns(child, fieldName :: prev) - case GetArrayStructFields(child, field,_,_,_)=> + case GetArrayStructFields(child, field, _, _, _) => flattenProjectedColumns(child, field.name :: prev) case AttributeReference(name, _, _, _) => (name :: prev, false) case Alias(child @ GetStructField(_, _, Some(fname)), name) if fname == name => flattenProjectedColumns(child) - case Alias(child @ GetArrayStructFields(childArray, field,_,_,_), name) => + case Alias(child @ GetArrayStructFields(childArray, field, _, _, _), name) => flattenProjectedColumns(child) case Alias(child, name) => List(name) -> true @@ -154,27 +160,32 @@ class XDDataFrame private[sql](@transient override val sqlContext: SQLContext, } def flatRows( - rows: Seq[Row], - firstLevelNames: Seq[(Seq[String], Boolean)] = Seq.empty - ): Seq[Row] = { + rows: Seq[Row], + firstLevelNames: Seq[(Seq[String], Boolean)] = Seq.empty + ): Seq[Row] = { - def baseName(parentName: String): String = parentName.headOption.map(_ => s"$parentName.").getOrElse("") + def baseName(parentName: String): String = + parentName.headOption.map(_ => s"$parentName.").getOrElse("") def flatRow( - row: GenericRowWithSchema, - parentsNamesAndAlias: Seq[(String, Boolean)] = Seq.empty): Array[(StructField, Any)] = { - (row.schema.fields zip row.values zipAll(parentsNamesAndAlias, null, "" -> false)) flatMap { + row: GenericRowWithSchema, + parentsNamesAndAlias: Seq[(String, Boolean)] = Seq.empty): Array[(StructField, Any)] = { + (row.schema.fields zip row.values zipAll (parentsNamesAndAlias, null, "" -> false)) flatMap { case (null, _) => Seq.empty case ((StructField(_, t, nable, mdata), vobject), (name, true)) => Seq((StructField(name, t, nable, mdata), vobject)) - case ((StructField(name, StructType(_), _, _), col: GenericRowWithSchema), (parentName, false)) => + case ((StructField(name, StructType(_), _, _), col: GenericRowWithSchema), + (parentName, false)) => flatRow(col, Seq.fill(col.schema.size)(s"${baseName(parentName)}$name" -> false)) case ((StructField(name, dtype, nullable, meta), vobject), (parentName, false)) => Seq((StructField(s"${baseName(parentName)}$name", dtype, nullable, meta), vobject)) } } - require(firstLevelNames.isEmpty ||rows.isEmpty || firstLevelNames.size == rows.headOption.map(_.length).getOrElse(0)) + require( + firstLevelNames.isEmpty || rows.isEmpty || firstLevelNames.size == rows.headOption + .map(_.length) + .getOrElse(0)) val thisLevelNames = firstLevelNames.map { case (nameseq, true) => (nameseq.headOption.getOrElse(""), true) case (nameseq, false) => (nameseq.init mkString ".", false) @@ -183,8 +194,9 @@ class XDDataFrame private[sql](@transient override val sqlContext: SQLContext, rows map { case row: GenericRowWithSchema => val newFieldsArray = flatRow(row, thisLevelNames) - val horizontallyFlattened: Row = new GenericRowWithSchema( - newFieldsArray.map(_._2), StructType(newFieldsArray.map(_._1))) + val horizontallyFlattened: Row = + new GenericRowWithSchema(newFieldsArray.map(_._2), + StructType(newFieldsArray.map(_._1))) horizontallyFlattened case row: Row => row @@ -193,116 +205,129 @@ class XDDataFrame private[sql](@transient override val sqlContext: SQLContext, def verticallyFlatRowArrays(row: GenericRowWithSchema)(limit: Int): Seq[GenericRowWithSchema] = { - def cartesian[T](ls: Seq[Seq[T]]): Seq[Seq[T]] = (ls :\ Seq(Seq.empty[T])) { - case (cur: Seq[T], prev) => for(x <- prev; y <- cur) yield y +: x - } + def cartesian[T](ls: Seq[Seq[T]]): Seq[Seq[T]] = + (ls :\ Seq(Seq.empty[T])) { + case (cur: Seq[T], prev) => for (x <- prev; y <- cur) yield y +: x + } val newSchema = StructType( - row.schema map { - case StructField(name, ArrayType(etype, _), nullable, meta) => - StructField(name, etype, true) - case other => other - } + row.schema map { + case StructField(name, ArrayType(etype, _), nullable, meta) => + StructField(name, etype, true) + case other => other + } ) val elementsWithIndex = row.values zipWithIndex val arrayColumnValues: Seq[Seq[(Int, _)]] = elementsWithIndex collect { - case (res: Seq[_], idx) => res map(idx -> _) + case (res: Seq[_], idx) => res map (idx -> _) } - cartesian(arrayColumnValues).take(limit) map { case replacements: Seq[(Int, _) @unchecked] => - val idx2newVal: Map[Int, Any] = replacements.toMap - val values = elementsWithIndex map { case (prevVal, idx: Int) => - idx2newVal.getOrElse(idx, prevVal) - } - new GenericRowWithSchema(values, newSchema) + cartesian(arrayColumnValues).take(limit) map { + case replacements: Seq[(Int, _) @unchecked] => + val idx2newVal: Map[Int, Any] = replacements.toMap + val values = elementsWithIndex map { + case (prevVal, idx: Int) => + idx2newVal.getOrElse(idx, prevVal) + } + new GenericRowWithSchema(values, newSchema) } } import WithTrackerFlatMapSeq._ def iterativeFlatten( - rows: Seq[Row], - firstLevelNames: Seq[(Seq[String], Boolean)] = Seq.empty - )(limit: Int = Int.MaxValue): Seq[Row] = + rows: Seq[Row], + firstLevelNames: Seq[(Seq[String], Boolean)] = Seq.empty + )(limit: Int = Int.MaxValue): Seq[Row] = flatRows(rows, firstLevelNames) withTrackerFlatMap { - case (_, Some(currentSize)) if(currentSize >= limit) => Seq() + case (_, Some(currentSize)) if (currentSize >= limit) => Seq() case (row: GenericRowWithSchema, currentSize) => row.schema collectFirst { case StructField(_, _: ArrayType, _, _) => - val newLimit = limit-currentSize.getOrElse(0) + val newLimit = limit - currentSize.getOrElse(0) iterativeFlatten(verticallyFlatRowArrays(row)(newLimit))(newLimit) } getOrElse Seq(row) case (row: Row, _) => Seq(row) } - def processProjection(plist: Seq[NamedExpression], child: LogicalPlan, limit: Int = Int.MaxValue): Array[Row] = { + def processProjection(plist: Seq[NamedExpression], + child: LogicalPlan, + limit: Int = Int.MaxValue): Array[Row] = { val fullyAnnotatedRequestedColumns = plist map (flattenProjectedColumns(_)) iterativeFlatten(collect(), fullyAnnotatedRequestedColumns)(limit) toArray } queryExecution.optimizedPlan match { - case Limit(lexp, Project(plist, child)) => processProjection(plist, child, lexp.toString().toInt) + case Limit(lexp, Project(plist, child)) => + processProjection(plist, child, lexp.toString().toInt) case Project(plist, child) => processProjection(plist, child) - case Limit(lexp, _) => iterativeFlatten(collect())(lexp.toString().toInt) toArray + case Limit(lexp, _) => + iterativeFlatten(collect())(lexp.toString().toInt) toArray case _ => iterativeFlatten(collect())() toArray } } /** - * Collect using an specific [[ExecutionType]]. Only for testing purpose so far. - * When using the Security Manager, this method has to be invoked with the parameter [[ExecutionType.Default]] - * in order to ensure that the workflow of the execution reaches the point where the authorization is called. - * - * @param executionType one of the [[ExecutionType]] - * @return the query result - */ + * Collect using an specific [[ExecutionType]]. Only for testing purpose so far. + * When using the Security Manager, this method has to be invoked with the parameter [[ExecutionType.Default]] + * in order to ensure that the workflow of the execution reaches the point where the authorization is called. + * + * @param executionType one of the [[ExecutionType]] + * @return the query result + */ @DeveloperApi def collect(executionType: ExecutionType): Array[Row] = executionType match { case Default => collect() case Spark => super.collect() case Native => - val result = findNativeQueryExecutor(queryExecution.optimizedPlan).flatMap(executeNativeQuery) + val result = + findNativeQueryExecutor(queryExecution.optimizedPlan).flatMap(executeNativeQuery) result.getOrElse(throw new NativeExecutionException) } - /** - * @inheritdoc - */ - override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(collect(): _*) + * @inheritdoc + */ + override def collectAsList(): java.util.List[Row] = + java.util.Arrays.asList(collect(): _*) /** - * @inheritdoc - */ - override def limit(n: Int): DataFrame = XDDataFrame(sqlContext, Limit(Literal(n), logicalPlan)) + * @inheritdoc + */ + override def limit(n: Int): DataFrame = + XDDataFrame(sqlContext, Limit(Literal(n), logicalPlan)) /** - * @inheritdoc - */ + * @inheritdoc + */ override def count(): Long = { val aggregateExpr = Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")()) - XDDataFrame(sqlContext, Aggregate(Seq.empty, aggregateExpr, logicalPlan)).collect().head.getLong(0) + XDDataFrame(sqlContext, Aggregate(Seq.empty, aggregateExpr, logicalPlan)) + .collect() + .head + .getLong(0) } - /** - * Executes the logical plan. - * - * @param provider [[org.apache.spark.sql.sources.BaseRelation]] mixing-in [[NativeScan]] - * @return an array that contains all of [[Row]]s in this [[XDDataFrame]] - * or None if the provider cannot resolve the entire [[XDDataFrame]] natively. - */ + * Executes the logical plan. + * + * @param provider [[org.apache.spark.sql.sources.BaseRelation]] mixing-in [[NativeScan]] + * @return an array that contains all of [[Row]]s in this [[XDDataFrame]] + * or None if the provider cannot resolve the entire [[XDDataFrame]] natively. + */ private[this] def executeNativeQuery(provider: NativeScan): Option[Array[Row]] = { val containsSubfields = notSupportedProject(queryExecution.optimizedPlan) - val planSupported = !containsSubfields && queryExecution.optimizedPlan.map(lp => lp).forall(provider.isSupported(_, queryExecution.optimizedPlan)) - if(planSupported) { + val planSupported = !containsSubfields && queryExecution.optimizedPlan + .map(lp => lp) + .forall(provider.isSupported(_, queryExecution.optimizedPlan)) + if (planSupported) { // TODO handle failed executions which are currently wrapped within the option, so these jobs will appear duplicated // TODO the plan should notice the native execution - withNewExecutionId{ + withNewExecutionId { provider.buildScan(queryExecution.optimizedPlan) } } else @@ -313,32 +338,40 @@ class XDDataFrame private[sql](@transient override val sqlContext: SQLContext, private[this] def notSupportedProject(optimizedLogicalPlan: LogicalPlan): Boolean = { optimizedLogicalPlan collectFirst { - case a@Project(seq, _) if seq.collectFirst { case b: GetMapValue => b }.isDefined => a - case a@Project(seq, _) if seq.collectFirst { case b: GetStructField => b }.isDefined => a - case a@Project(seq, _) if seq.collectFirst { case Alias(b: GetMapValue, _) => a }.isDefined => a - case a@Project(seq, _) if seq.collectFirst { case Alias(b: GetStructField, _) => a }.isDefined => a + case a @ Project(seq, _) if seq.collectFirst { case b: GetMapValue => b }.isDefined => + a + case a @ Project(seq, _) if seq.collectFirst { + case b: GetStructField => b + }.isDefined => + a + case a @ Project(seq, _) if seq.collectFirst { + case Alias(b: GetMapValue, _) => a + }.isDefined => + a + case a @ Project(seq, _) if seq.collectFirst { + case Alias(b: GetStructField, _) => a + }.isDefined => + a } isDefined } - - //TODO: Move to a common library private[crossdata] object WithTrackerFlatMapSeq { - implicit def seq2superflatmapseq[T](s: Seq[T]): WithTrackerFlatMapSeq[T] = new WithTrackerFlatMapSeq(s) + implicit def seq2superflatmapseq[T](s: Seq[T]): WithTrackerFlatMapSeq[T] = + new WithTrackerFlatMapSeq(s) } //TODO: Move to a common library - private[crossdata] class WithTrackerFlatMapSeq[T] private(val s: Seq[T]) - extends scala.collection.immutable.Seq[T] { + private[crossdata] class WithTrackerFlatMapSeq[T] private (val s: Seq[T]) + extends scala.collection.immutable.Seq[T] { override def length: Int = s.length override def apply(idx: Int): T = s(idx) override def iterator: Iterator[T] = s.iterator - def withTrackerFlatMap[B, That]( - f: (T, Option[Int]) => GenTraversableOnce[B] - )(implicit bf: CanBuildFrom[immutable.Seq[T], B, That]): That = { - def builder : mutable.Builder[B, That] = bf(repr) + f: (T, Option[Int]) => GenTraversableOnce[B] + )(implicit bf: CanBuildFrom[immutable.Seq[T], B, That]): That = { + def builder: mutable.Builder[B, That] = bf(repr) val b = builder val builderAsBufferLike = b match { case bufferl: BufferLike[_, _] => Some(bufferl) diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/XDSQLConf.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/XDSQLConf.scala index d3aa0ea3a..909f7d417 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/XDSQLConf.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/XDSQLConf.scala @@ -21,7 +21,6 @@ trait XDSQLConf extends SQLConf { def enableCacheInvalidation(enable: Boolean): XDSQLConf } - object XDSQLConf { implicit def fromSQLConf(conf: SQLConf): XDSQLConf = new XDSQLConf { @@ -30,4 +29,3 @@ object XDSQLConf { override protected[spark] val settings: java.util.Map[String, String] = conf.settings } } - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/XDSession.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/XDSession.scala index 4d8a0ce64..ce7be4d09 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/XDSession.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/XDSession.scala @@ -22,25 +22,24 @@ import org.apache.spark.sql.crossdata.catalog.interfaces.XDCatalogCommon import org.apache.spark.sql.crossdata.catalog.{CatalogChain, XDCatalog} import org.apache.spark.sql.crossdata.session.{XDSessionState, XDSharedState} - -object XDSession{ +object XDSession { // TODO Spark2.0. It will be the main entryPoint, so we should add a XDSession builder to make it easier to work with. } /** - * - * [[XDSession]], as with Spark 2.0, SparkSession will be the Crossdata entry point for SQL interfaces. It wraps and - * implements [[XDContext]]. Overriding those methods & attributes which vary among sessions and keeping - * common ones in the delegated [[XDContext]]. - * - * Resource initialization is avoided through attribute initialization laziness. - */ + * + * [[XDSession]], as with Spark 2.0, SparkSession will be the Crossdata entry point for SQL interfaces. It wraps and + * implements [[XDContext]]. Overriding those methods & attributes which vary among sessions and keeping + * common ones in the delegated [[XDContext]]. + * + * Resource initialization is avoided through attribute initialization laziness. + */ class XDSession( - xdSharedState: XDSharedState, - xdSessionState: XDSessionState, - userConfig: Option[Config] = None - ) - extends XDContext(xdSharedState.sc) with Logging { + xdSharedState: XDSharedState, + xdSessionState: XDSessionState, + userConfig: Option[Config] = None +) extends XDContext(xdSharedState.sc) + with Logging { override protected[sql] lazy val catalog: XDCatalog = { val catalogs: Seq[XDCatalogCommon] = (xdSessionState.temporaryCatalogs :+ xdSharedState.externalCatalog) ++ xdSharedState.streamingCatalog.toSeq @@ -48,9 +47,9 @@ class XDSession( } - override protected[sql] lazy val conf: SQLConf = xdSessionState.sqlConf.enableCacheInvalidation(false) + override protected[sql] lazy val conf: SQLConf = + xdSessionState.sqlConf.enableCacheInvalidation(false) xdSessionState.sqlConf.enableCacheInvalidation(true) } - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/CatalogChain.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/CatalogChain.scala index 91548545e..7176ef19a 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/CatalogChain.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/CatalogChain.scala @@ -25,16 +25,21 @@ import org.apache.spark.sql.crossdata.models.{EphemeralQueryModel, EphemeralStat import scala.util.Try - object CatalogChain { def apply(catalogs: XDCatalogCommon*)(implicit xdContext: XDContext): CatalogChain = { - val temporaryCatalogs = catalogs.collect { case a: XDTemporaryCatalog => a } - val persistentCatalogs = catalogs.collect { case a: XDPersistentCatalog => a } - val streamingCatalogs = catalogs.collect { case a: XDStreamingCatalog => a } + val temporaryCatalogs = catalogs.collect { + case a: XDTemporaryCatalog => a + } + val persistentCatalogs = catalogs.collect { + case a: XDPersistentCatalog => a + } + val streamingCatalogs = catalogs.collect { + case a: XDStreamingCatalog => a + } require(streamingCatalogs.length <= 1, "Only one streaming catalog can be included") require( - temporaryCatalogs.headOption.orElse(persistentCatalogs.headOption).isDefined, - "At least one catalog (temporary or persistent ) must be included" + temporaryCatalogs.headOption.orElse(persistentCatalogs.headOption).isDefined, + "At least one catalog (temporary or persistent ) must be included" ) new CatalogChain(temporaryCatalogs, persistentCatalogs, streamingCatalogs.headOption) } @@ -44,10 +49,12 @@ object CatalogChain { Write through (always true for this class)-> Each write is synchronously done to all catalogs in the chain No-Write allocate (always true) -> A miss at levels 0...i-1,i isn't written to these levels when found at level i+1 */ -private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTemporaryCatalog], - val persistentCatalogs: Seq[XDPersistentCatalog], - val streamingCatalogs: Option[XDStreamingCatalog] - )(implicit val xdContext: XDContext) extends XDCatalog with SparkLoggerComponent { +private[crossdata] class CatalogChain private ( + val temporaryCatalogs: Seq[XDTemporaryCatalog], + val persistentCatalogs: Seq[XDPersistentCatalog], + val streamingCatalogs: Option[XDStreamingCatalog])(implicit val xdContext: XDContext) + extends XDCatalog + with SparkLoggerComponent { import XDCatalogCommon._ @@ -55,8 +62,8 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo private val catalogs: Seq[XDCatalogCommon] = temporaryCatalogs ++: persistentCatalogs ++: streamingCatalogs.toSeq - - private implicit def crossdataTable2tableIdentifier(xdTable: CrossdataTable): TableIdentifierNormalized = + private implicit def crossdataTable2tableIdentifier( + xdTable: CrossdataTable): TableIdentifierNormalized = xdTable.tableIdentifier private def normalize(tableIdentifier: TableIdentifier): TableIdentifierNormalized = @@ -69,7 +76,8 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo * Apply the lookup function to each underlying catalog until a [[LogicalPlan]] is found. If the table is found in a * temporary catalog, the relation is saved into the previous temporary catalogs. */ - private def chainedLookup(lookup: XDCatalogCommon => Option[LogicalPlan], tableIdentifier: TableIdentifier): Option[LogicalPlan] = { + private def chainedLookup(lookup: XDCatalogCommon => Option[LogicalPlan], + tableIdentifier: TableIdentifier): Option[LogicalPlan] = { val (relationOpt, previousCatalogs) = takeUntilRelationFound(lookup, temporaryCatalogs) if (relationOpt.isDefined) { @@ -83,7 +91,6 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo } - /** * Apply the lookup function to each temporary catalog until a relation [[R]] is found. Returns the list of catalogs, * until a catalog satisfy the predicate 'lookup'. @@ -92,8 +99,9 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo * @param tempCatalogs a seq of temporary catalogs * @return a tuple (optionalRelation, previousNonMatchingLookupCatalogs) */ - private def takeUntilRelationFound[R](lookup: XDCatalogCommon => Option[R], tempCatalogs: Seq[XDTemporaryCatalog]): - (Option[R], Seq[XDTemporaryCatalog]) = { + private def takeUntilRelationFound[R]( + lookup: XDCatalogCommon => Option[R], + tempCatalogs: Seq[XDTemporaryCatalog]): (Option[R], Seq[XDTemporaryCatalog]) = { val (res: Option[R], idx: Int) = (tempCatalogs.view map (lookup) zipWithIndex) collectFirst { case e @ (Some(_), _) => e @@ -102,20 +110,23 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo (res, tempCatalogs.take(idx)) } - private def persistentChainedLookup[R](lookup: XDPersistentCatalog => Option[R]): Option[R] = persistentCatalogs.view map lookup collectFirst { case Some(res) => res } /** - * TemporaryCatalog - */ - override def registerView(viewIdentifier: ViewIdentifier, logicalPlan: LogicalPlan, sql: Option[String]): Unit = + * TemporaryCatalog + */ + override def registerView(viewIdentifier: ViewIdentifier, + logicalPlan: LogicalPlan, + sql: Option[String]): Unit = temporaryCatalogs.foreach(_.saveView(normalize(viewIdentifier), logicalPlan, sql)) // TODO throw an exception if there is no temp catalogs! Review CatalogChain - override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan, crossdataTable: Option[CrossdataTable]): Unit = + override def registerTable(tableIdent: TableIdentifier, + plan: LogicalPlan, + crossdataTable: Option[CrossdataTable]): Unit = temporaryCatalogs.foreach(_.saveTable(normalize(tableIdent), plan, crossdataTable)) override def unregisterView(viewIdentifier: ViewIdentifier): Unit = @@ -127,16 +138,16 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo override def unregisterAllTables(): Unit = temporaryCatalogs.foreach(_.dropAllTables()) - /** - * CommonCatalog - */ - + * CommonCatalog + */ private def lookupRelationOpt(tableIdent: TableIdentifier): Option[LogicalPlan] = chainedLookup(_.relation(normalize(tableIdent)), tableIdent) override def lookupRelation(tableIdent: TableIdentifier, alias: Option[String]): LogicalPlan = - lookupRelationOpt(tableIdent) map { processAlias(tableIdent, _, alias)(conf)} getOrElse { + lookupRelationOpt(tableIdent) map { + processAlias(tableIdent, _, alias)(conf) + } getOrElse { log.debug(s"Relation not found: ${tableIdent.unquotedString}") sys.error(s"Relation not found: ${tableIdent.unquotedString}") } @@ -146,39 +157,47 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo // TODO streaming tables override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - def getRelations(catalogSeq: Seq[XDCatalogCommon], isTemporary: Boolean): Seq[(String, Boolean)] = { + def getRelations(catalogSeq: Seq[XDCatalogCommon], + isTemporary: Boolean): Seq[(String, Boolean)] = { catalogSeq.flatMap { cat => - cat.allRelations(databaseName.map( dbn => StringNormalized(XDCatalogCommon.normalizeIdentifier(dbn, conf)))).map(stringifyTableIdentifierNormalized(_) -> isTemporary) + cat + .allRelations(databaseName.map(dbn => + StringNormalized(XDCatalogCommon.normalizeIdentifier(dbn, conf)))) + .map(stringifyTableIdentifierNormalized(_) -> isTemporary) } } - getRelations(temporaryCatalogs, isTemporary = true) ++ getRelations(persistentCatalogs, isTemporary = false) + getRelations(temporaryCatalogs, isTemporary = true) ++ getRelations(persistentCatalogs, + isTemporary = false) } /** - * Check the connection to the set Catalog - */ + * Check the connection to the set Catalog + */ override def checkConnectivity: Boolean = catalogs.forall(_.isAvailable) /** - * ExternalCatalog - */ - + * ExternalCatalog + */ override def persistTable(crossdataTable: CrossdataTable, table: LogicalPlan): Unit = persistentCatalogs.foreach(_.saveTable(crossdataTable, table)) - override def persistView(viewIdentifier: ViewIdentifier, plan: LogicalPlan, sqlText: String): Unit = + override def persistView(viewIdentifier: ViewIdentifier, + plan: LogicalPlan, + sqlText: String): Unit = persistentCatalogs.foreach(_.saveView(normalize(viewIdentifier), plan, sqlText)) override def persistIndex(crossdataIndex: CrossdataIndex): Unit = if (tableMetadata(crossdataIndex.tableIdentifier.toTableIdentifier).isEmpty) { - throw new RuntimeException(s"Cannot create the index. Table ${crossdataIndex.tableIdentifier} doesn't exist or is temporary") + throw new RuntimeException( + s"Cannot create the index. Table ${crossdataIndex.tableIdentifier} doesn't exist or is temporary") } else { persistentCatalogs.foreach(_.saveIndex(crossdataIndex)) } override def dropTable(tableIdentifier: TableIdentifier): Unit = { val strTable = tableIdentifier.unquotedString - if (!tableExists(tableIdentifier)) throw new RuntimeException(s"Table $strTable can't be deleted because it doesn't exist") + if (!tableExists(tableIdentifier)) + throw new RuntimeException(s"Table $strTable can't be deleted because it doesn't exist") logInfo(s"Deleting table $strTable from catalog") indexMetadataByTableIdentifier(tableIdentifier) foreach { index => @@ -198,7 +217,8 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo override def dropView(viewIdentifier: ViewIdentifier): Unit = { val strView = viewIdentifier.unquotedString - if (lookupRelationOpt(viewIdentifier).isEmpty) throw new RuntimeException(s"View $strView can't be deleted because it doesn't exist") + if (lookupRelationOpt(viewIdentifier).isEmpty) + throw new RuntimeException(s"View $strView can't be deleted because it doesn't exist") logInfo(s"Deleting view ${viewIdentifier.unquotedString} from catalog") temporaryCatalogs foreach (_.dropView(normalize(viewIdentifier))) persistentCatalogs foreach (_.dropView(normalize(viewIdentifier))) @@ -209,25 +229,26 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo persistentCatalogs foreach (_.dropAllViews()) } - override def dropIndex(indexIdentifier: IndexIdentifier): Unit = { val strIndex = indexIdentifier.unquotedString - if(indexMetadata(indexIdentifier).isEmpty) throw new RuntimeException(s"Index $strIndex can't be deleted because it doesn't exist") + if (indexMetadata(indexIdentifier).isEmpty) + throw new RuntimeException(s"Index $strIndex can't be deleted because it doesn't exist") logInfo(s"Deleting index ${indexIdentifier.unquotedString} from catalog") //First remove table that holds the index - if(tableExists(indexIdentifier.asTableIdentifier)) + if (tableExists(indexIdentifier.asTableIdentifier)) dropTable(indexIdentifier.asTableIdentifier) - persistentCatalogs foreach(catalog => Try(catalog.dropIndex(indexIdentifier.normalize))) + persistentCatalogs foreach (catalog => Try(catalog.dropIndex(indexIdentifier.normalize))) } - override def indexMetadata(indexIdentifier: IndexIdentifier): Option[CrossdataIndex]= + override def indexMetadata(indexIdentifier: IndexIdentifier): Option[CrossdataIndex] = persistentChainedLookup(_.lookupIndex(indexIdentifier.normalize)) - override def indexMetadataByTableIdentifier(tableIdentifier: TableIdentifier):Option[CrossdataIndex]= + override def indexMetadataByTableIdentifier( + tableIdentifier: TableIdentifier): Option[CrossdataIndex] = persistentCatalogs.view map (_.lookupIndexByTableIdentifier(normalize(tableIdentifier))) collectFirst { - case Some(index) =>index + case Some(index) => index } override def dropAllIndexes(): Unit = { @@ -242,9 +263,8 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo persistentCatalogs.foreach(_.refreshCache(normalize(tableIdent))) /** - * StreamingCatalog - */ - + * StreamingCatalog + */ // Ephemeral Table Functions override def existsEphemeralTable(tableIdentifier: String): Boolean = @@ -253,11 +273,10 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo override def getEphemeralTable(tableIdentifier: String): Option[EphemeralTableModel] = executeWithStrCatalogOrNone(_.getEphemeralTable(tableIdentifier)) - - override def createEphemeralTable(ephemeralTable: EphemeralTableModel): Either[String, EphemeralTableModel] = + override def createEphemeralTable( + ephemeralTable: EphemeralTableModel): Either[String, EphemeralTableModel] = withStreamingCatalogDo(_.createEphemeralTable(ephemeralTable)) - override def dropEphemeralTable(tableIdentifier: String): Unit = withStreamingCatalogDo(_.dropEphemeralTable(tableIdentifier)) @@ -269,7 +288,8 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo // Ephemeral Queries Functions - override def createEphemeralQuery(ephemeralQuery: EphemeralQueryModel): Either[String, EphemeralQueryModel] = + override def createEphemeralQuery( + ephemeralQuery: EphemeralQueryModel): Either[String, EphemeralQueryModel] = withStreamingCatalogDo(_.createEphemeralQuery(ephemeralQuery)) override def getEphemeralQuery(queryAlias: String): Option[EphemeralQueryModel] = @@ -287,10 +307,10 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo override def dropAllEphemeralQueries(): Unit = withStreamingCatalogDo(_.dropAllEphemeralQueries()) - // Ephemeral Status Functions - override protected[crossdata] def getEphemeralStatus(tableIdentifier: String): Option[EphemeralStatusModel] = + override protected[crossdata] def getEphemeralStatus( + tableIdentifier: String): Option[EphemeralStatusModel] = executeWithStrCatalogOrNone(_.getEphemeralStatus(tableIdentifier)) override protected[crossdata] def getAllEphemeralStatuses: Seq[EphemeralStatusModel] = @@ -302,10 +322,13 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo override protected[crossdata] def dropAllEphemeralStatus(): Unit = withStreamingCatalogDo(_.dropAllEphemeralStatus()) - override protected[crossdata] def createEphemeralStatus(tableIdentifier: String, ephemeralStatusModel: EphemeralStatusModel): EphemeralStatusModel = + override protected[crossdata] def createEphemeralStatus( + tableIdentifier: String, + ephemeralStatusModel: EphemeralStatusModel): EphemeralStatusModel = withStreamingCatalogDo(_.createEphemeralStatus(tableIdentifier, ephemeralStatusModel)) - override protected[crossdata] def updateEphemeralStatus(tableIdentifier: String, status: EphemeralStatusModel): Unit = + override protected[crossdata] def updateEphemeralStatus(tableIdentifier: String, + status: EphemeralStatusModel): Unit = withStreamingCatalogDo(_.updateEphemeralStatus(tableIdentifier, status)) // Utils @@ -314,10 +337,12 @@ private[crossdata] class CatalogChain private(val temporaryCatalogs: Seq[XDTempo throw new RuntimeException("There is no streaming catalog") } } - private def executeWithStrCatalogOrNone[R](streamingCatalogOperation: XDStreamingCatalog => Option[R]): Option[R] = + private def executeWithStrCatalogOrNone[R]( + streamingCatalogOperation: XDStreamingCatalog => Option[R]): Option[R] = streamingCatalogs.flatMap(streamingCatalogOperation) - private def executeWithStrCatalogOrEmptyList[R](streamingCatalogOperation: XDStreamingCatalog => Seq[R]): Seq[R] = + private def executeWithStrCatalogOrEmptyList[R]( + streamingCatalogOperation: XDStreamingCatalog => Seq[R]): Seq[R] = streamingCatalogs.toSeq.flatMap(streamingCatalogOperation) override def lookupApp(alias: String): Option[CrossdataApp] = diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/ExternalCatalogAPI.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/ExternalCatalogAPI.scala index 43bf3fac9..ca787cfc3 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/ExternalCatalogAPI.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/ExternalCatalogAPI.scala @@ -20,8 +20,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.crossdata.catalog.XDCatalog.{CrossdataIndex, CrossdataTable, IndexIdentifier, ViewIdentifier} import org.apache.spark.sql.crossdata.catalog.interfaces.XDAppsCatalog - -private[crossdata] trait ExternalCatalogAPI extends XDAppsCatalog{ +private[crossdata] trait ExternalCatalogAPI extends XDAppsCatalog { def persistTable(crossdataTable: CrossdataTable, table: LogicalPlan): Unit def persistView(viewIdentifier: ViewIdentifier, plan: LogicalPlan, sqlText: String): Unit @@ -42,6 +41,3 @@ private[crossdata] trait ExternalCatalogAPI extends XDAppsCatalog{ def tableHasGlobalIndex(tableIdentifier: TableIdentifier): Boolean = indexMetadataByTableIdentifier(tableIdentifier).isDefined } - - - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/StreamingCatalogAPI.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/StreamingCatalogAPI.scala index 41c88b40f..17b835836 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/StreamingCatalogAPI.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/StreamingCatalogAPI.scala @@ -25,11 +25,12 @@ private[crossdata] trait StreamingCatalogAPI { */ def existsEphemeralTable(tableIdentifier: String): Boolean - def getEphemeralTable(tableIdentifier: String) : Option[EphemeralTableModel] + def getEphemeralTable(tableIdentifier: String): Option[EphemeralTableModel] - def getAllEphemeralTables : Seq[EphemeralTableModel] + def getAllEphemeralTables: Seq[EphemeralTableModel] - def createEphemeralTable(ephemeralTable: EphemeralTableModel): Either[String, EphemeralTableModel] + def createEphemeralTable( + ephemeralTable: EphemeralTableModel): Either[String, EphemeralTableModel] def dropEphemeralTable(tableIdentifier: String): Unit @@ -38,13 +39,17 @@ private[crossdata] trait StreamingCatalogAPI { /** * Ephemeral Status Functions */ - protected[crossdata] def createEphemeralStatus(tableIdentifier: String, ephemeralStatusModel: EphemeralStatusModel): EphemeralStatusModel + protected[crossdata] def createEphemeralStatus( + tableIdentifier: String, + ephemeralStatusModel: EphemeralStatusModel): EphemeralStatusModel - protected[crossdata] def getEphemeralStatus(tableIdentifier: String) : Option[EphemeralStatusModel] + protected[crossdata] def getEphemeralStatus( + tableIdentifier: String): Option[EphemeralStatusModel] - protected[crossdata] def getAllEphemeralStatuses : Seq[EphemeralStatusModel] + protected[crossdata] def getAllEphemeralStatuses: Seq[EphemeralStatusModel] - protected[crossdata] def updateEphemeralStatus(tableIdentifier: String, status: EphemeralStatusModel) : Unit + protected[crossdata] def updateEphemeralStatus(tableIdentifier: String, + status: EphemeralStatusModel): Unit protected[crossdata] def dropEphemeralStatus(tableIdentifier: String): Unit @@ -55,11 +60,12 @@ private[crossdata] trait StreamingCatalogAPI { */ def existsEphemeralQuery(queryAlias: String): Boolean - def getEphemeralQuery(queryAlias: String) : Option[EphemeralQueryModel] + def getEphemeralQuery(queryAlias: String): Option[EphemeralQueryModel] - def getAllEphemeralQueries : Seq[EphemeralQueryModel] + def getAllEphemeralQueries: Seq[EphemeralQueryModel] - def createEphemeralQuery(ephemeralQuery: EphemeralQueryModel): Either[String, EphemeralQueryModel] + def createEphemeralQuery( + ephemeralQuery: EphemeralQueryModel): Either[String, EphemeralQueryModel] def dropEphemeralQuery(queryAlias: String): Unit diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/XDCatalog.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/XDCatalog.scala index dabe4bca1..698d0ac4b 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/XDCatalog.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/XDCatalog.scala @@ -15,7 +15,6 @@ */ package org.apache.spark.sql.crossdata.catalog - import org.apache.spark.sql.catalyst.{CatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.Catalog import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -26,10 +25,10 @@ import org.apache.spark.sql.crossdata.serializers.CrossdataSerializer import org.apache.spark.sql.types.StructType import org.json4s.jackson.Serialization._ - object XDCatalog extends CrossdataSerializer { -implicit def asXDCatalog (catalog: Catalog): XDCatalog = catalog.asInstanceOf[XDCatalog] + implicit def asXDCatalog(catalog: Catalog): XDCatalog = + catalog.asInstanceOf[XDCatalog] type ViewIdentifier = TableIdentifier type ViewIdentifierNormalized = TableIdentifierNormalized @@ -38,60 +37,66 @@ implicit def asXDCatalog (catalog: Catalog): XDCatalog = catalog.asInstanceOf[XD def quotedString: String = s"`$indexName`.`$indexType`" def unquotedString: String = s"$indexName.$indexType" override def toString: String = quotedString - def asTableIdentifier: TableIdentifier = TableIdentifier(indexType,Option(indexName)) + def asTableIdentifier: TableIdentifier = + TableIdentifier(indexType, Option(indexName)) } - case class CrossdataTable(tableIdentifier: TableIdentifierNormalized, schema: Option[StructType], - datasource: String, partitionColumn: Array[String] = Array.empty, - opts: Map[String, String] = Map.empty, crossdataVersion: String = crossdata.CrossdataVersion) - - - case class CrossdataIndex(tableIdentifier: TableIdentifierNormalized, indexIdentifier: IndexIdentifierNormalized, - indexedCols: Seq[String], pk: String, datasource: String, - opts: Map[String, String] = Map.empty, crossdataVersion: String = crossdata.CrossdataVersion) - + case class CrossdataTable(tableIdentifier: TableIdentifierNormalized, + schema: Option[StructType], + datasource: String, + partitionColumn: Array[String] = Array.empty, + opts: Map[String, String] = Map.empty, + crossdataVersion: String = crossdata.CrossdataVersion) + + case class CrossdataIndex(tableIdentifier: TableIdentifierNormalized, + indexIdentifier: IndexIdentifierNormalized, + indexedCols: Seq[String], + pk: String, + datasource: String, + opts: Map[String, String] = Map.empty, + crossdataVersion: String = crossdata.CrossdataVersion) case class CrossdataApp(jar: String, appAlias: String, appClass: String) - def serializeSchema(schema: StructType): String = write(schema) - def deserializeUserSpecifiedSchema(schemaJSON: String): StructType = read[StructType](schemaJSON) + def deserializeUserSpecifiedSchema(schemaJSON: String): StructType = + read[StructType](schemaJSON) - def serializePartitionColumn(partitionColumn: Array[String]): String = write(partitionColumn) + def serializePartitionColumn(partitionColumn: Array[String]): String = + write(partitionColumn) - def deserializePartitionColumn(partitionColumn: String): Array[String] = read[Array[String]](partitionColumn) + def deserializePartitionColumn(partitionColumn: String): Array[String] = + read[Array[String]](partitionColumn) - def serializeOptions(options: Map[String, String]): String = write(options) + def serializeOptions(options: Map[String, String]): String = write(options) - def deserializeOptions(optsJSON: String): Map[String, String] = read[Map[String, String]](optsJSON) + def deserializeOptions(optsJSON: String): Map[String, String] = + read[Map[String, String]](optsJSON) def serializeSeq(seq: Seq[String]): String = write(seq) def deserializeSeq(seqJSON: String): Seq[String] = read[Seq[String]](seqJSON) - } -trait XDCatalog extends Catalog -with ExternalCatalogAPI -with StreamingCatalogAPI { +trait XDCatalog extends Catalog with ExternalCatalogAPI with StreamingCatalogAPI { - def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan, crossdataTable: Option[CrossdataTable]): Unit - def registerView(viewIdentifier: ViewIdentifier, logicalPlan: LogicalPlan, sql: Option[String] = None): Unit + def registerTable(tableIdent: TableIdentifier, + plan: LogicalPlan, + crossdataTable: Option[CrossdataTable]): Unit + def registerView(viewIdentifier: ViewIdentifier, + logicalPlan: LogicalPlan, + sql: Option[String] = None): Unit final def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = registerTable(tableIdent, plan, None) - def unregisterView(viewIdentifier: ViewIdentifier): Unit /** - * Check the connection to the set Catalog - */ + * Check the connection to the set Catalog + */ def checkConnectivity: Boolean } - - - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/identifiersNormalized.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/identifiersNormalized.scala index 640c6a2dd..2748b38a7 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/identifiersNormalized.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/identifiersNormalized.scala @@ -23,25 +23,30 @@ case class TableIdentifierNormalized(table: String, database: Option[String]) { override def toString: String = quotedString - def quotedString: String = database.map(db => s"`$db`.`$table`").getOrElse(s"`$table`") + def quotedString: String = + database.map(db => s"`$db`.`$table`").getOrElse(s"`$table`") - def unquotedString: String = database.map(db => s"$db.$table").getOrElse(table) + def unquotedString: String = + database.map(db => s"$db.$table").getOrElse(table) def toTableIdentifier: TableIdentifier = TableIdentifier(table, database) } private[sql] object TableIdentifierNormalized { - def apply(tableName: String): TableIdentifierNormalized = new TableIdentifierNormalized(tableName) + def apply(tableName: String): TableIdentifierNormalized = + new TableIdentifierNormalized(tableName) } case class IndexIdentifierNormalized(indexType: String, indexName: String) { def quotedString: String = s"`$indexName`.`$indexType`" def unquotedString: String = s"$indexName.$indexType" override def toString: String = quotedString - def toIndexIdentifier: IndexIdentifier = IndexIdentifier(indexType, indexName) - def asTableIdentifierNormalized: TableIdentifierNormalized = TableIdentifierNormalized(indexType,Option(indexName)) + def toIndexIdentifier: IndexIdentifier = + IndexIdentifier(indexType, indexName) + def asTableIdentifierNormalized: TableIdentifierNormalized = + TableIdentifierNormalized(indexType, Option(indexName)) } -case class StringNormalized(normalizedString: String){ +case class StringNormalized(normalizedString: String) { override def toString = normalizedString } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/interfaces/catalogs.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/interfaces/catalogs.scala index f2d2dcdb8..55007d7a6 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/interfaces/catalogs.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/interfaces/catalogs.scala @@ -27,8 +27,9 @@ object XDCatalogCommon { implicit class RichTableIdentifier(tableIdentifier: TableIdentifier) { def normalize(implicit conf: CatalystConf): TableIdentifierNormalized = { - val normalizedDatabase = tableIdentifier.database.map(normalizeIdentifier(_,conf)) - TableIdentifierNormalized(normalizeIdentifier(tableIdentifier.table, conf), normalizedDatabase) + val normalizedDatabase = tableIdentifier.database.map(normalizeIdentifier(_, conf)) + TableIdentifierNormalized(normalizeIdentifier(tableIdentifier.table, conf), + normalizedDatabase) } } @@ -46,7 +47,6 @@ object XDCatalogCommon { def normalizeTableIdentifier(tableIdent: TableIdentifier, conf: CatalystConf): String = stringifyTableIdentifierNormalized(tableIdent.normalize(conf)) - def normalizeIdentifier(identifier: String, conf: CatalystConf): String = if (conf.caseSensitiveAnalysis) { identifier @@ -54,7 +54,8 @@ object XDCatalogCommon { identifier.toLowerCase } - def processAlias(tableIdentifier: TableIdentifier, lPlan: LogicalPlan, alias: Option[String])(conf: CatalystConf) = { + def processAlias(tableIdentifier: TableIdentifier, lPlan: LogicalPlan, alias: Option[String])( + conf: CatalystConf) = { val tableWithQualifiers = Subquery(normalizeTableIdentifier(tableIdentifier, conf), lPlan) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. @@ -66,7 +67,8 @@ sealed trait XDCatalogCommon extends SparkLoggerComponent { def catalystConf: CatalystConf - def relation(tableIdent: TableIdentifierNormalized)(implicit sqlContext: SQLContext): Option[LogicalPlan] + def relation(tableIdent: TableIdentifierNormalized)( + implicit sqlContext: SQLContext): Option[LogicalPlan] def allRelations(databaseName: Option[StringNormalized] = None): Seq[TableIdentifierNormalized] @@ -78,20 +80,17 @@ sealed trait XDCatalogCommon extends SparkLoggerComponent { throw new RuntimeException(message) } - } trait XDTemporaryCatalog extends XDCatalogCommon { - def saveTable( - tableIdentifier: TableIdentifierNormalized, - plan: LogicalPlan, - crossdataTable: Option[CrossdataTable] = None): Unit - - def saveView( - viewIdentifier: ViewIdentifierNormalized, + def saveTable(tableIdentifier: TableIdentifierNormalized, plan: LogicalPlan, - query: Option[String] = None): Unit + crossdataTable: Option[CrossdataTable] = None): Unit + + def saveView(viewIdentifier: ViewIdentifierNormalized, + plan: LogicalPlan, + query: Option[String] = None): Unit def dropTable(tableIdentifier: TableIdentifierNormalized): Unit @@ -103,14 +102,15 @@ trait XDTemporaryCatalog extends XDCatalogCommon { } - trait XDPersistentCatalog extends XDCatalogCommon { def refreshCache(tableIdent: TableIdentifierNormalized): Unit - def saveTable(crossdataTable: CrossdataTable, plan: LogicalPlan)(implicit sqlContext: SQLContext): Unit + def saveTable(crossdataTable: CrossdataTable, plan: LogicalPlan)( + implicit sqlContext: SQLContext): Unit - def saveView(tableIdentifier: ViewIdentifierNormalized, plan: LogicalPlan, sqlText: String)(implicit sqlContext: SQLContext): Unit + def saveView(tableIdentifier: ViewIdentifierNormalized, plan: LogicalPlan, sqlText: String)( + implicit sqlContext: SQLContext): Unit def saveIndex(crossdataIndex: CrossdataIndex): Unit @@ -135,7 +135,8 @@ trait XDPersistentCatalog extends XDCatalogCommon { def lookupIndex(indexIdentifier: IndexIdentifierNormalized): Option[CrossdataIndex] //TODO: Index operations to trait - def lookupIndexByTableIdentifier(tableIdentifier: TableIdentifierNormalized): Option[CrossdataIndex] + def lookupIndexByTableIdentifier( + tableIdentifier: TableIdentifierNormalized): Option[CrossdataIndex] def getApp(alias: String): Option[CrossdataApp] @@ -156,45 +157,51 @@ trait XDStreamingCatalog extends XDCatalogCommon { //TODO: TableIdentifier shouldn't be a String /** - * Ephemeral Table Functions - */ + * Ephemeral Table Functions + */ def existsEphemeralTable(tableIdentifier: String): Boolean def getEphemeralTable(tableIdentifier: String): Option[EphemeralTableModel] def getAllEphemeralTables: Seq[EphemeralTableModel] - def createEphemeralTable(ephemeralTable: EphemeralTableModel): Either[String, EphemeralTableModel] + def createEphemeralTable( + ephemeralTable: EphemeralTableModel): Either[String, EphemeralTableModel] def dropEphemeralTable(tableIdentifier: String): Unit def dropAllEphemeralTables(): Unit /** - * Ephemeral Status Functions - */ - protected[crossdata] def createEphemeralStatus(tableIdentifier: String, ephemeralStatusModel: EphemeralStatusModel): EphemeralStatusModel + * Ephemeral Status Functions + */ + protected[crossdata] def createEphemeralStatus( + tableIdentifier: String, + ephemeralStatusModel: EphemeralStatusModel): EphemeralStatusModel - protected[crossdata] def getEphemeralStatus(tableIdentifier: String): Option[EphemeralStatusModel] + protected[crossdata] def getEphemeralStatus( + tableIdentifier: String): Option[EphemeralStatusModel] protected[crossdata] def getAllEphemeralStatuses: Seq[EphemeralStatusModel] - protected[crossdata] def updateEphemeralStatus(tableIdentifier: String, status: EphemeralStatusModel): Unit + protected[crossdata] def updateEphemeralStatus(tableIdentifier: String, + status: EphemeralStatusModel): Unit protected[crossdata] def dropEphemeralStatus(tableIdentifier: String): Unit protected[crossdata] def dropAllEphemeralStatus(): Unit /** - * Ephemeral Queries Functions - */ + * Ephemeral Queries Functions + */ def existsEphemeralQuery(queryAlias: String): Boolean def getEphemeralQuery(queryAlias: String): Option[EphemeralQueryModel] def getAllEphemeralQueries: Seq[EphemeralQueryModel] - def createEphemeralQuery(ephemeralQuery: EphemeralQueryModel): Either[String, EphemeralQueryModel] + def createEphemeralQuery( + ephemeralQuery: EphemeralQueryModel): Either[String, EphemeralQueryModel] def dropEphemeralQuery(queryAlias: String): Unit diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/DerbyCatalog.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/DerbyCatalog.scala index cc5fdabca..3222dad34 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/DerbyCatalog.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/DerbyCatalog.scala @@ -57,7 +57,6 @@ object DerbyCatalog { } - /** * Default implementation of the [[persistent.PersistentCatalogWithCache]] with persistence using * Derby. @@ -65,7 +64,7 @@ object DerbyCatalog { * @param catalystConf An implementation of the [[CatalystConf]]. */ class DerbyCatalog(override val catalystConf: CatalystConf) - extends PersistentCatalogWithCache(catalystConf) { + extends PersistentCatalogWithCache(catalystConf) { import DerbyCatalog._ import XDCatalog._ @@ -94,9 +93,7 @@ class DerbyCatalog(override val catalystConf: CatalystConf) if (!schemaExists(DB, jdbcConnection)) { executeUpdate(s"CREATE SCHEMA $DB") - - executeUpdate( - s"""|CREATE TABLE $DB.$TableWithTableMetadata ( + executeUpdate(s"""|CREATE TABLE $DB.$TableWithTableMetadata ( |$DatabaseField VARCHAR(50), |$TableNameField VARCHAR(50), |$SchemaField LONG VARCHAR, @@ -106,16 +103,14 @@ class DerbyCatalog(override val catalystConf: CatalystConf) |$CrossdataVersionField LONG VARCHAR, |PRIMARY KEY ($DatabaseField,$TableNameField))""".stripMargin) - executeUpdate( - s"""|CREATE TABLE $DB.$TableWithViewMetadata ( + executeUpdate(s"""|CREATE TABLE $DB.$TableWithViewMetadata ( |$DatabaseField VARCHAR(50), |$TableNameField VARCHAR(50), |$SqlViewField LONG VARCHAR, |$CrossdataVersionField VARCHAR(30), |PRIMARY KEY ($DatabaseField,$TableNameField))""".stripMargin) - executeUpdate( - s"""|CREATE TABLE $DB.$TableWithAppJars ( + executeUpdate(s"""|CREATE TABLE $DB.$TableWithAppJars ( |$JarPath VARCHAR(100), |$AppAlias VARCHAR(50), |$AppClass VARCHAR(100), @@ -124,8 +119,7 @@ class DerbyCatalog(override val catalystConf: CatalystConf) //Index support if (!indexTableExists(DB, jdbcConnection)) { - executeUpdate( - s"""|CREATE TABLE $DB.$TableWithIndexMetadata ( + executeUpdate(s"""|CREATE TABLE $DB.$TableWithIndexMetadata ( |$DatabaseField VARCHAR(50), |$TableNameField VARCHAR(50), |$IndexNameField VARCHAR(50), @@ -142,23 +136,24 @@ class DerbyCatalog(override val catalystConf: CatalystConf) jdbcConnection } - def executeSQLCommand(sql: String): Unit = synchronized { using(connection.createStatement()) { statement => statement.executeUpdate(sql) } } - private def withConnectionWithoutCommit[T](f: Connection => T): T = synchronized { - try { - connection.setAutoCommit(false) - f(connection) - } finally { - connection.setAutoCommit(true) + private def withConnectionWithoutCommit[T](f: Connection => T): T = + synchronized { + try { + connection.setAutoCommit(false) + f(connection) + } finally { + connection.setAutoCommit(true) + } } - } - private def withStatement[T](sql: String)(f: PreparedStatement => T)(implicit conn: Connection = connection): T = + private def withStatement[T](sql: String)(f: PreparedStatement => T)(implicit conn: Connection = + connection): T = synchronized { using(conn.prepareStatement(sql)) { statement => f(statement) @@ -186,8 +181,12 @@ class DerbyCatalog(override val catalystConf: CatalystConf) val version = resultSet.getString(CrossdataVersionField) Some( - CrossdataTable(TableIdentifierNormalized(table, Some(database)), Option(deserializeUserSpecifiedSchema(schemaJSON)), datasource, - deserializePartitionColumn(partitionColumn), deserializeOptions(optsJSON), version) + CrossdataTable(TableIdentifierNormalized(table, Some(database)), + Option(deserializeUserSpecifiedSchema(schemaJSON)), + datasource, + deserializePartitionColumn(partitionColumn), + deserializeOptions(optsJSON), + version) ) } } @@ -204,13 +203,12 @@ class DerbyCatalog(override val catalystConf: CatalystConf) val clss = resultSet.getString(AppClass) Some( - CrossdataApp(jar, alias, clss) + CrossdataApp(jar, alias, clss) ) } } } - override def lookupView(viewIdentifier: ViewIdentifierNormalized): Option[String] = selectMetadata(TableWithViewMetadata, viewIdentifier) { resultSet => if (!resultSet.next) @@ -221,7 +219,6 @@ class DerbyCatalog(override val catalystConf: CatalystConf) override def lookupIndex(indexIdentifier: IndexIdentifierNormalized): Option[CrossdataIndex] = selectIndex(indexIdentifier) { resultSet => - if (!resultSet.next) { None } else { @@ -237,13 +234,17 @@ class DerbyCatalog(override val catalystConf: CatalystConf) val version = resultSet.getString(CrossdataVersionField) Some( - CrossdataIndex(TableIdentifierNormalized(table, Some(database)), IndexIdentifierNormalized(indexType, indexName), - deserializeSeq(indexedCols), pk, datasource, deserializeOptions(optsJSON), version) + CrossdataIndex(TableIdentifierNormalized(table, Some(database)), + IndexIdentifierNormalized(indexType, indexName), + deserializeSeq(indexedCols), + pk, + datasource, + deserializeOptions(optsJSON), + version) ) } } - override def persistTableMetadata(crossdataTable: CrossdataTable): Unit = withConnectionWithoutCommit { implicit conn => val tableSchema = serializeSchema(crossdataTable.schema.getOrElse(schemaNotFound())) @@ -252,10 +253,8 @@ class DerbyCatalog(override val catalystConf: CatalystConf) // check if the database-table exist in the persisted catalog selectMetadata(TableWithTableMetadata, crossdataTable.tableIdentifier) { resultSet => - if (!resultSet.next()) { - withStatement( - s"""|INSERT INTO $DB.$TableWithTableMetadata ( + withStatement(s"""|INSERT INTO $DB.$TableWithTableMetadata ( | $DatabaseField, $TableNameField, $SchemaField, $DatasourceField, $PartitionColumnField, $OptionsField, $CrossdataVersionField |) VALUES (?,?,?,?,?,?,?) """.stripMargin) { statement2 => @@ -271,9 +270,10 @@ class DerbyCatalog(override val catalystConf: CatalystConf) } else { withStatement( - s"""|UPDATE $DB.$TableWithTableMetadata + s"""|UPDATE $DB.$TableWithTableMetadata |SET $SchemaField=?, $DatasourceField=?,$PartitionColumnField=?,$OptionsField=?,$CrossdataVersionField=? - |WHERE $DatabaseField='${crossdataTable.tableIdentifier.database.getOrElse("")}' AND $TableNameField='${crossdataTable.tableIdentifier.table}'""".stripMargin) { + |WHERE $DatabaseField='${crossdataTable.tableIdentifier.database + .getOrElse("")}' AND $TableNameField='${crossdataTable.tableIdentifier.table}'""".stripMargin) { statement2 => statement2.setString(1, tableSchema) statement2.setString(2, crossdataTable.datasource) @@ -287,16 +287,14 @@ class DerbyCatalog(override val catalystConf: CatalystConf) } } - - override def persistViewMetadata(tableIdentifier: TableIdentifierNormalized, sqlText: String): Unit = + override def persistViewMetadata(tableIdentifier: TableIdentifierNormalized, + sqlText: String): Unit = withConnectionWithoutCommit { implicit conn => selectMetadata(TableWithViewMetadata, tableIdentifier) { resultSet => if (!resultSet.next()) { - withStatement( - s"""|INSERT INTO $DB.$TableWithViewMetadata ( + withStatement(s"""|INSERT INTO $DB.$TableWithViewMetadata ( | $DatabaseField, $TableNameField, $SqlViewField, $CrossdataVersionField |) VALUES (?,?,?,?)""".stripMargin) { statement2 => - statement2.setString(1, tableIdentifier.database.getOrElse("")) statement2.setString(2, tableIdentifier.table) statement2.setString(3, sqlText) @@ -305,8 +303,9 @@ class DerbyCatalog(override val catalystConf: CatalystConf) } } else { val prepped = connection.prepareStatement( - s"""|UPDATE $DB.$TableWithViewMetadata SET $SqlViewField=? - |WHERE $DatabaseField='${tableIdentifier.database.getOrElse("")}' AND $TableNameField='${tableIdentifier.table}' + s"""|UPDATE $DB.$TableWithViewMetadata SET $SqlViewField=? + |WHERE $DatabaseField='${tableIdentifier.database + .getOrElse("")}' AND $TableNameField='${tableIdentifier.table}' """.stripMargin) prepped.setString(1, sqlText) prepped.execute() @@ -315,17 +314,14 @@ class DerbyCatalog(override val catalystConf: CatalystConf) } } - override def persistIndexMetadata(crossdataIndex: CrossdataIndex): Unit = withConnectionWithoutCommit { implicit conn => - selectMetadata(TableWithIndexMetadata, crossdataIndex.tableIdentifier) { resultSet => val serializedIndexedCols = serializeSeq(crossdataIndex.indexedCols) val serializedOptions = serializeOptions(crossdataIndex.opts) if (!resultSet.next()) { - withStatement( - s"""|INSERT INTO $DB.$TableWithIndexMetadata ( + withStatement(s"""|INSERT INTO $DB.$TableWithIndexMetadata ( | $DatabaseField, $TableNameField, $IndexNameField, $IndexTypeField, $IndexedColsField, | $PKField, $DatasourceField, $OptionsField, $CrossdataVersionField |) VALUES (?,?,?,?,?,?,?,?,?)""".stripMargin) { statement2 => @@ -342,31 +338,34 @@ class DerbyCatalog(override val catalystConf: CatalystConf) } } else { //TODO: Support change index metadata? - sys.error(s"A global index already exists in table ${crossdataIndex.tableIdentifier.unquotedString}") + sys.error( + s"A global index already exists in table ${crossdataIndex.tableIdentifier.unquotedString}") } } } - override def saveAppMetadata(crossdataApp: CrossdataApp): Unit = withConnectionWithoutCommit { implicit conn => - withStatement(s"SELECT * FROM $DB.$TableWithAppJars WHERE $AppAlias= ?") { statement => statement.setString(1, crossdataApp.appAlias) withResultSet(statement) { resultSet => if (!resultSet.next()) { - withStatement(s"INSERT INTO $DB.$TableWithAppJars ($JarPath, $AppAlias, $AppClass) VALUES (?,?,?)") { statement2 => - statement2.setString(1, crossdataApp.jar) - statement2.setString(2, crossdataApp.appAlias) - statement2.setString(3, crossdataApp.appClass) - statement2.execute() + withStatement( + s"INSERT INTO $DB.$TableWithAppJars ($JarPath, $AppAlias, $AppClass) VALUES (?,?,?)") { + statement2 => + statement2.setString(1, crossdataApp.jar) + statement2.setString(2, crossdataApp.appAlias) + statement2.setString(3, crossdataApp.appClass) + statement2.execute() } } else { - withStatement(s"UPDATE $DB.$TableWithAppJars SET $JarPath=?, $AppClass=? WHERE $AppAlias='${crossdataApp.appAlias}'") { statement2 => - statement2.setString(1, crossdataApp.jar) - statement2.setString(2, crossdataApp.appClass) - statement2.execute() + withStatement( + s"UPDATE $DB.$TableWithAppJars SET $JarPath=?, $AppClass=? WHERE $AppAlias='${crossdataApp.appAlias}'") { + statement2 => + statement2.setString(1, crossdataApp.jar) + statement2.setString(2, crossdataApp.appClass) + statement2.execute() } } conn.commit() @@ -376,25 +375,27 @@ class DerbyCatalog(override val catalystConf: CatalystConf) override def dropTableMetadata(tableIdentifier: TableIdentifierNormalized): Unit = executeSQLCommand( - s"DELETE FROM $DB.$TableWithTableMetadata WHERE tableName='${tableIdentifier.table}' AND db='${tableIdentifier.database.getOrElse("")}'" + s"DELETE FROM $DB.$TableWithTableMetadata WHERE tableName='${tableIdentifier.table}' AND db='${tableIdentifier.database + .getOrElse("")}'" ) override def dropViewMetadata(viewIdentifier: ViewIdentifierNormalized): Unit = executeSQLCommand( - s"DELETE FROM $DB.$TableWithViewMetadata WHERE tableName='${viewIdentifier.table}' AND db='${viewIdentifier.database.getOrElse("")}'" + s"DELETE FROM $DB.$TableWithViewMetadata WHERE tableName='${viewIdentifier.table}' AND db='${viewIdentifier.database + .getOrElse("")}'" ) override def dropIndexMetadata(indexIdentifier: IndexIdentifierNormalized): Unit = executeSQLCommand( - s"DELETE FROM $DB.$TableWithIndexMetadata WHERE $IndexTypeField='${indexIdentifier.indexType}' AND $IndexNameField='${indexIdentifier.indexName}'" + s"DELETE FROM $DB.$TableWithIndexMetadata WHERE $IndexTypeField='${indexIdentifier.indexType}' AND $IndexNameField='${indexIdentifier.indexName}'" ) override def dropIndexMetadata(tableIdentifier: TableIdentifierNormalized): Unit = executeSQLCommand( - s"DELETE FROM $DB.$TableWithIndexMetadata WHERE $TableNameField='${tableIdentifier.table}' AND $DatabaseField='${tableIdentifier.database.getOrElse("")}'" + s"DELETE FROM $DB.$TableWithIndexMetadata WHERE $TableNameField='${tableIdentifier.table}' AND $DatabaseField='${tableIdentifier.database + .getOrElse("")}'" ) - override def dropAllTablesMetadata(): Unit = executeSQLCommand(s"DELETE FROM $DB.$TableWithTableMetadata") @@ -404,66 +405,79 @@ class DerbyCatalog(override val catalystConf: CatalystConf) override def dropAllIndexesMetadata(): Unit = executeSQLCommand(s"DELETE FROM $DB.$TableWithIndexMetadata") - override def isAvailable: Boolean = true - override def allRelations(databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = synchronized { - @tailrec - def getSequenceAux(resultset: ResultSet, next: Boolean, set: Set[TableIdentifierNormalized] = Set.empty): Set[TableIdentifierNormalized] = { - if (next) { - val database = resultset.getString(DatabaseField) - val table = resultset.getString(TableNameField) - val tableId = if (database.trim.isEmpty) TableIdentifierNormalized(table) else TableIdentifierNormalized(table, Option(database)) - getSequenceAux(resultset, resultset.next(), set + tableId) - } else { - set + override def allRelations( + databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = + synchronized { + @tailrec + def getSequenceAux( + resultset: ResultSet, + next: Boolean, + set: Set[TableIdentifierNormalized] = Set.empty): Set[TableIdentifierNormalized] = { + if (next) { + val database = resultset.getString(DatabaseField) + val table = resultset.getString(TableNameField) + val tableId = + if (database.trim.isEmpty) TableIdentifierNormalized(table) + else TableIdentifierNormalized(table, Option(database)) + getSequenceAux(resultset, resultset.next(), set + tableId) + } else { + set + } } - } - val statement = connection.createStatement - val dbFilter = databaseName.fold("")(dbName => s"WHERE $DatabaseField ='${dbName.normalizedString}'") - val resultSet = statement.executeQuery(s"SELECT $DatabaseField, $TableNameField FROM $DB.$TableWithTableMetadata $dbFilter") + val statement = connection.createStatement + val dbFilter = + databaseName.fold("")(dbName => s"WHERE $DatabaseField ='${dbName.normalizedString}'") + val resultSet = statement.executeQuery( + s"SELECT $DatabaseField, $TableNameField FROM $DB.$TableWithTableMetadata $dbFilter") - getSequenceAux(resultSet, resultSet.next).toSeq - } + getSequenceAux(resultSet, resultSet.next).toSeq + } - private def selectMetadata[T](targetTable: String, tableIdentifier: TableIdentifierNormalized)(f: ResultSet => T): T = - withStatement(s"SELECT * FROM $DB.$targetTable WHERE $DatabaseField= ? AND $TableNameField= ?") { statement => - statement.setString(1, tableIdentifier.database.getOrElse("")) - statement.setString(2, tableIdentifier.table) + private def selectMetadata[T](targetTable: String, tableIdentifier: TableIdentifierNormalized)( + f: ResultSet => T): T = + withStatement(s"SELECT * FROM $DB.$targetTable WHERE $DatabaseField= ? AND $TableNameField= ?") { + statement => + statement.setString(1, tableIdentifier.database.getOrElse("")) + statement.setString(2, tableIdentifier.table) - withResultSet(statement) { resultSet => - f(resultSet) - } + withResultSet(statement) { resultSet => + f(resultSet) + } } - private def selectIndex[T](indexIdentifier: IndexIdentifierNormalized)(f: ResultSet => T): T = - withStatement(s"SELECT * FROM $DB.$TableWithIndexMetadata WHERE $IndexNameField= ? AND $IndexTypeField= ?") { statement => - statement.setString(1, indexIdentifier.indexName) - statement.setString(2, indexIdentifier.indexType) + withStatement( + s"SELECT * FROM $DB.$TableWithIndexMetadata WHERE $IndexNameField= ? AND $IndexTypeField= ?") { + statement => + statement.setString(1, indexIdentifier.indexName) + statement.setString(2, indexIdentifier.indexType) - withResultSet(statement) { resultSet => - f(resultSet) - } + withResultSet(statement) { resultSet => + f(resultSet) + } } - - private def indexTableExists(schema: String, connection: Connection): Boolean = tableSchemaExists(schema, TableWithIndexMetadata, connection) + private def indexTableExists(schema: String, connection: Connection): Boolean = + tableSchemaExists(schema, TableWithIndexMetadata, connection) private def tableSchemaExists(schema: String, table: String, connection: Connection): Boolean = - withStatement( - s"""|SELECT * FROM SYS.SYSSCHEMAS sch + withStatement(s"""|SELECT * FROM SYS.SYSSCHEMAS sch |LEFT JOIN SYS.SYSTABLES tb ON tb.schemaid = sch.schemaid - |WHERE sch.SCHEMANAME='$schema' AND tb.TABLENAME='${table.toUpperCase}'""".stripMargin) { statement => - withResultSet(statement) { resultSet => - resultSet.next() - } + |WHERE sch.SCHEMANAME='$schema' AND tb.TABLENAME='${table.toUpperCase}'""".stripMargin) { + statement => + withResultSet(statement) { resultSet => + resultSet.next() + } }(connection) - override def lookupIndexByTableIdentifier(tableIdentifier: TableIdentifierNormalized): Option[CrossdataIndex] = { + override def lookupIndexByTableIdentifier( + tableIdentifier: TableIdentifierNormalized): Option[CrossdataIndex] = { val query = - s"SELECT * FROM $DB.$TableWithIndexMetadata WHERE $TableNameField='${tableIdentifier.table}' AND $DatabaseField='${tableIdentifier.database.getOrElse("")}'" + s"SELECT * FROM $DB.$TableWithIndexMetadata WHERE $TableNameField='${tableIdentifier.table}' AND $DatabaseField='${tableIdentifier.database + .getOrElse("")}'" withStatement(query) { statement => withResultSet(statement) { resultSet => @@ -482,11 +496,16 @@ class DerbyCatalog(override val catalystConf: CatalystConf) val version = resultSet.getString(CrossdataVersionField) Some( - CrossdataIndex(TableIdentifierNormalized(table, Some(database)), IndexIdentifierNormalized(indexType, indexName), - deserializeSeq(indexedCols), pk, datasource, deserializeOptions(optsJSON), version) + CrossdataIndex(TableIdentifierNormalized(table, Some(database)), + IndexIdentifierNormalized(indexType, indexName), + deserializeSeq(indexedCols), + pk, + datasource, + deserializeOptions(optsJSON), + version) ) } } } } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/MySQLXDCatalog.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/MySQLXDCatalog.scala index f94608747..d3be9be59 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/MySQLXDCatalog.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/MySQLXDCatalog.scala @@ -68,12 +68,11 @@ object MySQLXDCatalog { * @param catalystConf An implementation of the [[CatalystConf]]. */ class MySQLXDCatalog(override val catalystConf: CatalystConf) - extends PersistentCatalogWithCache(catalystConf) { + extends PersistentCatalogWithCache(catalystConf) { import MySQLXDCatalog._ import XDCatalog._ - private val config = XDContext.catalogConfig private val db = config.getString(Database) private val tableWithTableMetadata = config.getString(TableWithTableMetadata) @@ -95,9 +94,9 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) jdbcConnection.createStatement().executeUpdate(s"CREATE SCHEMA IF NOT EXISTS $db") - - jdbcConnection.createStatement().executeUpdate( - s"""|CREATE TABLE IF NOT EXISTS $db.$tableWithTableMetadata ( + jdbcConnection + .createStatement() + .executeUpdate(s"""|CREATE TABLE IF NOT EXISTS $db.$tableWithTableMetadata ( |$DatabaseField VARCHAR(50), |$TableNameField VARCHAR(50), |$SchemaField TEXT, @@ -107,24 +106,25 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) |$CrossdataVersionField TEXT, |PRIMARY KEY ($DatabaseField,$TableNameField))""".stripMargin) - jdbcConnection.createStatement().executeUpdate( - s"""|CREATE TABLE IF NOT EXISTS $db.$tableWithViewMetadata ( + jdbcConnection + .createStatement() + .executeUpdate(s"""|CREATE TABLE IF NOT EXISTS $db.$tableWithViewMetadata ( |$DatabaseField VARCHAR(50), |$TableNameField VARCHAR(50), |$SqlViewField TEXT, |$CrossdataVersionField VARCHAR(30), |PRIMARY KEY ($DatabaseField,$TableNameField))""".stripMargin) - jdbcConnection.createStatement().executeUpdate( - s"""|CREATE TABLE $db.$tableWithAppJars ( + jdbcConnection.createStatement().executeUpdate(s"""|CREATE TABLE $db.$tableWithAppJars ( |$JarPath VARCHAR(100), |$AppAlias VARCHAR(50), |$AppClass VARCHAR(100), |PRIMARY KEY ($AppAlias))""".stripMargin) //Index support - jdbcConnection.createStatement().executeUpdate( - s"""|CREATE TABLE IF NOT EXISTS $db.$TableWithIndexMetadata ( + jdbcConnection + .createStatement() + .executeUpdate(s"""|CREATE TABLE IF NOT EXISTS $db.$TableWithIndexMetadata ( |$DatabaseField VARCHAR(50), |$TableNameField VARCHAR(50), |$IndexNameField VARCHAR(50), @@ -145,7 +145,6 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) } } - override def lookupTable(tableIdentifier: TableIdentifierNormalized): Option[CrossdataTable] = { val resultSet = selectMetadata(tableWithTableMetadata, tableIdentifier) @@ -163,19 +162,29 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) val version = resultSet.getString(CrossdataVersionField) Some( - CrossdataTable(TableIdentifierNormalized(table, Some(database)), Option(deserializeUserSpecifiedSchema(schemaJSON)), datasource, deserializePartitionColumn(partitionColumn), deserializeOptions(optsJSON), version) + CrossdataTable(TableIdentifierNormalized(table, Some(database)), + Option(deserializeUserSpecifiedSchema(schemaJSON)), + datasource, + deserializePartitionColumn(partitionColumn), + deserializeOptions(optsJSON), + version) ) } } - - override def allRelations(databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = { + override def allRelations( + databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = { @tailrec - def getSequenceAux(resultset: ResultSet, next: Boolean, set: Set[TableIdentifierNormalized] = Set.empty): Set[TableIdentifierNormalized] = { + def getSequenceAux( + resultset: ResultSet, + next: Boolean, + set: Set[TableIdentifierNormalized] = Set.empty): Set[TableIdentifierNormalized] = { if (next) { val database = resultset.getString(DatabaseField) val table = resultset.getString(TableNameField) - val tableId = if (database.trim.isEmpty) TableIdentifierNormalized(table) else TableIdentifierNormalized(table, Option(database)) + val tableId = + if (database.trim.isEmpty) TableIdentifierNormalized(table) + else TableIdentifierNormalized(table, Option(database)) getSequenceAux(resultset, resultset.next(), set + tableId) } else { set @@ -183,8 +192,10 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) } val statement = connection.createStatement - val dbFilter = databaseName.fold("")(dbName => s"WHERE $DatabaseField ='${dbName.normalizedString}'") - val resultSet = statement.executeQuery(s"SELECT $DatabaseField, $TableNameField FROM $db.$tableWithTableMetadata $dbFilter") + val dbFilter = + databaseName.fold("")(dbName => s"WHERE $DatabaseField ='${dbName.normalizedString}'") + val resultSet = statement.executeQuery( + s"SELECT $DatabaseField, $TableNameField FROM $db.$tableWithTableMetadata $dbFilter") getSequenceAux(resultSet, resultSet.next).toSeq } @@ -203,8 +214,7 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) if (!resultSet.isBeforeFirst) { resultSet.close() - val prepped = connection.prepareStatement( - s"""|INSERT INTO $db.$tableWithTableMetadata ( + val prepped = connection.prepareStatement(s"""|INSERT INTO $db.$tableWithTableMetadata ( | $DatabaseField, $TableNameField, $SchemaField, $DatasourceField, $PartitionColumnField, $OptionsField, $CrossdataVersionField |) VALUES (?,?,?,?,?,?,?) """.stripMargin) @@ -219,11 +229,11 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) prepped.close() } else { resultSet.close() - val prepped = - connection.prepareStatement( + val prepped = connection.prepareStatement( s"""|UPDATE $db.$tableWithTableMetadata |SET $SchemaField=?, $DatasourceField=?,$PartitionColumnField=?,$OptionsField=?,$CrossdataVersionField=? - |WHERE $DatabaseField='${crossdataTable.tableIdentifier.database.getOrElse("")}' AND $TableNameField='${crossdataTable.tableIdentifier.table}'; + |WHERE $DatabaseField='${crossdataTable.tableIdentifier.database + .getOrElse("")}' AND $TableNameField='${crossdataTable.tableIdentifier.table}'; """.stripMargin.replaceAll("\n", " ")) prepped.setString(1, tableSchema) @@ -240,14 +250,14 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) connection.setAutoCommit(true) } - override def dropTableMetadata(tableIdentifier: ViewIdentifierNormalized): Unit = - connection.createStatement.executeUpdate(s"DELETE FROM $db.$tableWithTableMetadata WHERE tableName='${tableIdentifier.table}' AND db='${tableIdentifier.database.getOrElse("")}'") + connection.createStatement.executeUpdate( + s"DELETE FROM $db.$tableWithTableMetadata WHERE tableName='${tableIdentifier.table}' AND db='${tableIdentifier.database + .getOrElse("")}'") override def dropAllTablesMetadata(): Unit = connection.createStatement.executeUpdate(s"TRUNCATE $db.$tableWithTableMetadata") - override def lookupView(tableIdentifier: TableIdentifierNormalized): Option[String] = { val resultSet = selectMetadata(tableWithViewMetadata, tableIdentifier) if (!resultSet.isBeforeFirst) { @@ -258,15 +268,15 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) } } - override def persistViewMetadata(tableIdentifier: TableIdentifierNormalized, sqlText: String): Unit = + override def persistViewMetadata(tableIdentifier: TableIdentifierNormalized, + sqlText: String): Unit = try { connection.setAutoCommit(false) val resultSet = selectMetadata(tableWithViewMetadata, tableIdentifier) if (!resultSet.isBeforeFirst) { resultSet.close() - val prepped = connection.prepareStatement( - s"""|INSERT INTO $db.$tableWithViewMetadata ( + val prepped = connection.prepareStatement(s"""|INSERT INTO $db.$tableWithViewMetadata ( | $DatabaseField, $TableNameField, $SqlViewField, $CrossdataVersionField |) VALUES (?,?,?,?) """.stripMargin) @@ -278,10 +288,10 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) prepped.close() } else { resultSet.close() - val prepped = - connection.prepareStatement( + val prepped = connection.prepareStatement( s"""|UPDATE $db.$tableWithViewMetadata SET $SqlViewField=? - |WHERE $DatabaseField='${tableIdentifier.database.getOrElse("")}' AND $TableNameField='${tableIdentifier.table}' + |WHERE $DatabaseField='${tableIdentifier.database + .getOrElse("")}' AND $TableNameField='${tableIdentifier.table}' """.stripMargin.replaceAll("\n", " ")) prepped.setString(1, sqlText) @@ -294,9 +304,11 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) connection.setAutoCommit(true) } - private def selectMetadata(targetTable: String, tableIdentifier: TableIdentifierNormalized): ResultSet = { + private def selectMetadata(targetTable: String, + tableIdentifier: TableIdentifierNormalized): ResultSet = { - val preparedStatement = connection.prepareStatement(s"SELECT * FROM $db.$targetTable WHERE $DatabaseField= ? AND $TableNameField= ?") + val preparedStatement = connection.prepareStatement( + s"SELECT * FROM $db.$targetTable WHERE $DatabaseField= ? AND $TableNameField= ?") preparedStatement.setString(1, tableIdentifier.database.getOrElse("")) preparedStatement.setString(2, tableIdentifier.table) @@ -304,30 +316,28 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) } - override def dropViewMetadata(viewIdentifier: ViewIdentifierNormalized): Unit = connection.createStatement.executeUpdate( - s"DELETE FROM $db.$tableWithViewMetadata WHERE tableName='${viewIdentifier.table}' AND db='${viewIdentifier.database.getOrElse("")}'") - + s"DELETE FROM $db.$tableWithViewMetadata WHERE tableName='${viewIdentifier.table}' AND db='${viewIdentifier.database + .getOrElse("")}'") override def dropAllViewsMetadata(): Unit = { connection.createStatement.executeUpdate(s"DELETE FROM $db.$tableWithViewMetadata") } - override def saveAppMetadata(crossdataApp: CrossdataApp): Unit = try { connection.setAutoCommit(false) - val preparedStatement = connection.prepareStatement(s"SELECT * FROM $db.$tableWithAppJars WHERE $AppAlias= ?") + val preparedStatement = + connection.prepareStatement(s"SELECT * FROM $db.$tableWithAppJars WHERE $AppAlias= ?") preparedStatement.setString(1, crossdataApp.appAlias) val resultSet = preparedStatement.executeQuery() preparedStatement.close() if (!resultSet.next()) { resultSet.close() - val prepped = connection.prepareStatement( - s"""|INSERT INTO $db.$tableWithAppJars ( + val prepped = connection.prepareStatement(s"""|INSERT INTO $db.$tableWithAppJars ( | $JarPath, $AppAlias, $AppClass |) VALUES (?,?,?) """.stripMargin) @@ -338,8 +348,8 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) prepped.close() } else { resultSet.close() - val prepped = connection.prepareStatement( - s"""|UPDATE $db.$tableWithAppJars SET $JarPath=?, $AppClass=? + val prepped = + connection.prepareStatement(s"""|UPDATE $db.$tableWithAppJars SET $JarPath=?, $AppClass=? |WHERE $AppAlias='${crossdataApp.appAlias}' """.stripMargin) prepped.setString(1, crossdataApp.jar) @@ -354,7 +364,8 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) override def getApp(alias: String): Option[CrossdataApp] = { - val preparedStatement = connection.prepareStatement(s"SELECT * FROM $db.$tableWithAppJars WHERE $AppAlias= ?") + val preparedStatement = + connection.prepareStatement(s"SELECT * FROM $db.$tableWithAppJars WHERE $AppAlias= ?") preparedStatement.setString(1, alias) val resultSet = preparedStatement.executeQuery() @@ -370,14 +381,13 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) resultSet.close() preparedStatement.close() Some( - CrossdataApp(jar, alias, clss) + CrossdataApp(jar, alias, clss) ) } } override def isAvailable: Boolean = Option(connection).isDefined - override def persistIndexMetadata(crossdataIndex: CrossdataIndex): Unit = try { connection.setAutoCommit(false) @@ -388,8 +398,7 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) val serializedOptions = serializeOptions(crossdataIndex.opts) if (!resultSet.next()) { - val prepped = connection.prepareStatement( - s"""|INSERT INTO $db.$TableWithIndexMetadata ( + val prepped = connection.prepareStatement(s"""|INSERT INTO $db.$TableWithIndexMetadata ( | $DatabaseField, $TableNameField, $IndexNameField, $IndexTypeField, $IndexedColsField, | $PKField, $DatasourceField, $OptionsField, $CrossdataVersionField |) VALUES (?,?,?,?,?,?,?,?,?) @@ -406,7 +415,8 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) prepped.execute() } else { //TODO: Support change index metadata? - sys.error(s"A global index already exists in table ${crossdataIndex.tableIdentifier.unquotedString}") + sys.error( + s"A global index already exists in table ${crossdataIndex.tableIdentifier.unquotedString}") } } finally { connection.setAutoCommit(true) @@ -414,7 +424,7 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) override def dropIndexMetadata(indexIdentifier: IndexIdentifierNormalized): Unit = connection.createStatement.executeUpdate( - s"DELETE FROM $db.$TableWithIndexMetadata WHERE $IndexTypeField='${indexIdentifier.indexType}' AND $IndexNameField='${indexIdentifier.indexName}'" + s"DELETE FROM $db.$TableWithIndexMetadata WHERE $IndexTypeField='${indexIdentifier.indexType}' AND $IndexNameField='${indexIdentifier.indexName}'" ) override def dropAllIndexesMetadata(): Unit = @@ -438,14 +448,20 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) val version = resultSet.getString(CrossdataVersionField) Option( - CrossdataIndex(TableIdentifierNormalized(table, Option(database)), IndexIdentifierNormalized(indexType, indexName), - deserializeSeq(indexedCols), pk, datasource, deserializeOptions(optsJSON), version) + CrossdataIndex(TableIdentifierNormalized(table, Option(database)), + IndexIdentifierNormalized(indexType, indexName), + deserializeSeq(indexedCols), + pk, + datasource, + deserializeOptions(optsJSON), + version) ) } } private def selectIndex(indexIdentifier: IndexIdentifierNormalized): ResultSet = { - val preparedStatement = connection.prepareStatement(s"SELECT * FROM $db.$TableWithIndexMetadata WHERE $IndexNameField= ? AND $IndexTypeField= ?") + val preparedStatement = connection.prepareStatement( + s"SELECT * FROM $db.$TableWithIndexMetadata WHERE $IndexNameField= ? AND $IndexTypeField= ?") preparedStatement.setString(1, indexIdentifier.indexName) preparedStatement.setString(2, indexIdentifier.indexType) preparedStatement.executeQuery() @@ -453,12 +469,15 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) override def dropIndexMetadata(tableIdentifier: TableIdentifierNormalized): Unit = connection.createStatement.executeUpdate( - s"DELETE FROM $db.$TableWithIndexMetadata WHERE $TableNameField='${tableIdentifier.table}' AND $DatabaseField='${tableIdentifier.database.getOrElse("")}'" + s"DELETE FROM $db.$TableWithIndexMetadata WHERE $TableNameField='${tableIdentifier.table}' AND $DatabaseField='${tableIdentifier.database + .getOrElse("")}'" ) - override def lookupIndexByTableIdentifier(tableIdentifier: TableIdentifierNormalized): Option[CrossdataIndex] = { + override def lookupIndexByTableIdentifier( + tableIdentifier: TableIdentifierNormalized): Option[CrossdataIndex] = { val query = - s"SELECT * FROM $db.$TableWithIndexMetadata WHERE $TableNameField='${tableIdentifier.table}' AND $DatabaseField='${tableIdentifier.database.getOrElse("")}'" + s"SELECT * FROM $db.$TableWithIndexMetadata WHERE $TableNameField='${tableIdentifier.table}' AND $DatabaseField='${tableIdentifier.database + .getOrElse("")}'" val preparedStatement = connection.prepareStatement(query) val resultSet = preparedStatement.executeQuery() if (!resultSet.next) { @@ -476,9 +495,14 @@ class MySQLXDCatalog(override val catalystConf: CatalystConf) val version = resultSet.getString(CrossdataVersionField) Option( - CrossdataIndex(TableIdentifierNormalized(table, Option(database)), IndexIdentifierNormalized(indexType, indexName), - deserializeSeq(indexedCols), pk, datasource, deserializeOptions(optsJSON), version) + CrossdataIndex(TableIdentifierNormalized(table, Option(database)), + IndexIdentifierNormalized(indexType, indexName), + deserializeSeq(indexedCols), + pk, + datasource, + deserializeOptions(optsJSON), + version) ) } } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/PersistentCatalogWithCache.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/PersistentCatalogWithCache.scala index 036b82fc8..ea1dd230a 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/PersistentCatalogWithCache.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/PersistentCatalogWithCache.scala @@ -26,13 +26,13 @@ import org.apache.spark.sql.crossdata.util.CreateRelationUtil import scala.collection.mutable - /** * PersistentCatalog aims to provide a mechanism to persist the * [[org.apache.spark.sql.catalyst.analysis.Catalog]] metadata. */ -abstract class PersistentCatalogWithCache(catalystConf: CatalystConf) extends XDPersistentCatalog - with Serializable { +abstract class PersistentCatalogWithCache(catalystConf: CatalystConf) + extends XDPersistentCatalog + with Serializable { import CreateRelationUtil._ @@ -40,7 +40,8 @@ abstract class PersistentCatalogWithCache(catalystConf: CatalystConf) extends XD val viewCache: mutable.Map[TableIdentifierNormalized, LogicalPlan] = mutable.Map.empty val indexCache: mutable.Map[TableIdentifierNormalized, CrossdataIndex] = mutable.Map.empty - override final def relation(relationIdentifier: TableIdentifierNormalized)(implicit sqlContext: SQLContext): Option[LogicalPlan] = + override final def relation(relationIdentifier: TableIdentifierNormalized)( + implicit sqlContext: SQLContext): Option[LogicalPlan] = (tableCache get relationIdentifier) orElse (viewCache get relationIdentifier) orElse { logInfo(s"PersistentCatalog: Looking up table ${relationIdentifier.unquotedString}") lookupTable(relationIdentifier) map { crossdataTable => @@ -57,9 +58,12 @@ abstract class PersistentCatalogWithCache(catalystConf: CatalystConf) extends XD } } - override final def refreshCache(tableIdent: ViewIdentifierNormalized): Unit = tableCache clear + override final def refreshCache(tableIdent: ViewIdentifierNormalized): Unit = + tableCache clear - override final def saveView(viewIdentifier: ViewIdentifierNormalized, plan: LogicalPlan, sqlText: String)(implicit sqlContext:SQLContext): Unit = { + override final def saveView(viewIdentifier: ViewIdentifierNormalized, + plan: LogicalPlan, + sqlText: String)(implicit sqlContext: SQLContext): Unit = { import XDCatalogCommon._ def checkPlan(plan: LogicalPlan): Unit = { plan collect { @@ -83,7 +87,8 @@ abstract class PersistentCatalogWithCache(catalystConf: CatalystConf) extends XD } } - override final def saveTable(crossdataTable: CrossdataTable, table: LogicalPlan)(implicit sqlContext:SQLContext): Unit = { + override final def saveTable(crossdataTable: CrossdataTable, table: LogicalPlan)( + implicit sqlContext: SQLContext): Unit = { val tableIdentifier = crossdataTable.tableIdentifier if (relation(tableIdentifier)(sqlContext).isDefined) { @@ -100,7 +105,7 @@ abstract class PersistentCatalogWithCache(catalystConf: CatalystConf) extends XD val indexIdentifier = crossdataIndex.indexIdentifier - if(lookupIndex(indexIdentifier).isDefined) { + if (lookupIndex(indexIdentifier).isDefined) { logWarning(s"The index $indexIdentifier already exists") throw new UnsupportedOperationException(s"The index $indexIdentifier already exists") } else { @@ -129,9 +134,11 @@ abstract class PersistentCatalogWithCache(catalystConf: CatalystConf) extends XD override final def dropIndex(indexIdentifer: IndexIdentifierNormalized): Unit = { - val found: Option[(TableIdentifierNormalized, CrossdataIndex)] = indexCache find { case(key,value) => value.indexIdentifier == indexIdentifer} + val found: Option[(TableIdentifierNormalized, CrossdataIndex)] = indexCache find { + case (key, value) => value.indexIdentifier == indexIdentifer + } - if(found.isDefined) indexCache remove found.get._1 + if (found.isDefined) indexCache remove found.get._1 dropIndexMetadata(indexIdentifer) } @@ -139,7 +146,6 @@ abstract class PersistentCatalogWithCache(catalystConf: CatalystConf) extends XD override final def tableHasIndex(tableIdentifier: TableIdentifierNormalized): Boolean = indexCache.contains(tableIdentifier) - override final def dropAllViews(): Unit = { viewCache.clear dropAllViewsMetadata() @@ -155,11 +161,11 @@ abstract class PersistentCatalogWithCache(catalystConf: CatalystConf) extends XD dropAllIndexesMetadata() } - protected def schemaNotFound() = throw new RuntimeException("the schema must be non empty") + protected def schemaNotFound() = + throw new RuntimeException("the schema must be non empty") //New Methods - def lookupView(viewIdentifier: ViewIdentifierNormalized): Option[String] def persistTableMetadata(crossdataTable: CrossdataTable): Unit @@ -182,4 +188,4 @@ abstract class PersistentCatalogWithCache(catalystConf: CatalystConf) extends XD def dropAllIndexesMetadata(): Unit -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/PostgreSQLXDCatalog.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/PostgreSQLXDCatalog.scala index fa1c8c3dc..17a2a30d2 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/PostgreSQLXDCatalog.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/PostgreSQLXDCatalog.scala @@ -68,7 +68,7 @@ object PostgreSQLXDCatalog { * @param catalystConf An implementation of the [[CatalystConf]]. */ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: CatalystConf) - extends PersistentCatalogWithCache(catalystConf) { + extends PersistentCatalogWithCache(catalystConf) { import PostgreSQLXDCatalog._ import XDCatalog._ @@ -92,11 +92,10 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat val jdbcConnection = DriverManager.getConnection(url, user, pass) // CREATE PERSISTENT METADATA TABLE - if(!schemaExists(db, jdbcConnection)) + if (!schemaExists(db, jdbcConnection)) jdbcConnection.createStatement().executeUpdate(s"CREATE SCHEMA $db") - jdbcConnection.createStatement().executeUpdate( - s"""|CREATE TABLE IF NOT EXISTS $db.$table ( + jdbcConnection.createStatement().executeUpdate(s"""|CREATE TABLE IF NOT EXISTS $db.$table ( |$DatabaseField VARCHAR(50), |$TableNameField VARCHAR(50), |$SchemaField TEXT, @@ -106,23 +105,26 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat |$CrossdataVersionField TEXT, |PRIMARY KEY ($DatabaseField,$TableNameField))""".stripMargin) - jdbcConnection.createStatement().executeUpdate( - s"""|CREATE TABLE IF NOT EXISTS $db.$tableWithViewMetadata ( + jdbcConnection + .createStatement() + .executeUpdate(s"""|CREATE TABLE IF NOT EXISTS $db.$tableWithViewMetadata ( |$DatabaseField VARCHAR(50), |$TableNameField VARCHAR(50), |$SqlViewField TEXT, |$CrossdataVersionField VARCHAR(30), |PRIMARY KEY ($DatabaseField,$TableNameField))""".stripMargin) - jdbcConnection.createStatement().executeUpdate( - s"""|CREATE TABLE IF NOT EXISTS $db.$tableWithAppJars ( + jdbcConnection + .createStatement() + .executeUpdate(s"""|CREATE TABLE IF NOT EXISTS $db.$tableWithAppJars ( |$JarPath VARCHAR(100), |$AppAlias VARCHAR(50), |$AppClass VARCHAR(100), |PRIMARY KEY ($AppAlias))""".stripMargin) - jdbcConnection.createStatement().executeUpdate( - s"""|CREATE TABLE IF NOT EXISTS $db.$TableWithIndexMetadata ( + jdbcConnection + .createStatement() + .executeUpdate(s"""|CREATE TABLE IF NOT EXISTS $db.$TableWithIndexMetadata ( |$DatabaseField VARCHAR(50), |$TableNameField VARCHAR(50), |$IndexNameField VARCHAR(50), @@ -135,7 +137,6 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat |UNIQUE ($IndexNameField, $IndexTypeField), |PRIMARY KEY ($DatabaseField,$TableNameField))""".stripMargin) - jdbcConnection } catch { case e: Exception => @@ -145,10 +146,10 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat } - override def lookupTable(tableIdentifier: TableIdentifierNormalized): Option[CrossdataTable] = { - val preparedStatement = connection.prepareStatement(s"SELECT * FROM $db.$table WHERE $DatabaseField= ? AND $TableNameField= ?") + val preparedStatement = connection.prepareStatement( + s"SELECT * FROM $db.$table WHERE $DatabaseField= ? AND $TableNameField= ?") preparedStatement.setString(1, tableIdentifier.database.getOrElse("")) preparedStatement.setString(2, tableIdentifier.table) val resultSet = preparedStatement.executeQuery() @@ -166,19 +167,29 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat val version = resultSet.getString(CrossdataVersionField) Some( - CrossdataTable(TableIdentifierNormalized(table, Some(database)), Option(deserializeUserSpecifiedSchema(schemaJSON)), datasource, deserializePartitionColumn(partitionColumn), deserializeOptions(optsJSON), version) + CrossdataTable(TableIdentifierNormalized(table, Some(database)), + Option(deserializeUserSpecifiedSchema(schemaJSON)), + datasource, + deserializePartitionColumn(partitionColumn), + deserializeOptions(optsJSON), + version) ) } } - - override def allRelations(databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = { + override def allRelations( + databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = { @tailrec - def getSequenceAux(resultset: ResultSet, next: Boolean, set: Set[TableIdentifierNormalized] = Set.empty): Set[TableIdentifierNormalized] = { + def getSequenceAux( + resultset: ResultSet, + next: Boolean, + set: Set[TableIdentifierNormalized] = Set.empty): Set[TableIdentifierNormalized] = { if (next) { val database = resultset.getString(DatabaseField) val table = resultset.getString(TableNameField) - val tableId = if (database.trim.isEmpty) TableIdentifierNormalized(table) else TableIdentifierNormalized(table, Option(database)) + val tableId = + if (database.trim.isEmpty) TableIdentifierNormalized(table) + else TableIdentifierNormalized(table, Option(database)) getSequenceAux(resultset, resultset.next(), set + tableId) } else { set @@ -186,8 +197,10 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat } val statement = connection.createStatement - val dbFilter = databaseName.fold("")(dbName => s"WHERE $DatabaseField ='${dbName.normalizedString}'") - val resultSet = statement.executeQuery(s"SELECT $DatabaseField, $TableNameField FROM $db.$table $dbFilter") + val dbFilter = + databaseName.fold("")(dbName => s"WHERE $DatabaseField ='${dbName.normalizedString}'") + val resultSet = + statement.executeQuery(s"SELECT $DatabaseField, $TableNameField FROM $db.$table $dbFilter") getSequenceAux(resultSet, resultSet.next).toSeq } @@ -201,7 +214,8 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat connection.setAutoCommit(false) // check if the database-table exist in the persisted catalog - val preparedStatement = connection.prepareStatement(s"SELECT * FROM $db.$table WHERE $DatabaseField= ? AND $TableNameField= ?") + val preparedStatement = connection.prepareStatement( + s"SELECT * FROM $db.$table WHERE $DatabaseField= ? AND $TableNameField= ?") preparedStatement.setString(1, crossdataTable.tableIdentifier.database.getOrElse("")) preparedStatement.setString(2, crossdataTable.tableIdentifier.table) val resultSet = preparedStatement.executeQuery() @@ -209,8 +223,7 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat if (!resultSet.isBeforeFirst) { resultSet.close() - val prepped = connection.prepareStatement( - s"""|INSERT INTO $db.$table ( + val prepped = connection.prepareStatement(s"""|INSERT INTO $db.$table ( | $DatabaseField, $TableNameField, $SchemaField, $DatasourceField, $PartitionColumnField, $OptionsField, $CrossdataVersionField |) VALUES (?,?,?,?,?,?,?) """.stripMargin) @@ -223,12 +236,12 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat prepped.setString(7, CrossdataVersion) prepped.execute() prepped.close() - } - else { + } else { resultSet.close() val prepped = connection.prepareStatement( - s"""|UPDATE $db.$table SET $SchemaField=?, $DatasourceField=?,$PartitionColumnField=?,$OptionsField=?,$CrossdataVersionField=? - |WHERE $DatabaseField='${crossdataTable.tableIdentifier.database.getOrElse("")}' AND $TableNameField='${crossdataTable.tableIdentifier.table}'; + s"""|UPDATE $db.$table SET $SchemaField=?, $DatasourceField=?,$PartitionColumnField=?,$OptionsField=?,$CrossdataVersionField=? + |WHERE $DatabaseField='${crossdataTable.tableIdentifier.database + .getOrElse("")}' AND $TableNameField='${crossdataTable.tableIdentifier.table}'; """.stripMargin.replaceAll("\n", " ")) prepped.setString(1, tableSchema) @@ -243,15 +256,18 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat connection.setAutoCommit(true) } - override def dropTableMetadata(tableIdentifier: ViewIdentifierNormalized): Unit = - connection.createStatement.executeUpdate(s"DELETE FROM $db.$table WHERE tableName='${tableIdentifier.table}' AND db='${tableIdentifier.database.getOrElse("")}'") + connection.createStatement.executeUpdate( + s"DELETE FROM $db.$table WHERE tableName='${tableIdentifier.table}' AND db='${tableIdentifier.database + .getOrElse("")}'") - override def dropAllTablesMetadata(): Unit = connection.createStatement.executeUpdate(s"TRUNCATE $db.$table") + override def dropAllTablesMetadata(): Unit = + connection.createStatement.executeUpdate(s"TRUNCATE $db.$table") def schemaExists(schema: String, connection: Connection): Boolean = { val statement = connection.createStatement() - val result = statement.executeQuery(s"SELECT schema_name FROM information_schema.schemata WHERE schema_name = '$schema';") + val result = statement.executeQuery( + s"SELECT schema_name FROM information_schema.schemata WHERE schema_name = '$schema';") result.isBeforeFirst } @@ -265,15 +281,15 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat } } - override def persistViewMetadata(tableIdentifier: TableIdentifierNormalized, sqlText: String): Unit = { + override def persistViewMetadata(tableIdentifier: TableIdentifierNormalized, + sqlText: String): Unit = { try { connection.setAutoCommit(false) val resultSet = selectMetadata(tableWithViewMetadata, tableIdentifier) if (!resultSet.isBeforeFirst) { resultSet.close() - val prepped = connection.prepareStatement( - s"""|INSERT INTO $db.$tableWithViewMetadata ( + val prepped = connection.prepareStatement(s"""|INSERT INTO $db.$tableWithViewMetadata ( | $DatabaseField, $TableNameField, $SqlViewField, $CrossdataVersionField |) VALUES (?,?,?,?) """.stripMargin) @@ -285,10 +301,10 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat prepped.close() } else { resultSet.close() - val prepped = - connection.prepareStatement( + val prepped = connection.prepareStatement( s"""|UPDATE $db.$tableWithViewMetadata SET $SqlViewField=? - |WHERE $DatabaseField='${tableIdentifier.database.getOrElse("")}' AND $TableNameField='${tableIdentifier.table}' + |WHERE $DatabaseField='${tableIdentifier.database + .getOrElse("")}' AND $TableNameField='${tableIdentifier.table}' """.stripMargin.replaceAll("\n", " ")) prepped.setString(1, sqlText) @@ -302,41 +318,41 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat } } - private def selectMetadata(targetTable: String, tableIdentifier: TableIdentifierNormalized): ResultSet = { + private def selectMetadata(targetTable: String, + tableIdentifier: TableIdentifierNormalized): ResultSet = { - val preparedStatement = connection.prepareStatement(s"SELECT * FROM $db.$targetTable WHERE $DatabaseField= ? AND $TableNameField= ?") + val preparedStatement = connection.prepareStatement( + s"SELECT * FROM $db.$targetTable WHERE $DatabaseField= ? AND $TableNameField= ?") preparedStatement.setString(1, tableIdentifier.database.getOrElse("")) preparedStatement.setString(2, tableIdentifier.table) preparedStatement.executeQuery() - } override def dropViewMetadata(viewIdentifier: ViewIdentifierNormalized): Unit = { connection.createStatement.executeUpdate( - s"DELETE FROM $db.$tableWithViewMetadata WHERE tableName='${viewIdentifier.table}' AND db='${viewIdentifier.database.getOrElse("")}'") + s"DELETE FROM $db.$tableWithViewMetadata WHERE tableName='${viewIdentifier.table}' AND db='${viewIdentifier.database + .getOrElse("")}'") } - override def dropAllViewsMetadata(): Unit = { connection.createStatement.executeUpdate(s"DELETE FROM $db.$tableWithViewMetadata") } - override def saveAppMetadata(crossdataApp: CrossdataApp): Unit = try { connection.setAutoCommit(false) - val preparedStatement = connection.prepareStatement(s"SELECT * FROM $db.$tableWithAppJars WHERE $AppAlias= ?") + val preparedStatement = + connection.prepareStatement(s"SELECT * FROM $db.$tableWithAppJars WHERE $AppAlias= ?") preparedStatement.setString(1, crossdataApp.appAlias) val resultSet = preparedStatement.executeQuery() preparedStatement.close() if (!resultSet.next()) { resultSet.close() - val prepped = connection.prepareStatement( - s"""|INSERT INTO $db.$tableWithAppJars ( + val prepped = connection.prepareStatement(s"""|INSERT INTO $db.$tableWithAppJars ( | $JarPath, $AppAlias, $AppClass |) VALUES (?,?,?) """.stripMargin) @@ -347,8 +363,8 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat prepped.close() } else { resultSet.close() - val prepped = connection.prepareStatement( - s"""|UPDATE $db.$tableWithAppJars SET $JarPath=?, $AppClass=? + val prepped = + connection.prepareStatement(s"""|UPDATE $db.$tableWithAppJars SET $JarPath=?, $AppClass=? |WHERE $AppAlias='${crossdataApp.appAlias}' """.stripMargin) prepped.setString(1, crossdataApp.jar) @@ -363,7 +379,8 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat override def getApp(alias: String): Option[CrossdataApp] = { - val preparedStatement = connection.prepareStatement(s"SELECT * FROM $db.$tableWithAppJars WHERE $AppAlias= ?") + val preparedStatement = + connection.prepareStatement(s"SELECT * FROM $db.$tableWithAppJars WHERE $AppAlias= ?") preparedStatement.setString(1, alias) val resultSet = preparedStatement.executeQuery() @@ -379,14 +396,13 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat resultSet.close() preparedStatement.close() Some( - CrossdataApp(jar, alias, clss) + CrossdataApp(jar, alias, clss) ) } } override def isAvailable: Boolean = Option(connection).isDefined - override def persistIndexMetadata(crossdataIndex: CrossdataIndex): Unit = try { connection.setAutoCommit(false) @@ -397,8 +413,7 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat val serializedOptions = serializeOptions(crossdataIndex.opts) if (!resultSet.next()) { - val prepped = connection.prepareStatement( - s"""|INSERT INTO $db.$TableWithIndexMetadata ( + val prepped = connection.prepareStatement(s"""|INSERT INTO $db.$TableWithIndexMetadata ( | $DatabaseField, $TableNameField, $IndexNameField, $IndexTypeField, $IndexedColsField, | $PKField, $DatasourceField, $OptionsField, $CrossdataVersionField |) VALUES (?,?,?,?,?,?,?,?,?) @@ -415,16 +430,16 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat prepped.execute() } else { //TODO: Support change index metadata? - sys.error(s"A global index already exists in table ${crossdataIndex.tableIdentifier.unquotedString}") + sys.error( + s"A global index already exists in table ${crossdataIndex.tableIdentifier.unquotedString}") } } finally { connection.setAutoCommit(true) } - override def dropIndexMetadata(indexIdentifier: IndexIdentifierNormalized): Unit = connection.createStatement.executeUpdate( - s"DELETE FROM $db.$TableWithIndexMetadata WHERE $IndexTypeField='${indexIdentifier.indexType}' AND $IndexNameField='${indexIdentifier.indexName}'" + s"DELETE FROM $db.$TableWithIndexMetadata WHERE $IndexTypeField='${indexIdentifier.indexType}' AND $IndexNameField='${indexIdentifier.indexName}'" ) override def dropAllIndexesMetadata(): Unit = @@ -448,14 +463,20 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat val version = resultSet.getString(CrossdataVersionField) Option( - CrossdataIndex(TableIdentifierNormalized(table, Option(database)), IndexIdentifierNormalized(indexType, indexName), - deserializeSeq(indexedCols), pk, datasource, deserializeOptions(optsJSON), version) + CrossdataIndex(TableIdentifierNormalized(table, Option(database)), + IndexIdentifierNormalized(indexType, indexName), + deserializeSeq(indexedCols), + pk, + datasource, + deserializeOptions(optsJSON), + version) ) } } private def selectIndex(indexIdentifier: IndexIdentifierNormalized): ResultSet = { - val preparedStatement = connection.prepareStatement(s"SELECT * FROM $db.$TableWithIndexMetadata WHERE $IndexNameField= ? AND $IndexTypeField= ?") + val preparedStatement = connection.prepareStatement( + s"SELECT * FROM $db.$TableWithIndexMetadata WHERE $IndexNameField= ? AND $IndexTypeField= ?") preparedStatement.setString(1, indexIdentifier.indexName) preparedStatement.setString(2, indexIdentifier.indexType) preparedStatement.executeQuery() @@ -463,12 +484,15 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat override def dropIndexMetadata(tableIdentifier: TableIdentifierNormalized): Unit = connection.createStatement.executeUpdate( - s"DELETE FROM $db.$TableWithIndexMetadata WHERE $TableNameField='${tableIdentifier.table}' AND $DatabaseField='${tableIdentifier.database.getOrElse("")}'" + s"DELETE FROM $db.$TableWithIndexMetadata WHERE $TableNameField='${tableIdentifier.table}' AND $DatabaseField='${tableIdentifier.database + .getOrElse("")}'" ) - override def lookupIndexByTableIdentifier(tableIdentifier: TableIdentifierNormalized): Option[CrossdataIndex] = { + override def lookupIndexByTableIdentifier( + tableIdentifier: TableIdentifierNormalized): Option[CrossdataIndex] = { val query = - s"SELECT * FROM $db.$TableWithIndexMetadata WHERE $TableNameField='${tableIdentifier.table}' AND $DatabaseField='${tableIdentifier.database.getOrElse("")}'" + s"SELECT * FROM $db.$TableWithIndexMetadata WHERE $TableNameField='${tableIdentifier.table}' AND $DatabaseField='${tableIdentifier.database + .getOrElse("")}'" val preparedStatement = connection.prepareStatement(query) val resultSet = preparedStatement.executeQuery() if (!resultSet.next) { @@ -486,9 +510,14 @@ class PostgreSQLXDCatalog(sqlContext: SQLContext, override val catalystConf: Cat val version = resultSet.getString(CrossdataVersionField) Option( - CrossdataIndex(TableIdentifierNormalized(table, Option(database)), IndexIdentifierNormalized(indexType, indexName), - deserializeSeq(indexedCols), pk, datasource, deserializeOptions(optsJSON), version) + CrossdataIndex(TableIdentifierNormalized(table, Option(database)), + IndexIdentifierNormalized(indexType, indexName), + deserializeSeq(indexedCols), + pk, + datasource, + deserializeOptions(optsJSON), + version) ) } } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/ZookeeperCatalog.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/ZookeeperCatalog.scala index 3d9705fde..41ccbf048 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/ZookeeperCatalog.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/persistent/ZookeeperCatalog.scala @@ -34,7 +34,7 @@ import scala.util.Try * @param catalystConf An implementation of the [[CatalystConf]]. */ class ZookeeperCatalog(override val catalystConf: CatalystConf) - extends PersistentCatalogWithCache(catalystConf){ + extends PersistentCatalogWithCache(catalystConf) { import XDCatalog._ @@ -43,21 +43,22 @@ class ZookeeperCatalog(override val catalystConf: CatalystConf) @transient val appDAO = new AppTypesafeDAO(XDContext.catalogConfig) @transient val indexDAO = new IndexTypesafeDAO(XDContext.catalogConfig) - override def lookupTable(tableIdentifier: TableIdentifierNormalized): Option[CrossdataTable] = { if (tableDAO.dao.count > 0) { - val findTable = tableDAO.dao.getAll() + val findTable = tableDAO.dao + .getAll() .find(tableModel => - tableModel.name == tableIdentifier.table && tableModel.database == tableIdentifier.database) + tableModel.name == tableIdentifier.table && tableModel.database == tableIdentifier.database) findTable match { case Some(zkTable) => - Option(CrossdataTable(TableIdentifierNormalized(zkTable.name, zkTable.database), - Option(deserializeUserSpecifiedSchema(zkTable.schema)), - zkTable.dataSource, - zkTable.partitionColumns.toArray, - zkTable.options, - zkTable.version)) + Option( + CrossdataTable(TableIdentifierNormalized(zkTable.name, zkTable.database), + Option(deserializeUserSpecifiedSchema(zkTable.schema)), + zkTable.dataSource, + zkTable.partitionColumns.toArray, + zkTable.options, + zkTable.version)) case None => tableDAO.logger.warn("Table doesn't exist") None @@ -68,18 +69,13 @@ class ZookeeperCatalog(override val catalystConf: CatalystConf) } } - override def getApp(alias: String): Option[CrossdataApp] = { if (appDAO.dao.count > 0) { - val findApp = appDAO.dao.getAll() - .find(appModel => - appModel.appAlias == alias) + val findApp = appDAO.dao.getAll().find(appModel => appModel.appAlias == alias) findApp match { case Some(zkApp) => - Option(CrossdataApp(zkApp.jar, - zkApp.appAlias, - zkApp.appClass)) + Option(CrossdataApp(zkApp.jar, zkApp.appAlias, zkApp.appClass)) case None => appDAO.logger.warn("App doesn't exist") None @@ -90,14 +86,17 @@ class ZookeeperCatalog(override val catalystConf: CatalystConf) } } - - override def allRelations(databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = { + override def allRelations( + databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = { if (tableDAO.dao.count > 0) { - tableDAO.dao.getAll() + tableDAO.dao + .getAll() .flatMap(tableModel => { - databaseName.fold(Option(TableIdentifierNormalized(tableModel.name, tableModel.database))) { dbName => + databaseName.fold( + Option(TableIdentifierNormalized(tableModel.name, tableModel.database))) { dbName => tableModel.database.flatMap(dbNameModel => { - if (dbName.normalizedString == dbNameModel) Option(TableIdentifierNormalized(tableModel.name, tableModel.database)) + if (dbName.normalizedString == dbNameModel) + Option(TableIdentifierNormalized(tableModel.name, tableModel.database)) else None }) } @@ -111,36 +110,31 @@ class ZookeeperCatalog(override val catalystConf: CatalystConf) override def persistTableMetadata(crossdataTable: CrossdataTable): Unit = { val tableId = createId - tableDAO.dao.create(tableId, - TableModel(tableId, - crossdataTable.tableIdentifier.table, - serializeSchema(crossdataTable.schema.getOrElse(schemaNotFound())), - crossdataTable.datasource, - crossdataTable.tableIdentifier.database, - crossdataTable.partitionColumn, - crossdataTable.opts)) + tableDAO.dao.create( + tableId, + TableModel(tableId, + crossdataTable.tableIdentifier.table, + serializeSchema(crossdataTable.schema.getOrElse(schemaNotFound())), + crossdataTable.datasource, + crossdataTable.tableIdentifier.database, + crossdataTable.partitionColumn, + crossdataTable.opts)) } - override def saveAppMetadata(crossdataApp: CrossdataApp): Unit = { val appId = createId - appDAO.dao.create(appId, - AppModel( - crossdataApp.jar, - crossdataApp.appAlias, - crossdataApp.appClass)) + appDAO.dao + .create(appId, AppModel(crossdataApp.jar, crossdataApp.appAlias, crossdataApp.appClass)) } - override def dropTableMetadata(tableIdentifier: ViewIdentifierNormalized): Unit = - tableDAO.dao.getAll().filter { - tableModel => tableIdentifier.table == tableModel.name && tableIdentifier.database == tableModel.database + tableDAO.dao.getAll().filter { tableModel => + tableIdentifier.table == tableModel.name && tableIdentifier.database == tableModel.database } foreach { tableModel => tableDAO.dao.delete(tableModel.id) } - override def dropAllTablesMetadata(): Unit = { tableDAO.dao.deleteAll viewDAO.dao.getAll.foreach(view => viewDAO.dao.delete(view.id)) @@ -148,8 +142,10 @@ class ZookeeperCatalog(override val catalystConf: CatalystConf) override def lookupView(viewIdentifier: ViewIdentifierNormalized): Option[String] = { if (viewDAO.dao.count > 0) { - val findView = viewDAO.dao.getAll() - .find(viewModel => viewModel.name == viewIdentifier.table && viewModel.database == viewIdentifier.database) + val findView = viewDAO.dao + .getAll() + .find(viewModel => + viewModel.name == viewIdentifier.table && viewModel.database == viewIdentifier.database) findView match { case Some(zkView) => @@ -164,20 +160,20 @@ class ZookeeperCatalog(override val catalystConf: CatalystConf) } } - override def persistViewMetadata(tableIdentifier: TableIdentifierNormalized, sqlText: String): Unit = { + override def persistViewMetadata(tableIdentifier: TableIdentifierNormalized, + sqlText: String): Unit = { val viewId = createId - viewDAO.dao.create(viewId, ViewModel(viewId, tableIdentifier.table, tableIdentifier.database, sqlText)) + viewDAO.dao + .create(viewId, ViewModel(viewId, tableIdentifier.table, tableIdentifier.database, sqlText)) } - override def dropViewMetadata(viewIdentifier: ViewIdentifierNormalized): Unit = - viewDAO.dao.getAll().filter { - view => view.name == viewIdentifier.table && view.database == viewIdentifier.database + viewDAO.dao.getAll().filter { view => + view.name == viewIdentifier.table && view.database == viewIdentifier.database } foreach { selectedView => viewDAO.dao.delete(selectedView.id) } - override def dropAllViewsMetadata(): Unit = viewDAO.dao.deleteAll override def isAvailable: Boolean = { @@ -197,18 +193,22 @@ class ZookeeperCatalog(override val catalystConf: CatalystConf) } override def dropIndexMetadata(indexIdentifier: IndexIdentifierNormalized): Unit = - indexDAO.dao.getAll().filter( - index => index.crossdataIndex.indexIdentifier == indexIdentifier - ) foreach (selectedIndex => indexDAO.dao.delete(selectedIndex.indexId)) + indexDAO.dao + .getAll() + .filter( + index => index.crossdataIndex.indexIdentifier == indexIdentifier + ) foreach (selectedIndex => indexDAO.dao.delete(selectedIndex.indexId)) override def dropAllIndexesMetadata(): Unit = indexDAO.dao.deleteAll override def lookupIndex(indexIdentifier: IndexIdentifierNormalized): Option[CrossdataIndex] = { if (indexDAO.dao.count > 0) { - val res = indexDAO.dao.getAll().find( - _.crossdataIndex.indexIdentifier == indexIdentifier - ) map (_.crossdataIndex) + val res = indexDAO.dao + .getAll() + .find( + _.crossdataIndex.indexIdentifier == indexIdentifier + ) map (_.crossdataIndex) if (res.isEmpty) indexDAO.logger.warn("Index path doesn't exist") res @@ -220,15 +220,20 @@ class ZookeeperCatalog(override val catalystConf: CatalystConf) } override def dropIndexMetadata(tableIdentifier: TableIdentifierNormalized): Unit = - indexDAO.dao.getAll().filter( - index => index.crossdataIndex.tableIdentifier == tableIdentifier - ) foreach (selectedIndex => indexDAO.dao.delete(selectedIndex.indexId)) - - override def lookupIndexByTableIdentifier(tableIdentifier: TableIdentifierNormalized): Option[CrossdataIndex] = { + indexDAO.dao + .getAll() + .filter( + index => index.crossdataIndex.tableIdentifier == tableIdentifier + ) foreach (selectedIndex => indexDAO.dao.delete(selectedIndex.indexId)) + + override def lookupIndexByTableIdentifier( + tableIdentifier: TableIdentifierNormalized): Option[CrossdataIndex] = { if (indexDAO.dao.count > 0) { - val res = indexDAO.dao.getAll().find( - _.crossdataIndex.tableIdentifier == tableIdentifier - ) map (_.crossdataIndex) + val res = indexDAO.dao + .getAll() + .find( + _.crossdataIndex.tableIdentifier == tableIdentifier + ) map (_.crossdataIndex) if (res.isEmpty) indexDAO.logger.warn("Index path doesn't exist") res } else { @@ -236,4 +241,4 @@ class ZookeeperCatalog(override val catalystConf: CatalystConf) None } } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/streaming/ZookeeperStreamingCatalog.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/streaming/ZookeeperStreamingCatalog.scala index f68b323f7..26709138a 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/streaming/ZookeeperStreamingCatalog.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/streaming/ZookeeperStreamingCatalog.scala @@ -31,18 +31,19 @@ import scala.concurrent.duration._ import scala.concurrent.{Await, Future} import scala.util.Try -class ZookeeperStreamingCatalog(val catalystConf: CatalystConf, serverConfig: Config) extends XDStreamingCatalog { +class ZookeeperStreamingCatalog(val catalystConf: CatalystConf, serverConfig: Config) + extends XDStreamingCatalog { private[spark] val streamingConfig = serverConfig.getConfig(CoreConfig.StreamingConfigKey) - private[spark] val ephemeralTableDAO = - new EphemeralTableTypesafeDAO(streamingConfig.getConfig(CoreConfig.CatalogConfigKey)) - private[spark] val ephemeralQueriesDAO = - new EphemeralQueriesTypesafeDAO(streamingConfig.getConfig(CoreConfig.CatalogConfigKey)) - private[spark] val ephemeralTableStatusDAO = - new EphemeralTableStatusTypesafeDAO(streamingConfig.getConfig(CoreConfig.CatalogConfigKey)) - - - override def relation(tableIdent: TableIdentifierNormalized)(implicit sqlContext: SQLContext): Option[LogicalPlan] = { + private[spark] val ephemeralTableDAO = new EphemeralTableTypesafeDAO( + streamingConfig.getConfig(CoreConfig.CatalogConfigKey)) + private[spark] val ephemeralQueriesDAO = new EphemeralQueriesTypesafeDAO( + streamingConfig.getConfig(CoreConfig.CatalogConfigKey)) + private[spark] val ephemeralTableStatusDAO = new EphemeralTableStatusTypesafeDAO( + streamingConfig.getConfig(CoreConfig.CatalogConfigKey)) + + override def relation(tableIdent: TableIdentifierNormalized)( + implicit sqlContext: SQLContext): Option[LogicalPlan] = { import XDCatalogCommon._ val tableIdentifier: String = stringifyTableIdentifierNormalized(tableIdent) if (futurize(existsEphemeralTable(tableIdentifier))) @@ -55,40 +56,46 @@ class ZookeeperStreamingCatalog(val catalystConf: CatalystConf, serverConfig: Co override def isAvailable: Boolean = true // TODO It must not return the relations until the catalog can distinguish between real/ephemeral tables - override def allRelations(databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = Seq.empty + override def allRelations( + databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = + Seq.empty private def futurize[P](operation: => P): P = Await.result(Future(operation), 5 seconds) /** - * Ephemeral Table Functions - */ + * Ephemeral Table Functions + */ override def existsEphemeralTable(tableIdentifier: String): Boolean = futurize(ephemeralTableDAO.dao.exists(tableIdentifier)) override def getEphemeralTable(tableIdentifier: String): Option[EphemeralTableModel] = futurize(ephemeralTableDAO.dao.get(tableIdentifier)) - override def createEphemeralTable(ephemeralTable: EphemeralTableModel): Either[String, EphemeralTableModel] = + override def createEphemeralTable( + ephemeralTable: EphemeralTableModel): Either[String, EphemeralTableModel] = if (!existsEphemeralTable(ephemeralTable.name)) { - createEphemeralStatus(ephemeralTable.name, EphemeralStatusModel(ephemeralTable.name, EphemeralExecutionStatus.NotStarted)) + createEphemeralStatus( + ephemeralTable.name, + EphemeralStatusModel(ephemeralTable.name, EphemeralExecutionStatus.NotStarted)) Right(ephemeralTableDAO.dao.upsert(ephemeralTable.name, ephemeralTable)) - } - else Left("Ephemeral table exists") - + } else Left("Ephemeral table exists") override def dropEphemeralTable(tableIdentifier: String): Unit = { val isRunning = ephemeralTableStatusDAO.dao.get(tableIdentifier).map { tableStatus => tableStatus.status == EphemeralExecutionStatus.Started || tableStatus.status == EphemeralExecutionStatus.Starting } getOrElse notFound(tableIdentifier) - if (isRunning) throw new RuntimeException("The ephemeral is running. The process should be stopped first using 'Stop '") + if (isRunning) + throw new RuntimeException( + "The ephemeral is running. The process should be stopped first using 'Stop '") ephemeralTableDAO.dao.delete(tableIdentifier) ephemeralTableStatusDAO.dao.delete(tableIdentifier) - ephemeralQueriesDAO.dao.getAll().filter(_.ephemeralTableName == tableIdentifier) foreach { query => - ephemeralQueriesDAO.dao.delete(query.alias) + ephemeralQueriesDAO.dao.getAll().filter(_.ephemeralTableName == tableIdentifier) foreach { + query => + ephemeralQueriesDAO.dao.delete(query.alias) } } @@ -104,14 +111,14 @@ class ZookeeperStreamingCatalog(val catalystConf: CatalystConf, serverConfig: Co override def getAllEphemeralTables: Seq[EphemeralTableModel] = ephemeralTableDAO.dao.getAll() - /** - * Ephemeral Queries Functions - */ + * Ephemeral Queries Functions + */ override def existsEphemeralQuery(queryAlias: String): Boolean = ephemeralQueriesDAO.dao.exists(queryAlias) - override def createEphemeralQuery(ephemeralQuery: EphemeralQueryModel): Either[String, EphemeralQueryModel] = + override def createEphemeralQuery( + ephemeralQuery: EphemeralQueryModel): Either[String, EphemeralQueryModel] = if (!existsEphemeralQuery(ephemeralQuery.alias)) Right(ephemeralQueriesDAO.dao.upsert(ephemeralQuery.alias, ephemeralQuery)) else Left("Ephemeral query exists") @@ -125,13 +132,15 @@ class ZookeeperStreamingCatalog(val catalystConf: CatalystConf, serverConfig: Co override def dropEphemeralQuery(queryAlias: String): Unit = ephemeralQueriesDAO.dao.delete(queryAlias) - override def dropAllEphemeralQueries(): Unit = ephemeralQueriesDAO.dao.deleteAll + override def dropAllEphemeralQueries(): Unit = + ephemeralQueriesDAO.dao.deleteAll /** - * Ephemeral Status Functions - */ - override def createEphemeralStatus(tableIdentifier: String, - ephemeralStatusModel: EphemeralStatusModel): EphemeralStatusModel = + * Ephemeral Status Functions + */ + override def createEphemeralStatus( + tableIdentifier: String, + ephemeralStatusModel: EphemeralStatusModel): EphemeralStatusModel = ephemeralTableStatusDAO.dao.upsert(tableIdentifier, ephemeralStatusModel) override def getEphemeralStatus(tableIdentifier: String): Option[EphemeralStatusModel] = diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/HashmapCatalog.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/HashmapCatalog.scala index 501a98b95..09bcbe866 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/HashmapCatalog.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/HashmapCatalog.scala @@ -22,7 +22,8 @@ import scala.collection.mutable class HashmapCatalog(override val catalystConf: CatalystConf) extends MapCatalog(catalystConf) { - override protected def newMap: mutable.Map[String, LogicalPlan] = new mutable.HashMap[String, LogicalPlan] + override protected def newMap: mutable.Map[String, LogicalPlan] = + new mutable.HashMap[String, LogicalPlan] override def isAvailable: Boolean = true -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/MapCatalog.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/MapCatalog.scala index cc6a3433a..60f5ccaf4 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/MapCatalog.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/MapCatalog.scala @@ -33,12 +33,15 @@ abstract class MapCatalog(catalystConf: CatalystConf) extends XDTemporaryCatalog private val tables: mutable.Map[String, LogicalPlan] = newMap private val views: mutable.Map[String, LogicalPlan] = newMap - implicit def tableIdent2string(tident: TableIdentifierNormalized): String = XDCatalogCommon.stringifyTableIdentifierNormalized(tident) + implicit def tableIdent2string(tident: TableIdentifierNormalized): String = + XDCatalogCommon.stringifyTableIdentifierNormalized(tident) - override def relation(tableIdent: TableIdentifierNormalized)(implicit sqlContext: SQLContext): Option[LogicalPlan] = + override def relation(tableIdent: TableIdentifierNormalized)( + implicit sqlContext: SQLContext): Option[LogicalPlan] = (tables get tableIdent) orElse (views get tableIdent) - override def allRelations(databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = { + override def allRelations( + databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = { (tables ++ views).toSeq collect { case (k, _) if databaseName.map(_.normalizedString == k.split("\\.")(0)).getOrElse(true) => k.split("\\.") match { @@ -48,21 +51,19 @@ abstract class MapCatalog(catalystConf: CatalystConf) extends XDTemporaryCatalog } } - override def saveTable( - tableIdentifier: TableIdentifierNormalized, - plan: LogicalPlan, - crossdataTable: Option[CrossdataTable] = None): Unit = { + override def saveTable(tableIdentifier: TableIdentifierNormalized, + plan: LogicalPlan, + crossdataTable: Option[CrossdataTable] = None): Unit = { views get tableIdentifier foreach (_ => dropView(tableIdentifier)) - tables put(tableIdentifier, plan) + tables put (tableIdentifier, plan) } - override def saveView( - viewIdentifier: ViewIdentifierNormalized, - plan: LogicalPlan, - query: Option[String] = None): Unit = { + override def saveView(viewIdentifier: ViewIdentifierNormalized, + plan: LogicalPlan, + query: Option[String] = None): Unit = { tables get viewIdentifier foreach (_ => dropTable(viewIdentifier)) - views put(viewIdentifier, plan) + views put (viewIdentifier, plan) } override def dropView(viewIdentifier: ViewIdentifierNormalized): Unit = diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/XDTemporaryCatalogWithInvalidation.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/XDTemporaryCatalogWithInvalidation.scala index a7a498a9a..14402b702 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/XDTemporaryCatalogWithInvalidation.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/XDTemporaryCatalogWithInvalidation.scala @@ -30,19 +30,20 @@ import org.apache.spark.sql.crossdata.catalog.interfaces.XDTemporaryCatalog * @param invalidator Cache invalidation implementation */ class XDTemporaryCatalogWithInvalidation( - val underlying: XDTemporaryCatalog, - invalidator: CacheInvalidator - ) extends XDTemporaryCatalog { + val underlying: XDTemporaryCatalog, + invalidator: CacheInvalidator +) extends XDTemporaryCatalog { - override def saveTable( - tableIdentifier: ViewIdentifierNormalized, - plan: LogicalPlan, - crossdataTable: Option[CrossdataTable]): Unit = { + override def saveTable(tableIdentifier: ViewIdentifierNormalized, + plan: LogicalPlan, + crossdataTable: Option[CrossdataTable]): Unit = { invalidator.invalidateCache underlying.saveTable(tableIdentifier, plan, crossdataTable) } - override def saveView(viewIdentifier: ViewIdentifierNormalized, plan: LogicalPlan, query: Option[String]): Unit = { + override def saveView(viewIdentifier: ViewIdentifierNormalized, + plan: LogicalPlan, + query: Option[String]): Unit = { invalidator.invalidateCache underlying.saveView(viewIdentifier, plan, query) } @@ -67,10 +68,13 @@ class XDTemporaryCatalogWithInvalidation( underlying.dropTable(tableIdentifier) } - override def relation(tableIdent: ViewIdentifierNormalized)(implicit sqlContext: SQLContext): Option[LogicalPlan] = + override def relation(tableIdent: ViewIdentifierNormalized)( + implicit sqlContext: SQLContext): Option[LogicalPlan] = underlying.relation(tableIdent) override def catalystConf: CatalystConf = underlying.catalystConf override def isAvailable: Boolean = underlying.isAvailable - override def allRelations(databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = underlying.allRelations(databaseName) + override def allRelations( + databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = + underlying.allRelations(databaseName) } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/utils/CatalogUtils.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/utils/CatalogUtils.scala index c1701cd46..494e83dfc 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/utils/CatalogUtils.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalog/utils/CatalogUtils.scala @@ -25,11 +25,13 @@ import org.apache.spark.sql.crossdata.config.CoreConfig object CatalogUtils extends Logging { - protected[crossdata] def externalCatalog(catalystConf: CatalystConf, config: Config): XDPersistentCatalog = { + protected[crossdata] def externalCatalog(catalystConf: CatalystConf, + config: Config): XDPersistentCatalog = { import CoreConfig.DerbyClass - val externalCatalogName = if (config.hasPath(CoreConfig.ClassConfigKey)) - config.getString(CoreConfig.ClassConfigKey) - else DerbyClass + val externalCatalogName = + if (config.hasPath(CoreConfig.ClassConfigKey)) + config.getString(CoreConfig.ClassConfigKey) + else DerbyClass val externalCatalogClass = Class.forName(externalCatalogName) val constr: Constructor[_] = externalCatalogClass.getConstructor(classOf[CatalystConf]) @@ -37,11 +39,13 @@ object CatalogUtils extends Logging { constr.newInstance(catalystConf).asInstanceOf[XDPersistentCatalog] } - protected[crossdata] def streamingCatalog(catalystConf: CatalystConf, serverConfig: Config): Option[XDStreamingCatalog] = { + protected[crossdata] def streamingCatalog(catalystConf: CatalystConf, + serverConfig: Config): Option[XDStreamingCatalog] = { if (serverConfig.hasPath(CoreConfig.StreamingCatalogClassConfigKey)) { val streamingCatalogClass = serverConfig.getString(CoreConfig.StreamingCatalogClassConfigKey) val xdStreamingCatalog = Class.forName(streamingCatalogClass) - val constr: Constructor[_] = xdStreamingCatalog.getConstructor(classOf[CatalystConf], classOf[Config]) + val constr: Constructor[_] = + xdStreamingCatalog.getConstructor(classOf[CatalystConf], classOf[Config]) Option(constr.newInstance(catalystConf, serverConfig).asInstanceOf[XDStreamingCatalog]) } else { logWarning("There is no configured streaming catalog") diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/ExtendedUnresolvedRelation.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/ExtendedUnresolvedRelation.scala index 899d1942d..1ca04c6de 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/ExtendedUnresolvedRelation.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/ExtendedUnresolvedRelation.scala @@ -19,6 +19,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} -case class ExtendedUnresolvedRelation(tableIdentifier: TableIdentifier, child: LogicalPlan) extends UnaryNode { +case class ExtendedUnresolvedRelation(tableIdentifier: TableIdentifier, child: LogicalPlan) + extends UnaryNode { override def output: Seq[Attribute] = child.output } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/XDFunctionRegistry.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/XDFunctionRegistry.scala index c0e53c8b0..90d0df722 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/XDFunctionRegistry.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/XDFunctionRegistry.scala @@ -24,8 +24,10 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import scala.util.Try -private[crossdata] class XDFunctionRegistry(sparkFunctionRegistry: FunctionRegistry, functionInventoryServices: Seq[FunctionInventory]) - extends FunctionRegistry with SparkLoggerComponent { +private[crossdata] class XDFunctionRegistry(sparkFunctionRegistry: FunctionRegistry, + functionInventoryServices: Seq[FunctionInventory]) + extends FunctionRegistry + with SparkLoggerComponent { import FunctionInventory.qualifyUDF @@ -33,20 +35,26 @@ private[crossdata] class XDFunctionRegistry(sparkFunctionRegistry: FunctionRegis override def lookupFunction(name: String, children: Seq[Expression]): Expression = Try(sparkFunctionRegistry.lookupFunction(name, children)).getOrElse { - val datasourceCandidates: Seq[(Expression, String)] = functionInventoryServices.flatMap { fi => - Try( - (sparkFunctionRegistry.lookupFunction(qualifyUDF(fi.shortName(), name), children), fi.shortName()) - ).toOption + val datasourceCandidates: Seq[(Expression, String)] = functionInventoryServices.flatMap { + fi => + Try( + (sparkFunctionRegistry.lookupFunction(qualifyUDF(fi.shortName(), name), children), + fi.shortName()) + ).toOption } datasourceCandidates match { case Seq() => missingFunction(name) - case Seq((expression, dsname)) => logInfo(s"NativeUDF $name has been resolved to ${qualifyUDF(dsname, name)}"); expression + case Seq((expression, dsname)) => + logInfo(s"NativeUDF $name has been resolved to ${qualifyUDF(dsname, name)}"); + expression case multipleDC => duplicateFunction(name, multipleDC.map(_._2)) } } - override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder): Unit = + override def registerFunction(name: String, + info: ExpressionInfo, + builder: FunctionBuilder): Unit = sparkFunctionRegistry.registerFunction(name, info, builder) override def lookupFunction(name: String): Option[ExpressionInfo] = @@ -59,5 +67,7 @@ private[crossdata] class XDFunctionRegistry(sparkFunctionRegistry: FunctionRegis throw new AnalysisException(s"Undefined function $name") private def duplicateFunction(name: String, datasources: Seq[String]) = - throw new AnalysisException(s"Unable to resolve udf $name. You must qualify it: use one of ${datasources.map(qualifyUDF(_, name).mkString(", "))}") -} \ No newline at end of file + throw new AnalysisException( + s"Unable to resolve udf $name. You must qualify it: use one of ${datasources.map( + qualifyUDF(_, name).mkString(", "))}") +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/analysis/crossdataPlans.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/analysis/crossdataPlans.scala index e43b8c87d..92511082d 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/analysis/crossdataPlans.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/analysis/crossdataPlans.scala @@ -19,13 +19,16 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedEx import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable} import org.apache.spark.sql.types.DataType -case class PostponedAttribute(unresolvedAttribute: UnresolvedAttribute) extends Expression with Unevaluable { - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") +case class PostponedAttribute(unresolvedAttribute: UnresolvedAttribute) + extends Expression + with Unevaluable { + override def nullable: Boolean = + throw new UnresolvedException(this, "nullable") - override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def dataType: DataType = + throw new UnresolvedException(this, "dataType") override def children: Seq[Expression] = Seq() override lazy val resolved: Boolean = false } - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/analysis/crossdataRules.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/analysis/crossdataRules.scala index fba7ebb96..a9aab8744 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/analysis/crossdataRules.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/analysis/crossdataRules.scala @@ -32,10 +32,12 @@ object ResolveAggregateAlias extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p - case a@Aggregate(grouping, aggregateExp, child) if child.resolved && !a.resolved && groupingExpressionsContainAlias(grouping, aggregateExp) => + case a @ Aggregate(grouping, aggregateExp, child) + if child.resolved && !a.resolved && groupingExpressionsContainAlias(grouping, + aggregateExp) => val newGrouping = grouping.map { groupExpression => groupExpression transformUp { - case PostponedAttribute(u@UnresolvedAttribute(Seq(aliasCandidate))) => + case PostponedAttribute(u @ UnresolvedAttribute(Seq(aliasCandidate))) => aggregateExp.collectFirst { case Alias(resolvedAttribute, aliasName) if aliasName == aliasCandidate => resolvedAttribute @@ -46,7 +48,9 @@ object ResolveAggregateAlias extends Rule[LogicalPlan] { } - private def groupingExpressionsContainAlias(groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression]): Boolean = { + private def groupingExpressionsContainAlias( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression]): Boolean = { def aggregateExpressionsContainAliasReference(aliasCandidate: String) = aggregateExpressions.exists { case Alias(resolvedAttribute, aliasName) if aliasName == aliasCandidate => @@ -71,12 +75,16 @@ object PrepareAggregateAlias extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case a@Aggregate(grouping, aggregateExp, child) if !child.resolved && !a.resolved && groupingExpressionsContainUnresolvedAlias(grouping, aggregateExp) => + case a @ Aggregate(grouping, aggregateExp, child) + if !child.resolved && !a.resolved && groupingExpressionsContainUnresolvedAlias( + grouping, + aggregateExp) => val newGrouping = grouping.map { groupExpression => groupExpression transformUp { - case u@UnresolvedAttribute(Seq(aliasCandidate)) => + case u @ UnresolvedAttribute(Seq(aliasCandidate)) => aggregateExp.collectFirst { - case UnresolvedAlias(Alias(unresolvedAttr, aliasName)) if aliasName == aliasCandidate => + case UnresolvedAlias(Alias(unresolvedAttr, aliasName)) + if aliasName == aliasCandidate => PostponedAttribute(u) }.getOrElse(u) } @@ -84,10 +92,13 @@ object PrepareAggregateAlias extends Rule[LogicalPlan] { a.copy(groupingExpressions = newGrouping) } - private def groupingExpressionsContainUnresolvedAlias(groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression]): Boolean = { + private def groupingExpressionsContainUnresolvedAlias( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression]): Boolean = { def aggregateExpressionsContainAliasReference(aliasCandidate: String) = aggregateExpressions.exists { - case UnresolvedAlias(Alias(unresolvedAttribute, aliasName)) if aliasName == aliasCandidate => + case UnresolvedAlias(Alias(unresolvedAttribute, aliasName)) + if aliasName == aliasCandidate => true case _ => false @@ -113,7 +124,6 @@ case class WrapRelationWithGlobalIndex(catalog: XDCatalog) extends Rule[LogicalP } } - def planWithAvailableIndex(plan: LogicalPlan): Boolean = { //Get filters and escape projects to check if plan could be resolved using Indexes @@ -123,7 +133,7 @@ case class WrapRelationWithGlobalIndex(catalog: XDCatalog) extends Rule[LogicalP case logical.Filter(condition, child: LogicalPlan) => helper(filtersConditions :+ condition, child) - case p@logical.Project(_, child: LogicalPlan) => + case p @ logical.Project(_, child: LogicalPlan) => helper(filtersConditions, child) case u: UnresolvedRelation => @@ -143,5 +153,4 @@ case class WrapRelationWithGlobalIndex(catalog: XDCatalog) extends Rule[LogicalP helper(Seq.empty, plan) } - } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/execution/commands.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/execution/commands.scala index 7761959fc..6851923f8 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/execution/commands.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/execution/commands.scala @@ -41,52 +41,51 @@ private[crossdata] trait DoCatalogDataSourceTable extends RunnableCommand { } private[crossdata] case class PersistDataSourceTable( - protected val crossdataTable: CrossdataTable, - protected val allowExisting: Boolean - ) extends DoCatalogDataSourceTable { + protected val crossdataTable: CrossdataTable, + protected val allowExisting: Boolean +) extends DoCatalogDataSourceTable { override protected def catalogDataSourceTable(crossdataContext: XDContext): Seq[Row] = { - val tableIdentifier = crossdataTable.tableIdentifier if (crossdataContext.catalog.tableExists(tableIdentifier.toTableIdentifier) && !allowExisting) throw new AnalysisException(s"Table ${tableIdentifier.unquotedString} already exists") else - crossdataContext.catalog.persistTable(crossdataTable, createLogicalRelation(crossdataContext, crossdataTable)) + crossdataContext.catalog + .persistTable(crossdataTable, createLogicalRelation(crossdataContext, crossdataTable)) Seq.empty[Row] } - } private[crossdata] case class RegisterDataSourceTable( - protected val crossdataTable: CrossdataTable, - protected val allowExisting: Boolean - ) extends DoCatalogDataSourceTable { + protected val crossdataTable: CrossdataTable, + protected val allowExisting: Boolean +) extends DoCatalogDataSourceTable { override protected def catalogDataSourceTable(crossdataContext: XDContext): Seq[Row] = { val tableIdentifier = crossdataTable.tableIdentifier.toTableIdentifier crossdataContext.catalog.registerTable( - tableIdentifier, - createLogicalRelation(crossdataContext, crossdataTable), - Some(crossdataTable) + tableIdentifier, + createLogicalRelation(crossdataContext, crossdataTable), + Some(crossdataTable) ) Seq.empty[Row] } } -private[crossdata] case class PersistSelectAsTable( - tableIdent: TableIdentifier, - provider: String, - partitionColumns: Array[String], - mode: SaveMode, - options: Map[String, String], - query: LogicalPlan) extends RunnableCommand { +private[crossdata] case class PersistSelectAsTable(tableIdent: TableIdentifier, + provider: String, + partitionColumns: Array[String], + mode: SaveMode, + options: Map[String, String], + query: LogicalPlan) + extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { @@ -99,21 +98,25 @@ private[crossdata] case class PersistSelectAsTable( // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => - throw new AnalysisException(s"Table ${tableIdent.unquotedString} already exists. " + - s"If you are using saveAsTable, you can set SaveMode to SaveMode.Append to " + - s"insert data into the table or set SaveMode to SaveMode.Overwrite to overwrite" + - s"the existing data. " + - s"Or, if you are using SQL CREATE TABLE, you need to drop ${tableIdent.unquotedString} first.") + throw new AnalysisException( + s"Table ${tableIdent.unquotedString} already exists. " + + s"If you are using saveAsTable, you can set SaveMode to SaveMode.Append to " + + s"insert data into the table or set SaveMode to SaveMode.Overwrite to overwrite" + + s"the existing data. " + + s"Or, if you are using SQL CREATE TABLE, you need to drop ${tableIdent.unquotedString} first.") case SaveMode.Ignore => // Since the table already exists and the save mode is Ignore, we will just return. Seq.empty[Row] case SaveMode.Append => // Check if the specified data source match the data source of the existing table. - val resolved = ResolvedDataSource( - sqlContext, Some(query.schema.asNullable), partitionColumns, provider, options) + val resolved = ResolvedDataSource(sqlContext, + Some(query.schema.asNullable), + partitionColumns, + provider, + options) val createdRelation = LogicalRelation(resolved.relation) EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent)) match { - case l@LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _) => + case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _) => if (l.relation != createdRelation.relation) { val errorDescription = s"Cannot append to table ${tableIdent.unquotedString} because the resolved relation does not " + @@ -123,12 +126,10 @@ private[crossdata] case class PersistSelectAsTable( val errorMessage = s"""|$errorDescription |== Relations == - |${ - sideBySide( - s"== Expected Relation ==" :: l.toString :: Nil, - s"== Actual Relation ==" :: createdRelation.toString :: Nil - ).mkString("\n") - } + |${sideBySide( + s"== Expected Relation ==" :: l.toString :: Nil, + s"== Actual Relation ==" :: createdRelation.toString :: Nil + ).mkString("\n")} """.stripMargin throw new AnalysisException(errorMessage) } @@ -148,7 +149,8 @@ private[crossdata] case class PersistSelectAsTable( val data = DataFrame(crossdataContext, query) val df = existingSchema match { // If we are inserting into an existing table, just use the existing schema. - case Some(schema) => sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, schema) + case Some(schema) => + sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, schema) case None => data } @@ -157,12 +159,13 @@ private[crossdata] case class PersistSelectAsTable( if (createMetastoreTable) { val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) import XDCatalogCommon._ - val identifier = TableIdentifier(tableIdent.table, tableIdent.database).normalize(crossdataContext.conf) - val crossdataTable = CrossdataTable(identifier, Some(resolved.relation.schema), provider, Array.empty, options) + val identifier = + TableIdentifier(tableIdent.table, tableIdent.database).normalize(crossdataContext.conf) + val crossdataTable = + CrossdataTable(identifier, Some(resolved.relation.schema), provider, Array.empty, options) crossdataContext.catalog.persistTable(crossdataTable, LogicalRelation(resolved.relation)) } - Seq.empty[Row] } @@ -172,8 +175,8 @@ private[crossdata] case class PersistSelectAsTable( val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") leftPadded.zip(rightPadded).map { - case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.size) + 3)) + r + case (l, r) => + (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.size) + 3)) + r } } } - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/execution/ddl.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/execution/ddl.scala index f0e2e8f4d..b01d15f3b 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/execution/ddl.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/execution/ddl.scala @@ -64,22 +64,29 @@ object DDLUtils { case (value: String, _: TimestampType) => Try(Timestamp.valueOf(value)) case (seq: Seq[_], ArrayType(elementType, withNulls)) => - seqOfTryToTryOfSeq(seq map { seqValue => convertSparkDatatypeToScala(seqValue, elementType) }) + seqOfTryToTryOfSeq(seq map { seqValue => + convertSparkDatatypeToScala(seqValue, elementType) + }) case (invalidSeq, ArrayType(elementType, withNulls)) => Failure(new RuntimeException("Invalid array passed as argument:" + invalidSeq.toString)) case (mapParsed: Map[_, _], MapType(keyType, valueType, withNulls)) => Try( - mapParsed map { - case (key, value) => (convertSparkDatatypeToScala(key, keyType).get, convertSparkDatatypeToScala(value, valueType).get) - } + mapParsed map { + case (key, value) => + (convertSparkDatatypeToScala(key, keyType).get, + convertSparkDatatypeToScala(value, valueType).get) + } ) case (invalidMap, MapType(keyType, valueType, withNulls)) => Failure(new RuntimeException("Invalid map passed as argument:" + invalidMap.toString)) - case unparsed => Failure(new RuntimeException("Impossible to parse value as Spark DataType provided:" + unparsed.toString)) + case unparsed => + Failure( + new RuntimeException( + "Impossible to parse value as Spark DataType provided:" + unparsed.toString)) } } @@ -92,19 +99,23 @@ object DDLUtils { private def seqOfTryToTryOfSeq[T](tries: Seq[Try[T]]): Try[Seq[T]] = { Try( - tries map (_.get) + tries map (_.get) ) } } -private[crossdata] case class ImportTablesUsingWithOptions(datasource: String, opts: Map[String, String]) - extends LogicalPlan with RunnableCommand with SparkLoggerComponent { +private[crossdata] case class ImportTablesUsingWithOptions(datasource: String, + opts: Map[String, String]) + extends LogicalPlan + with RunnableCommand + with SparkLoggerComponent { // The result of IMPORT TABLE has only tableIdentifier so far. override val output: Seq[Attribute] = { val schema = StructType( - Seq(StructField("tableIdentifier", ArrayType(StringType), false), StructField("ignored", BooleanType, false)) + Seq(StructField("tableIdentifier", ArrayType(StringType), false), + StructField("ignored", BooleanType, false)) ) schema.toAttributes } @@ -113,7 +124,8 @@ private[crossdata] case class ImportTablesUsingWithOptions(datasource: String, o def tableExists(tableId: TableIdentifier): Boolean = { val doExist = sqlContext.catalog.tableExists(tableId) - if (doExist) log.warn(s"IMPORT TABLE omitted already registered table: ${tableId.unquotedString}") + if (doExist) + log.warn(s"IMPORT TABLE omitted already registered table: ${tableId.unquotedString}") doExist } @@ -133,10 +145,13 @@ private[crossdata] case class ImportTablesUsingWithOptions(datasource: String, o if (!ignoreTable) { logInfo(s"Importing table ${tableId.unquotedString}") val optionsWithTable = inventoryRelation.generateConnectorOpts(table, opts) - val identifier = TableIdentifier(table.tableName, table.database).normalize(sqlContext.conf) - val crossdataTable = CrossdataTable(identifier, table.schema, datasource, Array.empty, optionsWithTable) + val identifier = + TableIdentifier(table.tableName, table.database).normalize(sqlContext.conf) + val crossdataTable = + CrossdataTable(identifier, table.schema, datasource, Array.empty, optionsWithTable) import org.apache.spark.sql.crossdata.util.CreateRelationUtil._ - sqlContext.catalog.persistTable(crossdataTable, createLogicalRelation(sqlContext, crossdataTable)) + sqlContext.catalog + .persistTable(crossdataTable, createLogicalRelation(sqlContext, crossdataTable)) } val tableSeq = DDLUtils.tableIdentifierToSeq(tableId) Row(tableSeq, ignoreTable) @@ -154,11 +169,13 @@ private[crossdata] case class DropTable(tableIdentifier: TableIdentifier) extend } -private[crossdata] case class DropExternalTable(tableIdentifier: TableIdentifier) extends RunnableCommand { +private[crossdata] case class DropExternalTable(tableIdentifier: TableIdentifier) + extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - val crossadataTable = sqlContext.catalog.tableMetadata(tableIdentifier) getOrElse (sys.error("Error dropping external table. Table doesn't exist in the catalog")) + val crossadataTable = sqlContext.catalog.tableMetadata(tableIdentifier) getOrElse (sys.error( + "Error dropping external table. Table doesn't exist in the catalog")) val provider = crossadataTable.datasource val resolved = ResolvedDataSource.lookupDataSource(provider).newInstance() @@ -169,7 +186,6 @@ private[crossdata] case class DropExternalTable(tableIdentifier: TableIdentifier throw new AnalysisException(s"Table ${tableIdentifier.unquotedString} does not exist") case tableManipulation: TableManipulation => - tableManipulation.dropExternalTable(sqlContext, crossadataTable.opts) map { result => sqlContext.catalog.dropTable(tableIdentifier) Seq.empty @@ -193,12 +209,14 @@ private[crossdata] case object DropAllTables extends RunnableCommand { } -private[crossdata] case class InsertIntoTable(tableIdentifier: TableIdentifier, parsedRows: Seq[DDLUtils.RowValues], schemaFromUser: Option[Seq[String]] = None) - extends RunnableCommand { +private[crossdata] case class InsertIntoTable(tableIdentifier: TableIdentifier, + parsedRows: Seq[DDLUtils.RowValues], + schemaFromUser: Option[Seq[String]] = None) + extends RunnableCommand { override def output: Seq[Attribute] = { val schema = StructType( - Seq(StructField("Number of insertions", IntegerType, nullable = false)) + Seq(StructField("Number of insertions", IntegerType, nullable = false)) ) schema.toAttributes } @@ -208,15 +226,16 @@ private[crossdata] case class InsertIntoTable(tableIdentifier: TableIdentifier, sqlContext.catalog.lookupRelation(tableIdentifier) match { case Subquery(_, LogicalRelation(relation: BaseRelation, _)) => - - val schema = schemaFromUser map (DDLUtils.extractSchema(_, relation.schema)) getOrElse relation.schema + val schema = schemaFromUser map (DDLUtils + .extractSchema(_, relation.schema)) getOrElse relation.schema relation match { case insertableRelation: InsertableRelation => val dataframe = convertRows(sqlContext, parsedRows, schema) - sqlContext.catalog.indexMetadataByTableIdentifier(tableIdentifier).foreach{ idxIdentifier => - indexData(sqlContext, idxIdentifier, schema) + sqlContext.catalog.indexMetadataByTableIdentifier(tableIdentifier).foreach { + idxIdentifier => + indexData(sqlContext, idxIdentifier, schema) } insertableRelation.insert(dataframe, overwrite = false) @@ -235,7 +254,6 @@ private[crossdata] case class InsertIntoTable(tableIdentifier: TableIdentifier, sys.error("The Datasource does not support INSERT INTO table VALUES command") } - case _ => sys.error("Table not found. Are you trying to insert values into a view/temporary table?") } @@ -248,9 +266,12 @@ private[crossdata] case class InsertIntoTable(tableIdentifier: TableIdentifier, * * @param sqlContext */ - private def indexData(sqlContext: SQLContext, crossdataIndex: CrossdataIndex, tableSchema: StructType): Unit = { + private def indexData(sqlContext: SQLContext, + crossdataIndex: CrossdataIndex, + tableSchema: StructType): Unit = { - val columnsToIndex: Seq[String] = crossdataIndex.pk +: crossdataIndex.indexedCols.filter(tableSchema.getFieldIndex(_).isDefined) + val columnsToIndex: Seq[String] = crossdataIndex.pk +: crossdataIndex.indexedCols.filter( + tableSchema.getFieldIndex(_).isDefined) val filteredParsedRows = parsedRows.map { row => columnsToIndex map { idxCol => @@ -258,14 +279,18 @@ private[crossdata] case class InsertIntoTable(tableIdentifier: TableIdentifier, } } - InsertIntoTable(crossdataIndex.indexIdentifier.asTableIdentifierNormalized.toTableIdentifier, filteredParsedRows, Some(columnsToIndex)).run(sqlContext) + InsertIntoTable(crossdataIndex.indexIdentifier.asTableIdentifierNormalized.toTableIdentifier, + filteredParsedRows, + Some(columnsToIndex)).run(sqlContext) } - private def convertRows(sqlContext: SQLContext, rows: Seq[DDLUtils.RowValues], tableSchema: StructType): DataFrame = { + private def convertRows(sqlContext: SQLContext, + rows: Seq[DDLUtils.RowValues], + tableSchema: StructType): DataFrame = { val parsedRowsConverted: Seq[Row] = parsedRows map { values => - - if (tableSchema.fields.length != values.length) sys.error("Invalid length of parameters") + if (tableSchema.fields.length != values.length) + sys.error("Invalid length of parameters") val valuesConverted = tableSchema.fields zip values map { case (schemaCol, value) => @@ -277,29 +302,30 @@ private[crossdata] case class InsertIntoTable(tableIdentifier: TableIdentifier, Row.fromSeq(valuesConverted) } - val dataframe = sqlContext.asInstanceOf[XDContext].createDataFrame(parsedRowsConverted, tableSchema) + val dataframe = + sqlContext.asInstanceOf[XDContext].createDataFrame(parsedRowsConverted, tableSchema) dataframe } } private[crossdata] object InsertIntoTable { - def apply(tableIdentifier: TableIdentifier, parsedRows: Seq[DDLUtils.RowValues]) = new InsertIntoTable(tableIdentifier, parsedRows) + def apply(tableIdentifier: TableIdentifier, parsedRows: Seq[DDLUtils.RowValues]) = + new InsertIntoTable(tableIdentifier, parsedRows) } object CreateTempView { def apply( - viewIdentifier: ViewIdentifier, - queryPlan: LogicalPlan, - sql: String - ): CreateTempView = new CreateTempView(viewIdentifier, queryPlan, Some(sql)) + viewIdentifier: ViewIdentifier, + queryPlan: LogicalPlan, + sql: String + ): CreateTempView = new CreateTempView(viewIdentifier, queryPlan, Some(sql)) } private[crossdata] case class CreateTempView( - viewIdentifier: ViewIdentifier, - queryPlan: LogicalPlan, - sql: Option[String] - ) - extends RunnableCommand { + viewIdentifier: ViewIdentifier, + queryPlan: LogicalPlan, + sql: Option[String] +) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.catalog.registerView(viewIdentifier, queryPlan, sql) @@ -308,8 +334,11 @@ private[crossdata] case class CreateTempView( } -private[crossdata] case class CreateView(viewIdentifier: ViewIdentifier, queryPlan: LogicalPlan, sql: String) - extends LogicalPlan with RunnableCommand { +private[crossdata] case class CreateView(viewIdentifier: ViewIdentifier, + queryPlan: LogicalPlan, + sql: String) + extends LogicalPlan + with RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.catalog.persistView(viewIdentifier, queryPlan, sql) @@ -318,9 +347,9 @@ private[crossdata] case class CreateView(viewIdentifier: ViewIdentifier, queryPl } - private[crossdata] case class DropView(viewIdentifier: ViewIdentifier) - extends LogicalPlan with RunnableCommand { + extends LogicalPlan + with RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.catalog.dropView(viewIdentifier) @@ -328,8 +357,7 @@ private[crossdata] case class DropView(viewIdentifier: ViewIdentifier) } } -private[crossdata] case class AddJar(jarPath: String) - extends LogicalPlan with RunnableCommand { +private[crossdata] case class AddJar(jarPath: String) extends LogicalPlan with RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { if (jarPath.toLowerCase.startsWith("hdfs://") || File(jarPath).exists) { @@ -341,20 +369,19 @@ private[crossdata] case class AddJar(jarPath: String) } } - object CreateGlobalIndex { val DefaultDatabaseName = "gidx" } - private[crossdata] case class CreateGlobalIndex( - index: TableIdentifier, - tableIdent: TableIdentifier, - cols: Seq[String], - pk: String, - provider: Option[String], - options: Map[String, String] - ) extends LogicalPlan with RunnableCommand { + index: TableIdentifier, + tableIdent: TableIdentifier, + cols: Seq[String], + pk: String, + provider: Option[String], + options: Map[String, String] +) extends LogicalPlan + with RunnableCommand { import CreateGlobalIndex._ @@ -362,7 +389,8 @@ private[crossdata] case class CreateGlobalIndex( Try { val indexProvider = provider getOrElse "com.stratio.crossdata.connector.elasticsearch" - val finalIndex = IndexIdentifier(index.table, index.database getOrElse DefaultDatabaseName).normalize(sqlContext.conf) + val finalIndex = IndexIdentifier(index.table, index.database getOrElse DefaultDatabaseName) + .normalize(sqlContext.conf) val colsWithoutSchema = Seq(pk) ++ cols @@ -376,13 +404,20 @@ private[crossdata] case class CreateGlobalIndex( } //TODO: Change index name, for allowing multiple index ??? - CreateExternalTable(TableIdentifier(finalIndex.indexType, Option(finalIndex.indexName)), elasticSchema, indexProvider, options).run(sqlContext) - - CrossdataIndex(tableIdent.normalize(sqlContext.conf), finalIndex, cols, pk, indexProvider, options) + CreateExternalTable(TableIdentifier(finalIndex.indexType, Option(finalIndex.indexName)), + elasticSchema, + indexProvider, + options).run(sqlContext) + + CrossdataIndex(tableIdent.normalize(sqlContext.conf), + finalIndex, + cols, + pk, + indexProvider, + options) } - private def saveIndexMetadata(sqlContext: SQLContext, crossdataIndex: CrossdataIndex) = { sqlContext.catalog.persistIndex(crossdataIndex) @@ -397,73 +432,90 @@ private[crossdata] case class CreateGlobalIndex( } } - private[crossdata] case class AddApp(jarPath: String, className: String, aliasName: Option[String] = None) - extends LogicalPlan with RunnableCommand { +private[crossdata] case class AddApp(jarPath: String, + className: String, + aliasName: Option[String] = None) + extends LogicalPlan + with RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - if (File(jarPath).exists) { - sqlContext.addJar(jarPath) - } else { - sys.error("File doesn't exist") - } - sqlContext.asInstanceOf[XDContext].addApp(path = jarPath, clss = className, alias = aliasName.getOrElse(jarPath.split("/").last.split('.').head)) - Seq.empty + override def run(sqlContext: SQLContext): Seq[Row] = { + if (File(jarPath).exists) { + sqlContext.addJar(jarPath) + } else { + sys.error("File doesn't exist") } + sqlContext + .asInstanceOf[XDContext] + .addApp(path = jarPath, + clss = className, + alias = aliasName.getOrElse(jarPath.split("/").last.split('.').head)) + Seq.empty } +} - private[crossdata] case class ExecuteApp(appName: String, arguments: Seq[String], options: Option[Map[String, String]]) - extends LogicalPlan with RunnableCommand { - - override val output: Seq[Attribute] = { - val schema = StructType(Seq( - StructField("infoMessage", StringType, nullable = true) - )) - schema.toAttributes - } - - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.asInstanceOf[XDContext].executeApp(appName, arguments, options) - } +private[crossdata] case class ExecuteApp(appName: String, + arguments: Seq[String], + options: Option[Map[String, String]]) + extends LogicalPlan + with RunnableCommand { + override val output: Seq[Attribute] = { + val schema = StructType( + Seq( + StructField("infoMessage", StringType, nullable = true) + )) + schema.toAttributes } - case class CreateExternalTable( - tableIdent: TableIdentifier, - userSpecifiedSchema: StructType, - provider: String, - options: Map[String, String], - allowExisting: Boolean = false) extends LogicalPlan with RunnableCommand { - + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.asInstanceOf[XDContext].executeApp(appName, arguments, options) + } - override def run(sqlContext: SQLContext): Seq[Row] = { +} - val resolved = ResolvedDataSource.lookupDataSource(provider).newInstance() +case class CreateExternalTable(tableIdent: TableIdentifier, + userSpecifiedSchema: StructType, + provider: String, + options: Map[String, String], + allowExisting: Boolean = false) + extends LogicalPlan + with RunnableCommand { - resolved match { + override def run(sqlContext: SQLContext): Seq[Row] = { - case _ if sqlContext.catalog.tableExists(tableIdent) => - throw new AnalysisException(s"Table ${tableIdent.unquotedString} already exists") + val resolved = ResolvedDataSource.lookupDataSource(provider).newInstance() - case tableManipulation: TableManipulation => + resolved match { - val tableInventory = tableManipulation.createExternalTable(sqlContext, tableIdent.table, tableIdent.database, userSpecifiedSchema, options) - tableInventory.map { tableInventory => - val optionsWithTable = tableManipulation.generateConnectorOpts(tableInventory, options) - val identifier = TableIdentifier(tableIdent.table, tableIdent.database).normalize(sqlContext.conf) - val crossdataTable = CrossdataTable(identifier, Option(userSpecifiedSchema), provider, Array.empty, optionsWithTable) - import org.apache.spark.sql.crossdata.util.CreateRelationUtil._ - sqlContext.catalog.persistTable(crossdataTable, createLogicalRelation(sqlContext, crossdataTable)) - } getOrElse (throw new RuntimeException(s"External table can't be created")) + case _ if sqlContext.catalog.tableExists(tableIdent) => + throw new AnalysisException(s"Table ${tableIdent.unquotedString} already exists") - case _ => - sys.error("The Datasource does not support CREATE EXTERNAL TABLE command") - } - - Seq.empty + case tableManipulation: TableManipulation => + val tableInventory = tableManipulation.createExternalTable(sqlContext, + tableIdent.table, + tableIdent.database, + userSpecifiedSchema, + options) + tableInventory.map { tableInventory => + val optionsWithTable = tableManipulation.generateConnectorOpts(tableInventory, options) + val identifier = + TableIdentifier(tableIdent.table, tableIdent.database).normalize(sqlContext.conf) + val crossdataTable = CrossdataTable(identifier, + Option(userSpecifiedSchema), + provider, + Array.empty, + optionsWithTable) + import org.apache.spark.sql.crossdata.util.CreateRelationUtil._ + sqlContext.catalog.persistTable(crossdataTable, + createLogicalRelation(sqlContext, crossdataTable)) + } getOrElse (throw new RuntimeException(s"External table can't be created")) + case _ => + sys.error("The Datasource does not support CREATE EXTERNAL TABLE command") } - } - + Seq.empty + } +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/globalindex/IndexUtils.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/globalindex/IndexUtils.scala index d1209898d..1bffada7a 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/globalindex/IndexUtils.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/globalindex/IndexUtils.scala @@ -32,16 +32,22 @@ object IndexUtils { def areAllAttributeIndexedInExpr(condition: Expression, indexedCols: Seq[String]): Boolean = { @tailrec - def checkIfRemainExprAreSupported(remainExpr: Seq[Expression]): Boolean = remainExpr match { - case seq if seq.isEmpty => true - case nonEmptySeq => nonEmptySeq.head match { - case predicate: Predicate if !isSupportedPredicate(predicate) => false - case UnresolvedAttribute(name) if !indexedCols.contains(name.last) => false // TODO TOFIX subdocuments can cause conflicts - case AttributeReference(name, _, _, _) if !indexedCols.contains(name) => false - case head if head.children.nonEmpty => checkIfRemainExprAreSupported(remainExpr.tail ++ remainExpr.head.children) - case _ => checkIfRemainExprAreSupported(remainExpr.tail) + def checkIfRemainExprAreSupported(remainExpr: Seq[Expression]): Boolean = + remainExpr match { + case seq if seq.isEmpty => true + case nonEmptySeq => + nonEmptySeq.head match { + case predicate: Predicate if !isSupportedPredicate(predicate) => + false + case UnresolvedAttribute(name) if !indexedCols.contains(name.last) => + false // TODO TOFIX subdocuments can cause conflicts + case AttributeReference(name, _, _, _) if !indexedCols.contains(name) => + false + case head if head.children.nonEmpty => + checkIfRemainExprAreSupported(remainExpr.tail ++ remainExpr.head.children) + case _ => checkIfRemainExprAreSupported(remainExpr.tail) + } } - } checkIfRemainExprAreSupported(Seq(condition)) } @@ -70,4 +76,4 @@ object IndexUtils { case _ => false } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/nativeudfs.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/nativeudfs.scala index b1e16223d..1d789dbcb 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/nativeudfs.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/nativeudfs.scala @@ -21,18 +21,16 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.DataType - -case class NativeUDF(name: String, - dataType: DataType, - children: Seq[Expression]) extends Expression with Unevaluable { +case class NativeUDF(name: String, dataType: DataType, children: Seq[Expression]) + extends Expression + with Unevaluable { override def toString: String = s"NativeUDF#$name(${children.mkString(",")})" override def nullable: Boolean = true } -case class EvaluateNativeUDF(udf: NativeUDF, - child: LogicalPlan, - resultAttribute: Attribute) extends logical.UnaryNode { +case class EvaluateNativeUDF(udf: NativeUDF, child: LogicalPlan, resultAttribute: Attribute) + extends logical.UnaryNode { def output: Seq[Attribute] = child.output :+ resultAttribute @@ -49,44 +47,43 @@ object EvaluateNativeUDF { // case class NativeUDFEvaluation(udf: NativeUDF, output: Seq[Attribute], child: SparkPlan) extends SparkPlan /* -* -* Analysis rule to replace resolved NativeUDFs by their evaluations as filters LogicalPlans -* These evaluations contain the information needed to refer the UDF in the native connector -* query generator. -* -*/ + * + * Analysis rule to replace resolved NativeUDFs by their evaluations as filters LogicalPlans + * These evaluations contain the information needed to refer the UDF in the native connector + * query generator. + * + */ object ExtractNativeUDFs extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case plan: EvaluateNativeUDF => plan case plan: LogicalPlan => - plan.expressions. - flatMap(_.collect {case udf: NativeUDF => udf} ). - find(_.resolved). - map { case udf => - var evaluation: EvaluateNativeUDF = null + plan.expressions.flatMap(_.collect { case udf: NativeUDF => udf }).find(_.resolved).map { + case udf => + var evaluation: EvaluateNativeUDF = null - val newChildren = plan.children flatMap { child => - // Check to make sure that the UDF can be evaluated with only the input of this child. - // Other cases are disallowed as they are ambiguous or would require a cartesian - // product. - if (udf.references.subsetOf(child.outputSet)) { - evaluation = EvaluateNativeUDF(udf, child) - evaluation::Nil - } else if (udf.references.intersect(child.outputSet).nonEmpty) { - sys.error(s"Invalid NativeUDF $udf, requires attributes from more than one child.") - } else { - child::Nil - } + val newChildren = plan.children flatMap { child => + // Check to make sure that the UDF can be evaluated with only the input of this child. + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + if (udf.references.subsetOf(child.outputSet)) { + evaluation = EvaluateNativeUDF(udf, child) + evaluation :: Nil + } else if (udf.references.intersect(child.outputSet).nonEmpty) { + sys.error(s"Invalid NativeUDF $udf, requires attributes from more than one child.") + } else { + child :: Nil } + } - assert(evaluation != null, "Unable to evaluate NativeUDF. Missing input attributes.") + assert(evaluation != null, "Unable to evaluate NativeUDF. Missing input attributes.") - logical.Project( + logical.Project( plan.output, //plan.withNewChildren(newChildren) plan.transformExpressions { - case u: NativeUDF if(u.fastEquals(udf)) => evaluation.resultAttribute + case u: NativeUDF if (u.fastEquals(udf)) => + evaluation.resultAttribute }.withNewChildren(newChildren) - ) - } getOrElse plan + ) + } getOrElse plan } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/optimizer/XDOptimizer.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/optimizer/XDOptimizer.scala index 6527fb32b..c003ab379 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/optimizer/XDOptimizer.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/optimizer/XDOptimizer.scala @@ -38,26 +38,27 @@ case class XDOptimizer(xdContext: XDContext, conf: CatalystConf) extends Optimiz val defaultOptimizer = DefaultOptimizer(conf) - def convertStrategy(strategy: defaultOptimizer.Strategy): Strategy = strategy.maxIterations match { - case 1 => Once - case n => FixedPoint(n) - } + def convertStrategy(strategy: defaultOptimizer.Strategy): Strategy = + strategy.maxIterations match { + case 1 => Once + case n => FixedPoint(n) + } def convertBatches(batch: defaultOptimizer.Batch): Batch = Batch(batch.name, convertStrategy(batch.strategy), batch.rules: _*) override val batches: List[Batch] = - (defaultOptimizer.batches map (convertBatches(_))) ++ Seq(Batch("Global indexes phase", Once, CheckGlobalIndexInFilters(xdContext))) + (defaultOptimizer.batches map (convertBatches(_))) ++ Seq( + Batch("Global indexes phase", Once, CheckGlobalIndexInFilters(xdContext))) } - - case class CheckGlobalIndexInFilters(xdContext: XDContext) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case FilterWithIndexLogicalPlan(filters, projects, ExtendedUnresolvedRelation(tableIdentifier, relation)) => - + case FilterWithIndexLogicalPlan(filters, + projects, + ExtendedUnresolvedRelation(tableIdentifier, relation)) => val crossdataIndex = { xdContext.catalog.indexMetadataByTableIdentifier(tableIdentifier) } getOrElse { @@ -66,7 +67,8 @@ case class CheckGlobalIndexInFilters(xdContext: XDContext) extends Rule[LogicalP //Change the filters that has indexed rows, with a Filter IN with ES results or LocalRelation if we don't have results val newFilters: Seq[LogicalPlan] = filters map { filter => - if (IndexUtils.areAllAttributeIndexedInExpr(filter.condition, crossdataIndex.indexedCols)) { + if (IndexUtils + .areAllAttributeIndexedInExpr(filter.condition, crossdataIndex.indexedCols)) { val indexLogicalPlan = buildIndexRequestLogicalPlan(filter.condition, crossdataIndex) val indexedRows = XDDataFrame(xdContext, indexLogicalPlan).collect() //TODO: Warning memory issues if (indexedRows.nonEmpty) { @@ -76,7 +78,9 @@ case class CheckGlobalIndexInFilters(xdContext: XDContext) extends Rule[LogicalP val pkSchema = DDLUtils.extractSchema(Seq(crossdataIndex.pk), lr.schema) val pkAttribute = schemaToAttribute(pkSchema).head analyzeAndOptimize( - logical.Filter(In(pkAttribute, resultPksToLiterals(indexedRows, pkSchema.fields.head.dataType)), relation) + logical.Filter(In(pkAttribute, + resultPksToLiterals(indexedRows, pkSchema.fields.head.dataType)), + relation) ) } else { @@ -104,7 +108,6 @@ case class CheckGlobalIndexInFilters(xdContext: XDContext) extends Rule[LogicalP } - private def analyze(plan: LogicalPlan): LogicalPlan = { val analyzed = xdContext.analyzer.execute(plan) xdContext.analyzer.checkAnalysis(analyzed) @@ -116,38 +119,47 @@ case class CheckGlobalIndexInFilters(xdContext: XDContext) extends Rule[LogicalP } private def schemaToAttribute(schema: StructType): Seq[UnresolvedAttribute] = - schema.fields map {field => UnresolvedAttribute(field.name)} + schema.fields map { field => + UnresolvedAttribute(field.name) + } - private def resultPksToLiterals(rows: Array[Row], dataType:DataType): Seq[Literal] = + private def resultPksToLiterals(rows: Array[Row], dataType: DataType): Seq[Literal] = rows map { row => val valTransformed = row.get(0) Literal.create(valTransformed, dataType) } //TODO compound PK - private def buildIndexRequestLogicalPlan(condition: Expression, index: CrossdataIndex): LogicalPlan = { + private def buildIndexRequestLogicalPlan(condition: Expression, + index: CrossdataIndex): LogicalPlan = { - val logicalRelation = xdContext.catalog.lookupRelation(index.indexIdentifier.asTableIdentifierNormalized.toTableIdentifier) match { - case Subquery(_, logicalRelation @ LogicalRelation(_: BaseRelation, _)) => logicalRelation + val logicalRelation = xdContext.catalog.lookupRelation( + index.indexIdentifier.asTableIdentifierNormalized.toTableIdentifier) match { + case Subquery(_, logicalRelation @ LogicalRelation(_: BaseRelation, _)) => + logicalRelation } //We need to retrieve all the retrieve cols for use the filter - val pkAndColsIndexed: Seq[UnresolvedAttribute] = schemaToAttribute(DDLUtils.extractSchema(Seq(index.pk)++index.indexedCols, logicalRelation.schema)) + val pkAndColsIndexed: Seq[UnresolvedAttribute] = schemaToAttribute( + DDLUtils.extractSchema(Seq(index.pk) ++ index.indexedCols, logicalRelation.schema)) //Old attributes reference have to be updated val convertedCondition = condition transform { - case UnresolvedAttribute(name) => (pkAndColsIndexed filter (_.name == name)).head - case AttributeReference(name, _, _, _) => (pkAndColsIndexed filter (_.name == name)).head + case UnresolvedAttribute(name) => + (pkAndColsIndexed filter (_.name == name)).head + case AttributeReference(name, _, _, _) => + (pkAndColsIndexed filter (_.name == name)).head } Filter(convertedCondition, Project(pkAndColsIndexed, logicalRelation)) } def combineFiltersAndRelation(filters: Seq[LogicalPlan], relation: LogicalPlan): LogicalPlan = - filters.foldRight(relation){(filter, accum) => filter.withNewChildren(Seq(accum))} + filters.foldRight(relation) { (filter, accum) => + filter.withNewChildren(Seq(accum)) + } } - // TODO comment? object FilterWithIndexLogicalPlan { type ReturnType = (Seq[Filter], Seq[Project], ExtendedUnresolvedRelation) @@ -161,15 +173,15 @@ object FilterWithIndexLogicalPlan { @tailrec def recoverFilterAndProjects( - filters: Seq[logical.Filter], - projects: Seq[logical.Project], - current: LogicalPlan - ): Option[ReturnType] = current match { + filters: Seq[logical.Filter], + projects: Seq[logical.Project], + current: LogicalPlan + ): Option[ReturnType] = current match { - case f@logical.Filter(_, child: LogicalPlan) => + case f @ logical.Filter(_, child: LogicalPlan) => recoverFilterAndProjects(filters :+ f, projects, child) - case p@logical.Project(_, child: LogicalPlan) => + case p @ logical.Project(_, child: LogicalPlan) => recoverFilterAndProjects(filters, projects :+ p, child) case u: ExtendedUnresolvedRelation => @@ -177,4 +189,4 @@ object FilterWithIndexLogicalPlan { case _ => None } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/parser/XDDdlParser.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/parser/XDDdlParser.scala index 8f0017e41..9ce87c16a 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/parser/XDDdlParser.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/parser/XDDdlParser.scala @@ -27,15 +27,15 @@ import org.apache.spark.sql.types._ import scala.language.implicitConversions - -class XDDdlParser(parseQuery: String => LogicalPlan, xDContext: XDContext) extends DDLParser(parseQuery) { +class XDDdlParser(parseQuery: String => LogicalPlan, xDContext: XDContext) + extends DDLParser(parseQuery) { protected val IMPORT = Keyword("IMPORT") protected val TABLES = Keyword("TABLES") protected val DROP = Keyword("DROP") protected val VIEW = Keyword("VIEW") protected val EXTERNAL = Keyword("EXTERNAL") - protected val ADD =Keyword("ADD") + protected val ADD = Keyword("ADD") protected val JAR = Keyword("JAR") protected val INSERT = Keyword("INSERT") protected val INTO = Keyword("INTO") @@ -64,7 +64,6 @@ class XDDdlParser(parseQuery: String => LogicalPlan, xDContext: XDContext) exten protected val APP = Keyword("APP") protected val EXECUTE = Keyword("EXECUTE") - override protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable | importStart | dropTable | dropExternalTable | @@ -72,11 +71,10 @@ class XDDdlParser(parseQuery: String => LogicalPlan, xDContext: XDContext) exten // TODO move to StreamingDdlParser protected lazy val streamingSentences: Parser[LogicalPlan] = - describeEphemeralTable | showEphemeralTables | createEphemeralTable | dropAllEphemeralQueries | dropAllEphemeralTables | + describeEphemeralTable | showEphemeralTables | createEphemeralTable | dropAllEphemeralQueries | dropAllEphemeralTables | showEphemeralStatus | showEphemeralStatuses | startProcess | stopProcess | showEphemeralQueries | addEphemeralQuery | dropEphemeralQuery | dropEphemeralTable | dropAllTables - protected lazy val importStart: Parser[LogicalPlan] = IMPORT ~> TABLES ~> (USING ~> className) ~ (OPTIONS ~> options).? ^^ { case provider ~ ops => @@ -103,12 +101,13 @@ class XDDdlParser(parseQuery: String => LogicalPlan, xDContext: XDContext) exten protected lazy val schemaValues: Parser[Seq[String]] = "(" ~> repsep(token, ",") <~ ")" - protected lazy val tableValues: Parser[Seq[Any]] = "(" ~> repsep(mapValues | arrayValues | token, ",") <~ ")" + protected lazy val tableValues: Parser[Seq[Any]] = "(" ~> repsep(mapValues | arrayValues | token, + ",") <~ ")" protected lazy val arrayValues: Parser[Any] = ("[" ~> repsep(mapValues | token, ",") <~ "]") | ("[" ~> success(List()) <~ "]") - protected lazy val tokenMap: Parser[(Any,Any)] = { + protected lazy val tokenMap: Parser[(Any, Any)] = { (token <~ "-" <~ ">") ~ (arrayValues | token) ^^ { case key ~ value => (key, value) } @@ -116,26 +115,23 @@ class XDDdlParser(parseQuery: String => LogicalPlan, xDContext: XDContext) exten protected lazy val mapValues: Parser[Map[Any, Any]] = "(" ~> repsep(tokenMap, ",") <~ ")" ^^ { - case pairs => Map(pairs:_*) + case pairs => Map(pairs: _*) } | "(" ~> success(Map.empty[Any, Any]) <~ ")" - def token: Parser[String] = { import lexical.Token elem("token", _.isInstanceOf[Token]) ^^ (_.chars) } - protected lazy val insertIntoTable: Parser[LogicalPlan] = - INSERT ~> INTO ~> tableIdentifier ~ schemaValues.? ~ (VALUES ~> repsep(tableValues,",")) ^^ { + INSERT ~> INTO ~> tableIdentifier ~ schemaValues.? ~ (VALUES ~> repsep(tableValues, ",")) ^^ { case tableId ~ schemaValues ~ tableValues => - if(schemaValues.isDefined) + if (schemaValues.isDefined) InsertIntoTable(tableId, tableValues, schemaValues) else InsertIntoTable(tableId, tableValues) } - protected lazy val dropView: Parser[LogicalPlan] = DROP ~> VIEW ~> tableIdentifier ^^ { case tableId => @@ -170,24 +166,24 @@ class XDDdlParser(parseQuery: String => LogicalPlan, xDContext: XDContext) exten AddJar(jarPath.trim) } - -protected lazy val addApp: Parser[LogicalPlan] = - (ADD ~> APP ~> stringLit) ~ (AS ~> ident).? ~ (WITH ~> className) ^^ { - case jarPath ~ alias ~ cname => - AddApp(jarPath.toString, cname, alias) - } - + protected lazy val addApp: Parser[LogicalPlan] = + (ADD ~> APP ~> stringLit) ~ (AS ~> ident).? ~ (WITH ~> className) ^^ { + case jarPath ~ alias ~ cname => + AddApp(jarPath.toString, cname, alias) + } protected lazy val executeApp: Parser[LogicalPlan] = (EXECUTE ~> ident) ~ tableValues ~ (OPTIONS ~> options).? ^^ { case appName ~ arguments ~ opts => - val args=arguments map {arg=> arg.toString} + val args = arguments map { arg => + arg.toString + } ExecuteApp(appName, args, opts) } - /** - * Streaming - */ + /** + * Streaming + */ protected lazy val startProcess: Parser[LogicalPlan] = { (START ~> tableIdentifier) ^^ { case table => StartProcess(table.unquotedString) @@ -201,9 +197,8 @@ protected lazy val addApp: Parser[LogicalPlan] = } /** - * Ephemeral Table Functions - */ - + * Ephemeral Table Functions + */ protected lazy val describeEphemeralTable: Parser[LogicalPlan] = { (DESCRIBE ~ EPHEMERAL ~ TABLE ~> tableIdentifier) ^^ { case tableIdent => DescribeEphemeralTable(tableIdent) @@ -238,9 +233,8 @@ protected lazy val addApp: Parser[LogicalPlan] = } /** - * Ephemeral Table Status Functions - */ - + * Ephemeral Table Status Functions + */ protected lazy val showEphemeralStatus: Parser[LogicalPlan] = { (SHOW ~ EPHEMERAL ~ STATUS ~ IN ~> tableIdentifier) ^^ { case tableIdent => ShowEphemeralStatus(tableIdent) @@ -254,11 +248,10 @@ protected lazy val addApp: Parser[LogicalPlan] = } /** - * Ephemeral Queries Functions - */ - + * Ephemeral Queries Functions + */ protected lazy val showEphemeralQueries: Parser[LogicalPlan] = { - (SHOW ~ EPHEMERAL ~ QUERIES ~> ( IN ~> ident).? ) ^^ { + (SHOW ~ EPHEMERAL ~ QUERIES ~> (IN ~> ident).?) ^^ { case queryIdent => ShowEphemeralQueries(queryIdent) } } @@ -272,15 +265,19 @@ protected lazy val addApp: Parser[LogicalPlan] = xDContext.catalog.lookupRelation(tableIdent, alias) } - val ephTables: Seq[String] = queryTables.collect{ + val ephTables: Seq[String] = queryTables.collect { case StreamingRelation(ephTableName) => ephTableName } ephTables.distinct match { case Seq(eTableName) => - AddEphemeralQuery(eTableName, streamQl, topIdent.getOrElse(UUID.randomUUID().toString), new Integer(litN)) + AddEphemeralQuery(eTableName, + streamQl, + topIdent.getOrElse(UUID.randomUUID().toString), + new Integer(litN)) case tableNames => - sys.error(s"Expected an ephemeral table within the query, but found ${tableNames.mkString(",")}") + sys.error( + s"Expected an ephemeral table within the query, but found ${tableNames.mkString(",")}") } } } @@ -292,8 +289,9 @@ protected lazy val addApp: Parser[LogicalPlan] = } protected lazy val dropAllEphemeralQueries: Parser[LogicalPlan] = { - (DROP ~ ALL ~ EPHEMERAL ~ QUERIES ~> (IN ~> tableIdentifier).? ) ^^ { - case tableIdent => DropAllEphemeralQueries(tableIdent.map(_.unquotedString)) + (DROP ~ ALL ~ EPHEMERAL ~ QUERIES ~> (IN ~> tableIdentifier).?) ^^ { + case tableIdent => + DropAllEphemeralQueries(tableIdent.map(_.unquotedString)) } } @@ -307,7 +305,11 @@ protected lazy val addApp: Parser[LogicalPlan] = val streamSql = in.source.subSequence(in.offset, indexOfWithWindow).toString.trim def streamingInfoInput(inpt: Input): Input = { - val startsWithWindow = inpt.source.subSequence(inpt.offset, inpt.source.length()).toString.trim.startsWith("WITH WINDOW") + val startsWithWindow = inpt.source + .subSequence(inpt.offset, inpt.source.length()) + .toString + .trim + .startsWith("WITH WINDOW") if (startsWithWindow) inpt else streamingInfoInput(inpt.rest) } Success(streamSql, streamingInfoInput(in)) @@ -319,7 +321,6 @@ protected lazy val addApp: Parser[LogicalPlan] = CREATE ~ GLOBAL ~ INDEX ~> tableIdentifier ~ (ON ~> tableIdentifier) ~ schemaValues ~ (WITH ~> PK ~> token) ~ (USING ~> className).? ~ (OPTIONS ~> options) ^^ { case index ~ table ~ columns ~ pk ~ provider ~ opts => - CreateGlobalIndex(index, table, columns, pk, provider, opts) } } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/planning/ExtendedDataSourceStrategy.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/planning/ExtendedDataSourceStrategy.scala index 3e44557b9..84ac1f640 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/planning/ExtendedDataSourceStrategy.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/planning/ExtendedDataSourceStrategy.scala @@ -31,75 +31,73 @@ private[sql] object ExtendedDataSourceStrategy extends Strategy with SparkLogger def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { // TODO refactor => return None instead of check the aggregation - case ExtendedPhysicalOperation(projects, filters, l @ LogicalRelation(t: NativeFunctionExecutor, _)) - if plan.collectFirst { case _: Aggregate => false} getOrElse(true) => + case ExtendedPhysicalOperation(projects, + filters, + l @ LogicalRelation(t: NativeFunctionExecutor, _)) + if plan.collectFirst { case _: Aggregate => false } getOrElse (true) => pruneFilterProjectUdfs( - plan, - l, - projects, - filters, - (requestedColumns, srcFilters, attr2udf) => - toCatalystRDD(l, requestedColumns, t.buildScan( - requestedColumns.map { - case nat: AttributeReference if attr2udf contains nat.toString => nat.toString + plan, + l, + projects, + filters, + (requestedColumns, srcFilters, attr2udf) => + toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns.map { + case nat: AttributeReference if attr2udf contains nat.toString => + nat.toString case att => att.name }.toArray, srcFilters, attr2udf)) - ):: Nil + ) :: Nil case _ => Nil } - protected def pruneFilterProjectUdfs(plan: LogicalPlan, - relation: LogicalRelation, - projects: Seq[NamedExpression], - filterPredicates: Seq[Expression], - scanBuilder: ( - Seq[Attribute], - Array[Filter], - Map[String, NativeUDF] - ) => RDD[InternalRow] - ) = { + protected def pruneFilterProjectUdfs( + plan: LogicalPlan, + relation: LogicalRelation, + projects: Seq[NamedExpression], + filterPredicates: Seq[Expression], + scanBuilder: (Seq[Attribute], Array[Filter], Map[String, NativeUDF]) => RDD[InternalRow]) = { import org.apache.spark.sql.sources.CatalystToCrossdataAdapter - val (pro, fil, att2udf) = - (CatalystToCrossdataAdapter.getConnectorLogicalPlan(plan, projects, filterPredicates): @unchecked) match { - case (_, _, FilterReport(_, udfsIgnored)) if udfsIgnored.nonEmpty => - cannotExecuteNativeUDF(udfsIgnored) - case (SimpleLogicalPlan(pro, fil, udfs, _), _, _) => (pro, fil, udfs) - } + val (pro, fil, att2udf) = (CatalystToCrossdataAdapter + .getConnectorLogicalPlan(plan, projects, filterPredicates): @unchecked) match { + case (_, _, FilterReport(_, udfsIgnored)) if udfsIgnored.nonEmpty => + cannotExecuteNativeUDF(udfsIgnored) + case (SimpleLogicalPlan(pro, fil, udfs, _), _, _) => (pro, fil, udfs) + } val projectSet = AttributeSet(pro) - val filterSet = AttributeSet(filterPredicates.flatMap( - _.references flatMap { - case nat: AttributeReference if att2udf contains nat => - CatalystToCrossdataAdapter.udfFlattenedActualParameters(nat, (x: Attribute) => x)(att2udf) :+ nat - case x => Seq(relation.attributeMap(x)) - } - )) + val filterSet = AttributeSet( + filterPredicates.flatMap( + _.references flatMap { + case nat: AttributeReference if att2udf contains nat => + CatalystToCrossdataAdapter.udfFlattenedActualParameters(nat, (x: Attribute) => x)( + att2udf) :+ nat + case x => Seq(relation.attributeMap(x)) + } + )) val filterCondition = filterPredicates.reduceLeftOption(expressions.And) val requestedColumns = (projectSet ++ filterSet).toSeq - val scan = execution.PhysicalRDD.createFromDataSource( - requestedColumns, - scanBuilder(requestedColumns, fil, att2udf map { case (k, v) => k.toString() -> v }), - relation.relation) + val scan = execution.PhysicalRDD + .createFromDataSource(requestedColumns, scanBuilder(requestedColumns, fil, att2udf map { + case (k, v) => k.toString() -> v + }), relation.relation) execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } - private def cannotExecuteNativeUDF(udfsIgnored: Seq[AttributeReference]) = - throw new AnalysisException("Some filters containing native UDFS cannot be executed on the datasource." + - " It may happen when a cast is automatically applied by Spark, so try using the same type") - + throw new AnalysisException( + "Some filters containing native UDFS cannot be executed on the datasource." + + " It may happen when a cast is automatically applied by Spark, so try using the same type") /** - * Convert RDD of Row into RDD of InternalRow with objects in catalyst types - */ - private[this] def toCatalystRDD( - relation: LogicalRelation, - output: Seq[Attribute], - rdd: RDD[Row]): RDD[InternalRow] = { + * Convert RDD of Row into RDD of InternalRow with objects in catalyst types + */ + private[this] def toCatalystRDD(relation: LogicalRelation, + output: Seq[Attribute], + rdd: RDD[Row]): RDD[InternalRow] = { if (relation.relation.needConversion) { execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { @@ -107,4 +105,4 @@ private[sql] object ExtendedDataSourceStrategy extends Strategy with SparkLogger } } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/planning/XDStrategies.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/planning/XDStrategies.scala index 1026b8bf0..6a098cdc2 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/planning/XDStrategies.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/planning/XDStrategies.scala @@ -24,29 +24,40 @@ import org.apache.spark.sql.crossdata.catalyst.execution.{PersistDataSourceTable import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect} import org.apache.spark.sql.execution.{ExecutedCommand, SparkPlan, SparkStrategies} -trait XDStrategies extends SparkStrategies { - self: XDContext#XDPlanner => +trait XDStrategies extends SparkStrategies { self: XDContext#XDPlanner => object XDDDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableIdent, userSpecifiedSchema, provider, temporary, opts, allowExisting, _) => - + case CreateTableUsing(tableIdent, + userSpecifiedSchema, + provider, + temporary, + opts, + allowExisting, + _) => val crossdataTable = CrossdataTable( - TableIdentifierNormalized(tableIdent.table, tableIdent.database), - userSpecifiedSchema, - provider, - Array.empty[String], - opts + TableIdentifierNormalized(tableIdent.table, tableIdent.database), + userSpecifiedSchema, + provider, + Array.empty[String], + opts ) - val cmd = if(temporary) - RegisterDataSourceTable(crossdataTable, allowExisting) - else - PersistDataSourceTable(crossdataTable, allowExisting) + val cmd = + if (temporary) + RegisterDataSourceTable(crossdataTable, allowExisting) + else + PersistDataSourceTable(crossdataTable, allowExisting) ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsSelect(tableIdent, provider, false, partitionCols, mode, opts, query) => + case CreateTableUsingAsSelect(tableIdent, + provider, + false, + partitionCols, + mode, + opts, + query) => val cmd = PersistSelectAsTable(tableIdent, provider, partitionCols, mode, opts, query) ExecutedCommand(cmd) :: Nil diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/planning/patterns.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/planning/patterns.scala index a5ebf8730..a3ba0c6c7 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/planning/patterns.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/planning/patterns.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.crossdata.catalyst.EvaluateNativeUDF - object ExtendedPhysicalOperation extends PredicateHelper { type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan) @@ -37,8 +36,8 @@ object ExtendedPhysicalOperation extends PredicateHelper { Some((fields.getOrElse(child.output), filters, child)) } - def collectProjectsAndFilters(plan: LogicalPlan): - (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) = + def collectProjectsAndFilters(plan: LogicalPlan) + : (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) = plan match { case Project(fields, child) => val (_, filters, other, aliases) = collectProjectsAndFilters(child) @@ -60,9 +59,10 @@ object ExtendedPhysicalOperation extends PredicateHelper { (None, Nil, other, Map.empty) } - def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { - case a @ Alias(child, _) => a.toAttribute -> child - }.toMap + def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = + fields.collect { + case a @ Alias(child, _) => a.toAttribute -> child + }.toMap def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { expr.transform { @@ -73,4 +73,4 @@ object ExtendedPhysicalOperation extends PredicateHelper { aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) } } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/streaming/StreamingRelation.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/streaming/StreamingRelation.scala index b54c4b181..cd21b5aba 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/streaming/StreamingRelation.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/streaming/StreamingRelation.scala @@ -21,7 +21,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LeafNode case class StreamingRelation(ephemeralTableName: String) extends LeafNode { override def output: Seq[Attribute] = Seq.empty - override def productElement(n: Int): Any = throw new IndexOutOfBoundsException + override def productElement(n: Int): Any = + throw new IndexOutOfBoundsException override def productArity: Int = 0 } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/streaming/streamingDdl.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/streaming/streamingDdl.scala index bb0bde208..a8f8c1495 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/streaming/streamingDdl.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/catalyst/streaming/streamingDdl.scala @@ -32,34 +32,36 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.util.{Failure, Success} /** - * Ephemeral Table Functions - */ - -private[crossdata] case class DescribeEphemeralTable(tableIdent: TableIdentifier) extends LogicalPlan with RunnableCommand { + * Ephemeral Table Functions + */ +private[crossdata] case class DescribeEphemeralTable(tableIdent: TableIdentifier) + extends LogicalPlan + with RunnableCommand { override val output: Seq[Attribute] = { val schema = StructType( - Seq(StructField("Ephemeral table", StringType, nullable = false)) + Seq(StructField("Ephemeral table", StringType, nullable = false)) ) schema.toAttributes } override def run(sqlContext: SQLContext): Seq[Row] = - sqlContext.catalog.getEphemeralTable(normalizeTableIdentifier(tableIdent, sqlContext.conf)) + sqlContext.catalog + .getEphemeralTable(normalizeTableIdentifier(tableIdent, sqlContext.conf)) .map(ephTable => Seq(Row(ephTable.toPrettyString))) .getOrElse(throw new RuntimeException(s"${tableIdent.unquotedString} doesn't exist")) } - private[crossdata] case object ShowEphemeralTables extends LogicalPlan with RunnableCommand { override val output: Seq[Attribute] = { - val schema = StructType(Seq( - StructField("name", StringType, nullable = false), - StructField("status", StringType, nullable = false), - StructField("atomicWindow", IntegerType, nullable = false) - )) + val schema = StructType( + Seq( + StructField("name", StringType, nullable = false), + StructField("status", StringType, nullable = false), + StructField("atomicWindow", IntegerType, nullable = false) + )) schema.toAttributes } @@ -69,29 +71,33 @@ private[crossdata] case object ShowEphemeralTables extends LogicalPlan with Runn val tables = catalog.getAllEphemeralTables val status = tables.map(etm => catalog.getEphemeralStatus(etm.name).map(_.status)) - tables zip status map { case (tableModel, tableStatus) => - Row(tableModel.name, tableStatus.mkString, tableModel.options.atomicWindow) + tables zip status map { + case (tableModel, tableStatus) => + Row(tableModel.name, tableStatus.mkString, tableModel.options.atomicWindow) } } } - private[crossdata] case class CreateEphemeralTable(tableIdent: TableIdentifier, userSchema: Option[StructType], opts: Map[String, String]) - extends LogicalPlan with RunnableCommand { + extends LogicalPlan + with RunnableCommand { override val output: Seq[Attribute] = { val schema = StructType( - Seq(StructField("Ephemeral table", StringType, nullable = false)) + Seq(StructField("Ephemeral table", StringType, nullable = false)) ) schema.toAttributes } override def run(sqlContext: SQLContext): Seq[Row] = { - val ephTable = StreamingConfig.createEphemeralTableModel(normalizeTableIdentifier(tableIdent, sqlContext.conf), opts, userSchema) + val ephTable = StreamingConfig.createEphemeralTableModel( + normalizeTableIdentifier(tableIdent, sqlContext.conf), + opts, + userSchema) sqlContext.catalog.createEphemeralTable(ephTable) match { case Right(table) => Seq(Row(table.toPrettyString)) case Left(message) => sys.error(message) @@ -100,15 +106,13 @@ private[crossdata] case class CreateEphemeralTable(tableIdent: TableIdentifier, } - - - private[crossdata] case class DropEphemeralTable(tableIdent: TableIdentifier) - extends LogicalPlan with RunnableCommand { + extends LogicalPlan + with RunnableCommand { override val output: Seq[Attribute] = { val schema = StructType( - Seq(StructField("Dropped table", StringType, false)) + Seq(StructField("Dropped table", StringType, false)) ) schema.toAttributes } @@ -123,39 +127,39 @@ private[crossdata] case object DropAllEphemeralTables extends LogicalPlan with R override val output: Seq[Attribute] = { val schema = StructType( - Seq(StructField("Dropped tables", StringType, false)) + Seq(StructField("Dropped tables", StringType, false)) ) schema.toAttributes } override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.catalog val ephTables = catalog.getAllEphemeralTables catalog.dropAllEphemeralTables ephTables.map(eTable => Row(eTable.name)) - } } /** - * Ephemeral Table Status Functions - */ - -private[crossdata] case class ShowEphemeralStatus(tableIdent: TableIdentifier) extends LogicalPlan with RunnableCommand { + * Ephemeral Table Status Functions + */ +private[crossdata] case class ShowEphemeralStatus(tableIdent: TableIdentifier) + extends LogicalPlan + with RunnableCommand { override val output: Seq[Attribute] = { val schema = StructType( - Seq(StructField(s"status", StringType, nullable = false)) + Seq(StructField(s"status", StringType, nullable = false)) ) schema.toAttributes } override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.catalog.getEphemeralStatus(normalizeTableIdentifier(tableIdent, sqlContext.conf)) + sqlContext.catalog + .getEphemeralStatus(normalizeTableIdentifier(tableIdent, sqlContext.conf)) .map(ephStatus => Seq(Row(ephStatus.status.toString))) .getOrElse(sys.error(s"${tableIdent.unquotedString} status doesn't exist")) } @@ -164,10 +168,11 @@ private[crossdata] case class ShowEphemeralStatus(tableIdent: TableIdentifier) e private[crossdata] case object ShowAllEphemeralStatuses extends LogicalPlan with RunnableCommand { override val output: Seq[Attribute] = { - val schema = StructType(Seq( - StructField("name", StringType, nullable = false), - StructField("status", StringType, nullable = false) - )) + val schema = StructType( + Seq( + StructField("name", StringType, nullable = false), + StructField("status", StringType, nullable = false) + )) schema.toAttributes } @@ -178,18 +183,22 @@ private[crossdata] case object ShowAllEphemeralStatuses extends LogicalPlan with } -private[crossdata] case class StartProcess(tableIdentifier: String) extends LogicalPlan with RunnableCommand { +private[crossdata] case class StartProcess(tableIdentifier: String) + extends LogicalPlan + with RunnableCommand { override val output: Seq[Attribute] = { - val schema = StructType(Seq( - StructField("infoMessage", StringType, nullable = true) - )) + val schema = StructType( + Seq( + StructField("infoMessage", StringType, nullable = true) + )) schema.toAttributes } override def run(sqlContext: SQLContext): Seq[Row] = { val xdContext = sqlContext.asInstanceOf[XDContext] - val sparkJob = SparkJobLauncher.getSparkStreamingJob(xdContext, XDContext.xdConfig, tableIdentifier) + val sparkJob = + SparkJobLauncher.getSparkStreamingJob(xdContext, XDContext.xdConfig, tableIdentifier) sparkJob match { case Failure(exception) => @@ -205,7 +214,9 @@ private[crossdata] case class StartProcess(tableIdentifier: String) extends Logi } } -private[crossdata] case class StopProcess(tableIdentifier: String) extends LogicalPlan with RunnableCommand { +private[crossdata] case class StopProcess(tableIdentifier: String) + extends LogicalPlan + with RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val xdContext = sqlContext.asInstanceOf[XDContext] @@ -215,8 +226,8 @@ private[crossdata] case class StopProcess(tableIdentifier: String) extends Logic case Some(currentStatus) => if (currentStatus == EphemeralExecutionStatus.Started || currentStatus == EphemeralExecutionStatus.Starting) { xdContext.catalog.updateEphemeralStatus( - tableIdentifier, - EphemeralStatusModel(tableIdentifier, EphemeralExecutionStatus.Stopping) + tableIdentifier, + EphemeralStatusModel(tableIdentifier, EphemeralExecutionStatus.Stopping) ) } else { sys.error(s"Cannot stop process. $tableIdentifier status is $currentStatus") @@ -229,15 +240,15 @@ private[crossdata] case class StopProcess(tableIdentifier: String) extends Logic } /** - * Ephemeral Queries Functions - */ - + * Ephemeral Queries Functions + */ private[crossdata] case class ShowEphemeralQueries(tableIdentifier: Option[String] = None) - extends LogicalPlan with RunnableCommand { + extends LogicalPlan + with RunnableCommand { override val output: Seq[Attribute] = { val schema = StructType( - Seq(StructField("Ephemeral query", StringType, nullable = false)) + Seq(StructField("Ephemeral query", StringType, nullable = false)) ) schema.toAttributes } @@ -245,31 +256,32 @@ private[crossdata] case class ShowEphemeralQueries(tableIdentifier: Option[Strin override def run(sqlContext: SQLContext): Seq[Row] = { val queries = sqlContext.catalog.getAllEphemeralQueries - val filteredQueries = tableIdentifier.map(table => - queries.filter(eqm => eqm.ephemeralTableName == table) - ).getOrElse(queries) + val filteredQueries = tableIdentifier + .map(table => queries.filter(eqm => eqm.ephemeralTableName == table)) + .getOrElse(queries) filteredQueries.map(q => Row(q.toPrettyString)) } } - private[crossdata] case class AddEphemeralQuery(ephemeralTablename: String, sql: String, alias: String, window: Int, opts: Map[String, String] = Map.empty) - extends LogicalPlan with RunnableCommand { + extends LogicalPlan + with RunnableCommand { override val output: Seq[Attribute] = { val schema = StructType( - Seq(StructField("Ephemeral query", StringType, nullable = false)) + Seq(StructField("Ephemeral query", StringType, nullable = false)) ) schema.toAttributes } override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.catalog.createEphemeralQuery(EphemeralQueryModel(ephemeralTablename, sql, alias, window, opts)) match { + sqlContext.catalog.createEphemeralQuery( + EphemeralQueryModel(ephemeralTablename, sql, alias, window, opts)) match { case Left(errorMessage) => sys.error(errorMessage) case Right(query) => Seq(Row(query.toPrettyString)) } @@ -278,11 +290,12 @@ private[crossdata] case class AddEphemeralQuery(ephemeralTablename: String, } private[crossdata] case class DropEphemeralQuery(queryIdent: String) - extends LogicalPlan with RunnableCommand { + extends LogicalPlan + with RunnableCommand { override val output: Seq[Attribute] = { val schema = StructType( - Seq(StructField("Dropped query", StringType, nullable = false)) + Seq(StructField("Dropped query", StringType, nullable = false)) ) schema.toAttributes } @@ -293,11 +306,13 @@ private[crossdata] case class DropEphemeralQuery(queryIdent: String) } } -private[crossdata] case class DropAllEphemeralQueries(tableName: Option[String] = None) extends LogicalPlan with RunnableCommand { +private[crossdata] case class DropAllEphemeralQueries(tableName: Option[String] = None) + extends LogicalPlan + with RunnableCommand { override val output: Seq[Attribute] = { val schema = StructType( - Seq(StructField("Dropped query", StringType, nullable = false)) + Seq(StructField("Dropped query", StringType, nullable = false)) ) schema.toAttributes } @@ -308,8 +323,10 @@ private[crossdata] case class DropAllEphemeralQueries(tableName: Option[String] val filteredQueryAliases = { val queries = catalog.getAllEphemeralQueries - tableName.map(tname => queries.filter(eqm => eqm.ephemeralTableName == tname)) - .getOrElse(queries).map(q => q.alias) + tableName + .map(tname => queries.filter(eqm => eqm.ephemeralTableName == tname)) + .getOrElse(queries) + .map(q => q.alias) } filteredQueryAliases.map { queryAlias => @@ -319,6 +336,3 @@ private[crossdata] case class DropAllEphemeralQueries(tableName: Option[String] } } - - - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/config/CoreConfig.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/config/CoreConfig.scala index bcf4ee66f..c95384f72 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/config/CoreConfig.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/config/CoreConfig.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.SQLConf import scala.util.Try - object CoreConfig { val CoreBasicConfig = "core-reference.conf" @@ -32,14 +31,15 @@ object CoreConfig { val CoreUserConfigFile = "external.config.filename" val CoreUserConfigResource = "external.config.resource" val CatalogConfigKey = "catalog" - val LauncherKey= "launcher" + val LauncherKey = "launcher" val JarsRepo = "jars" val HdfsKey = "hdfs" val DerbyClass = "org.apache.spark.sql.crossdata.catalog.persistent.DerbyCatalog" val DefaultSecurityManager = "org.apache.spark.sql.crossdata.security.DefaultSecurityManager" val ZookeeperClass = "org.apache.spark.sql.crossdata.catalog.persistent.ZookeeperCatalog" - val ZookeeperStreamingClass = "org.apache.spark.sql.crossdata.catalog.streaming.ZookeeperStreamingCatalog" + val ZookeeperStreamingClass = + "org.apache.spark.sql.crossdata.catalog.streaming.ZookeeperStreamingCatalog" val StreamingConfigKey = "streaming" val SecurityConfigKey = "security" val SecurityManagerConfigKey = "manager" @@ -54,28 +54,28 @@ object CoreConfig { val SecurityClassConfigKey = s"$SecurityConfigKey.$SecurityManagerConfigKey.$ClassConfigKey" val SecurityAuditConfigKey = s"$SecurityConfigKey.$SecurityManagerConfigKey.$AuditConfigKey" val SecurityUserConfigKey = s"$SecurityConfigKey.$SecurityManagerConfigKey.$UserConfigKey" - val SecurityPasswordConfigKey = s"$SecurityConfigKey.$SecurityManagerConfigKey.$PasswordConfigKey" + val SecurityPasswordConfigKey = + s"$SecurityConfigKey.$SecurityManagerConfigKey.$PasswordConfigKey" val SecuritySessionConfigKey = s"$SecurityConfigKey.$SecurityManagerConfigKey.$SessionConfigKey" val SparkSqlConfigPrefix = "config.spark.sql" //WARNING!! XDServer is using this path to read its parameters - // WARNING: It only detects paths starting with "config.spark.sql" def configToSparkSQL(config: Config, defaultSqlConf: SQLConf = new SQLConf): SQLConf = { import scala.collection.JavaConversions._ - val sparkSQLProps: Map[String,String] = - config.entrySet() - .map(e => (e.getKey, e.getValue.unwrapped().toString)) - .toMap - .filterKeys(_.startsWith(CoreConfig.SparkSqlConfigPrefix)) - .map(e => (e._1.replace("config.", ""), e._2)) - + val sparkSQLProps: Map[String, String] = config + .entrySet() + .map(e => (e.getKey, e.getValue.unwrapped().toString)) + .toMap + .filterKeys(_.startsWith(CoreConfig.SparkSqlConfigPrefix)) + .map(e => (e._1.replace("config.", ""), e._2)) def sqlPropsToSQLConf(sparkSQLProps: Map[String, String], sqlConf: SQLConf): SQLConf = { - sparkSQLProps.foreach { case (key, value) => - sqlConf.setConfString(key, value) + sparkSQLProps.foreach { + case (key, value) => + sqlConf.setConfString(key, value) } sqlConf } @@ -128,17 +128,14 @@ trait CoreConfig extends Logging { } // System properties - val systemPropertiesConfig = - Try( + val systemPropertiesConfig = Try( ConfigFactory.parseProperties(System.getProperties).getConfig(ParentConfigName) - ).getOrElse( + ).getOrElse( ConfigFactory.parseProperties(System.getProperties) - ) + ) defaultConfig = systemPropertiesConfig.withFallback(defaultConfig) ConfigFactory.load(defaultConfig) } } - - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/config/StreamingConfig.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/config/StreamingConfig.scala index e5a09b2f5..0d8518fb0 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/config/StreamingConfig.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/config/StreamingConfig.scala @@ -15,7 +15,6 @@ */ package org.apache.spark.sql.crossdata.config - import org.apache.log4j.Logger import org.apache.spark.sql.crossdata.config.StreamingConstants._ import org.apache.spark.sql.crossdata.models._ @@ -28,25 +27,40 @@ object StreamingConfig extends CoreConfig { lazy val streamingConfig = config.getConfig(StreamingConstants.StreamingConfPath) - lazy val streamingConfigMap: Map[String, String] = - streamingConfig.entrySet().map(entry => (entry.getKey, streamingConfig.getAnyRef(entry.getKey).toString)).toMap + lazy val streamingConfigMap: Map[String, String] = streamingConfig + .entrySet() + .map(entry => (entry.getKey, streamingConfig.getAnyRef(entry.getKey).toString)) + .toMap - def createEphemeralTableModel(ident: String, opts : Map[String, String], userSchema: Option[StructType] = None) : EphemeralTableModel = { + def createEphemeralTableModel(ident: String, + opts: Map[String, String], + userSchema: Option[StructType] = None): EphemeralTableModel = { val finalOptions = getEphemeralTableOptions(ident, opts) - val connectionsModel = ConnectionHostModel(extractConnection(finalOptions, ZKConnection), extractConnection(finalOptions, KafkaConnection)) + val connectionsModel = ConnectionHostModel(extractConnection(finalOptions, ZKConnection), + extractConnection(finalOptions, KafkaConnection)) val topics = finalOptions(KafkaTopic) - .split(",").map(_.split(":")).map{ - case l if l.size == 2 => TopicModel(l(0), l(1).toInt) - }.toSeq + .split(",") + .map(_.split(":")) + .map { + case l if l.size == 2 => TopicModel(l(0), l(1).toInt) + } + .toSeq val groupId = finalOptions(KafkaGroupId) val partition = finalOptions.get(KafkaPartition) - val kafkaAdditionalOptions = finalOptions.filter{case (k, v) => k.startsWith(KafkaAdditionalOptionsKey)} + val kafkaAdditionalOptions = finalOptions.filter { + case (k, v) => k.startsWith(KafkaAdditionalOptionsKey) + } val storageLevel = finalOptions(ReceiverStorageLevel) - val kafkaOptions = KafkaOptionsModel(connectionsModel, topics, groupId, partition, kafkaAdditionalOptions, storageLevel) + val kafkaOptions = KafkaOptionsModel(connectionsModel, + topics, + groupId, + partition, + kafkaAdditionalOptions, + storageLevel) val minW = finalOptions(AtomicWindow).toInt val maxW = finalOptions(MaxWindow).toInt val outFormat = finalOptions(OutputFormat) match { @@ -55,10 +69,13 @@ object StreamingConfig extends CoreConfig { } val checkpointDirectory = s"${finalOptions(CheckpointDirectory)}/$ident" - val sparkOpts = finalOptions.filter{case (k, v) => k.startsWith(SparkConfPath)} + val sparkOpts = finalOptions.filter { + case (k, v) => k.startsWith(SparkConfPath) + } validateSparkConfig(sparkOpts) - val ephemeralOptions = EphemeralOptionsModel(kafkaOptions, minW, maxW, outFormat, checkpointDirectory, sparkOpts) + val ephemeralOptions = + EphemeralOptionsModel(kafkaOptions, minW, maxW, outFormat, checkpointDirectory, sparkOpts) EphemeralTableModel(ident, ephemeralOptions, userSchema) } @@ -72,10 +89,10 @@ object StreamingConfig extends CoreConfig { } } + private def getEphemeralTableOptions(ephTable: String, + opts: Map[String, String]): Map[String, String] = { - private def getEphemeralTableOptions(ephTable: String, opts : Map[String, String]): Map[String, String] = { - - listMandatoryEphemeralTableKeys.foreach{ mandatoryOption => + listMandatoryEphemeralTableKeys.foreach { mandatoryOption => if (opts.get(mandatoryOption).isEmpty) notFound(mandatoryOption) } streamingConfigMap ++ opts @@ -88,11 +105,11 @@ object StreamingConfig extends CoreConfig { } private def validateSparkConfig(config: Map[String, String]): Unit = { - config.get(SparkCoresMax).foreach{ maxCores => - if (maxCores.toInt < 2) throw new RuntimeException(s"At least 2 cores are required to launch streaming applications") + config.get(SparkCoresMax).foreach { maxCores => + if (maxCores.toInt < 2) + throw new RuntimeException( + s"At least 2 cores are required to launch streaming applications") } } } - - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/config/StreamingConstants.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/config/StreamingConstants.scala index 88821346f..4ae64f6ec 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/config/StreamingConstants.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/config/StreamingConstants.scala @@ -15,7 +15,6 @@ */ package org.apache.spark.sql.crossdata.config - object StreamingConstants { val MainClass = "com.stratio.crossdata.streaming.CrossdataStreamingApplication" @@ -23,17 +22,17 @@ object StreamingConstants { val SparkConfPath = "spark" /** - * Ephemeral table config - */ + * Ephemeral table config + */ /** - * Receiver - */ + * Receiver + */ //Connection //format "host0:consumerPort0,host1:consumerPort1,host2:consumerPort2" val ZKConnection = "receiver.zk.connection" //Connection //format "host0:producerPort0,host1:producerPort1,host2:producerPort2" val KafkaConnection = "receiver.kafka.connection" //format "topicName1:1,topicName1:2,topicName1:3" - val KafkaTopic= "receiver.kafka.topic" + val KafkaTopic = "receiver.kafka.topic" val KafkaGroupId = "receiver.kafka.groupId" val KafkaPartition = "receiver.kafka.numPartitions" //optional //would go through additionalOptions param. @@ -42,16 +41,16 @@ object StreamingConstants { val ReceiverStorageLevel = "receiver.storageLevel" //optional /** - * Streaming generic options - */ + * Streaming generic options + */ val AtomicWindow = "atomicWindow" val MaxWindow = "maxWindow" val OutputFormat = "outputFormat" val CheckpointDirectory = "checkpointDirectory" /** - * SparkOptions - */ + * SparkOptions + */ // One param for each element map. key = sparkOptions.x -> value = value val ZooKeeperStreamingCatalogPath = "catalog.zookeeper" val SparkHomeKey = "sparkHome" @@ -62,9 +61,6 @@ object StreamingConstants { val HdfsConf = "hdfs" // TODO define and validate mandatory options like spark.master, kafka.topic and groupid, ... - val listMandatoryEphemeralTableKeys = List( - KafkaTopic, - KafkaGroupId) - + val listMandatoryEphemeralTableKeys = List(KafkaTopic, KafkaGroupId) } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/AppDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/AppDAO.scala index a9d010c42..d3cbcbc41 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/AppDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/AppDAO.scala @@ -37,8 +37,11 @@ import org.apache.spark.sql.crossdata.daos.DAOConstants._ import org.apache.spark.sql.crossdata.models.{AppModel, ViewModel} import org.apache.spark.sql.crossdata.serializers.CrossdataSerializer -trait AppDAO extends GenericDAOComponent[AppModel] -with TypesafeConfigComponent with SparkLoggerComponent with CrossdataSerializer { +trait AppDAO + extends GenericDAOComponent[AppModel] + with TypesafeConfigComponent + with SparkLoggerComponent + with CrossdataSerializer { val appID = "appID" diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralQueriesDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralQueriesDAO.scala index 3306a2e45..2bb760f0b 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralQueriesDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralQueriesDAO.scala @@ -23,11 +23,14 @@ import org.apache.spark.sql.crossdata.models.EphemeralQueryModel import org.apache.spark.sql.crossdata.serializers.CrossdataSerializer import org.json4s.Formats -trait EphemeralQueriesDAO extends GenericDAOComponent[EphemeralQueryModel] -with TypesafeConfigComponent with SparkLoggerComponent with CrossdataSerializer { +trait EphemeralQueriesDAO + extends GenericDAOComponent[EphemeralQueryModel] + with TypesafeConfigComponent + with SparkLoggerComponent + with CrossdataSerializer { private val jacksonFormats: Formats = json4sJacksonFormats override implicit val formats = jacksonFormats override val dao: DAO = new GenericDAO(Option(EphemeralQueriesPath)) -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralQueriesMapDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralQueriesMapDAO.scala index 336e5a9f3..d41ecc94d 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralQueriesMapDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralQueriesMapDAO.scala @@ -22,8 +22,11 @@ import org.apache.spark.sql.crossdata.daos.DAOConstants._ import org.apache.spark.sql.crossdata.models.EphemeralQueryModel import org.apache.spark.sql.crossdata.serializers.CrossdataSerializer -trait EphemeralQueriesMapDAO extends GenericDAOComponent[EphemeralQueryModel] -with MapConfigComponent with SparkLoggerComponent with CrossdataSerializer { +trait EphemeralQueriesMapDAO + extends GenericDAOComponent[EphemeralQueryModel] + with MapConfigComponent + with SparkLoggerComponent + with CrossdataSerializer { override implicit val formats = json4sJacksonFormats diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableDAO.scala index 9c34b5541..e15bcfc39 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableDAO.scala @@ -22,8 +22,11 @@ import org.apache.spark.sql.crossdata.daos.DAOConstants._ import org.apache.spark.sql.crossdata.models.EphemeralTableModel import org.apache.spark.sql.crossdata.serializers.CrossdataSerializer -trait EphemeralTableDAO extends GenericDAOComponent[EphemeralTableModel] -with TypesafeConfigComponent with SparkLoggerComponent with CrossdataSerializer { +trait EphemeralTableDAO + extends GenericDAOComponent[EphemeralTableModel] + with TypesafeConfigComponent + with SparkLoggerComponent + with CrossdataSerializer { val ephemeralTableIdField = "EphemeralTableID" diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableMapDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableMapDAO.scala index 97c1ee1a0..0c5956f99 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableMapDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableMapDAO.scala @@ -22,10 +22,13 @@ import org.apache.spark.sql.crossdata.daos.DAOConstants._ import org.apache.spark.sql.crossdata.models.EphemeralTableModel import org.apache.spark.sql.crossdata.serializers.CrossdataSerializer -trait EphemeralTableMapDAO extends GenericDAOComponent[EphemeralTableModel] -with MapConfigComponent with SparkLoggerComponent with CrossdataSerializer { +trait EphemeralTableMapDAO + extends GenericDAOComponent[EphemeralTableModel] + with MapConfigComponent + with SparkLoggerComponent + with CrossdataSerializer { override implicit val formats = json4sJacksonFormats override val dao: DAO = new GenericDAO(Option(EphemeralTablesPath)) -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableStatusDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableStatusDAO.scala index 33d3760ca..236abf45a 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableStatusDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableStatusDAO.scala @@ -22,8 +22,11 @@ import org.apache.spark.sql.crossdata.daos.DAOConstants._ import org.apache.spark.sql.crossdata.models.EphemeralStatusModel import org.apache.spark.sql.crossdata.serializers.CrossdataSerializer -trait EphemeralTableStatusDAO extends GenericDAOComponent[EphemeralStatusModel] -with TypesafeConfigComponent with SparkLoggerComponent with CrossdataSerializer { +trait EphemeralTableStatusDAO + extends GenericDAOComponent[EphemeralStatusModel] + with TypesafeConfigComponent + with SparkLoggerComponent + with CrossdataSerializer { override implicit val formats = json4sJacksonFormats diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableStatusMapDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableStatusMapDAO.scala index 247c005a5..946d70a96 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableStatusMapDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/EphemeralTableStatusMapDAO.scala @@ -22,11 +22,14 @@ import org.apache.spark.sql.crossdata.daos.DAOConstants._ import org.apache.spark.sql.crossdata.models.EphemeralStatusModel import org.apache.spark.sql.crossdata.serializers.CrossdataSerializer -trait EphemeralTableStatusMapDAO extends GenericDAOComponent[EphemeralStatusModel] -with MapConfigComponent with SparkLoggerComponent with CrossdataSerializer { +trait EphemeralTableStatusMapDAO + extends GenericDAOComponent[EphemeralStatusModel] + with MapConfigComponent + with SparkLoggerComponent + with CrossdataSerializer { -override implicit val formats = json4sJacksonFormats + override implicit val formats = json4sJacksonFormats -override val dao: DAO = new GenericDAO(Option(EphemeralTableStatusPath)) + override val dao: DAO = new GenericDAO(Option(EphemeralTableStatusPath)) -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/IndexDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/IndexDAO.scala index e365ad2a2..5b3496870 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/IndexDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/IndexDAO.scala @@ -37,8 +37,11 @@ import org.apache.spark.sql.crossdata.daos.DAOConstants._ import org.apache.spark.sql.crossdata.models.IndexModel import org.apache.spark.sql.crossdata.serializers.CrossdataSerializer -trait IndexDAO extends GenericDAOComponent[IndexModel] -with TypesafeConfigComponent with SparkLoggerComponent with CrossdataSerializer { +trait IndexDAO + extends GenericDAOComponent[IndexModel] + with TypesafeConfigComponent + with SparkLoggerComponent + with CrossdataSerializer { val indexID = "indexID" diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/TableDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/TableDAO.scala index 5c79c17ba..cdaec212f 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/TableDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/TableDAO.scala @@ -28,7 +28,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql.crossdata.daos import com.stratio.common.utils.components.config.impl.TypesafeConfigComponent @@ -38,11 +37,14 @@ import org.apache.spark.sql.crossdata.daos.DAOConstants._ import org.apache.spark.sql.crossdata.models.TableModel import org.apache.spark.sql.crossdata.serializers.CrossdataSerializer -trait TableDAO extends GenericDAOComponent[TableModel] -with TypesafeConfigComponent with SparkLoggerComponent with CrossdataSerializer { +trait TableDAO + extends GenericDAOComponent[TableModel] + with TypesafeConfigComponent + with SparkLoggerComponent + with CrossdataSerializer { override implicit val formats = json4sJacksonFormats override val dao: DAO = new GenericDAO(Option(TablesPath)) -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/ViewDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/ViewDAO.scala index 171a1b870..ced22fe47 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/ViewDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/ViewDAO.scala @@ -37,8 +37,11 @@ import org.apache.spark.sql.crossdata.daos.DAOConstants._ import org.apache.spark.sql.crossdata.models.{ViewModel, EphemeralTableModel} import org.apache.spark.sql.crossdata.serializers.CrossdataSerializer -trait ViewDAO extends GenericDAOComponent[ViewModel] -with TypesafeConfigComponent with SparkLoggerComponent with CrossdataSerializer { +trait ViewDAO + extends GenericDAOComponent[ViewModel] + with TypesafeConfigComponent + with SparkLoggerComponent + with CrossdataSerializer { val viewID = "viewID" diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralQueriesMapDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralQueriesMapDAO.scala index 2570f55c2..211e603fa 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralQueriesMapDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralQueriesMapDAO.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.crossdata.daos.impl import org.apache.spark.sql.crossdata.daos.{EphemeralQueriesMapDAO => EphQueriesMapDAO} class EphemeralQueriesMapDAO(opts: Map[String, String], subPath: Option[String] = None) - extends EphQueriesMapDAO{ + extends EphQueriesMapDAO { val memoryMap = opts override lazy val config: Config = new DummyConfig(subPath) diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableMapDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableMapDAO.scala index 4a8a3b0ee..7f1f429ed 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableMapDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableMapDAO.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.crossdata.daos.impl import org.apache.spark.sql.crossdata.daos.{EphemeralTableMapDAO => EphTableMapDAO} -class EphemeralTableMapDAO (opts: Map[String, Any], subPath: Option[String] = None) - extends EphTableMapDAO{ +class EphemeralTableMapDAO(opts: Map[String, Any], subPath: Option[String] = None) + extends EphTableMapDAO { val memoryMap = opts override lazy val config: Config = new DummyConfig(subPath) diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableStatusMapDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableStatusMapDAO.scala index 63b3f9884..1de451a0f 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableStatusMapDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableStatusMapDAO.scala @@ -15,9 +15,9 @@ */ package org.apache.spark.sql.crossdata.daos.impl -import org.apache.spark.sql.crossdata.daos.{EphemeralTableStatusMapDAO => EphTableStatusMapDAO } -class EphemeralTableStatusMapDAO (opts: Map[String, String], subPath: Option[String] = None) - extends EphTableStatusMapDAO{ +import org.apache.spark.sql.crossdata.daos.{EphemeralTableStatusMapDAO => EphTableStatusMapDAO} +class EphemeralTableStatusMapDAO(opts: Map[String, String], subPath: Option[String] = None) + extends EphTableStatusMapDAO { val memoryMap = opts override lazy val config: Config = new DummyConfig(subPath) diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableStatusTypesafeDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableStatusTypesafeDAO.scala index f865bf686..6a8d71307 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableStatusTypesafeDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableStatusTypesafeDAO.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.crossdata.daos.impl import com.typesafe.config.Config import org.apache.spark.sql.crossdata.daos.EphemeralTableStatusDAO -class EphemeralTableStatusTypesafeDAO (configuration: Config) extends EphemeralTableStatusDAO { +class EphemeralTableStatusTypesafeDAO(configuration: Config) extends EphemeralTableStatusDAO { override val config = new TypesafeConfig(Option(configuration)) diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableTypesafeDAO.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableTypesafeDAO.scala index 99f167f10..dd1ffedf7 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableTypesafeDAO.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/daos/impl/EphemeralTableTypesafeDAO.scala @@ -23,4 +23,3 @@ class EphemeralTableTypesafeDAO(configuration: Config) extends EphemeralTableDAO override val config = new TypesafeConfig(Option(configuration)) } - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/exceptions/CrossdataException.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/exceptions/CrossdataException.scala index 95270679b..8d2e7c5a1 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/exceptions/CrossdataException.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/exceptions/CrossdataException.scala @@ -15,14 +15,13 @@ */ package org.apache.spark.sql.crossdata.exceptions -class CrossdataException(message: String, cause: Throwable) - extends Exception(message, cause) { +class CrossdataException(message: String, cause: Throwable) extends Exception(message, cause) { def this(message: String) = this(message, null) } /** - * Exception thrown when a Native [[org.apache.spark.sql.crossdata.ExecutionType]] fails. - */ + * Exception thrown when a Native [[org.apache.spark.sql.crossdata.ExecutionType]] fails. + */ private[spark] class NativeExecutionException - extends CrossdataException("The operation cannot be executed without Spark") + extends CrossdataException("The operation cannot be executed without Spark") diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/launcher/SparkJob.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/launcher/SparkJob.scala index 76889c628..d2dee8ecb 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/launcher/SparkJob.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/launcher/SparkJob.scala @@ -22,7 +22,8 @@ import scala.concurrent.{ExecutionContext, Future} import scala.io.Source import scala.util.{Failure, Success, Try} -class SparkJob(sparkLauncher: SparkLauncher)(implicit executionContext: ExecutionContext) extends SparkLoggerComponent{ +class SparkJob(sparkLauncher: SparkLauncher)(implicit executionContext: ExecutionContext) + extends SparkLoggerComponent { def submit(): Unit = Future[(Int, Process)] { diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/launcher/SparkJobLauncher.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/launcher/SparkJobLauncher.scala index 11a9b05ca..ff8606dc8 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/launcher/SparkJobLauncher.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/launcher/SparkJobLauncher.scala @@ -35,34 +35,57 @@ import scala.util.{Properties, Try} object SparkJobLauncher extends SparkLoggerComponent with CrossdataSerializer { - def getSparkStreamingJob(xdContext: XDContext, crossdataConfig: Config, ephemeralTableName: String) - (implicit executionContext: ExecutionContext): Try[SparkJob] = Try { - val streamingConfig = crossdataConfig.getConfig(StreamingConfPath) - val launcherConfig=crossdataConfig.getConfig(LauncherKey) - val sparkHome = - Properties.envOrNone("SPARK_HOME").orElse(Try(launcherConfig.getString(SparkHomeKey)).toOption).getOrElse( - throw new RuntimeException("You must set the $SPARK_HOME path in configuration or environment") - ) + def getSparkStreamingJob( + xdContext: XDContext, + crossdataConfig: Config, + ephemeralTableName: String)(implicit executionContext: ExecutionContext): Try[SparkJob] = + Try { + val streamingConfig = crossdataConfig.getConfig(StreamingConfPath) + val launcherConfig = crossdataConfig.getConfig(LauncherKey) + val sparkHome = Properties + .envOrNone("SPARK_HOME") + .orElse(Try(launcherConfig.getString(SparkHomeKey)).toOption) + .getOrElse( + throw new RuntimeException( + "You must set the $SPARK_HOME path in configuration or environment") + ) + + val eTable = xdContext.catalog + .getEphemeralTable(ephemeralTableName) + .getOrElse(notFound(ephemeralTableName)) + val appName = s"${eTable.name}_${UUID.randomUUID()}" + val zkConfigEncoded: String = encode(render(streamingConfig, ZooKeeperStreamingCatalogPath)) + val catalogConfigEncoded: String = + encode(render(crossdataConfig, CoreConfig.CatalogConfigKey)) + val appArgs = Seq(eTable.name, zkConfigEncoded, catalogConfigEncoded) + val master = streamingConfig.getString(SparkMasterKey) + val jar = streamingConfig.getString(AppJarKey) + val jars = Try(streamingConfig.getStringList(ExternalJarsKey).toSeq).getOrElse(Seq.empty) + val sparkConfig: Map[String, String] = sparkConf(streamingConfig) + if (master.toLowerCase.contains("mesos")) { + val hdfsPath = getHdfsPath(crossdataConfig, jar) + getJob(sparkHome, + StreamingConstants.MainClass, + appArgs, + appName, + master, + hdfsPath, + sparkConfig, + jars)(executionContext) + + } else { + getJob(sparkHome, + StreamingConstants.MainClass, + appArgs, + appName, + master, + jar, + sparkConfig, + jars)(executionContext) + } - val eTable = xdContext.catalog.getEphemeralTable(ephemeralTableName).getOrElse(notFound(ephemeralTableName)) - val appName = s"${eTable.name}_${UUID.randomUUID()}" - val zkConfigEncoded: String = encode(render(streamingConfig, ZooKeeperStreamingCatalogPath)) - val catalogConfigEncoded: String = encode(render(crossdataConfig, CoreConfig.CatalogConfigKey)) - val appArgs = Seq(eTable.name, zkConfigEncoded, catalogConfigEncoded) - val master = streamingConfig.getString(SparkMasterKey) - val jar = streamingConfig.getString(AppJarKey) - val jars = Try(streamingConfig.getStringList(ExternalJarsKey).toSeq).getOrElse(Seq.empty) - val sparkConfig: Map[String, String] = sparkConf(streamingConfig) - if (master.toLowerCase.contains("mesos")) { - val hdfsPath = getHdfsPath(crossdataConfig, jar) - getJob(sparkHome, StreamingConstants.MainClass, appArgs, appName, master, hdfsPath, sparkConfig, jars)(executionContext) - - } else { - getJob(sparkHome, StreamingConstants.MainClass, appArgs, appName, master, jar, sparkConfig, jars)(executionContext) } - } - /** * This method return the HDFS path of the streaming jar and if not exists previously it writes the jar in HDFS * @@ -86,19 +109,28 @@ object SparkJobLauncher extends SparkLoggerComponent with CrossdataSerializer { s"hdfs://$hdfsMaster/$destPath/$jarName" } - def getSparkJob(launcherConfig:Config, master: String, main: String, args: Seq[String], jar: String, appName: String, submitOptions: Option[Map[String, String]]) - (implicit executionContext: ExecutionContext): Try[SparkJob] = Try { - - val sparkHome = - Properties.envOrNone("SPARK_HOME").orElse(Try(launcherConfig.getString(SparkHomeKey)).toOption).getOrElse( - throw new RuntimeException("You must set the $SPARK_HOME path in configuration or environment") + def getSparkJob(launcherConfig: Config, + master: String, + main: String, + args: Seq[String], + jar: String, + appName: String, + submitOptions: Option[Map[String, String]])( + implicit executionContext: ExecutionContext): Try[SparkJob] = Try { + + val sparkHome = Properties + .envOrNone("SPARK_HOME") + .orElse(Try(launcherConfig.getString(SparkHomeKey)).toOption) + .getOrElse( + throw new RuntimeException( + "You must set the $SPARK_HOME path in configuration or environment") ) //due to the parser doesn't allow middle-score symbol and spark submit properties are all with that, we are using '.' instead of '-'. So now we map to '-' again val sparkConfig = submitOptions.getOrElse(Map.empty) map { case (k, v) => - val key=k.replaceAll("\\.", "-") - ("spark."+key, v) + val key = k.replaceAll("\\.", "-") + ("spark." + key, v) } getJob(sparkHome, main, args, appName, master, jar, sparkConfig)(executionContext) @@ -108,15 +140,15 @@ object SparkJobLauncher extends SparkLoggerComponent with CrossdataSerializer { sparkJob.submit() } - private def getJob(sparkHome: String, - appMain: String, - appArgs: Seq[String], - appName: String, - master: String, - jar: String, - sparkConf: Map[String, String] = Map.empty, - externalJars: Seq[String] = Seq.empty - )(executionContext: ExecutionContext): SparkJob = { + private def getJob( + sparkHome: String, + appMain: String, + appArgs: Seq[String], + appName: String, + master: String, + jar: String, + sparkConf: Map[String, String] = Map.empty, + externalJars: Seq[String] = Seq.empty)(executionContext: ExecutionContext): SparkJob = { val sparkLauncher = new SparkLauncher() .setSparkHome(sparkHome) .setAppName(appName) @@ -137,15 +169,19 @@ object SparkJobLauncher extends SparkLoggerComponent with CrossdataSerializer { private def sparkConf(streamingConfig: Config): Map[String, String] = typeSafeConfigToMapString(streamingConfig, Option(SparkConfPath)) - - private def typeSafeConfigToMapString(config: Config, path: Option[String] = None): Map[String, String] = { + private def typeSafeConfigToMapString(config: Config, + path: Option[String] = None): Map[String, String] = { val conf = path.map(config.getConfig).getOrElse(config) - conf.entrySet().toSeq.map(e => - (s"${path.fold("")(_ + ".") + e.getKey}", conf.getAnyRef(e.getKey).toString) - ).toMap + conf + .entrySet() + .toSeq + .map(e => (s"${path.fold("")(_ + ".") + e.getKey}", conf.getAnyRef(e.getKey).toString)) + .toMap } - private def render(config: Config, path: String): String = config.getConfig(path).atPath(path).root.render(ConfigRenderOptions.concise) + private def render(config: Config, path: String): String = + config.getConfig(path).atPath(path).root.render(ConfigRenderOptions.concise) - private def encode(value: String): String = BaseEncoding.base64().encode(value.getBytes) -} \ No newline at end of file + private def encode(value: String): String = + BaseEncoding.base64().encode(value.getBytes) +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/metrics/XDMetricsSource.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/metrics/XDMetricsSource.scala index e88e65388..a758febe6 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/metrics/XDMetricsSource.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/metrics/XDMetricsSource.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.crossdata.metrics import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source -class XDMetricsSource extends Source{ +class XDMetricsSource extends Source { override val sourceName = "XDMetricsSource" override val metricRegistry = new MetricRegistry() diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/AppModel.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/AppModel.scala index d1efdf461..c39e63fa7 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/AppModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/AppModel.scala @@ -15,7 +15,4 @@ */ package org.apache.spark.sql.crossdata.models - -case class AppModel(jar: String, - appAlias: String, - appClass: String) +case class AppModel(jar: String, appAlias: String, appClass: String) diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/ConnectionHostModel.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/ConnectionHostModel.scala index e7f268408..ea2479acf 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/ConnectionHostModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/ConnectionHostModel.scala @@ -15,7 +15,8 @@ */ package org.apache.spark.sql.crossdata.models -case class ConnectionHostModel(zkConnection: Seq[ConnectionModel], kafkaConnection: Seq[ConnectionModel]){ +case class ConnectionHostModel(zkConnection: Seq[ConnectionModel], + kafkaConnection: Seq[ConnectionModel]) { - def toPrettyString : String = ModelUtils.modelToJsonString(this) -} \ No newline at end of file + def toPrettyString: String = ModelUtils.modelToJsonString(this) +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/ConnectionModel.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/ConnectionModel.scala index 08b315c08..1b8d27604 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/ConnectionModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/ConnectionModel.scala @@ -15,7 +15,7 @@ */ package org.apache.spark.sql.crossdata.models -case class ConnectionModel(host: String, port: Int){ +case class ConnectionModel(host: String, port: Int) { - def toPrettyString : String = ModelUtils.modelToJsonString(this) + def toPrettyString: String = ModelUtils.modelToJsonString(this) } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralOptionsModel.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralOptionsModel.scala index 95fe4cac3..ab842c737 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralOptionsModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralOptionsModel.scala @@ -20,20 +20,21 @@ import org.apache.spark.sql.crossdata.models.EphemeralOptionsModel._ case class EphemeralOptionsModel(kafkaOptions: KafkaOptionsModel, atomicWindow: Int = DefaultAtomicWindow, maxWindow: Int = DefaultMaxWindow, - outputFormat: EphemeralOutputFormat.Value = EphemeralOutputFormat.ROW, + outputFormat: EphemeralOutputFormat.Value = + EphemeralOutputFormat.ROW, checkpointDirectory: String = DefaultCheckpointDirectory, sparkOptions: Map[String, String] = Map.empty) { - def toPrettyString : String = ModelUtils.modelToJsonString(this) + def toPrettyString: String = ModelUtils.modelToJsonString(this) } object EphemeralOptionsModel { /** - * Default minimum Time in Seconds for the Batch Interval in SparkStreaming. - * This parameter mark the the minimum time for the windowed queries - */ + * Default minimum Time in Seconds for the Batch Interval in SparkStreaming. + * This parameter mark the the minimum time for the windowed queries + */ val DefaultAtomicWindow = 5 val DefaultMaxWindow = 60 val DefaultCheckpointDirectory = "checkpoint/crossdata" -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralOutputFormat.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralOutputFormat.scala index 3dcf95952..652393d13 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralOutputFormat.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralOutputFormat.scala @@ -20,4 +20,3 @@ object EphemeralOutputFormat extends Enumeration { type Status = Value val ROW, JSON = Value } - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralQueryModel.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralQueryModel.scala index 298b99e90..0525a95eb 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralQueryModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralQueryModel.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.crossdata.models case class EphemeralQueryModel(ephemeralTableName: String, sql: String, - alias : String, + alias: String, window: Int = EphemeralOptionsModel.DefaultAtomicWindow, options: Map[String, String] = Map.empty) { - def toPrettyString : String = ModelUtils.modelToJsonString(this) -} \ No newline at end of file + def toPrettyString: String = ModelUtils.modelToJsonString(this) +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralStatusModel.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralStatusModel.scala index e6bb4c06b..a74703bbf 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralStatusModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralStatusModel.scala @@ -20,5 +20,5 @@ case class EphemeralStatusModel(ephemeralTableName: String, startedTime: Option[Long] = None, stoppedTime: Option[Long] = None) { - def toPrettyString : String = ModelUtils.modelToJsonString(this) + def toPrettyString: String = ModelUtils.modelToJsonString(this) } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralTableModel.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralTableModel.scala index db91798b7..db9b888cf 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralTableModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/EphemeralTableModel.scala @@ -21,6 +21,7 @@ case class EphemeralTableModel(name: String, options: EphemeralOptionsModel, schema: Option[StructType] = None) { - def toPrettyString : String = ModelUtils.modelToJsonString(this).replaceAll("\\\\\"","\"") + def toPrettyString: String = + ModelUtils.modelToJsonString(this).replaceAll("\\\\\"", "\"") } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/IndexModel.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/IndexModel.scala index cf7d68e19..116203bb5 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/IndexModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/IndexModel.scala @@ -17,4 +17,4 @@ package org.apache.spark.sql.crossdata.models import org.apache.spark.sql.crossdata.catalog.XDCatalog.CrossdataIndex -case class IndexModel(indexId:String, crossdataIndex: CrossdataIndex) +case class IndexModel(indexId: String, crossdataIndex: CrossdataIndex) diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/KafkaOptionsModel.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/KafkaOptionsModel.scala index 4b0289d3b..16ff473e6 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/KafkaOptionsModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/KafkaOptionsModel.scala @@ -22,5 +22,5 @@ case class KafkaOptionsModel(connection: ConnectionHostModel, additionalOptions: Map[String, String] = Map.empty, storageLevel: String = "MEMORY_AND_DISK_SER") { - def toPrettyString : String = ModelUtils.modelToJsonString(this) -} \ No newline at end of file + def toPrettyString: String = ModelUtils.modelToJsonString(this) +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/ModelUtils.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/ModelUtils.scala index 17b3bff6f..140dc9286 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/ModelUtils.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/ModelUtils.scala @@ -20,5 +20,5 @@ import org.json4s.jackson.Serialization._ object ModelUtils extends CrossdataSerializer { - def modelToJsonString[T <: AnyRef](model: T) : String = writePretty(model) + def modelToJsonString[T <: AnyRef](model: T): String = writePretty(model) } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/TableModel.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/TableModel.scala index 9c6c283ab..fda7565ff 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/TableModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/TableModel.scala @@ -27,7 +27,9 @@ case class TableModel(id: String, version: String = crossdata.CrossdataVersion) { def getExtendedName: String = - database.fold(name) { databaseName => s"$databaseName.$name" } + database.fold(name) { databaseName => + s"$databaseName.$name" + } - def toPrettyString : String = ModelUtils.modelToJsonString(this) + def toPrettyString: String = ModelUtils.modelToJsonString(this) } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/TopicModel.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/TopicModel.scala index da288b024..f748bd205 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/TopicModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/TopicModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.crossdata.models case class TopicModel(name: String, numPartitions: Int = TopicModel.DefaultNumPartitions) { - def toStringPretty : String = ModelUtils.modelToJsonString(this) + def toStringPretty: String = ModelUtils.modelToJsonString(this) } object TopicModel { diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/models/ViewModel.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/models/ViewModel.scala index b64a46dc1..b7a53f27f 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/models/ViewModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/models/ViewModel.scala @@ -24,5 +24,7 @@ case class ViewModel(id: String, version: String = crossdata.CrossdataVersion) { def getExtendedName: String = - database.fold(name) { databaseName => s"$databaseName.$name" } + database.fold(name) { databaseName => + s"$databaseName.$name" + } } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/package.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/package.scala index d81b599e9..d2df5d2c0 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/package.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/package.scala @@ -19,6 +19,7 @@ import scala.io.Source package object crossdata { - lazy val CrossdataVersion = Source.fromInputStream(getClass.getResourceAsStream("/crossdata.version")).mkString + lazy val CrossdataVersion = + Source.fromInputStream(getClass.getResourceAsStream("/crossdata.version")).mkString } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/security/Credentials.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/security/Credentials.scala index 6c8d32d46..5dc529faa 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/security/Credentials.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/security/Credentials.scala @@ -15,4 +15,6 @@ */ package org.apache.spark.sql.crossdata.security -case class Credentials(user: Option[String] = None, password: Option[String] = None, sessionId: Option[String] = None) +case class Credentials(user: Option[String] = None, + password: Option[String] = None, + sessionId: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/security/DefaultSecurityManager.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/security/DefaultSecurityManager.scala index eb33f10c4..0ff9e7957 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/security/DefaultSecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/security/DefaultSecurityManager.scala @@ -15,11 +15,16 @@ */ package org.apache.spark.sql.crossdata.security -class DefaultSecurityManager(credentials: Credentials, audit: Boolean) extends SecurityManager(credentials, audit) { +class DefaultSecurityManager(credentials: Credentials, audit: Boolean) + extends SecurityManager(credentials, audit) { override def authorize(resource: Any): AuthorizationReply = { - val info = s"${credentials.user.fold(""){u => s"User '$u', "}}${credentials.sessionId.fold(""){s => s"SessionId, '$s'}"}}Access to: '$resource'" - if(audit) logInfo(info) + val info = s"${credentials.user.fold("") { u => + s"User '$u', " + }}${credentials.sessionId.fold("") { s => + s"SessionId, '$s'}" + }}Access to: '$resource'" + if (audit) logInfo(info) new AuthorizationReply(true, Some(info)) } diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/security/SecurityManager.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/security/SecurityManager.scala index bcf08c01c..d44fc6062 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/security/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/security/SecurityManager.scala @@ -17,12 +17,8 @@ package org.apache.spark.sql.crossdata.security import org.apache.spark.Logging - - abstract class SecurityManager(val credentials: Credentials, val audit: Boolean) extends Logging { def authorize(resource: Any): AuthorizationReply } - - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/serializers/CrossdataSerializer.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/serializers/CrossdataSerializer.scala index 2e9a2259e..8125f3232 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/serializers/CrossdataSerializer.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/serializers/CrossdataSerializer.scala @@ -19,7 +19,6 @@ import org.apache.spark.sql.crossdata.models.{EphemeralExecutionStatus, Ephemera import org.json4s._ import org.json4s.ext.EnumNameSerializer - trait CrossdataSerializer { implicit val json4sJacksonFormats: Formats = @@ -29,4 +28,3 @@ trait CrossdataSerializer { new EnumNameSerializer(EphemeralOutputFormat) } - diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/serializers/StructTypeSerializer.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/serializers/StructTypeSerializer.scala index ee5db4066..edb4c8fa6 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/serializers/StructTypeSerializer.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/serializers/StructTypeSerializer.scala @@ -20,17 +20,19 @@ import org.json4s.JsonAST.{JString, JField, JObject} import org.json4s._ import org.json4s.reflect.TypeInfo -object StructTypeSerializer extends Serializer[StructType]{ +object StructTypeSerializer extends Serializer[StructType] { private val StructTypeClass = classOf[StructType] private val StructTypeId = "structType" def deserialize(implicit format: Formats): PartialFunction[(TypeInfo, JValue), StructType] = { - case (TypeInfo(StructTypeClass, _), json) => json match { - case JObject(JField(StructTypeId, JString(jsonString)) :: _) => - StructType.fromString(jsonString) - case x => throw new MappingException("Can't convert " + x + " to StructType") - } + case (TypeInfo(StructTypeClass, _), json) => + json match { + case JObject(JField(StructTypeId, JString(jsonString)) :: _) => + StructType.fromString(jsonString) + case x => + throw new MappingException("Can't convert " + x + " to StructType") + } } def serialize(implicit formats: Formats): PartialFunction[Any, JValue] = { @@ -38,4 +40,4 @@ object StructTypeSerializer extends Serializer[StructType]{ import JsonDSL._ StructTypeId -> x.json } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/session/XDSessionProvider.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/session/XDSessionProvider.scala index 5ec8130a3..568de195c 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/session/XDSessionProvider.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/session/XDSessionProvider.scala @@ -38,13 +38,12 @@ object XDSessionProvider { // TODO It should share some of the XDContext fields. It will be possible when Spark 2.0 is released // TODO sessionProvider should be threadSafe abstract class XDSessionProvider( - @transient val sc: SparkContext, - protected val commonConfig: Option[Config] = None - ) { + @transient val sc: SparkContext, + protected val commonConfig: Option[Config] = None +) { import XDSessionProvider._ - //NOTE: DO NEVER KEEP THE RETURNED REFERENCE FOR SEVERAL USES! def session(sessionID: SessionID): Try[XDSession] @@ -66,29 +65,32 @@ abstract class XDSessionProvider( * Session provider which store session info locally, so it can't be used when deploying several crossdata server */ class BasicSessionProvider( - @transient override val sc: SparkContext, - userConfig: Config - ) extends XDSessionProvider(sc, Option(userConfig)) with CoreConfig { + @transient override val sc: SparkContext, + userConfig: Config +) extends XDSessionProvider(sc, Option(userConfig)) + with CoreConfig { import XDSessionProvider._ override lazy val logger = Logger.getLogger(classOf[BasicSessionProvider]) - private lazy val catalogConfig = Try(config.getConfig(CoreConfig.CatalogConfigKey)).getOrElse(ConfigFactory.empty()) + private lazy val catalogConfig = + Try(config.getConfig(CoreConfig.CatalogConfigKey)).getOrElse(ConfigFactory.empty()) private lazy val sqlConf: SQLConf = configToSparkSQL(userConfig, new SQLConf) - @transient - protected lazy val externalCatalog: XDPersistentCatalog = CatalogUtils.externalCatalog(sqlConf, catalogConfig) + protected lazy val externalCatalog: XDPersistentCatalog = + CatalogUtils.externalCatalog(sqlConf, catalogConfig) @transient - protected lazy val streamingCatalog: Option[XDStreamingCatalog] = CatalogUtils.streamingCatalog(sqlConf, config) + protected lazy val streamingCatalog: Option[XDStreamingCatalog] = + CatalogUtils.streamingCatalog(sqlConf, config) private val sharedState = new XDSharedState(sc, sqlConf, externalCatalog, streamingCatalog) private val sessionIDToSQLProps: mutable.Map[SessionID, SQLConf] = mutable.Map.empty - private val sessionIDToTempCatalog: mutable.Map[SessionID, XDTemporaryCatalog] = mutable.Map.empty - + private val sessionIDToTempCatalog: mutable.Map[SessionID, XDTemporaryCatalog] = + mutable.Map.empty private val errorMessage = "A distributed context must be used to manage XDServer sessions. Please, use SparkSessions instead" @@ -103,17 +105,16 @@ class BasicSessionProvider( buildSession(sharedState.sqlConf, tempCatalog) } - override def closeSession(sessionID: SessionID): Try[Unit] = - { - for { - _ <- sessionIDToSQLProps.remove(sessionID) - _ <- sessionIDToTempCatalog.remove(sessionID) - } yield () - } map { - Success(_) - } getOrElse { - Failure(new RuntimeException(s"Cannot close session with sessionId=$sessionID")) - } + override def closeSession(sessionID: SessionID): Try[Unit] = { + for { + _ <- sessionIDToSQLProps.remove(sessionID) + _ <- sessionIDToTempCatalog.remove(sessionID) + } yield () + } map { + Success(_) + } getOrElse { + Failure(new RuntimeException(s"Cannot close session with sessionId=$sessionID")) + } override def session(sessionID: SessionID): Try[XDSession] = { for { diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/session/XDSharedState.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/session/XDSharedState.scala index cb03174cd..41e3e97a7 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/session/XDSharedState.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/session/XDSharedState.scala @@ -23,10 +23,9 @@ import org.apache.spark.sql.crossdata.config.CoreConfig import scala.collection.JavaConversions._ - final class XDSharedState( - @transient val sc: SparkContext, - val sqlConf: SQLConf, - val externalCatalog: XDCatalogCommon, - val streamingCatalog: Option[XDStreamingCatalog] - ) + @transient val sc: SparkContext, + val sqlConf: SQLConf, + val externalCatalog: XDCatalogCommon, + val streamingCatalog: Option[XDStreamingCatalog] +) diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/user/functions/GroupConcat.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/user/functions/GroupConcat.scala index 545b6f70c..be4f9786f 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/user/functions/GroupConcat.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/user/functions/GroupConcat.scala @@ -24,19 +24,20 @@ import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.DataType class GroupConcat(val separator: String) extends UserDefinedAggregateFunction { - def inputSchema: StructType = StructType(StructField("value", StringType) :: Nil) + def inputSchema: StructType = + StructType(StructField("value", StringType) :: Nil) - def bufferSchema: StructType = StructType(StructField("total", StringType) :: Nil) + def bufferSchema: StructType = + StructType(StructField("total", StringType) :: Nil) def dataType: DataType = StringType def deterministic: Boolean = true - def initialize(buffer: MutableAggregationBuffer): Unit = { - } + def initialize(buffer: MutableAggregationBuffer): Unit = {} def update(buffer: MutableAggregationBuffer, input: Row): Unit = { - if(buffer.isNullAt(0)){ + if (buffer.isNullAt(0)) { buffer(0) = input.getAs[String](0) } else { buffer(0) = buffer.getAs[String](0).concat(separator).concat(input.getAs[String](0)) @@ -44,7 +45,7 @@ class GroupConcat(val separator: String) extends UserDefinedAggregateFunction { } def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { - if(buffer1.isNullAt(0)){ + if (buffer1.isNullAt(0)) { buffer1(0) = buffer2.getAs[String](0) } else { buffer1(0) = buffer1.getAs[String](0).concat(separator).concat(buffer2.getAs[String](0)) diff --git a/core/src/main/scala/org/apache/spark/sql/crossdata/util/CreateRelationUtil.scala b/core/src/main/scala/org/apache/spark/sql/crossdata/util/CreateRelationUtil.scala index 9c2d091f4..f66c52b75 100644 --- a/core/src/main/scala/org/apache/spark/sql/crossdata/util/CreateRelationUtil.scala +++ b/core/src/main/scala/org/apache/spark/sql/crossdata/util/CreateRelationUtil.scala @@ -22,25 +22,31 @@ import XDCatalog.CrossdataTable import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.sources.{HadoopFsRelationProvider, RelationProvider, SchemaRelationProvider} -object CreateRelationUtil extends SparkLoggerComponent{ +object CreateRelationUtil extends SparkLoggerComponent { - protected[crossdata] def createLogicalRelation(sqlContext: SQLContext, crossdataTable: CrossdataTable): LogicalRelation = { + protected[crossdata] def createLogicalRelation( + sqlContext: SQLContext, + crossdataTable: CrossdataTable): LogicalRelation = { /** Although table schema is inferred and persisted in XDCatalog, the schema can't be specified in some cases because *the source does not implement SchemaRelationProvider (e.g. JDBC) */ + val tableSchema = + ResolvedDataSource.lookupDataSource(crossdataTable.datasource).newInstance() match { + case _: SchemaRelationProvider | _: HadoopFsRelationProvider => + crossdataTable.schema + case _: RelationProvider => + None + case other => + val msg = s"Unexpected datasource: $other" + logError(msg) + throw new RuntimeException(msg) + } - val tableSchema = ResolvedDataSource.lookupDataSource(crossdataTable.datasource).newInstance() match { - case _: SchemaRelationProvider | _: HadoopFsRelationProvider => - crossdataTable.schema - case _: RelationProvider => - None - case other => - val msg = s"Unexpected datasource: $other" - logError(msg) - throw new RuntimeException(msg) - } - - val resolved = ResolvedDataSource(sqlContext, tableSchema, crossdataTable.partitionColumn, crossdataTable.datasource, crossdataTable.opts) + val resolved = ResolvedDataSource(sqlContext, + tableSchema, + crossdataTable.partitionColumn, + crossdataTable.datasource, + crossdataTable.opts) LogicalRelation(resolved.relation) } diff --git a/core/src/main/scala/org/apache/spark/sql/sources/CatalystToCrossdataAdapter.scala b/core/src/main/scala/org/apache/spark/sql/sources/CatalystToCrossdataAdapter.scala index cc168abc1..9330c7fc9 100644 --- a/core/src/main/scala/org/apache/spark/sql/sources/CatalystToCrossdataAdapter.scala +++ b/core/src/main/scala/org/apache/spark/sql/sources/CatalystToCrossdataAdapter.scala @@ -15,7 +15,6 @@ */ package org.apache.spark.sql.sources - import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.expressions @@ -34,47 +33,53 @@ import scala.collection.mutable.ListBuffer object CatalystToCrossdataAdapter { - case class FilterReport(filtersIgnored: Seq[Expression], ignoredNativeUDFReferences: Seq[AttributeReference]) + case class FilterReport(filtersIgnored: Seq[Expression], + ignoredNativeUDFReferences: Seq[AttributeReference]) case class ProjectReport(expressionsIgnored: Seq[Expression]) abstract class BaseLogicalPlan( - val projects: Seq[NamedExpression], - val filters: Array[SourceFilter], - val udfsMap: Map[Attribute, NativeUDF], - val collectionRandomAccesses: Map[Attribute, GetArrayItem] - ) + val projects: Seq[NamedExpression], + val filters: Array[SourceFilter], + val udfsMap: Map[Attribute, NativeUDF], + val collectionRandomAccesses: Map[Attribute, GetArrayItem] + ) case class SimpleLogicalPlan(override val projects: Seq[Attribute], override val filters: Array[SourceFilter], override val udfsMap: Map[Attribute, NativeUDF], - override val collectionRandomAccesses: Map[Attribute, GetArrayItem] - ) extends BaseLogicalPlan(projects, filters, udfsMap, collectionRandomAccesses) - - case class AggregationLogicalPlan(override val projects: Seq[NamedExpression], - groupingExpresion: Seq[Expression], - override val filters: Array[SourceFilter], - override val udfsMap: Map[Attribute, NativeUDF], - override val collectionRandomAccesses: Map[Attribute, GetArrayItem] - ) extends BaseLogicalPlan(projects, filters, udfsMap, collectionRandomAccesses) + override val collectionRandomAccesses: Map[Attribute, GetArrayItem]) + extends BaseLogicalPlan(projects, filters, udfsMap, collectionRandomAccesses) + case class AggregationLogicalPlan( + override val projects: Seq[NamedExpression], + groupingExpresion: Seq[Expression], + override val filters: Array[SourceFilter], + override val udfsMap: Map[Attribute, NativeUDF], + override val collectionRandomAccesses: Map[Attribute, GetArrayItem]) + extends BaseLogicalPlan(projects, filters, udfsMap, collectionRandomAccesses) /** - * Transforms a Catalyst Logical Plan to a Crossdata Logical Plan - * @param logicalPlan catalyst logical plan - * @return A tuple of (Crossdata BaseLogicalPlan, FilterReport) - */ - def getConnectorLogicalPlan(logicalPlan: LogicalPlan, - projects: Seq[NamedExpression], - filterPredicates: Seq[Expression]): (BaseLogicalPlan, ProjectReport, FilterReport) = { + * Transforms a Catalyst Logical Plan to a Crossdata Logical Plan + * @param logicalPlan catalyst logical plan + * @return A tuple of (Crossdata BaseLogicalPlan, FilterReport) + */ + def getConnectorLogicalPlan( + logicalPlan: LogicalPlan, + projects: Seq[NamedExpression], + filterPredicates: Seq[Expression]): (BaseLogicalPlan, ProjectReport, FilterReport) = { val relation = logicalPlan.collectFirst { case lr: LogicalRelation => lr }.get - implicit val att2udf = logicalPlan.collect { case EvaluateNativeUDF(udf, child, att) => att -> udf } toMap - implicit val att2itemAccess: Map[Attribute, GetArrayItem] =(projects ++ filterPredicates).flatMap { c => - c.collect { - case gi @ GetArrayItem(a@AttributeReference(name, ArrayType(etype, _), nullable, md), _) => - AttributeReference(name, etype, true)() -> gi - } + implicit val att2udf = logicalPlan.collect { + case EvaluateNativeUDF(udf, child, att) => att -> udf } toMap + implicit val att2itemAccess: Map[Attribute, GetArrayItem] = + (projects ++ filterPredicates).flatMap { c => + c.collect { + case gi @ GetArrayItem(a @ AttributeReference(name, ArrayType(etype, _), nullable, md), + _) => + AttributeReference(name, etype, true)() -> gi + } + } toMap val itemAccess2att: Map[GetArrayItem, Attribute] = att2itemAccess.map(_.swap) @@ -85,32 +90,35 @@ object CatalystToCrossdataAdapter { import ExpressionType._ - def extractRequestedColumns(namedExpression: Expression): Seq[(ExpressionType, Expression)] = namedExpression match { + def extractRequestedColumns(namedExpression: Expression): Seq[(ExpressionType, Expression)] = + namedExpression match { - case Alias(child, _) => - extractRequestedColumns(child) + case Alias(child, _) => + extractRequestedColumns(child) - case aRef: AttributeReference => - Seq(Requested -> aRef) + case aRef: AttributeReference => + Seq(Requested -> aRef) - case nudf: NativeUDF => - nudf.references flatMap { - case nat: AttributeReference if att2udf contains nat => - udfFlattenedActualParameters(nat, at => Found -> relation.attributeMap(at)) :+ (Requested -> nat) - } toSeq + case nudf: NativeUDF => + nudf.references flatMap { + case nat: AttributeReference if att2udf contains nat => + udfFlattenedActualParameters(nat, at => Found -> relation.attributeMap(at)) :+ (Requested -> nat) + } toSeq - case c: GetArrayItem if itemAccess2att contains c => - c.references.map(Found -> relation.attributeMap(_)).toSeq :+ (Requested -> itemAccess2att(c)) + case c: GetArrayItem if itemAccess2att contains c => + c.references + .map(Found -> relation.attributeMap(_)) + .toSeq :+ (Requested -> itemAccess2att(c)) - // TODO should these expressions be ignored? We are omitting expressions within structfields - case c: GetStructField => c.references flatMap { - case x => Seq(Requested -> relation.attributeMap(x)) - } toSeq - - case ignoredExpr => - Seq(Ignored -> ignoredExpr) - } + // TODO should these expressions be ignored? We are omitting expressions within structfields + case c: GetStructField => + c.references flatMap { + case x => Seq(Requested -> relation.attributeMap(x)) + } toSeq + case ignoredExpr => + Seq(Ignored -> ignoredExpr) + } val columnExpressions: Map[ExpressionType, Seq[Expression]] = projects.flatMap { extractRequestedColumns @@ -118,171 +126,201 @@ object CatalystToCrossdataAdapter { val pushedFilters = filterPredicates.map { _ transform { - case getitem: GetArrayItem if itemAccess2att contains getitem => itemAccess2att(getitem) + case getitem: GetArrayItem if itemAccess2att contains getitem => + itemAccess2att(getitem) case a: AttributeReference if att2udf contains a => a - case a: Attribute => relation.attributeMap(a) // Match original case of attributes. + case a: Attribute => + relation.attributeMap(a) // Match original case of attributes. } } val (filters, filterReport) = selectFilters(pushedFilters, att2udf.keySet, att2itemAccess) val aggregatePlan: Option[(Seq[Expression], Seq[NamedExpression])] = logicalPlan.collectFirst { - case Aggregate(groupingExpression, aggregationExpression, child) => (groupingExpression, aggregationExpression) + case Aggregate(groupingExpression, aggregationExpression, child) => + (groupingExpression, aggregationExpression) } - val baseLogicalPlan = aggregatePlan.fold[BaseLogicalPlan] { val requestedColumns: Seq[Attribute] = - columnExpressions.getOrElse(Requested, Seq.empty) collect { case a: Attribute => a } + columnExpressions.getOrElse(Requested, Seq.empty) collect { + case a: Attribute => a + } SimpleLogicalPlan(requestedColumns, filters.toArray, att2udf, att2itemAccess) - } { case (groupingExpression, selectExpression) => - AggregationLogicalPlan(selectExpression, groupingExpression, filters, att2udf, att2itemAccess) + } { + case (groupingExpression, selectExpression) => + AggregationLogicalPlan(selectExpression, + groupingExpression, + filters, + att2udf, + att2itemAccess) } val projectReport = columnExpressions.getOrElse(Ignored, Seq.empty) (baseLogicalPlan, ProjectReport(projectReport), filterReport) } def udfFlattenedActualParameters[B]( - udfAttr: AttributeReference, - f: Attribute => B - )(implicit udfs: Map[Attribute, NativeUDF]): Seq[B] = { - udfs(udfAttr).children.flatMap { case att: AttributeReference => - if(udfs contains att) udfFlattenedActualParameters(att, f) else Seq(f(att)) + udfAttr: AttributeReference, + f: Attribute => B + )(implicit udfs: Map[Attribute, NativeUDF]): Seq[B] = { + udfs(udfAttr).children.flatMap { + case att: AttributeReference => + if (udfs contains att) udfFlattenedActualParameters(att, f) + else Seq(f(att)) } } /** - * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, - * and convert them. - * - * @param filters catalyst filters - * @return filters which are convertible and a boolean indicating whether any filter has been ignored. - */ + * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, + * and convert them. + * + * @param filters catalyst filters + * @return filters which are convertible and a boolean indicating whether any filter has been ignored. + */ private[this] def selectFilters( - filters: Seq[Expression], - udfs: Set[Attribute], - att2arrayaccess: Map[Attribute, GetArrayItem] - ): (Array[SourceFilter], FilterReport) = { + filters: Seq[Expression], + udfs: Set[Attribute], + att2arrayaccess: Map[Attribute, GetArrayItem] + ): (Array[SourceFilter], FilterReport) = { val ignoredExpressions: ListBuffer[Expression] = ListBuffer.empty val ignoredNativeUDFReferences: ListBuffer[AttributeReference] = ListBuffer.empty - def attAsOperand(att: Attribute): String = att2arrayaccess.get(att).map { - case GetArrayItem(child, ordinal) => s"${att.name}[${ordinal.toString()}]" - } getOrElse(att.name) - - def translate(predicate: Expression): Option[SourceFilter] = predicate match { - case expressions.EqualTo(a: Attribute, Literal(v, t)) => - Some(sources.EqualTo(attAsOperand(a), convertToScala(v, t))) - case expressions.EqualTo(Literal(v, t), a: Attribute) => - Some(sources.EqualTo(attAsOperand(a), convertToScala(v, t))) - case expressions.EqualTo(a: AttributeReference, b: Attribute) if udfs contains a => - Some(sources.EqualTo(attAsOperand(b), a)) - case expressions.EqualTo(b: Attribute, a: AttributeReference) if udfs contains a => - Some(sources.EqualTo(attAsOperand(b), a)) - case expressions.EqualTo(Cast(a:Attribute, StringType), Literal(v, t)) => - Some(sources.EqualTo(attAsOperand(a), convertToScala(Cast(Literal(v.toString), a.dataType).eval(EmptyRow), a - .dataType))) - - /* TODO + def attAsOperand(att: Attribute): String = + att2arrayaccess.get(att).map { + case GetArrayItem(child, ordinal) => + s"${att.name}[${ordinal.toString()}]" + } getOrElse (att.name) + + def translate(predicate: Expression): Option[SourceFilter] = + predicate match { + case expressions.EqualTo(a: Attribute, Literal(v, t)) => + Some(sources.EqualTo(attAsOperand(a), convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), a: Attribute) => + Some(sources.EqualTo(attAsOperand(a), convertToScala(v, t))) + case expressions.EqualTo(a: AttributeReference, b: Attribute) if udfs contains a => + Some(sources.EqualTo(attAsOperand(b), a)) + case expressions.EqualTo(b: Attribute, a: AttributeReference) if udfs contains a => + Some(sources.EqualTo(attAsOperand(b), a)) + case expressions.EqualTo(Cast(a: Attribute, StringType), Literal(v, t)) => + Some( + sources.EqualTo(attAsOperand(a), + convertToScala(Cast(Literal(v.toString), a.dataType).eval(EmptyRow), + a.dataType))) + + /* TODO case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) => Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) case expressions.EqualNullSafe(Literal(v, t), a: Attribute) => Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) - */ - - case expressions.GreaterThan(a: Attribute, Literal(v, t)) => - Some(sources.GreaterThan(attAsOperand(a), convertToScala(v, t))) - case expressions.GreaterThan(Literal(v, t), a: Attribute) => - Some(sources.LessThan(attAsOperand(a), convertToScala(v, t))) - case expressions.GreaterThan(b: Attribute, a: AttributeReference) if udfs contains a => - Some(sources.GreaterThan(attAsOperand(b), a)) - case expressions.GreaterThan(a: AttributeReference, b: Attribute) if udfs contains a => - Some(sources.LessThan(attAsOperand(b), a)) - case expressions.GreaterThan(Cast(a:Attribute, StringType), Literal(v, t)) => - Some(sources.GreaterThan(attAsOperand(a), - convertToScala(Cast(Literal(v.toString), a.dataType).eval(EmptyRow), a.dataType))) - - case expressions.LessThan(a: Attribute, Literal(v, t)) => - Some(sources.LessThan(attAsOperand(a), convertToScala(v, t))) - case expressions.LessThan(Literal(v, t), a: Attribute) => - Some(sources.GreaterThan(attAsOperand(a), convertToScala(v, t))) - case expressions.LessThan(b: Attribute, a: AttributeReference) if udfs contains a => - Some(sources.LessThan(attAsOperand(b), a)) - case expressions.LessThan(a: AttributeReference, b: Attribute) if udfs contains a => - Some(sources.GreaterThan(attAsOperand(b), a)) - case expressions.LessThan(Cast(a:Attribute, StringType), Literal(v, t)) => - Some(sources.LessThan(attAsOperand(a), - convertToScala(Cast(Literal(v.toString), a.dataType).eval(EmptyRow), a.dataType))) - - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) => - Some(sources.GreaterThanOrEqual(attAsOperand(a), convertToScala(v, t))) - case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) => - Some(sources.LessThanOrEqual(attAsOperand(a), convertToScala(v, t))) - case expressions.GreaterThanOrEqual(b: Attribute, a: AttributeReference) if udfs contains a => - Some(sources.GreaterThanOrEqual(attAsOperand(b), a)) - case expressions.GreaterThanOrEqual(a: AttributeReference, b: Attribute) if udfs contains a => - Some(sources.LessThanOrEqual(attAsOperand(b), a)) - case expressions.GreaterThanOrEqual(Cast(a:Attribute,StringType), Literal(v, t)) => - Some(sources.GreaterThanOrEqual(attAsOperand(a), - convertToScala(Cast(Literal(v.toString), a.dataType).eval(EmptyRow), a.dataType))) - - case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) => - Some(sources.LessThanOrEqual(attAsOperand(a), convertToScala(v, t))) - case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) => - Some(sources.GreaterThanOrEqual(attAsOperand(a), convertToScala(v, t))) - case expressions.LessThanOrEqual(b: Attribute, a: AttributeReference) if udfs contains a => - Some(sources.LessThanOrEqual(attAsOperand(b), a)) - case expressions.LessThanOrEqual(a: AttributeReference, b: Attribute) if udfs contains a => - Some(sources.GreaterThanOrEqual(attAsOperand(b), a)) - case expressions.LessThanOrEqual(Cast(a:Attribute,StringType), Literal(v, t)) => - Some(sources.LessThanOrEqual(attAsOperand(a), - convertToScala(Cast(Literal(v.toString), a.dataType).eval(EmptyRow), a.dataType))) - - - case expressions.InSet(a: Attribute, set) => - val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) - Some(sources.In(attAsOperand(a), set.toArray.map(toScala))) - - // Because we only convert In to InSet in Optimizer when there are more than certain - // items. So it is possible we still get an In expression here that needs to be pushed - // down. - case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => - val hSet = list.map(e => e.eval(EmptyRow)) - val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) - Some(sources.In(attAsOperand(a), hSet.toArray.map(toScala))) - - case expressions.IsNull(a: Attribute) => - Some(sources.IsNull(attAsOperand(a))) - case expressions.IsNotNull(a: Attribute) => - Some(sources.IsNotNull(attAsOperand(a))) - - case expressions.And(left, right) => - (translate(left) ++ translate(right)).reduceOption(sources.And) - - case expressions.Or(left, right) => - for { - leftFilter <- translate(left) - rightFilter <- translate(right) - } yield sources.Or(leftFilter, rightFilter) - - case expressions.Not(child) => - translate(child).map(sources.Not) - - case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => - Some(sources.StringStartsWith(attAsOperand(a), v.toString)) - - case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) => - Some(sources.StringEndsWith(attAsOperand(a), v.toString)) - - case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => - Some(sources.StringContains(attAsOperand(a), v.toString)) - - case expression => - ignoredExpressions += expression - ignoredNativeUDFReferences ++= expression.collect{case att: AttributeReference if udfs contains att => att} - None + */ + + case expressions.GreaterThan(a: Attribute, Literal(v, t)) => + Some(sources.GreaterThan(attAsOperand(a), convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), a: Attribute) => + Some(sources.LessThan(attAsOperand(a), convertToScala(v, t))) + case expressions.GreaterThan(b: Attribute, a: AttributeReference) if udfs contains a => + Some(sources.GreaterThan(attAsOperand(b), a)) + case expressions.GreaterThan(a: AttributeReference, b: Attribute) if udfs contains a => + Some(sources.LessThan(attAsOperand(b), a)) + case expressions.GreaterThan(Cast(a: Attribute, StringType), Literal(v, t)) => + Some( + sources.GreaterThan( + attAsOperand(a), + convertToScala(Cast(Literal(v.toString), a.dataType).eval(EmptyRow), + a.dataType))) + + case expressions.LessThan(a: Attribute, Literal(v, t)) => + Some(sources.LessThan(attAsOperand(a), convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), a: Attribute) => + Some(sources.GreaterThan(attAsOperand(a), convertToScala(v, t))) + case expressions.LessThan(b: Attribute, a: AttributeReference) if udfs contains a => + Some(sources.LessThan(attAsOperand(b), a)) + case expressions.LessThan(a: AttributeReference, b: Attribute) if udfs contains a => + Some(sources.GreaterThan(attAsOperand(b), a)) + case expressions.LessThan(Cast(a: Attribute, StringType), Literal(v, t)) => + Some( + sources.LessThan(attAsOperand(a), + convertToScala(Cast(Literal(v.toString), a.dataType).eval(EmptyRow), + a.dataType))) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) => + Some(sources.GreaterThanOrEqual(attAsOperand(a), convertToScala(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) => + Some(sources.LessThanOrEqual(attAsOperand(a), convertToScala(v, t))) + case expressions.GreaterThanOrEqual(b: Attribute, a: AttributeReference) + if udfs contains a => + Some(sources.GreaterThanOrEqual(attAsOperand(b), a)) + case expressions.GreaterThanOrEqual(a: AttributeReference, b: Attribute) + if udfs contains a => + Some(sources.LessThanOrEqual(attAsOperand(b), a)) + case expressions.GreaterThanOrEqual(Cast(a: Attribute, StringType), Literal(v, t)) => + Some( + sources.GreaterThanOrEqual( + attAsOperand(a), + convertToScala(Cast(Literal(v.toString), a.dataType).eval(EmptyRow), + a.dataType))) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) => + Some(sources.LessThanOrEqual(attAsOperand(a), convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) => + Some(sources.GreaterThanOrEqual(attAsOperand(a), convertToScala(v, t))) + case expressions.LessThanOrEqual(b: Attribute, a: AttributeReference) if udfs contains a => + Some(sources.LessThanOrEqual(attAsOperand(b), a)) + case expressions.LessThanOrEqual(a: AttributeReference, b: Attribute) if udfs contains a => + Some(sources.GreaterThanOrEqual(attAsOperand(b), a)) + case expressions.LessThanOrEqual(Cast(a: Attribute, StringType), Literal(v, t)) => + Some( + sources.LessThanOrEqual( + attAsOperand(a), + convertToScala(Cast(Literal(v.toString), a.dataType).eval(EmptyRow), + a.dataType))) + + case expressions.InSet(a: Attribute, set) => + val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) + Some(sources.In(attAsOperand(a), set.toArray.map(toScala))) + + // Because we only convert In to InSet in Optimizer when there are more than certain + // items. So it is possible we still get an In expression here that needs to be pushed + // down. + case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => + val hSet = list.map(e => e.eval(EmptyRow)) + val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) + Some(sources.In(attAsOperand(a), hSet.toArray.map(toScala))) + + case expressions.IsNull(a: Attribute) => + Some(sources.IsNull(attAsOperand(a))) + case expressions.IsNotNull(a: Attribute) => + Some(sources.IsNotNull(attAsOperand(a))) + + case expressions.And(left, right) => + (translate(left) ++ translate(right)).reduceOption(sources.And) + + case expressions.Or(left, right) => + for { + leftFilter <- translate(left) + rightFilter <- translate(right) + } yield sources.Or(leftFilter, rightFilter) + + case expressions.Not(child) => + translate(child).map(sources.Not) + + case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringStartsWith(attAsOperand(a), v.toString)) + + case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringEndsWith(attAsOperand(a), v.toString)) + + case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringContains(attAsOperand(a), v.toString)) + + case expression => + ignoredExpressions += expression + ignoredNativeUDFReferences ++= expression.collect { + case att: AttributeReference if udfs contains att => att + } + None - } + } val convertibleFilters = filters.flatMap(translate).toArray val filterReport = FilterReport(ignoredExpressions, ignoredNativeUDFReferences) diff --git a/core/src/test/scala/com/stratio/crossdata/test/BaseXDTest.scala b/core/src/test/scala/com/stratio/crossdata/test/BaseXDTest.scala index ad0d7fe51..d3dd80c71 100644 --- a/core/src/test/scala/com/stratio/crossdata/test/BaseXDTest.scala +++ b/core/src/test/scala/com/stratio/crossdata/test/BaseXDTest.scala @@ -20,8 +20,8 @@ import org.scalatest.time.SpanSugar._ import org.scalatest.{FlatSpec, Matchers} /** - * Base class for both unit and integration tests - */ + * Base class for both unit and integration tests + */ trait BaseXDTest extends FlatSpec with Matchers with TimeLimitedTests { val timeLimit = 2 minutes diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/MockBaseRelation.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/MockBaseRelation.scala index e4761206c..772569f6c 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/MockBaseRelation.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/MockBaseRelation.scala @@ -23,5 +23,6 @@ class MockBaseRelation extends BaseRelation with Serializable { override def sqlContext: SQLContext = ??? - override def schema: StructType = StructType(List(StructField("id", IntegerType))) + override def schema: StructType = + StructType(List(StructField("id", IntegerType))) } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/XDContextIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/XDContextIT.scala index d1cbf7c47..ad43509e1 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/XDContextIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/XDContextIT.scala @@ -38,7 +38,9 @@ class XDContextIT extends SharedXDContextTest { "A XDContext" should "perform a collect with a collection" in { - val df: DataFrame = xdContext.createDataFrame(xdContext.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i"))), StructType(Array(StructField("id", StringType)))) + val df: DataFrame = xdContext.createDataFrame( + xdContext.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i"))), + StructType(Array(StructField("id", StringType)))) df.registerTempTable("records") val result: Array[Row] = xdContext.sql("SELECT * FROM records").collect() @@ -48,33 +50,42 @@ class XDContextIT extends SharedXDContextTest { it must "return a XDDataFrame when executing a SQL query" in { - val df: DataFrame = xdContext.createDataFrame(xdContext.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i"))), StructType(Array(StructField("id", StringType)))) + val df: DataFrame = xdContext.createDataFrame( + xdContext.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i"))), + StructType(Array(StructField("id", StringType)))) df.registerTempTable("records") val dataframe = xdContext.sql("SELECT * FROM records") dataframe shouldBe a[XDDataFrame] } - it must "plan a PersistDataSource when creating a table " in { - val dataframe = xdContext.sql(s"CREATE TABLE jsonTable USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/core-reference.conf").toURI()).toString}')") + val dataframe = + xdContext.sql(s"CREATE TABLE jsonTable USING org.apache.spark.sql.json OPTIONS (path '${Paths + .get(getClass.getResource("/core-reference.conf").toURI()) + .toString}')") val sparkPlan = dataframe.queryExecution.sparkPlan xdContext.catalog.dropTable(TableIdentifier("jsonTable", None)) - sparkPlan should matchPattern { case ExecutedCommand(_: PersistDataSourceTable) => } + sparkPlan should matchPattern { + case ExecutedCommand(_: PersistDataSourceTable) => + } } it must "plan a query with conflicted column names between two tables resolving by alias preference" in { - val t1: DataFrame = xdContext.createDataFrame(xdContext.sparkContext.parallelize((1 to 5) - .map(i => Row(s"val_$i", i))), StructType(Array(StructField("id", StringType), StructField("value", IntegerType)))) + val t1: DataFrame = xdContext.createDataFrame( + xdContext.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i", i))), + StructType(Array(StructField("id", StringType), StructField("value", IntegerType)))) t1.registerTempTable("t1") - val t2: DataFrame = xdContext.createDataFrame(xdContext.sparkContext.parallelize((4 to 8) - .map(i => Row(s"val_$i", i))), StructType(Array(StructField("name", StringType), StructField("value", IntegerType)))) + val t2: DataFrame = xdContext.createDataFrame( + xdContext.sparkContext.parallelize((4 to 8).map(i => Row(s"val_$i", i))), + StructType(Array(StructField("name", StringType), StructField("value", IntegerType)))) t2.registerTempTable("t2") - val dataFrame = xdContext.sql("SELECT t1.id, t2.name as name, t1.value as total FROM t1 INNER JOIN t2 ON t1.id = t2.name GROUP BY id, name, total") + val dataFrame = xdContext.sql( + "SELECT t1.id, t2.name as name, t1.value as total FROM t1 INNER JOIN t2 ON t1.id = t2.name GROUP BY id, name, total") dataFrame.show @@ -84,8 +95,9 @@ class XDContextIT extends SharedXDContextTest { it must "plan a query with aliased attributes in the group by clause" in { - val t1: DataFrame = xdContext.createDataFrame(xdContext.sparkContext.parallelize((1 to 5) - .map(i => Row(s"val_$i", i))), StructType(Array(StructField("id", StringType), StructField("value", IntegerType)))) + val t1: DataFrame = xdContext.createDataFrame( + xdContext.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i", i))), + StructType(Array(StructField("id", StringType), StructField("value", IntegerType)))) t1.registerTempTable("t3") val dataFrame = xdContext.sql("SELECT id as id, value as product FROM t3 GROUP BY id, product") @@ -103,5 +115,4 @@ class XDContextIT extends SharedXDContextTest { // // } - } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/XDDataFrameIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/XDDataFrameIT.scala index 0e6a28c9c..57e919e05 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/XDDataFrameIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/XDDataFrameIT.scala @@ -31,7 +31,10 @@ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class XDDataFrameIT extends SharedXDContextTest with Inside { - lazy val sparkRows = xdContext.createDataFrame(xdContext.sparkContext.parallelize(Seq(Row(1))), StructType(Array(StructField("id", IntegerType)))).collect() + lazy val sparkRows = xdContext + .createDataFrame(xdContext.sparkContext.parallelize(Seq(Row(1))), + StructType(Array(StructField("id", IntegerType)))) + .collect() lazy val nativeRows = Array(Row(2l)) "A XDDataFrame (select * from nativeRelation)" should "be executed natively" in { @@ -56,7 +59,8 @@ class XDDataFrameIT extends SharedXDContextTest with Inside { } "A XDDataFrame with a logical plan which is not supported natively" should "be executed on the Spark cluster" in { - val result = XDDataFrame(xdContext, LogicalRelation(mockNativeRelationUnsupportedPlan)).collect() + val result = + XDDataFrame(xdContext, LogicalRelation(mockNativeRelationUnsupportedPlan)).collect() result should have length 1 result(0) should equal(sparkRows(0)) } @@ -70,27 +74,32 @@ class XDDataFrameIT extends SharedXDContextTest with Inside { "A XDDataFrame " should "return a XDDataFrame when applying a limit" in { val dataframe = XDDataFrame(xdContext, LogicalRelation(mockNativeRelation)).limit(5) dataframe shouldBe a[XDDataFrame] - dataframe.logicalPlan should matchPattern { case Limit(Literal(5, _), _) => } + dataframe.logicalPlan should matchPattern { + case Limit(Literal(5, _), _) => + } } "A XDDataFrame " should "return a XDDataFrame when applying a count" in { XDDataFrame(xdContext, LogicalRelation(mockNativeRelation)).count() should be(2l) } - val mockNonNativeRelation = new MockBaseRelation val mockNativeRelation = new MockBaseRelation with NativeScan with TableScan { override def isSupported(logicalStep: LogicalPlan, fullyLogicalPlan: LogicalPlan) = true // Native execution - override def buildScan(optimizedLogicalPlan: LogicalPlan): Option[Array[Row]] = Some(nativeRows) + override def buildScan(optimizedLogicalPlan: LogicalPlan): Option[Array[Row]] = + Some(nativeRows) // Spark execution - override def buildScan(): RDD[Row] = xdContext.createDataFrame(xdContext.sparkContext.parallelize(Seq(Row(1))), StructType(Array(StructField("id", IntegerType)))).rdd + override def buildScan(): RDD[Row] = + xdContext + .createDataFrame(xdContext.sparkContext.parallelize(Seq(Row(1))), + StructType(Array(StructField("id", IntegerType)))) + .rdd } - val mockPureSparkNativeRelation = new MockBaseRelation with NativeScan with TableScan { override def isSupported(logicalStep: LogicalPlan, fullyLogicalPlan: LogicalPlan) = true @@ -98,27 +107,41 @@ class XDDataFrameIT extends SharedXDContextTest with Inside { override def buildScan(optimizedLogicalPlan: LogicalPlan): Option[Array[Row]] = None // Spark execution - override def buildScan(): RDD[Row] = xdContext.createDataFrame(xdContext.sparkContext.parallelize(Seq(Row(1))), StructType(Array(StructField("id", IntegerType)))).rdd + override def buildScan(): RDD[Row] = + xdContext + .createDataFrame(xdContext.sparkContext.parallelize(Seq(Row(1))), + StructType(Array(StructField("id", IntegerType)))) + .rdd } val mockNativeRelationWith2Rows = new MockBaseRelation with NativeScan with TableScan { override def isSupported(logicalStep: LogicalPlan, fullyLogicalPlan: LogicalPlan) = true // Native execution - override def buildScan(optimizedLogicalPlan: LogicalPlan): Option[Array[Row]] = Some(Array(nativeRows(0), nativeRows(0))) + override def buildScan(optimizedLogicalPlan: LogicalPlan): Option[Array[Row]] = + Some(Array(nativeRows(0), nativeRows(0))) // Spark execution - override def buildScan(): RDD[Row] = xdContext.createDataFrame(xdContext.sparkContext.parallelize(Seq(Row(1))), StructType(Array(StructField("id", IntegerType)))).rdd + override def buildScan(): RDD[Row] = + xdContext + .createDataFrame(xdContext.sparkContext.parallelize(Seq(Row(1))), + StructType(Array(StructField("id", IntegerType)))) + .rdd } val mockNativeRelationUnsupportedPlan = new MockBaseRelation with NativeScan with TableScan { override def isSupported(logicalStep: LogicalPlan, fullyLogicalPlan: LogicalPlan) = false // Native execution - override def buildScan(optimizedLogicalPlan: LogicalPlan): Option[Array[Row]] = Some(nativeRows) + override def buildScan(optimizedLogicalPlan: LogicalPlan): Option[Array[Row]] = + Some(nativeRows) // Spark execution - override def buildScan(): RDD[Row] = xdContext.createDataFrame(xdContext.sparkContext.parallelize(Seq(Row(1))), StructType(Array(StructField("id", IntegerType)))).rdd + override def buildScan(): RDD[Row] = + xdContext + .createDataFrame(xdContext.sparkContext.parallelize(Seq(Row(1))), + StructType(Array(StructField("id", IntegerType)))) + .rdd } } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/XDDataFrameUtilsSpec.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/XDDataFrameUtilsSpec.scala index 61d46a410..7ba4a92f8 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/XDDataFrameUtilsSpec.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/XDDataFrameUtilsSpec.scala @@ -38,7 +38,6 @@ class XDDataFrameUtilsSpec extends BaseXDTest { val mockNonNativeRelation = new MockBaseRelation - "A logical plan (select *) with native relation" should "return a native relation" in { val nativeRelation = findNativeQueryExecutor(LogicalRelation(mockNativeRelation)) assert(nativeRelation.nonEmpty) @@ -54,8 +53,10 @@ class XDDataFrameUtilsSpec extends BaseXDTest { val attributeReference = AttributeReference("comment", StringType)(ExprId(1), Seq("students")) val projectPlan = Project(Seq(attributeReference), LogicalRelation(mockNativeRelation)) val count = Count(Literal(1)) - val optimizedLogicalPlan = Aggregate(Seq(attributeReference), - Seq(Alias(count, "c0")(ExprId(2)), attributeReference, Alias(count, "c2")(ExprId(3))), projectPlan) + val optimizedLogicalPlan = Aggregate( + Seq(attributeReference), + Seq(Alias(count, "c0")(ExprId(2)), attributeReference, Alias(count, "c2")(ExprId(3))), + projectPlan) val nativeRelation = findNativeQueryExecutor(optimizedLogicalPlan) assert(nativeRelation.nonEmpty) @@ -63,13 +64,11 @@ class XDDataFrameUtilsSpec extends BaseXDTest { } "A logical plan (select count(a.comment) from (select comment from students) as a ) with native relation" should "return a native relation" in { - val commentAttributeReference = AttributeReference("comment", StringType)(ExprId(1), Seq("students")) + val commentAttributeReference = + AttributeReference("comment", StringType)(ExprId(1), Seq("students")) val projectPlan = Project(Seq(commentAttributeReference), LogicalRelation(mockNativeRelation)) val count = Count(commentAttributeReference) - val optimizedLogicalPlan = Aggregate( - Nil, - Seq(Alias(count, "c0")(ExprId(4))), - projectPlan) + val optimizedLogicalPlan = Aggregate(Nil, Seq(Alias(count, "c0")(ExprId(4))), projectPlan) val nativeRelation = findNativeQueryExecutor(optimizedLogicalPlan) assert(nativeRelation.nonEmpty) @@ -88,7 +87,8 @@ class XDDataFrameUtilsSpec extends BaseXDTest { "A logical plan (select * from students where id > 2 or id = 1) with native relation" should "return a native relation" in { val idAttributeReference = AttributeReference("id", IntegerType)(ExprId(0), Seq("students")) - val orCondition = Or(GreaterThan(idAttributeReference, Literal(2)), EqualTo(idAttributeReference, Literal(1))) + val orCondition = + Or(GreaterThan(idAttributeReference, Literal(2)), EqualTo(idAttributeReference, Literal(1))) val optimizedLogicalPlan = Filter(orCondition, LogicalRelation(mockNativeRelation)) val nativeRelation = findNativeQueryExecutor(optimizedLogicalPlan) @@ -99,7 +99,8 @@ class XDDataFrameUtilsSpec extends BaseXDTest { "A logical plan (select * from students order by id) with native relation" should "return a native relation" in { val idAttributeReference = AttributeReference("id", IntegerType)(ExprId(0), Seq("students")) val sortOrder = SortOrder(idAttributeReference, Ascending) - val optimizedLogicalPlan = Sort(Seq(sortOrder), global = true, LogicalRelation(mockNativeRelation)) + val optimizedLogicalPlan = + Sort(Seq(sortOrder), global = true, LogicalRelation(mockNativeRelation)) val nativeRelation = findNativeQueryExecutor(optimizedLogicalPlan) assert(nativeRelation.nonEmpty) @@ -118,7 +119,8 @@ class XDDataFrameUtilsSpec extends BaseXDTest { "A logical plan (select case id > 5 when true then 1 else 0 end from students) with native relation" should "return a native relation" in { val idAttributeReference = AttributeReference("id", IntegerType)(ExprId(0), Seq("students")) - val caseKeyWhen = CaseKeyWhen(GreaterThan(idAttributeReference, Literal(5)), Seq(Literal(true), Literal(1), Literal(0))) + val caseKeyWhen = CaseKeyWhen(GreaterThan(idAttributeReference, Literal(5)), + Seq(Literal(true), Literal(1), Literal(0))) val caseWhenAlias = Alias(caseKeyWhen, "c0")(ExprId(5)) val optimizedLogicalPlan = Project(Seq(caseWhenAlias), LogicalRelation(mockNativeRelation)) @@ -130,7 +132,8 @@ class XDDataFrameUtilsSpec extends BaseXDTest { "A logical plan (select * from students where id between 0 and 6 AND id in (3,5,10)) with native relation" should "return a native relation" in { val idAttributeReference = AttributeReference("id", IntegerType)(ExprId(0), Seq("students")) - val betweenFilter = And(GreaterThanOrEqual(idAttributeReference, Literal(0)), LessThanOrEqual(idAttributeReference, Literal(6))) + val betweenFilter = And(GreaterThanOrEqual(idAttributeReference, Literal(0)), + LessThanOrEqual(idAttributeReference, Literal(6))) val condition = And(betweenFilter, InSet(idAttributeReference, Set(5, 10, 3))) val optimizedLogicalPlan = Filter(condition, LogicalRelation(mockNativeRelation)) @@ -143,21 +146,23 @@ class XDDataFrameUtilsSpec extends BaseXDTest { val idAttributeReference = AttributeReference("id", IntegerType)(ExprId(0), Seq("students")) - val commentAttributeReference = AttributeReference("comment", StringType)(ExprId(1), Seq("students")) + val commentAttributeReference = + AttributeReference("comment", StringType)(ExprId(1), Seq("students")) val havingAttributeReference = AttributeReference("havingCondition", BooleanType)(ExprId(7)) val countAttributeReference = AttributeReference("c0", LongType)(ExprId(6)) - val filterPlan = Filter(GreaterThan(idAttributeReference, Literal(5)), LogicalRelation(mockNativeRelation)) + val filterPlan = + Filter(GreaterThan(idAttributeReference, Literal(5)), LogicalRelation(mockNativeRelation)) val firstProjectPlan = Project(Seq(commentAttributeReference), filterPlan) val aggregateLogicalPlan = Aggregate( - Seq(commentAttributeReference), - Seq( - Alias(GreaterThanOrEqual(Count(Literal(1)), Literal(1)), "havingCondition")(ExprId(7)), - Alias(Count(Literal(1)), "c0")(ExprId(6)) - ), - firstProjectPlan) + Seq(commentAttributeReference), + Seq( + Alias(GreaterThanOrEqual(Count(Literal(1)), Literal(1)), "havingCondition")(ExprId(7)), + Alias(Count(Literal(1)), "c0")(ExprId(6)) + ), + firstProjectPlan) val filterLogicalPlan = Filter(havingAttributeReference, aggregateLogicalPlan) @@ -166,12 +171,10 @@ class XDDataFrameUtilsSpec extends BaseXDTest { val sortOrder = SortOrder(Count(Literal(1)), Ascending) val sortLogicalPlan = Sort(Seq(sortOrder), global = true, projectLogicalPlan) - val optimizedLogicalPlan = Limit(Literal(5), sortLogicalPlan) val nativeRelation = findNativeQueryExecutor(optimizedLogicalPlan) assert(nativeRelation.nonEmpty) assert(nativeRelation.get === mockNativeRelation) } - } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/XDFunctionRegistryIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/XDFunctionRegistryIT.scala index 816f2515f..74d139cc6 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/XDFunctionRegistryIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/XDFunctionRegistryIT.scala @@ -29,7 +29,10 @@ class XDFunctionRegistryIT extends SharedXDContextTest { "XD Function registry" should "throw an analysis exception when a native udf cannot be resolved" in { try { - xdContext.sql(s"CREATE TEMPORARY TABLE jsonTable USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/core-reference.conf").toURI()).toString}')") + xdContext.sql( + s"CREATE TEMPORARY TABLE jsonTable USING org.apache.spark.sql.json OPTIONS (path '${Paths + .get(getClass.getResource("/core-reference.conf").toURI()) + .toString}')") val missingUDFName = "missingFunction" val thrown = the[AnalysisException] thrownBy sql(s"SELECT $missingUDFName() FROM jsonTable") diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/XDSessionIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/XDSessionIT.scala index 9baae0489..67dc35217 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/XDSessionIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/XDSessionIT.scala @@ -53,13 +53,16 @@ class XDSessionIT extends BaseXDTest with BeforeAndAfterAll { val coreConfig = ConfigFactory.parseMap(sessionConfig) val sqlConf = new SQLConf - sessionConfig.foreach { case (k, v) => sqlConf.setConfString(k.stripPrefix("config."), v.toString) } + sessionConfig.foreach { + case (k, v) => + sqlConf.setConfString(k.stripPrefix("config."), v.toString) + } (coreConfig, sqlConf) } new XDSession( - new XDSharedState(_sparkContext,sqlConf, new DerbyCatalog(sqlConf), None), - new XDSessionState(sqlConf, new HashmapCatalog(sqlConf) :: Nil) + new XDSharedState(_sparkContext, sqlConf, new DerbyCatalog(sqlConf), None), + new XDSessionState(sqlConf, new HashmapCatalog(sqlConf) :: Nil) ) } @@ -83,15 +86,21 @@ class XDSessionIT extends BaseXDTest with BeforeAndAfterAll { val xdSession1 = createNewDefaultSession val xdSession2 = createNewDefaultSession - val df: DataFrame = xdSession2.createDataFrame(xdSession2.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i"))), StructType(Array(StructField("id", StringType)))) + val df: DataFrame = xdSession2.createDataFrame( + xdSession2.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i"))), + StructType(Array(StructField("id", StringType)))) df.registerTempTable(tempTableName) xdSession2.table(tempTableName).collect should not be empty - a [RuntimeException] shouldBe thrownBy{ + a[RuntimeException] shouldBe thrownBy { xdSession1.table(tempTableName).collect should not be empty } - df.write.format("json").mode(SaveMode.Overwrite).option("path", s"/tmp/$persTableName").saveAsTable(persTableName) + df.write + .format("json") + .mode(SaveMode.Overwrite) + .option("path", s"/tmp/$persTableName") + .saveAsTable(persTableName) xdSession2.table(persTableName).collect should not be empty xdSession1.table(persTableName).collect should not be empty @@ -100,12 +109,13 @@ class XDSessionIT extends BaseXDTest with BeforeAndAfterAll { } - "A XDSession" should "perform a collect with a collection" in { val xdSession = createNewDefaultSession - val df: DataFrame = xdSession.createDataFrame(xdSession.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i"))), StructType(Array(StructField("id", StringType)))) + val df: DataFrame = xdSession.createDataFrame( + xdSession.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i"))), + StructType(Array(StructField("id", StringType)))) df.registerTempTable("records") val result: Array[Row] = xdSession.sql("SELECT * FROM records").collect() @@ -117,22 +127,28 @@ class XDSessionIT extends BaseXDTest with BeforeAndAfterAll { val xdSession = createNewDefaultSession - val df: DataFrame = xdSession.createDataFrame(xdSession.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i"))), StructType(Array(StructField("id", StringType)))) + val df: DataFrame = xdSession.createDataFrame( + xdSession.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i"))), + StructType(Array(StructField("id", StringType)))) df.registerTempTable("records") val dataframe = xdSession.sql("SELECT * FROM records") dataframe shouldBe a[XDDataFrame] } - it must "plan a PersistDataSource when creating a table " in { val xdSession = createNewDefaultSession - val dataframe = xdSession.sql(s"CREATE TABLE jsonTable USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/core-reference.conf").toURI()).toString}')") + val dataframe = + xdSession.sql(s"CREATE TABLE jsonTable USING org.apache.spark.sql.json OPTIONS (path '${Paths + .get(getClass.getResource("/core-reference.conf").toURI()) + .toString}')") val sparkPlan = dataframe.queryExecution.sparkPlan xdSession.catalog.dropTable(TableIdentifier("jsonTable", None)) - sparkPlan should matchPattern { case ExecutedCommand(_: PersistDataSourceTable) => } + sparkPlan should matchPattern { + case ExecutedCommand(_: PersistDataSourceTable) => + } } @@ -140,15 +156,18 @@ class XDSessionIT extends BaseXDTest with BeforeAndAfterAll { val xdSession = createNewDefaultSession - val t1: DataFrame = xdSession.createDataFrame(xdSession.sparkContext.parallelize((1 to 5) - .map(i => Row(s"val_$i", i))), StructType(Array(StructField("id", StringType), StructField("value", IntegerType)))) + val t1: DataFrame = xdSession.createDataFrame( + xdSession.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i", i))), + StructType(Array(StructField("id", StringType), StructField("value", IntegerType)))) t1.registerTempTable("t1") - val t2: DataFrame = xdSession.createDataFrame(xdSession.sparkContext.parallelize((4 to 8) - .map(i => Row(s"val_$i", i))), StructType(Array(StructField("name", StringType), StructField("value", IntegerType)))) + val t2: DataFrame = xdSession.createDataFrame( + xdSession.sparkContext.parallelize((4 to 8).map(i => Row(s"val_$i", i))), + StructType(Array(StructField("name", StringType), StructField("value", IntegerType)))) t2.registerTempTable("t2") - val dataFrame = xdSession.sql("SELECT t1.id, t2.name as name, t1.value as total FROM t1 INNER JOIN t2 ON t1.id = t2.name GROUP BY id, name, total") + val dataFrame = xdSession.sql( + "SELECT t1.id, t2.name as name, t1.value as total FROM t1 INNER JOIN t2 ON t1.id = t2.name GROUP BY id, name, total") dataFrame.show @@ -160,8 +179,9 @@ class XDSessionIT extends BaseXDTest with BeforeAndAfterAll { val xdSession = createNewDefaultSession - val t1: DataFrame = xdSession.createDataFrame(xdSession.sparkContext.parallelize((1 to 5) - .map(i => Row(s"val_$i", i))), StructType(Array(StructField("id", StringType), StructField("value", IntegerType)))) + val t1: DataFrame = xdSession.createDataFrame( + xdSession.sparkContext.parallelize((1 to 5).map(i => Row(s"val_$i", i))), + StructType(Array(StructField("id", StringType), StructField("value", IntegerType)))) t1.registerTempTable("t3") val dataFrame = xdSession.sql("SELECT id as id, value as product FROM t3 GROUP BY id, product") @@ -172,9 +192,12 @@ class XDSessionIT extends BaseXDTest with BeforeAndAfterAll { override protected def beforeAll(): Unit = { _sparkContext = new SparkContext( - "local[2]", - "test-xdsession", - new SparkConf().set("spark.cores.max", "2").set("spark.sql.testkey", "true").set("spark.sql.shuffle.partitions", "3") + "local[2]", + "test-xdsession", + new SparkConf() + .set("spark.cores.max", "2") + .set("spark.sql.testkey", "true") + .set("spark.sql.shuffle.partitions", "3") ) } @@ -183,11 +206,11 @@ class XDSessionIT extends BaseXDTest with BeforeAndAfterAll { } private def createNewDefaultSession: XDSession = { - val sqlConf = new SQLConf - new XDSession( + val sqlConf = new SQLConf + new XDSession( new XDSharedState(_sparkContext, sqlConf, new DerbyCatalog(sqlConf), None), new XDSessionState(sqlConf, new HashmapCatalog(sqlConf) :: Nil) - ) - } + ) + } } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/CatalogChainIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/CatalogChainIT.scala index 13d328177..b86404d71 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/CatalogChainIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/CatalogChainIT.scala @@ -40,7 +40,9 @@ class CatalogChainIT extends SharedXDContextTest { val firstFallbackCatalog = new HashmapCatalog(xdContext.conf) val secondfallbackCatalog = new HashmapCatalog(xdContext.conf) - val catalogChain = CatalogChain(prioritaryHashMapCatalog, firstFallbackCatalog, secondfallbackCatalog)(_xdContext) + val catalogChain = + CatalogChain(prioritaryHashMapCatalog, firstFallbackCatalog, secondfallbackCatalog)( + _xdContext) val localRelation: LocalRelation = { val attributes = AttributeReference("mystring", StringType)() :: Nil @@ -50,15 +52,15 @@ class CatalogChainIT extends SharedXDContextTest { secondfallbackCatalog.saveTable(TableNormalized, localRelation) - secondfallbackCatalog.relation(TableNormalized) should contain (localRelation) + secondfallbackCatalog.relation(TableNormalized) should contain(localRelation) prioritaryHashMapCatalog.relation(TableNormalized) shouldBe None firstFallbackCatalog.relation(TableNormalized) shouldBe None // Once we lookup the relation, it should be stored in prioritary and firstFallback catalogs catalogChain.lookupRelation(TableId) - prioritaryHashMapCatalog.relation(TableNormalized) should contain (localRelation) - firstFallbackCatalog.relation(TableNormalized) should contain (localRelation) + prioritaryHashMapCatalog.relation(TableNormalized) should contain(localRelation) + firstFallbackCatalog.relation(TableNormalized) should contain(localRelation) } } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/CatalogConstants.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/CatalogConstants.scala index 04505fbab..9f4c2b7c4 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/CatalogConstants.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/CatalogConstants.scala @@ -34,16 +34,30 @@ trait CatalogConstants { val SubField2 = StructField(SubField2Name, StringType, nullable = true) val arrayField = StructField(SubField2Name, ArrayType(StringType), nullable = true) val arrayFieldIntegers = StructField(SubField2Name, ArrayType(IntegerType), nullable = true) - val arrayFieldWithSubDocs = StructField(FieldWitStrangeChars, ArrayType(StructType(Seq(Field1, Field2)))) + val arrayFieldWithSubDocs = + StructField(FieldWitStrangeChars, ArrayType(StructType(Seq(Field1, Field2)))) val SourceDatasource = "org.apache.spark.sql.json" val Fields = Seq[StructField](Field1, Field2) val SubFields = Seq(SubField, SubField2) val Columns = StructType(Fields) - val ColumnsWithSubColumns = StructType(Seq(StructField(Field1Name, StringType, nullable = true), StructField(FieldWithSubcolumnsName, StructType(SubFields), nullable = true)) ) - val ColumnsWithArrayString = StructType(Seq(StructField(Field1Name, StringType, nullable = true), StructField(FieldWithSubcolumnsName, StructType(SubFields), nullable = true), arrayField) ) - val ColumnsWithArrayInteger = StructType(Seq(StructField(Field1Name, StringType, nullable = true), StructField(FieldWithSubcolumnsName, StructType(SubFields), nullable = true), arrayFieldIntegers) ) - val ColumnsWithArrayWithSubdocuments = StructType(Seq(StructField(Field1Name, StringType, nullable = true), StructField(FieldWithSubcolumnsName, StructType(SubFields), nullable = true), arrayFieldWithSubDocs) ) - val ColumnsWithMapWithArrayWithSubdocuments = StructType(Seq(StructField(Field1Name, MapType(ColumnsWithSubColumns, ColumnsWithArrayWithSubdocuments)))) + val ColumnsWithSubColumns = StructType( + Seq(StructField(Field1Name, StringType, nullable = true), + StructField(FieldWithSubcolumnsName, StructType(SubFields), nullable = true))) + val ColumnsWithArrayString = StructType( + Seq(StructField(Field1Name, StringType, nullable = true), + StructField(FieldWithSubcolumnsName, StructType(SubFields), nullable = true), + arrayField)) + val ColumnsWithArrayInteger = StructType( + Seq(StructField(Field1Name, StringType, nullable = true), + StructField(FieldWithSubcolumnsName, StructType(SubFields), nullable = true), + arrayFieldIntegers)) + val ColumnsWithArrayWithSubdocuments = StructType( + Seq(StructField(Field1Name, StringType, nullable = true), + StructField(FieldWithSubcolumnsName, StructType(SubFields), nullable = true), + arrayFieldWithSubDocs)) + val ColumnsWithMapWithArrayWithSubdocuments = StructType( + Seq(StructField(Field1Name, + MapType(ColumnsWithSubColumns, ColumnsWithArrayWithSubdocuments)))) val OptsJSON = Map("path" -> "/fake_path") val sqlView = s"select $Field1Name from $Database.$TableName" } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/InsensitiveCatalogIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/InsensitiveCatalogIT.scala index 041117aa1..aa09bacc0 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/InsensitiveCatalogIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/InsensitiveCatalogIT.scala @@ -26,13 +26,13 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner - @RunWith(classOf[JUnitRunner]) class InsensitiveCatalogIT extends DerbyCatalogIT { - override val coreConfig : Option[Config] = - Some(ConfigFactory.empty().withValue(s"config.${SQLConf.CASE_SENSITIVE.key}", ConfigValueFactory.fromAnyRef(false))) - + override val coreConfig: Option[Config] = Some( + ConfigFactory + .empty() + .withValue(s"config.${SQLConf.CASE_SENSITIVE.key}", ConfigValueFactory.fromAnyRef(false))) it should s"persist a table and retrieve it changing some letters to upper case in $catalogName" in { @@ -40,7 +40,11 @@ class InsensitiveCatalogIT extends DerbyCatalogIT { import XDCatalogCommon._ val tableIdentifier = TableIdentifier(tableNameOriginal, Some(Database)) val tableNormalized = tableIdentifier.normalize - val crossdataTable = CrossdataTable(tableNormalized, Some(Columns), SourceDatasource, Array[String](Field1Name), OptsJSON) + val crossdataTable = CrossdataTable(tableNormalized, + Some(Columns), + SourceDatasource, + Array[String](Field1Name), + OptsJSON) xdContext.catalog.persistTable(crossdataTable, OneRowRelation) xdContext.catalog.tableExists(tableIdentifier) shouldBe true diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/XDCatalogCommonSpec.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/XDCatalogCommonSpec.scala index cb448ca22..53bcf5c84 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/XDCatalogCommonSpec.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/XDCatalogCommonSpec.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.crossdata.catalog.interfaces.XDCatalogCommon import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner - @RunWith(classOf[JUnitRunner]) class XDCatalogCommonSpec extends BaseXDTest { @@ -31,9 +30,13 @@ class XDCatalogCommonSpec extends BaseXDTest { val tableIdentifier = TableIdentifier(tableName, Some(dbName)) import XDCatalogCommon._ - tableIdentifier.normalize(new SimpleCatalystConf(true)) shouldBe TableIdentifierNormalized(tableName, Some(dbName)) + tableIdentifier.normalize(new SimpleCatalystConf(true)) shouldBe TableIdentifierNormalized( + tableName, + Some(dbName)) - tableIdentifier.normalize(new SimpleCatalystConf(false)) shouldBe TableIdentifierNormalized(tableName.toLowerCase, Some(dbName.toLowerCase)) + tableIdentifier.normalize(new SimpleCatalystConf(false)) shouldBe TableIdentifierNormalized( + tableName.toLowerCase, + Some(dbName.toLowerCase)) } } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/persistent/DerbyCatalogIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/persistent/DerbyCatalogIT.scala index 3eecf504f..ae20b1413 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/persistent/DerbyCatalogIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/persistent/DerbyCatalogIT.scala @@ -20,8 +20,7 @@ import org.apache.spark.sql.crossdata.test.SharedXDContextTest import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner - @RunWith(classOf[JUnitRunner]) class DerbyCatalogIT extends { val catalogName = "Derby" -} with SharedXDContextTest with CatalogConstants with GenericCatalogTests \ No newline at end of file +} with SharedXDContextTest with CatalogConstants with GenericCatalogTests diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/persistent/GenericCatalogTests.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/persistent/GenericCatalogTests.scala index 825c8deab..776b841b2 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/persistent/GenericCatalogTests.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/persistent/GenericCatalogTests.scala @@ -35,7 +35,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { implicit lazy val conf: CatalystConf = xdContext.catalog.conf implicit def catalogToPersistenceWithCache(catalog: XDCatalog): PersistentCatalogWithCache = { - catalog.asInstanceOf[CatalogChain].persistentCatalogs.head.asInstanceOf[PersistentCatalogWithCache] + catalog + .asInstanceOf[CatalogChain] + .persistentCatalogs + .head + .asInstanceOf[PersistentCatalogWithCache] } implicit def catalogToTemporaryCatalog(catalog: XDCatalog): MapCatalog = { @@ -47,7 +51,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { val columns = StructType(fields) val opts = Map("path" -> "/fake_path") val tableIdentifier = Seq(TableName) - val crossdataTable = CrossdataTable(TableIdentifier(TableName, None).normalize, Some(Columns), SourceDatasource, Array.empty, opts) + val crossdataTable = CrossdataTable(TableIdentifier(TableName, None).normalize, + Some(Columns), + SourceDatasource, + Array.empty, + opts) xdContext.catalog.persistTableMetadata(crossdataTable) val dataframe = xdContext.sql(s"SELECT * FROM $TableName") @@ -57,7 +65,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { it should s"persist a table with catalog and partitionColumns in $catalogName" in { val tableIdentifier = TableIdentifier(TableName, Some(Database)) - val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) + val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, + Some(Columns), + SourceDatasource, + Array(Field1Name), + OptsJSON) xdContext.catalog.persistTableMetadata(crossdataTable) xdContext.catalog.tableExists(tableIdentifier) shouldBe true @@ -75,7 +87,6 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { xdContext.catalog.tableExists(viewIdentifier) shouldBe false } - it should s"not drop view that not exists " in { a[RuntimeException] shouldBe thrownBy { val viewIdentifier = TableIdentifier(ViewName, Option(Database)) @@ -87,7 +98,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { it should s"persist a table with catalog and partitionColumns with multiple subdocuments as schema in $catalogName" in { xdContext.catalog.dropAllTables() val tableIdentifier = TableIdentifier(TableName, Some(Database)) - val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, Some(ColumnsWithSubColumns), SourceDatasource, Array.empty, OptsJSON) + val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, + Some(ColumnsWithSubColumns), + SourceDatasource, + Array.empty, + OptsJSON) xdContext.catalog.persistTableMetadata(crossdataTable) xdContext.catalog.unregisterTable(tableIdentifier) @@ -101,7 +116,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { it should s"persist a table with catalog and partitionColumns with arrays as schema in $catalogName" in { xdContext.catalog.dropAllTables() val tableIdentifier = TableIdentifier(TableName, Some(Database)) - val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, Some(ColumnsWithArrayString), SourceDatasource, Array.empty, OptsJSON) + val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, + Some(ColumnsWithArrayString), + SourceDatasource, + Array.empty, + OptsJSON) xdContext.catalog.persistTableMetadata(crossdataTable) xdContext.catalog.unregisterTable(tableIdentifier) @@ -113,7 +132,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { it should s"persist a table with catalog and partitionColumns with array of integers as schema in $catalogName" in { xdContext.catalog.dropAllTables() val tableIdentifier = TableIdentifier(TableName, Some(Database)) - val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, Some(ColumnsWithArrayInteger), SourceDatasource, Array.empty, OptsJSON) + val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, + Some(ColumnsWithArrayInteger), + SourceDatasource, + Array.empty, + OptsJSON) xdContext.catalog.persistTableMetadata(crossdataTable) xdContext.catalog.unregisterTable(tableIdentifier) @@ -127,7 +150,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { s"characters in Field names as schema in $catalogName" in { xdContext.catalog.dropAllTables() val tableIdentifier = TableIdentifier(TableName, Some(Database)) - val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, Some(ColumnsWithArrayWithSubdocuments), SourceDatasource, Array.empty, OptsJSON) + val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, + Some(ColumnsWithArrayWithSubdocuments), + SourceDatasource, + Array.empty, + OptsJSON) xdContext.catalog.persistTableMetadata(crossdataTable) xdContext.catalog.unregisterTable(tableIdentifier) @@ -142,20 +169,33 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { s"characters in Field names as schema in $catalogName" in { xdContext.catalog.dropAllTables() val tableIdentifier = TableIdentifier(TableName, Some(Database)) - val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, Some(ColumnsWithMapWithArrayWithSubdocuments), SourceDatasource, Array.empty, OptsJSON) + val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, + Some(ColumnsWithMapWithArrayWithSubdocuments), + SourceDatasource, + Array.empty, + OptsJSON) xdContext.catalog.persistTableMetadata(crossdataTable) xdContext.catalog.unregisterTable(tableIdentifier) val schemaDF = xdContext.sql(s"DESCRIBE $Database.$TableName") schemaDF.count() should be(1) val df = xdContext.sql(s"SELECT `$Field1Name` FROM $Database.$TableName") df shouldBe a[XDDataFrame] - df.schema.apply(0).dataType shouldBe (MapType(ColumnsWithSubColumns, ColumnsWithArrayWithSubdocuments)) + df.schema.apply(0).dataType shouldBe (MapType(ColumnsWithSubColumns, + ColumnsWithArrayWithSubdocuments)) } it should "returns list of tables" in { xdContext.catalog.dropAllTables() - val crossdataTable1 = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) - val crossdataTable2 = CrossdataTable(TableIdentifier(TableName, None).normalize, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) + val crossdataTable1 = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, + Some(Columns), + SourceDatasource, + Array(Field1Name), + OptsJSON) + val crossdataTable2 = CrossdataTable(TableIdentifier(TableName, None).normalize, + Some(Columns), + SourceDatasource, + Array(Field1Name), + OptsJSON) xdContext.catalog.persistTableMetadata(crossdataTable1) xdContext.catalog.persistTableMetadata(crossdataTable2) @@ -169,8 +209,16 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { it should "drop tables" in { xdContext.catalog.dropAllTables() - val crossdataTable1 = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) - val crossdataTable2 = CrossdataTable(TableIdentifier(TableName, None).normalize, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) + val crossdataTable1 = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, + Some(Columns), + SourceDatasource, + Array(Field1Name), + OptsJSON) + val crossdataTable2 = CrossdataTable(TableIdentifier(TableName, None).normalize, + Some(Columns), + SourceDatasource, + Array(Field1Name), + OptsJSON) val tableIdentifier1 = TableIdentifier(TableName, Some(Database)) val tableIdentifier2 = TableIdentifier(TableName) xdContext.catalog.persistTableMetadata(crossdataTable1) @@ -191,14 +239,18 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { val tableIdentifier = TableIdentifier(TableName, Some(Database)) - a[RuntimeException] shouldBe thrownBy{ + a[RuntimeException] shouldBe thrownBy { xdContext.catalog.dropTable(tableIdentifier) } } it should "check if tables map is correct with databaseName" in { xdContext.catalog.dropAllTables() - val crossdataTable1 = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) + val crossdataTable1 = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, + Some(Columns), + SourceDatasource, + Array(Field1Name), + OptsJSON) val tableIdentifier2 = TableIdentifier(TableName) xdContext.catalog.persistTableMetadata(crossdataTable1) @@ -212,7 +264,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { it should "check if tables map is correct without databaseName " in { xdContext.catalog.dropAllTables() - val crossdataTable1 = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) + val crossdataTable1 = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, + Some(Columns), + SourceDatasource, + Array(Field1Name), + OptsJSON) val tableIdentifier2 = TableIdentifier(TableName) xdContext.catalog.persistTableMetadata(crossdataTable1) @@ -226,16 +282,20 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { it should "check if persisted tables are marked as not temporary" in { xdContext.catalog.dropAllTables() - val crossdataTable1 = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) + val crossdataTable1 = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, + Some(Columns), + SourceDatasource, + Array(Field1Name), + OptsJSON) val tableIdentifier2 = TableIdentifier(TableName) xdContext.catalog.persistTableMetadata(crossdataTable1) xdContext.catalog.registerTable(tableIdentifier2, LogicalRelation(new MockBaseRelation)) val tables = xdContext.catalog.getTables(None).toMap - if(xdContext.conf.caseSensitiveAnalysis){ + if (xdContext.conf.caseSensitiveAnalysis) { tables(s"$Database.$TableName") shouldBe false tables(TableName) shouldBe true - } else{ + } else { tables(s"${Database.toLowerCase}.${TableName.toLowerCase}") shouldBe false tables(TableName.toLowerCase) shouldBe true } @@ -244,18 +304,21 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { it should "describe a table persisted and non persisted with subcolumns" in { xdContext.catalog.dropAllTables() - val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, Some(ColumnsWithSubColumns), SourceDatasource, Array.empty, OptsJSON) + val crossdataTable = CrossdataTable(TableIdentifier(TableName, Some(Database)).normalize, + Some(ColumnsWithSubColumns), + SourceDatasource, + Array.empty, + OptsJSON) xdContext.catalog.persistTableMetadata(crossdataTable) xdContext.sql(s"DESCRIBE $Database.$TableName").count() should not be 0 } - - it should s"persist an App in catalog "in { + it should s"persist an App in catalog " in { val crossdataApp = CrossdataApp("hdfs://url/myjar.jar", "myApp", "com.stratio.app.main") xdContext.catalog.persistAppMetadata(crossdataApp) - val res:CrossdataApp=xdContext.catalog.lookupApp("myApp").get + val res: CrossdataApp = xdContext.catalog.lookupApp("myApp").get res shouldBe a[CrossdataApp] res.jar shouldBe "hdfs://url/myjar.jar" res.appAlias shouldBe "myApp" @@ -270,9 +333,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { val dataSource = "mongo" val opts = Map[String, String]() val version = "1.5.0" - val crossdataIndex = CrossdataIndex(tableIdentifier, indexIdentifier, indexedCols, pk, dataSource, opts, version) + val crossdataIndex = + CrossdataIndex(tableIdentifier, indexIdentifier, indexedCols, pk, dataSource, opts, version) - val crossdataTable = CrossdataTable(tableIdentifier, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) + val crossdataTable = + CrossdataTable(tableIdentifier, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) xdContext.catalog.persistTableMetadata(crossdataTable) xdContext.catalog.persistIndex(crossdataIndex) @@ -291,9 +356,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { val dataSource = "mongo" val opts = Map[String, String]() val version = "1.5.0" - val crossdataIndex = CrossdataIndex(tableIdentifier, indexIdentifier, indexedCols, pk, dataSource, opts, version) + val crossdataIndex = + CrossdataIndex(tableIdentifier, indexIdentifier, indexedCols, pk, dataSource, opts, version) - val crossdataTable = CrossdataTable(tableIdentifier, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) + val crossdataTable = + CrossdataTable(tableIdentifier, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) xdContext.catalog.persistTableMetadata(crossdataTable) xdContext.catalog.persistIndex(crossdataIndex) @@ -311,9 +378,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { val dataSource = "mongo" val opts = Map[String, String]() val version = "1.5.0" - val crossdataIndex = CrossdataIndex(tableIdentifier, indexIdentifier, indexedCols, pk, dataSource, opts, version) + val crossdataIndex = + CrossdataIndex(tableIdentifier, indexIdentifier, indexedCols, pk, dataSource, opts, version) - val crossdataTable = CrossdataTable(tableIdentifier, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) + val crossdataTable = + CrossdataTable(tableIdentifier, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) xdContext.catalog.persistTableMetadata(crossdataTable) xdContext.catalog.persistIndex(crossdataIndex) @@ -331,9 +400,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { val dataSource = "mongo" val opts = Map[String, String]() val version = "1.5.0" - val crossdataIndex = CrossdataIndex(tableIdentifier, indexIdentifier, indexedCols, pk, dataSource, opts, version) + val crossdataIndex = + CrossdataIndex(tableIdentifier, indexIdentifier, indexedCols, pk, dataSource, opts, version) - val crossdataTable = CrossdataTable(tableIdentifier, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) + val crossdataTable = + CrossdataTable(tableIdentifier, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) xdContext.catalog.persistTableMetadata(crossdataTable) xdContext.catalog.persistIndex(crossdataIndex) @@ -351,8 +422,11 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { val table = CrossdataTable(tableIdentifier, None, "com.stratio.crossdata.connector.mongodb") - val index = CrossdataIndex(tableIdentifier, indexIdentifier, - Seq(), "pk", "com.stratio.crossdata.connector.elasticsearch") + val index = CrossdataIndex(tableIdentifier, + indexIdentifier, + Seq(), + "pk", + "com.stratio.crossdata.connector.elasticsearch") xdContext.catalog.persistTable(table, LocalRelation()) xdContext.catalog.persistIndex(index) @@ -370,11 +444,15 @@ trait GenericCatalogTests extends SharedXDContextTest with CatalogConstants { val table = CrossdataTable(tableIdentifier, None, "com.stratio.crossdata.connector.mongodb") - val index = CrossdataIndex(tableIdentifier, indexIdentifier, - Seq(), "pk", "com.stratio.crossdata.connector.elasticsearch") + val index = CrossdataIndex(tableIdentifier, + indexIdentifier, + Seq(), + "pk", + "com.stratio.crossdata.connector.elasticsearch") val tableGenerated = CrossdataTable(indexIdentifier.asTableIdentifierNormalized, - None, "com.stratio.crossdata.connector.elasticsearch") + None, + "com.stratio.crossdata.connector.elasticsearch") xdContext.catalog.persistTable(table, LocalRelation()) xdContext.catalog.persistIndex(index) diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/persistent/ZookeeperCatalogIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/persistent/ZookeeperCatalogIT.scala index d111699cb..8429e560e 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/persistent/ZookeeperCatalogIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/persistent/ZookeeperCatalogIT.scala @@ -27,15 +27,20 @@ import scala.util.Try @RunWith(classOf[JUnitRunner]) class ZookeeperCatalogIT extends { val catalogName = "Zookeeper" -} with SharedXDContextTest with CatalogConstants with GenericCatalogTests with ZookeeperDefaultTestConstants{ +} with SharedXDContextTest with CatalogConstants with GenericCatalogTests +with ZookeeperDefaultTestConstants { - override val coreConfig : Option[Config] = { - val zkResourceConfig = - Try(ConfigFactory.load("zookeeper-catalog.conf").getConfig(CoreConfig.ParentConfigName)).toOption + override val coreConfig: Option[Config] = { + val zkResourceConfig = Try( + ConfigFactory + .load("zookeeper-catalog.conf") + .getConfig(CoreConfig.ParentConfigName)).toOption - ZookeeperConnection.fold(zkResourceConfig) {connectionString => - zkResourceConfig.flatMap(resourceConfig => - Option(resourceConfig.withValue(ZookeeperConnectionKey, ConfigValueFactory.fromAnyRef(connectionString)))) + ZookeeperConnection.fold(zkResourceConfig) { connectionString => + zkResourceConfig.flatMap( + resourceConfig => + Option(resourceConfig.withValue(ZookeeperConnectionKey, + ConfigValueFactory.fromAnyRef(connectionString)))) } } @@ -43,6 +48,6 @@ class ZookeeperCatalogIT extends { sealed trait ZookeeperDefaultTestConstants { val ZookeeperConnectionKey = "catalog.zookeeper.connectionString" - val ZookeeperConnection: Option[String] = - Try(ConfigFactory.load().getString(ZookeeperConnectionKey)).toOption -} \ No newline at end of file + val ZookeeperConnection: Option[String] = Try( + ConfigFactory.load().getString(ZookeeperConnectionKey)).toOption +} diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/streaming/ZookeeperStreamingCatalogIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/streaming/ZookeeperStreamingCatalogIT.scala index c8b21212a..3ee2c455e 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/streaming/ZookeeperStreamingCatalogIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/streaming/ZookeeperStreamingCatalogIT.scala @@ -26,19 +26,23 @@ import org.scalatest.junit.JUnitRunner import scala.util.Try @RunWith(classOf[JUnitRunner]) -class ZookeeperStreamingCatalogIT extends SharedXDContextTest with CatalogConstants with ZookeeperStreamingDefaultTestConstants { +class ZookeeperStreamingCatalogIT + extends SharedXDContextTest + with CatalogConstants + with ZookeeperStreamingDefaultTestConstants { override val coreConfig: Option[Config] = { - val zkResourceConfig = - Try(ConfigFactory.load("core-reference.conf").getConfig(CoreConfig.ParentConfigName)).toOption + val zkResourceConfig = Try( + ConfigFactory.load("core-reference.conf").getConfig(CoreConfig.ParentConfigName)).toOption ZookeeperConnection.fold(zkResourceConfig) { connectionString => - zkResourceConfig.flatMap(resourceConfig => - Option(resourceConfig.withValue(ZookeeperStreamingConnectionKey, ConfigValueFactory.fromAnyRef(connectionString)))) + zkResourceConfig.flatMap( + resourceConfig => + Option(resourceConfig.withValue(ZookeeperStreamingConnectionKey, + ConfigValueFactory.fromAnyRef(connectionString)))) } } - s"ZookeeperStreamingCatalogSpec" should "persist ephemeral tables" in { val streamCatalog = xdContext.catalog @@ -85,7 +89,7 @@ class ZookeeperStreamingCatalogIT extends SharedXDContextTest with CatalogConsta val streamCatalog = xdContext.catalog - an [Exception] should be thrownBy streamCatalog.dropEphemeralTable("stronker") + an[Exception] should be thrownBy streamCatalog.dropEphemeralTable("stronker") } it should "not fail when droppingAll ephemeral tables even though the catalog is empty" in { @@ -103,28 +107,30 @@ class ZookeeperStreamingCatalogIT extends SharedXDContextTest with CatalogConsta streamCatalog.createEphemeralTable(EphemeralTable) streamCatalog.updateEphemeralStatus( - EphemeralTableName, - EphemeralStatusModel(EphemeralTableName, EphemeralExecutionStatus.Started) + EphemeralTableName, + EphemeralStatusModel(EphemeralTableName, EphemeralExecutionStatus.Started) ) - the [Exception] thrownBy { + the[Exception] thrownBy { streamCatalog.dropEphemeralTable(EphemeralTableName) } should have message "The ephemeral is running. The process should be stopped first using 'Stop '" streamCatalog.updateEphemeralStatus( - EphemeralTableName, - EphemeralStatusModel(EphemeralTableName, EphemeralExecutionStatus.Stopped) + EphemeralTableName, + EphemeralStatusModel(EphemeralTableName, EphemeralExecutionStatus.Stopped) ) streamCatalog.dropEphemeralTable(EphemeralTableName) } - it should "manage ephemeral table status" in { val streamCatalog = xdContext.catalog streamCatalog.createEphemeralTable(EphemeralTable) streamCatalog.getEphemeralStatus(EphemeralTableName).isDefined shouldBe true - streamCatalog.getEphemeralStatus(EphemeralTableName).get.status shouldBe EphemeralExecutionStatus.NotStarted + streamCatalog + .getEphemeralStatus(EphemeralTableName) + .get + .status shouldBe EphemeralExecutionStatus.NotStarted streamCatalog.dropEphemeralTable(EphemeralTableName) streamCatalog.getEphemeralStatus(EphemeralTableName).isDefined shouldBe false @@ -144,10 +150,9 @@ class ZookeeperStreamingCatalogIT extends SharedXDContextTest with CatalogConsta } - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ protected override def afterAll(): Unit = { xdContext.catalog.dropAllEphemeralTables() super.afterAll() @@ -157,22 +162,24 @@ class ZookeeperStreamingCatalogIT extends SharedXDContextTest with CatalogConsta sealed trait ZookeeperStreamingDefaultTestConstants { val ZookeeperStreamingConnectionKey = "streaming.catalog.zookeeper.connectionString" - val ZookeeperConnection: Option[String] = - Try(ConfigFactory.load().getString(ZookeeperStreamingConnectionKey)).toOption + val ZookeeperConnection: Option[String] = Try( + ConfigFactory.load().getString(ZookeeperStreamingConnectionKey)).toOption // Ephemeral table val EphemeralTableName = "epheTable" val KafkaOptions = KafkaOptionsModel( - ConnectionHostModel(Seq(ConnectionModel("zkHost", 2020)), Seq(ConnectionModel("kafkaHost", 2125))), - Seq(TopicModel("topic", 1)), - "groupId", None, - Map("key" -> "value"), - "MEMORY_AND_DISK" ) - val EphemeralTableOptions = EphemeralOptionsModel(KafkaOptions,5) + ConnectionHostModel(Seq(ConnectionModel("zkHost", 2020)), + Seq(ConnectionModel("kafkaHost", 2125))), + Seq(TopicModel("topic", 1)), + "groupId", + None, + Map("key" -> "value"), + "MEMORY_AND_DISK") + val EphemeralTableOptions = EphemeralOptionsModel(KafkaOptions, 5) val EphemeralTable = EphemeralTableModel(EphemeralTableName, EphemeralTableOptions) //Queries val QueryAlias = "qalias" val Sql = "select * from epheTable" val EphemeralQuery = EphemeralQueryModel(EphemeralTableName, Sql, QueryAlias, 5, Map.empty) -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/temporary/XDTemporaryCatalogTests.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/temporary/XDTemporaryCatalogTests.scala index 28a838eea..8dcdf53d4 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/temporary/XDTemporaryCatalogTests.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/catalog/temporary/XDTemporaryCatalogTests.scala @@ -41,7 +41,8 @@ trait XDTemporaryCatalogTests extends SharedXDContextTest with CatalogConstants val columns = StructType(fields) val opts = Map("path" -> "/fake_path") val tableIdentifier = TableIdentifier(TableName).normalize - val crossdataTable = CrossdataTable(tableIdentifier, Some(Columns), SourceDatasource, Array.empty, opts) + val crossdataTable = + CrossdataTable(tableIdentifier, Some(Columns), SourceDatasource, Array.empty, opts) temporaryCatalog.relation(tableIdentifier) shouldBe empty @@ -52,7 +53,8 @@ trait XDTemporaryCatalogTests extends SharedXDContextTest with CatalogConstants it should s"register a table with catalog and partitionColumns in $catalogName" in { val tableIdentifier = TableIdentifier(TableName, Some(Database)).normalize - val crossdataTable = CrossdataTable(tableIdentifier, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) + val crossdataTable = + CrossdataTable(tableIdentifier, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) temporaryCatalog.saveTable(tableIdentifier, null, Some(crossdataTable)) @@ -60,30 +62,41 @@ trait XDTemporaryCatalogTests extends SharedXDContextTest with CatalogConstants } - it should s"register a table with catalog and partitionColumns with multiple subdocuments as schema in $catalogName" in { temporaryCatalog.dropAllTables() val tableIdentifier = TableIdentifier(TableName, Some(Database)).normalize - val crossdataTable = CrossdataTable(tableIdentifier, Some(ColumnsWithSubColumns), SourceDatasource, Array.empty, OptsJSON) + val crossdataTable = CrossdataTable(tableIdentifier, + Some(ColumnsWithSubColumns), + SourceDatasource, + Array.empty, + OptsJSON) temporaryCatalog.saveTable(tableIdentifier, null, Some(crossdataTable)) temporaryCatalog.relation(tableIdentifier) shouldBe defined } - it should "returns list of tables" in { temporaryCatalog.dropAllTables() val tableIdentifier1 = TableIdentifier(TableName, Some(Database)).normalize val tableIdentifier2 = TableIdentifier(TableName, None).normalize - val crossdataTable1 = CrossdataTable(tableIdentifier1, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) - val crossdataTable2 = CrossdataTable(tableIdentifier2, Some(Columns), SourceDatasource, Array(Field1Name), OptsJSON) + val crossdataTable1 = CrossdataTable(tableIdentifier1, + Some(Columns), + SourceDatasource, + Array(Field1Name), + OptsJSON) + val crossdataTable2 = CrossdataTable(tableIdentifier2, + Some(Columns), + SourceDatasource, + Array(Field1Name), + OptsJSON) temporaryCatalog.saveTable(tableIdentifier1, null, Some(crossdataTable1)) temporaryCatalog.saveTable(tableIdentifier2, null, Some(crossdataTable2)) - val tables = temporaryCatalog.allRelations(Some(StringNormalized(normalizeIdentifier(Database, conf)))) + val tables = + temporaryCatalog.allRelations(Some(StringNormalized(normalizeIdentifier(Database, conf)))) tables should have length 1 val tables2 = temporaryCatalog.allRelations() @@ -108,7 +121,6 @@ trait XDTemporaryCatalogTests extends SharedXDContextTest with CatalogConstants temporaryCatalog.relation(viewIdentifier) shouldBe empty } - protected override def beforeAll(): Unit = { super.beforeAll() implicitContext = _xdContext diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/catalyst/analysis/AnalysisTest.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/catalyst/analysis/AnalysisTest.scala index 30d876f2f..403cc5f17 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/catalyst/analysis/AnalysisTest.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/catalyst/analysis/AnalysisTest.scala @@ -31,11 +31,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.types.StringType -trait AnalysisTest extends BaseXDTest{ +trait AnalysisTest extends BaseXDTest { - val testRelation = LocalRelation( - AttributeReference("col1", StringType)(), - AttributeReference("col2", StringType)()) + val testRelation = LocalRelation(AttributeReference("col1", StringType)(), + AttributeReference("col2", StringType)()) val caseSensitiveAnalyzer = { val caseSensitiveConf = new SimpleCatalystConf(true) @@ -47,8 +46,7 @@ trait AnalysisTest extends BaseXDTest{ } } - protected def checkAnalysis(inputPlan: LogicalPlan, - expectedPlan: LogicalPlan): Unit = { + protected def checkAnalysis(inputPlan: LogicalPlan, expectedPlan: LogicalPlan): Unit = { val analyzer = caseSensitiveAnalyzer val actualPlan = analyzer.execute(inputPlan) analyzer.checkAnalysis(actualPlan) @@ -68,8 +66,7 @@ trait AnalysisTest extends BaseXDTest{ val normalized1 = normalizeExprIds(plan1) val normalized2 = normalizeExprIds(plan2) if (normalized1 != normalized2) { - fail( - s""" + fail(s""" |== FAIL: Plans do not match === |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} """.stripMargin) diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/catalyst/analysis/CrossdataAnalysisSpec.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/catalyst/analysis/CrossdataAnalysisSpec.scala index a5220d15c..8a51abb0a 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/catalyst/analysis/CrossdataAnalysisSpec.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/catalyst/analysis/CrossdataAnalysisSpec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.crossdata.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -class CrossdataAnalysisSpec extends AnalysisTest{ +class CrossdataAnalysisSpec extends AnalysisTest { "CrossdataAggregateAlias rule" should "resolve alias references within the group by clause" in { val col1 = testRelation.output(0) diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/config/StreamingConfigSpec.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/config/StreamingConfigSpec.scala index 9db4502f1..34e7f4e11 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/config/StreamingConfigSpec.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/config/StreamingConfigSpec.scala @@ -25,26 +25,26 @@ class StreamingConfigSpec extends BaseXDTest { val EphemeralTableName = "ephtable" val KafkaGroupId = "xd1" - val KafkaTopic= "ephtable" + val KafkaTopic = "ephtable" val KafkaNumPartitions = 1 val EmptyTableOptions: Map[String, String] = Map.empty val MandatoryTableOptions: Map[String, String] = Map( - "receiver.kafka.topic" -> s"$KafkaTopic:$KafkaNumPartitions", - "receiver.kafka.groupId" -> KafkaGroupId + "receiver.kafka.topic" -> s"$KafkaTopic:$KafkaNumPartitions", + "receiver.kafka.groupId" -> KafkaGroupId ) val CompleteTableOptions: Map[String, String] = Map( - "receiver.storageLevel" -> "MEMORY_AND_DISK", - "atomicWindow" -> "10", - "outputFormat" -> "JSON", - "spark.cores.max" -> "3" - ) ++ MandatoryTableOptions - + "receiver.storageLevel" -> "MEMORY_AND_DISK", + "atomicWindow" -> "10", + "outputFormat" -> "JSON", + "spark.cores.max" -> "3" + ) ++ MandatoryTableOptions it should "add default options to the ephemeral table" in { - val ephTable = StreamingConfig.createEphemeralTableModel(EphemeralTableName, MandatoryTableOptions) + val ephTable = + StreamingConfig.createEphemeralTableModel(EphemeralTableName, MandatoryTableOptions) val options = ephTable.options options.atomicWindow shouldBe 5 @@ -52,12 +52,12 @@ class StreamingConfigSpec extends BaseXDTest { options.maxWindow shouldBe 10 options.outputFormat shouldBe EphemeralOutputFormat.ROW - options.sparkOptions should contain ("spark.cores.max", "2") - options.sparkOptions should contain ("spark.stopGracefully", "true") + options.sparkOptions should contain("spark.cores.max", "2") + options.sparkOptions should contain("spark.stopGracefully", "true") options.kafkaOptions.connection shouldBe ConnectionHostModel( - Seq(ConnectionModel("localhost", 2181)), - Seq(ConnectionModel("localhost", 9092))) + Seq(ConnectionModel("localhost", 2181)), + Seq(ConnectionModel("localhost", 9092))) options.kafkaOptions.storageLevel shouldBe "MEMORY_AND_DISK_SER" // table options @@ -67,19 +67,21 @@ class StreamingConfigSpec extends BaseXDTest { } it should "override default options" in { - val ephTable = StreamingConfig.createEphemeralTableModel(EphemeralTableName, CompleteTableOptions) + val ephTable = + StreamingConfig.createEphemeralTableModel(EphemeralTableName, CompleteTableOptions) val options = ephTable.options options.atomicWindow shouldBe 10 options.outputFormat shouldBe EphemeralOutputFormat.JSON - options.sparkOptions should contain ("spark.cores.max", "3") + options.sparkOptions should contain("spark.cores.max", "3") options.kafkaOptions.storageLevel shouldBe "MEMORY_AND_DISK" } it should "fail if spark.cores.max is less than 2" in { val wrongOptions = CompleteTableOptions + (StreamingConstants.SparkCoresMax -> "1") - an [Exception] should be thrownBy StreamingConfig.createEphemeralTableModel(EphemeralTableName, wrongOptions) + an[Exception] should be thrownBy StreamingConfig.createEphemeralTableModel(EphemeralTableName, + wrongOptions) } } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/DdlSpec.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/DdlSpec.scala index 8edffd522..dd6cab502 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/DdlSpec.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/DdlSpec.scala @@ -28,113 +28,128 @@ import org.scalatest.mock.MockitoSugar import scala.util.Success @RunWith(classOf[JUnitRunner]) -class DdlSpec extends BaseXDTest with MockitoSugar{ +class DdlSpec extends BaseXDTest with MockitoSugar { + "Ddl" should "successfully convert from ByteType to Byte" in { - "Ddl" should "successfully convert from ByteType to Byte" in { - - DDLUtils.convertSparkDatatypeToScala("4", ByteType) shouldBe Success(4 : Byte) + DDLUtils.convertSparkDatatypeToScala("4", ByteType) shouldBe Success(4: Byte) } - "Ddl" should "successfully convert from ShortType to Short" in { + "Ddl" should "successfully convert from ShortType to Short" in { - DDLUtils.convertSparkDatatypeToScala("6", ShortType) shouldBe Success(6 : Short) + DDLUtils.convertSparkDatatypeToScala("6", ShortType) shouldBe Success(6: Short) } - "Ddl" should "successfully convert from IntegerType to Integer" in { + "Ddl" should "successfully convert from IntegerType to Integer" in { - DDLUtils.convertSparkDatatypeToScala("25", IntegerType) shouldBe Success(25 : Int) + DDLUtils.convertSparkDatatypeToScala("25", IntegerType) shouldBe Success(25: Int) } - "Ddl" should "successfully convert from LongType to Long" in { + "Ddl" should "successfully convert from LongType to Long" in { - DDLUtils.convertSparkDatatypeToScala("-127", LongType) shouldBe Success(-127 : Long) + DDLUtils.convertSparkDatatypeToScala("-127", LongType) shouldBe Success(-127: Long) } - "Ddl" should "successfully convert from FloatType to Float" in { + "Ddl" should "successfully convert from FloatType to Float" in { - DDLUtils.convertSparkDatatypeToScala("-1.01", FloatType) shouldBe Success(-1.01f : Float) + DDLUtils.convertSparkDatatypeToScala("-1.01", FloatType) shouldBe Success(-1.01f: Float) } - "Ddl" should "successfully convert from DoubleType to Double" in { + "Ddl" should "successfully convert from DoubleType to Double" in { - DDLUtils.convertSparkDatatypeToScala("3.75", DoubleType) shouldBe Success(3.75 : Double) + DDLUtils.convertSparkDatatypeToScala("3.75", DoubleType) shouldBe Success(3.75: Double) } - "Ddl" should "successfully convert from DecimalType to BigDecimal" in { + "Ddl" should "successfully convert from DecimalType to BigDecimal" in { - DDLUtils.convertSparkDatatypeToScala("-106.75", DecimalType.SYSTEM_DEFAULT) shouldBe Success(BigDecimal(-106.75)) + DDLUtils.convertSparkDatatypeToScala("-106.75", DecimalType.SYSTEM_DEFAULT) shouldBe Success( + BigDecimal(-106.75)) } - "Ddl" should "successfully convert from StringType to String" in { + "Ddl" should "successfully convert from StringType to String" in { DDLUtils.convertSparkDatatypeToScala("abcde", StringType) shouldBe Success("abcde") } - "Ddl" should "successfully convert from BooleanType to Boolean" in { + "Ddl" should "successfully convert from BooleanType to Boolean" in { - DDLUtils.convertSparkDatatypeToScala("false", BooleanType) shouldBe Success(false : Boolean) + DDLUtils.convertSparkDatatypeToScala("false", BooleanType) shouldBe Success(false: Boolean) } - "Ddl" should "successfully convert from DateType to Date" in { + "Ddl" should "successfully convert from DateType to Date" in { - DDLUtils.convertSparkDatatypeToScala("2015-01-01", DateType) shouldBe Success(Date.valueOf("2015-01-01")) + DDLUtils.convertSparkDatatypeToScala("2015-01-01", DateType) shouldBe Success( + Date.valueOf("2015-01-01")) } - "Ddl" should "successfully convert from TimestampType to Timestamp" in { + "Ddl" should "successfully convert from TimestampType to Timestamp" in { - DDLUtils.convertSparkDatatypeToScala("1988-08-11 11:12:13", TimestampType) shouldBe Success(Timestamp.valueOf("1988-08-11 11:12:13")) + DDLUtils.convertSparkDatatypeToScala("1988-08-11 11:12:13", TimestampType) shouldBe Success( + Timestamp.valueOf("1988-08-11 11:12:13")) } - "Ddl" should "successfully convert from ArrayType to Array" in { + "Ddl" should "successfully convert from ArrayType to Array" in { - DDLUtils.convertSparkDatatypeToScala(List("1","2","3"), ArrayType(IntegerType)) shouldBe Success(Seq(1,2,3)) + DDLUtils + .convertSparkDatatypeToScala(List("1", "2", "3"), ArrayType(IntegerType)) shouldBe Success( + Seq(1, 2, 3)) - DDLUtils.convertSparkDatatypeToScala(List("1","2","3"), ArrayType(StringType)) shouldBe Success(Seq("1","2","3")) + DDLUtils + .convertSparkDatatypeToScala(List("1", "2", "3"), ArrayType(StringType)) shouldBe Success( + Seq("1", "2", "3")) - DDLUtils.convertSparkDatatypeToScala(List("proof one", "proof, two","proof three"), - ArrayType(StringType)) shouldBe Success(Seq("proof one", "proof, two","proof three")) + DDLUtils.convertSparkDatatypeToScala(List("proof one", "proof, two", "proof three"), + ArrayType(StringType)) shouldBe Success( + Seq("proof one", "proof, two", "proof three")) - DDLUtils.convertSparkDatatypeToScala(List("true"), - ArrayType(BooleanType)) shouldBe Success(Seq(true)) + DDLUtils.convertSparkDatatypeToScala(List("true"), ArrayType(BooleanType)) shouldBe Success( + Seq(true)) } - "Ddl" should "successfully convert from MapType to Map" in { + "Ddl" should "successfully convert from MapType to Map" in { - DDLUtils.convertSparkDatatypeToScala(Map("x"->"1","y"->"2"), MapType(StringType,IntegerType)) shouldBe Success(Map(("x",1),("y",2))) + DDLUtils.convertSparkDatatypeToScala(Map("x" -> "1", "y" -> "2"), + MapType(StringType, IntegerType)) shouldBe Success( + Map(("x", 1), ("y", 2))) - DDLUtils.convertSparkDatatypeToScala(Map("x1" -> "proof,comma","x2" -> "proof2"), - MapType(StringType,StringType)) shouldBe Success(Map("x1" -> "proof,comma","x2" -> "proof2")) + DDLUtils.convertSparkDatatypeToScala(Map("x1" -> "proof,comma", "x2" -> "proof2"), + MapType(StringType, StringType)) shouldBe Success( + Map("x1" -> "proof,comma", "x2" -> "proof2")) DDLUtils.convertSparkDatatypeToScala(Map("1" -> "true", "2" -> "false", "3" -> "true"), - MapType(IntegerType,BooleanType)) shouldBe Success(Map(1 -> true, 2-> false, 3 -> true)) + MapType(IntegerType, BooleanType)) shouldBe Success( + Map(1 -> true, 2 -> false, 3 -> true)) } - "Ddl" should "successfully convert from MapType with ArrayType and viceversa" in { - - DDLUtils.convertSparkDatatypeToScala(List(Map("x" -> "1", "y" -> "2"),Map("z"-> "3")), - ArrayType(MapType(StringType,IntegerType))) shouldBe Success(Seq( Map("x" -> 1, "y"-> 2 ), Map( "z" -> 3))) - - DDLUtils.convertSparkDatatypeToScala(Map("x" -> List("3","4"), "y" -> List("5","6")), - MapType(StringType, ArrayType(IntegerType))) shouldBe Success(Map( "x" ->Seq(3,4), "y" -> Seq(5,6) )) + "Ddl" should "successfully convert from MapType with ArrayType and viceversa" in { - DDLUtils.convertSparkDatatypeToScala(Map("true" -> List("3","4"), "false" -> List("5","6")), - MapType(BooleanType, ArrayType(IntegerType))) shouldBe Success(Map( true ->Seq(3,4), false -> Seq(5,6) )) + DDLUtils.convertSparkDatatypeToScala( + List(Map("x" -> "1", "y" -> "2"), Map("z" -> "3")), + ArrayType(MapType(StringType, IntegerType))) shouldBe Success( + Seq(Map("x" -> 1, "y" -> 2), Map("z" -> 3))) + DDLUtils.convertSparkDatatypeToScala( + Map("x" -> List("3", "4"), "y" -> List("5", "6")), + MapType(StringType, ArrayType(IntegerType))) shouldBe Success( + Map("x" -> Seq(3, 4), "y" -> Seq(5, 6))) + DDLUtils.convertSparkDatatypeToScala( + Map("true" -> List("3", "4"), "false" -> List("5", "6")), + MapType(BooleanType, ArrayType(IntegerType))) shouldBe Success( + Map(true -> Seq(3, 4), false -> Seq(5, 6))) } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/DropTableIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/DropTableIT.scala index 738605b67..606b65aed 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/DropTableIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/DropTableIT.scala @@ -37,21 +37,33 @@ class DropTableIT extends SharedXDContextTest { private val Schema = StructType(Seq(StructField("col", StringType))) implicit def catalogToPersistenceWithCache(catalog: XDCatalog): PersistentCatalogWithCache = { - catalog.asInstanceOf[CatalogChain].persistentCatalogs.head.asInstanceOf[PersistentCatalogWithCache] + catalog + .asInstanceOf[CatalogChain] + .persistentCatalogs + .head + .asInstanceOf[PersistentCatalogWithCache] } implicit lazy val conf: CatalystConf = xdContext.catalog.conf "DropTable command" should "remove a table from Crossdata catalog" in { - _xdContext.catalog.persistTableMetadata(CrossdataTable(TableIdentifier(TableName, None).normalize, Some(Schema), DatasourceName, opts = Map("path" -> "fakepath"))) + _xdContext.catalog.persistTableMetadata( + CrossdataTable(TableIdentifier(TableName, None).normalize, + Some(Schema), + DatasourceName, + opts = Map("path" -> "fakepath"))) _xdContext.catalog.tableExists(TableIdentifier(TableName)) shouldBe true sql(s"DROP TABLE $TableName") _xdContext.catalog.tableExists(TableIdentifier(TableName)) shouldBe false } it should "remove a qualified table from Crossdata catalog" in { - _xdContext.catalog.persistTableMetadata(CrossdataTable(TableIdentifier(TableName, Some(DatabaseName)).normalize, Some(Schema), DatasourceName, opts = Map("path" -> "fakepath"))) + _xdContext.catalog.persistTableMetadata( + CrossdataTable(TableIdentifier(TableName, Some(DatabaseName)).normalize, + Some(Schema), + DatasourceName, + opts = Map("path" -> "fakepath"))) _xdContext.catalog.tableExists(TableIdentifier(TableName, Some(DatabaseName))) shouldBe true sql(s"DROP TABLE $DatabaseName.$TableName") _xdContext.catalog.tableExists(TableIdentifier(TableName, Some(DatabaseName))) shouldBe false diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/StreamingDdlIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/StreamingDdlIT.scala index ae757a60e..94c41b41d 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/StreamingDdlIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/StreamingDdlIT.scala @@ -42,8 +42,8 @@ class StreamingDdlIT extends SharedXDContextTest with StreamingDDLTestConstants{ } /** - * Parser tests - */ + * Parser tests + */ "StreamingDDLParser" should "parse an add ephemeral query" in { val logicalPlan = xdContext.ddlParser.parse(s"ADD $Sql WITH WINDOW $Window SEC AS $QueryName") logicalPlan shouldBe AddEphemeralQuery(EphemeralTableName, Sql, QueryName, Window) @@ -79,8 +79,8 @@ class StreamingDdlIT extends SharedXDContextTest with StreamingDDLTestConstants{ } /** - * Command tests - */ + * Command tests + */ "StreamingCommand" should "allow to describe ephemeral tables" in { val ephTableIdentifier = TableIdentifier(EphemeralTableName, Some("db")) @@ -383,11 +383,11 @@ class StreamingDdlIT extends SharedXDContextTest with StreamingDDLTestConstants{ } /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ protected override def afterAll(): Unit = { xdContext.streamingCatalog.foreach(_.dropAllEphemeralTables()) super.afterAll() } } -*/ \ No newline at end of file + */ diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/StreamingDdlParserSpec.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/StreamingDdlParserSpec.scala index 6688f3123..ba57172df 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/StreamingDdlParserSpec.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/StreamingDdlParserSpec.scala @@ -30,25 +30,27 @@ import org.scalatest.mock.MockitoSugar import scala.util.Try - @RunWith(classOf[JUnitRunner]) class StreamingDdlParserSpec extends BaseXDTest with StreamingDDLTestConstants with MockitoSugar { val xdContext = mock[XDContext] val parser = new XDDdlParser(_ => null, xdContext) - // EPHEMERAL TABLE "StreamingDDLParser" should "parse a create ephemeral table" in { - val logicalPlan = - parser.parse(s"CREATE EPHEMERAL TABLE $EphemeralTableName (id STRING) OPTIONS (kafka.options.opKey 'value')") - logicalPlan shouldBe CreateEphemeralTable(EphemeralTableIdentifier, Some(EphemeralTableSchema), Map("kafka.options.opKey" -> "value")) + val logicalPlan = parser.parse( + s"CREATE EPHEMERAL TABLE $EphemeralTableName (id STRING) OPTIONS (kafka.options.opKey 'value')") + logicalPlan shouldBe CreateEphemeralTable(EphemeralTableIdentifier, + Some(EphemeralTableSchema), + Map("kafka.options.opKey" -> "value")) } it should "parse a create ephemeral table without schema" in { - val logicalPlan = - parser.parse(s"CREATE EPHEMERAL TABLE $EphemeralTableName OPTIONS (kafka.options.opKey 'value')") - logicalPlan shouldBe CreateEphemeralTable(EphemeralTableIdentifier, None, Map("kafka.options.opKey" -> "value")) + val logicalPlan = parser.parse( + s"CREATE EPHEMERAL TABLE $EphemeralTableName OPTIONS (kafka.options.opKey 'value')") + logicalPlan shouldBe CreateEphemeralTable(EphemeralTableIdentifier, + None, + Map("kafka.options.opKey" -> "value")) } it should "parse a describe ephemeral table" in { @@ -71,7 +73,6 @@ class StreamingDdlParserSpec extends BaseXDTest with StreamingDDLTestConstants w logicalPlan shouldBe DropAllEphemeralTables } - // STATUS it should "parse a show ephemeral status" in { val logicalPlan = parser.parse(s"SHOW EPHEMERAL STATUS IN $EphemeralTableName") @@ -99,14 +100,13 @@ class StreamingDdlParserSpec extends BaseXDTest with StreamingDDLTestConstants w logicalPlan shouldBe ShowEphemeralQueries(None) } - it should "parse a show ephemeral queries with a specific table" in { val logicalPlan = parser.parse(s"SHOW EPHEMERAL QUERIES IN $EphemeralTableName") logicalPlan shouldBe ShowEphemeralQueries(Some(EphemeralTableIdentifier.unquotedString)) } it should "fail parsing an add query statement without window" in { - an [Exception] should be thrownBy parser.parse(s"ADD $Sql AS topic") + an[Exception] should be thrownBy parser.parse(s"ADD $Sql AS topic") } it should "parse a drop ephemeral query" in { @@ -115,7 +115,6 @@ class StreamingDdlParserSpec extends BaseXDTest with StreamingDDLTestConstants w } - it should "parse a drop all ephemeral queries" in { val logicalPlan = parser.parse(s"DROP ALL EPHEMERAL QUERIES") DropAllEphemeralQueries() @@ -128,7 +127,6 @@ class StreamingDdlParserSpec extends BaseXDTest with StreamingDDLTestConstants w } - } trait StreamingDDLTestConstants { @@ -144,15 +142,16 @@ trait StreamingDDLTestConstants { val KafkaTopic = "ephtable" val KafkaNumPartitions = 1 Map( - "receiver.kafka.topic" -> s"$KafkaTopic:$KafkaNumPartitions", - "receiver.kafka.groupId" -> KafkaGroupId + "receiver.kafka.topic" -> s"$KafkaTopic:$KafkaNumPartitions", + "receiver.kafka.groupId" -> KafkaGroupId ) } - val EphemeralTable = StreamingConfig.createEphemeralTableModel(EphemeralTableName, MandatoryTableOptions) + val EphemeralTable = + StreamingConfig.createEphemeralTableModel(EphemeralTableName, MandatoryTableOptions) val EphemeralQuery = EphemeralQueryModel(EphemeralTableName, Sql, QueryName, Window) val ZookeeperStreamingConnectionKey = "streaming.catalog.zookeeper.connectionString" - val ZookeeperConnection: Option[String] = - Try(ConfigFactory.load().getString(ZookeeperStreamingConnectionKey)).toOption + val ZookeeperConnection: Option[String] = Try( + ConfigFactory.load().getString(ZookeeperStreamingConnectionKey)).toOption } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/ViewsIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/ViewsIT.scala index a3933ed3b..c18474de6 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/ViewsIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/ViewsIT.scala @@ -28,7 +28,7 @@ class ViewsIT extends SharedXDContextTest { val sqlContext = _xdContext import sqlContext.implicits._ - val df = sqlContext.sparkContext.parallelize(1 to 5).toDF + val df = sqlContext.sparkContext.parallelize(1 to 5).toDF df.registerTempTable("person") sql("CREATE TEMPORARY VIEW vn AS SELECT * FROM person WHERE _1 < 3") @@ -51,6 +51,4 @@ class ViewsIT extends SharedXDContextTest { } - - } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/XDDdlParserSpec.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/XDDdlParserSpec.scala index 736346296..802c42677 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/XDDdlParserSpec.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/datasources/XDDdlParserSpec.scala @@ -27,16 +27,15 @@ import org.scalatest.junit.JUnitRunner import org.scalatest.mock.MockitoSugar @RunWith(classOf[JUnitRunner]) -class XDDdlParserSpec extends BaseXDTest with MockitoSugar{ +class XDDdlParserSpec extends BaseXDTest with MockitoSugar { val xdContext = mock[XDContext] val parser = new XDDdlParser(_ => null, xdContext) - "A XDDlParser" should """successfully parse an "IMPORT TABLES" sentence into + "A XDDlParser" should """successfully parse an "IMPORT TABLES" sentence into |a ImportTablesUsingWithOptions RunnableCommand """.stripMargin in { - val sentence = - """IMPORT TABLES + val sentence = """IMPORT TABLES | USING org.apache.dumypackage.dummyclass | OPTIONS ( | addr "dummyaddr", @@ -45,11 +44,11 @@ class XDDdlParserSpec extends BaseXDTest with MockitoSugar{ | """.stripMargin parser.parse(sentence) shouldBe ImportTablesUsingWithOptions( - "org.apache.dumypackage.dummyclass", - Map( - "addr" -> "dummyaddr", - "database" -> "dummydb" - ) + "org.apache.dumypackage.dummyclass", + Map( + "addr" -> "dummyaddr", + "database" -> "dummydb" + ) ) } @@ -57,32 +56,34 @@ class XDDdlParserSpec extends BaseXDTest with MockitoSugar{ it should "generate an ImportCatalogUsingWithOptions with empty options map when they haven't been provided" in { val sentence = "IMPORT TABLES USING org.apache.dumypackage.dummyclass" - parser.parse(sentence) shouldBe ImportTablesUsingWithOptions("org.apache.dumypackage.dummyclass", Map.empty) + parser.parse(sentence) shouldBe ImportTablesUsingWithOptions( + "org.apache.dumypackage.dummyclass", + Map.empty) } //Sentences and expected values for SparkSQL core DDL - val rightSentences = List[(String, PartialFunction[Any, Unit])] ( - ("""CREATE TEMPORARY TABLE words + val rightSentences = List[(String, PartialFunction[Any, Unit])]( + ("""CREATE TEMPORARY TABLE words |USING org.apache.spark.sql.cassandra |OPTIONS ( | table "words", | keyspace "test", | cluster "Test Cluster", | pushdown "true" - |)""".stripMargin, { case _: CreateTableUsing => () } ), - ("REFRESH TABLE ddb.dummyTable", { case _: RefreshTable => () } ), - ("DESCRIBE ddb.dummyTable", { case _: DescribeCommand => (); } ) + |)""".stripMargin, { case _: CreateTableUsing => () }), + ("REFRESH TABLE ddb.dummyTable", { case _: RefreshTable => () }), + ("DESCRIBE ddb.dummyTable", { case _: DescribeCommand => (); }) ) - for((sentence, expect) <- rightSentences) + for ((sentence, expect) <- rightSentences) it should s"keep parsing SparkSQL core DDL sentences: $sentence" in { expect.lift(parser.parse(sentence)) should not be None } //Malformed sentences and their expectations - val wrongSentences = List[String] ( - "IMPORT TABLES", - """IMPORT TABLES + val wrongSentences = List[String]( + "IMPORT TABLES", + """IMPORT TABLES | OPTIONS ( | addr "dummyaddr", | database "dummydb" @@ -91,41 +92,38 @@ class XDDdlParserSpec extends BaseXDTest with MockitoSugar{ ) wrongSentences foreach { sentence => it should s"fail when parsing wrong sentences: $sentence" in { - an [Exception] should be thrownBy parser.parse(sentence) + an[Exception] should be thrownBy parser.parse(sentence) } } - it should "successfully parse a DROP TABLE into a DropTable RunnableCommand" in { val sentence = "DROP TABLE tableId" - parser.parse(sentence) shouldBe DropTable( TableIdentifier("tableId", None)) + parser.parse(sentence) shouldBe DropTable(TableIdentifier("tableId", None)) } it should "successfully parse a DROP TABLE with a qualified table name into a DropTable RunnableCommand" in { val sentence = "DROP TABLE dbId.tableId" - parser.parse(sentence) shouldBe DropTable( TableIdentifier("tableId", Some("dbId"))) + parser.parse(sentence) shouldBe DropTable(TableIdentifier("tableId", Some("dbId"))) } it should "successfully parse a DROP EXTERNAL TABLE into a DropExternalTable RunnableCommand" in { val sentence = "DROP EXTERNAL TABLE tableId" - parser.parse(sentence) shouldBe DropExternalTable( TableIdentifier("tableId", None)) + parser.parse(sentence) shouldBe DropExternalTable(TableIdentifier("tableId", None)) } it should "successfully parse a DROP EXTERNAL TABLE with a qualified table name into a DropExternalTable RunnableCommand" in { val sentence = "DROP EXTERNAL TABLE dbId.tableId" - parser.parse(sentence) shouldBe DropExternalTable( TableIdentifier("tableId", Some("dbId"))) + parser.parse(sentence) shouldBe DropExternalTable(TableIdentifier("tableId", Some("dbId"))) } - - it should "successfully parse a INSERT TABLE with qualified table name and VALUES provided into InsertTable RunnableCommand" in { val sentence = """INSERT INTO tableId VALUES ( 12, 12.01, 'proof', true)""" @@ -136,73 +134,114 @@ class XDDdlParserSpec extends BaseXDTest with MockitoSugar{ it should "successfully parse a INSERT TABLE with qualified table name and more than one VALUES provided into InsertTable RunnableCommand" in { - val sentence = """INSERT INTO tableId VALUES ( 12, 12.01, 'proof', true), ( 2, 1.01, 'pof', true), ( 256, 0.01, 'pr', false)""" + val sentence = + """INSERT INTO tableId VALUES ( 12, 12.01, 'proof', true), ( 2, 1.01, 'pof', true), ( 256, 0.01, 'pr', false)""" parser.parse(sentence) shouldBe InsertIntoTable(TableIdentifier("tableId"), - List(List("12", "12.01", "proof", "true"),List("2", "1.01", "pof", "true"),List("256", "0.01", "pr", "false"))) + List(List("12", "12.01", "proof", "true"), + List("2", "1.01", "pof", "true"), + List("256", "0.01", "pr", "false"))) } it should "successfully parse a INSERT TABLE with qualified table name, schema and VALUES provided into InsertTable RunnableCommand" in { - val sentence = """INSERT INTO tableId(Column1, Column2, Column3, Column4) VALUES ( 256, 0.01, 'pr', false)""" + val sentence = + """INSERT INTO tableId(Column1, Column2, Column3, Column4) VALUES ( 256, 0.01, 'pr', false)""" parser.parse(sentence) shouldBe - InsertIntoTable(TableIdentifier("tableId"), List(List("256", "0.01", "pr", "false")), Some(List("Column1", "Column2", "Column3", "Column4"))) + InsertIntoTable(TableIdentifier("tableId"), + List(List("256", "0.01", "pr", "false")), + Some(List("Column1", "Column2", "Column3", "Column4"))) } it should "successfully parse a INSERT TABLE with qualified table name, schema and more than one VALUES provided into InsertTable RunnableCommand" in { - val sentence = """INSERT INTO tableId(Column1, Column2, Column3, Column4) VALUES ( 12, 12.01, 'proof', true), ( 2, 1.01, 'pof', true), ( 256, 0.01, 'pr', false)""" + val sentence = + """INSERT INTO tableId(Column1, Column2, Column3, Column4) VALUES ( 12, 12.01, 'proof', true), ( 2, 1.01, 'pof', true), ( 256, 0.01, 'pr', false)""" parser.parse(sentence) shouldBe InsertIntoTable(TableIdentifier("tableId"), - List(List("12", "12.01", "proof", "true"),List("2", "1.01", "pof", "true"),List("256", "0.01", "pr", "false")), - Some(List("Column1", "Column2", "Column3", "Column4"))) + List(List("12", "12.01", "proof", "true"), + List("2", "1.01", "pof", "true"), + List("256", "0.01", "pr", "false")), + Some(List("Column1", "Column2", "Column3", "Column4"))) } it should "successfully parse a INSERT TABLE using arrays provided in VALUES" in { - val sentence = """INSERT INTO tableId VALUES ( [1,2], 12, 12.01, 'proof', [false,true], true, ["proof array", "proof2"])""" + val sentence = + """INSERT INTO tableId VALUES ( [1,2], 12, 12.01, 'proof', [false,true], true, ["proof array", "proof2"])""" parser.parse(sentence) shouldBe InsertIntoTable(TableIdentifier("tableId"), - List(List(List("1","2"),"12", "12.01", "proof", List("false","true"), "true", List("proof array","proof2")))) + List( + List(List("1", "2"), + "12", + "12.01", + "proof", + List("false", "true"), + "true", + List("proof array", "proof2")))) } it should "successfully parse a INSERT TABLE using arrays with Strings with comma provided in VALUES" in { - val sentence = """INSERT INTO tableId VALUES ( [1,2], 12, 12.01, 'proof', [false,true], true, ["proof, array", "proof2"])""" + val sentence = + """INSERT INTO tableId VALUES ( [1,2], 12, 12.01, 'proof', [false,true], true, ["proof, array", "proof2"])""" parser.parse(sentence) shouldBe InsertIntoTable(TableIdentifier("tableId"), - List(List(List("1","2"),"12", "12.01", "proof", List("false","true"), "true", List("proof, array","proof2")))) + List( + List(List("1", "2"), + "12", + "12.01", + "proof", + List("false", "true"), + "true", + List("proof, array", "proof2")))) } it should "successfully parse a INSERT TABLE using maps provided in VALUES" in { - val sentence = """INSERT INTO tableId VALUES ( (x -> 1, y -> 2), 12, 12.01, 'proof', (x1 -> false, x2 -> true), true)""" + val sentence = + """INSERT INTO tableId VALUES ( (x -> 1, y -> 2), 12, 12.01, 'proof', (x1 -> false, x2 -> true), true)""" parser.parse(sentence) shouldBe InsertIntoTable(TableIdentifier("tableId"), - List(List(Map("x"->"1","y"->"2"),"12", "12.01", "proof", Map("x1"->"false","x2"->"true"), "true"))) + List( + List(Map("x" -> "1", "y" -> "2"), + "12", + "12.01", + "proof", + Map("x1" -> "false", "x2" -> "true"), + "true"))) } it should "successfully parse a INSERT TABLE using maps with strings provided in VALUES" in { - val sentence = """INSERT INTO tableId VALUES ( (x -> 1, y -> 2, z -> 3), 12, 12.01, 'proof', (x1 -> "proof,comma", x2 -> "proof2"), true)""" + val sentence = + """INSERT INTO tableId VALUES ( (x -> 1, y -> 2, z -> 3), 12, 12.01, 'proof', (x1 -> "proof,comma", x2 -> "proof2"), true)""" parser.parse(sentence) shouldBe InsertIntoTable(TableIdentifier("tableId"), - List(List(Map("x"->"1","y"->"2","z"->"3"),"12", "12.01", "proof", Map("x1"->"proof,comma","x2"->"proof2"), "true"))) + List( + List(Map("x" -> "1", "y" -> "2", "z" -> "3"), + "12", + "12.01", + "proof", + Map("x1" -> "proof,comma", "x2" -> "proof2"), + "true"))) } it should "successfully parse a INSERT TABLE using maps with arrays and viceversa provided in VALUES" in { - val sentence = """INSERT INTO tableId VALUES ( [(x->1, y->2), (z->3)], (x -> [3,4], y -> [5,6]) )""" + val sentence = + """INSERT INTO tableId VALUES ( [(x->1, y->2), (z->3)], (x -> [3,4], y -> [5,6]) )""" parser.parse(sentence) shouldBe InsertIntoTable(TableIdentifier("tableId"), - List(List( List(Map("x"->"1","y"->"2"), Map("z"->"3")), Map("x" -> List("3","4"), "y" -> List("5","6")) ))) + List(List(List(Map("x" -> "1", "y" -> "2"), Map("z" -> "3")), + Map("x" -> List("3", "4"), "y" -> List("5", "6"))))) } @@ -218,11 +257,11 @@ class XDDdlParserSpec extends BaseXDTest with MockitoSugar{ val sentence = "CREATE VIEW vn AS SELECT * FROM tn" val logicalPlan = parser.parse(sentence) - logicalPlan shouldBe a [CreateView] + logicalPlan shouldBe a[CreateView] logicalPlan match { case CreateView(tableIdent, lPlan, sqlView) => tableIdent shouldBe TableIdentifier("vn") - sqlView.trim shouldBe "SELECT * FROM tn" + sqlView.trim shouldBe "SELECT * FROM tn" } } @@ -232,7 +271,7 @@ class XDDdlParserSpec extends BaseXDTest with MockitoSugar{ val sourceSentence = "SELECT * FROM tn" val sentence = s"CREATE TEMPORARY VIEW vn AS $sourceSentence" val logicalPlan = parser.parse(sentence) - logicalPlan shouldBe a [CreateTempView] + logicalPlan shouldBe a[CreateTempView] logicalPlan match { case CreateTempView(tableIdent, lPlan, sql) => tableIdent shouldBe TableIdentifier("vn") @@ -257,8 +296,7 @@ class XDDdlParserSpec extends BaseXDTest with MockitoSugar{ } it should "successfully parse a CREATE GLOBAL INDEX into a CreateGlobalIndex RunnableCommand" in { - val sentence = - """|CREATE GLOBAL INDEX myIndex + val sentence = """|CREATE GLOBAL INDEX myIndex |ON myDb.myTable(col1, col2) |WITH PK pk1 |USING com.stratio.crossdata.connector.elasticsearch @@ -267,18 +305,16 @@ class XDDdlParserSpec extends BaseXDTest with MockitoSugar{ | opt2 "opt2val" |)""".stripMargin parser.parse(sentence) shouldBe - CreateGlobalIndex( - TableIdentifier("myIndex"), + CreateGlobalIndex(TableIdentifier("myIndex"), TableIdentifier("myTable", Some("myDb")), - Seq("col1","col2"), + Seq("col1", "col2"), "pk1", Option("com.stratio.crossdata.connector.elasticsearch"), Map("opt1" -> "opt1val", "opt2" -> "opt2val")) } it should "successfully parse a CREATE GLOBAL INDEX without USING into a CreateGlobalIndex RunnableCommand" in { - val sentence = - """|CREATE GLOBAL INDEX myIndex + val sentence = """|CREATE GLOBAL INDEX myIndex |ON myDb.myTable(col1, col2) |WITH PK pk2 |OPTIONS ( @@ -286,19 +322,16 @@ class XDDdlParserSpec extends BaseXDTest with MockitoSugar{ | opt2 "opt2val" |)""".stripMargin parser.parse(sentence) shouldBe - CreateGlobalIndex( - TableIdentifier("myIndex"), - TableIdentifier("myTable", Some("myDb")), - Seq("col1","col2"), - "pk2", - None, - Map("opt1" -> "opt1val", "opt2" -> "opt2val")) + CreateGlobalIndex(TableIdentifier("myIndex"), + TableIdentifier("myTable", Some("myDb")), + Seq("col1", "col2"), + "pk2", + None, + Map("opt1" -> "opt1val", "opt2" -> "opt2val")) } - it should "successfully parse a CREATE GLOBAL INDEX without USING without dbName into a CreateGlobalIndex RunnableCommand" in { - val sentence = - """|CREATE GLOBAL INDEX myDbIndex.myIndex + val sentence = """|CREATE GLOBAL INDEX myDbIndex.myIndex |ON myTable(col1, col2) |WITH PK pk |OPTIONS ( @@ -306,13 +339,12 @@ class XDDdlParserSpec extends BaseXDTest with MockitoSugar{ | opt2 "opt2val" |)""".stripMargin parser.parse(sentence) shouldBe - CreateGlobalIndex( - TableIdentifier("myIndex",Option("myDbIndex")), - TableIdentifier("myTable"), - Seq("col1","col2"), - "pk", - None, - Map("opt1" -> "opt1val", "opt2" -> "opt2val")) + CreateGlobalIndex(TableIdentifier("myIndex", Option("myDbIndex")), + TableIdentifier("myTable"), + Seq("col1", "col2"), + "pk", + None, + Map("opt1" -> "opt1val", "opt2" -> "opt2val")) } } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/udaf/UdafsIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/udaf/UdafsIT.scala index 3684100d2..226043017 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/execution/udaf/UdafsIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/execution/udaf/UdafsIT.scala @@ -24,9 +24,6 @@ import org.apache.spark.sql.types.StructType import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner - - - @RunWith(classOf[JUnitRunner]) class UdafsIT extends SharedXDContextTest { @@ -40,7 +37,10 @@ class UdafsIT extends SharedXDContextTest { val schema = StructType(Seq(StructField("name", StringType), StructField("age", IntegerType))) - val df = _xdContext.createDataFrame(_xdContext.sc.parallelize(Seq(Row("Torcuato", 27), Row("Rosalinda", 34), Row("Arthur", 41))), schema) + val df = _xdContext.createDataFrame( + _xdContext.sc.parallelize( + Seq(Row("Torcuato", 27), Row("Rosalinda", 34), Row("Arthur", 41))), + schema) df.registerTempTable("udafs_test_gc") diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/metrics/XDMetricsSourceSpec.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/metrics/XDMetricsSourceSpec.scala index 43846db32..2ab60baf8 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/metrics/XDMetricsSourceSpec.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/metrics/XDMetricsSourceSpec.scala @@ -22,7 +22,6 @@ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class XDMetricsSourceSpec extends BaseXDTest { - "A logical plan (select *) with native relation" should "return a native relation" in { val xdms = new XDMetricsSource() @@ -31,8 +30,7 @@ class XDMetricsSourceSpec extends BaseXDTest { xdms.registerGauge("testName") //Expectations - xdms.metricRegistry.getGauges.keySet() should contain ("metricName.testName") + xdms.metricRegistry.getGauges.keySet() should contain("metricName.testName") } - } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/security/DummySecurityManager.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/security/DummySecurityManager.scala index b135d629d..5b943ea3c 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/security/DummySecurityManager.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/security/DummySecurityManager.scala @@ -19,12 +19,13 @@ object DummySecurityManager { val UniqueReply = "Authorized" } -class DummySecurityManager(credentials: Credentials, audit: Boolean) extends SecurityManager(credentials, audit) { +class DummySecurityManager(credentials: Credentials, audit: Boolean) + extends SecurityManager(credentials, audit) { import org.apache.spark.sql.crossdata.security.DummySecurityManager._ override def authorize(resource: Any): AuthorizationReply = { - if(audit) logInfo(s"DUMMY SECURITY MANAGER: $UniqueReply") + if (audit) logInfo(s"DUMMY SECURITY MANAGER: $UniqueReply") new AuthorizationReply(true, Some(UniqueReply)) } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/security/SecurityManagerIT.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/security/SecurityManagerIT.scala index 54cd42a4b..6e244534b 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/security/SecurityManagerIT.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/security/SecurityManagerIT.scala @@ -33,5 +33,4 @@ class SecurityManagerIT extends SharedXDContextTest { assert(securityManager.credentials.sessionId === Some("1234")) } - } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/session/BasicSessionProviderSpec.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/session/BasicSessionProviderSpec.scala index 5a02090d5..41ace1dd0 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/session/BasicSessionProviderSpec.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/session/BasicSessionProviderSpec.scala @@ -40,14 +40,15 @@ class BasicSessionProviderSpec extends SharedXDContextTest { implicit lazy val conf: CatalystConf = xdContext.catalog.conf - "BasicSessionProvider" should "provides new sessions whose properties are initialized properly" in { - val basicSessionProvider = new BasicSessionProvider(xdContext.sc, ConfigFactory.parseString(SparkSqlConfigString)) + val basicSessionProvider = + new BasicSessionProvider(xdContext.sc, ConfigFactory.parseString(SparkSqlConfigString)) val session = createNewSession(basicSessionProvider) - session.conf.settings should contain(Entry("spark.sql.inMemoryColumnarStorage.batchSize", "5000")) + session.conf.settings should contain( + Entry("spark.sql.inMemoryColumnarStorage.batchSize", "5000")) val tempCatalogs = tempCatalogsFromSession(session) @@ -56,7 +57,6 @@ class BasicSessionProviderSpec extends SharedXDContextTest { } - it should "provides a common persistent catalog and isolated catalogs" in { // TODO we should share the persistentCatalog @@ -80,7 +80,6 @@ class BasicSessionProviderSpec extends SharedXDContextTest { } - it should "allow to lookup an existing session" in { val basicSessionProvider = new BasicSessionProvider(xdContext.sc, ConfigFactory.empty()) @@ -89,7 +88,10 @@ class BasicSessionProviderSpec extends SharedXDContextTest { val session = createNewSession(basicSessionProvider, sessionId) - session.catalog.registerTable(tableIdent, LocalRelation(), Some(CrossdataTable(tableIdent.normalize, None, "fakedatasource"))) + session.catalog.registerTable( + tableIdent, + LocalRelation(), + Some(CrossdataTable(tableIdent.normalize, None, "fakedatasource"))) basicSessionProvider.session(sessionId) should matchPattern { case Success(s: XDSession) if Try(s.catalog.lookupRelation(tableIdent)).isSuccess => @@ -97,7 +99,6 @@ class BasicSessionProviderSpec extends SharedXDContextTest { } - it should "fail when trying to lookup a non-existing session" in { val basicSessionProvider = new BasicSessionProvider(xdContext.sc, ConfigFactory.empty()) @@ -106,8 +107,6 @@ class BasicSessionProviderSpec extends SharedXDContextTest { } - - it should "remove the session metadata when closing an open session" in { val basicSessionProvider = new BasicSessionProvider(xdContext.sc, ConfigFactory.empty()) val sessionId = UUID.randomUUID() @@ -130,7 +129,6 @@ class BasicSessionProviderSpec extends SharedXDContextTest { } - private def tempCatalogsFromSession(session: XDSession): Seq[XDTemporaryCatalog] = { session.catalog shouldBe a[CatalogChain] session.catalog.asInstanceOf[CatalogChain].temporaryCatalogs @@ -141,10 +139,11 @@ class BasicSessionProviderSpec extends SharedXDContextTest { session.catalog.asInstanceOf[CatalogChain].persistentCatalogs } - private def createNewSession(sessionProvider: XDSessionProvider, uuid: UUID = UUID.randomUUID()): XDSession = { + private def createNewSession(sessionProvider: XDSessionProvider, + uuid: UUID = UUID.randomUUID()): XDSession = { val optSession = sessionProvider.newSession(uuid).toOption optSession shouldBe defined optSession.get } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextTest.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextTest.scala index 6f862e709..67be7e42a 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextTest.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextTest.scala @@ -25,31 +25,30 @@ import org.apache.spark.sql.{ColumnName, DataFrame} import scala.language.implicitConversions /** - * Helper trait for SQL test suites where all tests share a single [[TestXDContext]]. - */ + * Helper trait for SQL test suites where all tests share a single [[TestXDContext]]. + */ trait SharedXDContextTest extends XDTestUtils { /** - * The [[TestXDContext]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ + * The [[TestXDContext]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ private var _ctx: org.apache.spark.sql.crossdata.test.TestXDContext = null val coreConfig: Option[Config] = None /** - * The [[TestXDContext]] to use for all tests in this suite. - */ - + * The [[TestXDContext]] to use for all tests in this suite. + */ protected def xdContext: TestXDContext = _ctx protected override def _xdContext: XDContext = _ctx /** - * Initialize the [[TestXDContext]]. - */ + * Initialize the [[TestXDContext]]. + */ protected override def beforeAll(): Unit = { if (_ctx == null) { _ctx = coreConfig.fold(new TestXDContext()) { cConfig => @@ -61,8 +60,8 @@ trait SharedXDContextTest extends XDTestUtils { } /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ protected override def afterAll(): Unit = { try { if (_ctx != null) { @@ -75,10 +74,10 @@ trait SharedXDContextTest extends XDTestUtils { } /** - * Converts $"col name" into an Column. - * - * @since 1.3.0 - */ + * Converts $"col name" into an Column. + * + * @since 1.3.0 + */ // This must be duplicated here to preserve binary compatibility with Spark < 1.5. implicit class StringToColumn(val sc: StringContext) { def $(args: Any*): ColumnName = { @@ -86,6 +85,7 @@ trait SharedXDContextTest extends XDTestUtils { } } - implicit def dataFrameToXDFrame(dataFrame: DataFrame): XDDataFrame = new XDDataFrame(dataFrame.sqlContext, dataFrame.queryExecution.logical) + implicit def dataFrameToXDFrame(dataFrame: DataFrame): XDDataFrame = + new XDDataFrame(dataFrame.sqlContext, dataFrame.queryExecution.logical) } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextTypesTest.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextTypesTest.scala index 85e6db94a..967b57cda 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextTypesTest.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextTypesTest.scala @@ -33,37 +33,36 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest { //Template steps: Override them - val emptyTypesSetError: String /* Error message to be shown when the types test data have not - * been properly inserted in the data source */ - def saveTypesData: Int // Entry point for saving types examples into the data source - def sparkAdditionalKeyColumns: Seq[SparkSQLColDef] = Seq() /* There are data sources which require their tables to have a - * primary key. This entry point allows specifying primary keys - * columns. - * NOTE that these `SparkSQLColdDef`s shouldn't have type checker - * since the column type does not form part of the test. - * e.g: - * override def sparkAdditionalKeyColumns( - * "k", - * "INT PRIMARY KEY" - * ) - */ - def dataTypesSparkOptions: Map[String, String] /* Especial SparkSQL options for type tables, it is equivalent to - * `defaultOptions` but will only apply in the registration of - * the types test table. - */ - + val emptyTypesSetError: String /* Error message to be shown when the types test data have not + * been properly inserted in the data source */ + def saveTypesData: Int // Entry point for saving types examples into the data source + def sparkAdditionalKeyColumns: Seq[SparkSQLColDef] = + Seq() /* There are data sources which require their tables to have a + * primary key. This entry point allows specifying primary keys + * columns. + * NOTE that these `SparkSQLColdDef`s shouldn't have type checker + * since the column type does not form part of the test. + * e.g: + * override def sparkAdditionalKeyColumns( + * "k", + * "INT PRIMARY KEY" + * ) + */ + def dataTypesSparkOptions: Map[String, String] /* Especial SparkSQL options for type tables, it is equivalent to + * `defaultOptions` but will only apply in the registration of + * the types test table. + */ //Template: This is the template implementation and shouldn't be modified in any specific test def doTypesTest(datasourceName: String): Unit = { - for(executionType <- ExecutionType.Spark::ExecutionType.Native::Nil) + for (executionType <- ExecutionType.Spark :: ExecutionType.Native :: Nil) datasourceName should s"provide the right types for $executionType execution" in { assumeEnvironmentIsUpAndRunning - val dframe = sql("SELECT " + typesSet.map(_.colname).mkString(", ") + s" FROM $dataTypesTableName") - for( - (tpe, i) <- typesSet zipWithIndex; - typeCheck <- tpe.typeCheck - ) typeCheck(dframe.collect(executionType).head(i)) + val dframe = + sql("SELECT " + typesSet.map(_.colname).mkString(", ") + s" FROM $dataTypesTableName") + for ((tpe, i) <- typesSet zipWithIndex; + typeCheck <- tpe.typeCheck) typeCheck(dframe.collect(executionType).head(i)) } //Multi-level column flat test @@ -119,42 +118,57 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest { } protected def typesSet: Seq[SparkSQLColDef] = Seq( - SparkSQLColDef("int", "INT", _ shouldBe a[java.lang.Integer]), - SparkSQLColDef("bigint", "BIGINT", _ shouldBe a[java.lang.Long]), - SparkSQLColDef("long", "LONG", _ shouldBe a[java.lang.Long]), - SparkSQLColDef("string", "STRING", _ shouldBe a[java.lang.String]), - SparkSQLColDef("boolean", "BOOLEAN", _ shouldBe a[java.lang.Boolean]), - SparkSQLColDef("double", "DOUBLE", _ shouldBe a[java.lang.Double]), - SparkSQLColDef("float", "FLOAT", _ shouldBe a[java.lang.Float]), - SparkSQLColDef("decimalint", "DECIMAL", _ shouldBe a[java.math.BigDecimal]), - SparkSQLColDef("decimallong", "DECIMAL", _ shouldBe a[java.math.BigDecimal]), - SparkSQLColDef("decimaldouble", "DECIMAL", _ shouldBe a[java.math.BigDecimal]), - SparkSQLColDef("decimalfloat", "DECIMAL", _ shouldBe a[java.math.BigDecimal]), - SparkSQLColDef("date", "DATE", _ shouldBe a[java.sql.Date]), - SparkSQLColDef("timestamp", "TIMESTAMP", _ shouldBe a[java.sql.Timestamp]), - SparkSQLColDef("tinyint", "TINYINT", _ shouldBe a[java.lang.Byte]), - SparkSQLColDef("smallint", "SMALLINT", _ shouldBe a[java.lang.Short]), - SparkSQLColDef("binary", "BINARY", _.asInstanceOf[Array[Byte]]), - SparkSQLColDef("arrayint", "ARRAY", _ shouldBe a[Seq[_]]), - SparkSQLColDef("arraystring", "ARRAY", _ shouldBe a[Seq[_]]), - SparkSQLColDef("mapintint", "MAP", _ shouldBe a[Map[_, _]]), - SparkSQLColDef("mapstringint", "MAP", _ shouldBe a[Map[_, _]]), - SparkSQLColDef("mapstringstring", "MAP", _ shouldBe a[Map[_, _]]), - SparkSQLColDef("struct", "STRUCT", _ shouldBe a[Row]), - SparkSQLColDef("arraystruct", "ARRAY>", _ shouldBe a[Seq[_]]), - SparkSQLColDef("arraystructwithdate", "ARRAY>", _ shouldBe a[Seq[_]]), - SparkSQLColDef("structofstruct", "STRUCT>", _ shouldBe a[Row]), - SparkSQLColDef("mapstruct", "MAP>", _ shouldBe a[Map[_,_]]), - SparkSQLColDef( - "arraystructarraystruct", - "ARRAY>>>", - { res => - res shouldBe a[Seq[_]] - res.asInstanceOf[Seq[_]].head shouldBe a[Row] - res.asInstanceOf[Seq[_]].head.asInstanceOf[Row].get(1) shouldBe a[Seq[_]] - res.asInstanceOf[Seq[_]].head.asInstanceOf[Row].get(1).asInstanceOf[Seq[_]].head shouldBe a[Row] - } - ) + SparkSQLColDef("int", "INT", _ shouldBe a[java.lang.Integer]), + SparkSQLColDef("bigint", "BIGINT", _ shouldBe a[java.lang.Long]), + SparkSQLColDef("long", "LONG", _ shouldBe a[java.lang.Long]), + SparkSQLColDef("string", "STRING", _ shouldBe a[java.lang.String]), + SparkSQLColDef("boolean", "BOOLEAN", _ shouldBe a[java.lang.Boolean]), + SparkSQLColDef("double", "DOUBLE", _ shouldBe a[java.lang.Double]), + SparkSQLColDef("float", "FLOAT", _ shouldBe a[java.lang.Float]), + SparkSQLColDef("decimalint", "DECIMAL", _ shouldBe a[java.math.BigDecimal]), + SparkSQLColDef("decimallong", "DECIMAL", _ shouldBe a[java.math.BigDecimal]), + SparkSQLColDef("decimaldouble", "DECIMAL", _ shouldBe a[java.math.BigDecimal]), + SparkSQLColDef("decimalfloat", "DECIMAL", _ shouldBe a[java.math.BigDecimal]), + SparkSQLColDef("date", "DATE", _ shouldBe a[java.sql.Date]), + SparkSQLColDef("timestamp", "TIMESTAMP", _ shouldBe a[java.sql.Timestamp]), + SparkSQLColDef("tinyint", "TINYINT", _ shouldBe a[java.lang.Byte]), + SparkSQLColDef("smallint", "SMALLINT", _ shouldBe a[java.lang.Short]), + SparkSQLColDef("binary", "BINARY", _.asInstanceOf[Array[Byte]]), + SparkSQLColDef("arrayint", "ARRAY", _ shouldBe a[Seq[_]]), + SparkSQLColDef("arraystring", "ARRAY", _ shouldBe a[Seq[_]]), + SparkSQLColDef("mapintint", "MAP", _ shouldBe a[Map[_, _]]), + SparkSQLColDef("mapstringint", "MAP", _ shouldBe a[Map[_, _]]), + SparkSQLColDef("mapstringstring", "MAP", _ shouldBe a[Map[_, _]]), + SparkSQLColDef("struct", "STRUCT", _ shouldBe a[Row]), + SparkSQLColDef("arraystruct", + "ARRAY>", + _ shouldBe a[Seq[_]]), + SparkSQLColDef("arraystructwithdate", + "ARRAY>", + _ shouldBe a[Seq[_]]), + SparkSQLColDef( + "structofstruct", + "STRUCT>", + _ shouldBe a[Row]), + SparkSQLColDef("mapstruct", + "MAP>", + _ shouldBe a[Map[_, _]]), + SparkSQLColDef( + "arraystructarraystruct", + "ARRAY>>>", { + res => + res shouldBe a[Seq[_]] + res.asInstanceOf[Seq[_]].head shouldBe a[Row] + res.asInstanceOf[Seq[_]].head.asInstanceOf[Row].get(1) shouldBe a[Seq[_]] + res + .asInstanceOf[Seq[_]] + .head + .asInstanceOf[Row] + .get(1) + .asInstanceOf[Seq[_]] + .head shouldBe a[Row] + } + ) ) override def sparkRegisterTableSQL: Seq[SparkTable] = super.sparkRegisterTableSQL :+ { @@ -168,9 +182,11 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest { object SharedXDContextTypesTest { val dataTypesTableName = "typesCheckTable" - case class SparkSQLColDef(colname: String, sqlType: String, typeCheck: Option[Any => Unit] = None) + case class SparkSQLColDef(colname: String, + sqlType: String, + typeCheck: Option[Any => Unit] = None) object SparkSQLColDef { def apply(colname: String, sqlType: String, typeCheck: Any => Unit): SparkSQLColDef = - SparkSQLColDef(colname, sqlType, Some(typeCheck)) + SparkSQLColDef(colname, sqlType, Some(typeCheck)) } } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextWithDataTest.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextWithDataTest.scala index 91674b566..7ea3732a9 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextWithDataTest.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextWithDataTest.scala @@ -25,41 +25,42 @@ trait SharedXDContextWithDataTest extends SharedXDContextTest with SparkLoggerCo //Template settings: Override them - type ClientParams /* Abstract type which should be overridden in order to specify the type of - * the native client used in the test to insert test data. - */ - - - val runningError: String /* Error message shown when a test is running without a propper - * environment being set - */ - - val provider: String // Datasource class name (fully specified) - def defaultOptions: Map[String, String] = Map.empty // Spark options used to register the test table in the catalog - - def sparkRegisterTableSQL: Seq[SparkTable] = Nil /* Spark CREATE sentence. Without OPTIONS or USING parts since - * they'll be generated from `provider` and `defaultOptions` - * attributes. - * e.g: override def sparkRegisterTableSQL: Seq[SparkTable] = - * Seq("CREATE TABLE T", "CREATE TEMPORARY TABLE S") - */ - - lazy val assumeEnvironmentIsUpAndRunning = - if (!isEnvironmentReady) { - fail(runningError) - } - + type ClientParams /* Abstract type which should be overridden in order to specify the type of + * the native client used in the test to insert test data. + */ + + val runningError: String /* Error message shown when a test is running without a propper + * environment being set + */ + + val provider: String // Datasource class name (fully specified) + def defaultOptions: Map[String, String] = + Map.empty // Spark options used to register the test table in the catalog + + def sparkRegisterTableSQL: Seq[SparkTable] = + Nil /* Spark CREATE sentence. Without OPTIONS or USING parts since + * they'll be generated from `provider` and `defaultOptions` + * attributes. + * e.g: override def sparkRegisterTableSQL: Seq[SparkTable] = + * Seq("CREATE TABLE T", "CREATE TEMPORARY TABLE S") + */ + + lazy val assumeEnvironmentIsUpAndRunning = if (!isEnvironmentReady) { + fail(runningError) + } - protected def prepareClient: Option[ClientParams] // Native client initialization - protected def terminateClient: Unit // Native client finalization - protected def saveTestData: Unit = () // Creation and insertion of test data examples - protected def cleanTestData: Unit /* Erases test data from the data source after the test has - * finished - */ + protected def prepareClient: Option[ClientParams] // Native client initialization + protected def terminateClient: Unit // Native client finalization + protected def saveTestData: Unit = + () // Creation and insertion of test data examples + protected def cleanTestData: Unit /* Erases test data from the data source after the test has + * finished + */ //Template: This is the template implementation and shouldn't be modified in any specific test - implicit def str2sparkTableDesc(query: String): SparkTable = SparkTable(query, defaultOptions) + implicit def str2sparkTableDesc(query: String): SparkTable = + SparkTable(query, defaultOptions) var client: Option[ClientParams] = None var isEnvironmentReady = false @@ -70,11 +71,14 @@ trait SharedXDContextWithDataTest extends SharedXDContextTest with SparkLoggerCo isEnvironmentReady = Try { client = prepareClient saveTestData - sparkRegisterTableSQL.foreach { case SparkTable(s, opts) => sql(Sentence(s, provider, opts).toString) } + sparkRegisterTableSQL.foreach { + case SparkTable(s, opts) => sql(Sentence(s, provider, opts).toString) + } client.isDefined - } recover { case e: Throwable => - logError(e.getMessage, e) - false + } recover { + case e: Throwable => + logError(e.getMessage, e) + false } get } @@ -102,4 +106,4 @@ object SharedXDContextWithDataTest { case class SparkTable(sql: String, options: Map[String, String]) -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/test/TestXDContext.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/test/TestXDContext.scala index 051886d7b..6367401cb 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/test/TestXDContext.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/test/TestXDContext.scala @@ -23,10 +23,12 @@ import org.apache.spark.sql.{SQLConf, SQLContext} import org.apache.spark.sql.crossdata.XDContext import org.apache.spark.{SparkConf, SparkContext} +object TestXDContext { -object TestXDContext{ - - val DefaultTestSparkConf: SparkConf = new SparkConf().set("spark.cores.max", "2").set("spark.sql.testkey", "true").set("spark.sql.shuffle.partitions", "3") + val DefaultTestSparkConf: SparkConf = new SparkConf() + .set("spark.cores.max", "2") + .set("spark.sql.testkey", "true") + .set("spark.sql.shuffle.partitions", "3") } @@ -35,23 +37,25 @@ import TestXDContext._ /** * A special [[SQLContext]] prepared for testing. */ -private[sql] class TestXDContext private(sc: SparkContext, coreConfig: Config) - extends XDContext(sc, coreConfig) { +private[sql] class TestXDContext private (sc: SparkContext, coreConfig: Config) + extends XDContext(sc, coreConfig) { def this() { this(new SparkContext( - "local[2]", - "test-xd-context", - DefaultTestSparkConf - ), ConfigFactory.empty()) + "local[2]", + "test-xd-context", + DefaultTestSparkConf + ), + ConfigFactory.empty()) } def this(catalogConfig: Config) { this(new SparkContext( - "local[2]", - "test-xd-context", - DefaultTestSparkConf - ), catalogConfig) + "local[2]", + "test-xd-context", + DefaultTestSparkConf + ), + catalogConfig) } } diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/test/XDImplicits.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/test/XDImplicits.scala index 31c88a26c..518d383cf 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/test/XDImplicits.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/test/XDImplicits.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.crossdata.test - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow @@ -31,31 +30,31 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag /** - * A collection of implicit methods for converting common Scala objects into [[org.apache.spark.sql.crossdata.XDDataFrame]]s. - */ + * A collection of implicit methods for converting common Scala objects into [[org.apache.spark.sql.crossdata.XDDataFrame]]s. + */ private[sql] abstract class XDImplicits { protected def _xdContext: XDContext /** - * An implicit conversion that turns a Scala `Symbol` into a Column. - * @since 1.3.0 - */ + * An implicit conversion that turns a Scala `Symbol` into a Column. + * @since 1.3.0 + */ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) /** - * Creates a DataFrame from an RDD of Product (e.g. case classes, tuples). - * @since 1.3.0 - */ - implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { + * Creates a DataFrame from an RDD of Product (e.g. case classes, tuples). + * @since 1.3.0 + */ + implicit def rddToDataFrameHolder[A <: Product: TypeTag](rdd: RDD[A]): DataFrameHolder = { DataFrameHolder(_xdContext.createDataFrame(rdd)) } /** - * Creates a DataFrame from a local Seq of Product. - * @since 1.3.0 - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { + * Creates a DataFrame from a local Seq of Product. + * @since 1.3.0 + */ + implicit def localSeqToDataFrameHolder[A <: Product: TypeTag](data: Seq[A]): DataFrameHolder = { DataFrameHolder(_xdContext.createDataFrame(data)) } @@ -64,9 +63,9 @@ private[sql] abstract class XDImplicits { // because of [[DoubleRDDFunctions]]. /** - * Creates a single column DataFrame from an RDD[Int]. - * @since 1.3.0 - */ + * Creates a single column DataFrame from an RDD[Int]. + * @since 1.3.0 + */ implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { val dataType = IntegerType val rows = data.mapPartitions { iter => @@ -77,13 +76,13 @@ private[sql] abstract class XDImplicits { } } DataFrameHolder( - _xdContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + _xdContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } /** - * Creates a single column DataFrame from an RDD[Long]. - * @since 1.3.0 - */ + * Creates a single column DataFrame from an RDD[Long]. + * @since 1.3.0 + */ implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { val dataType = LongType val rows = data.mapPartitions { iter => @@ -94,13 +93,13 @@ private[sql] abstract class XDImplicits { } } DataFrameHolder( - _xdContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + _xdContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } /** - * Creates a single column DataFrame from an RDD[String]. - * @since 1.3.0 - */ + * Creates a single column DataFrame from an RDD[String]. + * @since 1.3.0 + */ implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { val dataType = StringType val rows = data.mapPartitions { iter => @@ -111,8 +110,6 @@ private[sql] abstract class XDImplicits { } } DataFrameHolder( - _xdContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + _xdContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } } - - diff --git a/core/src/test/scala/org/apache/spark/sql/crossdata/test/XDTestUtils.scala b/core/src/test/scala/org/apache/spark/sql/crossdata/test/XDTestUtils.scala index 824aaaa37..33e6198cc 100644 --- a/core/src/test/scala/org/apache/spark/sql/crossdata/test/XDTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/sql/crossdata/test/XDTestUtils.scala @@ -32,56 +32,52 @@ import scala.language.implicitConversions import scala.util.Try /** - * Helper trait that should be extended by all SQL test suites. - * - * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data - * prepared in advance as well as all implicit conversions used extensively by dataframes. - * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. - * - * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is - * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. - */ -trait XDTestUtils - extends BaseXDTest - with BeforeAndAfterAll { - self => + * Helper trait that should be extended by all SQL test suites. + * + * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * prepared in advance as well as all implicit conversions used extensively by dataframes. + * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * + * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +trait XDTestUtils extends BaseXDTest with BeforeAndAfterAll { self => protected def _xdContext: XDContext // Shorthand for running a query using our SQLContext protected lazy val sql = _xdContext.sql _ - /** - * The Hadoop configuration used by the active [[SQLContext]]. - */ + * The Hadoop configuration used by the active [[SQLContext]]. + */ protected def configuration: Configuration = { _xdContext.sparkContext.hadoopConfiguration } /** - * A helper object for importing SQL implicits. - * - * Note that the alternative of importing `sqlContext.implicits._` is not possible here. - * This is because we create the [[SQLContext]] immediately before the first test is run, - * but the implicits import is needed in the constructor. - */ + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `sqlContext.implicits._` is not possible here. + * This is because we create the [[SQLContext]] immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ protected object testImplicits extends XDImplicits { protected override def _xdContext: XDContext = self._xdContext } - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL - * configurations. - * - * @todo Probably this method should be moved to a more general place - */ + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + * + * @todo Probably this method should be moved to a more general place + */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip val currentValues = keys.map(key => Try(_xdContext.conf.getConfString(key)).toOption) (keys, values).zipped.foreach(_xdContext.conf.setConfString) - try f finally { + try f + finally { keys.zip(currentValues).foreach { case (key, Some(value)) => _xdContext.conf.setConfString(key, value) case (key, None) => _xdContext.conf.unsetConf(key) @@ -90,34 +86,35 @@ trait XDTestUtils } /** - * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If - * a file/directory is created there by `f`, it will be delete after `f` returns. - * - * @todo Probably this method should be moved to a more general place - */ + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + * + * @todo Probably this method should be moved to a more general place + */ protected def withTempPath(f: File => Unit): Unit = { val path = Utils.createTempDir() path.delete() - try f(path) finally Utils.deleteRecursively(path) + try f(path) + finally Utils.deleteRecursively(path) } /** - * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` - * returns. - * - * @todo Probably this method should be moved to a more general place - */ + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + * + * @todo Probably this method should be moved to a more general place + */ protected def withTempDir(f: File => Unit): Unit = { val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) + try f(dir) + finally Utils.deleteRecursively(dir) } - /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier - * way to construct [[XDDataFrame]] directly out of local data without relying on implicits. - */ + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier + * way to construct [[XDDataFrame]] directly out of local data without relying on implicits. + */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { XDDataFrame(_xdContext, plan) } -} \ No newline at end of file +} diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/Driver.scala b/driver/src/main/scala/com/stratio/crossdata/driver/Driver.scala index 2645a73a8..8999c5bed 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/Driver.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/Driver.scala @@ -46,7 +46,6 @@ import scala.util.Try * ======================================================================================= */ - object Driver { /** @@ -82,13 +81,15 @@ object Driver { * WARNING! It should be called once all active sessions have been closed. After the shutdown, new session cannot be created. */ def shutdown(): Unit = { - if (!system.isTerminated) system.shutdown() + if (!system.isTerminated) system.shutdown() } - private[crossdata] def newSession(driverConf: DriverConf, authentication: Authentication): Driver = { + private[crossdata] def newSession(driverConf: DriverConf, + authentication: Authentication): Driver = { val driver = new Driver(driverConf, authentication) val isConnected = driver.openSession().getOrElse { - throw new RuntimeException(s"Cannot establish connection to XDServer: timed out after ${Driver.InitializationTimeout}") + throw new RuntimeException( + s"Cannot establish connection to XDServer: timed out after ${Driver.InitializationTimeout}") } if (!isConnected) { throw new RuntimeException(s"The server has rejected the open session request") @@ -96,7 +97,8 @@ object Driver { driver } - private[driver] def generateDefaultAuth = new Authentication("crossdata", "stratio") + private[driver] def generateDefaultAuth = + new Authentication("crossdata", "stratio") Runtime.getRuntime.addShutdownHook(new Thread(new Runnable { def run() { @@ -106,8 +108,8 @@ object Driver { } -class Driver private(private[crossdata] val driverConf: DriverConf, - auth: Authentication = Driver.generateDefaultAuth) { +class Driver private (private[crossdata] val driverConf: DriverConf, + auth: Authentication = Driver.generateDefaultAuth) { import Driver._ @@ -115,7 +117,6 @@ class Driver private(private[crossdata] val driverConf: DriverConf, private lazy val logger = LoggerFactory.getLogger(classOf[Driver]) - private lazy val clusterClientActor = { if (logger.isDebugEnabled) { @@ -126,7 +127,8 @@ class Driver private(private[crossdata] val driverConf: DriverConf, val initialContacts = contactPoints.map(system.actorSelection).toSet logger.debug("Initial contacts: " + initialContacts) - val remoteClientName: String = ServerClusterClientParameters.RemoteClientName + UUID.randomUUID() + val remoteClientName: String = ServerClusterClientParameters.RemoteClientName + UUID + .randomUUID() val actor = system.actorOf(ClusterClient.props(initialContacts), remoteClientName) logger.debug(s"Cluster client actor created with name: $remoteClientName") @@ -139,11 +141,11 @@ class Driver private(private[crossdata] val driverConf: DriverConf, } private val sessionBeaconProps = SessionBeaconActor.props( - driverSession.id, - 5 seconds, /* This ins't configurable since it's simpler for the user + driverSession.id, + 5 seconds, /* This ins't configurable since it's simpler for the user to play just with alert period time at server side. */ - clusterClientActor, - ServerClusterClientParameters.ClientMonitorPath + clusterClientActor, + ServerClusterClientParameters.ClientMonitorPath ) private var sessionBeacon: Option[ActorRef] = None @@ -167,8 +169,8 @@ class Driver private(private[crossdata] val driverConf: DriverConf, //TODO remove this part when servers broadcast bus was realized //Preparse query to know if it is an special command sent from the shell or other driver user that is not a query val addJarPattern = """(\s*add)(\s+jar\s+)(.*)""".r - val addAppWithAliasPattern ="""(\s*add)(\s+app\s+)(.*)(\s+as\s+)(.*)(\s+with\s+)(.*)""".r - val addAppPattern ="""(\s*add)(\s+app\s+)(.*)(\s+with\s+)(.*)""".r + val addAppWithAliasPattern = """(\s*add)(\s+app\s+)(.*)(\s+as\s+)(.*)(\s+with\s+)(.*)""".r + val addAppPattern = """(\s*add)(\s+app\s+)(.*)(\s+with\s+)(.*)""".r query match { case addJarPattern(add, jar, path) => @@ -188,7 +190,8 @@ class Driver private(private[crossdata] val driverConf: DriverConf, val sqlCommand = new SQLCommand(query, retrieveColNames = driverConf.getFlattenTables) val futureReply = askCommand(securitizeCommand(sqlCommand)).map { case SQLReply(_, sqlResult) => sqlResult - case other => throw new RuntimeException(s"SQLReply expected. Received: $other") + case other => + throw new RuntimeException(s"SQLReply expected. Received: $other") } new SQLResponse(sqlCommand.requestId, futureReply) { // TODO cancel sync => 5 secs @@ -198,28 +201,28 @@ class Driver private(private[crossdata] val driverConf: DriverConf, } } - /** * Add Jar to the XD Context * * @param path The path of the JAR * @return A SQLResponse with the id and the result set. */ - def addJar(path: String, toClassPath:Option[Boolean]=None): SQLResponse = { - val addJarCommand = AddJARCommand(path,toClassPath=toClassPath) + def addJar(path: String, toClassPath: Option[Boolean] = None): SQLResponse = { + val addJarCommand = AddJARCommand(path, toClassPath = toClassPath) if (File(path).exists) { val futureReply = askCommand(securitizeCommand(addJarCommand)).map { case SQLReply(_, sqlResult) => sqlResult - case other => throw new RuntimeException(s"SQLReply expected. Received: $other") + case other => + throw new RuntimeException(s"SQLReply expected. Received: $other") } new SQLResponse(addJarCommand.requestId, futureReply) } else { - new SQLResponse(addJarCommand.requestId, Future(ErrorSQLResult("File doesn't exist"))) + new SQLResponse(addJarCommand.requestId, Future(ErrorSQLResult("File doesn't exist"))) } } def addAppCommand(path: String, clss: String, alias: Option[String] = None): SQLResponse = { - val result = addJar(path,Option(false)).waitForResult() + val result = addJar(path, Option(false)).waitForResult() val hdfsPath = result.resultSet(0).getString(0) addApp(hdfsPath, clss, alias.getOrElse(path)) } @@ -234,26 +237,29 @@ class Driver private(private[crossdata] val driverConf: DriverConf, val addAppCommand = AddAppCommand(path, alias, clss) val futureReply = askCommand(securitizeCommand(addAppCommand)).map { case SQLReply(_, sqlResult) => sqlResult - case other => throw new RuntimeException(s"SQLReply expected. Received: $other") + case other => + throw new RuntimeException(s"SQLReply expected. Received: $other") } new SQLResponse(addAppCommand.requestId, futureReply) } - def importTables(dataSourceProvider: String, options: Map[String, String]): SQLResponse = sql( - s"""|IMPORT TABLES + s"""|IMPORT TABLES |USING $dataSourceProvider |${mkOptionsStatement(options)} """.stripMargin ) - // TODO schema -> StructType insteadOf String // schema -> e.g "( name STRING, age INT )" - def createTable(name: String, dataSourceProvider: String, schema: Option[String], options: Map[String, String], isTemporary: Boolean = false): SQLResponse = + def createTable(name: String, + dataSourceProvider: String, + schema: Option[String], + options: Map[String, String], + isTemporary: Boolean = false): SQLResponse = sql( - s"""|CREATE ${if (isTemporary) "TEMPORARY" else ""} TABLE $name + s"""|CREATE ${if (isTemporary) "TEMPORARY" else ""} TABLE $name |USING $dataSourceProvider |${schema.getOrElse("")} |${mkOptionsStatement(options)} @@ -262,10 +268,11 @@ class Driver private(private[crossdata] val driverConf: DriverConf, def dropTable(name: String, isTemporary: Boolean = false): SQLResponse = { - if (isTemporary) throw new UnsupportedOperationException("Drop temporary table is not supported yet") + if (isTemporary) + throw new UnsupportedOperationException("Drop temporary table is not supported yet") sql( - s"""|DROP ${if (isTemporary) "TEMPORARY" else ""} + s"""|DROP ${if (isTemporary) "TEMPORARY" else ""} |TABLE $name """.stripMargin ) @@ -273,7 +280,7 @@ class Driver private(private[crossdata] val driverConf: DriverConf, def dropAllTables(): SQLResponse = { sql( - s"""|DROP ALL TABLES""".stripMargin + s"""|DROP ALL TABLES""".stripMargin ) } @@ -282,14 +289,12 @@ class Driver private(private[crossdata] val driverConf: DriverConf, options.headOption.fold("")(_ => s" OPTIONS ( $opt ) ") } - private def askCommand(commandEnvelope: CommandEnvelope): Future[ServerReply] = { val promise = Promise[ServerReply]() - proxyActor !(commandEnvelope, promise) + proxyActor ! (commandEnvelope, promise) promise.future } - /** * Gets a list of tables from a database or all if the database is None * @@ -321,19 +326,21 @@ class Driver private(private[crossdata] val driverConf: DriverConf, */ def describeTable(database: Option[String], tableName: String): Seq[FieldMetadata] = { - def extractNameDataType: Row => (String, String) = row => (row.getString(0), row.getString(1)) + def extractNameDataType: Row => (String, String) = + row => (row.getString(0), row.getString(1)) import SQLResponse._ val sqlResult: SQLResult = sql(s"DESCRIBE ${database.map(_ + ".").getOrElse("")}$tableName") sqlResult match { case SuccessfulSQLResult(result, _) => - result.map(extractNameDataType) flatMap { case (name, dataType) => - if (!driverConf.getFlattenTables) { - FieldMetadata(name, DataTypesUtils.toDataType(dataType)) :: Nil - } else { - getFlattenedFields(name, DataTypesUtils.toDataType(dataType)) - } + result.map(extractNameDataType) flatMap { + case (name, dataType) => + if (!driverConf.getFlattenTables) { + FieldMetadata(name, DataTypesUtils.toDataType(dataType)) :: Nil + } else { + getFlattenedFields(name, DataTypesUtils.toDataType(dataType)) + } } toSeq case other => @@ -341,8 +348,8 @@ class Driver private(private[crossdata] val driverConf: DriverConf, } } - def show(query: String) = sql(query).waitForResult().prettyResult.foreach(println) - + def show(query: String) = + sql(query).waitForResult().prettyResult.foreach(println) /** * Gets the server/cluster state @@ -352,7 +359,7 @@ class Driver private(private[crossdata] val driverConf: DriverConf, */ def clusterState(): Future[CurrentClusterState] = { val promise = Promise[ServerReply]() - proxyActor !(securitizeCommand(ClusterStateCommand()), promise) + proxyActor ! (securitizeCommand(ClusterStateCommand()), promise) promise.future.mapTo[ClusterStateReply].map(_.clusterState) } @@ -380,7 +387,6 @@ class Driver private(private[crossdata] val driverConf: DriverConf, def isClusterAlive(atMost: Duration = 3 seconds): Boolean = Try(Await.result(serversUp(), atMost)).map(_.nonEmpty).getOrElse(false) - private def openSession(): Try[Boolean] = { import Driver._ @@ -390,7 +396,7 @@ class Driver private(private[crossdata] val driverConf: DriverConf, Await.result(promise.future.mapTo[OpenSessionReply].map(_.isOpen), InitializationTimeout) } - if(res.isSuccess) + if (res.isSuccess) sessionBeacon = Some(system.actorOf(sessionBeaconProps)) res @@ -408,15 +414,16 @@ class Driver private(private[crossdata] val driverConf: DriverConf, private def securitizeCommand(command: Command): CommandEnvelope = new CommandEnvelope(command, driverSession) - - private def getFlattenedFields(fieldName: String, dataType: DataType): Seq[FieldMetadata] = dataType match { - case structType: StructType => - structType.flatMap(field => getFlattenedFields(s"$fieldName.${field.name}", field.dataType)) - case ArrayType(etype, _) => - getFlattenedFields(fieldName, etype) - case _ => - FieldMetadata(fieldName, dataType) :: Nil - } + private def getFlattenedFields(fieldName: String, dataType: DataType): Seq[FieldMetadata] = + dataType match { + case structType: StructType => + structType.flatMap(field => + getFlattenedFields(s"$fieldName.${field.name}", field.dataType)) + case ArrayType(etype, _) => + getFlattenedFields(fieldName, etype) + case _ => + FieldMetadata(fieldName, dataType) :: Nil + } private def handleCommandError(result: SQLResult) = result match { case ErrorSQLResult(message, Some(cause)) => diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/JavaDriver.scala b/driver/src/main/scala/com/stratio/crossdata/driver/JavaDriver.scala index d4b430ce7..b459b4d07 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/JavaDriver.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/JavaDriver.scala @@ -24,9 +24,7 @@ import org.slf4j.LoggerFactory import scala.collection.JavaConversions._ import scala.concurrent.duration.Duration - -class JavaDriver private(driverConf: DriverConf, - auth: Authentication) { +class JavaDriver private (driverConf: DriverConf, auth: Authentication) { def this(driverConf: DriverConf) = this(driverConf, Driver.generateDefaultAuth) @@ -45,66 +43,68 @@ class JavaDriver private(driverConf: DriverConf, def this(seedNodes: java.util.List[String]) = this(seedNodes, new DriverConf) - private lazy val logger = LoggerFactory.getLogger(classOf[JavaDriver]) private val scalaDriver = Driver.newSession(driverConf, auth) /** - * Sync execution with defaults: timeout 10 sec, nr-retries 2 - */ + * Sync execution with defaults: timeout 10 sec, nr-retries 2 + */ def sql(sqlText: String): SQLResult = scalaDriver.sql(sqlText).waitForResult() def sql(sqlText: String, timeoutDuration: Duration): SQLResult = scalaDriver.sql(sqlText).waitForResult(timeoutDuration) - def importTables(dataSourceProvider: String, options: java.util.Map[String, String]): SQLResult = scalaDriver.importTables(dataSourceProvider, options.toMap) - def createTable(name: String, dataSourceProvider: String, schema: Option[String], options: java.util.Map[String, String], isTemporary: Boolean): SQLResult = - scalaDriver.createTable(name, dataSourceProvider, schema, options.toMap, isTemporary).waitForResult() + def createTable(name: String, + dataSourceProvider: String, + schema: Option[String], + options: java.util.Map[String, String], + isTemporary: Boolean): SQLResult = + scalaDriver + .createTable(name, dataSourceProvider, schema, options.toMap, isTemporary) + .waitForResult() def dropTable(name: String, isTemporary: Boolean = false): SQLResult = scalaDriver.dropTable(name, isTemporary) - def listTables(): java.util.List[JavaTableName] = - scalaDriver.listTables(None).map { case (table, database) => new JavaTableName(table, database.getOrElse("")) } - + scalaDriver.listTables(None).map { + case (table, database) => + new JavaTableName(table, database.getOrElse("")) + } def listTables(database: String): java.util.List[JavaTableName] = - scalaDriver.listTables(Some(database)).map { case (table, database) => new JavaTableName(table, database.getOrElse("")) } - + scalaDriver.listTables(Some(database)).map { + case (table, database) => + new JavaTableName(table, database.getOrElse("")) + } def describeTable(database: String, tableName: String): java.util.List[FieldMetadata] = scalaDriver.describeTable(Some(database), tableName) - def describeTable(tableName: String): java.util.List[FieldMetadata] = scalaDriver.describeTable(None, tableName) - def show(sqlText: String): Unit = scalaDriver.show(sqlText) /** - * Indicates if the cluster is alive or not - * - * @since 1.3 - * @return whether at least one member of the cluster is alive or not - */ + * Indicates if the cluster is alive or not + * + * @since 1.3 + * @return whether at least one member of the cluster is alive or not + */ def isClusterAlive(): Boolean = scalaDriver.isClusterAlive() def closeSession(): Unit = scalaDriver.closeSession() - - def addJar(path:String): Unit = + def addJar(path: String): Unit = scalaDriver.addJar(path) - } - diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/actor/ProxyActor.scala b/driver/src/main/scala/com/stratio/crossdata/driver/actor/ProxyActor.scala index 319595a71..a3c9d4e0c 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/actor/ProxyActor.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/actor/ProxyActor.scala @@ -15,7 +15,6 @@ */ package com.stratio.crossdata.driver.actor - import java.util.UUID import akka.actor.{Actor, ActorRef, Props} @@ -61,12 +60,12 @@ class ProxyActor(clusterClientActor: ActorRef, driver: Driver) extends Actor { self ! any } - // Previous step to process the message where promise is stored. def storePromise(promisesByIds: PromisesByIds): Receive = { case (message: CommandEnvelope, promise: Promise[ServerReply @unchecked]) => logger.debug("Sending message to the Crossdata cluster") - context.become(start(promisesByIds.copy(promisesByIds.promises + (message.cmd.requestId -> promise)))) + context.become( + start(promisesByIds.copy(promisesByIds.promises + (message.cmd.requestId -> promise)))) self ! message } @@ -74,43 +73,50 @@ class ProxyActor(clusterClientActor: ActorRef, driver: Driver) extends Actor { def sendToServer(promisesByIds: PromisesByIds): Receive = { case secureSQLCommand @ CommandEnvelope(sqlCommand: SQLCommand, _) => - logger.info(s"Sending query: ${sqlCommand.sql} with requestID=${sqlCommand.requestId} & queryID=${sqlCommand.queryId}") - clusterClientActor ! ClusterClient.Send(ServerClusterClientParameters.ServerPath, secureSQLCommand, localAffinity = false) + logger.info( + s"Sending query: ${sqlCommand.sql} with requestID=${sqlCommand.requestId} & queryID=${sqlCommand.queryId}") + clusterClientActor ! ClusterClient + .Send(ServerClusterClientParameters.ServerPath, secureSQLCommand, localAffinity = false) - case secureSQLCommand @ CommandEnvelope(addJARCommand @ AddJARCommand(path, _, _, _), session) => + case secureSQLCommand @ CommandEnvelope(addJARCommand @ AddJARCommand(path, _, _, _), + session) => import context.dispatcher val shipmentResponse: Future[SQLReply] = sendJarToServers(addJARCommand, path, session) shipmentResponse pipeTo sender case secureSQLCommand @ CommandEnvelope(clusterStateCommand: ClusterStateCommand, _) => logger.debug(s"Send cluster state with requestID=${clusterStateCommand.requestId}") - clusterClientActor ! ClusterClient.Send(ServerClusterClientParameters.ServerPath, secureSQLCommand, localAffinity = false) + clusterClientActor ! ClusterClient + .Send(ServerClusterClientParameters.ServerPath, secureSQLCommand, localAffinity = false) case secureSQLCommand @ CommandEnvelope(aCmd @ AddAppCommand(path, alias, clss, _), _) => - clusterClientActor ! ClusterClient.Send(ServerClusterClientParameters.ServerPath,secureSQLCommand, localAffinity=false) + clusterClientActor ! ClusterClient + .Send(ServerClusterClientParameters.ServerPath, secureSQLCommand, localAffinity = false) case secureSQLCommand @ CommandEnvelope(_: OpenSessionCommand | _: CloseSessionCommand, _) => - clusterClientActor ! ClusterClient.Send(ServerClusterClientParameters.ServerPath, secureSQLCommand, localAffinity = true) + clusterClientActor ! ClusterClient + .Send(ServerClusterClientParameters.ServerPath, secureSQLCommand, localAffinity = true) case sqlCommand: SQLCommand => - logger.warn(s"Command message not securitized: ${sqlCommand.sql}. Message won't be sent to the Crossdata cluster") + logger.warn( + s"Command message not securitized: ${sqlCommand.sql}. Message won't be sent to the Crossdata cluster") } - - def sendJarToServers(command: Command, path: String, session:Session): Future[SQLReply] = { + def sendJarToServers(command: Command, path: String, session: Session): Future[SQLReply] = { import scala.concurrent.ExecutionContext.Implicits.global httpClient.sendJarToHTTPServer(path, session) map { response => SQLReply( - command.requestId, - SuccessfulSQLResult(Array(Row(response)), StructType(StructField("filepath", StringType) :: Nil)) + command.requestId, + SuccessfulSQLResult(Array(Row(response)), + StructType(StructField("filepath", StringType) :: Nil)) ) } recover { case failureCause => val msg = s"Error trying to send JAR through HTTP: ${failureCause.getMessage}" logger.error(msg) SQLReply( - command.requestId, - ErrorSQLResult(msg) + command.requestId, + ErrorSQLResult(msg) ) } } @@ -118,7 +124,8 @@ class ProxyActor(clusterClientActor: ActorRef, driver: Driver) extends Actor { // Message received from a Crossdata Server. def receiveFromServer(promisesByIds: PromisesByIds): Receive = { case reply: ServerReply => - logger.info(s"Sever reply received from Crossdata Server: $sender with ID=${reply.requestId}") + logger.info( + s"Sever reply received from Crossdata Server: $sender with ID=${reply.requestId}") promisesByIds.promises.get(reply.requestId) match { case Some(p) => context.become(start(promisesByIds.copy(promisesByIds.promises - reply.requestId))) @@ -145,11 +152,10 @@ class ProxyActor(clusterClientActor: ActorRef, driver: Driver) extends Actor { def start(promisesByIds: PromisesByIds): Receive = { storePromise(promisesByIds) orElse - sendToServer(promisesByIds) orElse - receiveFromServer(promisesByIds) orElse { + sendToServer(promisesByIds) orElse + receiveFromServer(promisesByIds) orElse { case any => logger.warn(s"Unknown message: $any. Message won't be sent to the Crossdata cluster") } } } - diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/actor/SessionBeaconActor.scala b/driver/src/main/scala/com/stratio/crossdata/driver/actor/SessionBeaconActor.scala index 5ccf08215..4327ac55a 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/actor/SessionBeaconActor.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/actor/SessionBeaconActor.scala @@ -26,11 +26,10 @@ import scala.concurrent.duration.FiniteDuration object SessionBeaconActor { - def props( - sessionId: UUID, - period: FiniteDuration, - clusterClientActor: ActorRef, - clusterPath: String): Props = + def props(sessionId: UUID, + period: FiniteDuration, + clusterClientActor: ActorRef, + clusterPath: String): Props = Props(new SessionBeaconActor(sessionId, period, clusterClientActor, clusterPath)) } @@ -39,11 +38,12 @@ object SessionBeaconActor { * This actor is used by the driver provide the cluster with proof of life for the current session. * Check [[LiveMan]] for more details. */ -class SessionBeaconActor private ( - override val keepAliveId: UUID, - override val period: FiniteDuration, - clusterClientActor: ActorRef, - clusterPath: String) extends Actor with LiveMan[UUID] { +class SessionBeaconActor private (override val keepAliveId: UUID, + override val period: FiniteDuration, + clusterClientActor: ActorRef, + clusterPath: String) + extends Actor + with LiveMan[UUID] { override def receive: Receive = PartialFunction.empty override val master: ActorRef = clusterClientActor diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/config/DriverConf.scala b/driver/src/main/scala/com/stratio/crossdata/driver/config/DriverConf.scala index e3d2945ce..7834a7da9 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/config/DriverConf.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/config/DriverConf.scala @@ -33,16 +33,16 @@ class DriverConf extends Logging { private val userSettings = new ConcurrentHashMap[String, ConfigValue]() - private[crossdata] lazy val finalSettings: Config = - userSettings.foldLeft(typesafeConf) { case (prevConfig, keyValue) => + private[crossdata] lazy val finalSettings: Config = userSettings.foldLeft(typesafeConf) { + case (prevConfig, keyValue) => prevConfig.withValue(keyValue._1, keyValue._2) - } + } /** - * Adds a generic key-value - * akka => e.g akka.loglevel = "INFO" - * driverConfig => e.g config.cluster.actor = "my-server-actor" - */ + * Adds a generic key-value + * akka => e.g akka.loglevel = "INFO" + * driverConfig => e.g config.cluster.actor = "my-server-actor" + */ def set(key: String, value: ConfigValue): DriverConf = { userSettings.put(key, value) this @@ -54,16 +54,16 @@ class DriverConf extends Logging { } /** - * @param hostAndPort e.g 127.0.0.1:13420 - */ + * @param hostAndPort e.g 127.0.0.1:13420 + */ def setClusterContactPoint(hostAndPort: String*): DriverConf = { userSettings.put(DriverConfigHosts, ConfigValueFactory.fromIterable(hostAndPort)) this } /** - * @param hostAndPort e.g 127.0.0.1:13420 - */ + * @param hostAndPort e.g 127.0.0.1:13420 + */ def setClusterContactPoint(hostAndPort: java.util.List[String]): DriverConf = { userSettings.put(DriverConfigHosts, ConfigValueFactory.fromIterable(hostAndPort)) this @@ -75,7 +75,8 @@ class DriverConf extends Logging { } def setTunnelTimeout(seconds: Int): DriverConf = { - userSettings.put(AkkaClusterRecepcionistTunnelTimeout, ConfigValueFactory.fromAnyRef(seconds * 1000)) + userSettings + .put(AkkaClusterRecepcionistTunnelTimeout, ConfigValueFactory.fromAnyRef(seconds * 1000)) this } @@ -90,7 +91,7 @@ class DriverConf extends Logging { private[crossdata] def getClusterContactPoint: List[String] = { val hosts = finalSettings.getStringList(DriverConfigHosts).toList val clusterName = finalSettings.getString(DriverClusterName) - val ssl= Try(finalSettings.getBoolean(SSLEnabled)).getOrElse(false) + val ssl = Try(finalSettings.getBoolean(SSLEnabled)).getOrElse(false) if (ssl) hosts map (host => s"akka.ssl.tcp://$clusterName@$host$ActorsPath") else @@ -109,7 +110,6 @@ class DriverConf extends Logging { private[crossdata] def getFlattenTables: Boolean = finalSettings.getBoolean(DriverFlattenTables) - private val typesafeConf: Config = { val defaultConfig = ConfigFactory.load(DriverConfigDefault).getConfig(ParentConfigName) @@ -158,21 +158,19 @@ class DriverConf extends Logging { } // System properties - val systemPropertiesConfig = - Try( + val systemPropertiesConfig = Try( ConfigFactory.parseProperties(System.getProperties).getConfig(ParentConfigName) - ).getOrElse( + ).getOrElse( ConfigFactory.parseProperties(System.getProperties) - ) + ) val finalConfigWithSystemProperties = systemPropertiesConfig.withFallback(finalConfig) val finalConfigWithEnvVars = { if (finalConfigWithSystemProperties.hasPath("config.cluster.servers")) { val serverNodes = finalConfigWithSystemProperties.getString("config.cluster.servers") - defaultConfig.withValue( - DriverConfigHosts, - ConfigValueFactory.fromIterable(serverNodes.split(",").toList)) + defaultConfig.withValue(DriverConfigHosts, + ConfigValueFactory.fromIterable(serverNodes.split(",").toList)) } else { finalConfigWithSystemProperties } @@ -185,7 +183,6 @@ class DriverConf extends Logging { } - object DriverConf { val ActorsPath = "/user/receptionist" val DriverConfigDefault = "driver-reference.conf" @@ -197,5 +194,6 @@ object DriverConf { val DriverFlattenTables = "config.flatten-tables" val DriverClusterName = "config.cluster.name" val SSLEnabled = "akka.remote.netty.ssl.enable-ssl" - val AkkaClusterRecepcionistTunnelTimeout = "akka.contrib.cluster.receptionist.response-tunnel-receive-timeout" -} \ No newline at end of file + val AkkaClusterRecepcionistTunnelTimeout = + "akka.contrib.cluster.receptionist.response-tunnel-receive-timeout" +} diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/metadata/JavaTableName.scala b/driver/src/main/scala/com/stratio/crossdata/driver/metadata/JavaTableName.scala index c81dd306d..f9bf16183 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/metadata/JavaTableName.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/metadata/JavaTableName.scala @@ -16,8 +16,8 @@ package com.stratio.crossdata.driver.metadata /** - * database can be empty ("") - */ + * database can be empty ("") + */ class JavaTableName(val tableName: java.lang.String, val database: java.lang.String) { override def equals(other: Any): Boolean = other match { diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/ProjectedSelect.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/ProjectedSelect.scala index 905c5db7c..ed5641751 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/ProjectedSelect.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/ProjectedSelect.scala @@ -19,7 +19,8 @@ import com.stratio.crossdata.driver.querybuilder.dslentities.XDQLStatement class ProjectedSelect(selection: Expression*)(implicit context: String => String = x => x) { - def from(relation: Relation): SimpleRunnableQuery = new SimpleRunnableQuery(selection, relation, context) + def from(relation: Relation): SimpleRunnableQuery = + new SimpleRunnableQuery(selection, relation, context) def from(relations: Relation*): SimpleRunnableQuery = { val rel = relations.reduce((a: Relation, b: Relation) => a.join(b)) diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/RunnableQuery.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/RunnableQuery.scala index 30f11542f..01827c9c6 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/RunnableQuery.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/RunnableQuery.scala @@ -20,25 +20,28 @@ import com.stratio.crossdata.driver.querybuilder.dslentities.{And, CombinationIn object RunnableQuery { implicit class RunnableQueryAsExpression(runnableQuery: RunnableQuery) extends Expression { - override private[querybuilder] def toXDQL: String = s"( ${runnableQuery.toXDQL})" + override private[querybuilder] def toXDQL: String = + s"( ${runnableQuery.toXDQL})" } implicit class RunnableQueryAsRelation(runnableQuery: RunnableQuery) extends Relation { - override private[querybuilder] def toXDQL: String = s"( ${runnableQuery.toXDQL})" + override private[querybuilder] def toXDQL: String = + s"( ${runnableQuery.toXDQL})" } } -abstract class RunnableQuery protected(protected val context: String => String, - protected val projections: Seq[Expression], - protected val relation: Relation, - protected val filters: Option[Predicate] = None, - protected val groupingExpressions: Seq[Expression] = Seq.empty, - protected val havingExpressions: Option[Predicate] = None, - protected val ordering: Option[SortCriteria] = None, - protected val limit: Option[Int] = None, - protected val composition: Option[CombinationInfo] = None - ) extends Combinable { +abstract class RunnableQuery protected (protected val context: String => String, + protected val projections: Seq[Expression], + protected val relation: Relation, + protected val filters: Option[Predicate] = None, + protected val groupingExpressions: Seq[Expression] = + Seq.empty, + protected val havingExpressions: Option[Predicate] = None, + protected val ordering: Option[SortCriteria] = None, + protected val limit: Option[Int] = None, + protected val composition: Option[CombinationInfo] = None) + extends Combinable { def where(condition: String): this.type = where(XDQLStatement(condition)) diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/XDQLStatement.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/XDQLStatement.scala index 7319e6d7d..62b9e228e 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/XDQLStatement.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/XDQLStatement.scala @@ -17,6 +17,6 @@ package com.stratio.crossdata.driver.querybuilder.dslentities import com.stratio.crossdata.driver.querybuilder.{Relation, Predicate} -case class XDQLStatement(queryStr: String) extends Predicate with Relation{ +case class XDQLStatement(queryStr: String) extends Predicate with Relation { override private[querybuilder] def toXDQL: String = queryStr } diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/combination.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/combination.scala index aa08601c9..acb31678e 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/combination.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/combination.scala @@ -18,7 +18,6 @@ package com.stratio.crossdata.driver.querybuilder.dslentities import com.stratio.crossdata.driver.querybuilder.CrossdataSQLStatement import com.stratio.crossdata.driver.querybuilder.RunnableQuery - object CombineType extends Enumeration { type CombineType = Value val UnionAll = Value("UNION ALL") @@ -27,9 +26,10 @@ object CombineType extends Enumeration { val UnionDistinct = Value("UNION DISTINCT") } - import com.stratio.crossdata.driver.querybuilder.dslentities.CombineType.CombineType -case class CombinationInfo(combineType: CombineType, runnableQuery: RunnableQuery) extends CrossdataSQLStatement { - override private[querybuilder] def toXDQL: String = s" ${combineType.toString} ${runnableQuery.toXDQL}" -} \ No newline at end of file +case class CombinationInfo(combineType: CombineType, runnableQuery: RunnableQuery) + extends CrossdataSQLStatement { + override private[querybuilder] def toXDQL: String = + s" ${combineType.toString} ${runnableQuery.toXDQL}" +} diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/expressions.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/expressions.scala index fea8c9874..06691748f 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/expressions.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/expressions.scala @@ -17,7 +17,6 @@ package com.stratio.crossdata.driver.querybuilder.dslentities import com.stratio.crossdata.driver.querybuilder.{BinaryExpression, Expression, Predicate, UnaryExpression} - case class AsteriskExpression() extends Expression { override private[querybuilder] def toXDQL: String = "*" } @@ -89,7 +88,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryExpressi //Select expressions case class Distinct(expr: Expression*) extends Expression { - override private[querybuilder] def toXDQL: String = s" DISTINCT ${expr.map(_.toXDQL) mkString ","}" + override private[querybuilder] def toXDQL: String = + s" DISTINCT ${expr.map(_.toXDQL) mkString ","}" } case class Sum(expr: Expression) extends Expression { @@ -97,7 +97,8 @@ case class Sum(expr: Expression) extends Expression { } case class SumDistinct(expr: Expression) extends Expression { - override private[querybuilder] def toXDQL: String = s" sum( DISTINCT ${expr.toXDQL})" + override private[querybuilder] def toXDQL: String = + s" sum( DISTINCT ${expr.toXDQL})" } case class Count(expr: Expression) extends Expression { @@ -105,11 +106,13 @@ case class Count(expr: Expression) extends Expression { } case class CountDistinct(expr: Expression*) extends Expression { - override private[querybuilder] def toXDQL: String = s" count( DISTINCT ${expr.map(_.toXDQL) mkString ","})" + override private[querybuilder] def toXDQL: String = + s" count( DISTINCT ${expr.map(_.toXDQL) mkString ","})" } case class ApproxCountDistinct(expr: Expression, rsd: Double) extends Expression { - override private[querybuilder] def toXDQL: String = s" APPROXIMATE ($rsd) count ( DISTINCT ${expr.toXDQL})" + override private[querybuilder] def toXDQL: String = + s" APPROXIMATE ($rsd) count ( DISTINCT ${expr.toXDQL})" } case class Avg(expr: Expression) extends Expression { @@ -126,4 +129,4 @@ case class Max(expr: Expression) extends Expression { case class Abs(expr: Expression) extends Expression { override private[querybuilder] def toXDQL: String = s" abs(${expr.toXDQL})" -} \ No newline at end of file +} diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/identifiers.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/identifiers.scala index 4f2075421..275e24e69 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/identifiers.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/identifiers.scala @@ -20,12 +20,14 @@ import com.stratio.crossdata.driver.querybuilder.{CrossdataSQLStatement, Express sealed trait Identifier extends Expression with Relation /** - * Identifier for tables and columns - */ + * Identifier for tables and columns + */ case class EntityIdentifier(id: String) extends Identifier { override private[querybuilder] def toXDQL: String = id } -case class AliasIdentifier(underlyingEntity: CrossdataSQLStatement, alias: String) extends Identifier { - override private[querybuilder] def toXDQL: String = s" ${underlyingEntity.toXDQL} AS $alias" -} \ No newline at end of file +case class AliasIdentifier(underlyingEntity: CrossdataSQLStatement, alias: String) + extends Identifier { + override private[querybuilder] def toXDQL: String = + s" ${underlyingEntity.toXDQL} AS $alias" +} diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/joins.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/joins.scala index 8e075e0cc..4dcc767c6 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/joins.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/joins.scala @@ -35,7 +35,8 @@ import JoinType._ case class Join(private val left: Relation, private val right: Relation, private val joinType: JoinType, - private val condition: Option[Expression] = None) extends Relation { + private val condition: Option[Expression] = None) + extends Relation { def on(condition: String): Relation = on(XDQLStatement(condition)) @@ -43,5 +44,7 @@ case class Join(private val left: Relation, Join(left, right, joinType, Some(condition)) override private[querybuilder] def toXDQL: String = - s"${left.toXDQL} $joinType ${right.toXDQL}" + condition.map(c => s" ON ${c.toXDQL}").getOrElse("") -} \ No newline at end of file + s"${left.toXDQL} $joinType ${right.toXDQL}" + condition + .map(c => s" ON ${c.toXDQL}") + .getOrElse("") +} diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/predicates.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/predicates.scala index c1b162979..bb3baea2c 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/predicates.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/predicates.scala @@ -18,8 +18,7 @@ package com.stratio.crossdata.driver.querybuilder.dslentities import com.stratio.crossdata.driver.querybuilder.{BinaryExpression, Expression, Predicate} // Logical predicates -case class And(left: Expression, right: Expression) extends BinaryExpression -with Predicate { +case class And(left: Expression, right: Expression) extends BinaryExpression with Predicate { override val tokenStr = "AND" @@ -31,8 +30,7 @@ with Predicate { } } -case class Or(left: Expression, right: Expression) extends BinaryExpression -with Predicate { +case class Or(left: Expression, right: Expression) extends BinaryExpression with Predicate { override val tokenStr = "OR" @@ -51,18 +49,15 @@ private[dslentities] trait EqualityCheckers extends BinaryExpression { } // Comparison predicates -case class Equal(left: Expression, right: Expression) extends EqualityCheckers -with Predicate { +case class Equal(left: Expression, right: Expression) extends EqualityCheckers with Predicate { override val tokenStr: String = "=" } -case class Different(left: Expression, right: Expression) extends EqualityCheckers -with Predicate { +case class Different(left: Expression, right: Expression) extends EqualityCheckers with Predicate { override val tokenStr: String = "<>" } -case class LessThan(left: Expression, right: Expression) extends BinaryExpression -with Predicate { +case class LessThan(left: Expression, right: Expression) extends BinaryExpression with Predicate { override val tokenStr: String = "<" @@ -72,8 +67,9 @@ with Predicate { } } -case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryExpression -with Predicate { +case class LessThanOrEqual(left: Expression, right: Expression) + extends BinaryExpression + with Predicate { override val tokenStr: String = "<=" @@ -83,8 +79,9 @@ with Predicate { } } -case class GreaterThan(left: Expression, right: Expression) extends BinaryExpression //TODO: Review -with Predicate { +case class GreaterThan(left: Expression, right: Expression) + extends BinaryExpression //TODO: Review + with Predicate { override val tokenStr: String = ">" @@ -94,8 +91,9 @@ with Predicate { } } -case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryExpression -with Predicate { +case class GreaterThanOrEqual(left: Expression, right: Expression) + extends BinaryExpression + with Predicate { override val tokenStr: String = ">=" @@ -106,17 +104,20 @@ with Predicate { } case class IsNull(expr: Expression) extends Predicate { - override private[querybuilder] def toXDQL: String = s" ${expr.toXDQL} IS NULL" + override private[querybuilder] def toXDQL: String = + s" ${expr.toXDQL} IS NULL" } case class IsNotNull(expr: Expression) extends Predicate { - override private[querybuilder] def toXDQL: String = s" ${expr.toXDQL} IS NOT NULL" + override private[querybuilder] def toXDQL: String = + s" ${expr.toXDQL} IS NOT NULL" } case class In(left: Expression, right: Expression*) extends Expression with Predicate { - override private[querybuilder] def toXDQL: String = s" ${left.toXDQL} IN ${right map (_.toXDQL) mkString("(", ",", ")")}" + override private[querybuilder] def toXDQL: String = + s" ${left.toXDQL} IN ${right map (_.toXDQL) mkString ("(", ",", ")")}" } case class Like(left: Expression, right: Expression) extends BinaryExpression with Predicate { override val tokenStr = "LIKE" -} \ No newline at end of file +} diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/sort.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/sort.scala index 620a58a76..68ba483f9 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/sort.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/dslentities/sort.scala @@ -35,14 +35,14 @@ object SortOrder { } - -class SortOrder private(val expression: Expression, - val direction: Option[SortDirection] = None) extends Expression { - override private[querybuilder] def toXDQL: String = s"${expression.toXDQL} ${direction.getOrElse("")}" +class SortOrder private (val expression: Expression, val direction: Option[SortDirection] = None) + extends Expression { + override private[querybuilder] def toXDQL: String = + s"${expression.toXDQL} ${direction.getOrElse("")}" } - -case class SortCriteria(global: Boolean, expressions: Seq[SortOrder]) extends CrossdataSQLStatement { +case class SortCriteria(global: Boolean, expressions: Seq[SortOrder]) + extends CrossdataSQLStatement { override private[querybuilder] def toXDQL: String = (if (global) "ORDER" else "SORT") + " BY " + expressions.map(_.toXDQL).mkString(", ") } diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/inserts.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/inserts.scala index 9badff176..c7a21fb08 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/inserts.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/inserts.scala @@ -25,15 +25,17 @@ object InsertMode extends Enumeration { } class Insert { - def into(rel: Relation): ConfiguredInsert = new ConfiguredInsert(rel, InsertMode.INTO) - def overwrite(rel: Relation): ConfiguredInsert = new ConfiguredInsert(rel, InsertMode.OVERWRITE) + def into(rel: Relation): ConfiguredInsert = + new ConfiguredInsert(rel, InsertMode.INTO) + def overwrite(rel: Relation): ConfiguredInsert = + new ConfiguredInsert(rel, InsertMode.OVERWRITE) - private[Insert] class ConfiguredInsert(val target: Relation, mode: InsertMode.InsertMode) extends InitialSelectPhrases - { + private[Insert] class ConfiguredInsert(val target: Relation, mode: InsertMode.InsertMode) + extends InitialSelectPhrases { override protected def selectImp(projections: Seq[Expression]): ProjectedSelect = - new ProjectedSelect(projections:_*)(qStr => s"INSERT $mode ${target.toXDQL} $qStr") + new ProjectedSelect(projections: _*)(qStr => s"INSERT $mode ${target.toXDQL} $qStr") } private[Insert] class RunnableInsert -} \ No newline at end of file +} diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/interfaces.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/interfaces.scala index d71b616bc..ab8c72d00 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/interfaces.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/interfaces.scala @@ -25,7 +25,8 @@ trait CrossdataSQLStatement { } object Expression { - implicit def exp2sortorder(exp: Expression): SortOrder = SortOrder(exp, Ascending) + implicit def exp2sortorder(exp: Expression): SortOrder = + SortOrder(exp, Ascending) } trait Expression extends CrossdataSQLStatement { @@ -68,17 +69,14 @@ trait Expression extends CrossdataSQLStatement { def desc: SortOrder = SortOrder(this, Descending) - def in(list: Expression*): Predicate = In(this, list: _*) def like(other: Expression): Predicate = Like(this, other) - def isNull: Predicate = IsNull(this) def isNotNull: Predicate = IsNotNull(this) - def as(alias: String): Identifier = AliasIdentifier(this, alias) def as(alias: Symbol): Identifier = AliasIdentifier(this, alias.name) @@ -87,23 +85,25 @@ trait Expression extends CrossdataSQLStatement { trait Predicate extends Expression - trait Relation extends CrossdataSQLStatement { def join(other: Relation): Join = innerJoin(other) def innerJoin(other: Relation): Join = Join(this, other, JoinType.Inner) - def leftOuterJoin(other: Relation): Join = Join(this, other, JoinType.LeftOuter) + def leftOuterJoin(other: Relation): Join = + Join(this, other, JoinType.LeftOuter) - def rightOuterJoin(other: Relation): Join = Join(this, other, JoinType.RightOuter) + def rightOuterJoin(other: Relation): Join = + Join(this, other, JoinType.RightOuter) - def fullOuterJoin(other: Relation): Join = Join(this, other, JoinType.FullOuter) + def fullOuterJoin(other: Relation): Join = + Join(this, other, JoinType.FullOuter) - def leftSemiJoin(other: Relation): Join = Join(this, other, JoinType.LeftSemi) + def leftSemiJoin(other: Relation): Join = + Join(this, other, JoinType.LeftSemi) } - trait UnaryExpression extends Expression { val child: Expression @@ -127,4 +127,4 @@ trait BinaryExpression extends Expression { override private[querybuilder] def toXDQL: String = Seq(left, right) map childExpansion mkString s" $tokenStr " -} \ No newline at end of file +} diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/package.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/package.scala index 295bdc030..c7a2acfef 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/package.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/package.scala @@ -49,24 +49,30 @@ trait Literals { implicit def double2Literal(d: Double): Literal = Literal(d) implicit def string2Literal(s: String): Literal = Literal(s) implicit def date2Literal(d: Date): Literal = Literal(d) - implicit def bigDecimal2Literal(d: BigDecimal): Literal = Literal(d.underlying()) - implicit def bigDecimal2Literal(d: java.math.BigDecimal): Literal = Literal(d) + implicit def bigDecimal2Literal(d: BigDecimal): Literal = + Literal(d.underlying()) + implicit def bigDecimal2Literal(d: java.math.BigDecimal): Literal = + Literal(d) implicit def timestamp2Literal(t: Timestamp): Literal = Literal(t) implicit def binary2Literal(a: Array[Byte]): Literal = Literal(a) } trait Identifiers { - implicit def symbol2Identifier(s: Symbol): Identifier = EntityIdentifier(s.name) + implicit def symbol2Identifier(s: Symbol): Identifier = + EntityIdentifier(s.name) } trait InitialSelectPhrases { - def select(projections: Expression*): ProjectedSelect = selectImp(projections) + def select(projections: Expression*): ProjectedSelect = + selectImp(projections) - def select(projections: String): ProjectedSelect = selectImp(XDQLStatement(projections)::Nil) + def select(projections: String): ProjectedSelect = + selectImp(XDQLStatement(projections) :: Nil) - def selectAll: ProjectedSelect = selectImp(AsteriskExpression()::Nil) + def selectAll: ProjectedSelect = selectImp(AsteriskExpression() :: Nil) - protected def selectImp(projections: Seq[Expression]): ProjectedSelect = new ProjectedSelect(projections:_*)(x => x) + protected def selectImp(projections: Seq[Expression]): ProjectedSelect = + new ProjectedSelect(projections: _*)(x => x) } @@ -85,7 +91,8 @@ trait ExpressionOperators { def countDistinct(e: Expression*): Expression = CountDistinct(e: _*) - def approxCountDistinct(e: Expression, rsd: Double): Expression = ApproxCountDistinct(e, rsd) + def approxCountDistinct(e: Expression, rsd: Double): Expression = + ApproxCountDistinct(e, rsd) def avg(e: Expression): Expression = Avg(e) @@ -98,7 +105,9 @@ trait ExpressionOperators { def all: Expression = AsteriskExpression() } -package object querybuilder extends InitialSelectPhrases with InitialInsertPhrases - with Literals - with Identifiers - with ExpressionOperators \ No newline at end of file +package object querybuilder + extends InitialSelectPhrases + with InitialInsertPhrases + with Literals + with Identifiers + with ExpressionOperators diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/states.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/states.scala index 63a360c7d..0df47b024 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/states.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/states.scala @@ -17,15 +17,14 @@ package com.stratio.crossdata.driver.querybuilder import com.stratio.crossdata.driver.querybuilder.dslentities.{CombinationInfo, SortCriteria, XDQLStatement} - -class SimpleRunnableQuery private(context: String => String, - projections: Seq[Expression], - relation: Relation, - filters: Option[Predicate] = None) - extends RunnableQuery(context, projections, relation, filters) - with Sortable - with Limitable - with Groupable { +class SimpleRunnableQuery private (context: String => String, + projections: Seq[Expression], + relation: Relation, + filters: Option[Predicate] = None) + extends RunnableQuery(context, projections, relation, filters) + with Sortable + with Limitable + with Groupable { def this(projections: Seq[Expression], relation: Relation, context: String => String) = this(context, projections, relation, None) @@ -33,7 +32,8 @@ class SimpleRunnableQuery private(context: String => String, // It has to be abstract (simple runnable query has transitions) and concrete override def where(condition: Predicate): this.type = //Not knew alternatives to `asInstanceOf`: http://stackoverflow.com/a/791157/1893995 - new SimpleRunnableQuery(context, projections, relation, Some(combinePredicates(condition))).asInstanceOf[this.type] + new SimpleRunnableQuery(context, projections, relation, Some(combinePredicates(condition))) + .asInstanceOf[this.type] } @@ -42,27 +42,26 @@ class GroupedQuery(context: String => String, relation: Relation, filters: Option[Predicate] = None, groupingExpressions: Seq[Expression]) - extends RunnableQuery(context, projections, relation, filters, groupingExpressions) - with Sortable - with Limitable { + extends RunnableQuery(context, projections, relation, filters, groupingExpressions) + with Sortable + with Limitable { - def having(expression: String): HavingQuery = having(XDQLStatement(expression)) + def having(expression: String): HavingQuery = + having(XDQLStatement(expression)) def having(expression: Predicate): HavingQuery = new HavingQuery(context, projections, relation, filters, groupingExpressions, expression) - override def where(condition: Predicate): this.type = //Not knew alternatices to `asInstanceOf`: http://stackoverflow.com/a/791157/1893995 new GroupedQuery( - context, - projections, - relation, - Some(combinePredicates(condition)), - groupingExpressions + context, + projections, + relation, + Some(combinePredicates(condition)), + groupingExpressions ).asInstanceOf[this.type] - } class HavingQuery(context: String => String, @@ -71,20 +70,23 @@ class HavingQuery(context: String => String, filters: Option[Predicate] = None, groupingExpressions: Seq[Expression], havingExpressions: Predicate) - extends RunnableQuery( - context, projections, relation, - filters, groupingExpressions, Some(havingExpressions)) - with Sortable - with Limitable { + extends RunnableQuery(context, + projections, + relation, + filters, + groupingExpressions, + Some(havingExpressions)) + with Sortable + with Limitable { override def where(condition: Predicate): this.type = new HavingQuery( - context, - projections, - relation, - Some(combinePredicates(condition)), - groupingExpressions, - havingExpressions + context, + projections, + relation, + Some(combinePredicates(condition)), + groupingExpressions, + havingExpressions ).asInstanceOf[this.type] } @@ -96,20 +98,26 @@ class SortedQuery(context: String => String, groupingExpressions: Seq[Expression] = Seq.empty, havingExpressions: Option[Predicate] = None, ordering: SortCriteria) - extends RunnableQuery( - context, projections, relation, - filters, groupingExpressions, havingExpressions, Some(ordering) - ) with Limitable { + extends RunnableQuery( + context, + projections, + relation, + filters, + groupingExpressions, + havingExpressions, + Some(ordering) + ) + with Limitable { override def where(condition: Predicate): this.type = new SortedQuery( - context, - projections, - relation, - Some(combinePredicates(condition)), - groupingExpressions, - havingExpressions, - ordering + context, + projections, + relation, + Some(combinePredicates(condition)), + groupingExpressions, + havingExpressions, + ordering ).asInstanceOf[this.type] } @@ -122,20 +130,25 @@ class LimitedQuery(context: String => String, havingExpressions: Option[Predicate] = None, ordering: Option[SortCriteria], limit: Int) - extends RunnableQuery( - context, projections, relation, - filters, groupingExpressions, havingExpressions, ordering, Some(limit)) { + extends RunnableQuery(context, + projections, + relation, + filters, + groupingExpressions, + havingExpressions, + ordering, + Some(limit)) { override def where(condition: Predicate): this.type = new LimitedQuery( - context, - projections, - relation, - Some(combinePredicates(condition)), - groupingExpressions, - havingExpressions, - ordering, - limit + context, + projections, + relation, + Some(combinePredicates(condition)), + groupingExpressions, + havingExpressions, + ordering, + limit ).asInstanceOf[this.type] } @@ -149,11 +162,20 @@ class CombinedQuery(context: String => String, ordering: Option[SortCriteria], limit: Option[Int], combinationInfo: CombinationInfo) - extends RunnableQuery( - context, projections, relation, - filters, groupingExpressions, havingExpressions, ordering, limit, Some(combinationInfo) - ) with Combinable { + extends RunnableQuery( + context, + projections, + relation, + filters, + groupingExpressions, + havingExpressions, + ordering, + limit, + Some(combinationInfo) + ) + with Combinable { + + def where(condition: Predicate): this.type = + throw new Error("Predicates cannot by applied to combined queries") - def where(condition: Predicate): this.type = throw new Error("Predicates cannot by applied to combined queries") - -} \ No newline at end of file +} diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/transitions.scala b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/transitions.scala index a8e53c315..5e3b1f3ff 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/transitions.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/querybuilder/transitions.scala @@ -24,7 +24,8 @@ import com.stratio.crossdata.driver.querybuilder.dslentities.CombinationInfo trait Groupable { this: RunnableQuery => - def groupBy(groupingExpressions: String): GroupedQuery = groupBy(XDQLStatement(groupingExpressions)) + def groupBy(groupingExpressions: String): GroupedQuery = + groupBy(XDQLStatement(groupingExpressions)) def groupBy(groupingExpressions: Expression*): GroupedQuery = new GroupedQuery(context, projections, relation, filters, groupingExpressions) @@ -38,23 +39,25 @@ trait Sortable { def sortBy(ordering: Symbol): SortedQuery = sortBy(SortOrder(ordering)) - def sortBy(ordering: SortOrder*): SortedQuery = orderOrSortBy(global = false, ordering) + def sortBy(ordering: SortOrder*): SortedQuery = + orderOrSortBy(global = false, ordering) def orderBy(ordering: Symbol): SortedQuery = orderBy(SortOrder(ordering)) def orderBy(ordering: String): SortedQuery = orderBy(SortOrder(ordering)) - def orderBy(ordering: SortOrder*): SortedQuery = orderOrSortBy(global = true, ordering) + def orderBy(ordering: SortOrder*): SortedQuery = + orderOrSortBy(global = true, ordering) private def orderOrSortBy(global: Boolean, ordering: Seq[SortOrder]): SortedQuery = new SortedQuery( - context, - projections, - relation, - filters, - groupingExpressions, - havingExpressions, - SortCriteria(global, ordering) + context, + projections, + relation, + filters, + groupingExpressions, + havingExpressions, + SortCriteria(global, ordering) ) } @@ -62,15 +65,15 @@ trait Sortable { trait Limitable { this: RunnableQuery => - def limit(value: Int): LimitedQuery = new LimitedQuery( - context, - projections, - relation, - filters, - groupingExpressions, - havingExpressions, - ordering, - value) + def limit(value: Int): LimitedQuery = + new LimitedQuery(context, + projections, + relation, + filters, + groupingExpressions, + havingExpressions, + ordering, + value) } trait Combinable extends CrossdataSQLStatement { @@ -83,16 +86,15 @@ trait Combinable extends CrossdataSQLStatement { import CombineType.CombineType def unionAll(newQuery: RunnableQuery): CombinedQuery = - generateCombinedQuery { - computeCombinationInfo(newQuery, UnionAll, query => query.unionAll(newQuery)) - } + generateCombinedQuery { + computeCombinationInfo(newQuery, UnionAll, query => query.unionAll(newQuery)) + } def unionDistinct(newQuery: RunnableQuery): CombinedQuery = generateCombinedQuery { computeCombinationInfo(newQuery, UnionDistinct, query => query.unionDistinct(newQuery)) } - def intersect(newQuery: RunnableQuery): CombinedQuery = generateCombinedQuery { computeCombinationInfo(newQuery, Intersect, query => query.intersect(newQuery)) @@ -103,42 +105,41 @@ trait Combinable extends CrossdataSQLStatement { computeCombinationInfo(newQuery, Except, query => query.except(newQuery)) } - /** - * Computes the new combination info after receiving a new query. - * - * See example below: - * q1 UNION ALL q2 UNION DISTINCT q3 - * 1º) "q1 UNION ALL q2" generates a combination info for q1 (UNION ALL, q2) - * 2º) "q1_q2 UNION DISTINCT q3" generates: - * q1 should have a combination info (UNION ALL, q2 UNION DISTINCT q3) - * q2 should have a combination info (UNION DISTINCT, q3) - * - * @param newQuery incoming query - * @param newCombineType the incoming [[CombineType]] - * @param childCombination function to generate a combined query from a runnable query. - * It will be applied if the query is already a combined query - * @return the new combination info - */ + * Computes the new combination info after receiving a new query. + * + * See example below: + * q1 UNION ALL q2 UNION DISTINCT q3 + * 1º) "q1 UNION ALL q2" generates a combination info for q1 (UNION ALL, q2) + * 2º) "q1_q2 UNION DISTINCT q3" generates: + * q1 should have a combination info (UNION ALL, q2 UNION DISTINCT q3) + * q2 should have a combination info (UNION DISTINCT, q3) + * + * @param newQuery incoming query + * @param newCombineType the incoming [[CombineType]] + * @param childCombination function to generate a combined query from a runnable query. + * It will be applied if the query is already a combined query + * @return the new combination info + */ private def computeCombinationInfo( //TODO: Simplify this - newQuery: RunnableQuery, - newCombineType: CombineType, - childCombination: RunnableQuery => CombinedQuery): CombinationInfo = + newQuery: RunnableQuery, + newCombineType: CombineType, + childCombination: RunnableQuery => CombinedQuery): CombinationInfo = composition map { - case CombinationInfo(combType, previous) => CombinationInfo(combType, childCombination(previous)) + case CombinationInfo(combType, previous) => + CombinationInfo(combType, childCombination(previous)) } getOrElse { CombinationInfo(newCombineType, newQuery) } private def generateCombinedQuery(combinationInfo: CombinationInfo): CombinedQuery = - new CombinedQuery( - context, - projections, - relation, - filters, - groupingExpressions, - havingExpressions, - ordering, - limit, - combinationInfo) + new CombinedQuery(context, + projections, + relation, + filters, + groupingExpressions, + havingExpressions, + ordering, + limit, + combinationInfo) } diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/session/SessionManager.scala b/driver/src/main/scala/com/stratio/crossdata/driver/session/SessionManager.scala index 59c01dfb2..31d62279e 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/session/SessionManager.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/session/SessionManager.scala @@ -28,4 +28,3 @@ object SessionManager { } case class Authentication(user: String, password: String) - diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/shell/BasicShell.scala b/driver/src/main/scala/com/stratio/crossdata/driver/shell/BasicShell.scala index b96e6160a..fe5620f09 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/shell/BasicShell.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/shell/BasicShell.scala @@ -42,7 +42,6 @@ object BasicShell extends App { createHistoryDirectory(HistoryPath) - private def getLine(reader: ConsoleReader): Option[String] = Try(reader.readLine).recoverWith { case uie: UserInterruptException => @@ -50,14 +49,12 @@ object BasicShell extends App { Failure(uie) } toOption - private def checkEnd(line: Option[String]): Boolean = line.isEmpty || { val trimmedLine = line.get trimmedLine.equalsIgnoreCase("exit") || trimmedLine.equalsIgnoreCase("quit") } - private def close(console: ConsoleReader): Unit = { logger.info("Saving history...") val pw = new PrintWriter(PersistentHistory) @@ -68,7 +65,7 @@ object BasicShell extends App { } def loadHistory(console: ConsoleReader): Unit = { - if(PersistentHistory.exists){ + if (PersistentHistory.exists) { logger.info("Loading history...") console.setHistory(new FileHistory(PersistentHistory)) } else { @@ -85,7 +82,6 @@ object BasicShell extends App { loadHistory(console) } - initialize(console) private def runConsole(console: ConsoleReader): Unit = { @@ -93,7 +89,8 @@ object BasicShell extends App { console.println() console.println("+-----------------+-------------------------+---------------------------+") - console.println(s"| CROSSDATA ${crossdata.CrossdataVersion} | Powered by Apache Spark | Easy access to big things |") + console.println( + s"| CROSSDATA ${crossdata.CrossdataVersion} | Powered by Apache Spark | Easy access to big things |") console.println("+-----------------+-------------------------+---------------------------+") console.println() console.flush @@ -125,7 +122,7 @@ object BasicShell extends App { runConsole(console) - sys addShutdownHook{ + sys addShutdownHook { close(console) } diff --git a/driver/src/main/scala/com/stratio/crossdata/driver/util/HttpClient.scala b/driver/src/main/scala/com/stratio/crossdata/driver/util/HttpClient.scala index f764b6e11..a95b14418 100644 --- a/driver/src/main/scala/com/stratio/crossdata/driver/util/HttpClient.scala +++ b/driver/src/main/scala/com/stratio/crossdata/driver/util/HttpClient.scala @@ -53,27 +53,25 @@ class HttpClient(ctx: HttpClientContext) { val serverHttp = config.getCrossdataServerHttp val sessionUUID = session.id - for ( - request <- createRequest(s"http://$serverHttp/upload/$sessionUUID", new File(path)); - response <- http.singleRequest(request) map { - case res@HttpResponse(code, _, _, _) if code != StatusCodes.OK => - throw new RuntimeException(s"Request failed, response code: $code") - case other => other - }; - strictEntity <- response.entity.toStrict(5 seconds) - ) yield strictEntity.data.decodeString("UTF-8") + for (request <- createRequest(s"http://$serverHttp/upload/$sessionUUID", new File(path)); + response <- http.singleRequest(request) map { + case res @ HttpResponse(code, _, _, _) if code != StatusCodes.OK => + throw new RuntimeException(s"Request failed, response code: $code") + case other => other + }; + strictEntity <- response.entity.toStrict(5 seconds)) + yield strictEntity.data.decodeString("UTF-8") } private def createEntity(file: File): Future[RequestEntity] = { require(file.exists()) val fileIO = FileIO.fromFile(file) - val formData = - Multipart.FormData( + val formData = Multipart.FormData( Source.single( - Multipart.FormData.BodyPart( - "fileChunk", - HttpEntity(ContentTypes.`application/octet-stream`, file.length(), fileIO), - Map("filename" -> file.getName)))) + Multipart.FormData.BodyPart( + "fileChunk", + HttpEntity(ContentTypes.`application/octet-stream`, file.length(), fileIO), + Map("filename" -> file.getName)))) Marshal(formData).to[RequestEntity] } diff --git a/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/DefaultSource.scala b/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/DefaultSource.scala index 46766887d..5cfd2a862 100644 --- a/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/DefaultSource.scala +++ b/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/DefaultSource.scala @@ -17,7 +17,7 @@ *under the License. ** *Modifications and adaptations - Copyright (C) 2015 Stratio (http://stratio.com) -*/ + */ package com.stratio.crossdata.connector.elasticsearch import com.sksamuel.elastic4s.ElasticDsl._ @@ -38,7 +38,6 @@ import org.elasticsearch.spark.sql.ElasticsearchXDRelation import scala.util.Try - object DefaultSource { val DataSourcePushDown: String = "es.internal.spark.sql.pushdown" val DataSourcePushDownStrict: String = "es.internal.spark.sql.pushdown.strict" @@ -48,14 +47,16 @@ object DefaultSource { } /** - * This class is used by Spark to create a new [[ElasticsearchXDRelation]] - */ -class DefaultSource extends RelationProvider with SchemaRelationProvider - with CreatableRelationProvider - with TableInventory - with DataSourceRegister - with TableManipulation - with SparkLoggerComponent { + * This class is used by Spark to create a new [[ElasticsearchXDRelation]] + */ +class DefaultSource + extends RelationProvider + with SchemaRelationProvider + with CreatableRelationProvider + with TableInventory + with DataSourceRegister + with TableManipulation + with SparkLoggerComponent { import DefaultSource._ @@ -63,15 +64,20 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider override def shortName(): String = "elasticsearch" - override def createRelation(@transient sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { + override def createRelation(@transient sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { new ElasticsearchXDRelation(params(parameters), sqlContext) } - override def createRelation(@transient sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = { + override def createRelation(@transient sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { new ElasticsearchXDRelation(params(parameters), sqlContext, Some(schema)) } - override def createRelation(@transient sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], + override def createRelation(@transient sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], data: DataFrame): BaseRelation = { val relation = new ElasticsearchXDRelation(params(parameters), sqlContext, Some(data.schema)) @@ -82,7 +88,9 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider relation.insert(data, overwrite = true) case ErrorIfExists => if (relation.isEmpty()) relation.insert(data, overwrite = false) - else throw new EsHadoopIllegalStateException(s"Index ${relation.cfg.getResourceWrite} already exists") + else + throw new EsHadoopIllegalStateException( + s"Index ${relation.cfg.getResourceWrite} already exists") case Ignore => if (relation.isEmpty()) { relation.insert(data, overwrite = false) @@ -92,43 +100,51 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider } /** - * Validates the input parameters, defined in https://www.elastic.co/guide/en/elasticsearch/hadoop/current/configuration.html + * Validates the input parameters, defined in https://www.elastic.co/guide/en/elasticsearch/hadoop/current/configuration.html * * @param parameters a Map with the configurations parameters - * @return the validated map. - */ + * @return the validated map. + */ private def params(parameters: Map[String, String]) = { // '.' seems to be problematic when specifying the options - val params = parameters.map { case (k, v) => (k.replace('_', '.'), v) }.map { case (k, v) => - if (k.startsWith("es.")) (k, v) - else if (k == "path") (ConfigurationOptions.ES_RESOURCE, v) - else if (k == "pushdown") (DataSourcePushDown, v) - else if (k == "strict") (DataSourcePushDownStrict, v) - else ("es." + k, v) + val params = parameters.map { case (k, v) => (k.replace('_', '.'), v) }.map { + case (k, v) => + if (k.startsWith("es.")) (k, v) + else if (k == "path") (ConfigurationOptions.ES_RESOURCE, v) + else if (k == "pushdown") (DataSourcePushDown, v) + else if (k == "strict") (DataSourcePushDownStrict, v) + else ("es." + k, v) } // validate path - if (params.get(ConfigurationOptions.ES_RESOURCE_READ).orElse(params.get(ConfigurationOptions.ES_RESOURCE)).isEmpty) - throw new EsHadoopIllegalArgumentException("resource must be specified for Elasticsearch resources.") + if (params + .get(ConfigurationOptions.ES_RESOURCE_READ) + .orElse(params.get(ConfigurationOptions.ES_RESOURCE)) + .isEmpty) + throw new EsHadoopIllegalArgumentException( + "resource must be specified for Elasticsearch resources.") params } /** - * @inheritdoc - */ - override def generateConnectorOpts(item: Table, userOpts: Map[String, String]): Map[String, String] = Map( - ES_RESOURCE -> s"${item.database.get}/${item.tableName}" - ) ++ userOpts + * @inheritdoc + */ + override def generateConnectorOpts(item: Table, + userOpts: Map[String, String]): Map[String, String] = + Map( + ES_RESOURCE -> s"${item.database.get}/${item.tableName}" + ) ++ userOpts /** - * @inheritdoc - */ + * @inheritdoc + */ override def listTables(context: SQLContext, options: Map[String, String]): Seq[Table] = { Seq(ElasticCluster).foreach { opName => - if (!options.contains(opName)) sys.error( s"""Option "$opName" is mandatory for IMPORT TABLES""") + if (!options.contains(opName)) + sys.error(s"""Option "$opName" is mandatory for IMPORT TABLES""") } ElasticSearchConnectionUtils.listTypes(params(options)) @@ -140,8 +156,11 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider schema: StructType, options: Map[String, String]): Option[Table] = { - val (index, typeName) = ElasticSearchConnectionUtils.extractIndexAndType(options).orElse(databaseName.map((_, tableName))). - getOrElse(throw new RuntimeException(s"$ES_RESOURCE is required when running CREATE EXTERNAL TABLE")) + val (index, typeName) = ElasticSearchConnectionUtils + .extractIndexAndType(options) + .orElse(databaseName.map((_, tableName))) + .getOrElse(throw new RuntimeException( + s"$ES_RESOURCE is required when running CREATE EXTERNAL TABLE")) // TODO specified mapping is not the same that the resulting mapping inferred once some data is indexed val elasticSchema = schema.map { field => @@ -160,7 +179,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider try { ElasticSearchConnectionUtils.withClientDo(options) { client => - if(!client.execute(indexExists(index)).await.isExists){ + if (!client.execute(indexExists(index)).await.isExists) { client.execute { createIndex(index).mappings() }.await @@ -182,21 +201,21 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider } } - override def dropExternalTable(context: SQLContext, - options: Map[String, String]): Try[Unit] = { + override def dropExternalTable(context: SQLContext, options: Map[String, String]): Try[Unit] = { - if(ElasticSearchConnectionUtils.numberOfTypes(options) == 1) { + if (ElasticSearchConnectionUtils.numberOfTypes(options) == 1) { val (index, _) = ElasticSearchConnectionUtils.extractIndexAndType(options).get Try { - ElasticSearchConnectionUtils.withClientDo(options){ client => + ElasticSearchConnectionUtils.withClientDo(options) { client => client.execute { deleteIndex(index) }.await } } } else { - sys.error("Cannot remove table from ElasticSearch if more than one table is persisted in the same index. Please remove it natively") + sys.error( + "Cannot remove table from ElasticSearch if more than one table is persisted in the same index. Please remove it natively") } } diff --git a/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchConnectionUtils.scala b/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchConnectionUtils.scala index 409728b32..e590d6eea 100644 --- a/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchConnectionUtils.scala +++ b/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchConnectionUtils.scala @@ -36,7 +36,6 @@ object ElasticSearchConnectionUtils { } } - private def buildClient(parameters: Map[String, String]): ElasticClient = { val host: String = parameters.getOrElse(ES_NODES, ES_NODES_DEFAULT) // TODO support for multiple host, no documentation found with expected format. @@ -50,9 +49,9 @@ object ElasticSearchConnectionUtils { } def extractIndexAndType(options: Map[String, String]): Option[(String, String)] = { - options.get(ES_RESOURCE).map{ indexType => + options.get(ES_RESOURCE).map { indexType => val indexTypeArray = indexType.split("/") - require(indexTypeArray.size==2, s"$ES_RESOURCE option has an invalid format") + require(indexTypeArray.size == 2, s"$ES_RESOURCE option has an invalid format") (indexTypeArray(0), indexTypeArray(1)) } } @@ -61,10 +60,10 @@ object ElasticSearchConnectionUtils { val adminClient = buildClient(options).admin.indices() - val indexType: Option[(String, String)] = extractIndexAndType(options) + val indexType: Option[(String, String)] = extractIndexAndType(options) val index = indexType.map(_._1).orElse(options.get(ElasticIndex)) - index.fold(listAllIndexTypes(adminClient)){indexName => + index.fold(listAllIndexTypes(adminClient)) { indexName => listIndexTypes(adminClient, indexName, indexType.map(_._2)) } @@ -73,9 +72,12 @@ object ElasticSearchConnectionUtils { import collection.JavaConversions._ private def listAllIndexTypes(adminClient: IndicesAdminClient): Seq[Table] = { - val mappings: ImmutableOpenMap[String, ImmutableOpenMap[String, MappingMetaData]] = adminClient.prepareGetIndex().get().mappings - mappings.keys().flatMap { index => - getIndexDetails(index.value, mappings.get(index.value)) + val mappings: ImmutableOpenMap[String, ImmutableOpenMap[String, MappingMetaData]] = + adminClient.prepareGetIndex().get().mappings + mappings + .keys() + .flatMap { index => + getIndexDetails(index.value, mappings.get(index.value)) } toSeq } @@ -83,30 +85,41 @@ object ElasticSearchConnectionUtils { def numberOfTypes(options: Map[String, String]): Int = { val adminClient = buildClient(options).admin.indices() - val indexType: Option[(String, String)] = extractIndexAndType(options) - val index = indexType.map(_._1).orElse(options.get(ElasticIndex)) getOrElse sys.error("Index not found") + val indexType: Option[(String, String)] = extractIndexAndType(options) + val index = indexType.map(_._1).orElse(options.get(ElasticIndex)) getOrElse sys.error( + "Index not found") adminClient.prepareGetIndex().addIndices(index).get().mappings().get(index).size() } - private def listIndexTypes(adminClient: IndicesAdminClient, indexName: String, typeName: Option[String] = None): Seq[Table] = { + private def listIndexTypes(adminClient: IndicesAdminClient, + indexName: String, + typeName: Option[String] = None): Seq[Table] = { val elasticBuilder = adminClient.prepareGetIndex().addIndices(indexName) val elasticBuilderWithTypes = typeName.fold(elasticBuilder)(elasticBuilder.addTypes(_)) - val mappings: ImmutableOpenMap[String, ImmutableOpenMap[String, MappingMetaData]] = elasticBuilderWithTypes.get().mappings + val mappings: ImmutableOpenMap[String, ImmutableOpenMap[String, MappingMetaData]] = + elasticBuilderWithTypes.get().mappings getIndexDetails(indexName, mappings.get(indexName)) } - - private def getIndexDetails(indexName:String, indexData: ImmutableOpenMap[String, MappingMetaData]): Seq[Table] ={ - indexData.keys().map(typeES => new Table(typeES.value, Some(indexName), Some(buildStructType(indexData.get(typeES.value))))).toSeq + private def getIndexDetails(indexName: String, + indexData: ImmutableOpenMap[String, MappingMetaData]): Seq[Table] = { + indexData + .keys() + .map( + typeES => + new Table(typeES.value, + Some(indexName), + Some(buildStructType(indexData.get(typeES.value))))) + .toSeq } - private def convertType(typeName:String): DataType = { + private def convertType(typeName: String): DataType = { typeName match { - case "string"=> StringType + case "string" => StringType case "integer" => IntegerType case "date" => DateType case "boolean" => BooleanType @@ -114,17 +127,24 @@ object ElasticSearchConnectionUtils { case "long" => LongType case "float" => FloatType case "null" => NullType - case _ => throw new RuntimeException (s"The type $typeName isn't supported yet in Elasticsearch connector.") + case _ => + throw new RuntimeException( + s"The type $typeName isn't supported yet in Elasticsearch connector.") } } - private def buildStructType(mapping: MappingMetaData): StructType ={ + private def buildStructType(mapping: MappingMetaData): StructType = { - val esFields = mapping.sourceAsMap().get("properties").asInstanceOf[java.util.LinkedHashMap[String,java.util.LinkedHashMap[String, String]]].toMap + val esFields = mapping + .sourceAsMap() + .get("properties") + .asInstanceOf[java.util.LinkedHashMap[String, java.util.LinkedHashMap[String, String]]] + .toMap val fields: Seq[StructField] = esFields.map { - case (colName, propertyValueMap) => StructField(colName, convertType(propertyValueMap.get("type")), false) + case (colName, propertyValueMap) => + StructField(colName, convertType(propertyValueMap.get("type")), false) }(collection.breakOut) StructType(fields) diff --git a/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchQueryProcessor.scala b/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchQueryProcessor.scala index 9f846630d..83e03fa36 100644 --- a/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchQueryProcessor.scala +++ b/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchQueryProcessor.scala @@ -32,37 +32,46 @@ import scala.util.{Failure, Try} object ElasticSearchQueryProcessor { - def apply(logicalPlan: LogicalPlan, parameters: Map[String, String], schemaProvided: Option[StructType] = None) - = new ElasticSearchQueryProcessor(logicalPlan, parameters, schemaProvided) + def apply(logicalPlan: LogicalPlan, + parameters: Map[String, String], + schemaProvided: Option[StructType] = None) = + new ElasticSearchQueryProcessor(logicalPlan, parameters, schemaProvided) } /** - * Process the logicalPlan to generate the query results - * - * @param logicalPlan [[LogicalPlan]]] to be executed - * @param parameters ElasticSearch Configuration Parameters - * @param schemaProvided Spark used defined schema - */ -class ElasticSearchQueryProcessor(val logicalPlan: LogicalPlan, val parameters: Map[String, String], - val schemaProvided: Option[StructType] = None) extends SparkLoggerComponent { + * Process the logicalPlan to generate the query results + * + * @param logicalPlan [[LogicalPlan]]] to be executed + * @param parameters ElasticSearch Configuration Parameters + * @param schemaProvided Spark used defined schema + */ +class ElasticSearchQueryProcessor(val logicalPlan: LogicalPlan, + val parameters: Map[String, String], + val schemaProvided: Option[StructType] = None) + extends SparkLoggerComponent { type Limit = Option[Int] /** - * Executes the [[LogicalPlan]]] and query the ElasticSearch database + * Executes the [[LogicalPlan]]] and query the ElasticSearch database * * @return the query result - */ + */ def execute(): Option[Array[Row]] = { - def tryRows(requiredColumns: Seq[Attribute], finalQuery: SearchDefinition, esClient: ElasticClient): Try[Array[Row]] = { + def tryRows(requiredColumns: Seq[Attribute], + finalQuery: SearchDefinition, + esClient: ElasticClient): Try[Array[Row]] = { val rows: Try[Array[Row]] = Try { val resp: SearchResponse = esClient.execute(finalQuery).await.original if (resp.getShardFailures.length > 0) { - val errors = resp.getShardFailures map { failure => failure.reason() } - throw new RuntimeException(errors mkString("Errors from ES:", ";\n", "")) + val errors = resp.getShardFailures map { failure => + failure.reason() + } + throw new RuntimeException(errors mkString ("Errors from ES:", ";\n", "")) } else { - ElasticSearchRowConverter.asRows(schemaProvided.get, resp.getHits.getHits, requiredColumns) + ElasticSearchRowConverter + .asRows(schemaProvided.get, resp.getHits.getHits, requiredColumns) } } rows @@ -88,39 +97,50 @@ class ElasticSearchQueryProcessor(val logicalPlan: LogicalPlan, val parameters: result.toOption } - - def buildNativeQuery(requiredColumns: Seq[Attribute], filters: Array[SourceFilter], query: SearchDefinition): SearchDefinition = { + def buildNativeQuery(requiredColumns: Seq[Attribute], + filters: Array[SourceFilter], + query: SearchDefinition): SearchDefinition = { val queryWithFilters = buildFilters(filters, query) selectFields(requiredColumns, queryWithFilters) } - private def buildFilters(sFilters: Array[SourceFilter], query: SearchDefinition): SearchDefinition = { + private def buildFilters(sFilters: Array[SourceFilter], + query: SearchDefinition): SearchDefinition = { val matchers = sFilters.collect { - case sources.StringContains(attribute, value) => termQuery(attribute, value.toLowerCase) - case sources.StringStartsWith(attribute, value) => prefixQuery(attribute, value.toLowerCase) + case sources.StringContains(attribute, value) => + termQuery(attribute, value.toLowerCase) + case sources.StringStartsWith(attribute, value) => + prefixQuery(attribute, value.toLowerCase) } import scala.collection.JavaConversions._ val searchFilters = sFilters.collect { case sources.EqualTo(attribute, value) => termQuery(attribute, value) - case sources.GreaterThan(attribute, value) => rangeQuery(attribute).from(value).includeLower(false) - case sources.GreaterThanOrEqual(attribute, value) => rangeQuery(attribute).gte(value.toString) - case sources.LessThan(attribute, value) => rangeQuery(attribute).to(value).includeUpper(false) - case sources.LessThanOrEqual(attribute, value) => rangeQuery(attribute).lte(value.toString) - case sources.In(attribute, value) => termsQuery(attribute, value.map(_.asInstanceOf[AnyRef]): _*) + case sources.GreaterThan(attribute, value) => + rangeQuery(attribute).from(value).includeLower(false) + case sources.GreaterThanOrEqual(attribute, value) => + rangeQuery(attribute).gte(value.toString) + case sources.LessThan(attribute, value) => + rangeQuery(attribute).to(value).includeUpper(false) + case sources.LessThanOrEqual(attribute, value) => + rangeQuery(attribute).lte(value.toString) + case sources.In(attribute, value) => + termsQuery(attribute, value.map(_.asInstanceOf[AnyRef]): _*) case sources.IsNotNull(attribute) => existsQuery(attribute) case sources.IsNull(attribute) => must(not(existsQuery(attribute))) } val matchQuery = query bool must(matchers) - val finalQuery = if (searchFilters.isEmpty) - matchQuery - else matchQuery postFilter bool { - must(searchFilters) - } + val finalQuery = + if (searchFilters.isEmpty) + matchQuery + else + matchQuery postFilter bool { + must(searchFilters) + } log.debug("LogicalPlan transformed to the Elasticsearch query:" + finalQuery.toString()) finalQuery @@ -128,13 +148,14 @@ class ElasticSearchQueryProcessor(val logicalPlan: LogicalPlan, val parameters: } private def selectFields(fields: Seq[Attribute], query: SearchDefinition): SearchDefinition = { - val stringFields: Seq[String] = fields.map(_.name) - query.fields(stringFields.toList: _*) + val stringFields: Seq[String] = fields.map(_.name) + query.fields(stringFields.toList: _*) } - def validatedNativePlan: Option[(BaseLogicalPlan, Limit)] = { - lazy val limit: Option[Int] = logicalPlan.collectFirst { case Limit(Literal(num: Int, _), _) => num } + lazy val limit: Option[Int] = logicalPlan.collectFirst { + case Limit(Literal(num: Int, _), _) => num + } def findProjectsFilters(lplan: LogicalPlan): Option[BaseLogicalPlan] = { lplan match { @@ -143,8 +164,10 @@ class ElasticSearchQueryProcessor(val logicalPlan: LogicalPlan, val parameters: findProjectsFilters(child) case PhysicalOperation(projectList, filterList, _) => - CatalystToCrossdataAdapter.getConnectorLogicalPlan(logicalPlan, projectList, filterList) match { - case (_, ProjectReport(exprIgnored), FilterReport(filtersIgnored, _)) if filtersIgnored.nonEmpty || exprIgnored.nonEmpty => + CatalystToCrossdataAdapter + .getConnectorLogicalPlan(logicalPlan, projectList, filterList) match { + case (_, ProjectReport(exprIgnored), FilterReport(filtersIgnored, _)) + if filtersIgnored.nonEmpty || exprIgnored.nonEmpty => None case (basePlan, _, _) => Some(basePlan) @@ -152,23 +175,26 @@ class ElasticSearchQueryProcessor(val logicalPlan: LogicalPlan, val parameters: } } - findProjectsFilters(logicalPlan).collect{ case bp if checkNativeFilters(bp.filters) => (bp, limit) } + findProjectsFilters(logicalPlan).collect { + case bp if checkNativeFilters(bp.filters) => (bp, limit) + } } - private[this] def checkNativeFilters(filters: Array[SourceFilter]): Boolean = filters.forall { - case _: sources.EqualTo => true - case _: sources.In => true - case _: sources.LessThan => true - case _: sources.GreaterThan => true - case _: sources.LessThanOrEqual => true - case _: sources.GreaterThanOrEqual => true - case _: sources.IsNull => true - case _: sources.IsNotNull => true - case _: sources.StringStartsWith => true - case _: sources.StringContains => true - case sources.And(left, right) => checkNativeFilters(Array(left, right)) - // TODO add more filters (Not?) - case _ => false + private[this] def checkNativeFilters(filters: Array[SourceFilter]): Boolean = + filters.forall { + case _: sources.EqualTo => true + case _: sources.In => true + case _: sources.LessThan => true + case _: sources.GreaterThan => true + case _: sources.LessThanOrEqual => true + case _: sources.GreaterThanOrEqual => true + case _: sources.IsNull => true + case _: sources.IsNotNull => true + case _: sources.StringStartsWith => true + case _: sources.StringContains => true + case sources.And(left, right) => checkNativeFilters(Array(left, right)) + // TODO add more filters (Not?) + case _ => false - } + } } diff --git a/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchRowConverter.scala b/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchRowConverter.scala index 56a993db5..cb9a002a9 100644 --- a/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchRowConverter.scala +++ b/elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchRowConverter.scala @@ -39,8 +39,9 @@ import org.joda.time.DateTime object ElasticSearchRowConverter { - - def asRows(schema: StructType, array: Array[SearchHit], requiredFields: Seq[Attribute]): Array[Row] = { + def asRows(schema: StructType, + array: Array[SearchHit], + requiredFields: Seq[Attribute]): Array[Row] = { import scala.collection.JavaConverters._ val schemaMap = schema.map(field => field.name -> field.dataType).toMap @@ -49,51 +50,48 @@ object ElasticSearchRowConverter { } } - def hitAsRow( - hitFields: Map[String, SearchHitField], - schemaMap: Map[String, DataType], - requiredFields: Seq[String]): Row = { - val values: Seq[Any] = requiredFields.map { - name => - hitFields.get(name).flatMap(v => Option(v)).map( - toSQL(_, schemaMap(name))).orNull + def hitAsRow(hitFields: Map[String, SearchHitField], + schemaMap: Map[String, DataType], + requiredFields: Seq[String]): Row = { + val values: Seq[Any] = requiredFields.map { name => + hitFields.get(name).flatMap(v => Option(v)).map(toSQL(_, schemaMap(name))).orNull } Row.fromSeq(values) } def toSQL(value: SearchHitField, dataType: DataType): Any = { - Option(value).map { case value => - //Assure value is mapped to schema constrained type. - enforceCorrectType(value.getValue, dataType) + Option(value).map { + case value => + //Assure value is mapped to schema constrained type. + enforceCorrectType(value.getValue, dataType) }.orNull } - protected def enforceCorrectType(value: Any, desiredType: DataType): Any = { - // TODO check if value==null - Option(desiredType).map { - case StringType => value.toString - case _ if value == "" => null // guard the non string type - case IntegerType => toInt(value) - case LongType => toLong(value) - case DoubleType => toDouble(value) - case DecimalType() => toDecimal(value) - case BooleanType => value.asInstanceOf[Boolean] - case TimestampType => toTimestamp(value) - case NullType => null - case DateType => toDate(value) - case _ => - sys.error(s"Unsupported datatype conversion [${value.getClass}},$desiredType]") - value - }.orNull + // TODO check if value==null + Option(desiredType).map { + case StringType => value.toString + case _ if value == "" => null // guard the non string type + case IntegerType => toInt(value) + case LongType => toLong(value) + case DoubleType => toDouble(value) + case DecimalType() => toDecimal(value) + case BooleanType => value.asInstanceOf[Boolean] + case TimestampType => toTimestamp(value) + case NullType => null + case DateType => toDate(value) + case _ => + sys.error(s"Unsupported datatype conversion [${value.getClass}},$desiredType]") + value + }.orNull } private def toInt(value: Any): Int = { import scala.language.reflectiveCalls value match { case value: String => value.toInt - case _ => value.asInstanceOf[ {def toInt: Int}].toInt + case _ => value.asInstanceOf[{ def toInt: Int }].toInt } } @@ -116,7 +114,8 @@ object ElasticSearchRowConverter { value match { case value: java.lang.Integer => Decimal(value) case value: java.lang.Long => Decimal(value) - case value: java.math.BigInteger => Decimal(new java.math.BigDecimal(value)) + case value: java.math.BigInteger => + Decimal(new java.math.BigDecimal(value)) case value: java.lang.Double => Decimal(value) case value: java.math.BigDecimal => Decimal(value) } @@ -124,12 +123,13 @@ object ElasticSearchRowConverter { private def toTimestamp(value: Any): Timestamp = { value match { - case value : String => + case value: String => val dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'hh:mm:ss.SSS") val parsedDate = dateFormat.parse(value) new java.sql.Timestamp(parsedDate.getTime) case value: java.util.Date => new Timestamp(value.getTime) - case _ => sys.error(s"Unsupported datatype conversion [${value.getClass}},Timestamp]") + case _ => + sys.error(s"Unsupported datatype conversion [${value.getClass}},Timestamp]") } } diff --git a/elasticsearch/src/main/scala/org/elasticsearch/spark/sql/ElasticsearchXDRelation.scala b/elasticsearch/src/main/scala/org/elasticsearch/spark/sql/ElasticsearchXDRelation.scala index b5c219575..9034a36db 100644 --- a/elasticsearch/src/main/scala/org/elasticsearch/spark/sql/ElasticsearchXDRelation.scala +++ b/elasticsearch/src/main/scala/org/elasticsearch/spark/sql/ElasticsearchXDRelation.scala @@ -25,45 +25,51 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, Limit} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} - /** - * ElasticSearchXDRelation inherits from ElasticsearchRelation - * and adds the NativeScan support to make Native Queries from the XDContext - * - * @param parameters Configuration form ElasticSearch - * @param sqlContext Spark SQL Context - * @param userSchema Spark User Defined Schema - */ -class ElasticsearchXDRelation(parameters: Map[String, String], sqlContext: SQLContext, userSchema: Option[StructType] = None) - extends ElasticsearchRelation(parameters, sqlContext, userSchema) with NativeScan with Logging { + * ElasticSearchXDRelation inherits from ElasticsearchRelation + * and adds the NativeScan support to make Native Queries from the XDContext + * + * @param parameters Configuration form ElasticSearch + * @param sqlContext Spark SQL Context + * @param userSchema Spark User Defined Schema + */ +class ElasticsearchXDRelation(parameters: Map[String, String], + sqlContext: SQLContext, + userSchema: Option[StructType] = None) + extends ElasticsearchRelation(parameters, sqlContext, userSchema) + with NativeScan + with Logging { /** - * Build and Execute a NativeScan for the [[LogicalPlan]] provided. - * @param optimizedLogicalPlan the [[LogicalPlan]] to be executed - * @return a list of Spark [[Row]] with the [[LogicalPlan]] execution result. - */ + * Build and Execute a NativeScan for the [[LogicalPlan]] provided. + * @param optimizedLogicalPlan the [[LogicalPlan]] to be executed + * @return a list of Spark [[Row]] with the [[LogicalPlan]] execution result. + */ override def buildScan(optimizedLogicalPlan: LogicalPlan): Option[Array[Row]] = { logDebug(s"Processing ${optimizedLogicalPlan.toString()}") val queryExecutor = ElasticSearchQueryProcessor(optimizedLogicalPlan, parameters, userSchema) queryExecutor.execute() } - /** - * Checks the ability to execute a [[LogicalPlan]]. - * - * @param logicalStep isolated plan - * @param wholeLogicalPlan the whole DataFrame tree - * @return whether the logical step within the entire logical plan is supported - */ - override def isSupported(logicalStep: LogicalPlan, wholeLogicalPlan: LogicalPlan): Boolean = logicalStep match { - case ln: LeafNode => true // TODO leafNode == LogicalRelation(xdSourceRelation) - case un: UnaryNode => un match { - case Project(_, _) | Filter(_, _) => true - case Limit(_, _)=> false //TODO add support to others - case _ => false + * Checks the ability to execute a [[LogicalPlan]]. + * + * @param logicalStep isolated plan + * @param wholeLogicalPlan the whole DataFrame tree + * @return whether the logical step within the entire logical plan is supported + */ + override def isSupported(logicalStep: LogicalPlan, wholeLogicalPlan: LogicalPlan): Boolean = + logicalStep match { + case ln: LeafNode => + true // TODO leafNode == LogicalRelation(xdSourceRelation) + case un: UnaryNode => + un match { + case Project(_, _) | Filter(_, _) => true + case Limit(_, _) => false //TODO add support to others + case _ => false + } + case unsupportedLogicalPlan => + false //TODO log.debug(s"LogicalPlan $unsupportedLogicalPlan cannot be executed natively"); } - case unsupportedLogicalPlan => false //TODO log.debug(s"LogicalPlan $unsupportedLogicalPlan cannot be executed natively"); - } } diff --git a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/DefaultSourceESSpec.scala b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/DefaultSourceESSpec.scala index 70a20f19b..ad34078dc 100644 --- a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/DefaultSourceESSpec.scala +++ b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/DefaultSourceESSpec.scala @@ -20,8 +20,7 @@ import com.stratio.crossdata.test.BaseXDTest import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType import org.elasticsearch.hadoop.cfg.ConfigurationOptions -import org.elasticsearch.hadoop.cfg.ConfigurationOptions._ -; +import org.elasticsearch.hadoop.cfg.ConfigurationOptions._; import org.scalatest.junit.JUnitRunner import org.scalatest.mock.MockitoSugar @@ -31,12 +30,13 @@ import org.mockito.Mockito._ @RunWith(classOf[JUnitRunner]) class DefaultSourceESSpec extends BaseXDTest with MockitoSugar { - "A DefaultSource " should "build a ElasticSearchXDRelation without schema" in { //Fixture val defaultDatasource = new DefaultSource() val sqlContext = mock[SQLContext] - val parameters = Map[String, String] {ConfigurationOptions.ES_RESOURCE -> "index/type"} + val parameters = Map[String, String] { + ConfigurationOptions.ES_RESOURCE -> "index/type" + } //Experimentation val result = defaultDatasource.createRelation(sqlContext, parameters) @@ -50,7 +50,9 @@ class DefaultSourceESSpec extends BaseXDTest with MockitoSugar { val defaultDatasource = new DefaultSource() val sqlContext = mock[SQLContext] val schema = mock[StructType] - val parameters = Map[String, String] {ConfigurationOptions.ES_RESOURCE -> "index/type"} + val parameters = Map[String, String] { + ConfigurationOptions.ES_RESOURCE -> "index/type" + } //Experimentation val result = defaultDatasource.createRelation(sqlContext, parameters, schema) @@ -68,12 +70,12 @@ class DefaultSourceESSpec extends BaseXDTest with MockitoSugar { val userOpts: Map[String, String] = Map(ES_HOST -> "localhost") //Experimentation - val result:Map[String, String] = defaultDatasource.generateConnectorOpts(item, userOpts) + val result: Map[String, String] = defaultDatasource.generateConnectorOpts(item, userOpts) //Expectations result should not be null - result.get(ES_RESOURCE).get should be ("index/type") - result.get(ES_HOST).get should be ("localhost") + result.get(ES_RESOURCE).get should be("index/type") + result.get(ES_HOST).get should be("localhost") } } diff --git a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticInsertCollection.scala b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticInsertCollection.scala index 3e0259643..23d6b527b 100644 --- a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticInsertCollection.scala +++ b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticInsertCollection.scala @@ -31,62 +31,61 @@ import org.scalatest.Suite import scala.util.Try - trait ElasticInsertCollection extends ElasticWithSharedContext { override protected def saveTestData: Unit = for (a <- 1 to 10) { client.get.execute { - index into Index / Type fields( - "id" -> a, - "age" -> (10 + a), - "description" -> s"A ${a}description about the Name$a", - "enrolled" -> (if (a % 2 == 0) true else null), - "name" -> s"Name $a", - "birthday" -> DateTime.parse((1980 + a) + "-01-01T10:00:00-00:00").toDate, - "salary" -> a * 1000.5, - "ageInMilis" -> DateTime.parse((1980 + a) + "-01-01T10:00:00-00:00").getMillis, - "array_test" -> List(a, a+1, a+2), - "map_test" -> Map("x" -> a, "y" -> (a+1)), - "array_map" -> Seq( Map("x" -> a), Map("y" -> (a+1)) ), - "map_array" -> Map("x" -> Seq(1,2), "y" -> Seq(2,3)) - ) + index into Index / Type fields ( + "id" -> a, + "age" -> (10 + a), + "description" -> s"A ${a}description about the Name$a", + "enrolled" -> (if (a % 2 == 0) true else null), + "name" -> s"Name $a", + "birthday" -> DateTime.parse((1980 + a) + "-01-01T10:00:00-00:00").toDate, + "salary" -> a * 1000.5, + "ageInMilis" -> DateTime.parse((1980 + a) + "-01-01T10:00:00-00:00").getMillis, + "array_test" -> List(a, a + 1, a + 2), + "map_test" -> Map("x" -> a, "y" -> (a + 1)), + "array_map" -> Seq(Map("x" -> a), Map("y" -> (a + 1))), + "map_array" -> Map("x" -> Seq(1, 2), "y" -> Seq(2, 3)) + ) }.await client.get.execute { flush index Index }.await } - override def sparkRegisterTableSQL: Seq[SparkTable] = super.sparkRegisterTableSQL :+ - str2sparkTableDesc(s"""|CREATE TEMPORARY TABLE $Type (id INT, age INT, description STRING, enrolled BOOLEAN, + override def sparkRegisterTableSQL: Seq[SparkTable] = + super.sparkRegisterTableSQL :+ + str2sparkTableDesc( + s"""|CREATE TEMPORARY TABLE $Type (id INT, age INT, description STRING, enrolled BOOLEAN, |name STRING, optionalField BOOLEAN, birthday DATE, salary DOUBLE, ageInMilis LONG, |array_test ARRAY, map_test MAP, |array_map ARRAY>, map_array MAP>)""".stripMargin) - - override def typeMapping(): MappingDefinition ={ - Type as( - "id" typed IntegerType, - "age" typed IntegerType, - "description" typed StringType, - "enrolled" typed BooleanType, - "name" typed StringType index NotAnalyzed, - "birthday" typed DateType, - "salary" typed DoubleType, - "ageInMilis" typed LongType, - "array_test" typed StringType, - "map_test" typed ObjectType - ) + override def typeMapping(): MappingDefinition = { + Type as ( + "id" typed IntegerType, + "age" typed IntegerType, + "description" typed StringType, + "enrolled" typed BooleanType, + "name" typed StringType index NotAnalyzed, + "birthday" typed DateType, + "salary" typed DoubleType, + "ageInMilis" typed LongType, + "array_test" typed StringType, + "map_test" typed ObjectType + ) } override val Type = s"students_test_insert" override val defaultOptions = Map( - "resource" -> s"$Index/$Type", - "es.nodes" -> s"$ElasticHost", - "es.port" -> s"$ElasticRestPort", - "es.nativePort" -> s"$ElasticNativePort", - "es.cluster" -> s"$ElasticClusterName" + "resource" -> s"$Index/$Type", + "es.nodes" -> s"$ElasticHost", + "es.port" -> s"$ElasticRestPort", + "es.nativePort" -> s"$ElasticNativePort", + "es.cluster" -> s"$ElasticClusterName" ) - -} \ No newline at end of file +} diff --git a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticInsertTableIT.scala b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticInsertTableIT.scala index 00f639bb5..9ed055948 100644 --- a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticInsertTableIT.scala +++ b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticInsertTableIT.scala @@ -22,27 +22,39 @@ import org.apache.spark.sql.crossdata.ExecutionType // TODO ignore while elastic fix ... or when native support select map/array values class ElasticInsertTableIT extends ElasticInsertCollection { - it should "insert a row using INSERT INTO table VALUES in ElasticSearch" ignore { - val query = s"""|INSERT INTO $Type VALUES (20, 25, 'proof description', true, 'Eve', false, '2015-01-01' , + val query = + s"""|INSERT INTO $Type VALUES (20, 25, 'proof description', true, 'Eve', false, '2015-01-01' , |1200.00, 1463646640046, ['proof'], (a->2), [ (x -> 1, y-> 1), (z -> 1) ], ( x->[1,2], y-> [3,4] ))""".stripMargin - _xdContext.sql(query).collect() should be (Row(1)::Nil) + _xdContext.sql(query).collect() should be(Row(1) :: Nil) //EXPECTATION - val results = sql(s"select id, age, description, enrolled, name, optionalField, birthday, salary, ageInMilis, array_map, map_array from $Type where id=20").collect(ExecutionType.Spark) + val results = sql( + s"select id, age, description, enrolled, name, optionalField, birthday, salary, ageInMilis, array_map, map_array from $Type where id=20") + .collect(ExecutionType.Spark) results should have length 1 results should contain - Row(20, 25, "proof description", true, "Eve", - false, "2015-01-01", 1200.00, "1463646640046".toLong, - Seq("proof"), Map("a" -> "2"), - Seq(Map("x" -> "1", "y" -> "1"), Map("z" -> "1")), - Map("x" -> Seq("1", "2"), "y" -> Seq("3", "4"))) + Row(20, + 25, + "proof description", + true, + "Eve", + false, + "2015-01-01", + 1200.00, + "1463646640046".toLong, + Seq("proof"), + Map("a" -> "2"), + Seq(Map("x" -> "1", "y" -> "1"), Map("z" -> "1")), + Map("x" -> Seq("1", "2"), "y" -> Seq("3", "4"))) } it should "insert a row using INSERT INTO table(schema) VALUES in ElasticSearch" ignore { - _xdContext.sql(s"INSERT INTO $Type(age, name, enrolled) VALUES ( 25, 'Peter', true)").collect() should be (Row(1)::Nil) + _xdContext + .sql(s"INSERT INTO $Type(age, name, enrolled) VALUES ( 25, 'Peter', true)") + .collect() should be(Row(1) :: Nil) } it should "insert multiple rows using INSERT INTO table VALUES in ElasticSearch" ignore { @@ -52,20 +64,21 @@ class ElasticInsertTableIT extends ElasticInsertCollection { |(23, 33, 'other fun description', false, 'July', false, '2015-01-08' , 1400.00, 1463046640046, [true,true], (z->1, a-> 2), [ (za -> 12) ], ( x->[1,2] )) """.stripMargin val rows: Array[Row] = _xdContext.sql(query).collect() - rows should be (Row(3)::Nil) + rows should be(Row(3) :: Nil) } it should "insert multiple rows using INSERT INTO table(schema) VALUES in ElasticSearch" ignore { - _xdContext.sql(s"INSERT INTO $Type (age, name, enrolled) VALUES ( 50, 'Samantha', true),( 1, 'Charlie', false)").collect() should be (Row(2)::Nil) + _xdContext + .sql(s"INSERT INTO $Type (age, name, enrolled) VALUES ( 50, 'Samantha', true),( 1, 'Charlie', false)") + .collect() should be(Row(2) :: Nil) } - it should "insert rows using INSERT INTO table(schema) VALUES with Arrays in ElasticSearch" ignore { val query = s"""|INSERT INTO $Type (age, name, enrolled, array_test) VALUES |( 55, 'Jules', true, [true, false]), |( 12, 'Martha', false, ['test1,t', 'test2']) """.stripMargin - _xdContext.sql(query).collect() should be (Row(2)::Nil) + _xdContext.sql(query).collect() should be(Row(2) :: Nil) } it should "insert rows using INSERT INTO table(schema) VALUES with Map in ElasticSearch" ignore { @@ -73,7 +86,7 @@ class ElasticInsertTableIT extends ElasticInsertCollection { |( 12, 'Albert', true, (x->1, y->2, z->3) ), |( 20, 'Alfred', false, (xa->1, ya->2, za->3,d -> 5) ) """.stripMargin - _xdContext.sql(query).collect() should be (Row(2)::Nil) + _xdContext.sql(query).collect() should be(Row(2) :: Nil) } it should "insert rows using INSERT INTO table(schema) VALUES with Array of Maps in ElasticSearch" ignore { @@ -81,7 +94,7 @@ class ElasticInsertTableIT extends ElasticInsertCollection { |( 1, 'Nikolai', true, [(x -> 3), (z -> 1)] ), |( 14, 'Ludwig', false, [(x -> 1, y-> 1), (z -> 1)] ) """.stripMargin - _xdContext.sql(query).collect() should be (Row(2)::Nil) + _xdContext.sql(query).collect() should be(Row(2) :: Nil) } it should "insert rows using INSERT INTO table(schema) VALUES with Map of Array in ElasticSearch" ignore { @@ -89,7 +102,7 @@ class ElasticInsertTableIT extends ElasticInsertCollection { |( 13, 'Svletiana', true, ( x->[1], y-> [3,4] ) ), |( 17, 'Wolfang', false, ( x->[1,2], y-> [3] ) ) """.stripMargin - _xdContext.sql(query).collect() should be (Row(2)::Nil) + _xdContext.sql(query).collect() should be(Row(2) :: Nil) } } diff --git a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchConnectionUtilsIT.scala b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchConnectionUtilsIT.scala index afcffc185..e6ef1bc26 100644 --- a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchConnectionUtilsIT.scala +++ b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchConnectionUtilsIT.scala @@ -19,36 +19,36 @@ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) -class ElasticSearchConnectionUtilsIT extends ElasticWithSharedContext with ElasticSearchDefaultConstants { - +class ElasticSearchConnectionUtilsIT + extends ElasticWithSharedContext + with ElasticSearchDefaultConstants { + "ElasticSearchConnectionUtils" should "build a native ES Connection" in { assumeEnvironmentIsUpAndRunning val options: Map[String, String] = Map( - "es.nodes" -> s"$ElasticHost", - "es.port" -> s"$ElasticRestPort", - "es.nativePort" -> s"$ElasticNativePort", - "es.cluster" -> s"$ElasticClusterName" + "es.nodes" -> s"$ElasticHost", + "es.port" -> s"$ElasticRestPort", + "es.nativePort" -> s"$ElasticNativePort", + "es.cluster" -> s"$ElasticClusterName" ) //Experimentation - ElasticSearchConnectionUtils.withClientDo(options){ client => - + ElasticSearchConnectionUtils.withClientDo(options) { client => //Expectations client should not be (null) } } - it should "list ElasticSearch Tables in One Index" in { assumeEnvironmentIsUpAndRunning val options: Map[String, String] = Map( - "es.nodes" -> s"$ElasticHost", - "es.port" -> s"$ElasticRestPort", - "es.nativePort" -> s"$ElasticNativePort", - "es.cluster" -> s"$ElasticClusterName", - "es.index" -> s"$Index" + "es.nodes" -> s"$ElasticHost", + "es.port" -> s"$ElasticRestPort", + "es.nativePort" -> s"$ElasticNativePort", + "es.cluster" -> s"$ElasticClusterName", + "es.index" -> s"$Index" ) //Experimentation @@ -56,8 +56,8 @@ class ElasticSearchConnectionUtilsIT extends ElasticWithSharedContext with Elast //Expectations types should not be (null) - types.size should be (1) - types(0).schema.get.size should be (8) + types.size should be(1) + types(0).schema.get.size should be(8) } @@ -65,14 +65,14 @@ class ElasticSearchConnectionUtilsIT extends ElasticWithSharedContext with Elast assumeEnvironmentIsUpAndRunning val options: Map[String, String] = Map( - "es.nodes" -> s"$ElasticHost", - "es.port" -> s"$ElasticRestPort", - "es.nativePort" -> s"$ElasticNativePort", - "es.cluster" -> s"$ElasticClusterName" + "es.nodes" -> s"$ElasticHost", + "es.port" -> s"$ElasticRestPort", + "es.nativePort" -> s"$ElasticNativePort", + "es.cluster" -> s"$ElasticClusterName" ) - ElasticSearchConnectionUtils.withClientDo(options){ client => - createIndex(client,"index_test", typeMapping()) + ElasticSearchConnectionUtils.withClientDo(options) { client => + createIndex(client, "index_test", typeMapping()) try { //Experimentation val types = ElasticSearchConnectionUtils.listTypes(options) @@ -87,20 +87,19 @@ class ElasticSearchConnectionUtilsIT extends ElasticWithSharedContext with Elast } - it should "list tables on an empty index" in { assumeEnvironmentIsUpAndRunning val options: Map[String, String] = Map( - "es.nodes" -> s"$ElasticHost", - "es.port" -> s"$ElasticRestPort", - "es.nativePort" -> s"$ElasticNativePort", - "es.cluster" -> s"$ElasticClusterName", - "es.index" -> "empty_index" + "es.nodes" -> s"$ElasticHost", + "es.port" -> s"$ElasticRestPort", + "es.nativePort" -> s"$ElasticNativePort", + "es.cluster" -> s"$ElasticClusterName", + "es.index" -> "empty_index" ) - ElasticSearchConnectionUtils.withClientDo(options){ client => - createIndex(client,"empty_index", null) + ElasticSearchConnectionUtils.withClientDo(options) { client => + createIndex(client, "empty_index", null) try { //Experimentation @@ -108,7 +107,7 @@ class ElasticSearchConnectionUtilsIT extends ElasticWithSharedContext with Elast //Expectations types should not be null - types.size should be (0) + types.size should be(0) } finally { cleanTestData(client, "empty_index") } diff --git a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchCreateExternalTableIT.scala b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchCreateExternalTableIT.scala index e1d0c4ba1..e9176404f 100644 --- a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchCreateExternalTableIT.scala +++ b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchCreateExternalTableIT.scala @@ -41,7 +41,7 @@ class ElasticSearchCreateExternalTableIT extends ElasticWithSharedContext { //Expectations val table = xdContext.table(s"$Index.newtable") table should not be null - table.schema.fieldNames should contain ("title") + table.schema.fieldNames should contain("title") client.get.admin.indices.prepareTypesExists(Index).setTypes(Type).get.isExists shouldBe true } @@ -59,14 +59,18 @@ class ElasticSearchCreateExternalTableIT extends ElasticWithSharedContext { |) """.stripMargin.replaceAll("\n", " ") - sql(createTableQUeryString).collect() + sql(createTableQUeryString).collect() //Expectations val table = xdContext.table(s"$Index.newtable2") table should not be null - table.schema.fieldNames should contain ("city") + table.schema.fieldNames should contain("city") - client.get.admin.indices.prepareTypesExists(Index).setTypes("newtable2").get.isExists shouldBe true + client.get.admin.indices + .prepareTypesExists(Index) + .setTypes("newtable2") + .get + .isExists shouldBe true } diff --git a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchDropExternalTableIT.scala b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchDropExternalTableIT.scala index bc645105c..9f0b096f8 100644 --- a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchDropExternalTableIT.scala +++ b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchDropExternalTableIT.scala @@ -47,8 +47,7 @@ class ElasticSearchDropExternalTableIT extends ElasticWithSharedContext { """.stripMargin.replaceAll("\n", " ") sql(createTableQueryString1).collect() - val createTableQueryString2 = - s"""|CREATE EXTERNAL TABLE testDrop2 (id Integer, name String) + val createTableQueryString2 = s"""|CREATE EXTERNAL TABLE testDrop2 (id Integer, name String) |USING $SourceProvider |OPTIONS ( |es.resource '$Index2/drop_table_example', @@ -60,8 +59,7 @@ class ElasticSearchDropExternalTableIT extends ElasticWithSharedContext { """.stripMargin.replaceAll("\n", " ") sql(createTableQueryString2).collect() - val createTableQueryString3 = - s"""|CREATE EXTERNAL TABLE testDrop3 (id Integer, name String) + val createTableQueryString3 = s"""|CREATE EXTERNAL TABLE testDrop3 (id Integer, name String) |USING $SourceProvider |OPTIONS ( |es.resource '$Index3/drop_table_example', @@ -73,8 +71,7 @@ class ElasticSearchDropExternalTableIT extends ElasticWithSharedContext { """.stripMargin.replaceAll("\n", " ") sql(createTableQueryString3).collect() - val createTableQueryString4 = - s"""|CREATE EXTERNAL TABLE testDrop4 (id Integer, name String) + val createTableQueryString4 = s"""|CREATE EXTERNAL TABLE testDrop4 (id Integer, name String) |USING $SourceProvider |OPTIONS ( |es.resource '$Index3/drop_table_example2', @@ -97,7 +94,7 @@ class ElasticSearchDropExternalTableIT extends ElasticWithSharedContext { //DROP val dropExternalTableQuery = s"DROP EXTERNAL TABLE $Index1.$mappingName" - sql(dropExternalTableQuery).collect() should be (Seq.empty) + sql(dropExternalTableQuery).collect() should be(Seq.empty) //Expectations an[Exception] shouldBe thrownBy(xdContext.table(s"$Index1.testDrop1")) @@ -113,7 +110,7 @@ class ElasticSearchDropExternalTableIT extends ElasticWithSharedContext { //DROP val dropExternalTableQuery = "DROP EXTERNAL TABLE testDrop2" - sql(dropExternalTableQuery).collect() should be (Seq.empty) + sql(dropExternalTableQuery).collect() should be(Seq.empty) //Expectations an[Exception] shouldBe thrownBy(xdContext.table("testDrop2")) diff --git a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchImportTablesIT.scala b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchImportTablesIT.scala index f0b650dd5..3e2667563 100644 --- a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchImportTablesIT.scala +++ b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchImportTablesIT.scala @@ -19,7 +19,6 @@ import com.sksamuel.elastic4s.ElasticDsl._ class ElasticSearchImportTablesIT extends ElasticWithSharedContext { - // IMPORT OPERATIONS it should "import all tables from a keyspace" in { @@ -28,8 +27,7 @@ class ElasticSearchImportTablesIT extends ElasticWithSharedContext { val initialLength = tableCountInHighschool xdContext.dropAllTables() - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -44,15 +42,14 @@ class ElasticSearchImportTablesIT extends ElasticWithSharedContext { sql(importQuery) //Expectations - tableCountInHighschool should be (1) - xdContext.tableNames() should contain (s"$Index.$Type") + tableCountInHighschool should be(1) + xdContext.tableNames() should contain(s"$Index.$Type") } it should "infer schema after import all tables from an Index" in { assumeEnvironmentIsUpAndRunning xdContext.dropAllTables() - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -68,7 +65,7 @@ class ElasticSearchImportTablesIT extends ElasticWithSharedContext { sql(importQuery) //Expectations - xdContext.tableNames() should contain (s"$Index.$Type") + xdContext.tableNames() should contain(s"$Index.$Type") xdContext.table(s"$Index.$Type").schema should have length 8 } @@ -76,13 +73,14 @@ class ElasticSearchImportTablesIT extends ElasticWithSharedContext { assumeEnvironmentIsUpAndRunning xdContext.dropAllTables() - ElasticSearchConnectionUtils.withClientDo(connectionOptions){ client => - client.execute { index into Index -> "NewMapping" fields { - "name" -> "luis" - }} + ElasticSearchConnectionUtils.withClientDo(connectionOptions) { client => + client.execute { + index into Index -> "NewMapping" fields { + "name" -> "luis" + } + } - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -98,7 +96,7 @@ class ElasticSearchImportTablesIT extends ElasticWithSharedContext { sql(importQuery) //Expectations - xdContext.tableNames() should contain (s"$Index.$Type") + xdContext.tableNames() should contain(s"$Index.$Type") xdContext.tableNames() should not contain s"$Index.NewMapping" } } @@ -107,13 +105,14 @@ class ElasticSearchImportTablesIT extends ElasticWithSharedContext { assumeEnvironmentIsUpAndRunning xdContext.dropAllTables() - ElasticSearchConnectionUtils.withClientDo(connectionOptions){ client => - client.execute { index into Index -> "NewMapping" fields { - "name" -> "luis" - }} + ElasticSearchConnectionUtils.withClientDo(connectionOptions) { client => + client.execute { + index into Index -> "NewMapping" fields { + "name" -> "luis" + } + } - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -126,7 +125,7 @@ class ElasticSearchImportTablesIT extends ElasticWithSharedContext { """.stripMargin //Experimentation - an [IllegalArgumentException] should be thrownBy sql(importQuery) + an[IllegalArgumentException] should be thrownBy sql(importQuery) } } @@ -135,11 +134,10 @@ class ElasticSearchImportTablesIT extends ElasticWithSharedContext { assumeEnvironmentIsUpAndRunning xdContext.dropAllTables() - ElasticSearchConnectionUtils.withClientDo(connectionOptions){ client => - createIndex(client,"index_test", typeMapping()) + ElasticSearchConnectionUtils.withClientDo(connectionOptions) { client => + createIndex(client, "index_test", typeMapping()) try { - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -167,13 +165,14 @@ class ElasticSearchImportTablesIT extends ElasticWithSharedContext { assumeEnvironmentIsUpAndRunning xdContext.dropAllTables() - ElasticSearchConnectionUtils.withClientDo(connectionOptions){ client => - client.execute { index into Index -> "NewMapping" fields { - "name" -> "luis" - }} + ElasticSearchConnectionUtils.withClientDo(connectionOptions) { client => + client.execute { + index into Index -> "NewMapping" fields { + "name" -> "luis" + } + } - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -185,15 +184,15 @@ class ElasticSearchImportTablesIT extends ElasticWithSharedContext { """.stripMargin //Experimentation - an [RuntimeException] should be thrownBy sql(importQuery) + an[RuntimeException] should be thrownBy sql(importQuery) } } lazy val connectionOptions: Map[String, String] = Map( - "es.nodes" -> s"$ElasticHost", - "es.port" -> s"$ElasticRestPort", - "es.nativePort" -> s"$ElasticNativePort", - "es.cluster" -> s"$ElasticClusterName" + "es.nodes" -> s"$ElasticHost", + "es.port" -> s"$ElasticRestPort", + "es.nativePort" -> s"$ElasticNativePort", + "es.cluster" -> s"$ElasticClusterName" ) } diff --git a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchQueryProcessorSpec.scala b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchQueryProcessorSpec.scala index c5b152a9c..8153f09f1 100644 --- a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchQueryProcessorSpec.scala +++ b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchQueryProcessorSpec.scala @@ -28,7 +28,6 @@ import org.scalatest.mock.MockitoSugar @RunWith(classOf[JUnitRunner]) class ElasticSearchQueryProcessorSpec extends BaseXDTest with MockitoSugar { - "A ElasticSearchQueryProcessor " should "build a Match All query in ES" in { //Fixture @@ -48,10 +47,10 @@ class ElasticSearchQueryProcessorSpec extends BaseXDTest with MockitoSugar { //Expectations result should not be null - result.toString().replace("\n", "").replace(" ", "") should be("{\"query\":{\"bool\":{}},\"fields\":\"title\"}") + result.toString().replace("\n", "").replace(" ", "") should be( + "{\"query\":{\"bool\":{}},\"fields\":\"title\"}") } - it should "build a Simple Filter query in ES" in { //Fixture @@ -71,10 +70,10 @@ class ElasticSearchQueryProcessorSpec extends BaseXDTest with MockitoSugar { //Expectations result should not be null - result.toString().replace("\n", "").replace(" ", "") should be("{\"query\":{\"bool\":{}},\"post_filter\":{\"bool\":{\"must\":{\"term\":{\"year\":1990}}}},\"fields\":\"title\"}") + result.toString().replace("\n", "").replace(" ", "") should be( + "{\"query\":{\"bool\":{}},\"post_filter\":{\"bool\":{\"must\":{\"term\":{\"year\":1990}}}},\"fields\":\"title\"}") } - it should "build a AND Query" in { //Fixture @@ -94,6 +93,7 @@ class ElasticSearchQueryProcessorSpec extends BaseXDTest with MockitoSugar { //Expectations result should not be null - result.toString().replace("\n", "").replace(" ", "") should be("{\"query\":{\"bool\":{}},\"post_filter\":{\"bool\":{\"must\":[{\"term\":{\"year\":1990}},{\"term\":{\"Name\":\"Lord\"}}]}},\"fields\":\"title\"}") + result.toString().replace("\n", "").replace(" ", "") should be( + "{\"query\":{\"bool\":{}},\"post_filter\":{\"bool\":{\"must\":[{\"term\":{\"year\":1990}},{\"term\":{\"Name\":\"Lord\"}}]}},\"fields\":\"title\"}") } } diff --git a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchTypesIT.scala b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchTypesIT.scala index be77ef1e9..7969ff5be 100644 --- a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchTypesIT.scala +++ b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchTypesIT.scala @@ -25,7 +25,6 @@ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class ElasticSearchTypesIT extends ElasticWithSharedContext { - // "id" typed IntegerType, // "age" typed IntegerType, // "description" typed StringType, @@ -42,24 +41,24 @@ class ElasticSearchTypesIT extends ElasticWithSharedContext { val result = dataframe.collect(Native) //Expectations - result(0).get(0).isInstanceOf[Integer] should be (true) - result(0).get(1).isInstanceOf[Integer] should be (true) - result(0).get(2).isInstanceOf[String] should be (true) - result(0).get(3).isInstanceOf[Boolean] should be (true) - result(0).get(4).isInstanceOf[String] should be (true) + result(0).get(0).isInstanceOf[Integer] should be(true) + result(0).get(1).isInstanceOf[Integer] should be(true) + result(0).get(2).isInstanceOf[String] should be(true) + result(0).get(3).isInstanceOf[Boolean] should be(true) + result(0).get(4).isInstanceOf[String] should be(true) - result(0).get(6).isInstanceOf[Date] should be (true) - result(0).get(7).isInstanceOf[Double] should be (true) - result(0).get(8).isInstanceOf[Long] should be (true) + result(0).get(6).isInstanceOf[Date] should be(true) + result(0).get(7).isInstanceOf[Double] should be(true) + result(0).get(8).isInstanceOf[Long] should be(true) - result(0).getInt(0) should be (2) - result(0).getInt(1) should be (12) - result(0).getString(2) should be ("A 2description about the Name2") - result(0).getBoolean(3) should be (true) - result(0).getString(4) should be ("Name 2") + result(0).getInt(0) should be(2) + result(0).getInt(1) should be(12) + result(0).getString(2) should be("A 2description about the Name2") + result(0).getBoolean(3) should be(true) + result(0).getString(4) should be("Name 2") - result(0).getDate(6) should be (DateTime.parse("1982-01-01T10:00:00-00:00").toDate) - result(0).getDouble(7) should be (2001.0) - result(0).getLong(8) should be (DateTime.parse("1982-01-01T10:00:00-00:00").getMillis) + result(0).getDate(6) should be(DateTime.parse("1982-01-01T10:00:00-00:00").toDate) + result(0).getDouble(7) should be(2001.0) + result(0).getLong(8) should be(DateTime.parse("1982-01-01T10:00:00-00:00").getMillis) } } diff --git a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticWithSharedContext.scala b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticWithSharedContext.scala index 4f113bacd..a8d4465a4 100644 --- a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticWithSharedContext.scala +++ b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticWithSharedContext.scala @@ -15,7 +15,6 @@ */ package com.stratio.crossdata.connector.elasticsearch - import java.util.UUID import com.sksamuel.elastic4s.{ElasticClient, ElasticsearchClientUri} @@ -32,32 +31,33 @@ import org.scalatest.Suite import scala.util.Try - -trait ElasticWithSharedContext extends SharedXDContextWithDataTest with ElasticSearchDefaultConstants with SparkLoggerComponent { +trait ElasticWithSharedContext + extends SharedXDContextWithDataTest + with ElasticSearchDefaultConstants + with SparkLoggerComponent { this: Suite => override type ClientParams = ElasticClient override val provider: String = SourceProvider override val defaultOptions = Map( - "resource" -> s"$Index/$Type", - "es.nodes" -> s"$ElasticHost", - "es.port" -> s"$ElasticRestPort", - "es.nativePort" -> s"$ElasticNativePort", - "es.cluster" -> s"$ElasticClusterName" + "resource" -> s"$Index/$Type", + "es.nodes" -> s"$ElasticHost", + "es.port" -> s"$ElasticRestPort", + "es.nativePort" -> s"$ElasticNativePort", + "es.cluster" -> s"$ElasticClusterName" ) override protected def saveTestData: Unit = for (a <- 1 to 10) { client.get.execute { - index into Index / Type fields( - "id" -> a, - "age" -> (10 + a), - "description" -> s"A ${a}description about the Name$a", - "enrolled" -> (if (a % 2 == 0) true else null), - "name" -> s"Name $a", - "birthday" -> DateTime.parse((1980 + a) + "-01-01T10:00:00-00:00").toDate, - "salary" -> a * 1000.5, - "ageInMilis" -> DateTime.parse((1980 + a) + "-01-01T10:00:00-00:00").getMillis) + index into Index / Type fields ("id" -> a, + "age" -> (10 + a), + "description" -> s"A ${a}description about the Name$a", + "enrolled" -> (if (a % 2 == 0) true else null), + "name" -> s"Name $a", + "birthday" -> DateTime.parse((1980 + a) + "-01-01T10:00:00-00:00").toDate, + "salary" -> a * 1000.5, + "ageInMilis" -> DateTime.parse((1980 + a) + "-01-01T10:00:00-00:00").getMillis) }.await client.get.execute { flush index Index @@ -69,39 +69,45 @@ trait ElasticWithSharedContext extends SharedXDContextWithDataTest with ElasticS override protected def cleanTestData: Unit = cleanTestData(client.get, Index) //Template steps: Override them - override protected def prepareClient: Option[ClientParams] = Try { - logInfo(s"Connection to elastic search, ElasticHost: $ElasticHost, ElasticNativePort:$ElasticNativePort, ElasticClusterName $ElasticClusterName") - val settings = Settings.settingsBuilder().put("cluster.name", ElasticClusterName).build() - val uri = ElasticsearchClientUri(s"elasticsearch://$ElasticHost:$ElasticNativePort") - val elasticClient = ElasticClient.transport(settings, uri) - createIndex(elasticClient, Index, typeMapping()) - elasticClient - } toOption - - override def sparkRegisterTableSQL: Seq[SparkTable] = super.sparkRegisterTableSQL :+ - str2sparkTableDesc(s"CREATE TEMPORARY TABLE $Type (id INT, age INT, description STRING, enrolled BOOLEAN, name STRING, optionalField BOOLEAN, birthday DATE, salary DOUBLE, ageInMilis LONG)") + override protected def prepareClient: Option[ClientParams] = + Try { + logInfo( + s"Connection to elastic search, ElasticHost: $ElasticHost, ElasticNativePort:$ElasticNativePort, ElasticClusterName $ElasticClusterName") + val settings = Settings.settingsBuilder().put("cluster.name", ElasticClusterName).build() + val uri = ElasticsearchClientUri(s"elasticsearch://$ElasticHost:$ElasticNativePort") + val elasticClient = ElasticClient.transport(settings, uri) + createIndex(elasticClient, Index, typeMapping()) + elasticClient + } toOption + + override def sparkRegisterTableSQL: Seq[SparkTable] = + super.sparkRegisterTableSQL :+ + str2sparkTableDesc( + s"CREATE TEMPORARY TABLE $Type (id INT, age INT, description STRING, enrolled BOOLEAN, name STRING, optionalField BOOLEAN, birthday DATE, salary DOUBLE, ageInMilis LONG)") override val runningError: String = "ElasticSearch and Spark must be up and running" - def createIndex(elasticClient: ElasticClient, indexName:String, mappings:MappingDefinition): Unit ={ + def createIndex(elasticClient: ElasticClient, + indexName: String, + mappings: MappingDefinition): Unit = { val command = Option(mappings).fold(create index indexName)(create index indexName mappings _) - elasticClient.execute {command}.await + elasticClient.execute { command }.await } - def typeMapping(): MappingDefinition ={ + def typeMapping(): MappingDefinition = { Type fields ( - "id" typed IntegerType, - "age" typed IntegerType, - "description" typed StringType, - "enrolled" typed BooleanType, - "name" typed StringType index NotAnalyzed, - "birthday" typed DateType, - "salary" typed DoubleType, - "ageInMilis" typed LongType - ) + "id" typed IntegerType, + "age" typed IntegerType, + "description" typed StringType, + "enrolled" typed BooleanType, + "name" typed StringType index NotAnalyzed, + "birthday" typed DateType, + "salary" typed DoubleType, + "ageInMilis" typed LongType + ) } - def cleanTestData(elasticClient: ElasticClient, indexName:String): Unit = { + def cleanTestData(elasticClient: ElasticClient, indexName: String): Unit = { elasticClient.execute { deleteIndex(indexName) } @@ -109,15 +115,16 @@ trait ElasticWithSharedContext extends SharedXDContextWithDataTest with ElasticS } - trait ElasticSearchDefaultConstants { private lazy val config = ConfigFactory.load() val Index = s"highschool${UUID.randomUUID.toString.replaceAll("-", "")}" val Type = s"students${UUID.randomUUID.toString.replaceAll("-", "")}" - val ElasticHost: String = Try(config.getStringList("elasticsearch.hosts")).map(_.get(0)).getOrElse("127.0.0.1") + val ElasticHost: String = + Try(config.getStringList("elasticsearch.hosts")).map(_.get(0)).getOrElse("127.0.0.1") val ElasticRestPort = 9200 val ElasticNativePort = 9300 val SourceProvider = "com.stratio.crossdata.connector.elasticsearch" - val ElasticClusterName: String = Try(config.getString("elasticsearch.cluster")).getOrElse("esCluster") + val ElasticClusterName: String = + Try(config.getString("elasticsearch.cluster")).getOrElse("esCluster") -} \ No newline at end of file +} diff --git a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticsearchConnectorIT.scala b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticsearchConnectorIT.scala index 8edc8a770..28ec63226 100644 --- a/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticsearchConnectorIT.scala +++ b/elasticsearch/src/test/scala/com/stratio/crossdata/connector/elasticsearch/ElasticsearchConnectorIT.scala @@ -34,11 +34,19 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { val result = dataframe.collect(Native) result should have length 10 - schema.fieldNames should equal (Seq("id", "age", "description", "enrolled", "name", "optionalField", "birthday", "salary", "ageInMilis")) + schema.fieldNames should equal( + Seq("id", + "age", + "description", + "enrolled", + "name", + "optionalField", + "birthday", + "salary", + "ageInMilis")) result.head.toSeq(4).toString should fullyMatch regex "Name [0-9]+" } - it should "select with projections" in { assumeEnvironmentIsUpAndRunning @@ -50,7 +58,7 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { val result = dataframe.collect(Native) result should have length 10 - schema.fieldNames should equal (Seq("name", "age")) + schema.fieldNames should equal(Seq("name", "age")) result.head.toSeq(0).toString should fullyMatch regex "Name [0-9]+" } @@ -63,7 +71,7 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { //Expectations val result = dataframe.collect(Native) result should have length 1 - result(0).get(0) should be (1) + result(0).get(0) should be(1) } it should "select with simple filter and projection" in { @@ -75,11 +83,10 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { //Expectations val result = dataframe.collect(Native) result should have length 1 - result(0).get(0) should be ("Name 2") - result(0).get(1) should be (12) + result(0).get(0) should be("Name 2") + result(0).get(1) should be(12) } - it should "select with LT filter" in { assumeEnvironmentIsUpAndRunning @@ -89,7 +96,7 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { //Expectations val result = dataframe.collect(Native) result should have length 4 - result(0).get(0).toString.toInt should be <5 + result(0).get(0).toString.toInt should be < 5 } it should "select with LTE filter" in { @@ -125,10 +132,9 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { //Expectations val result = dataframe.collect(Native) result should have length 5 - result(0).get(0).toString.toInt should be > 5 + result(0).get(0).toString.toInt should be > 5 } - it should "select with IN filter" in { assumeEnvironmentIsUpAndRunning @@ -138,11 +144,10 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { //Expectations val result = dataframe.collect(Native) result should have length 2 - result(0).get(0).toString.toInt should (be (3) or be (4)) + result(0).get(0).toString.toInt should (be(3) or be(4)) result.head.toSeq(1).toString should fullyMatch regex "Name [3,4]+" } - it should "select with Null filter" in { assumeEnvironmentIsUpAndRunning @@ -152,7 +157,7 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { //Expectations val result = dataframe.collect(Native) result should have length 5 - result(0).get(0).toString.toInt % 2 should not be 0 + result(0).get(0).toString.toInt % 2 should not be 0 result.head.toSeq(1).toString should fullyMatch regex "Name [1,3,5,7,9]+" } @@ -165,7 +170,7 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { //Expectations val result = dataframe.collect(Native) result should have length 5 - result(0).get(0).toString.toInt % 2 should be (0) + result(0).get(0).toString.toInt % 2 should be(0) result.head.toSeq(1).toString should fullyMatch regex "Name [2,4,6,8,10]+" } @@ -178,8 +183,8 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { //Expectations val result = dataframe.collect(Native) result should have length 1 - result(0).get(0) should be ("Name 3") - result(0).get(1) should be (13) + result(0).get(0) should be("Name 3") + result(0).get(1) should be(13) } it should "select with Equals String" in { @@ -191,8 +196,8 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { //Expectations val result = dataframe.collect(Native) result should have length 1 - result(0).get(0) should be ("Name 3") - result(0).get(1) should be (13) + result(0).get(0) should be("Name 3") + result(0).get(1) should be(13) } it should "select with Like String" in { @@ -204,8 +209,8 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { //Expectations val result = dataframe.collect(Native) result should have length 1 - result(0).get(0) should be ("Name 3") - result(0).get(1) should be (13) + result(0).get(0) should be("Name 3") + result(0).get(1) should be(13) } it should "select with Like %String%" in { @@ -217,8 +222,8 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { //Expectations val result = dataframe.collect(Native) result should have length 1 - result(0).get(0) should be ("Name 3") - result(0).get(1) should be (13) + result(0).get(0) should be("Name 3") + result(0).get(1) should be(13) } it should "select with Like String%" in { @@ -230,10 +235,9 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { //Expectations val result = dataframe.collect(Native) result should have length 1 - result(0).get(0) should be ("Name 4") + result(0).get(0) should be("Name 4") } - it should "test retrieve a date value" in { assumeEnvironmentIsUpAndRunning @@ -245,7 +249,7 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { val result = dataframe.collect(Native) result should have length 1 - result(0).getDate(1) should be (DateTime.parse((1981)+"-01-01T10:00:00-00:00").toDate) + result(0).getDate(1) should be(DateTime.parse((1981) + "-01-01T10:00:00-00:00").toDate) } //TODO add support for dates in query? @@ -253,13 +257,13 @@ class ElasticsearchConnectorIT extends ElasticWithSharedContext { assumeEnvironmentIsUpAndRunning //Experimentation - val dataframe = sql(s"SELECT name, age FROM $Type where birthday = '1984-01-01T10:00:00-00:00'") + val dataframe = + sql(s"SELECT name, age FROM $Type where birthday = '1984-01-01T10:00:00-00:00'") //Expectations val result = dataframe.collect(Native) result should have length 1 - result(0).get(0) should be ("Name 4") - result(0).get(1) should be (14) + result(0).get(0) should be("Name 4") + result(0).get(1) should be(14) } } - diff --git a/elasticsearch/src/test/scala/org/elasticsearch/spark/sql/ElasticSearchXDRelationSpec.scala b/elasticsearch/src/test/scala/org/elasticsearch/spark/sql/ElasticSearchXDRelationSpec.scala index 5dada0079..9f6757581 100644 --- a/elasticsearch/src/test/scala/org/elasticsearch/spark/sql/ElasticSearchXDRelationSpec.scala +++ b/elasticsearch/src/test/scala/org/elasticsearch/spark/sql/ElasticSearchXDRelationSpec.scala @@ -26,7 +26,6 @@ import org.scalatest.mock.MockitoSugar @RunWith(classOf[JUnitRunner]) class ElasticSearchXDRelationSpec extends BaseXDTest with MockitoSugar { - "An ElasticSearchXDRelation " should "support Limit Node" in { //Fixture val logicalStep = mock[Filter] @@ -38,7 +37,7 @@ class ElasticSearchXDRelationSpec extends BaseXDTest with MockitoSugar { val result = esRelation.isSupported(logicalStep, wholeLogicalPlan) //Expectations - result should be (true) + result should be(true) } it should "support LeafNode Node" in { @@ -52,7 +51,7 @@ class ElasticSearchXDRelationSpec extends BaseXDTest with MockitoSugar { val result = esRelation.isSupported(logicalStep, wholeLogicalPlan) //Expectations - result should be (true) + result should be(true) } it should "support Project Node" in { @@ -66,7 +65,7 @@ class ElasticSearchXDRelationSpec extends BaseXDTest with MockitoSugar { val result = esRelation.isSupported(logicalStep, wholeLogicalPlan) //Expectations - result should be (true) + result should be(true) } def buildElasticSearchRelation(): ElasticsearchXDRelation = { diff --git a/examples/src/main/scala/com/stratio/crossdata/examples/cassandra/CassandraExample.scala b/examples/src/main/scala/com/stratio/crossdata/examples/cassandra/CassandraExample.scala index 5f1833844..cb6b335ff 100644 --- a/examples/src/main/scala/com/stratio/crossdata/examples/cassandra/CassandraExample.scala +++ b/examples/src/main/scala/com/stratio/crossdata/examples/cassandra/CassandraExample.scala @@ -15,7 +15,6 @@ */ package com.stratio.crossdata.examples.cassandra - import org.apache.spark.sql.crossdata.XDContext import org.apache.spark.{SparkConf, SparkContext} @@ -24,9 +23,7 @@ object CassandraExample extends App with CassandraDefaultConstants { val (cluster, session) = prepareEnvironment() withCrossdataContext { xdContext => - - xdContext.sql( - s"""|CREATE TEMPORARY TABLE $Table + xdContext.sql(s"""|CREATE TEMPORARY TABLE $Table |USING $SourceProvider |OPTIONS ( |table '$Table', @@ -39,11 +36,15 @@ object CassandraExample extends App with CassandraDefaultConstants { // Native queries xdContext.sql(s"SELECT comment as b FROM $Table WHERE id = 1").show(5) - xdContext.sql(s"SELECT comment as b FROM $Table WHERE id IN(1,2,3,4,5,6,7,8,9,10) limit 2").show(5) + xdContext + .sql(s"SELECT comment as b FROM $Table WHERE id IN(1,2,3,4,5,6,7,8,9,10) limit 2") + .show(5) xdContext.sql(s"SELECT * FROM $Table ").show(5) // Spark queries - xdContext.sql(s"SELECT comment as b FROM $Table WHERE comment = 'Comment 5' AND id = 5").show(5) + xdContext + .sql(s"SELECT comment as b FROM $Table WHERE comment = 'Comment 5' AND id = 5") + .show(5) } @@ -51,9 +52,7 @@ object CassandraExample extends App with CassandraDefaultConstants { private def withCrossdataContext(commands: XDContext => Unit) = { - val sparkConf = new SparkConf(). - setAppName("CassandraExample"). - setMaster("local[4]") + val sparkConf = new SparkConf().setAppName("CassandraExample").setMaster("local[4]") val sc = new SparkContext(sparkConf) try { @@ -64,8 +63,4 @@ object CassandraExample extends App with CassandraDefaultConstants { } } - - - } - diff --git a/examples/src/main/scala/com/stratio/crossdata/examples/cassandra/package.scala b/examples/src/main/scala/com/stratio/crossdata/examples/cassandra/package.scala index 8ef96594e..ed3ead54a 100644 --- a/examples/src/main/scala/com/stratio/crossdata/examples/cassandra/package.scala +++ b/examples/src/main/scala/com/stratio/crossdata/examples/cassandra/package.scala @@ -25,7 +25,7 @@ package object cassandra { val Catalog = "highschool" val Table = "students" val CassandraHost = "127.0.0.1" - val SourceProvider = "com.stratio.crossdata.connector.cassandra"// Cassandra provider => org.apache.spark.sql.cassandra + val SourceProvider = "com.stratio.crossdata.connector.cassandra" // Cassandra provider => org.apache.spark.sql.cassandra } def prepareEnvironment(): (Cluster, Session) = { @@ -46,12 +46,15 @@ package object cassandra { private def buildTable(session: Session): Unit = { - session.execute(s"CREATE KEYSPACE $Catalog WITH replication = {'class':'SimpleStrategy', 'replication_factor':1} AND durable_writes = true;") - session.execute(s"CREATE TABLE $Catalog.$Table (id int PRIMARY KEY, age int, comment text, enrolled boolean, name text)") + session.execute( + s"CREATE KEYSPACE $Catalog WITH replication = {'class':'SimpleStrategy', 'replication_factor':1} AND durable_writes = true;") + session.execute( + s"CREATE TABLE $Catalog.$Table (id int PRIMARY KEY, age int, comment text, enrolled boolean, name text)") for (a <- 1 to 10) { - session.execute("INSERT INTO " + Catalog + "." + Table + " (id, age, comment, enrolled, name) VALUES " + - "(" + a + ", " + (10 + a) + ", 'Comment " + a + "', " + (a % 2 == 0) + ", 'Name " + a + "')") + session.execute( + "INSERT INTO " + Catalog + "." + Table + " (id, age, comment, enrolled, name) VALUES " + + "(" + a + ", " + (10 + a) + ", 'Comment " + a + "', " + (a % 2 == 0) + ", 'Name " + a + "')") } } diff --git a/examples/src/main/scala/com/stratio/crossdata/examples/driver/CassandraExample.scala b/examples/src/main/scala/com/stratio/crossdata/examples/driver/CassandraExample.scala index fe6bc8f45..3b34ba33f 100644 --- a/examples/src/main/scala/com/stratio/crossdata/examples/driver/CassandraExample.scala +++ b/examples/src/main/scala/com/stratio/crossdata/examples/driver/CassandraExample.scala @@ -21,9 +21,8 @@ import com.stratio.crossdata.driver.config.DriverConf import com.stratio.crossdata.examples.cassandra._ /** - * Driver example - Cassandra - */ - + * Driver example - Cassandra + */ sealed trait DefaultConstants { val ClusterName = "Test Cluster" val Catalog = "highschool" @@ -32,8 +31,8 @@ sealed trait DefaultConstants { val SourceProvider = "cassandra" // Cassandra provider => org.apache.spark.sql.cassandra val CassandraOptions = Map( - "cluster" -> ClusterName, - "spark_cassandra_connection_host" -> CassandraHost + "cluster" -> ClusterName, + "spark_cassandra_connection_host" -> CassandraHost ) } @@ -43,10 +42,10 @@ object DriverExample extends App with DefaultConstants { var driver: Option[Driver] = None - val driverConf = new DriverConf(). - setFlattenTables(false). - setTunnelTimeout(30). - setClusterContactPoint("127.0.0.1:13420", "127.0.0.1:13425") + val driverConf = new DriverConf() + .setFlattenTables(false) + .setTunnelTimeout(30) + .setClusterContactPoint("127.0.0.1:13420", "127.0.0.1:13425") try { diff --git a/examples/src/main/scala/com/stratio/crossdata/examples/driver/StreamingSqlExample.scala b/examples/src/main/scala/com/stratio/crossdata/examples/driver/StreamingSqlExample.scala index 1ed31d423..bee640788 100644 --- a/examples/src/main/scala/com/stratio/crossdata/examples/driver/StreamingSqlExample.scala +++ b/examples/src/main/scala/com/stratio/crossdata/examples/driver/StreamingSqlExample.scala @@ -19,18 +19,19 @@ import com.stratio.crossdata.driver.Driver import com.stratio.crossdata.driver.config.DriverConf import com.stratio.crossdata.examples.cassandra._ - /** - * Driver example - Join Kafka and Cassandra - Output to Kafka - */ -object StreamingSqlExample extends App with CassandraDefaultConstants with StreamingDefaultConstants{ + * Driver example - Join Kafka and Cassandra - Output to Kafka + */ +object StreamingSqlExample + extends App + with CassandraDefaultConstants + with StreamingDefaultConstants { val (cluster, session) = prepareEnvironment() val driver = Driver.newSession(new DriverConf().setClusterContactPoint("127.0.0.1:13420")) - val importQuery = - s"""|IMPORT TABLES + val importQuery = s"""|IMPORT TABLES |USING $SourceProvider |OPTIONS ( | cluster "$ClusterName", @@ -38,15 +39,14 @@ object StreamingSqlExample extends App with CassandraDefaultConstants with Strea |) """.stripMargin - val createEphemeralTable = - s"""|CREATE EPHEMERAL TABLE $EphemeralTableName + val createEphemeralTable = s"""|CREATE EPHEMERAL TABLE $EphemeralTableName |OPTIONS ( | receiver.kafka.topic '$InputTopic:$NumPartitionsToConsume', | receiver.kafka.groupId 'xd1' |) """.stripMargin - try{ + try { // Imports tables from Cassandra cluster driver.sql(importQuery).waitForResult() @@ -55,7 +55,6 @@ object StreamingSqlExample extends App with CassandraDefaultConstants with Strea // Adds a streaming query. It will be executed when the streaming process is running driver.sql(s"SELECT count(*) FROM $EphemeralTableName WITH WINDOW 5 SECS AS outputTopic") - // Starts the streaming process associated to the ephemeral table driver.sql(s"START $EphemeralTableName") @@ -64,14 +63,12 @@ object StreamingSqlExample extends App with CassandraDefaultConstants with Strea // Example: kafka-console-producer.sh --broker-list localhost:9092 --topic // Input events format: {"id": 1, "msg": "Hello world", "city": "Tolomango"} - // WARNING: Then, you could start a Kafka consumer in order to read the processed data from queryAlias/outputTopic // Example: kafka-console-consumer.sh --zookeeper localhost:2181 --topic // Later, we can add a query to join batch and streaming sources, which output will be other Kafka topic // NOTE: In order to produce results, you should add ids matching the id's range of Cassandra table (1 to 10) - driver.sql( - s""" + driver.sql(s""" |SELECT name FROM $EphemeralTableName INNER JOIN $Catalog.$Table |ON $EphemeralTableName.id = $Table.id |WITH WINDOW 10 SECS AS joinTopic @@ -91,11 +88,9 @@ object StreamingSqlExample extends App with CassandraDefaultConstants with Strea cleanEnvironment(cluster, session) } - } - -trait StreamingDefaultConstants{ +trait StreamingDefaultConstants { val EphemeralTableName = "t" val InputTopic = "ephtable" val NumPartitionsToConsume = "1" diff --git a/examples/src/main/scala/com/stratio/crossdata/examples/elasticsearch/ElasticsearchExample.scala b/examples/src/main/scala/com/stratio/crossdata/examples/elasticsearch/ElasticsearchExample.scala index 523ed1c45..797b1498d 100644 --- a/examples/src/main/scala/com/stratio/crossdata/examples/elasticsearch/ElasticsearchExample.scala +++ b/examples/src/main/scala/com/stratio/crossdata/examples/elasticsearch/ElasticsearchExample.scala @@ -37,9 +37,7 @@ object ElasticsearchExample extends App with ElasticsearchDefaultConstants { val client = prepareEnvironment() withCrossdataContext { xdContext => - - xdContext.sql( - s"""|CREATE TEMPORARY TABLE $Type + xdContext.sql(s"""|CREATE TEMPORARY TABLE $Type |(id INT, age INT, description STRING, enrolled BOOLEAN, name STRING) |USING $SourceProvider |OPTIONS ( @@ -58,7 +56,9 @@ object ElasticsearchExample extends App with ElasticsearchDefaultConstants { // Spark xdContext.sql(s"SELECT name as b FROM $Type WHERE age > 1 limit 7").show(5) xdContext.sql(s"SELECT description as b FROM $Type WHERE description = 'Comment 4'").show(5) - xdContext.sql(s"SELECT description as b FROM $Type WHERE description = 'Comment 2' AND id = 2").show(5) + xdContext + .sql(s"SELECT description as b FROM $Type WHERE description = 'Comment 2' AND id = 2") + .show(5) } @@ -66,9 +66,7 @@ object ElasticsearchExample extends App with ElasticsearchDefaultConstants { private def withCrossdataContext(commands: XDContext => Unit) = { - val sparkConf = new SparkConf(). - setAppName("ElasticsearchExample"). - setMaster("local[4]") + val sparkConf = new SparkConf().setAppName("ElasticsearchExample").setMaster("local[4]") val sc = new SparkContext(sparkConf) try { @@ -99,25 +97,25 @@ object ElasticsearchExample extends App with ElasticsearchDefaultConstants { private def buildTable(client: ElasticClient): Unit = { client.execute { create index s"$Index" mappings ( - mapping(s"$Type") fields( - "id" typed IntegerType, - "age" typed IntegerType, - "description" typed StringType, - "enrolled" typed BooleanType, - "name" typed StringType + mapping(s"$Type") fields ( + "id" typed IntegerType, + "age" typed IntegerType, + "description" typed StringType, + "enrolled" typed BooleanType, + "name" typed StringType ) - ) + ) }.await for (a <- 1 to 10) { client.execute { - index into s"$Index" / s"$Type" fields( - "id" -> a, - "age" -> (10 + a), - "description" -> s"Comment $a", - "enrolled" -> (a % 2 == 0), - "name" -> s"Name $a" - ) + index into s"$Index" / s"$Type" fields ( + "id" -> a, + "age" -> (10 + a), + "description" -> s"Comment $a", + "enrolled" -> (a % 2 == 0), + "name" -> s"Name $a" + ) }.await } @@ -134,4 +132,3 @@ object ElasticsearchExample extends App with ElasticsearchDefaultConstants { } } - diff --git a/examples/src/main/scala/com/stratio/crossdata/examples/mongodb/MongoDescribeExample.scala b/examples/src/main/scala/com/stratio/crossdata/examples/mongodb/MongoDescribeExample.scala index 44d53204d..e049c04b8 100644 --- a/examples/src/main/scala/com/stratio/crossdata/examples/mongodb/MongoDescribeExample.scala +++ b/examples/src/main/scala/com/stratio/crossdata/examples/mongodb/MongoDescribeExample.scala @@ -25,9 +25,7 @@ object MongoDescribeExample extends App with MongoDefaultConstants { xdContext.sql(s"DESCRIBE highschool.studentsTestDataTypes").show() } private def withCrossdataContext(commands: XDContext => Unit) = { - val sparkConf = new SparkConf(). - setAppName("MongoExample"). - setMaster("local[4]") + val sparkConf = new SparkConf().setAppName("MongoExample").setMaster("local[4]") val sc = new SparkContext(sparkConf) try { val xdContext = new XDContext(sc) @@ -37,7 +35,7 @@ object MongoDescribeExample extends App with MongoDefaultConstants { } } def prepareEnvironment(): MongoClient = { - val mongoClient = MongoClient(MongoHost,MongoPort) + val mongoClient = MongoClient(MongoHost, MongoPort) mongoClient } -} \ No newline at end of file +} diff --git a/examples/src/main/scala/com/stratio/crossdata/examples/mongodb/MongoExample.scala b/examples/src/main/scala/com/stratio/crossdata/examples/mongodb/MongoExample.scala index 56b02b02e..964069f52 100644 --- a/examples/src/main/scala/com/stratio/crossdata/examples/mongodb/MongoExample.scala +++ b/examples/src/main/scala/com/stratio/crossdata/examples/mongodb/MongoExample.scala @@ -25,9 +25,7 @@ object MongoExample extends App with MongoDefaultConstants { val mongoClient = prepareEnvironment() withCrossdataContext { xdContext => - - xdContext.sql( - s"""|CREATE TEMPORARY TABLE $Collection + xdContext.sql(s"""|CREATE TEMPORARY TABLE $Collection |(id STRING, age INT, description STRING, enrolled BOOLEAN, name STRING) |USING $MongoConnector |OPTIONS ( @@ -46,11 +44,10 @@ object MongoExample extends App with MongoDefaultConstants { //xdContext.sql(s"SELECT id, age FROM $Collection WHERE id LIKE '1%'").show(5) xdContext.sql(s"SELECT id, name FROM $Collection WHERE name LIKE '%ame%'").show(5) - //Spark xdContext.sql(s"SELECT count(*), avg(age) FROM $Collection GROUP BY enrolled").show(5) - /* TODO CREATE TABLE AS SELECT EXAMPLE + /* TODO CREATE TABLE AS SELECT EXAMPLE xdContext.sql( s"""|CREATE TABLE newTable |USING $SourceProvider @@ -67,9 +64,7 @@ object MongoExample extends App with MongoDefaultConstants { private def withCrossdataContext(commands: XDContext => Unit) = { - val sparkConf = new SparkConf(). - setAppName("MongoExample"). - setMaster("local[4]") + val sparkConf = new SparkConf().setAppName("MongoExample").setMaster("local[4]") val sc = new SparkContext(sparkConf) try { @@ -82,7 +77,7 @@ object MongoExample extends App with MongoDefaultConstants { } def prepareEnvironment(): MongoClient = { - val mongoClient = MongoClient(MongoHost,MongoPort) + val mongoClient = MongoClient(MongoHost, MongoPort) populateTable(mongoClient) mongoClient } @@ -92,18 +87,16 @@ object MongoExample extends App with MongoDefaultConstants { mongoClient.close() } - private def populateTable(client: MongoClient): Unit = { val collection = client(Database)(Collection) for (a <- 1 to 10) { - collection.insert{ + collection.insert { MongoDBObject("id" -> a.toString, - "age" -> (10+a), + "age" -> (10 + a), "description" -> s"description $a", - "enrolled" -> (a % 2 == 0 ), - "name" -> s"Name $a" - ) + "enrolled" -> (a % 2 == 0), + "name" -> s"Name $a") } } } @@ -113,5 +106,4 @@ object MongoExample extends App with MongoDefaultConstants { collection.dropCollection() } - -} \ No newline at end of file +} diff --git a/mongodb/src/main/scala/com/stratio/crossdata/connector/mongodb/DefaultSource.scala b/mongodb/src/main/scala/com/stratio/crossdata/connector/mongodb/DefaultSource.scala index c306ac80b..db99f2836 100644 --- a/mongodb/src/main/scala/com/stratio/crossdata/connector/mongodb/DefaultSource.scala +++ b/mongodb/src/main/scala/com/stratio/crossdata/connector/mongodb/DefaultSource.scala @@ -32,93 +32,96 @@ import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} import scala.util.{Failure, Try} /** - * Allows creation of MongoDB based tables using - * the syntax CREATE TEMPORARY TABLE ... USING com.stratio.deep.mongodb. - * Required options are detailed in [[com.stratio.datasource.mongodb.config.MongodbConfig]] - */ -class DefaultSource extends ProviderDS with TableInventory with DataSourceRegister with TableManipulation{ + * Allows creation of MongoDB based tables using + * the syntax CREATE TEMPORARY TABLE ... USING com.stratio.deep.mongodb. + * Required options are detailed in [[com.stratio.datasource.mongodb.config.MongodbConfig]] + */ +class DefaultSource + extends ProviderDS + with TableInventory + with DataSourceRegister + with TableManipulation { import MongodbConfig._ /** * if the collection is capped */ - val MongoCollectionPropertyCapped:String= "capped" + val MongoCollectionPropertyCapped: String = "capped" /** * collection size */ - val MongoCollectionPropertySize:String= "size" + val MongoCollectionPropertySize: String = "size" /** * max number of documents */ - val MongoCollectionPropertyMax:String= "max" + val MongoCollectionPropertyMax: String = "max" override def shortName(): String = "mongodb" - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { + override def createRelation(sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { - MongodbXDRelation( - MongodbConfigBuilder(parseParameters(parameters)) - .build())(sqlContext) + MongodbXDRelation(MongodbConfigBuilder(parseParameters(parameters)).build())(sqlContext) } - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String], - schema: StructType): BaseRelation = { + override def createRelation(sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { - MongodbXDRelation( - MongodbConfigBuilder(parseParameters(parameters)) - .build(),Some(schema))(sqlContext) + MongodbXDRelation(MongodbConfigBuilder(parseParameters(parameters)).build(), Some(schema))( + sqlContext) } - override def createRelation( - sqlContext: SQLContext, - mode: SaveMode, - parameters: Map[String, String], - data: DataFrame): BaseRelation = { - - val mongodbRelation = MongodbXDRelation( - MongodbConfigBuilder(parseParameters(parameters)) - .build())(sqlContext) - - mode match{ - case Append => mongodbRelation.insert(data, overwrite = false) - case Overwrite => mongodbRelation.insert(data, overwrite = true) - case ErrorIfExists => if(mongodbRelation.isEmptyCollection) mongodbRelation.insert(data, overwrite = false) - else throw new UnsupportedOperationException("Writing in a non-empty collection.") - case Ignore => if(mongodbRelation.isEmptyCollection) mongodbRelation.insert(data, overwrite = false) + override def createRelation(sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + + val mongodbRelation = + MongodbXDRelation(MongodbConfigBuilder(parseParameters(parameters)).build())(sqlContext) + + mode match { + case Append => mongodbRelation.insert(data, overwrite = false) + case Overwrite => mongodbRelation.insert(data, overwrite = true) + case ErrorIfExists => + if (mongodbRelation.isEmptyCollection) + mongodbRelation.insert(data, overwrite = false) + else + throw new UnsupportedOperationException("Writing in a non-empty collection.") + case Ignore => + if (mongodbRelation.isEmptyCollection) + mongodbRelation.insert(data, overwrite = false) } mongodbRelation } - /** - * @inheritdoc - */ - override def generateConnectorOpts(item: Table, userOpts: Map[String, String]): Map[String, String] = Map( - Database -> item.database.get, - Collection -> item.tableName - ) ++ userOpts + * @inheritdoc + */ + override def generateConnectorOpts(item: Table, + userOpts: Map[String, String]): Map[String, String] = + Map( + Database -> item.database.get, + Collection -> item.tableName + ) ++ userOpts /** - * @inheritdoc - */ + * @inheritdoc + */ override def listTables(context: SQLContext, options: Map[String, String]): Seq[Table] = { Seq(Host).foreach { opName => - if (!options.contains(opName)) sys.error( s"""Option "$opName" is mandatory for IMPORT TABLES""") + if (!options.contains(opName)) + sys.error(s"""Option "$opName" is mandatory for IMPORT TABLES""") } MongodbConnection.withClientDo(parseParametersWithoutValidation(options)) { mongoClient => - def extractAllDatabases: Seq[MongoDB] = mongoClient.getDatabaseNames().map(mongoClient.getDB) @@ -128,7 +131,7 @@ class DefaultSource extends ProviderDS with TableInventory with DataSourceRegist val tablesIt: Iterable[Table] = for { database: MongoDB <- extractAllDatabases collection: DBCollection <- extractAllCollections(database) - if options.get(Database).forall( _ == collection.getDB.getName) + if options.get(Database).forall(_ == collection.getDB.getName) if options.get(Collection).forall(_ == collection.getName) } yield { collectionToTable(context, options, database.getName, collection.getName) @@ -141,7 +144,10 @@ class DefaultSource extends ProviderDS with TableInventory with DataSourceRegist override def exclusionFilter(t: TableInventory.Table): Boolean = !t.tableName.startsWith("""system.""") && !t.database.get.equals("local") - private def collectionToTable(context: SQLContext, options: Map[String, String], database: String, collection: String): Table = { + private def collectionToTable(context: SQLContext, + options: Map[String, String], + database: String, + collection: String): Table = { val collectionConfig = MongodbConfigBuilder() .apply(parseParameters(options + (Database -> database) + (Collection -> collection))) @@ -155,16 +161,22 @@ class DefaultSource extends ProviderDS with TableInventory with DataSourceRegist schema: StructType, options: Map[String, String]): Option[Table] = { - val database: String = options.get(Database).orElse(databaseName). - getOrElse(throw new RuntimeException(s"$Database required when use CREATE EXTERNAL TABLE command")) + val database: String = options + .get(Database) + .orElse(databaseName) + .getOrElse( + throw new RuntimeException(s"$Database required when use CREATE EXTERNAL TABLE command")) val collection: String = options.getOrElse(Collection, tableName) val mongoOptions = DBObject() options.map { - case (MongoCollectionPropertyCapped, value) => mongoOptions.put(MongoCollectionPropertyCapped, value) - case (MongoCollectionPropertySize, value) => mongoOptions.put(MongoCollectionPropertySize, value.toInt) - case (MongoCollectionPropertyMax, value) => mongoOptions.put(MongoCollectionPropertyMax, value.toInt) + case (MongoCollectionPropertyCapped, value) => + mongoOptions.put(MongoCollectionPropertyCapped, value) + case (MongoCollectionPropertySize, value) => + mongoOptions.put(MongoCollectionPropertySize, value.toInt) + case (MongoCollectionPropertyMax, value) => + mongoOptions.put(MongoCollectionPropertyMax, value.toInt) case _ => } @@ -180,56 +192,66 @@ class DefaultSource extends ProviderDS with TableInventory with DataSourceRegist } } - override def dropExternalTable(context: SQLContext, - options: Map[String, String]): Try[Unit] = { + override def dropExternalTable(context: SQLContext, options: Map[String, String]): Try[Unit] = { val tupleDbColl = for { db <- options.get(Database) coll <- options.get(Collection) } yield (db, coll) tupleDbColl.fold[Try[Unit]]( - ifEmpty = Failure(throw new RuntimeException(s"Required options not found ${Set(Database, Collection) -- options.keys}")) - ) { case (dbName, collName) => - Try { - MongodbConnection.withClientDo(parseParametersWithoutValidation(options)) { mongoClient => - mongoClient.getDB(dbName).getCollection(collName).drop() + ifEmpty = Failure( + throw new RuntimeException( + s"Required options not found ${Set(Database, Collection) -- options.keys}")) + ) { + case (dbName, collName) => + Try { + MongodbConnection.withClientDo(parseParametersWithoutValidation(options)) { + mongoClient => + mongoClient.getDB(dbName).getCollection(collName).drop() + } } - } } } - // TODO refactor datasource -> avoid duplicated method - def parseParametersWithoutValidation(parameters : Map[String,String]): Config = { + def parseParametersWithoutValidation(parameters: Map[String, String]): Config = { // required properties /** We will assume hosts are provided like 'host:port,host2:port2,...' */ - val properties: Map[String, Any] = parameters.updated(Host, parameters.getOrElse(Host, notFound[String](Host)).split(",").toList) + val properties: Map[String, Any] = parameters + .updated(Host, parameters.getOrElse(Host, notFound[String](Host)).split(",").toList) //optional parseable properties - val optionalProperties: List[String] = List(Credentials,SSLOptions, UpdateFields) + val optionalProperties: List[String] = List(Credentials, SSLOptions, UpdateFields) + + val finalProperties = (properties /: optionalProperties) { - val finalProperties = (properties /: optionalProperties){ /** We will assume credentials are provided like 'user,database,password;user,database,password;...' */ - case (properties,Credentials) => - parameters.get(Credentials).map{ credentialInput => - val credentials = credentialInput.split(";").map(_.split(",")).toList - .map(credentials => MongodbCredentials(credentials(0), credentials(1), credentials(2).toCharArray)) + case (properties, Credentials) => + parameters.get(Credentials).map { credentialInput => + val credentials = credentialInput + .split(";") + .map(_.split(",")) + .toList + .map(credentials => + MongodbCredentials(credentials(0), credentials(1), credentials(2).toCharArray)) properties + (Credentials -> credentials) } getOrElse properties /** We will assume ssloptions are provided like '/path/keystorefile,keystorepassword,/path/truststorefile,truststorepassword' */ - case (properties,SSLOptions) => - parameters.get(SSLOptions).map{ ssloptionsInput => - + case (properties, SSLOptions) => + parameters.get(SSLOptions).map { ssloptionsInput => val ssloption = ssloptionsInput.split(",") - val ssloptions = MongodbSSLOptions(Some(ssloption(0)), Some(ssloption(1)), ssloption(2), Some(ssloption(3))) + val ssloptions = MongodbSSLOptions(Some(ssloption(0)), + Some(ssloption(1)), + ssloption(2), + Some(ssloption(3))) properties + (SSLOptions -> ssloptions) } getOrElse properties /** We will assume fields are provided like 'user,database,password...' */ case (properties, UpdateFields) => - parameters.get(UpdateFields).map{ updateInputs => + parameters.get(UpdateFields).map { updateInputs => val updateFields = updateInputs.split(",") properties + (UpdateFields -> updateFields) } getOrElse properties @@ -250,4 +272,4 @@ class DefaultSource extends ProviderDS with TableInventory with DataSourceRegist def apply(props: Map[Property, Any]) = MongodbConnectorConfigBuilder(props) } -} \ No newline at end of file +} diff --git a/mongodb/src/main/scala/com/stratio/crossdata/connector/mongodb/MongoQueryProcessor.scala b/mongodb/src/main/scala/com/stratio/crossdata/connector/mongodb/MongoQueryProcessor.scala index e6dff5515..be619b5a2 100644 --- a/mongodb/src/main/scala/com/stratio/crossdata/connector/mongodb/MongoQueryProcessor.scala +++ b/mongodb/src/main/scala/com/stratio/crossdata/connector/mongodb/MongoQueryProcessor.scala @@ -39,27 +39,28 @@ object MongoQueryProcessor { type ColumnName = String type Limit = Option[Int] - case class MongoPlan(basePlan: BaseLogicalPlan, limit: Limit){ + case class MongoPlan(basePlan: BaseLogicalPlan, limit: Limit) { def projects: Seq[NamedExpression] = basePlan.projects def filters: Array[SourceFilter] = basePlan.filters } - def apply(logicalPlan: LogicalPlan, config: Config, schemaProvided: Option[StructType] = None) = new MongoQueryProcessor(logicalPlan, config, schemaProvided) + def apply(logicalPlan: LogicalPlan, config: Config, schemaProvided: Option[StructType] = None) = + new MongoQueryProcessor(logicalPlan, config, schemaProvided) def buildNativeQuery( - requiredColums: Seq[ColumnName], - filters: Array[SourceFilter], - config: Config, - name2randomAccess: Map[String, GetArrayItem] = Map.empty - ): (DBObject, DBObject) = { + requiredColums: Seq[ColumnName], + filters: Array[SourceFilter], + config: Config, + name2randomAccess: Map[String, GetArrayItem] = Map.empty + ): (DBObject, DBObject) = { (filtersToDBObject(filters, name2randomAccess)(config), selectFields(requiredColums)) } def filtersToDBObject( - sFilters: Array[SourceFilter], - name2randomAccess: Map[String, GetArrayItem], - parentFilterIsNot: Boolean = false - )(implicit config: Config): DBObject = { + sFilters: Array[SourceFilter], + name2randomAccess: Map[String, GetArrayItem], + parentFilterIsNot: Boolean = false + )(implicit config: Config): DBObject = { def attstr2left(att: String): String = name2randomAccess.get(att).map { @@ -76,9 +77,13 @@ object MongoQueryProcessor { case sources.GreaterThan(attribute, value) => queryBuilder.put(attstr2left(attribute)).greaterThan(correctIdValue(attribute, value)) case sources.GreaterThanOrEqual(attribute, value) => - queryBuilder.put(attstr2left(attribute)).greaterThanEquals(correctIdValue(attribute, value)) + queryBuilder + .put(attstr2left(attribute)) + .greaterThanEquals(correctIdValue(attribute, value)) case sources.In(attribute, values) => - queryBuilder.put(attstr2left(attribute)).in(values.map(value => correctIdValue(attribute, value))) + queryBuilder + .put(attstr2left(attribute)) + .in(values.map(value => correctIdValue(attribute, value))) case sources.LessThan(attribute, value) => queryBuilder.put(attstr2left(attribute)).lessThan(correctIdValue(attribute, value)) case sources.LessThanOrEqual(attribute, value) => @@ -89,10 +94,10 @@ object MongoQueryProcessor { queryBuilder.put(attstr2left(attribute)).notEquals(null) case sources.And(leftFilter, rightFilter) if !parentFilterIsNot => queryBuilder.and(filtersToDBObject(Array(leftFilter), name2randomAccess), - filtersToDBObject(Array(rightFilter),name2randomAccess)) + filtersToDBObject(Array(rightFilter), name2randomAccess)) case sources.Or(leftFilter, rightFilter) if !parentFilterIsNot => - queryBuilder.or(filtersToDBObject(Array(leftFilter),name2randomAccess), - filtersToDBObject(Array(rightFilter), name2randomAccess)) + queryBuilder.or(filtersToDBObject(Array(leftFilter), name2randomAccess), + filtersToDBObject(Array(rightFilter), name2randomAccess)) case sources.StringStartsWith(attribute, value) if !parentFilterIsNot => queryBuilder.put(attstr2left(attribute)).regex(Pattern.compile("^" + value + ".*$")) case sources.StringEndsWith(attribute, value) if !parentFilterIsNot => @@ -106,50 +111,52 @@ object MongoQueryProcessor { queryBuilder.get } - /** - * Check if the field is "_id" and if the user wants to filter by this field as an ObjectId - * - * @param attribute Name of the file - * @param value Value for the attribute - * @return The value in the correct data type - */ - private def correctIdValue(attribute: String, value: Any)(implicit config: Config) : Any = { - - val idAsObjectId: Boolean = config.getOrElse[String](MongodbConfig.IdAsObjectId, MongodbConfig.DefaultIdAsObjectId).equalsIgnoreCase("true") - - attribute match { - case "_id" if idAsObjectId => new ObjectId(value.toString) - case _ => value - } + /** + * Check if the field is "_id" and if the user wants to filter by this field as an ObjectId + * + * @param attribute Name of the file + * @param value Value for the attribute + * @return The value in the correct data type + */ + private def correctIdValue(attribute: String, value: Any)(implicit config: Config): Any = { + + val idAsObjectId: Boolean = config + .getOrElse[String](MongodbConfig.IdAsObjectId, MongodbConfig.DefaultIdAsObjectId) + .equalsIgnoreCase("true") + + attribute match { + case "_id" if idAsObjectId => new ObjectId(value.toString) + case _ => value } + } - /** - * - * Prepared DBObject used to specify required fields in mongodb 'find' - * @param fields Required fields - * @return A mongodb object that represents required fields. - */ - private def selectFields(fields: Seq[ColumnName]): DBObject = - { - MongoDBObject( - fields.toList.filterNot(_ == "_id").map(_ -> 1) ::: { - List("_id" -> fields.find(_ == "_id").fold(0)(_ => 1)) - }) - /* + /** + * + * Prepared DBObject used to specify required fields in mongodb 'find' + * @param fields Required fields + * @return A mongodb object that represents required fields. + */ + private def selectFields(fields: Seq[ColumnName]): DBObject = { + MongoDBObject(fields.toList.filterNot(_ == "_id").map(_ -> 1) ::: { + List("_id" -> fields.find(_ == "_id").fold(0)(_ => 1)) + }) + /* For random accesses to array columns elements, a performance improvement is doable by querying MongoDB in a way that would only select a size-1 slice of the accessed array thanks to the "$slice" operator. However this operator can only be used once for each column in a projection which implies that several accesses (e.g: SELECT arraystring[0] as first, arraystring[3] as fourth FROM MONGO_T) would require to implement an smart "$slice" use selecting the minimum slice containing all requested elements. That requires way too much effort when the performance boost is taken into consideration. - */ - } + */ + } } - // TODO logs, doc, tests -class MongoQueryProcessor(logicalPlan: LogicalPlan, config: Config, schemaProvided: Option[StructType] = None) extends SparkLoggerComponent { +class MongoQueryProcessor(logicalPlan: LogicalPlan, + config: Config, + schemaProvided: Option[StructType] = None) + extends SparkLoggerComponent { import MongoQueryProcessor._ @@ -160,27 +167,33 @@ class MongoQueryProcessor(logicalPlan: LogicalPlan, config: Config, schemaProvid None } else { try { - validatedNativePlan.map { case MongoPlan(bs: SimpleLogicalPlan, limit) => - if (limit.exists(_ == 0)) { - Array.empty[Row] - } else { - val name2randomAccess = bs.collectionRandomAccesses.map { - case (k, v) => s"${k.name}[${v.right}]" -> v + validatedNativePlan.map { + case MongoPlan(bs: SimpleLogicalPlan, limit) => + if (limit.exists(_ == 0)) { + Array.empty[Row] + } else { + val name2randomAccess = bs.collectionRandomAccesses.map { + case (k, v) => s"${k.name}[${v.right}]" -> v + } + val (mongoFilters, mongoRequiredColumns) = buildNativeQuery( + bs.projects.map(_.name), + bs.filters, + config, + name2randomAccess + ) + val resultSet = MongodbConnection.withCollectionDo(config) { collection => + logDebug( + s"Executing native query: filters => $mongoFilters projects => $mongoRequiredColumns") + val cursor = collection.find(mongoFilters, mongoRequiredColumns) + val result = cursor.limit(limit.getOrElse(DefaultLimit)).toArray[DBObject] + cursor.close() + result + } + sparkResultFromMongodb(bs.projects, + bs.collectionRandomAccesses, + schemaProvided.get, + resultSet) } - val (mongoFilters, mongoRequiredColumns) = buildNativeQuery( - bs.projects.map(_.name), bs.filters, - config, - name2randomAccess - ) - val resultSet = MongodbConnection.withCollectionDo(config) { collection => - logDebug(s"Executing native query: filters => $mongoFilters projects => $mongoRequiredColumns") - val cursor = collection.find(mongoFilters, mongoRequiredColumns) - val result = cursor.limit(limit.getOrElse(DefaultLimit)).toArray[DBObject] - cursor.close() - result - } - sparkResultFromMongodb(bs.projects, bs.collectionRandomAccesses, schemaProvided.get, resultSet) - } } } catch { case exc: Exception => @@ -190,66 +203,71 @@ class MongoQueryProcessor(logicalPlan: LogicalPlan, config: Config, schemaProvid } + def validatedNativePlan: Option[_] = { // TODO + lazy val limit: Option[Int] = logicalPlan.collectFirst { + case LogicalLimit(Literal(num: Int, _), _) => num + } - def validatedNativePlan: Option[_] = {// TODO - lazy val limit: Option[Int] = logicalPlan.collectFirst { case LogicalLimit(Literal(num: Int, _), _) => num } - - def findBasePlan(lplan: LogicalPlan): Option[BaseLogicalPlan] = lplan match { - - case LogicalLimit(_, child) => - findBasePlan(child) + def findBasePlan(lplan: LogicalPlan): Option[BaseLogicalPlan] = + lplan match { + + case LogicalLimit(_, child) => + findBasePlan(child) + + case PhysicalOperation(projectList, filterList, _) => + CatalystToCrossdataAdapter + .getConnectorLogicalPlan(logicalPlan, projectList, filterList) match { + case (_, ProjectReport(exprIgnored), FilterReport(filtersIgnored, _)) + if filtersIgnored.nonEmpty || exprIgnored.nonEmpty => + None + case (basePlan: SimpleLogicalPlan, _, _) => + Some(basePlan) + case _ => ??? // TODO + } - case PhysicalOperation(projectList, filterList, _) => - CatalystToCrossdataAdapter.getConnectorLogicalPlan(logicalPlan, projectList, filterList) match { - case (_, ProjectReport(exprIgnored), FilterReport(filtersIgnored, _)) if filtersIgnored.nonEmpty || exprIgnored.nonEmpty => - None - case (basePlan: SimpleLogicalPlan, _, _) => - Some(basePlan) - case _ => ??? // TODO - } + } + findBasePlan(logicalPlan).collect { + case bp if checkNativeFilters(bp.filters) => MongoPlan(bp, limit) } - - findBasePlan(logicalPlan).collect{ case bp if checkNativeFilters(bp.filters) => MongoPlan(bp, limit) } } + private[this] def checkNativeFilters(filters: Seq[SourceFilter]): Boolean = + filters.forall { + case _: sources.EqualTo => true + case _: sources.In => true + case _: sources.LessThan => true + case _: sources.GreaterThan => true + case _: sources.LessThanOrEqual => true + case _: sources.GreaterThanOrEqual => true + case _: sources.IsNull => true + case _: sources.IsNotNull => true + case _: sources.StringStartsWith => true + case _: sources.StringEndsWith => true + case _: sources.StringContains => true + case sources.And(left, right) => checkNativeFilters(Array(left, right)) + case sources.Or(left, right) => checkNativeFilters(Array(left, right)) + case sources.Not(filter) => checkNativeFilters(Array(filter)) + // TODO add more filters + case _ => false - private[this] def checkNativeFilters(filters: Seq[SourceFilter]): Boolean = filters.forall { - case _: sources.EqualTo => true - case _: sources.In => true - case _: sources.LessThan => true - case _: sources.GreaterThan => true - case _: sources.LessThanOrEqual => true - case _: sources.GreaterThanOrEqual => true - case _: sources.IsNull => true - case _: sources.IsNotNull => true - case _: sources.StringStartsWith => true - case _: sources.StringEndsWith => true - case _: sources.StringContains => true - case sources.And(left, right) => checkNativeFilters(Array(left, right)) - case sources.Or(left, right) => checkNativeFilters(Array(left, right)) - case sources.Not(filter) => checkNativeFilters(Array(filter)) - // TODO add more filters - case _ => false - - } + } private[this] def sparkResultFromMongodb( - requiredColumns: Seq[Attribute], - indexAccesses: Map[Attribute, GetArrayItem], - schema: StructType, - resultSet: Array[DBObject] - ): Array[Row] = { + requiredColumns: Seq[Attribute], + indexAccesses: Map[Attribute, GetArrayItem], + schema: StructType, + resultSet: Array[DBObject] + ): Array[Row] = { asRow( - pruneSchema( - schema, - requiredColumns.map(r => r.name -> indexAccesses.get(r).map(_.right.toString().toInt)).toArray - ), - resultSet + pruneSchema( + schema, + requiredColumns + .map(r => r.name -> indexAccesses.get(r).map(_.right.toString().toInt)) + .toArray + ), + resultSet ) } - } - - diff --git a/mongodb/src/main/scala/com/stratio/crossdata/connector/mongodb/MongodbXDRelation.scala b/mongodb/src/main/scala/com/stratio/crossdata/connector/mongodb/MongodbXDRelation.scala index a98f156fc..4ab9dbfbf 100644 --- a/mongodb/src/main/scala/com/stratio/crossdata/connector/mongodb/MongodbXDRelation.scala +++ b/mongodb/src/main/scala/com/stratio/crossdata/connector/mongodb/MongodbXDRelation.scala @@ -24,21 +24,21 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} /** - * A MongoDB baseRelation that can eliminate unneeded columns - * and filter using selected predicates before producing - * an RDD containing all matching tuples as Row objects. - * @param config A Deep configuration with needed properties for MongoDB - * @param schemaProvided The optionally provided schema. If not provided, - * it will be inferred from the whole field projection - * of the specified table in Spark SQL statement using - * a sample ratio (as JSON Data Source does). - * @param sqlContext An existing Spark SQL context. - */ -case class MongodbXDRelation(config: Config, - schemaProvided: Option[StructType] = None)( - @transient sqlContext: SQLContext) - extends MongodbRelation(config, schemaProvided)(sqlContext) with NativeScan with SparkLoggerComponent{ - + * A MongoDB baseRelation that can eliminate unneeded columns + * and filter using selected predicates before producing + * an RDD containing all matching tuples as Row objects. + * @param config A Deep configuration with needed properties for MongoDB + * @param schemaProvided The optionally provided schema. If not provided, + * it will be inferred from the whole field projection + * of the specified table in Spark SQL statement using + * a sample ratio (as JSON Data Source does). + * @param sqlContext An existing Spark SQL context. + */ +case class MongodbXDRelation(config: Config, schemaProvided: Option[StructType] = None)( + @transient sqlContext: SQLContext) + extends MongodbRelation(config, schemaProvided)(sqlContext) + with NativeScan + with SparkLoggerComponent { override def buildScan(optimizedLogicalPlan: LogicalPlan): Option[Array[Row]] = { logDebug(s"Processing ${optimizedLogicalPlan.toString()}") @@ -46,15 +46,19 @@ case class MongodbXDRelation(config: Config, queryExecutor.execute() } - override def isSupported(logicalStep: LogicalPlan, wholeLogicalPlan: LogicalPlan): Boolean = logicalStep match { - case ln: LeafNode => true // TODO leafNode == LogicalRelation(xdSourceRelation) - case un: UnaryNode => un match { - case Limit(_, _) | Project(_, _) | Filter(_, _) => true - case _ => false + override def isSupported(logicalStep: LogicalPlan, wholeLogicalPlan: LogicalPlan): Boolean = + logicalStep match { + case ln: LeafNode => + true // TODO leafNode == LogicalRelation(xdSourceRelation) + case un: UnaryNode => + un match { + case Limit(_, _) | Project(_, _) | Filter(_, _) => true + case _ => false + } + case unsupportedLogicalPlan => + logDebug(s"LogicalPlan $unsupportedLogicalPlan cannot be executed natively"); + false } - case unsupportedLogicalPlan =>logDebug(s"LogicalPlan $unsupportedLogicalPlan cannot be executed natively"); false - } - -} \ No newline at end of file +} diff --git a/mongodb/src/main/scala/com/stratio/datasource/mongodb/MongodbConnection.scala b/mongodb/src/main/scala/com/stratio/datasource/mongodb/MongodbConnection.scala index ecbcb584e..25dc43b23 100644 --- a/mongodb/src/main/scala/com/stratio/datasource/mongodb/MongodbConnection.scala +++ b/mongodb/src/main/scala/com/stratio/datasource/mongodb/MongodbConnection.scala @@ -40,7 +40,7 @@ object MongodbConnection { } private def openClient(config: Config): MongoClient = - MongodbClientFactory.getClient(config.hosts, config.credentials, config.sslOptions, config.clientOptions) + MongodbClientFactory + .getClient(config.hosts, config.credentials, config.sslOptions, config.clientOptions) - -} \ No newline at end of file +} diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoAggregationIT.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoAggregationIT.scala index 7f31aac29..c36b987b9 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoAggregationIT.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoAggregationIT.scala @@ -35,11 +35,12 @@ class MongoAggregationIT extends MongoWithSharedContext { ignore should "execute natively a (SELECT max(col), min(col), avg(col), sum(col), first(col), last(col) FROM _)" in { assumeEnvironmentIsUpAndRunning - val dataframe = sql(s"SELECT max(age), min(age), avg(age), sum(age), first(age), last(age) FROM $Collection") + val dataframe = + sql(s"SELECT max(age), min(age), avg(age), sum(age), first(age), last(age) FROM $Collection") val result = dataframe.collect(Native) result should have length 1 result(0) should have length 6 - result(0).toSeq should be (Seq(20,10,15,84,1,2)) + result(0).toSeq should be(Seq(20, 10, 15, 84, 1, 2)) // TODO update comparation taking into account real values } @@ -55,14 +56,16 @@ class MongoAggregationIT extends MongoWithSharedContext { ignore should "execute natively a (SELECT col, count(*) FROM _ GROUP BY col)" in { assumeEnvironmentIsUpAndRunning - val result = sql(s"SELECT enrolled, count(*) FROM $Collection GROUP BY enrolled").collect(Native) + val result = + sql(s"SELECT enrolled, count(*) FROM $Collection GROUP BY enrolled").collect(Native) ??? } ignore should "execute natively a (SELECT col1, col2, count(*) FROM _ GROUP BY col1, col2)" in { assumeEnvironmentIsUpAndRunning - val result = sql(s"SELECT id, enrolled, count(*) FROM $Collection GROUP BY id, enrolled").collect(Native) + val result = + sql(s"SELECT id, enrolled, count(*) FROM $Collection GROUP BY id, enrolled").collect(Native) ??? } @@ -76,7 +79,8 @@ class MongoAggregationIT extends MongoWithSharedContext { ignore should "execute natively a (SELECT count(*) FROM _ GROUP BY _ WHERE filterCondition)" in { assumeEnvironmentIsUpAndRunning - val result = sql(s"SELECT count(*) FROM $Collection GROUP BY enrolled WHERE id > 5").collect(Native) + val result = + sql(s"SELECT count(*) FROM $Collection GROUP BY enrolled WHERE id > 5").collect(Native) ??? } @@ -84,7 +88,8 @@ class MongoAggregationIT extends MongoWithSharedContext { ignore should "execute natively a (SELECT count(DISTINCT col) FROM _ GROUP BY col WHERE filterCondition)" in { assumeEnvironmentIsUpAndRunning - val result = sql(s"SELECT count(DISTINCT age) FROM $Collection GROUP BY enrolled WHERE id > 5").collect(Native) + val result = sql(s"SELECT count(DISTINCT age) FROM $Collection GROUP BY enrolled WHERE id > 5") + .collect(Native) ??? } @@ -92,7 +97,8 @@ class MongoAggregationIT extends MongoWithSharedContext { ignore should "execute natively a (SELECT name, sum(age) FROM _ GROUP BY _ HAVING _)" in { assumeEnvironmentIsUpAndRunning - val result = sql(s"SELECT name, sum(age) FROM $Collection GROUP BY name HAVING sum(age) > 25").collect(Native) + val result = sql(s"SELECT name, sum(age) FROM $Collection GROUP BY name HAVING sum(age) > 25") + .collect(Native) ??? } diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoConnectorIT.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoConnectorIT.scala index 2213a5313..23f65c76f 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoConnectorIT.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoConnectorIT.scala @@ -33,19 +33,19 @@ class MongoConnectorIT extends MongoWithSharedContext { val schema = dataframe.schema val result = dataframe.collect(Native) result should have length 10 - schema.fieldNames should equal (Seq("id", "age", "description", "enrolled", "name", "optionalField")) - result.head.toSeq should equal (Seq(1, 11, "description1", false, "Name 1", null)) + schema.fieldNames should equal( + Seq("id", "age", "description", "enrolled", "name", "optionalField")) + result.head.toSeq should equal(Seq(1, 11, "description1", false, "Name 1", null)) } - it should "return the columns in the requested order" in { assumeEnvironmentIsUpAndRunning val dataframe = sql(s"SELECT name, id FROM $Collection ") val schema = dataframe.schema val result = dataframe.collect(Native) result should have length 10 - schema.fieldNames should equal (Seq("name", "id")) - result.head.toSeq should equal (Seq( "Name 1", 1)) + schema.fieldNames should equal(Seq("name", "id")) + result.head.toSeq should equal(Seq("Name 1", 1)) } it should "execute natively a simple project" in { @@ -123,7 +123,9 @@ class MongoConnectorIT extends MongoWithSharedContext { it should "execute natively an AND filter" in { assumeEnvironmentIsUpAndRunning - val result = sql(s"SELECT * FROM $Collection WHERE id = 3 AND age = 13 AND description = 'description3' ").collect(Native) + val result = + sql(s"SELECT * FROM $Collection WHERE id = 3 AND age = 13 AND description = 'description3' ") + .collect(Native) result should have length 1 } @@ -135,7 +137,9 @@ class MongoConnectorIT extends MongoWithSharedContext { it should "execute natively an AND filter in a complex query (nested filters)" in { assumeEnvironmentIsUpAndRunning - val result = sql(s"SELECT * FROM $Collection WHERE id = 3 OR (age = 14 AND description = 'description4') ").collect(Native) + val result = sql( + s"SELECT * FROM $Collection WHERE id = 3 OR (age = 14 AND description = 'description4') ") + .collect(Native) result should have length 2 } @@ -145,7 +149,6 @@ class MongoConnectorIT extends MongoWithSharedContext { result should have length 0 } - // NOT SUPPORTED => JOIN it should "not execute natively a (SELECT * ... ORDER BY _ )" in { assumeEnvironmentIsUpAndRunning @@ -155,19 +158,16 @@ class MongoConnectorIT extends MongoWithSharedContext { } should have message "The operation cannot be executed without Spark" } - - // IMPORT OPERATIONS it should "import all user collections" in { assumeEnvironmentIsUpAndRunning //This crates a new collection in the database which will not be initially registered at the Spark - val client = MongoClient(MongoHost, MongoPort)(Database)(UnregisteredCollection).insert(MongoDBObject("id" -> 1)) + val client = MongoClient(MongoHost, MongoPort)(Database)(UnregisteredCollection) + .insert(MongoDBObject("id" -> 1)) - val result = - sql( - s""" + val result = sql(s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -176,8 +176,9 @@ class MongoConnectorIT extends MongoWithSharedContext { |) """.stripMargin) - val imported = result.collect().exists{ row => - val tableFound = row.getSeq(row.fieldIndex("tableIdentifier")) == Seq(Database, UnregisteredCollection) + val imported = result.collect().exists { row => + val tableFound = row.getSeq(row.fieldIndex("tableIdentifier")) == Seq(Database, + UnregisteredCollection) val isIgnored = row.getBoolean(row.fieldIndex("ignored")) tableFound && !isIgnored } @@ -192,8 +193,3 @@ class MongoConnectorIT extends MongoWithSharedContext { } } - - - - - diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoCreateExternalTableIT.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoCreateExternalTableIT.scala index e47110b9b..948754601 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoCreateExternalTableIT.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoCreateExternalTableIT.scala @@ -38,10 +38,9 @@ class MongoCreateExternalTableIT extends MongoWithSharedContext { //Expectations val table = xdContext.table(s"$Database.newtable") table should not be null - table.schema.fieldNames should contain ("name") + table.schema.fieldNames should contain("name") } - it should "execute a CREATE EXTERNAL TABLE with options" in { val createTableQUeryString = @@ -61,8 +60,8 @@ class MongoCreateExternalTableIT extends MongoWithSharedContext { //Expectations val table = xdContext.table(s"$Database.cappedTable") table should not be null - table.schema.fieldNames should contain ("name") - this.client.get.getDB(Database).getCollection("cappedTable").isCapped should be (true) + table.schema.fieldNames should contain("name") + this.client.get.getDB(Database).getCollection("cappedTable").isCapped should be(true) } it should "execute a CREATE EXTERNAL TABLE with a different tableName" in { @@ -83,8 +82,8 @@ class MongoCreateExternalTableIT extends MongoWithSharedContext { //Expectations val table = xdContext.table("other") table should not be null - table.schema.fieldNames should contain ("name") - this.client.get.getDB(Database).getCollection("cTable").isCapped should be (true) + table.schema.fieldNames should contain("name") + this.client.get.getDB(Database).getCollection("cTable").isCapped should be(true) } it should "execute a CREATE EXTERNAL TABLE without specific db and table options options" in { @@ -103,8 +102,8 @@ class MongoCreateExternalTableIT extends MongoWithSharedContext { //Expectations val table = xdContext.table("dbase.tbase") table should not be null - table.schema.fieldNames should contain ("name") - this.client.get.getDB("dbase").getCollection("tbase").isCapped should be (true) + table.schema.fieldNames should contain("name") + this.client.get.getDB("dbase").getCollection("tbase").isCapped should be(true) this.client.get.dropDatabase("dbase") } diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoDataTypesCollection.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoDataTypesCollection.scala index 0cc08556f..02eb577e2 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoDataTypesCollection.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoDataTypesCollection.scala @@ -25,13 +25,12 @@ import scala.util.Try trait MongoDataTypesCollection extends MongoWithSharedContext with SharedXDContextTypesTest { - override val emptyTypesSetError: String = "Type test entries should have been already inserted" override def dataTypesSparkOptions: Map[String, String] = Map( - "host" -> s"$MongoHost:$MongoPort", - "database" -> s"$Database", - "collection" -> s"$DataTypesCollection" + "host" -> s"$MongoHost:$MongoPort", + "database" -> s"$Database", + "collection" -> s"$DataTypesCollection" ) override def saveTypesData: Int = { @@ -44,39 +43,40 @@ trait MongoDataTypesCollection extends MongoWithSharedContext with SharedXDConte baseDate.set(Calendar.MILLISECOND, a) dataTypesCollection.insert { MongoDBObject( - "int" -> (2000 + a), - "bigint" -> (200000 + a).toLong, - "long" -> (200000 + a).toLong, - "string" -> s"String $a", - "boolean" -> true, - "double" -> (9.0 + (a.toDouble / 10)), - "float" -> float, - "decimalint" -> decimalInt, - "decimallong" -> decimalLong, - "decimaldouble" -> decimalDouble, - "decimalfloat" -> decimalFloat, - "date" -> new java.sql.Date(baseDate.getTimeInMillis), - "timestamp" -> new java.sql.Timestamp(baseDate.getTimeInMillis), - "tinyint" -> tinyint, - "smallint" -> smallint, - "binary" -> binary, - "arrayint" -> arrayint, - "arraystring" -> arraystring, - "mapintint" -> mapintint, - "mapstringint" -> mapstringint, - "mapstringstring" -> mapstringstring, - "struct" -> struct, - "arraystruct" -> arraystruct, - "arraystructwithdate" -> arraystructwithdate, - "structofstruct" -> structofstruct, - "mapstruct" -> mapstruct, - "arraystructarraystruct" -> arraystructarraystruct + "int" -> (2000 + a), + "bigint" -> (200000 + a).toLong, + "long" -> (200000 + a).toLong, + "string" -> s"String $a", + "boolean" -> true, + "double" -> (9.0 + (a.toDouble / 10)), + "float" -> float, + "decimalint" -> decimalInt, + "decimallong" -> decimalLong, + "decimaldouble" -> decimalDouble, + "decimalfloat" -> decimalFloat, + "date" -> new java.sql.Date(baseDate.getTimeInMillis), + "timestamp" -> new java.sql.Timestamp(baseDate.getTimeInMillis), + "tinyint" -> tinyint, + "smallint" -> smallint, + "binary" -> binary, + "arrayint" -> arrayint, + "arraystring" -> arraystring, + "mapintint" -> mapintint, + "mapstringint" -> mapstringint, + "mapstringstring" -> mapstringstring, + "struct" -> struct, + "arraystruct" -> arraystruct, + "arraystructwithdate" -> arraystructwithdate, + "structofstruct" -> structofstruct, + "mapstruct" -> mapstruct, + "arraystructarraystruct" -> arraystructarraystruct ) } } }.map(_ => 1).getOrElse(0) } - override def sparkRegisterTableSQL: Seq[SparkTable] = super.sparkRegisterTableSQL + override def sparkRegisterTableSQL: Seq[SparkTable] = + super.sparkRegisterTableSQL } diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoDotsNotationIT.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoDotsNotationIT.scala index a12f9a083..3265c06ba 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoDotsNotationIT.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoDotsNotationIT.scala @@ -25,46 +25,53 @@ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class MongoDotsNotationIT extends MongoDataTypesCollection { - it should "supports Projection with DOT notation using Spark" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT bigint, structofstruct.field1 FROM ${SharedXDContextTypesTest.dataTypesTableName}").collect(ExecutionType.Spark) + val sparkRow = sql( + s"SELECT bigint, structofstruct.field1 FROM ${SharedXDContextTypesTest.dataTypesTableName}") + .collect(ExecutionType.Spark) - sparkRow.head.schema.size should be (2) - sparkRow.head.schema.head.isInstanceOf[StructField] should be (true) + sparkRow.head.schema.size should be(2) + sparkRow.head.schema.head.isInstanceOf[StructField] should be(true) } it should "supports Projection with DOT notation with no ExecutionType defined" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT bigint, structofstruct.field1 FROM ${SharedXDContextTypesTest.dataTypesTableName}").collect() + val sparkRow = sql( + s"SELECT bigint, structofstruct.field1 FROM ${SharedXDContextTypesTest.dataTypesTableName}") + .collect() - sparkRow.head.schema.size should be (2) - sparkRow.head.schema.head.isInstanceOf[StructField] should be (true) + sparkRow.head.schema.size should be(2) + sparkRow.head.schema.head.isInstanceOf[StructField] should be(true) } it should "Does not supports Projection with DOT notation in Native" in { assumeEnvironmentIsUpAndRunning - val df = sql(s"SELECT bigint, structofstruct.field1 FROM ${SharedXDContextTypesTest.dataTypesTableName}") + val df = sql( + s"SELECT bigint, structofstruct.field1 FROM ${SharedXDContextTypesTest.dataTypesTableName}") - an [CrossdataException] should be thrownBy df.collect(ExecutionType.Native) + an[CrossdataException] should be thrownBy df.collect(ExecutionType.Native) } it should "supports Filters with DOT notation with no ExecutionType defined" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT int FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE struct.field2=3").collect() + val sparkRow = + sql(s"SELECT int FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE struct.field2=3") + .collect() - sparkRow.length should be (10) + sparkRow.length should be(10) } it should "Does not supports Filters with DOT notation in Native" in { assumeEnvironmentIsUpAndRunning - val df = sql(s"SELECT int FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE struct.field2=3") + val df = + sql(s"SELECT int FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE struct.field2=3") - an [CrossdataException] should be thrownBy df.collect(ExecutionType.Native) + an[CrossdataException] should be thrownBy df.collect(ExecutionType.Native) } } diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoDropExternalTableIT.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoDropExternalTableIT.scala index 8554716f4..ff10285c0 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoDropExternalTableIT.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoDropExternalTableIT.scala @@ -20,7 +20,6 @@ import org.scalatest.junit.JUnitRunner import scala.collection.Seq - @RunWith(classOf[JUnitRunner]) class MongoDropExternalTableIT extends MongoWithSharedContext { @@ -28,8 +27,7 @@ class MongoDropExternalTableIT extends MongoWithSharedContext { super.beforeAll() //Create test tables - val createTable1 = - s"""|CREATE EXTERNAL TABLE $Database.drop1 (id Integer, name String) + val createTable1 = s"""|CREATE EXTERNAL TABLE $Database.drop1 (id Integer, name String) USING $SourceProvider |OPTIONS ( |host '$MongoHost:$MongoPort', @@ -37,10 +35,9 @@ class MongoDropExternalTableIT extends MongoWithSharedContext { |collection 'drop1' |) """.stripMargin.replaceAll("\n", " ") - sql(createTable1).collect() + sql(createTable1).collect() - val createTable2 = - s"""|CREATE EXTERNAL TABLE drop2 (id Integer, name String) + val createTable2 = s"""|CREATE EXTERNAL TABLE drop2 (id Integer, name String) USING $SourceProvider |OPTIONS ( |host '$MongoHost:$MongoPort', @@ -51,7 +48,6 @@ class MongoDropExternalTableIT extends MongoWithSharedContext { sql(createTable2).collect() } - "The Mongo connector" should "execute a DROP EXTERNAL TABLE" in { //Precondition @@ -61,11 +57,11 @@ class MongoDropExternalTableIT extends MongoWithSharedContext { //DROP val dropExternalTableQuery = s"DROP EXTERNAL TABLE $Database.drop1" - sql(dropExternalTableQuery).collect() should be (Seq.empty) + sql(dropExternalTableQuery).collect() should be(Seq.empty) //Expectations an[Exception] shouldBe thrownBy(xdContext.table(s"$Database.drop1")) - this.client.get.getDB(Database).collectionExists(mongoTableName) should be (false) + this.client.get.getDB(Database).collectionExists(mongoTableName) should be(false) } @@ -78,11 +74,11 @@ class MongoDropExternalTableIT extends MongoWithSharedContext { //DROP val dropExternalTableQuery = "DROP EXTERNAL TABLE drop2" - sql(dropExternalTableQuery).collect() should be (Seq.empty) + sql(dropExternalTableQuery).collect() should be(Seq.empty) //Expectations an[Exception] shouldBe thrownBy(xdContext.table("drop2")) - this.client.get.getDB(Database).collectionExists(mongoTableName) should be (false) + this.client.get.getDB(Database).collectionExists(mongoTableName) should be(false) } diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoFilterIT.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoFilterIT.scala index a2bb13c09..131b66e50 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoFilterIT.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoFilterIT.scala @@ -27,76 +27,94 @@ class MongoFilterIT extends MongoDataTypesCollection { "MongoConnector" should "supports NOT BETWEEN by spark" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT id FROM $Collection WHERE id NOT BETWEEN 2 AND 10").collect(ExecutionType.Spark).head + val sparkRow = sql(s"SELECT id FROM $Collection WHERE id NOT BETWEEN 2 AND 10") + .collect(ExecutionType.Spark) + .head val result = Row(1) - sparkRow should be (result) + sparkRow should be(result) } it should "supports equals AND NOT IN by spark" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT id FROM $Collection WHERE id = 6 AND id NOT IN (2,3,4,5)").collect(ExecutionType.Spark).head + val sparkRow = sql(s"SELECT id FROM $Collection WHERE id = 6 AND id NOT IN (2,3,4,5)") + .collect(ExecutionType.Spark) + .head val result = Row(6) - sparkRow should be (result) + sparkRow should be(result) } it should "supports NOT LIKE by spark" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT description FROM $Collection WHERE description NOT LIKE 'description1'").collect(ExecutionType.Spark) + val sparkRow = + sql(s"SELECT description FROM $Collection WHERE description NOT LIKE 'description1'") + .collect(ExecutionType.Spark) val result = (2 to 10).map(n => Row(s"description$n")).toArray - sparkRow should be (result) + sparkRow should be(result) } it should "supports filter DATE greater than" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE date > '1970'").collect(ExecutionType.Spark) - sparkRow.length should be (10) + val sparkRow = + sql(s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE date > '1970'") + .collect(ExecutionType.Spark) + sparkRow.length should be(10) } it should "supports filter DATE equals to" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE date = '1970-01-02'").collect(ExecutionType.Spark) - sparkRow.length should be (1) + val sparkRow = sql( + s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE date = '1970-01-02'") + .collect(ExecutionType.Spark) + sparkRow.length should be(1) } it should "supports filter DATE BETWEEN two dates" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE date BETWEEN '1970' AND '1971'").collect(ExecutionType.Spark) - sparkRow.length should be (10) + val sparkRow = sql( + s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE date BETWEEN '1970' AND '1971'") + .collect(ExecutionType.Spark) + sparkRow.length should be(10) } it should "supports filter DATE NOT BETWEEN two dates" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE date NOT BETWEEN '1970-01-02' AND '1971'").collect(ExecutionType.Spark) - sparkRow.length should be (1) + val sparkRow = sql( + s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE date NOT BETWEEN '1970-01-02' AND '1971'") + .collect(ExecutionType.Spark) + sparkRow.length should be(1) } it should "supports filter TIMESTAMP greater than" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT timestamp FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE timestamp > '1970'").collect(ExecutionType.Spark) - sparkRow.length should be (10) + val sparkRow = sql( + s"SELECT timestamp FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE timestamp > '1970'") + .collect(ExecutionType.Spark) + sparkRow.length should be(10) } it should "supports filter TIMESTAMP equals to" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT timestamp FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE timestamp = '1970-01-02 00:0:0.002'").collect(ExecutionType.Native) - sparkRow.head(0) should be (java.sql.Timestamp.valueOf("1970-01-02 00:00:00.002")) + val sparkRow = sql( + s"SELECT timestamp FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE timestamp = '1970-01-02 00:0:0.002'") + .collect(ExecutionType.Native) + sparkRow.head(0) should be(java.sql.Timestamp.valueOf("1970-01-02 00:00:00.002")) } @@ -104,56 +122,66 @@ class MongoFilterIT extends MongoDataTypesCollection { it should "supports filter TIMESTAMP less or equals to" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT timestamp FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE timestamp <= '1970-01-01 00:00:00.001'").collect(ExecutionType.Native) + val sparkRow = sql( + s"SELECT timestamp FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE timestamp <= '1970-01-01 00:00:00.001'") + .collect(ExecutionType.Native) sparkRow.size should be > 0 - sparkRow.head(0) should be (java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")) + sparkRow.head(0) should be(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")) } - it should "supports filter TIMESTAMP BETWEEN two times" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT timestamp FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE timestamp BETWEEN '1970' AND '1971'").collect(ExecutionType.Spark) - sparkRow.length should be (10) + val sparkRow = sql( + s"SELECT timestamp FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE timestamp BETWEEN '1970' AND '1971'") + .collect(ExecutionType.Spark) + sparkRow.length should be(10) } - it should "supports Native filter DATE LESS OR EQUALS THAN " in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE timestamp <= '1970-01-02 00:0:0.002'").collect(ExecutionType.Native) - sparkRow.length should be (2) + val sparkRow = sql( + s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE timestamp <= '1970-01-02 00:0:0.002'") + .collect(ExecutionType.Native) + sparkRow.length should be(2) } it should "supports Native filter DATE GREATER OR EQUALS THAN " in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE date >= '1970-01-02'").collect(ExecutionType.Native) - sparkRow.length should be (9) + val sparkRow = sql( + s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE date >= '1970-01-02'") + .collect(ExecutionType.Native) + sparkRow.length should be(9) } it should "supports Native filter DATE LESS THAN " in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE date < '1970-01-02'").collect(ExecutionType.Native) - sparkRow.length should be (1) + val sparkRow = sql( + s"SELECT date FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE date < '1970-01-02'") + .collect(ExecutionType.Native) + sparkRow.length should be(1) } it should "supports Native filter DATE GREATER THAN " in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT timestamp FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE timestamp > '1970-01-02'").collect(ExecutionType.Native) - sparkRow.length should be (9) + val sparkRow = sql( + s"SELECT timestamp FROM ${SharedXDContextTypesTest.dataTypesTableName} WHERE timestamp > '1970-01-02'") + .collect(ExecutionType.Native) + sparkRow.length should be(9) } - it should "execute Spark UDFs by Spark" in { assumeEnvironmentIsUpAndRunning - val sparkRow = sql(s"SELECT substring(name,0,2) FROM $Collection LIMIT 2").collect(ExecutionType.Default) - sparkRow.length should be (2) + val sparkRow = + sql(s"SELECT substring(name,0,2) FROM $Collection LIMIT 2").collect(ExecutionType.Default) + sparkRow.length should be(2) sparkRow(0).getString(0) should have length 2 sparkRow(1).getString(0) should have length 2 } diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoImportTablesIT.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoImportTablesIT.scala index 262cad38e..d171b6316 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoImportTablesIT.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoImportTablesIT.scala @@ -24,14 +24,12 @@ import org.scalatest.junit.JUnitRunner class MongoImportTablesIT extends MongoDataTypesCollection { /**All tables imported after dropAllTables won't be temporary**/ - "MongoConnector" should "import all tables from MongoDB" in { assumeEnvironmentIsUpAndRunning xdContext.dropAllTables() - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -42,7 +40,8 @@ class MongoImportTablesIT extends MongoDataTypesCollection { val importedTables = sql(importQuery).collect().map(_.getSeq(0)) - importedTables should contain allOf (Seq("highschool",Collection), Seq("highschool",DataTypesCollection)) + importedTables should contain allOf (Seq("highschool", Collection), Seq("highschool", + DataTypesCollection)) } it should "import tables from a MongoDB database" in { @@ -50,8 +49,7 @@ class MongoImportTablesIT extends MongoDataTypesCollection { xdContext.dropAllTables() - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -62,7 +60,9 @@ class MongoImportTablesIT extends MongoDataTypesCollection { """.stripMargin sql(importQuery) - sql("SHOW TABLES").collect() should contain allOf (Row(s"highschool.$Collection", false),Row(s"highschool.$DataTypesCollection", false)) + sql("SHOW TABLES").collect() should contain allOf (Row(s"highschool.$Collection", false), Row( + s"highschool.$DataTypesCollection", + false)) } @@ -71,8 +71,7 @@ class MongoImportTablesIT extends MongoDataTypesCollection { xdContext.dropAllTables() - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -85,7 +84,7 @@ class MongoImportTablesIT extends MongoDataTypesCollection { val importedTables = sql(importQuery).collect().map(_.getSeq(0)) - importedTables should contain only Seq("highschool",Collection) + importedTables should contain only Seq("highschool", Collection) } it should "import table from a collection with incorrect database" in { @@ -93,8 +92,7 @@ class MongoImportTablesIT extends MongoDataTypesCollection { xdContext.dropAllTables() val wrongCollection = "wrongCollection" - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -113,8 +111,7 @@ class MongoImportTablesIT extends MongoDataTypesCollection { xdContext.dropAllTables() - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -125,7 +122,7 @@ class MongoImportTablesIT extends MongoDataTypesCollection { """.stripMargin sql(importQuery) - sql("SHOW TABLES").collect() should contain (Row(s"highschool.$Collection", false)) + sql("SHOW TABLES").collect() should contain(Row(s"highschool.$Collection", false)) } diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoInsertCollection.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoInsertCollection.scala index e64b6c452..e8f279829 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoInsertCollection.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoInsertCollection.scala @@ -33,28 +33,31 @@ trait MongoInsertCollection extends MongoWithSharedContext { for (a <- 1 to 10) { collection.insert { MongoDBObject("id" -> a, - "age" -> (10 + a), - "description" -> s"description$a", - "enrolled" -> (a % 2 == 0), - "name" -> s"Name $a", - "array_test" -> Seq(a toString, a+1 toString, a+2 toString), - "map_test" -> Map(("x",a),("y",a+2),("c",a+3)), - "array_map" -> Seq( Map("x" -> a), Map("y" -> (a+1)) ), - "map_array" -> Map("x" -> Seq(1,2), "y" -> Seq(2,3)) - ) + "age" -> (10 + a), + "description" -> s"description$a", + "enrolled" -> (a % 2 == 0), + "name" -> s"Name $a", + "array_test" -> Seq(a toString, a + 1 toString, a + 2 toString), + "map_test" -> Map(("x", a), ("y", a + 2), ("c", a + 3)), + "array_map" -> Seq(Map("x" -> a), Map("y" -> (a + 1))), + "map_array" -> Map("x" -> Seq(1, 2), "y" -> Seq(2, 3))) } - collection.update(QueryBuilder.start("id").greaterThan(4).get, MongoDBObject(("$set", MongoDBObject(("optionalField", true)))), multi = true) + collection.update(QueryBuilder.start("id").greaterThan(4).get, + MongoDBObject(("$set", MongoDBObject(("optionalField", true)))), + multi = true) } } - override def sparkRegisterTableSQL: Seq[SparkTable] = super.sparkRegisterTableSQL :+ - str2sparkTableDesc(s"""|CREATE TEMPORARY TABLE $Collection (id BIGINT, age INT, description STRING, enrolled BOOLEAN, + override def sparkRegisterTableSQL: Seq[SparkTable] = + super.sparkRegisterTableSQL :+ + str2sparkTableDesc( + s"""|CREATE TEMPORARY TABLE $Collection (id BIGINT, age INT, description STRING, enrolled BOOLEAN, |name STRING, optionalField BOOLEAN, array_test ARRAY, map_test MAP, |array_map ARRAY>, map_array MAP>)""".stripMargin) - override val Collection = "studentsInsertTest" - override def defaultOptions = super.defaultOptions + ("collection" -> s"$Collection") + override def defaultOptions = + super.defaultOptions + ("collection" -> s"$Collection") } diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoInsertTableIT.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoInsertTableIT.scala index 05d605bd3..b24b7e672 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoInsertTableIT.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoInsertTableIT.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.Row import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner - @RunWith(classOf[JUnitRunner]) class MongoInsertTableIT extends MongoInsertCollection { @@ -31,25 +30,34 @@ class MongoInsertTableIT extends MongoInsertCollection { s"""|INSERT INTO $Collection VALUES (20, 25, 'proof description', true, 'Eve', false, |['proof'], (a->2), [ (x -> 1, y-> 1), (z -> 1) ], ( x->[1,2], y-> [3,4] ) )""".stripMargin - _xdContext.sql(query).collect() should be (Row(1)::Nil) + _xdContext.sql(query).collect() should be(Row(1) :: Nil) //EXPECTATION val results = sql(s"select * from $Collection where id=20").collect() results should have length 1 results should contain - Row(20, 25, "proof description", true, "Eve", - false, Seq("proof"), Map("a" -> "2"), List(Map("x" -> "1", "y" -> "1"), - Map("z" -> "1")), Map("x" -> List("1", "2"), "y" -> List("3", "4"))) - + Row(20, + 25, + "proof description", + true, + "Eve", + false, + Seq("proof"), + Map("a" -> "2"), + List(Map("x" -> "1", "y" -> "1"), Map("z" -> "1")), + Map("x" -> List("1", "2"), "y" -> List("3", "4"))) } it should "insert a row using INSERT INTO table(schema) VALUES in MongoDb" in { - _xdContext.sql(s"INSERT INTO $Collection(age,name, enrolled) VALUES ( 25, 'Peter', true)").collect() should be (Row(1)::Nil) + _xdContext + .sql(s"INSERT INTO $Collection(age,name, enrolled) VALUES ( 25, 'Peter', true)") + .collect() should be(Row(1) :: Nil) //EXPECTATION - val results = sql(s"select age, name, enrolled from $Collection where age=25 and name='Peter'").collect() + val results = + sql(s"select age, name, enrolled from $Collection where age=25 and name='Peter'").collect() results should have length 1 results should contain @@ -64,37 +72,62 @@ class MongoInsertTableIT extends MongoInsertCollection { |(23, 33, 'other fun description', false, 'July', false, [true,true], (z->1, a-> 2), [ (za -> 12) ], ( x->[1,2] ) ) """.stripMargin val rows: Array[Row] = _xdContext.sql(query).collect() - rows should be (Row(3)::Nil) + rows should be(Row(3) :: Nil) //EXPECTATION val results = sql(s"select * from $Collection where id=21 or id=22 or id=23").collect() results should have length 3 - results should contain allOf( - - Row(21, 25, "proof description", true, "John", false, Seq("4", "5"), - Map("x" -> "1"), Seq(Map("z" -> "1")), Map("x" -> Seq("1", "2"))), - - Row(22, 1, "other description", false, "James", true, Seq("1", "2", "3"), - Map("key" -> "value"), Seq(Map("a" -> "1")), Map("x" -> Seq("1", "a"))), - - Row(23, 33, "other fun description", false, "July", false, Seq("true", "true"), - Map("z" -> "1", "a" -> "2"), Seq(Map("za" -> "12")), Map("x" -> Seq("1", "2"))) - ) + results should contain allOf ( + Row(21, + 25, + "proof description", + true, + "John", + false, + Seq("4", "5"), + Map("x" -> "1"), + Seq(Map("z" -> "1")), + Map("x" -> Seq("1", "2"))), + Row(22, + 1, + "other description", + false, + "James", + true, + Seq("1", "2", "3"), + Map("key" -> "value"), + Seq(Map("a" -> "1")), + Map("x" -> Seq("1", "a"))), + Row(23, + 33, + "other fun description", + false, + "July", + false, + Seq("true", "true"), + Map("z" -> "1", "a" -> "2"), + Seq(Map("za" -> "12")), + Map("x" -> Seq("1", "2"))) + ) } it should "insert multiple rows using INSERT INTO table(schema) VALUES in MongoDb" in { - _xdContext.sql(s"INSERT INTO $Collection(age,name, enrolled) VALUES ( 50, 'Samantha', true),( 1, 'Charlie', false)").collect() should be (Row(2)::Nil) + _xdContext + .sql(s"INSERT INTO $Collection(age,name, enrolled) VALUES ( 50, 'Samantha', true),( 1, 'Charlie', false)") + .collect() should be(Row(2) :: Nil) //EXPECTATION - val results = sql(s"select age, enrolled, name from $Collection where (age=50 and name='Samantha') or (age=1 and name='Charlie')").collect() + val results = sql( + s"select age, enrolled, name from $Collection where (age=50 and name='Samantha') or (age=1 and name='Charlie')") + .collect() results should have length 2 - results should contain allOf( - Row(50, true, "Samantha"), - Row(1, false, "Charlie") - ) + results should contain allOf ( + Row(50, true, "Samantha"), + Row(1, false, "Charlie") + ) } it should "insert rows using INSERT INTO table(schema) VALUES with Arrays in MongoDb" in { @@ -102,16 +135,18 @@ class MongoInsertTableIT extends MongoInsertCollection { |( 55, 'Jules', true, [true, false]), |( 12, 'Martha', false, ['test1,t', 'test2']) """.stripMargin - _xdContext.sql(query).collect() should be (Row(2)::Nil) + _xdContext.sql(query).collect() should be(Row(2) :: Nil) //EXPECTATION - val results = sql(s"select age, name, enrolled, array_test from $Collection where (age=55 and name='Jules') or (age=12 and name='Martha')").collect() + val results = sql( + s"select age, name, enrolled, array_test from $Collection where (age=55 and name='Jules') or (age=12 and name='Martha')") + .collect() results should have length 2 - results should contain allOf( - Row(55, "Jules", true, Seq("true", "false")), - Row(12, "Martha", false, Seq("test1,t", "test2")) - ) + results should contain allOf ( + Row(55, "Jules", true, Seq("true", "false")), + Row(12, "Martha", false, Seq("test1,t", "test2")) + ) } it should "insert rows using INSERT INTO table(schema) VALUES with Map in MongoDb" in { @@ -119,16 +154,18 @@ class MongoInsertTableIT extends MongoInsertCollection { |( 12, 'Albert', true, (x->1, y->2, z->3) ), |( 20, 'Alfred', false, (xa->1, ya->2, za->3,d -> 5) ) """.stripMargin - _xdContext.sql(query).collect() should be (Row(2)::Nil) + _xdContext.sql(query).collect() should be(Row(2) :: Nil) //EXPECTATION - val results = sql(s"select age, name, enrolled, map_test from $Collection where (age=12 and name='Albert') or (age=20 and name='Alfred')").collect() + val results = sql( + s"select age, name, enrolled, map_test from $Collection where (age=12 and name='Albert') or (age=20 and name='Alfred')") + .collect() results should have length 2 - results should contain allOf( - Row(12, "Albert", true, Map("x" -> "1", "y" -> "2", "z" -> "3")), - Row(20, "Alfred", false, Map("xa" -> "1", "ya" -> "2", "za" -> "3", "d" -> "5")) - ) + results should contain allOf ( + Row(12, "Albert", true, Map("x" -> "1", "y" -> "2", "z" -> "3")), + Row(20, "Alfred", false, Map("xa" -> "1", "ya" -> "2", "za" -> "3", "d" -> "5")) + ) } it should "insert rows using INSERT INTO table(schema) VALUES with Array of Maps in MongoDb" in { @@ -136,16 +173,18 @@ class MongoInsertTableIT extends MongoInsertCollection { |( 1, 'Nikolai', true, [(x -> 3), (z -> 1)] ), |( 14, 'Ludwig', false, [(x -> 1, y-> 1), (z -> 1)] ) """.stripMargin - _xdContext.sql(query).collect() should be (Row(2)::Nil) + _xdContext.sql(query).collect() should be(Row(2) :: Nil) //EXPECTATION - val results = sql(s"select age, name, enrolled, array_map from $Collection where (age=1 and name='Nikolai') or (age=14 and name='Ludwig')").collect() + val results = sql( + s"select age, name, enrolled, array_map from $Collection where (age=1 and name='Nikolai') or (age=14 and name='Ludwig')") + .collect() results should have length 2 - results should contain allOf( - Row(1, "Nikolai", true, Seq(Map("x" -> "3"), Map("z" -> "1"))), - Row(14, "Ludwig", false, Seq(Map("x" -> "1", "y" -> "1"), Map("z" -> "1"))) - ) + results should contain allOf ( + Row(1, "Nikolai", true, Seq(Map("x" -> "3"), Map("z" -> "1"))), + Row(14, "Ludwig", false, Seq(Map("x" -> "1", "y" -> "1"), Map("z" -> "1"))) + ) } it should "insert rows using INSERT INTO table(schema) VALUES with Map of Array in MongoDb" in { @@ -153,16 +192,18 @@ class MongoInsertTableIT extends MongoInsertCollection { |( 13, 'Svletiana', true, ( x->[1], y-> [3,4] ) ), |( 17, 'Wolfang', false, ( x->[1,2], y-> [3] ) ) """.stripMargin - _xdContext.sql(query).collect() should be (Row(2)::Nil) + _xdContext.sql(query).collect() should be(Row(2) :: Nil) //EXPECTATION - val results = sql(s"select age, name, enrolled, map_array from $Collection where (age=13 and name='Svletiana') or (age=17 and name='Wolfang')").collect() + val results = sql( + s"select age, name, enrolled, map_array from $Collection where (age=13 and name='Svletiana') or (age=17 and name='Wolfang')") + .collect() results should have length 2 - results should contain allOf( - Row(13, "Svletiana", true, Map("x" -> Seq("1"), "y" -> Seq("3", "4"))), - Row(17, "Wolfang", false, Map("x" -> Seq("1", "2"), "y" -> Seq("3"))) - ) + results should contain allOf ( + Row(13, "Svletiana", true, Map("x" -> Seq("1"), "y" -> Seq("3", "4"))), + Row(17, "Wolfang", false, Map("x" -> Seq("1", "2"), "y" -> Seq("3"))) + ) } } diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoQueryProcessorSpec.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoQueryProcessorSpec.scala index 93786cf96..2efc84eb9 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoQueryProcessorSpec.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoQueryProcessorSpec.scala @@ -45,42 +45,47 @@ class MongoQueryProcessorSpec extends BaseXDTest { .build() "A MongoQueryProcessor" should "build a query requiring some columns" in { - val (filters, requiredColumns) = MongoQueryProcessor.buildNativeQuery(Array(ColumnId, ColumnAge), Array(), config) + val (filters, requiredColumns) = + MongoQueryProcessor.buildNativeQuery(Array(ColumnId, ColumnAge), Array(), config) val columnsSet = requiredColumns.keySet filters.keySet should have size 0 columnsSet should have size 3 columnsSet should contain allOf (ColumnId, ColumnAge) - requiredColumns.get(ColumnId) should be (1) - requiredColumns.get(ColumnAge) should be (1) - requiredColumns.get(ObjectId) should be (0) + requiredColumns.get(ColumnId) should be(1) + requiredColumns.get(ColumnAge) should be(1) + requiredColumns.get(ObjectId) should be(0) } it should "build a query with two equal filters" in { - val (filters, requiredColumns) = MongoQueryProcessor.buildNativeQuery(Array(ColumnId), Array(EqualTo(ColumnAge, ValueAge), EqualTo(ColumnId, ValueId)), config) + val (filters, requiredColumns) = MongoQueryProcessor.buildNativeQuery( + Array(ColumnId), + Array(EqualTo(ColumnAge, ValueAge), EqualTo(ColumnId, ValueId)), + config) val filterSet = filters.keySet requiredColumns.keySet should contain(ColumnId) - requiredColumns.get(ColumnId) should be (1) - requiredColumns.get(ObjectId) should be (0) + requiredColumns.get(ColumnId) should be(1) + requiredColumns.get(ObjectId) should be(0) filterSet should have size 2 - filters.get(ColumnId) should be (ValueId.toString) - filters.get(ColumnAge) should be (ValueAge) + filters.get(ColumnId) should be(ValueId.toString) + filters.get(ColumnAge) should be(ValueAge) } it should "build a query with an IN clause" in { - val (filters, requiredColumns) = MongoQueryProcessor.buildNativeQuery(Array(ColumnId), Array(In(ColumnAge, Array(ValueAge, ValueAge2))), config) + val (filters, requiredColumns) = MongoQueryProcessor + .buildNativeQuery(Array(ColumnId), Array(In(ColumnAge, Array(ValueAge, ValueAge2))), config) val filterSet = filters.keySet - requiredColumns.keySet should contain (ColumnId) - requiredColumns.get(ColumnId) should be (1) - requiredColumns.get(ObjectId) should be (0) + requiredColumns.keySet should contain(ColumnId) + requiredColumns.get(ColumnId) should be(1) + requiredColumns.get(ObjectId) should be(0) filterSet should have size 1 - filters.get(ColumnAge) shouldBe a [DBObject] + filters.get(ColumnAge) shouldBe a[DBObject] val inListValues = filters.get(ColumnAge).asInstanceOf[DBObject].get(QueryOperators.IN) inListValues should matchPattern { case _: Array[_] => } @@ -90,99 +95,115 @@ class MongoQueryProcessorSpec extends BaseXDTest { } it should "build a query with a LT clause" in { - val (filters, requiredColumns) = MongoQueryProcessor.buildNativeQuery(Array(ColumnId), Array(LessThan(ColumnAge, ValueAge)), config) + val (filters, requiredColumns) = MongoQueryProcessor + .buildNativeQuery(Array(ColumnId), Array(LessThan(ColumnAge, ValueAge)), config) val filterSet = filters.keySet requiredColumns.keySet should contain(ColumnId) - requiredColumns.get(ColumnId) should be (1) - requiredColumns.get(ObjectId) should be (0) + requiredColumns.get(ColumnId) should be(1) + requiredColumns.get(ObjectId) should be(0) filterSet should have size 1 - filters.get(ColumnAge) shouldBe a [DBObject] + filters.get(ColumnAge) shouldBe a[DBObject] filters.get(ColumnAge).asInstanceOf[DBObject].get(QueryOperators.LT) shouldBe (ValueAge) } it should "build a query with a LTE clause" in { - val (filters, requiredColumns) = MongoQueryProcessor.buildNativeQuery(Array(ColumnId), Array(LessThanOrEqual(ColumnAge, ValueAge)), config) + val (filters, requiredColumns) = MongoQueryProcessor + .buildNativeQuery(Array(ColumnId), Array(LessThanOrEqual(ColumnAge, ValueAge)), config) val filterSet = filters.keySet requiredColumns.keySet should contain(ColumnId) - requiredColumns.get(ColumnId) should be (1) - requiredColumns.get(ObjectId) should be (0) + requiredColumns.get(ColumnId) should be(1) + requiredColumns.get(ObjectId) should be(0) filterSet should have size 1 - filters.get(ColumnAge) shouldBe a [DBObject] + filters.get(ColumnAge) shouldBe a[DBObject] filters.get(ColumnAge).asInstanceOf[DBObject].get(QueryOperators.LTE) shouldBe (ValueAge) } - it should "build a query with a GTE clause" in { - val (filters, requiredColumns) = MongoQueryProcessor.buildNativeQuery(Array(ColumnId), Array(GreaterThanOrEqual(ColumnAge, ValueAge)), config) + val (filters, requiredColumns) = MongoQueryProcessor + .buildNativeQuery(Array(ColumnId), Array(GreaterThanOrEqual(ColumnAge, ValueAge)), config) val filterSet = filters.keySet requiredColumns.keySet should contain(ColumnId) - requiredColumns.get(ColumnId) should be (1) - requiredColumns.get(ObjectId) should be (0) + requiredColumns.get(ColumnId) should be(1) + requiredColumns.get(ObjectId) should be(0) filterSet should have size 1 - filters.get(ColumnAge) shouldBe a [DBObject] + filters.get(ColumnAge) shouldBe a[DBObject] filters.get(ColumnAge).asInstanceOf[DBObject].get(QueryOperators.GTE) shouldBe (ValueAge) } it should "build a query with an IS NOT NULL clause" in { - val (filters, requiredColumns) = MongoQueryProcessor.buildNativeQuery(Array(ColumnId), Array(IsNotNull(ColumnAge)), config) + val (filters, requiredColumns) = + MongoQueryProcessor.buildNativeQuery(Array(ColumnId), Array(IsNotNull(ColumnAge)), config) val filterSet = filters.keySet requiredColumns.keySet should contain(ColumnId) - requiredColumns.get(ColumnId) should be (1) - requiredColumns.get(ObjectId) should be (0) + requiredColumns.get(ColumnId) should be(1) + requiredColumns.get(ObjectId) should be(0) filterSet should have size 1 - filters.get(ColumnAge) shouldBe a [DBObject] + filters.get(ColumnAge) shouldBe a[DBObject] filters.get(ColumnAge).asInstanceOf[DBObject].get(QueryOperators.NE) shouldBe (null) } it should "build a query with an AND(v1 > x <= v2)" in { - val (filters, requiredColumns) = MongoQueryProcessor.buildNativeQuery(Array(ColumnId), Array(And(GreaterThan(ColumnAge, ValueAge), LessThanOrEqual(ColumnAge, ValueAge2))), config) + val (filters, requiredColumns) = MongoQueryProcessor.buildNativeQuery( + Array(ColumnId), + Array(And(GreaterThan(ColumnAge, ValueAge), LessThanOrEqual(ColumnAge, ValueAge2))), + config) val filterSet = filters.keySet - requiredColumns.keySet should contain (ColumnId) - requiredColumns.get(ColumnId) should be (1) - requiredColumns.get(ObjectId) should be (0) + requiredColumns.keySet should contain(ColumnId) + requiredColumns.get(ColumnId) should be(1) + requiredColumns.get(ObjectId) should be(0) filterSet should have size 1 - filters.get(QueryOperators.AND) shouldBe a [util.ArrayList[_]] + filters.get(QueryOperators.AND) shouldBe a[util.ArrayList[_]] val subfilters = filters.get(QueryOperators.AND).asInstanceOf[util.ArrayList[DBObject]] //filter GT - subfilters.get(0).get(ColumnAge) shouldBe a [DBObject] - subfilters.get(0).get(ColumnAge).asInstanceOf[DBObject].get(QueryOperators.GT) shouldBe (ValueAge) + subfilters.get(0).get(ColumnAge) shouldBe a[DBObject] + subfilters + .get(0) + .get(ColumnAge) + .asInstanceOf[DBObject] + .get(QueryOperators.GT) shouldBe (ValueAge) //filter LTE - subfilters.get(1).get(ColumnAge) shouldBe a [DBObject] - subfilters.get(1).get(ColumnAge).asInstanceOf[DBObject].get(QueryOperators.LTE) shouldBe (ValueAge2) + subfilters.get(1).get(ColumnAge) shouldBe a[DBObject] + subfilters + .get(1) + .get(ColumnAge) + .asInstanceOf[DBObject] + .get(QueryOperators.LTE) shouldBe (ValueAge2) } it should "build a query with a REGEX clause " in { - val (filters, requiredColumns) = MongoQueryProcessor.buildNativeQuery(Array(ColumnId), Array(StringContains(ColumnId, ValueId.toString)), config) + val (filters, requiredColumns) = MongoQueryProcessor + .buildNativeQuery(Array(ColumnId), Array(StringContains(ColumnId, ValueId.toString)), config) val filterSet = filters.keySet - requiredColumns.keySet should contain (ColumnId) - requiredColumns.get(ColumnId) should be (1) - requiredColumns.get(ObjectId) should be (0) + requiredColumns.keySet should contain(ColumnId) + requiredColumns.get(ColumnId) should be(1) + requiredColumns.get(ObjectId) should be(0) filterSet should have size 1 - filters.get(ColumnId) shouldBe a [Pattern] + filters.get(ColumnId) shouldBe a[Pattern] - filters.get(ColumnId).asInstanceOf[Pattern].pattern should be (Pattern.compile(s".*${ValueId.toString}.*").pattern) + filters.get(ColumnId).asInstanceOf[Pattern].pattern should be( + Pattern.compile(s".*${ValueId.toString}.*").pattern) } } diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoWithSharedContext.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoWithSharedContext.scala index f061857f3..eb2c12699 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoWithSharedContext.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongoWithSharedContext.scala @@ -15,7 +15,6 @@ */ package com.stratio.crossdata.connector.mongodb - import com.mongodb.{BasicDBObject, QueryBuilder} import com.mongodb.casbah.MongoClient import com.mongodb.casbah.commons.MongoDBObject @@ -27,15 +26,18 @@ import org.scalatest.Suite import scala.util.Try -trait MongoWithSharedContext extends SharedXDContextWithDataTest with MongoDefaultConstants with SparkLoggerComponent { +trait MongoWithSharedContext + extends SharedXDContextWithDataTest + with MongoDefaultConstants + with SparkLoggerComponent { this: Suite => override type ClientParams = MongoClient override val provider: String = SourceProvider override def defaultOptions = Map( - "host" -> s"$MongoHost:$MongoPort", - "database" -> s"$Database", - "collection" -> s"$Collection" + "host" -> s"$MongoHost:$MongoPort", + "database" -> s"$Database", + "collection" -> s"$Collection" ) override protected def saveTestData: Unit = { @@ -45,13 +47,14 @@ trait MongoWithSharedContext extends SharedXDContextWithDataTest with MongoDefau for (a <- 1 to 10) { collection.insert { MongoDBObject("id" -> a, - "age" -> (10 + a), - "description" -> s"description$a", - "enrolled" -> (a % 2 == 0), - "name" -> s"Name $a" - ) + "age" -> (10 + a), + "description" -> s"description$a", + "enrolled" -> (a % 2 == 0), + "name" -> s"Name $a") } - collection.update(QueryBuilder.start("id").greaterThan(4).get, MongoDBObject(("$set", MongoDBObject(("optionalField", true)))), multi = true) + collection.update(QueryBuilder.start("id").greaterThan(4).get, + MongoDBObject(("$set", MongoDBObject(("optionalField", true)))), + multi = true) } } @@ -63,12 +66,15 @@ trait MongoWithSharedContext extends SharedXDContextWithDataTest with MongoDefau client(Database).dropDatabase() } - override protected def prepareClient: Option[ClientParams] = Try { - MongoClient(MongoHost, MongoPort) - } toOption + override protected def prepareClient: Option[ClientParams] = + Try { + MongoClient(MongoHost, MongoPort) + } toOption - override def sparkRegisterTableSQL: Seq[SparkTable] = super.sparkRegisterTableSQL :+ - str2sparkTableDesc(s"CREATE TEMPORARY TABLE $Collection (id BIGINT, age INT, description STRING, enrolled BOOLEAN, name STRING, optionalField BOOLEAN)") + override def sparkRegisterTableSQL: Seq[SparkTable] = + super.sparkRegisterTableSQL :+ + str2sparkTableDesc( + s"CREATE TEMPORARY TABLE $Collection (id BIGINT, age INT, description STRING, enrolled BOOLEAN, name STRING, optionalField BOOLEAN)") override val runningError: String = "MongoDB and Spark must be up and running" @@ -98,7 +104,7 @@ sealed trait MongoDefaultConstants { // Numeric types val float = 1.5f - val tinyint= 127 // Mongo store it like Byte + val tinyint = 127 // Mongo store it like Byte val smallint = 32767 val byte = Byte.MaxValue @@ -106,32 +112,36 @@ sealed trait MongoDefaultConstants { val date = new java.sql.Date(100000000) // Arrays - val arrayint = Seq(1,2,3) - val arraystring = Seq("a","b","c") + val arrayint = Seq(1, 2, 3) + val arraystring = Seq("a", "b", "c") val arraystruct = Seq(MongoDBObject("field1" -> 1, "field2" -> 2)) - val arraystructwithdate = Seq(MongoDBObject("field1" -> date ,"field2" -> 3)) + val arraystructwithdate = Seq(MongoDBObject("field1" -> date, "field2" -> 3)) // Map - val mapintint = new BasicDBObject("1",1).append("2",2) - val mapstringint = new BasicDBObject("1",1).append("2",2) - val mapstringstring = new BasicDBObject("1","1").append("2","2") - val mapstruct = new BasicDBObject("mapstruct", MongoDBObject("structField1" -> date ,"structField2" -> 3)) + val mapintint = new BasicDBObject("1", 1).append("2", 2) + val mapstringint = new BasicDBObject("1", 1).append("2", 2) + val mapstringstring = new BasicDBObject("1", "1").append("2", "2") + val mapstruct = + new BasicDBObject("mapstruct", MongoDBObject("structField1" -> date, "structField2" -> 3)) // Struct - val struct = MongoDBObject("field1" -> 2 ,"field2" -> 3) - val structofstruct = MongoDBObject("field1" -> date ,"field2" -> 3, "struct1" -> MongoDBObject("structField1"-> "structfield1", "structField2" -> 2)) + val struct = MongoDBObject("field1" -> 2, "field2" -> 3) + val structofstruct = MongoDBObject( + "field1" -> date, + "field2" -> 3, + "struct1" -> MongoDBObject("structField1" -> "structfield1", "structField2" -> 2)) // Complex compositions val arraystructarraystruct = Seq( - MongoDBObject( - "stringfield" -> "aa", - "arrayfield" -> Seq(MongoDBObject("field1" -> 1, "field2" -> 2), MongoDBObject("field1" -> -1, "field2" -> -2)) - ), - MongoDBObject( - "stringfield" -> "bb", - "arrayfield" -> Seq(MongoDBObject("field1" -> 11, "field2" -> 22) + MongoDBObject( + "stringfield" -> "aa", + "arrayfield" -> Seq(MongoDBObject("field1" -> 1, "field2" -> 2), + MongoDBObject("field1" -> -1, "field2" -> -2)) + ), + MongoDBObject( + "stringfield" -> "bb", + "arrayfield" -> Seq(MongoDBObject("field1" -> 11, "field2" -> 22)) ) - ) ) -} \ No newline at end of file +} diff --git a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongodbDataTypesIT.scala b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongodbDataTypesIT.scala index 4a4951fe7..3fd0d97d5 100644 --- a/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongodbDataTypesIT.scala +++ b/mongodb/src/test/scala/com/stratio/crossdata/connector/mongodb/MongodbDataTypesIT.scala @@ -20,7 +20,7 @@ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) -class MongodbDataTypesIT extends MongoDataTypesCollection{ +class MongodbDataTypesIT extends MongoDataTypesCollection { override val emptyTypesSetError: String = "Type test entries should have been already inserted" @@ -29,21 +29,21 @@ class MongodbDataTypesIT extends MongoDataTypesCollection{ it should "be able to natively select array elements using their index" in { assumeEnvironmentIsUpAndRunning - val df = sql(s"SELECT arraystring, arraystring[2], arraystring[-1], arrayint[0] FROM typesCheckTable") + val df = + sql(s"SELECT arraystring, arraystring[2], arraystring[-1], arrayint[0] FROM typesCheckTable") val firstRow = df.collect(ExecutionType.Native).head - firstRow(0) shouldBe a[Seq[_]] // Whole `arraystring` column - firstRow(1) shouldBe a[String] // Access to a single element within a string array + firstRow(0) shouldBe a[Seq[_]] // Whole `arraystring` column + firstRow(1) shouldBe a[String] // Access to a single element within a string array Option(firstRow(2)) shouldBe None // Access to an out-of-bounds index - firstRow(3) shouldBe a[Integer] // Access to a single element within an int array + firstRow(3) shouldBe a[Integer] // Access to a single element within an int array } it should "to natively filter by array column indexed elements" in { assumeEnvironmentIsUpAndRunning - val query = - """|SELECT arraystring, arraystring[2], arraystring[-1], arrayint[0] + val query = """|SELECT arraystring, arraystring[2], arraystring[-1], arrayint[0] | FROM typesCheckTable | WHERE (arrayint[0] = 1 OR arrayint[1] = 1) AND arrayint[2] = 3 """.stripMargin.replace("\n", "") @@ -55,5 +55,4 @@ class MongodbDataTypesIT extends MongoDataTypesCollection{ } - } diff --git a/pom.xml b/pom.xml index 64946b4b4..607d73d3e 100644 --- a/pom.xml +++ b/pom.xml @@ -126,6 +126,8 @@ 2.11.8 2.10 ${scala_2.10.version} + + 0.2.11 @@ -395,6 +397,16 @@ + + com.devsmobile + mvn-scalafmt + ${scalafmt.version} + + --maxColumn 100 -i -f . + --test --maxColumn 100 -f . + + + com.stratio.mojo scala-cross-build-maven-plugin diff --git a/server/src/main/scala/com/stratio/crossdata/kryo/CrossdataRegistrator.scala b/server/src/main/scala/com/stratio/crossdata/kryo/CrossdataRegistrator.scala index ac961ad5e..2d1918b99 100644 --- a/server/src/main/scala/com/stratio/crossdata/kryo/CrossdataRegistrator.scala +++ b/server/src/main/scala/com/stratio/crossdata/kryo/CrossdataRegistrator.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.types._ -class CrossdataRegistrator extends KryoRegistrator{ +class CrossdataRegistrator extends KryoRegistrator { override def registerClasses(kryo: Kryo): Unit = { kryo.register(Nil.getClass) kryo.register(StringType.getClass) diff --git a/server/src/main/scala/com/stratio/crossdata/server/CrossdataApplication.scala b/server/src/main/scala/com/stratio/crossdata/server/CrossdataApplication.scala index 88748bf33..3aecad4be 100644 --- a/server/src/main/scala/com/stratio/crossdata/server/CrossdataApplication.scala +++ b/server/src/main/scala/com/stratio/crossdata/server/CrossdataApplication.scala @@ -28,9 +28,9 @@ object CrossdataApplication extends App { crossdataServer.destroy() /** - * This method make a command loop. - * @return nothing. - * */ + * This method make a command loop. + * @return nothing. + * */ @tailrec private def commandLoop(): Unit = { Console.readLine() match { @@ -40,4 +40,4 @@ object CrossdataApplication extends App { commandLoop() } -} \ No newline at end of file +} diff --git a/server/src/main/scala/com/stratio/crossdata/server/CrossdataHttpServer.scala b/server/src/main/scala/com/stratio/crossdata/server/CrossdataHttpServer.scala index b6616ad85..f37b33073 100644 --- a/server/src/main/scala/com/stratio/crossdata/server/CrossdataHttpServer.scala +++ b/server/src/main/scala/com/stratio/crossdata/server/CrossdataHttpServer.scala @@ -52,35 +52,38 @@ class CrossdataHttpServer(config: Config, serverActor: ActorRef, implicit val sy entity(as[Multipart.FormData]) { formData => // collect all parts of the multipart as it arrives into a map var path = "" - val allPartsF: Future[Map[String, Any]] = formData.parts.mapAsync[(String, Any)](1) { + val allPartsF: Future[Map[String, Any]] = formData.parts + .mapAsync[(String, Any)](1) { - case part: BodyPart if part.name == "fileChunk" => - // stream into a file as the chunks of it arrives and return a future file to where it got stored - val file = new java.io.File(s"/tmp/${part.filename.getOrElse("uploadFile")}") - path = file.getAbsolutePath - logger.info("Uploading file...") - // TODO map is not used - part.entity.dataBytes.runWith(FileIO.toFile(file)).map(_ => part.name -> file) - - }.runFold(Map.empty[String, Any])((map, tuple) => map + tuple) + case part: BodyPart if part.name == "fileChunk" => + // stream into a file as the chunks of it arrives and return a future file to where it got stored + val file = new java.io.File(s"/tmp/${part.filename.getOrElse("uploadFile")}") + path = file.getAbsolutePath + logger.info("Uploading file...") + // TODO map is not used + part.entity.dataBytes.runWith(FileIO.toFile(file)).map(_ => part.name -> file) + } + .runFold(Map.empty[String, Any])((map, tuple) => map + tuple) // when processing have finished create a response for the user onSuccess(allPartsF) { allParts => - logger.info("Recieved file") complete { val hdfsConfig = XDContext.xdConfig.getConfig("hdfs") val hdfsPath = writeJarToHdfs(hdfsConfig, path) val session = Session(sessionUUID, null) - allParts.values.toSeq.foreach{ + allParts.values.toSeq.foreach { case file: File => file.delete logger.info("Tmp file deleted") case _ => logger.error("Problem deleting the temporary file.") } //Send a broadcast message to all servers - mediator ! Publish(AddJarTopic, CommandEnvelope(AddJARCommand(hdfsPath, hdfsConfig = Option(hdfsConfig)), session)) + mediator ! Publish(AddJarTopic, + CommandEnvelope(AddJARCommand(hdfsPath, + hdfsConfig = Option(hdfsConfig)), + session)) hdfsPath } } diff --git a/server/src/main/scala/com/stratio/crossdata/server/CrossdataServer.scala b/server/src/main/scala/com/stratio/crossdata/server/CrossdataServer.scala index 7d58acf6f..c08e811b5 100644 --- a/server/src/main/scala/com/stratio/crossdata/server/CrossdataServer.scala +++ b/server/src/main/scala/com/stratio/crossdata/server/CrossdataServer.scala @@ -36,7 +36,6 @@ import org.apache.spark.{SparkConf, SparkContext} import scala.collection.JavaConversions._ import scala.concurrent.Future - class CrossdataServer extends Daemon with ServerConfig { override lazy val logger = Logger.getLogger(classOf[CrossdataServer]) @@ -49,7 +48,8 @@ class CrossdataServer extends Daemon with ServerConfig { override def start(): Unit = { - val sparkParams = config.entrySet() + val sparkParams = config + .entrySet() .map(e => (e.getKey, e.getValue.unwrapped().toString)) .toMap .filterKeys(_.startsWith("config.spark")) @@ -57,7 +57,8 @@ class CrossdataServer extends Daemon with ServerConfig { val metricsPath = Option(sparkParams.get("spark.metrics.conf")) - val filteredSparkParams = metricsPath.fold(sparkParams)(m => checkMetricsFile(sparkParams, m.get)) + val filteredSparkParams = + metricsPath.fold(sparkParams)(m => checkMetricsFile(sparkParams, m.get)) val sparkContext = new SparkContext(new SparkConf().setAll(filteredSparkParams)) @@ -68,26 +69,25 @@ class CrossdataServer extends Daemon with ServerConfig { new BasicSessionProvider(sparkContext, config) } - val sessionProvider = sessionProviderOpt.getOrElse(throw new RuntimeException("Crossdata Server cannot be started because there is no session provider")) - + val sessionProvider = sessionProviderOpt.getOrElse( + throw new RuntimeException( + "Crossdata Server cannot be started because there is no session provider")) system = Some(ActorSystem(clusterName, config)) - system.fold(throw new RuntimeException("Actor system cannot be started")) { actorSystem => - val resizer = DefaultResizer(lowerBound = minServerActorInstances, upperBound = maxServerActorInstances) + val resizer = + DefaultResizer(lowerBound = minServerActorInstances, upperBound = maxServerActorInstances) val serverActor = actorSystem.actorOf( - RoundRobinPool(minServerActorInstances, Some(resizer)).props( - Props( - classOf[ServerActor], - Cluster(actorSystem), - sessionProvider)), - actorName) + RoundRobinPool(minServerActorInstances, Some(resizer)) + .props(Props(classOf[ServerActor], Cluster(actorSystem), sessionProvider)), + actorName) val clientMonitor = actorSystem.actorOf(KeepAliveMaster.props(serverActor), "client-monitor") ClusterReceptionistExtension(actorSystem).registerService(clientMonitor) - val resourceManagerActor = actorSystem.actorOf(ResourceManagerActor.props(Cluster(actorSystem), sessionProvider)) + val resourceManagerActor = + actorSystem.actorOf(ResourceManagerActor.props(Cluster(actorSystem), sessionProvider)) ClusterReceptionistExtension(actorSystem).registerService(serverActor) ClusterReceptionistExtension(actorSystem).registerService(resourceManagerActor) @@ -104,7 +104,7 @@ class CrossdataServer extends Daemon with ServerConfig { def checkMetricsFile(params: Map[String, String], metricsPath: String): Map[String, String] = { val metricsFile = new File(metricsPath) - if(!metricsFile.exists){ + if (!metricsFile.exists) { logger.warn(s"Metrics configuration file not found: ${metricsFile.getPath}") params - "spark.metrics.conf" } else { diff --git a/server/src/main/scala/com/stratio/crossdata/server/actors/JobActor.scala b/server/src/main/scala/com/stratio/crossdata/server/actors/JobActor.scala index dd1d937da..23cc4b6a2 100644 --- a/server/src/main/scala/com/stratio/crossdata/server/actors/JobActor.scala +++ b/server/src/main/scala/com/stratio/crossdata/server/actors/JobActor.scala @@ -31,16 +31,15 @@ import scala.concurrent.duration.FiniteDuration import scala.concurrent.{ExecutionContext, ExecutionException} import scala.util.{Failure, Success} - object JobActor { trait JobStatus object JobStatus { case object Idle extends JobStatus - case object Running extends JobStatus + case object Running extends JobStatus case object Completed extends JobStatus case object Cancelled extends JobStatus - case class Failed(reason: Throwable) extends JobStatus + case class Failed(reason: Throwable) extends JobStatus } trait JobEvent @@ -66,7 +65,6 @@ object JobActor { case class Task(command: SQLCommand, requester: ActorRef, timeout: Option[FiniteDuration]) - /** * The [[JobActor]] state is directly given by the running task which can be: None (Idle st) or a Running, Completed, * Cancelled or Failed task. @@ -74,30 +72,35 @@ object JobActor { */ case class State(runningTask: Option[Cancellable[SQLReply]]) { import JobStatus._ - def getStatus: JobStatus = runningTask map { task => - task.future.value map { - case Success(_) => Completed - case Failure(_: CancellationException) => Cancelled - case Failure(err) => Failed(err) - } getOrElse Running - } getOrElse Idle + def getStatus: JobStatus = + runningTask map { task => + task.future.value map { + case Success(_) => Completed + case Failure(_: CancellationException) => Cancelled + case Failure(err) => Failed(err) + } getOrElse Running + } getOrElse Idle } - def props(xdSession: XDSession, command: SQLCommand, requester: ActorRef, timeout: Option[FiniteDuration]): Props = + def props(xdSession: XDSession, + command: SQLCommand, + requester: ActorRef, + timeout: Option[FiniteDuration]): Props = Props(new JobActor(xdSession, Task(command, requester, timeout))) /** * Executor class which runs each command in a brand new thread each time */ - class ProlificExecutor extends Executor { override def execute(command: Runnable): Unit = new Thread(command) start } + class ProlificExecutor extends Executor { + override def execute(command: Runnable): Unit = new Thread(command) start + } } class JobActor( - val xdContext: XDContext, - val task: Task - ) extends Actor { - + val xdContext: XDContext, + val task: Task +) extends Actor { import JobActor.JobStatus._ import JobActor.State @@ -108,12 +111,10 @@ class JobActor( override def receive: Receive = receive(State(None)) - private def receive(st: State): Receive = { // Commands case StartJob if st.getStatus == Idle => - logger.debug(s"Starting Job under ${context.parent.path}") import context.dispatcher @@ -123,12 +124,13 @@ class JobActor( case Success(queryRes) => requester ! queryRes self ! JobCompleted - case Failure(_: CancellationException) => self ! JobCompleted // Job cancellation - case Failure(e: ExecutionException) => self ! JobFailed(e.getCause) // Spark exception + case Failure(_: CancellationException) => + self ! JobCompleted // Job cancellation + case Failure(e: ExecutionException) => + self ! JobFailed(e.getCause) // Spark exception case Failure(reason) => self ! JobFailed(reason) // Job failure } - val isRunning = runningTask.future.value.isEmpty timeout.filter(_ => isRunning).foreach { @@ -138,7 +140,7 @@ class JobActor( context.become(receive(st.copy(runningTask = Some(runningTask)))) case CancelJob => - st.runningTask.foreach{ tsk => + st.runningTask.foreach { tsk => logger.debug(s"Cancelling ${self.path}'s task ") tsk.cancel() } @@ -151,7 +153,8 @@ class JobActor( case event @ JobFailed(e) if sender == self => logger.debug(s"Task failed at ${self.path}") context.parent ! event - requester ! SQLReply(command.requestId, ErrorSQLResult(e.getMessage, Some(new Exception(e.getMessage)))) + requester ! SQLReply(command.requestId, + ErrorSQLResult(e.getMessage, Some(new Exception(e.getMessage)))) throw e //Let It Crash: It'll be managed by its supervisor case JobCompleted if sender == self => logger.debug(s"Completed or cancelled ${self.path} task") @@ -164,9 +167,11 @@ class JobActor( Cancellable { val df = xdContext.sql(command.sql) - val rows = if (command.flattenResults) - df.asInstanceOf[XDDataFrame].flattenedCollect() //TODO: Replace this cast by an implicit conversion - else df.collect() + val rows = + if (command.flattenResults) + df.asInstanceOf[XDDataFrame] + .flattenedCollect() //TODO: Replace this cast by an implicit conversion + else df.collect() SQLReply(command.requestId, SuccessfulSQLResult(rows, df.schema)) } diff --git a/server/src/main/scala/com/stratio/crossdata/server/actors/ResourceManagerActor.scala b/server/src/main/scala/com/stratio/crossdata/server/actors/ResourceManagerActor.scala index 6de37ab6f..94b8bf77a 100644 --- a/server/src/main/scala/com/stratio/crossdata/server/actors/ResourceManagerActor.scala +++ b/server/src/main/scala/com/stratio/crossdata/server/actors/ResourceManagerActor.scala @@ -15,7 +15,6 @@ */ package com.stratio.crossdata.server.actors - import akka.actor.{Actor, ActorRef, Props} import akka.cluster.Cluster import akka.contrib.pattern.DistributedPubSubExtension @@ -41,7 +40,9 @@ object ResourceManagerActor { } -class ResourceManagerActor(cluster: Cluster, sessionProvider: XDSessionProvider) extends Actor with ServerConfig { +class ResourceManagerActor(cluster: Cluster, sessionProvider: XDSessionProvider) + extends Actor + with ServerConfig { import ResourceManagerActor._ @@ -75,8 +76,9 @@ class ResourceManagerActor(cluster: Cluster, sessionProvider: XDSessionProvider) // Commands reception: Checks whether the command can be run at this Server passing it to the execution method if so def AddJarMessages(st: State): Receive = { - case CommandEnvelope(addJarCommand: AddJARCommand, session@Session(id, requester)) => - logger.debug(s"Add JAR received ${addJarCommand.requestId}: ${addJarCommand.path}. Actor ${self.path.toStringWithoutAddress}") + case CommandEnvelope(addJarCommand: AddJARCommand, session @ Session(id, requester)) => + logger.debug( + s"Add JAR received ${addJarCommand.requestId}: ${addJarCommand.path}. Actor ${self.path.toStringWithoutAddress}") logger.debug(s"Session identifier $session") //TODO Maybe include job controller if it is necessary as in sql command if (addJarCommand.path.toLowerCase.startsWith("hdfs://")) { @@ -85,18 +87,20 @@ class ResourceManagerActor(cluster: Cluster, sessionProvider: XDSessionProvider) xdSession.addJar(addJarCommand.path) case Failure(error) => logger.warn(s"Received message with an unknown sessionId $id", error) - sender ! ErrorSQLResult(s"Unable to recover the session ${session.id}. Cause: ${error.getMessage}") + sender ! ErrorSQLResult( + s"Unable to recover the session ${session.id}. Cause: ${error.getMessage}") } // TODO addJar should not affect other sessions - sender ! SQLReply(addJarCommand.requestId, SuccessfulSQLResult(Array.empty, new StructType())) + sender ! SQLReply(addJarCommand.requestId, + SuccessfulSQLResult(Array.empty, new StructType())) } else { - sender ! SQLReply(addJarCommand.requestId, ErrorSQLResult("File doesn't exist or is not a hdfs file", Some(new Exception("File doesn't exist or is not a hdfs file")))) + sender ! SQLReply( + addJarCommand.requestId, + ErrorSQLResult("File doesn't exist or is not a hdfs file", + Some(new Exception("File doesn't exist or is not a hdfs file")))) } case _ => } - - - } diff --git a/server/src/main/scala/com/stratio/crossdata/server/actors/ServerActor.scala b/server/src/main/scala/com/stratio/crossdata/server/actors/ServerActor.scala index 386e04f19..fdda7b814 100644 --- a/server/src/main/scala/com/stratio/crossdata/server/actors/ServerActor.scala +++ b/server/src/main/scala/com/stratio/crossdata/server/actors/ServerActor.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.types.StructType import scala.concurrent.duration._ import scala.util.{Failure, Success} - object ServerActor { val ManagementTopic: String = "jobsManagement" @@ -61,7 +60,8 @@ object ServerActor { // TODO it should only accept messages from known sessions class ServerActor(cluster: Cluster, sessionProvider: XDSessionProvider) - extends Actor with ServerConfig { + extends Actor + with ServerConfig { import ServerActor.ManagementMessages._ import ServerActor._ @@ -100,35 +100,46 @@ class ServerActor(cluster: Cluster, sessionProvider: XDSessionProvider) * @param cmd * @param st */ - private def executeAccepted(cmd: CommandEnvelope)(st: State): Unit = cmd match { - case CommandEnvelope(sqlCommand@SQLCommand(query, queryId, withColnames, timeout), session@Session(id, requester)) => - logger.debug(s"Query received $queryId: $query. Actor ${self.path.toStringWithoutAddress}") - logger.debug(s"Session identifier $session") - - sessionProvider.session(id) match { - case Success(xdSession) => - val jobActor = context.actorOf(JobActor.props(xdSession, sqlCommand, sender(), timeout)) - jobActor ! StartJob - context.become( - ready(st.copy(jobsById = st.jobsById + (JobId(requester, id, sqlCommand.queryId) -> jobActor))) - ) - - case Failure(error) => - logger.warn(s"Received message with an unknown sessionId $id", error) - sender ! ErrorSQLResult(s"Unable to recover the session ${session.id}. Cause: ${error.getMessage}") - } - - - case CommandEnvelope(addAppCommand@AddAppCommand(path, alias, clss, _), session@Session(id, requester)) => - if ( sessionProvider.session(id).map(_.addApp(path, clss, alias)).getOrElse(None).isDefined)// TODO improve addJar sessionManagement - sender ! SQLReply(addAppCommand.requestId, SuccessfulSQLResult(Array.empty, new StructType())) - else - sender ! SQLReply(addAppCommand.requestId, ErrorSQLResult("App can't be stored in the catalog")) - - case CommandEnvelope(cc@CancelQueryExecution(queryId), session@Session(id, requester)) => - st.jobsById.get(JobId(requester, id, queryId)).get ! CancelJob - } - + private def executeAccepted(cmd: CommandEnvelope)(st: State): Unit = + cmd match { + case CommandEnvelope(sqlCommand @ SQLCommand(query, queryId, withColnames, timeout), + session @ Session(id, requester)) => + logger.debug(s"Query received $queryId: $query. Actor ${self.path.toStringWithoutAddress}") + logger.debug(s"Session identifier $session") + + sessionProvider.session(id) match { + case Success(xdSession) => + val jobActor = + context.actorOf(JobActor.props(xdSession, sqlCommand, sender(), timeout)) + jobActor ! StartJob + context.become( + ready( + st.copy( + jobsById = st.jobsById + (JobId(requester, id, sqlCommand.queryId) -> jobActor))) + ) + + case Failure(error) => + logger.warn(s"Received message with an unknown sessionId $id", error) + sender ! ErrorSQLResult( + s"Unable to recover the session ${session.id}. Cause: ${error.getMessage}") + } + + case CommandEnvelope(addAppCommand @ AddAppCommand(path, alias, clss, _), + session @ Session(id, requester)) => + if (sessionProvider + .session(id) + .map(_.addApp(path, clss, alias)) + .getOrElse(None) + .isDefined) // TODO improve addJar sessionManagement + sender ! SQLReply(addAppCommand.requestId, + SuccessfulSQLResult(Array.empty, new StructType())) + else + sender ! SQLReply(addAppCommand.requestId, + ErrorSQLResult("App can't be stored in the catalog")) + + case CommandEnvelope(cc @ CancelQueryExecution(queryId), session @ Session(id, requester)) => + st.jobsById.get(JobId(requester, id, queryId)).get ! CancelJob + } // Receive functions: @@ -139,7 +150,7 @@ class ServerActor(cluster: Cluster, sessionProvider: XDSessionProvider) case DelegateCommand(cmd, broadcaster) if broadcaster != self => cmd match { // Inner pattern matching for future delegated command validations - case sc@CommandEnvelope(CancelQueryExecution(queryId), Session(sid, requester)) => + case sc @ CommandEnvelope(CancelQueryExecution(queryId), Session(sid, requester)) => st.jobsById.get(JobId(requester, sid, queryId)) foreach (_ => executeAccepted(sc)(st)) /* If it doesn't validate it won't be re-broadcast since the source server already distributed it to all servers through the topic. */ @@ -149,16 +160,16 @@ class ServerActor(cluster: Cluster, sessionProvider: XDSessionProvider) // Commands reception: Checks whether the command can be run at this Server passing it to the execution method if so def commandMessagesRec(st: State): Receive = { - case sc@CommandEnvelope(_: SQLCommand, _) => + case sc @ CommandEnvelope(_: SQLCommand, _) => executeAccepted(sc)(st) - case sc@CommandEnvelope(_: AddJARCommand, _) => + case sc @ CommandEnvelope(_: AddJARCommand, _) => executeAccepted(sc)(st) - case sc@CommandEnvelope(_: AddAppCommand, _) => + case sc @ CommandEnvelope(_: AddAppCommand, _) => executeAccepted(sc)(st) - case sc@CommandEnvelope(cc: ControlCommand, session@Session(id, requester)) => + case sc @ CommandEnvelope(cc: ControlCommand, session @ Session(id, requester)) => st.jobsById.get(JobId(requester, id, cc.requestId)) map { _ => executeAccepted(sc)(st) // Command validated to be executed by this server. } getOrElse { @@ -166,10 +177,10 @@ class ServerActor(cluster: Cluster, sessionProvider: XDSessionProvider) mediator ! Publish(ManagementTopic, DelegateCommand(sc, self)) } - case sc@CommandEnvelope(_: ClusterStateCommand, session) => + case sc @ CommandEnvelope(_: ClusterStateCommand, session) => sender ! ClusterStateReply(sc.cmd.requestId, cluster.state) - case sc@CommandEnvelope(_: OpenSessionCommand, session) => + case sc @ CommandEnvelope(_: OpenSessionCommand, session) => val open = sessionProvider.newSession(session.id) match { case Success(_) => logger.debug(s"new session with sessionID=${session.id} has been created") @@ -180,12 +191,12 @@ class ServerActor(cluster: Cluster, sessionProvider: XDSessionProvider) } sender ! OpenSessionReply(sc.cmd.requestId, isOpen = open) + context.actorSelection("/user/client-monitor") ! DoCheck(session.id, + expectedClientHeartbeatPeriod) - context.actorSelection("/user/client-monitor") ! DoCheck(session.id, expectedClientHeartbeatPeriod) - - case sc@CommandEnvelope(_: CloseSessionCommand, session) => + case sc @ CommandEnvelope(_: CloseSessionCommand, session) => closeSessionTerminatingJobs(session.id)(st) - /* Note that the client monitoring isn't explicitly stopped. It'll after the first miss + /* Note that the client monitoring isn't explicitly stopped. It'll after the first miss is detected, right after the driver has ended its session. */ } @@ -216,8 +227,9 @@ class ServerActor(cluster: Cluster, sessionProvider: XDSessionProvider) broadcastRequestsRec(st) orElse commandMessagesRec(st) orElse eventsRec(st) orElse - clientMonitoringEvents(st) orElse { case any => - logger.warn(s"Something is going wrong! Unknown message: $any") + clientMonitoringEvents(st) orElse { + case any => + logger.warn(s"Something is going wrong! Unknown message: $any") } private def closeSessionTerminatingJobs(sessionId: UUID)(st: State): Unit = { @@ -243,9 +255,10 @@ class ServerActor(cluster: Cluster, sessionProvider: XDSessionProvider) } //TODO: Use number of tries and timeout configuration parameters - override def supervisorStrategy: SupervisorStrategy = OneForOneStrategy(retryNoAttempts, retryCountWindow) { - case _ => Restart //Crashed job gets restarted (or not, depending on `retryNoAttempts` and `retryCountWindow`) - } + override def supervisorStrategy: SupervisorStrategy = + OneForOneStrategy(retryNoAttempts, retryCountWindow) { + case _ => + Restart //Crashed job gets restarted (or not, depending on `retryNoAttempts` and `retryCountWindow`) + } } - diff --git a/server/src/main/scala/com/stratio/crossdata/server/config/NumberActorConfig.scala b/server/src/main/scala/com/stratio/crossdata/server/config/NumberActorConfig.scala index c0ba9a4bf..f45e94c33 100644 --- a/server/src/main/scala/com/stratio/crossdata/server/config/NumberActorConfig.scala +++ b/server/src/main/scala/com/stratio/crossdata/server/config/NumberActorConfig.scala @@ -17,7 +17,6 @@ package com.stratio.crossdata.server.config import com.typesafe.config.Config - object NumberActorConfig { val DefaultServerExecutorInstances = 5 val ServerActorInstancesMin = "config.akka.number.server-actor-min" @@ -29,8 +28,12 @@ trait NumberActorConfig { import NumberActorConfig.ServerActorInstancesMin import NumberActorConfig.ServerActorInstancesMax import NumberActorConfig.DefaultServerExecutorInstances - lazy val minServerActorInstances: Int = Option(config.getString(ServerActorInstancesMin)).map(_.toInt).getOrElse(DefaultServerExecutorInstances) - lazy val maxServerActorInstances: Int = Option(config.getString(ServerActorInstancesMax)).map(_.toInt).getOrElse(minServerActorInstances*2) + lazy val minServerActorInstances: Int = Option(config.getString(ServerActorInstancesMin)) + .map(_.toInt) + .getOrElse(DefaultServerExecutorInstances) + lazy val maxServerActorInstances: Int = Option(config.getString(ServerActorInstancesMax)) + .map(_.toInt) + .getOrElse(minServerActorInstances * 2) def config: Config } diff --git a/server/src/main/scala/com/stratio/crossdata/server/config/ServerConfig.scala b/server/src/main/scala/com/stratio/crossdata/server/config/ServerConfig.scala index 0aa90127a..9c2b5725a 100644 --- a/server/src/main/scala/com/stratio/crossdata/server/config/ServerConfig.scala +++ b/server/src/main/scala/com/stratio/crossdata/server/config/ServerConfig.scala @@ -31,7 +31,6 @@ object ServerConfig { val ServerBasicConfig = "server-reference.conf" val ParentConfigName = "crossdata-server" - val SparkSqlConfigPrefix = CoreConfig.SparkSqlConfigPrefix val ClientExpectedHeartbeatPeriod = "config.client.ExpectedHeartbeatPeriod" @@ -68,33 +67,37 @@ trait ServerConfig extends NumberActorConfig { lazy val clusterName = config.getString(ServerConfig.ServerClusterNameKey) lazy val actorName = config.getString(ServerConfig.ServerActorNameKey) - lazy val retryNoAttempts: Int = Try(config.getInt(ServerConfig.ServerRetryMaxAttempts)).getOrElse(0) + lazy val retryNoAttempts: Int = + Try(config.getInt(ServerConfig.ServerRetryMaxAttempts)).getOrElse(0) lazy val retryCountWindow: Duration = Try( - config.getDuration(ServerConfig.ServerRetryCountWindow, TimeUnit.MILLISECONDS) - ) map (Duration(_, TimeUnit.MILLISECONDS)) getOrElse (Duration.Inf) + config.getDuration(ServerConfig.ServerRetryCountWindow, TimeUnit.MILLISECONDS) + ) map (Duration(_, TimeUnit.MILLISECONDS)) getOrElse (Duration.Inf) lazy val completedJobTTL: Duration = extractDurationField(ServerConfig.FinishedJobTTL) - lazy val expectedClientHeartbeatPeriod: FiniteDuration = - extractDurationField(ServerConfig.ClientExpectedHeartbeatPeriod) match { - case d: FiniteDuration => - Seq(11 seconds, d) max // Alarm period need to be at least twice the hear beat period (5 seconds) - case _ => 2 minute // Default value - } + lazy val expectedClientHeartbeatPeriod: FiniteDuration = extractDurationField( + ServerConfig.ClientExpectedHeartbeatPeriod) match { + case d: FiniteDuration => + Seq(11 seconds, d) max // Alarm period need to be at least twice the hear beat period (5 seconds) + case _ => 2 minute // Default value + } lazy val isHazelcastEnabled = config.getBoolean(ServerConfig.IsHazelcastProviderEnabledProperty) override val config: Config = { - var defaultConfig = ConfigFactory.load(ServerConfig.ServerBasicConfig).getConfig(ServerConfig.ParentConfigName) + var defaultConfig = + ConfigFactory.load(ServerConfig.ServerBasicConfig).getConfig(ServerConfig.ParentConfigName) val envConfigFile = Option(System.getProperties.getProperty(ServerConfig.ServerUserConfigFile)) - val configFile = envConfigFile.getOrElse(defaultConfig.getString(ServerConfig.ServerUserConfigFile)) + val configFile = + envConfigFile.getOrElse(defaultConfig.getString(ServerConfig.ServerUserConfigFile)) val configResource = defaultConfig.getString(ServerConfig.ServerUserConfigResource) if (configResource != "") { val resource = ServerConfig.getClass.getClassLoader.getResource(configResource) if (resource != null) { - val userConfig = ConfigFactory.parseResources(configResource).getConfig(ServerConfig.ParentConfigName) + val userConfig = + ConfigFactory.parseResources(configResource).getConfig(ServerConfig.ParentConfigName) defaultConfig = userConfig.withFallback(defaultConfig) logger.info("User resource (" + configResource + ") found in resources") } else { @@ -122,21 +125,21 @@ trait ServerConfig extends NumberActorConfig { } // System properties - val systemPropertiesConfig = - Try( - ConfigFactory.parseProperties(System.getProperties).getConfig(ServerConfig.ParentConfigName) - ).getOrElse( + val systemPropertiesConfig = Try( + ConfigFactory + .parseProperties(System.getProperties) + .getConfig(ServerConfig.ParentConfigName) + ).getOrElse( ConfigFactory.parseProperties(System.getProperties) - ) + ) defaultConfig = systemPropertiesConfig.withFallback(defaultConfig) val finalConfig = { if (defaultConfig.hasPath("akka.cluster.server-nodes")) { val serverNodes = defaultConfig.getString("akka.cluster.server-nodes") - defaultConfig.withValue( - "akka.cluster.seed-nodes", - ConfigValueFactory.fromIterable(serverNodes.split(",").toList)) + defaultConfig.withValue("akka.cluster.seed-nodes", + ConfigValueFactory.fromIterable(serverNodes.split(",").toList)) } else { defaultConfig } @@ -145,9 +148,9 @@ trait ServerConfig extends NumberActorConfig { ConfigFactory.load(finalConfig) } - private def extractDurationField(key: String): Duration = Try( - config.getDuration(key, TimeUnit.MILLISECONDS) - ) map (FiniteDuration(_, TimeUnit.MILLISECONDS)) getOrElse (Duration.Inf) + private def extractDurationField(key: String): Duration = + Try( + config.getDuration(key, TimeUnit.MILLISECONDS) + ) map (FiniteDuration(_, TimeUnit.MILLISECONDS)) getOrElse (Duration.Inf) } - diff --git a/server/src/main/scala/org/apache/spark/sql/crossdata/HazelcastSQLConf.scala b/server/src/main/scala/org/apache/spark/sql/crossdata/HazelcastSQLConf.scala index 57d9c4913..c9f96d45b 100644 --- a/server/src/main/scala/org/apache/spark/sql/crossdata/HazelcastSQLConf.scala +++ b/server/src/main/scala/org/apache/spark/sql/crossdata/HazelcastSQLConf.scala @@ -20,21 +20,24 @@ import java.util.Map.Entry import com.hazelcast.core.IMap import com.stratio.crossdata.util.CacheInvalidator -class HazelcastSQLConf(hazelcastMap: IMap[String, String], cacheInvalidator: CacheInvalidator) extends XDSQLConf { +class HazelcastSQLConf(hazelcastMap: IMap[String, String], cacheInvalidator: CacheInvalidator) + extends XDSQLConf { import HazelcastSQLConf._ private var enabledInvalidation = true - private val invalidator: () => CacheInvalidator = - () => if(enabledInvalidation) cacheInvalidator else disabledInvalidator + private val invalidator: () => CacheInvalidator = () => + if (enabledInvalidation) cacheInvalidator else disabledInvalidator def invalidateLocalCache: Unit = localMap.clear - private val localMap = java.util.Collections.synchronizedMap(new java.util.HashMap[String, String]()) + private val localMap = + java.util.Collections.synchronizedMap(new java.util.HashMap[String, String]()) override protected[spark] val settings = { - new ChainedJavaMapWithWriteInvalidation[String, String](Seq(localMap, hazelcastMap), invalidator) + new ChainedJavaMapWithWriteInvalidation[String, String](Seq(localMap, hazelcastMap), + invalidator) } override def enableCacheInvalidation(enable: Boolean): XDSQLConf = { @@ -55,26 +58,28 @@ object HazelcastSQLConf { lazy val nullval = pNull } - class ChainedJavaMapWithWriteInvalidation[K,V]( - private val delegatedMaps: Seq[java.util.Map[K,V]], - private val invalidator: () => CacheInvalidator - ) - - extends java.util.Map[K,V] { + class ChainedJavaMapWithWriteInvalidation[K, V]( + private val delegatedMaps: Seq[java.util.Map[K, V]], + private val invalidator: () => CacheInvalidator + ) extends java.util.Map[K, V] { require(!delegatedMaps.isEmpty) import scala.collection.JavaConversions._ - override def values(): java.util.Collection[V] = (Set.empty[V] /: delegatedMaps) { - case (values, delegatedMap) => values ++ delegatedMap.values - } + override def values(): java.util.Collection[V] = + (Set.empty[V] /: delegatedMaps) { + case (values, delegatedMap) => values ++ delegatedMap.values + } - override def get(key: scala.Any): V = delegatedMaps.view.map(_.get(key)).find(_ != null).getOrElse { - NullBuilder[V]().nullval - } + override def get(key: scala.Any): V = + delegatedMaps.view.map(_.get(key)).find(_ != null).getOrElse { + NullBuilder[V]().nullval + } + + override def entrySet(): java.util.Set[Entry[K, V]] = + delegatedMaps.last.entrySet() - override def entrySet(): java.util.Set[Entry[K, V]] = delegatedMaps.last.entrySet() /** * Note that this implementation assumes each level is contained by the next one, being the last one * continent of every single preceding one. Otherwise, it'd be better to do something like: @@ -83,20 +88,18 @@ object HazelcastSQLConf { * case (values, delegatedMap) => values ++ delegatedMap.entrySet() * } */ - - override def put(key: K, value: V): V = { invalidator().invalidateCache (Option.empty[V] /: delegatedMaps) { case (prev, delegatedMap) => val newRes = delegatedMap.put(key, value) prev orElse Option(newRes) - } getOrElse(NullBuilder[V]().nullval) + } getOrElse (NullBuilder[V]().nullval) } override def clear(): Unit = { invalidator().invalidateCache - delegatedMaps foreach(_.clear) + delegatedMaps foreach (_.clear) } override def size(): Int = delegatedMaps.maxBy(_.size).size @@ -106,9 +109,11 @@ object HazelcastSQLConf { delegatedMaps.map(_.remove(key)).head } - override def containsKey(key: scala.Any): Boolean = delegatedMaps.view exists (_.containsKey(key)) + override def containsKey(key: scala.Any): Boolean = + delegatedMaps.view exists (_.containsKey(key)) - override def containsValue(value: scala.Any): Boolean = delegatedMaps.view exists (_.containsValue(value)) + override def containsValue(value: scala.Any): Boolean = + delegatedMaps.view exists (_.containsValue(value)) override def isEmpty: Boolean = delegatedMaps forall (_.isEmpty) @@ -120,13 +125,13 @@ object HazelcastSQLConf { override def keySet(): java.util.Set[K] = delegatedMaps.last.keySet() /** - * Note that this implementation assumes each level is contained by the next one, being the last one - * continent of every single preceding one. Otherwise, it'd be better to do something like: - * - * (Set.empty[K] /: delegatedMaps) { - * case (keys, delegatedMap) => keys ++ delegatedMap.keySet() - * } - */ + * Note that this implementation assumes each level is contained by the next one, being the last one + * continent of every single preceding one. Otherwise, it'd be better to do something like: + * + * (Set.empty[K] /: delegatedMaps) { + * case (keys, delegatedMap) => keys ++ delegatedMap.keySet() + * } + */ } diff --git a/server/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/HazelcastCatalog.scala b/server/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/HazelcastCatalog.scala index f4912fbff..91665a50f 100644 --- a/server/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/HazelcastCatalog.scala +++ b/server/src/main/scala/org/apache/spark/sql/crossdata/catalog/temporary/HazelcastCatalog.scala @@ -24,45 +24,50 @@ import org.apache.spark.sql.crossdata.catalog.XDCatalog.{CrossdataTable, ViewIde import org.apache.spark.sql.crossdata.catalog.interfaces.{XDCatalogCommon, XDTemporaryCatalog} import org.apache.spark.sql.crossdata.util.CreateRelationUtil - class HazelcastCatalog( - private val tables: IMap[TableIdentifierNormalized, CrossdataTable], - private val views: IMap[TableIdentifierNormalized, String] - )(implicit val catalystConf: CatalystConf) extends XDTemporaryCatalog with Serializable { - - - override def relation(tableIdent: TableIdentifierNormalized)(implicit sqlContext: SQLContext): Option[LogicalPlan] = - { - Option(tables.get(tableIdent)) map (CreateRelationUtil.createLogicalRelation(sqlContext, _)) - } orElse { - Option(views.get(tableIdent)) map (sqlContext.sql(_).logicalPlan) - } - + private val tables: IMap[TableIdentifierNormalized, CrossdataTable], + private val views: IMap[TableIdentifierNormalized, String] +)(implicit val catalystConf: CatalystConf) + extends XDTemporaryCatalog + with Serializable { + + override def relation(tableIdent: TableIdentifierNormalized)( + implicit sqlContext: SQLContext): Option[LogicalPlan] = { + Option(tables.get(tableIdent)) map (CreateRelationUtil.createLogicalRelation(sqlContext, _)) + } orElse { + Option(views.get(tableIdent)) map (sqlContext.sql(_).logicalPlan) + } - override def allRelations(databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = { + override def allRelations( + databaseName: Option[StringNormalized]): Seq[TableIdentifierNormalized] = { import scala.collection.JavaConversions._ val tableIdentSeq = (tables ++ views).keys.toSeq databaseName.map { dbName => tableIdentSeq.filter { - case TableIdentifierNormalized(_, Some(dIdent)) => dIdent == dbName.normalizedString + case TableIdentifierNormalized(_, Some(dIdent)) => + dIdent == dbName.normalizedString case other => false } }.getOrElse(tableIdentSeq) } - override def saveTable(tableIdentifier: TableIdentifierNormalized, plan: LogicalPlan, crossdataTable: Option[CrossdataTable]): Unit = { + override def saveTable(tableIdentifier: TableIdentifierNormalized, + plan: LogicalPlan, + crossdataTable: Option[CrossdataTable]): Unit = { require(crossdataTable.isDefined, requireSerializablePlanMessage("CrossdataTable")) // TODO add create/drop if not exists => fail if exists instead of override the table Option(views get tableIdentifier) foreach (_ => dropView(tableIdentifier)) - tables set(tableIdentifier, crossdataTable.get) + tables set (tableIdentifier, crossdataTable.get) } - override def saveView(viewIdentifier: ViewIdentifierNormalized, plan: LogicalPlan, query: Option[String]): Unit = { + override def saveView(viewIdentifier: ViewIdentifierNormalized, + plan: LogicalPlan, + query: Option[String]): Unit = { require(query.isDefined, requireSerializablePlanMessage("query")) Option(tables get viewIdentifier) foreach (_ => dropTable(viewIdentifier)) - views set(viewIdentifier, query.get) + views set (viewIdentifier, query.get) } override def dropTable(tableIdentifier: TableIdentifierNormalized): Unit = @@ -71,13 +76,13 @@ class HazelcastCatalog( override def dropView(viewIdentifier: ViewIdentifierNormalized): Unit = views remove viewIdentifier - override def dropAllViews(): Unit = views clear() + override def dropAllViews(): Unit = views clear () - override def dropAllTables(): Unit = tables clear() + override def dropAllTables(): Unit = tables clear () override def isAvailable: Boolean = true private def requireSerializablePlanMessage(parameter: String) = s"Parameter $parameter is required. A LogicalPlan cannot be stored in Hazelcast" -} \ No newline at end of file +} diff --git a/server/src/main/scala/org/apache/spark/sql/crossdata/session/HazelcastCacheInvalidator.scala b/server/src/main/scala/org/apache/spark/sql/crossdata/session/HazelcastCacheInvalidator.scala index d022bb885..032cb1f05 100644 --- a/server/src/main/scala/org/apache/spark/sql/crossdata/session/HazelcastCacheInvalidator.scala +++ b/server/src/main/scala/org/apache/spark/sql/crossdata/session/HazelcastCacheInvalidator.scala @@ -22,18 +22,17 @@ import org.apache.spark.sql.crossdata.session.XDSessionProvider.SessionID object HazelcastCacheInvalidator { - trait CacheInvalidationEvent extends Serializable case class ResourceInvalidation(sessionId: SessionID) extends CacheInvalidationEvent - case object ResourceInvalidationForAllSessions extends CacheInvalidationEvent + case object ResourceInvalidationForAllSessions extends CacheInvalidationEvent } class HazelcastCacheInvalidator( - sessionID: SessionID, - topic: ITopic[CacheInvalidationEvent] - ) extends CacheInvalidator { + sessionID: SessionID, + topic: ITopic[CacheInvalidationEvent] +) extends CacheInvalidator { private val invalidateSessionEvent = ResourceInvalidation(sessionID) diff --git a/server/src/main/scala/org/apache/spark/sql/crossdata/session/hazelcastResourceManagers.scala b/server/src/main/scala/org/apache/spark/sql/crossdata/session/hazelcastResourceManagers.scala index 8d136222f..e60d8b86e 100644 --- a/server/src/main/scala/org/apache/spark/sql/crossdata/session/hazelcastResourceManagers.scala +++ b/server/src/main/scala/org/apache/spark/sql/crossdata/session/hazelcastResourceManagers.scala @@ -32,11 +32,9 @@ import org.apache.spark.sql.crossdata.{HazelcastSQLConf, XDSQLConf} import scala.collection.mutable import scala.util.{Success, Try} - - - -trait HazelcastSessionResourceManager[V] extends MessageListener[CacheInvalidationEvent] - with SessionResourceManager[V] { +trait HazelcastSessionResourceManager[V] + extends MessageListener[CacheInvalidationEvent] + with SessionResourceManager[V] { protected val topicName: String protected val hInstance: HazelcastInstance @@ -49,7 +47,7 @@ trait HazelcastSessionResourceManager[V] extends MessageListener[CacheInvalidati override def onMessage(message: Message[CacheInvalidationEvent]): Unit = Option(message.getMessageObject).filterNot( - _ => message.getPublishingMember equals hInstance.getCluster.getLocalMember + _ => message.getPublishingMember equals hInstance.getCluster.getLocalMember ) foreach { case ResourceInvalidation(sessionId) => invalidateLocalCaches(sessionId) case ResourceInvalidationForAllSessions => invalidateAllLocalCaches @@ -68,15 +66,16 @@ trait HazelcastSessionResourceManager[V] extends MessageListener[CacheInvalidati sessionID.map(ResourceInvalidation(_)) getOrElse ResourceInvalidationForAllSessions } - protected def publishInvalidation(sessionID: SessionID): Unit = publishInvalidation(Some(sessionID)) + protected def publishInvalidation(sessionID: SessionID): Unit = + publishInvalidation(Some(sessionID)) } class HazelcastSessionCatalogManager( - override protected val hInstance: HazelcastInstance, - catalystConf: CatalystConf, - sessionInvalidator: Option[SessionID] => Option[CacheInvalidator] = (_ => None) - ) extends HazelcastSessionResourceManager[Seq[XDTemporaryCatalog]] { + override protected val hInstance: HazelcastInstance, + catalystConf: CatalystConf, + sessionInvalidator: Option[SessionID] => Option[CacheInvalidator] = (_ => None) +) extends HazelcastSessionResourceManager[Seq[XDTemporaryCatalog]] { import HazelcastSessionProvider._ @@ -87,14 +86,17 @@ class HazelcastSessionCatalogManager( invalidationTopic - private val sessionIDToMapCatalog: mutable.Map[SessionID, XDTemporaryCatalogWithInvalidation] = mutable.Map.empty - private val sessionIDToTableViewID: IMap[SessionID, (TableMapUUID, ViewMapUUID)] = hInstance.getMap(HazelcastCatalogMapId) + private val sessionIDToMapCatalog: mutable.Map[SessionID, XDTemporaryCatalogWithInvalidation] = + mutable.Map.empty + private val sessionIDToTableViewID: IMap[SessionID, (TableMapUUID, ViewMapUUID)] = + hInstance.getMap(HazelcastCatalogMapId) // Returns the seq of XDTempCatalog for the new session - //NOTE: THIS METHOD SHOULD NEVER BE CALLED TWICE WITH THE SAME ID. IT SHOULDN'T HAPPEN BUT SOME PROTECTION IS STILL TODO - override def newResource(key: SessionID, from: Option[Seq[XDTemporaryCatalog]] = None): Seq[XDTemporaryCatalog] = { + override def newResource( + key: SessionID, + from: Option[Seq[XDTemporaryCatalog]] = None): Seq[XDTemporaryCatalog] = { // AddMapCatalog for local/cache interaction val localCatalog = addNewMapCatalog(key) @@ -112,8 +114,11 @@ class HazelcastSessionCatalogManager( override def getResource(key: SessionID): Try[Seq[XDTemporaryCatalog]] = for { (tableUUID, viewUUID) <- checkNotNull(sessionIDToTableViewID.get(key)) - hazelcastTables <- checkNotNull(hInstance.getMap[TableIdentifierNormalized, CrossdataTable](tableUUID.toString)) - hazelcastViews <- checkNotNull(hInstance.getMap[ViewIdentifierNormalized, String](viewUUID.toString)) + hazelcastTables <- checkNotNull( + hInstance.getMap[TableIdentifierNormalized, CrossdataTable]( + tableUUID.toString)) + hazelcastViews <- checkNotNull( + hInstance.getMap[ViewIdentifierNormalized, String](viewUUID.toString)) } yield { val hazelcastCatalog = new HazelcastCatalog(hazelcastTables, hazelcastViews)(catalystConf) val mapCatalog = sessionIDToMapCatalog.getOrElse(key, addNewMapCatalog(key)) // local catalog could not exist @@ -126,8 +131,8 @@ class HazelcastSessionCatalogManager( hazelcastTables <- checkNotNull(hInstance.getMap(tableUUID.toString)) hazelcastViews <- checkNotNull(hInstance.getMap(viewUUID.toString)) } yield { - hazelcastViews clear() - hazelcastTables clear() + hazelcastViews clear () + hazelcastTables clear () sessionIDToTableViewID remove key sessionIDToMapCatalog remove key publishInvalidation(key) @@ -135,9 +140,10 @@ class HazelcastSessionCatalogManager( override def clearAllSessionsResources(): Unit = { import scala.collection.JavaConversions._ - sessionIDToTableViewID.values().foreach { case (tableUUID, viewUUID) => - hInstance.getMap(tableUUID.toString).clear() - hInstance.getMap(viewUUID.toString).clear() + sessionIDToTableViewID.values().foreach { + case (tableUUID, viewUUID) => + hInstance.getMap(tableUUID.toString).clear() + hInstance.getMap(viewUUID.toString).clear() } sessionIDToMapCatalog.clear() sessionIDToTableViewID.clear() @@ -146,8 +152,8 @@ class HazelcastSessionCatalogManager( private def addNewMapCatalog(sessionID: SessionID): XDTemporaryCatalogWithInvalidation = { val localCatalog = new XDTemporaryCatalogWithInvalidation( - new HashmapCatalog(catalystConf), - resourceInvalidator(sessionID) + new HashmapCatalog(catalystConf), + resourceInvalidator(sessionID) ) sessionIDToMapCatalog.put(sessionID, localCatalog) @@ -160,16 +166,16 @@ class HazelcastSessionCatalogManager( } override def invalidateAllLocalCaches: Unit = { - sessionIDToMapCatalog clear() + sessionIDToMapCatalog clear () sessionInvalidator(None).foreach(_.invalidateCache) } } class HazelcastSessionConfigManager( - override protected val hInstance: HazelcastInstance, - sessionInvalidator: Option[SessionID] => Option[CacheInvalidator] = (_ => None) - ) extends HazelcastSessionResourceManager[SQLConf] { + override protected val hInstance: HazelcastInstance, + sessionInvalidator: Option[SessionID] => Option[CacheInvalidator] = (_ => None) +) extends HazelcastSessionResourceManager[SQLConf] { import HazelcastSessionProvider._ @@ -195,35 +201,32 @@ class HazelcastSessionConfigManager( conf } - override def getResource(key: SessionID): Try[XDSQLConf] = sessionId2Config.get(key).map(Success(_)) getOrElse { - for ( - configId <- checkNotNull(sessionId2ConfigMapId.get(key)); - configMap <- checkNotNull(hInstance.getMap[String, String](configId.toString)) - ) yield { - val conf = new HazelcastSQLConf(configMap, resourceInvalidator(key)) - sessionId2Config += key -> conf - conf + override def getResource(key: SessionID): Try[XDSQLConf] = + sessionId2Config.get(key).map(Success(_)) getOrElse { + for (configId <- checkNotNull(sessionId2ConfigMapId.get(key)); + configMap <- checkNotNull(hInstance.getMap[String, String](configId.toString))) yield { + val conf = new HazelcastSQLConf(configMap, resourceInvalidator(key)) + sessionId2Config += key -> conf + conf + } } - } override def deleteSessionResource(key: SessionID): Try[Unit] = { sessionId2Config.remove(key) - for ( - configId <- checkNotNull(sessionId2ConfigMapId.get(key)); - configMap <- checkNotNull(hInstance.getMap[String, String](configId.toString)) - ) yield { - configMap clear() - sessionId2ConfigMapId remove key - publishInvalidation(key) - } + for (configId <- checkNotNull(sessionId2ConfigMapId.get(key)); + configMap <- checkNotNull(hInstance.getMap[String, String](configId.toString))) + yield { + configMap clear () + sessionId2ConfigMapId remove key + publishInvalidation(key) + } } - override def clearAllSessionsResources(): Unit = { import scala.collection.JavaConversions._ - sessionId2Config clear() + sessionId2Config clear () sessionId2ConfigMapId.values foreach (configId => hInstance.getMap(configId.toString) clear) - sessionId2ConfigMapId clear() + sessionId2ConfigMapId clear () publishInvalidation() } @@ -237,4 +240,4 @@ class HazelcastSessionConfigManager( sessionInvalidator(Some(key)).foreach(_.invalidateCache) } -} \ No newline at end of file +} diff --git a/server/src/test/scala/com/stratio/crossdata/server/CrossdataAvro.scala b/server/src/test/scala/com/stratio/crossdata/server/CrossdataAvro.scala index 403d42f7d..7745e9987 100644 --- a/server/src/test/scala/com/stratio/crossdata/server/CrossdataAvro.scala +++ b/server/src/test/scala/com/stratio/crossdata/server/CrossdataAvro.scala @@ -30,7 +30,7 @@ class CrossdataAvro extends SharedXDContextTest with ServerConfig { "Crossdata" should "execute avro queries" in { - try{ + try { sql(s"CREATE TABLE test USING com.databricks.spark.avro OPTIONS (path '${Paths.get(getClass.getResource("/test.avro").toURI()).toString}')") val result = sql("SELECT * FROM test").collect() result should have length 3 @@ -41,4 +41,4 @@ class CrossdataAvro extends SharedXDContextTest with ServerConfig { } } -} \ No newline at end of file +} diff --git a/server/src/test/scala/com/stratio/crossdata/server/CrossdataCSV.scala b/server/src/test/scala/com/stratio/crossdata/server/CrossdataCSV.scala index 7e0fe0702..8914bb8d8 100644 --- a/server/src/test/scala/com/stratio/crossdata/server/CrossdataCSV.scala +++ b/server/src/test/scala/com/stratio/crossdata/server/CrossdataCSV.scala @@ -30,7 +30,7 @@ class CrossdataCSV extends SharedXDContextTest with ServerConfig { "Crossdata" should "execute csv queries" in { - try{ + try { sql(s"CREATE TABLE cars USING com.databricks.spark.csv OPTIONS (path '${Paths.get(getClass.getResource("/cars.csv").toURI()).toString}', header 'true')") val result = sql("SELECT * FROM cars").collect() diff --git a/server/src/test/scala/org/apache/spark/sql/crossdata/HazelcastSQLConfSpec.scala b/server/src/test/scala/org/apache/spark/sql/crossdata/HazelcastSQLConfSpec.scala index ff314bef0..202976fa8 100644 --- a/server/src/test/scala/org/apache/spark/sql/crossdata/HazelcastSQLConfSpec.scala +++ b/server/src/test/scala/org/apache/spark/sql/crossdata/HazelcastSQLConfSpec.scala @@ -29,8 +29,9 @@ import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike} object HazelcastSQLConfSpec { - class ProbedHazelcastSessionConfigManager(hInstance: HazelcastInstance)(implicit monitorActor: ActorRef) - extends HazelcastSessionConfigManager(hInstance) { + class ProbedHazelcastSessionConfigManager(hInstance: HazelcastInstance)( + implicit monitorActor: ActorRef) + extends HazelcastSessionConfigManager(hInstance) { override def invalidateLocalCaches(key: SessionID): Unit = { super.invalidateLocalCaches(key) @@ -50,12 +51,12 @@ object HazelcastSQLConfSpec { } -class HazelcastSQLConfSpec extends TestKit(ActorSystem("HZSessionConfigTest")) - with WordSpecLike - with BeforeAndAfterAll - with ImplicitSender - with Matchers { - +class HazelcastSQLConfSpec + extends TestKit(ActorSystem("HZSessionConfigTest")) + with WordSpecLike + with BeforeAndAfterAll + with ImplicitSender + with Matchers { // Test description @@ -98,26 +99,22 @@ class HazelcastSQLConfSpec extends TestKit(ActorSystem("HZSessionConfigTest")) val sqlConf: SQLConf = configManager.getResource(sessionID).get expectNoMsg() - sqlConf.setConfString("spark.sql.parquet.filterPushdown","false") + sqlConf.setConfString("spark.sql.parquet.filterPushdown", "false") val sqlConfAtB: SQLConf = probedConfigManager.getResource(sessionID).get sqlConfAtB.getConfString("spark.sql.parquet.filterPushdown") shouldBe "false" } - } - - - } - // Test plumbing // TODO: Extract common class providing this kind of tests plumbing - private def createHazelcastInstance: HazelcastInstance = Hazelcast.newHazelcastInstance() + private def createHazelcastInstance: HazelcastInstance = + Hazelcast.newHazelcastInstance() var configManager: HazelcastSessionConfigManager = _ var probedConfigManager: ProbedHazelcastSessionConfigManager = _ @@ -133,5 +130,4 @@ class HazelcastSQLConfSpec extends TestKit(ActorSystem("HZSessionConfigTest")) override protected def afterAll(): Unit = Hazelcast shutdownAll - } diff --git a/server/src/test/scala/org/apache/spark/sql/crossdata/catalog/temporary/HazelcastCatalogSpec.scala b/server/src/test/scala/org/apache/spark/sql/crossdata/catalog/temporary/HazelcastCatalogSpec.scala index 46ccc501f..63628a65f 100644 --- a/server/src/test/scala/org/apache/spark/sql/crossdata/catalog/temporary/HazelcastCatalogSpec.scala +++ b/server/src/test/scala/org/apache/spark/sql/crossdata/catalog/temporary/HazelcastCatalogSpec.scala @@ -32,9 +32,10 @@ class HazelcastCatalogSpec extends { override lazy val temporaryCatalog: XDTemporaryCatalog = { val hInstance = Hazelcast.newHazelcastInstance - val tables = hInstance.getMap[TableIdentifierNormalized, CrossdataTable](UUID.randomUUID().toString) + val tables = + hInstance.getMap[TableIdentifierNormalized, CrossdataTable](UUID.randomUUID().toString) val views = hInstance.getMap[TableIdentifierNormalized, String](UUID.randomUUID().toString) new HazelcastCatalog(tables, views)(xdContext.conf) } -} \ No newline at end of file +} diff --git a/server/src/test/scala/org/apache/spark/sql/crossdata/session/HazelcastSessionCatalogManagerSpec.scala b/server/src/test/scala/org/apache/spark/sql/crossdata/session/HazelcastSessionCatalogManagerSpec.scala index ff9d7caa9..dcfbfa8c3 100644 --- a/server/src/test/scala/org/apache/spark/sql/crossdata/session/HazelcastSessionCatalogManagerSpec.scala +++ b/server/src/test/scala/org/apache/spark/sql/crossdata/session/HazelcastSessionCatalogManagerSpec.scala @@ -27,11 +27,11 @@ import org.apache.spark.sql.crossdata.catalog.interfaces.XDTemporaryCatalog import org.apache.spark.sql.crossdata.session.HazelcastSessionCatalogManagerSpec.{InvalidatedSession, ProbedHazelcastSessionCatalogManager} import org.scalatest.{BeforeAndAfterAll, WordSpecLike} - object HazelcastSessionCatalogManagerSpec { - private class ProbedHazelcastSessionCatalogManager(hInstance: HazelcastInstance)(implicit monitorActor: ActorRef) - extends HazelcastSessionCatalogManager(hInstance, EmptyConf) { + private class ProbedHazelcastSessionCatalogManager(hInstance: HazelcastInstance)( + implicit monitorActor: ActorRef) + extends HazelcastSessionCatalogManager(hInstance, EmptyConf) { override def invalidateLocalCaches(key: SessionID): Unit = { super.invalidateLocalCaches(key) @@ -50,10 +50,11 @@ object HazelcastSessionCatalogManagerSpec { } -class HazelcastSessionCatalogManagerSpec extends TestKit(ActorSystem("HZSessionCatalogTest")) - with WordSpecLike - with BeforeAndAfterAll - with ImplicitSender { +class HazelcastSessionCatalogManagerSpec + extends TestKit(ActorSystem("HZSessionCatalogTest")) + with WordSpecLike + with BeforeAndAfterAll + with ImplicitSender { // Test description @@ -86,14 +87,14 @@ class HazelcastSessionCatalogManagerSpec extends TestKit(ActorSystem("HZSessionC } - } } // Test plumbing - private def createHazelcastInstance: HazelcastInstance = Hazelcast.newHazelcastInstance(new Config()) + private def createHazelcastInstance: HazelcastInstance = + Hazelcast.newHazelcastInstance(new Config()) var catalogManager: HazelcastSessionCatalogManager = _ var probedCatalogManager: HazelcastSessionCatalogManager = _ diff --git a/server/src/test/scala/org/apache/spark/sql/crossdata/session/HazelcastSessionProviderSpec.scala b/server/src/test/scala/org/apache/spark/sql/crossdata/session/HazelcastSessionProviderSpec.scala index adf35b160..a758f7a6d 100644 --- a/server/src/test/scala/org/apache/spark/sql/crossdata/session/HazelcastSessionProviderSpec.scala +++ b/server/src/test/scala/org/apache/spark/sql/crossdata/session/HazelcastSessionProviderSpec.scala @@ -40,30 +40,33 @@ class HazelcastSessionProviderSpec extends SharedXDContextTest { val SparkSqlConfigString = "config.spark.sql.inMemoryColumnarStorage.batchSize=5000" + "HazelcastSessionProvider" should "provides new sessions whose properties are initialized properly" in { -"HazelcastSessionProvider" should "provides new sessions whose properties are initialized properly" in { - - val hazelcastSessionProvider = new HazelcastSessionProvider(xdContext.sc, ConfigFactory.parseString(SparkSqlConfigString)) + val hazelcastSessionProvider = + new HazelcastSessionProvider(xdContext.sc, ConfigFactory.parseString(SparkSqlConfigString)) val session = createNewSession(hazelcastSessionProvider) - session.conf.settings should contain(Entry("spark.sql.inMemoryColumnarStorage.batchSize", "5000")) + session.conf.settings should contain( + Entry("spark.sql.inMemoryColumnarStorage.batchSize", "5000")) val tempCatalogs = tempCatalogsFromSession(session) tempCatalogs should have length 2 tempCatalogs.head shouldBe a[XDTemporaryCatalogWithInvalidation] - tempCatalogs.head.asInstanceOf[XDTemporaryCatalogWithInvalidation].underlying shouldBe a[HashmapCatalog] + tempCatalogs.head + .asInstanceOf[XDTemporaryCatalogWithInvalidation] + .underlying shouldBe a[HashmapCatalog] tempCatalogs(1) shouldBe a[HazelcastCatalog] hazelcastSessionProvider.close() } - it should "provides a common persistent catalog and isolated catalogs" in { // TODO we should share the persistentCatalog - val hazelcastSessionProvider = new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) + val hazelcastSessionProvider = + new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) val (sessionTempCatalogs, sessionPersCatalogs) = { val session = createNewSession(hazelcastSessionProvider) @@ -87,7 +90,8 @@ class HazelcastSessionProviderSpec extends SharedXDContextTest { it should "allow to lookup an existing session" in { - val hazelcastSessionProvider = new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) + val hazelcastSessionProvider = + new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) val sessionId = UUID.randomUUID() val tableIdent = TableIdentifier("tab") @@ -95,7 +99,10 @@ class HazelcastSessionProviderSpec extends SharedXDContextTest { import org.apache.spark.sql.crossdata.catalog.interfaces.XDCatalogCommon._ - session.catalog.registerTable(tableIdent, LocalRelation(), Some(CrossdataTable(tableIdent.normalize(xdContext.catalog.conf), None, "fakedatasource"))) + session.catalog.registerTable( + tableIdent, + LocalRelation(), + Some(CrossdataTable(tableIdent.normalize(xdContext.catalog.conf), None, "fakedatasource"))) hazelcastSessionProvider.session(sessionId) should matchPattern { case Success(s: XDSession) if Try(s.catalog.lookupRelation(tableIdent)).isSuccess => @@ -104,20 +111,19 @@ class HazelcastSessionProviderSpec extends SharedXDContextTest { hazelcastSessionProvider.close() } - it should "fail when trying to lookup a non-existing session" in { - val hazelcastSessionProvider = new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) + val hazelcastSessionProvider = + new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) hazelcastSessionProvider.session(UUID.randomUUID()).isFailure shouldBe true hazelcastSessionProvider.close() } - - it should "remove the session metadata when closing an open session" in { - val hazelcastSessionProvider = new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) + val hazelcastSessionProvider = + new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) val sessionId = UUID.randomUUID() val session = hazelcastSessionProvider.newSession(sessionId) @@ -131,31 +137,32 @@ class HazelcastSessionProviderSpec extends SharedXDContextTest { it should "fail when trying to close a non-existing session" in { - val hazelcastSessionProvider = new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) + val hazelcastSessionProvider = + new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) val session = hazelcastSessionProvider.newSession(UUID.randomUUID()) hazelcastSessionProvider.closeSession(UUID.randomUUID()).isFailure shouldBe true - + hazelcastSessionProvider.close() } - - it should "close the hazelcast instance when closing" in { - val hazelcastSessionProvider = new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) + val hazelcastSessionProvider = + new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) val sessionID = UUID.randomUUID() hazelcastSessionProvider.newSession(sessionID) hazelcastSessionProvider.close() - a [RuntimeException] shouldBe thrownBy (hazelcastSessionProvider.session(sessionID)) - + a[RuntimeException] shouldBe thrownBy(hazelcastSessionProvider.session(sessionID)) + } it should "provide the same cached instance when it hasn't been invalidated" in { - val hazelcastSessionProvider = new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) + val hazelcastSessionProvider = + new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) val sessionID = UUID.randomUUID() val refA = hazelcastSessionProvider.newSession(sessionID).get @@ -167,21 +174,23 @@ class HazelcastSessionProviderSpec extends SharedXDContextTest { } testInvalidation("provide a new session instance after its invalidation by a SQLConf change")( - // This changes a setting value using a second hazelcast peer - _.setConf("spark.sql.parquet.filterPushdown", "false") + // This changes a setting value using a second hazelcast peer + _.setConf("spark.sql.parquet.filterPushdown", "false") ) testInvalidation("provide a new session instance after its invalidation by a Catalog change")( - // This changes a setting value using a second hazelcast peer - _.catalog.unregisterTable(TableIdentifier("DUMMY_TABLE")) + // This changes a setting value using a second hazelcast peer + _.catalog.unregisterTable(TableIdentifier("DUMMY_TABLE")) ) def testInvalidation(testDescription: String)(invalidationAction: XDSession => Unit) = it should testDescription in { // Two hazelcast peers shall be created - val hazelcastSessionProviderA = new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) - val hazelcastSessionProviderB = new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) + val hazelcastSessionProviderA = + new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) + val hazelcastSessionProviderB = + new HazelcastSessionProvider(xdContext.sc, ConfigFactory.empty()) val sessionID = UUID.randomUUID() @@ -214,10 +223,11 @@ class HazelcastSessionProviderSpec extends SharedXDContextTest { session.catalog.asInstanceOf[CatalogChain].persistentCatalogs } - private def createNewSession(hazelcastSessionProvider: HazelcastSessionProvider, uuid: UUID = UUID.randomUUID()): XDSession = { + private def createNewSession(hazelcastSessionProvider: HazelcastSessionProvider, + uuid: UUID = UUID.randomUUID()): XDSession = { val optSession = hazelcastSessionProvider.newSession(uuid).toOption optSession shouldBe defined optSession.get } -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/com/stratio/crossdata/streaming/CrossdataStreaming.scala b/streaming/src/main/scala/com/stratio/crossdata/streaming/CrossdataStreaming.scala index b9a844dfc..1c8b8c140 100644 --- a/streaming/src/main/scala/com/stratio/crossdata/streaming/CrossdataStreaming.scala +++ b/streaming/src/main/scala/com/stratio/crossdata/streaming/CrossdataStreaming.scala @@ -28,7 +28,7 @@ import scala.util.Try class CrossdataStreaming(ephemeralTableName: String, streamingCatalogConfig: Map[String, String], crossdataCatalogConfiguration: Map[String, String]) - extends EphemeralTableMapDAO { + extends EphemeralTableMapDAO { private val zookeeperCatalogConfig = streamingCatalogConfig.collect { case (key, value) if key.startsWith(ZooKeeperStreamingCatalogPath) => @@ -39,26 +39,25 @@ class CrossdataStreaming(ephemeralTableName: String, def init(): Try[Any] = { Try { - val ephemeralTable = dao.get(ephemeralTableName) + val ephemeralTable = dao + .get(ephemeralTableName) .getOrElse(throw new IllegalStateException("Ephemeral table not found")) val sparkConfig = configToSparkConf(ephemeralTable) - val ssc = StreamingContext.getOrCreate(ephemeralTable.options.checkpointDirectory, - () => { - CrossdataStreamingHelper.createContext(ephemeralTable, - sparkConfig, - zookeeperCatalogConfig, - crossdataCatalogConfiguration - ) - }) + val ssc = StreamingContext.getOrCreate(ephemeralTable.options.checkpointDirectory, () => { + CrossdataStreamingHelper.createContext(ephemeralTable, + sparkConfig, + zookeeperCatalogConfig, + crossdataCatalogConfiguration) + }) CrossdataStatusHelper.initStatusActor(ssc, zookeeperCatalogConfig, ephemeralTable.name) logger.info(s"Started Ephemeral Table: $ephemeralTableName") CrossdataStatusHelper.setEphemeralStatus( - EphemeralExecutionStatus.Started, - zookeeperCatalogConfig, - ephemeralTableName + EphemeralExecutionStatus.Started, + zookeeperCatalogConfig, + ephemeralTableName ) ssc.start() @@ -70,9 +69,9 @@ class CrossdataStreaming(ephemeralTableName: String, new SparkConf().setAll(setPrefixSpark(ephemeralTable.options.sparkOptions)) private[streaming] def setPrefixSpark(sparkConfig: Map[String, String]): Map[String, String] = - sparkConfig.map { case entry@(key, value) => - if (key.startsWith(SparkPrefixName)) entry - else (s"$SparkPrefixName.$key", value) + sparkConfig.map { + case entry @ (key, value) => + if (key.startsWith(SparkPrefixName)) entry + else (s"$SparkPrefixName.$key", value) } } - diff --git a/streaming/src/main/scala/com/stratio/crossdata/streaming/CrossdataStreamingApplication.scala b/streaming/src/main/scala/com/stratio/crossdata/streaming/CrossdataStreamingApplication.scala index ecdeee3f5..c241c4138 100644 --- a/streaming/src/main/scala/com/stratio/crossdata/streaming/CrossdataStreamingApplication.scala +++ b/streaming/src/main/scala/com/stratio/crossdata/streaming/CrossdataStreamingApplication.scala @@ -36,15 +36,17 @@ object CrossdataStreamingApplication extends SparkLoggerComponent with Crossdata Try { val ephemeralTableName = args(EphemeralTableNameIndex) - val zookConfigurationRendered = new String(BaseEncoding.base64().decode(args(StreamingCatalogConfigurationIndex))) + val zookConfigurationRendered = + new String(BaseEncoding.base64().decode(args(StreamingCatalogConfigurationIndex))) val zookeeperConf = parseConf(zookConfigurationRendered).getOrElse { - val message = s"Error parsing zookeeper argument -> $zookConfigurationRendered" - logger.error(message) - throw new IllegalArgumentException(message) - } + val message = s"Error parsing zookeeper argument -> $zookConfigurationRendered" + logger.error(message) + throw new IllegalArgumentException(message) + } - val xdCatalogConfRendered = new String(BaseEncoding.base64().decode(args(CrossdataCatalogIndex))) + val xdCatalogConfRendered = + new String(BaseEncoding.base64().decode(args(CrossdataCatalogIndex))) val xdCatalogConf = parseConf(xdCatalogConfRendered).getOrElse { val message = s"Error parsing XDCatalog argument -> $xdCatalogConfRendered" @@ -52,14 +54,18 @@ object CrossdataStreamingApplication extends SparkLoggerComponent with Crossdata throw new IllegalArgumentException(message) } - val crossdataStreaming = new CrossdataStreaming(ephemeralTableName, typeSafeConfigToMapString(zookeeperConf), typeSafeConfigToMapString(xdCatalogConf)) + val crossdataStreaming = new CrossdataStreaming(ephemeralTableName, + typeSafeConfigToMapString(zookeeperConf), + typeSafeConfigToMapString(xdCatalogConf)) crossdataStreaming.init() match { case Success(_) => logger.info(s"Ephemeral Table Finished correctly: $ephemeralTableName") CrossdataStatusHelper.close() case Failure(exception) => logger.error(exception.getMessage, exception) - CrossdataStatusHelper.setEphemeralStatus(EphemeralExecutionStatus.Error, typeSafeConfigToMapString(zookeeperConf), ephemeralTableName) + CrossdataStatusHelper.setEphemeralStatus(EphemeralExecutionStatus.Error, + typeSafeConfigToMapString(zookeeperConf), + ephemeralTableName) CrossdataStatusHelper.close() sys.exit(-1) } @@ -81,13 +87,15 @@ object CrossdataStreamingApplication extends SparkLoggerComponent with Crossdata ConfigFactory.parseString(renderedConfig) } - private def typeSafeConfigToMapString(config: Config, path: Option[String]= None): Map[String, String] = { + private def typeSafeConfigToMapString(config: Config, + path: Option[String] = None): Map[String, String] = { import scala.collection.JavaConversions._ val conf = path.map(config.getConfig).getOrElse(config) - conf.entrySet().toSeq.map( e => - (s"${path.fold("")(_+".")+ e.getKey}", conf.getAnyRef(e.getKey).toString) - ).toMap + conf + .entrySet() + .toSeq + .map(e => (s"${path.fold("")(_ + ".") + e.getKey}", conf.getAnyRef(e.getKey).toString)) + .toMap } - } diff --git a/streaming/src/main/scala/com/stratio/crossdata/streaming/actors/EphemeralQueryActor.scala b/streaming/src/main/scala/com/stratio/crossdata/streaming/actors/EphemeralQueryActor.scala index 8cb3671eb..d55ac53a9 100644 --- a/streaming/src/main/scala/com/stratio/crossdata/streaming/actors/EphemeralQueryActor.scala +++ b/streaming/src/main/scala/com/stratio/crossdata/streaming/actors/EphemeralQueryActor.scala @@ -23,8 +23,9 @@ import org.apache.spark.sql.crossdata.models.EphemeralQueryModel import scala.util.Try -class EphemeralQueryActor(zookeeperConfiguration: Map[String, String]) extends Actor -with EphemeralQueriesMapDAO { +class EphemeralQueryActor(zookeeperConfiguration: Map[String, String]) + extends Actor + with EphemeralQueriesMapDAO { val memoryMap = Map(ZookeeperPrefixName -> zookeeperConfiguration) var streamingQueries: List[EphemeralQueryModel] = dao.getAll() @@ -57,7 +58,7 @@ object EphemeralQueryActor { case object AddListener - case class ListenerResponse(added : Boolean) + case class ListenerResponse(added: Boolean) case class EphemeralQueriesResponse(streamingQueries: Seq[EphemeralQueryModel]) diff --git a/streaming/src/main/scala/com/stratio/crossdata/streaming/actors/EphemeralStatusActor.scala b/streaming/src/main/scala/com/stratio/crossdata/streaming/actors/EphemeralStatusActor.scala index 7c8774a3a..75b4da582 100644 --- a/streaming/src/main/scala/com/stratio/crossdata/streaming/actors/EphemeralStatusActor.scala +++ b/streaming/src/main/scala/com/stratio/crossdata/streaming/actors/EphemeralStatusActor.scala @@ -28,8 +28,9 @@ import scala.concurrent.duration._ class EphemeralStatusActor(streamingContext: StreamingContext, zookeeperConfiguration: Map[String, String], - ephemeralTableName: String) extends Actor -with EphemeralTableStatusMapDAO { + ephemeralTableName: String) + extends Actor + with EphemeralTableStatusMapDAO { val ephemeralTMDao = new EphemeralTableMapDAO(Map(ZookeeperPrefixName -> zookeeperConfiguration)) val memoryMap = Map(ZookeeperPrefixName -> zookeeperConfiguration) @@ -60,9 +61,13 @@ with EphemeralTableStatusMapDAO { case Some(tableModel) => tableModel.options.atomicWindow * 1000 case None => EphemeralOptionsModel.DefaultAtomicWindow * 1000 } - cancellableCheckStatus = Option(cancellableCheckStatus.fold( - context.system.scheduler.schedule(delayMs milliseconds, delayMs milliseconds, self, CheckStatus)(context.dispatcher) - )(identity)) + cancellableCheckStatus = Option( + cancellableCheckStatus.fold( + context.system.scheduler.schedule(delayMs milliseconds, + delayMs milliseconds, + self, + CheckStatus)(context.dispatcher) + )(identity)) } override def postStop(): Unit = { @@ -75,7 +80,7 @@ with EphemeralTableStatusMapDAO { } private[streaming] def doCheckStatus(): Unit = - // TODO check if the status can be read from ephemeralStatus insteadof getRepository...; the listener should work + // TODO check if the status can be read from ephemeralStatus insteadof getRepository...; the listener should work getRepositoryStatusTable.foreach { statusModel => if (statusModel.status == EphemeralExecutionStatus.Stopping) { // TODO add an actor containing status and query actor in order to exit gracefully @@ -88,15 +93,22 @@ with EphemeralTableStatusMapDAO { private[streaming] def doSetStatus(newStatus: EphemeralExecutionStatus.Value): Unit = { - val startTime = if (newStatus == EphemeralExecutionStatus.Started) Option(DateTime.now.getMillis) else None - val stopTime = if (newStatus == EphemeralExecutionStatus.Stopped) Option(DateTime.now.getMillis) else None + val startTime = + if (newStatus == EphemeralExecutionStatus.Started) + Option(DateTime.now.getMillis) + else None + val stopTime = + if (newStatus == EphemeralExecutionStatus.Stopped) + Option(DateTime.now.getMillis) + else None val resultStatus = ephemeralStatus.fold( - dao.create(ephemeralTableName, EphemeralStatusModel(ephemeralTableName, newStatus, startTime, stopTime)) + dao.create(ephemeralTableName, + EphemeralStatusModel(ephemeralTableName, newStatus, startTime, stopTime)) ) { ephStatus => val newStatusModel = ephStatus.copy( - status = newStatus, - stoppedTime = stopTime, - startedTime = startTime.orElse(ephStatus.startedTime) + status = newStatus, + stoppedTime = stopTime, + startedTime = startTime.orElse(ephStatus.startedTime) ) dao.upsert(ephemeralTableName, newStatusModel) } @@ -108,18 +120,20 @@ with EphemeralTableStatusMapDAO { private[streaming] def doAddListener(): Unit = { repository.addListener[EphemeralStatusModel]( - dao.entity, - ephemeralTableName, - (newEphemeralStatus: EphemeralStatusModel, _) => ephemeralStatus = Option(newEphemeralStatus) + dao.entity, + ephemeralTableName, + (newEphemeralStatus: EphemeralStatusModel, + _) => ephemeralStatus = Option(newEphemeralStatus) ) sender ! ListenerResponse(true) } - private[streaming] def getRepositoryStatusTable: Option[EphemeralStatusModel] = dao.get(ephemeralTableName) + private[streaming] def getRepositoryStatusTable: Option[EphemeralStatusModel] = + dao.get(ephemeralTableName) - private[streaming] def getStatusFromTable(ephemeralTable: Option[EphemeralStatusModel]) - : EphemeralExecutionStatus.Value = { + private[streaming] def getStatusFromTable( + ephemeralTable: Option[EphemeralStatusModel]): EphemeralExecutionStatus.Value = { ephemeralTable.fold(EphemeralExecutionStatus.NotStarted) { tableStatus => tableStatus.status } @@ -139,7 +153,8 @@ with EphemeralTableStatusMapDAO { } } - private[streaming] def doGetStreamingStatus(): Unit = sender ! StreamingStatusResponse(streamingContext.getState()) + private[streaming] def doGetStreamingStatus(): Unit = + sender ! StreamingStatusResponse(streamingContext.getState()) } object EphemeralStatusActor { @@ -160,4 +175,4 @@ object EphemeralStatusActor { case class StreamingStatusResponse(status: StreamingContextState) -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/com/stratio/crossdata/streaming/constants/ApplicationConstants.scala b/streaming/src/main/scala/com/stratio/crossdata/streaming/constants/ApplicationConstants.scala index bf2682ad2..55c680a07 100644 --- a/streaming/src/main/scala/com/stratio/crossdata/streaming/constants/ApplicationConstants.scala +++ b/streaming/src/main/scala/com/stratio/crossdata/streaming/constants/ApplicationConstants.scala @@ -27,11 +27,11 @@ object ApplicationConstants { val StopGracefully = true val DefaultZookeeperConfiguration = Map( - "connectionString" -> "127.0.0.1:2181", - "connectionTimeout" -> 1500, - "sessionTimeout" -> 60000, - "retryAttempts" -> 6, - "retryInterval" -> 10000 + "connectionString" -> "127.0.0.1:2181", + "connectionTimeout" -> 1500, + "sessionTimeout" -> 60000, + "retryAttempts" -> 6, + "retryInterval" -> 10000 ) } diff --git a/streaming/src/main/scala/com/stratio/crossdata/streaming/constants/KafkaConstants.scala b/streaming/src/main/scala/com/stratio/crossdata/streaming/constants/KafkaConstants.scala index 5fc4c369a..d9e5670e9 100644 --- a/streaming/src/main/scala/com/stratio/crossdata/streaming/constants/KafkaConstants.scala +++ b/streaming/src/main/scala/com/stratio/crossdata/streaming/constants/KafkaConstants.scala @@ -18,8 +18,8 @@ package com.stratio.crossdata.streaming.constants object KafkaConstants { /** - * Default parameters - */ + * Default parameters + */ val DefaultPartition = 1 val DefaultConsumerPort = "2181" val DefaultProducerPort = "9092" @@ -27,14 +27,14 @@ object KafkaConstants { val DefaultSerializer = "kafka.serializer.StringEncoder" /** - * Kafka Spark consumer keys - */ + * Kafka Spark consumer keys + */ val ZookeeperConnectionKey = "zookeeper.connect" val GroupIdKey = "group.id" /** - * Kafka native producer keys - */ + * Kafka native producer keys + */ val SerializerKey = "serializer.class" val BrokerListKey = "metadata.broker.list" val PartitionKey = "partition" @@ -46,11 +46,10 @@ object KafkaConstants { val MaxRetriesKey = "maxRetries" val ClientIdKey = "clientId" - val producerProperties = Map( - RequiredAckKey -> "request.required.acks", - CompressionCodecKey -> "compression.codec", - ProducerTypeKey -> "producer.type", - BatchSizeKey -> "batch.num.messages", - MaxRetriesKey -> "message.send.max.retries", - ClientIdKey -> "client.id") + val producerProperties = Map(RequiredAckKey -> "request.required.acks", + CompressionCodecKey -> "compression.codec", + ProducerTypeKey -> "producer.type", + BatchSizeKey -> "batch.num.messages", + MaxRetriesKey -> "message.send.max.retries", + ClientIdKey -> "client.id") } diff --git a/streaming/src/main/scala/com/stratio/crossdata/streaming/helpers/CrossdataStatusHelper.scala b/streaming/src/main/scala/com/stratio/crossdata/streaming/helpers/CrossdataStatusHelper.scala index 7217dc945..08c439057 100644 --- a/streaming/src/main/scala/com/stratio/crossdata/streaming/helpers/CrossdataStatusHelper.scala +++ b/streaming/src/main/scala/com/stratio/crossdata/streaming/helpers/CrossdataStatusHelper.scala @@ -42,20 +42,23 @@ object CrossdataStatusHelper extends SparkLoggerComponent { def initStatusActor(streamingContext: StreamingContext, zookeeperConfiguration: Map[String, String], ephemeralTableName: String): Option[ActorRef] = { - if (ephemeralStatusActor.isEmpty) { - Try( + if (ephemeralStatusActor.isEmpty) { + Try( actorSystem.actorOf( - Props(new EphemeralStatusActor(streamingContext, zookeeperConfiguration, ephemeralTableName)), - EphemeralStatusActorName + Props( + new EphemeralStatusActor(streamingContext, + zookeeperConfiguration, + ephemeralTableName)), + EphemeralStatusActorName ) - ) match { - case Success(actorRef) => - ephemeralStatusActor = Option(actorRef) - actorRef ! AddListener - case Failure(e) => - logger.error("Error creating streaming status actor with listener: ", e) - } + ) match { + case Success(actorRef) => + ephemeralStatusActor = Option(actorRef) + actorRef ! AddListener + case Failure(e) => + logger.error("Error creating streaming status actor with listener: ", e) } + } ephemeralStatusActor } @@ -68,7 +71,8 @@ object CrossdataStatusHelper extends SparkLoggerComponent { val futureResult = queryActorRef ? GetQueries Await.result(futureResult, timeout.duration) match { case EphemeralQueriesResponse(queries) => - queries.filter(streamingQueryModel => streamingQueryModel.ephemeralTableName == ephemeralTableName) + queries.filter(streamingQueryModel => + streamingQueryModel.ephemeralTableName == ephemeralTableName) case _ => Seq.empty } @@ -77,28 +81,29 @@ object CrossdataStatusHelper extends SparkLoggerComponent { def setEphemeralStatus(status: EphemeralExecutionStatus.Value, zookeeperConfiguration: Map[String, String], - ephemeralTableName: String): Unit = { + ephemeralTableName: String): Unit = { ephemeralStatusActor.foreach { statusActorRef => statusActorRef ! SetStatus(status) } } - def close(): Unit = { ephemeralQueryActor.foreach(_ ! PoisonPill) ephemeralStatusActor.foreach(_ ! PoisonPill) - if(!actorSystem.isTerminated) { + if (!actorSystem.isTerminated) { actorSystem.shutdown() actorSystem.awaitTermination(5 seconds) } } - private[streaming] def createEphemeralQueryActor(zookeeperConfiguration: Map[String, String]): Option[ActorRef] = { + private[streaming] def createEphemeralQueryActor( + zookeeperConfiguration: Map[String, String]): Option[ActorRef] = { synchronized { if (ephemeralQueryActor.isEmpty) { - Try ( - actorSystem.actorOf(Props(new EphemeralQueryActor(zookeeperConfiguration)),EphemeralQueryActorName) + Try( + actorSystem.actorOf(Props(new EphemeralQueryActor(zookeeperConfiguration)), + EphemeralQueryActorName) ) match { case Success(actorRef) => ephemeralQueryActor = Option(actorRef) diff --git a/streaming/src/main/scala/com/stratio/crossdata/streaming/helpers/CrossdataStreamingHelper.scala b/streaming/src/main/scala/com/stratio/crossdata/streaming/helpers/CrossdataStreamingHelper.scala index 89008f848..84dd38467 100644 --- a/streaming/src/main/scala/com/stratio/crossdata/streaming/helpers/CrossdataStreamingHelper.scala +++ b/streaming/src/main/scala/com/stratio/crossdata/streaming/helpers/CrossdataStreamingHelper.scala @@ -49,7 +49,8 @@ object CrossdataStreamingHelper extends SparkLoggerComponent { val kafkaOptions = ephemeralTable.options.kafkaOptions val kafkaInput = new KafkaInput(kafkaOptions) - val kafkaDStream = toWindowDStream(kafkaInput.createStream(streamingContext), ephemeralTable.options) + val kafkaDStream = + toWindowDStream(kafkaInput.createStream(streamingContext), ephemeralTable.options) // DStream.foreachRDD is a method that is executed in the Spark Driver if and when an output action is not called // over a RDD. Thus, the value countdowns can be used inside. @@ -57,22 +58,28 @@ object CrossdataStreamingHelper extends SparkLoggerComponent { // http://spark.apache.org/docs/latest/streaming-programming-guide.html#design-patterns-for-using-foreachrdd kafkaDStream.foreachRDD { rdd => if (rdd.take(1).length > 0) { - val ephemeralQueries = CrossdataStatusHelper.queriesFromEphemeralTable(zookeeperConf, ephemeralTable.name) + val ephemeralQueries = + CrossdataStatusHelper.queriesFromEphemeralTable(zookeeperConf, ephemeralTable.name) if (ephemeralQueries.nonEmpty) { - ephemeralQueries.foreach( ephemeralQuery => { + ephemeralQueries.foreach(ephemeralQuery => { val alias = ephemeralQuery.alias - if(!countdowns.contains(alias)){ + if (!countdowns.contains(alias)) { countdowns.put(alias, (ephemeralQuery.window / sparkStreamingWindow)) } - countdowns.put(alias, countdowns.get(alias).getOrElse(0)-1) + countdowns.put(alias, countdowns.get(alias).getOrElse(0) - 1) logDebug(s"Countdowns: ${countdowns.mkString(", ")}") countdowns.get(alias) foreach { case 0 => { countdowns.put(alias, (ephemeralQuery.window / sparkStreamingWindow)) logInfo(s"Executing streaming query $alias") - executeQuery(rdd, ephemeralQuery, ephemeralTable, kafkaOptions, zookeeperConf, crossdataCatalogConf) + executeQuery(rdd, + ephemeralQuery, + ephemeralTable, + kafkaOptions, + zookeeperConf, + crossdataCatalogConf) } case countdown => logDebug(s"Current countdown for $alias: $countdown") @@ -111,13 +118,15 @@ object CrossdataStreamingHelper extends SparkLoggerComponent { val topic = ephemeralQuery.alias ephemeralTable.options.outputFormat match { - case EphemeralOutputFormat.JSON => saveToKafkaInJSONFormat(dataFrame, topic, kafkaOptionsMerged) - case _ => saveToKafkaInRowFormat(dataFrame, topic, kafkaOptionsMerged) + case EphemeralOutputFormat.JSON => + saveToKafkaInJSONFormat(dataFrame, topic, kafkaOptionsMerged) + case _ => + saveToKafkaInRowFormat(dataFrame, topic, kafkaOptionsMerged) } } match { case Failure(throwable) => logger.warn( - s"""|There are problems executing the ephemeral query: $query + s"""|There are problems executing the ephemeral query: $query |with Schema: ${df.printSchema()} |and the first row is: ${df.show(1)} |Exception message: ${throwable.getMessage} @@ -133,30 +142,40 @@ object CrossdataStreamingHelper extends SparkLoggerComponent { private[streaming] def mergeKafkaOptions(ephemeralQuery: EphemeralQueryModel, kafkaOptions: KafkaOptionsModel): KafkaOptionsModel = { kafkaOptions.copy( - partitionOutput = ephemeralQuery.options.get(PartitionKey).orElse(kafkaOptions.partitionOutput), - additionalOptions = kafkaOptions.additionalOptions ++ ephemeralQuery.options) + partitionOutput = + ephemeralQuery.options.get(PartitionKey).orElse(kafkaOptions.partitionOutput), + additionalOptions = kafkaOptions.additionalOptions ++ ephemeralQuery.options) } private[streaming] def filterRddWithWindow(rdd: RDD[(Long, String)], window: Int): RDD[String] = - rdd.flatMap { case (time, row) => - if (time > DateTime.now.getMillis - window * 1000) Option(row) - else None + rdd.flatMap { + case (time, row) => + if (time > DateTime.now.getMillis - window * 1000) Option(row) + else None } - private[streaming] def toWindowDStream(inputStream: DStream[(String, String)], - ephemeralOptions: EphemeralOptionsModel): DStream[(Long, String)] = - // TODO window per query? + private[streaming] def toWindowDStream( + inputStream: DStream[(String, String)], + ephemeralOptions: EphemeralOptionsModel): DStream[(Long, String)] = + // TODO window per query? inputStream.mapPartitions { iterator => val dateTime = DateTime.now.getMillis iterator.map { case (_, kafkaEvent) => (dateTime, kafkaEvent) } }.window(Seconds(ephemeralOptions.maxWindow), Seconds(ephemeralOptions.atomicWindow)) - private[streaming] def saveToKafkaInJSONFormat(dataFrame: DataFrame, topic: String, kafkaOptions: KafkaOptionsModel): Unit = + private[streaming] def saveToKafkaInJSONFormat(dataFrame: DataFrame, + topic: String, + kafkaOptions: KafkaOptionsModel): Unit = dataFrame.toJSON.foreachPartition(values => - values.foreach(value => KafkaProducer.put(topic, value, kafkaOptions, kafkaOptions.partitionOutput))) - - private[streaming] def saveToKafkaInRowFormat(dataFrame: DataFrame, topic: String, kafkaOptions: KafkaOptionsModel): Unit = - dataFrame.rdd.foreachPartition(values => - values.foreach(value => - KafkaProducer.put(topic, value.mkString(","), kafkaOptions, kafkaOptions.partitionOutput))) + values.foreach(value => + KafkaProducer.put(topic, value, kafkaOptions, kafkaOptions.partitionOutput))) + + private[streaming] def saveToKafkaInRowFormat(dataFrame: DataFrame, + topic: String, + kafkaOptions: KafkaOptionsModel): Unit = + dataFrame.rdd.foreachPartition( + values => + values.foreach(value => + KafkaProducer + .put(topic, value.mkString(","), kafkaOptions, kafkaOptions.partitionOutput))) } diff --git a/streaming/src/main/scala/com/stratio/crossdata/streaming/kafka/KafkaInput.scala b/streaming/src/main/scala/com/stratio/crossdata/streaming/kafka/KafkaInput.scala index 7aff0f3a3..175ae2fc8 100644 --- a/streaming/src/main/scala/com/stratio/crossdata/streaming/kafka/KafkaInput.scala +++ b/streaming/src/main/scala/com/stratio/crossdata/streaming/kafka/KafkaInput.scala @@ -32,28 +32,30 @@ class KafkaInput(options: KafkaOptionsModel) { val groupId = Map(getGroupId) KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( - ssc, - connection ++ groupId ++ kafkaParams, - getTopics, - storageLevel(options.storageLevel)) + ssc, + connection ++ groupId ++ kafkaParams, + getTopics, + storageLevel(options.storageLevel)) } - private[streaming] def getConnection : (String, String) = { + private[streaming] def getConnection: (String, String) = { - val connectionChain = ( - for(zkConnection <- options.connection.zkConnection) yield (s"${zkConnection.host}:${zkConnection.port}") - ).mkString(",") + val connectionChain = (for (zkConnection <- options.connection.zkConnection) + yield ( s"${zkConnection.host}:${zkConnection.port}")).mkString(",") - (ZookeeperConnectionKey, if(connectionChain.isEmpty) s"$DefaultHost:$DefaultConsumerPort" else connectionChain) + (ZookeeperConnectionKey, + if (connectionChain.isEmpty) s"$DefaultHost:$DefaultConsumerPort" + else connectionChain) } - private[streaming] def getGroupId : (String, String) = (GroupIdKey, options.groupId) + private[streaming] def getGroupId: (String, String) = + (GroupIdKey, options.groupId) - private[streaming] def getTopics : Map[String, Int] = { + private[streaming] def getTopics: Map[String, Int] = { if (options.topics.isEmpty) { throw new IllegalStateException(s"Invalid configuration, topics must be declared.") } else { - options.topics.map(topicModel => (topicModel.name, topicModel.numPartitions)).toMap + options.topics.map(topicModel => (topicModel.name, topicModel.numPartitions)).toMap } } @@ -61,4 +63,4 @@ class KafkaInput(options: KafkaOptionsModel) { StorageLevel.fromString(sparkStorageLevel) } -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/com/stratio/crossdata/streaming/kafka/KafkaProducer.scala b/streaming/src/main/scala/com/stratio/crossdata/streaming/kafka/KafkaProducer.scala index bc2062edc..618f4d1fb 100644 --- a/streaming/src/main/scala/com/stratio/crossdata/streaming/kafka/KafkaProducer.scala +++ b/streaming/src/main/scala/com/stratio/crossdata/streaming/kafka/KafkaProducer.scala @@ -39,14 +39,15 @@ object KafkaProducer { } private[streaming] def kafkaMessage(topic: String, - message: String, - partition: Option[String]): KeyedMessage[String, String] = { + message: String, + partition: Option[String]): KeyedMessage[String, String] = { partition.fold(new KeyedMessage[String, String](topic, message)) { key => new KeyedMessage[String, String](topic, key, message) } } - private[streaming] def sendMessage(message: KeyedMessage[String, String], options: KafkaOptionsModel): Unit = { + private[streaming] def sendMessage(message: KeyedMessage[String, String], + options: KafkaOptionsModel): Unit = { getProducer(options).send(message) } @@ -57,7 +58,8 @@ object KafkaProducer { private[streaming] def getKey(connection: ConnectionHostModel): String = s"ConnectionHostModel([${connection.zkConnection.map(_.toString).mkString(",")}],[${connection.kafkaConnection.map(_.toString).mkString(",")}])" - private[streaming] def getInstance(key: String, options: KafkaOptionsModel): Producer[String, String] = + private[streaming] def getInstance(key: String, + options: KafkaOptionsModel): Producer[String, String] = producers.getOrElse(key, { val producer = createProducer(options) producers.put(key, producer) @@ -69,8 +71,9 @@ object KafkaProducer { properties.put(BrokerListKey, getBrokerList(options.connection)) properties.put(SerializerKey, DefaultSerializer) - options.additionalOptions.foreach { case (key, value) => - producerProperties.get(key).foreach(kafkaKey => properties.put(kafkaKey, value)) + options.additionalOptions.foreach { + case (key, value) => + producerProperties.get(key).foreach(kafkaKey => properties.put(kafkaKey, value)) } val producerConfig = new ProducerConfig(properties) @@ -78,20 +81,20 @@ object KafkaProducer { } private[streaming] def getBrokerList(connection: ConnectionHostModel, - defaultHost: String = DefaultHost, - defaultPort: String = DefaultProducerPort): String = { + defaultHost: String = DefaultHost, + defaultPort: String = DefaultProducerPort): String = { - val connectionStr = ( - for (kafkaConnection <- connection.kafkaConnection) yield (s"${kafkaConnection.host}:${kafkaConnection.port}") - ).mkString(",") + val connectionStr = (for (kafkaConnection <- connection.kafkaConnection) + yield ( s"${kafkaConnection.host}:${kafkaConnection.port}")).mkString(",") if (connectionStr.isEmpty) s"$defaultHost:$defaultPort" else connectionStr } private[streaming] def deleteProducers(): Unit = { - producers.foreach { case (key, producer) => - producer.close() - producers.remove(key) + producers.foreach { + case (key, producer) => + producer.close() + producers.remove(key) } } diff --git a/streaming/src/test/scala/com/stratio/crossdata/streaming/CrossdataStreamingApplicationSpec.scala b/streaming/src/test/scala/com/stratio/crossdata/streaming/CrossdataStreamingApplicationSpec.scala index 9a6a271b5..b57e9a4dd 100644 --- a/streaming/src/test/scala/com/stratio/crossdata/streaming/CrossdataStreamingApplicationSpec.scala +++ b/streaming/src/test/scala/com/stratio/crossdata/streaming/CrossdataStreamingApplicationSpec.scala @@ -35,10 +35,10 @@ class CrossdataStreamingApplicationSpec extends BaseStreamingXDTest with CommonV "CrossdataStreamingApplication" should "parse correctly the zookeeper argument" in { - val result = CrossdataStreamingApplication.parseMapArguments("""{"connectionString":"localhost:2181"}""") + val result = + CrossdataStreamingApplication.parseMapArguments("""{"connectionString":"localhost:2181"}""") val expected = Try(Map("connectionString" -> "localhost:2181")) result should be(expected) } } - diff --git a/streaming/src/test/scala/com/stratio/crossdata/streaming/CrossdataStreamingSpec.scala b/streaming/src/test/scala/com/stratio/crossdata/streaming/CrossdataStreamingSpec.scala index 66353d0ac..f423573a0 100644 --- a/streaming/src/test/scala/com/stratio/crossdata/streaming/CrossdataStreamingSpec.scala +++ b/streaming/src/test/scala/com/stratio/crossdata/streaming/CrossdataStreamingSpec.scala @@ -28,15 +28,18 @@ class CrossdataStreamingSpec extends BaseStreamingXDTest with CommonValues { implicit val timeout: Timeout = Timeout(15.seconds) "CrossdataStreaming" should "return a empty Sparkconf according to the table options" in { - val XDStreaming = new CrossdataStreaming(TableName, Map.empty[String, String], Map.empty[String, String]) - val configuration = XDStreaming.configToSparkConf(ephemeralTableModelWithoutSparkOptions).getAll + val XDStreaming = + new CrossdataStreaming(TableName, Map.empty[String, String], Map.empty[String, String]) + val configuration = + XDStreaming.configToSparkConf(ephemeralTableModelWithoutSparkOptions).getAll val expected = Array.empty[(String, String)] configuration should be(expected) } "CrossdataStreaming" should "return Sparkconf according to the table options" in { - val XDStreaming = new CrossdataStreaming(TableName, Map.empty[String, String], Map.empty[String, String]) + val XDStreaming = + new CrossdataStreaming(TableName, Map.empty[String, String], Map.empty[String, String]) val configuration = XDStreaming.configToSparkConf(ephemeralTableModelWithSparkOptions).getAll val expected = Array(("spark.defaultParallelism", "50")) @@ -44,8 +47,10 @@ class CrossdataStreamingSpec extends BaseStreamingXDTest with CommonValues { } "CrossdataStreaming" should "return Sparkconf according to the table options with prefix" in { - val XDStreaming = new CrossdataStreaming(TableName, Map.empty[String, String], Map.empty[String, String]) - val configuration = XDStreaming.configToSparkConf(ephemeralTableModelWithSparkOptionsPrefix).getAll + val XDStreaming = + new CrossdataStreaming(TableName, Map.empty[String, String], Map.empty[String, String]) + val configuration = + XDStreaming.configToSparkConf(ephemeralTableModelWithSparkOptionsPrefix).getAll val expected = Array(("spark.defaultParallelism", "50")) configuration should be(expected) diff --git a/streaming/src/test/scala/com/stratio/crossdata/streaming/actors/EphemeralQueryActorIT.scala b/streaming/src/test/scala/com/stratio/crossdata/streaming/actors/EphemeralQueryActorIT.scala index df66773de..7c57074f3 100644 --- a/streaming/src/test/scala/com/stratio/crossdata/streaming/actors/EphemeralQueryActorIT.scala +++ b/streaming/src/test/scala/com/stratio/crossdata/streaming/actors/EphemeralQueryActorIT.scala @@ -28,12 +28,13 @@ import org.scalatest.{BeforeAndAfterAll, WordSpecLike} import org.scalatest.time.SpanSugar._ @RunWith(classOf[JUnitRunner]) -class EphemeralQueryActorIT(_system: ActorSystem) extends TestKit(_system) -with DefaultTimeout -with ImplicitSender -with WordSpecLike -with BeforeAndAfterAll -with TimeLimitedTests { +class EphemeralQueryActorIT(_system: ActorSystem) + extends TestKit(_system) + with DefaultTimeout + with ImplicitSender + with WordSpecLike + with BeforeAndAfterAll + with TimeLimitedTests { def this() = this(ActorSystem("EphemeralQueryActor")) @@ -55,7 +56,8 @@ with TimeLimitedTests { "EphemeralQueryActor" should { "set up with zookeeper configuration without any error" in { - _system.actorOf(Props(new EphemeralQueryActor(Map("connectionString" -> zookeeperConnection)))) + _system.actorOf( + Props(new EphemeralQueryActor(Map("connectionString" -> zookeeperConnection)))) } } @@ -63,8 +65,8 @@ with TimeLimitedTests { "AddListener the first message" in new CommonValues { - val ephemeralQueryActor = - _system.actorOf(Props(new EphemeralQueryActor(Map("connectionString" -> zookeeperConnection)))) + val ephemeralQueryActor = _system.actorOf( + Props(new EphemeralQueryActor(Map("connectionString" -> zookeeperConnection)))) ephemeralQueryActor ! EphemeralQueryActor.AddListener @@ -73,8 +75,8 @@ with TimeLimitedTests { "AddListener the not be the first message" in new CommonValues { - val ephemeralQueryActor = - _system.actorOf(Props(new EphemeralQueryActor(Map("connectionString" -> zookeeperConnection)))) + val ephemeralQueryActor = _system.actorOf( + Props(new EphemeralQueryActor(Map("connectionString" -> zookeeperConnection)))) ephemeralQueryActor ! EphemeralQueryActor.GetQueries @@ -83,8 +85,8 @@ with TimeLimitedTests { "GetQueries is the second message" in new CommonValues { - val ephemeralQueryActor = - _system.actorOf(Props(new EphemeralQueryActor(Map("connectionString" -> zookeeperConnection)))) + val ephemeralQueryActor = _system.actorOf( + Props(new EphemeralQueryActor(Map("connectionString" -> zookeeperConnection)))) ephemeralQueryActor ! EphemeralQueryActor.AddListener expectMsg(new ListenerResponse(true)) diff --git a/streaming/src/test/scala/com/stratio/crossdata/streaming/actors/EphemeralStatusActorIT.scala b/streaming/src/test/scala/com/stratio/crossdata/streaming/actors/EphemeralStatusActorIT.scala index 87a455496..a61288897 100644 --- a/streaming/src/test/scala/com/stratio/crossdata/streaming/actors/EphemeralStatusActorIT.scala +++ b/streaming/src/test/scala/com/stratio/crossdata/streaming/actors/EphemeralStatusActorIT.scala @@ -31,15 +31,16 @@ import org.scalatest.junit.JUnitRunner import org.scalatest.time.SpanSugar._ @RunWith(classOf[JUnitRunner]) -class EphemeralStatusActorIT(_system: ActorSystem) extends TestKit(_system) - with DefaultTimeout - with ImplicitSender - with WordSpecLike - with BeforeAndAfterAll - with CommonValues - with BeforeAndAfter - with ShouldMatchers - with TimeLimitedTests { +class EphemeralStatusActorIT(_system: ActorSystem) + extends TestKit(_system) + with DefaultTimeout + with ImplicitSender + with WordSpecLike + with BeforeAndAfterAll + with CommonValues + with BeforeAndAfter + with ShouldMatchers + with TimeLimitedTests { def this() = this(ActorSystem("EphemeralStatusActor")) @@ -78,8 +79,10 @@ class EphemeralStatusActorIT(_system: ActorSystem) extends TestKit(_system) "EphemeralStatusActor" should { "set up with zookeeper configuration and StreamingContext without any error" in { - _system.actorOf(Props(new EphemeralStatusActor(ssc, - Map("connectionString" -> zookeeperConnection), TableName))) + _system.actorOf( + Props(new EphemeralStatusActor(ssc, + Map("connectionString" -> zookeeperConnection), + TableName))) } } @@ -87,9 +90,10 @@ class EphemeralStatusActorIT(_system: ActorSystem) extends TestKit(_system) "AddListener the first message" in new CommonValues { - val ephemeralStatusActor = - _system.actorOf(Props(new EphemeralStatusActor(ssc, - Map("connectionString" -> zookeeperConnection), TableName))) + val ephemeralStatusActor = _system.actorOf( + Props(new EphemeralStatusActor(ssc, + Map("connectionString" -> zookeeperConnection), + TableName))) ephemeralStatusActor ! EphemeralStatusActor.AddListener @@ -98,9 +102,10 @@ class EphemeralStatusActorIT(_system: ActorSystem) extends TestKit(_system) "AddListener is the two messages" in new CommonValues { - val ephemeralStatusActor = - _system.actorOf(Props(new EphemeralStatusActor(ssc, - Map("connectionString" -> zookeeperConnection), TableName))) + val ephemeralStatusActor = _system.actorOf( + Props(new EphemeralStatusActor(ssc, + Map("connectionString" -> zookeeperConnection), + TableName))) ephemeralStatusActor ! EphemeralStatusActor.AddListener expectMsg(new ListenerResponse(true)) @@ -111,9 +116,10 @@ class EphemeralStatusActorIT(_system: ActorSystem) extends TestKit(_system) "GetStatus return the status" in new CommonValues { - val ephemeralStatusActor = - _system.actorOf(Props(new EphemeralStatusActor(ssc, - Map("connectionString" -> zookeeperConnection), TableName))) + val ephemeralStatusActor = _system.actorOf( + Props(new EphemeralStatusActor(ssc, + Map("connectionString" -> zookeeperConnection), + TableName))) ephemeralStatusActor ! EphemeralStatusActor.GetStatus expectMsg(new StatusResponse(EphemeralExecutionStatus.NotStarted)) @@ -121,9 +127,10 @@ class EphemeralStatusActorIT(_system: ActorSystem) extends TestKit(_system) "CheckStatus shoud make nothing" in new CommonValues { - val ephemeralStatusActor = - _system.actorOf(Props(new EphemeralStatusActor(ssc, - Map("connectionString" -> zookeeperConnection), TableName))) + val ephemeralStatusActor = _system.actorOf( + Props(new EphemeralStatusActor(ssc, + Map("connectionString" -> zookeeperConnection), + TableName))) ephemeralStatusActor ! EphemeralStatusActor.CheckStatus @@ -132,9 +139,10 @@ class EphemeralStatusActorIT(_system: ActorSystem) extends TestKit(_system) "SetStatus shoud change the status" in new CommonValues { - val ephemeralStatusActor = - _system.actorOf(Props(new EphemeralStatusActor(ssc, - Map("connectionString" -> zookeeperConnection), TableName))) + val ephemeralStatusActor = _system.actorOf( + Props(new EphemeralStatusActor(ssc, + Map("connectionString" -> zookeeperConnection), + TableName))) ephemeralStatusActor ! EphemeralStatusActor.GetStatus expectMsg(new StatusResponse(EphemeralExecutionStatus.NotStarted)) @@ -148,9 +156,10 @@ class EphemeralStatusActorIT(_system: ActorSystem) extends TestKit(_system) "GetStreamingStatus shoud return the correct streaming status" in new CommonValues { - val ephemeralStatusActor = - _system.actorOf(Props(new EphemeralStatusActor(ssc, - Map("connectionString" -> zookeeperConnection), TableName))) + val ephemeralStatusActor = _system.actorOf( + Props(new EphemeralStatusActor(ssc, + Map("connectionString" -> zookeeperConnection), + TableName))) ephemeralStatusActor ! EphemeralStatusActor.GetStreamingStatus expectMsg(StreamingStatusResponse(StreamingContextState.INITIALIZED)) @@ -167,9 +176,10 @@ class EphemeralStatusActorIT(_system: ActorSystem) extends TestKit(_system) "CheckStatus shoud make StreamingContext stop when status is Stopping without Listener" in new CommonValues { - val ephemeralStatusActor = - _system.actorOf(Props(new EphemeralStatusActor(ssc, - Map("connectionString" -> zookeeperConnection), TableName))) + val ephemeralStatusActor = _system.actorOf( + Props(new EphemeralStatusActor(ssc, + Map("connectionString" -> zookeeperConnection), + TableName))) ephemeralStatusActor ! EphemeralStatusActor.SetStatus(EphemeralExecutionStatus.Started) expectMsg(new StatusResponse(EphemeralExecutionStatus.Started)) diff --git a/streaming/src/test/scala/com/stratio/crossdata/streaming/helpers/CrossdataStatusHelperIT.scala b/streaming/src/test/scala/com/stratio/crossdata/streaming/helpers/CrossdataStatusHelperIT.scala index e80bc2b28..1a10bef8a 100644 --- a/streaming/src/test/scala/com/stratio/crossdata/streaming/helpers/CrossdataStatusHelperIT.scala +++ b/streaming/src/test/scala/com/stratio/crossdata/streaming/helpers/CrossdataStatusHelperIT.scala @@ -24,7 +24,6 @@ import org.apache.spark.{SparkContext, SparkConf} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner - @RunWith(classOf[JUnitRunner]) class CrossdataStatusHelperIT extends BaseStreamingXDTest with CommonValues { @@ -47,7 +46,8 @@ class CrossdataStatusHelperIT extends BaseStreamingXDTest with CommonValues { val sparkConf = new SparkConf().setMaster("local[2]").setAppName(this.getClass.getSimpleName) val sc = SparkContext.getOrCreate(sparkConf) val ssc = new StreamingContext(sc, Milliseconds(1000)) - val result = CrossdataStatusHelper.initStatusActor(ssc, Map("connectionString" -> zookeeperConnection), TableName) + val result = CrossdataStatusHelper + .initStatusActor(ssc, Map("connectionString" -> zookeeperConnection), TableName) val expected = true result.isDefined should be(expected) @@ -58,7 +58,8 @@ class CrossdataStatusHelperIT extends BaseStreamingXDTest with CommonValues { "CrossdataStatusHelperIT" should "create a QueryActor without errors" in { - val result = CrossdataStatusHelper.createEphemeralQueryActor(Map("connectionString" -> zookeeperConnection)) + val result = CrossdataStatusHelper.createEphemeralQueryActor( + Map("connectionString" -> zookeeperConnection)) val expected = true result.isDefined should be(expected) @@ -66,8 +67,8 @@ class CrossdataStatusHelperIT extends BaseStreamingXDTest with CommonValues { "CrossdataStatusHelperIT" should "create a QueryActor and return the queries" in { - val result = - CrossdataStatusHelper.queriesFromEphemeralTable(Map("connectionString" -> zookeeperConnection), TableName) + val result = CrossdataStatusHelper + .queriesFromEphemeralTable(Map("connectionString" -> zookeeperConnection), TableName) val expected = Seq.empty[EphemeralQueryModel] result should be(expected) diff --git a/streaming/src/test/scala/com/stratio/crossdata/streaming/kafka/KafkaProducerSpec.scala b/streaming/src/test/scala/com/stratio/crossdata/streaming/kafka/KafkaProducerSpec.scala index c0ede121a..357d2af85 100644 --- a/streaming/src/test/scala/com/stratio/crossdata/streaming/kafka/KafkaProducerSpec.scala +++ b/streaming/src/test/scala/com/stratio/crossdata/streaming/kafka/KafkaProducerSpec.scala @@ -29,7 +29,8 @@ class KafkaProducerSpec extends BaseStreamingXDTest with CommonValues { "KafkaProducer" should "return a correct key" in { val result = KafkaProducer.getKey(connectionHostModel) - val expected = """ConnectionHostModel([ConnectionModel(localhost,2181)],[ConnectionModel(localhost,9042)])""" + val expected = + """ConnectionHostModel([ConnectionModel(localhost,2181)],[ConnectionModel(localhost,9042)])""" result should be(expected) } @@ -68,21 +69,24 @@ class KafkaProducerSpec extends BaseStreamingXDTest with CommonValues { } "KafkaProducer" should "return additional params" in { - val result = KafkaProducer.getProducer(kafkaStreamModel).config.props.containsKey("batch.num.messages") + val result = + KafkaProducer.getProducer(kafkaStreamModel).config.props.containsKey("batch.num.messages") val expected = true result should be(expected) } "KafkaProducer" should "return a correct additional param" in { - val result = KafkaProducer.getProducer(kafkaStreamModel).config.props.getString("batch.num.messages") + val result = + KafkaProducer.getProducer(kafkaStreamModel).config.props.getString("batch.num.messages") val expected = "100" result should be(expected) } "KafkaProducer" should "return empty params" in { - val result = KafkaProducer.getProducer(kafkaOptionsModel).config.props.containsKey("batch.num.messages") + val result = + KafkaProducer.getProducer(kafkaOptionsModel).config.props.containsKey("batch.num.messages") val expected = false result should be(expected) diff --git a/streaming/src/test/scala/com/stratio/crossdata/streaming/test/BaseSparkStreamingXDTest.scala b/streaming/src/test/scala/com/stratio/crossdata/streaming/test/BaseSparkStreamingXDTest.scala index 0952bc4bd..4fed96636 100644 --- a/streaming/src/test/scala/com/stratio/crossdata/streaming/test/BaseSparkStreamingXDTest.scala +++ b/streaming/src/test/scala/com/stratio/crossdata/streaming/test/BaseSparkStreamingXDTest.scala @@ -19,13 +19,14 @@ import org.scalatest._ import org.scalatest.concurrent.{Eventually, TimeLimitedTests} import org.scalatest.time.SpanSugar._ -trait BaseSparkStreamingXDTest extends FunSuite -with Matchers -with ShouldMatchers -with BeforeAndAfterAll -with BeforeAndAfter -with Eventually -with TimeLimitedTests { +trait BaseSparkStreamingXDTest + extends FunSuite + with Matchers + with ShouldMatchers + with BeforeAndAfterAll + with BeforeAndAfter + with Eventually + with TimeLimitedTests { val timeLimit = 2 minutes } diff --git a/streaming/src/test/scala/com/stratio/crossdata/streaming/test/BaseStreamingXDTest.scala b/streaming/src/test/scala/com/stratio/crossdata/streaming/test/BaseStreamingXDTest.scala index 177298fae..8fa35188f 100644 --- a/streaming/src/test/scala/com/stratio/crossdata/streaming/test/BaseStreamingXDTest.scala +++ b/streaming/src/test/scala/com/stratio/crossdata/streaming/test/BaseStreamingXDTest.scala @@ -19,9 +19,9 @@ import com.stratio.crossdata.test.BaseXDTest import org.scalatest._ import org.scalatest.mock.MockitoSugar -trait BaseStreamingXDTest extends BaseXDTest -with ShouldMatchers -with BeforeAndAfterAll -with BeforeAndAfter -with MockitoSugar - +trait BaseStreamingXDTest + extends BaseXDTest + with ShouldMatchers + with BeforeAndAfterAll + with BeforeAndAfter + with MockitoSugar diff --git a/streaming/src/test/scala/com/stratio/crossdata/streaming/test/CommonValues.scala b/streaming/src/test/scala/com/stratio/crossdata/streaming/test/CommonValues.scala index 66adf0371..5675776c6 100644 --- a/streaming/src/test/scala/com/stratio/crossdata/streaming/test/CommonValues.scala +++ b/streaming/src/test/scala/com/stratio/crossdata/streaming/test/CommonValues.scala @@ -27,11 +27,11 @@ import org.apache.spark.sql.crossdata.models._ import scala.language.postfixOps import scala.util.{Failure, Random, Success, Try} -trait CommonValues extends SparkLoggerComponent{ +trait CommonValues extends SparkLoggerComponent { /** - * Kafka Options - */ + * Kafka Options + */ val ConsumerHost = "localhost" val ProducerHost = "localhost" val HostStream = "127.0.0.1" @@ -47,39 +47,35 @@ trait CommonValues extends SparkLoggerComponent{ val StorageLevel = "MEMORY_ONLY_SER" val StorageStreamLevel = "MEMORY_ONLY" val connectionHostModel = ConnectionHostModel( - Seq(ConnectionModel(ConsumerHost, ConsumerPort.toInt)), - Seq(ConnectionModel(ProducerHost, ProducerPort.toInt))) + Seq(ConnectionModel(ConsumerHost, ConsumerPort.toInt)), + Seq(ConnectionModel(ProducerHost, ProducerPort.toInt))) val topicModel = TopicModel(TopicTest) val kafkaOptionsModel = KafkaOptionsModel(connectionHostModel, - Seq(topicModel), - GroupId, - PartitionOutputEmpty, - additionalOptionsEmpty, - StorageLevel - ) + Seq(topicModel), + GroupId, + PartitionOutputEmpty, + additionalOptionsEmpty, + StorageLevel) val kafkaOptionsModelEmptyConnection = KafkaOptionsModel(ConnectionHostModel(Seq(), Seq()), - Seq(topicModel), - GroupId, - PartitionOutputEmpty, - additionalOptionsEmpty, - StorageLevel - ) + Seq(topicModel), + GroupId, + PartitionOutputEmpty, + additionalOptionsEmpty, + StorageLevel) val kafkaOptionsModelEmptyTopics = KafkaOptionsModel(connectionHostModel, - Seq(), - s"$GroupId-${Random.nextInt(10000)}", - PartitionOutputEmpty, - additionalOptionsEmpty, - StorageLevel - ) + Seq(), + s"$GroupId-${Random.nextInt(10000)}", + PartitionOutputEmpty, + additionalOptionsEmpty, + StorageLevel) val kafkaStreamModel = KafkaOptionsModel(connectionHostModel, - Seq(topicModel), - GroupId, - PartitionOutputEmpty, - additionalOptionsStream, - StorageStreamLevel - ) + Seq(topicModel), + GroupId, + PartitionOutputEmpty, + additionalOptionsStream, + StorageStreamLevel) val zookeeperConfEmpty = Map.empty[String, String] val zookeeperConfError = Map("a" -> "c", "a.b" -> "c") @@ -88,48 +84,53 @@ trait CommonValues extends SparkLoggerComponent{ val AliasName = "alias" val Sql = s"select * from $TableName" val queryModel = EphemeralQueryModel(TableName, Sql, AliasName) - val queryOptionsModel = EphemeralQueryModel(TableName, Sql, AliasName, 5, Map("option" -> "value")) + val queryOptionsModel = + EphemeralQueryModel(TableName, Sql, AliasName, 5, Map("option" -> "value")) val ephemeralOptionsEmptySparkOptions = EphemeralOptionsModel( - kafkaOptionsModel, - EphemeralOptionsModel.DefaultAtomicWindow, - EphemeralOptionsModel.DefaultMaxWindow, - EphemeralOutputFormat.ROW, - s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableName", - Map.empty + kafkaOptionsModel, + EphemeralOptionsModel.DefaultAtomicWindow, + EphemeralOptionsModel.DefaultMaxWindow, + EphemeralOutputFormat.ROW, + s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableName", + Map.empty ) val ephemeralOptionsWithSparkOptions = EphemeralOptionsModel( - kafkaOptionsModel, - EphemeralOptionsModel.DefaultAtomicWindow, - EphemeralOptionsModel.DefaultMaxWindow, - EphemeralOutputFormat.ROW, - s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableName", - Map("spark.defaultParallelism" -> "50") + kafkaOptionsModel, + EphemeralOptionsModel.DefaultAtomicWindow, + EphemeralOptionsModel.DefaultMaxWindow, + EphemeralOutputFormat.ROW, + s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableName", + Map("spark.defaultParallelism" -> "50") ) val ephemeralOptionsWithSparkOptionsPrefix = EphemeralOptionsModel( - kafkaOptionsModel, - EphemeralOptionsModel.DefaultAtomicWindow, - EphemeralOptionsModel.DefaultMaxWindow, - EphemeralOutputFormat.ROW, - s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableName", - Map("defaultParallelism" -> "50") + kafkaOptionsModel, + EphemeralOptionsModel.DefaultAtomicWindow, + EphemeralOptionsModel.DefaultMaxWindow, + EphemeralOutputFormat.ROW, + s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableName", + Map("defaultParallelism" -> "50") ) val ephemeralOptionsStreamKafka = EphemeralOptionsModel( - kafkaStreamModel, - EphemeralOptionsModel.DefaultAtomicWindow, - EphemeralOptionsModel.DefaultMaxWindow, - EphemeralOutputFormat.ROW, - s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableName", - Map.empty + kafkaStreamModel, + EphemeralOptionsModel.DefaultAtomicWindow, + EphemeralOptionsModel.DefaultMaxWindow, + EphemeralOutputFormat.ROW, + s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableName", + Map.empty ) - val ephemeralTableModelWithoutSparkOptions = EphemeralTableModel(TableName, ephemeralOptionsEmptySparkOptions) - val ephemeralTableModelStreamKafkaOptions = EphemeralTableModel(TableName, ephemeralOptionsStreamKafka) - val ephemeralTableModelWithSparkOptions = EphemeralTableModel(TableName, ephemeralOptionsWithSparkOptions) - val ephemeralTableModelWithSparkOptionsPrefix = EphemeralTableModel(TableName, ephemeralOptionsWithSparkOptionsPrefix) + val ephemeralTableModelWithoutSparkOptions = + EphemeralTableModel(TableName, ephemeralOptionsEmptySparkOptions) + val ephemeralTableModelStreamKafkaOptions = + EphemeralTableModel(TableName, ephemeralOptionsStreamKafka) + val ephemeralTableModelWithSparkOptions = + EphemeralTableModel(TableName, ephemeralOptionsWithSparkOptions) + val ephemeralTableModelWithSparkOptionsPrefix = + EphemeralTableModel(TableName, ephemeralOptionsWithSparkOptionsPrefix) /** - * Select query - */ + * Select query + */ val TableNameSelect = "tabletestselect" val TopicTestSelect = "topictestselect" val AliasNameSelect = "aliasselect" @@ -137,27 +138,27 @@ trait CommonValues extends SparkLoggerComponent{ val querySelectModel = EphemeralQueryModel(TableNameSelect, SqlSelect, AliasNameSelect) val topicModelSelect = TopicModel(TopicTestSelect) val kafkaStreamModelSelect = KafkaOptionsModel(connectionHostModel, - Seq(topicModelSelect), - GroupId, - PartitionOutputEmpty, - additionalOptionsStream, - StorageStreamLevel - ) - val checkpointDirectorySelect = s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableNameSelect" + Seq(topicModelSelect), + GroupId, + PartitionOutputEmpty, + additionalOptionsStream, + StorageStreamLevel) + val checkpointDirectorySelect = + s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableNameSelect" val ephemeralOptionsStreamKafkaSelect = EphemeralOptionsModel( - kafkaStreamModelSelect, - EphemeralOptionsModel.DefaultAtomicWindow, - EphemeralOptionsModel.DefaultMaxWindow, - EphemeralOutputFormat.ROW, - s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableNameSelect", - Map.empty + kafkaStreamModelSelect, + EphemeralOptionsModel.DefaultAtomicWindow, + EphemeralOptionsModel.DefaultMaxWindow, + EphemeralOutputFormat.ROW, + s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableNameSelect", + Map.empty ) val ephemeralTableModelStreamKafkaOptionsSelect = EphemeralTableModel(TableNameSelect, ephemeralOptionsStreamKafkaSelect) /** - * Projected query - */ + * Projected query + */ val TableNameProject = "tabletestproject" val TopicTestProject = "topicTestproject" val AliasNameProject = "aliasproject" @@ -165,20 +166,20 @@ trait CommonValues extends SparkLoggerComponent{ val queryProjectedModel = EphemeralQueryModel(TableNameProject, SqlProjected, AliasNameProject) val topicModelProject = TopicModel(TopicTestProject) val kafkaStreamModelProject = KafkaOptionsModel(connectionHostModel, - Seq(topicModelProject), - GroupId, - PartitionOutputEmpty, - additionalOptionsStream, - StorageStreamLevel - ) - val checkpointDirectoryProject = s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableNameProject" + Seq(topicModelProject), + GroupId, + PartitionOutputEmpty, + additionalOptionsStream, + StorageStreamLevel) + val checkpointDirectoryProject = + s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableNameProject" val ephemeralOptionsStreamKafkaProject = EphemeralOptionsModel( - kafkaStreamModelProject, - EphemeralOptionsModel.DefaultAtomicWindow, - EphemeralOptionsModel.DefaultMaxWindow, - EphemeralOutputFormat.ROW, - s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableNameProject", - Map.empty + kafkaStreamModelProject, + EphemeralOptionsModel.DefaultAtomicWindow, + EphemeralOptionsModel.DefaultMaxWindow, + EphemeralOutputFormat.ROW, + s"${EphemeralOptionsModel.DefaultCheckpointDirectory}/$TableNameProject", + Map.empty ) val ephemeralTableModelStreamKafkaOptionsProject = EphemeralTableModel(TableNameProject, ephemeralOptionsStreamKafkaProject) @@ -186,11 +187,13 @@ trait CommonValues extends SparkLoggerComponent{ def parseZookeeperCatalogConfig(zookeeperConf: Map[String, String]): Map[String, String] = { Map(CatalogClassConfigKey -> ZookeeperClass) ++ Map(StreamingCatalogClassConfigKey -> ZookeeperStreamingClass) ++ - zookeeperConf.map { case (key, value) => - s"$CatalogConfigKey.$ZookeeperPrefixName.$key" -> value + zookeeperConf.map { + case (key, value) => + s"$CatalogConfigKey.$ZookeeperPrefixName.$key" -> value } ++ - zookeeperConf.map { case (key, value) => - s"$StreamingConfigKey.$CatalogConfigKey.$ZookeeperPrefixName.$key" -> value + zookeeperConf.map { + case (key, value) => + s"$StreamingConfigKey.$CatalogConfigKey.$ZookeeperPrefixName.$key" -> value } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/kafka/CrossdataStreamingHelperProjectIT.scala b/streaming/src/test/scala/org/apache/spark/streaming/kafka/CrossdataStreamingHelperProjectIT.scala index f39ce2ede..3f93a8e2d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/kafka/CrossdataStreamingHelperProjectIT.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/kafka/CrossdataStreamingHelperProjectIT.scala @@ -37,7 +37,8 @@ import scala.language.postfixOps @RunWith(classOf[JUnitRunner]) class CrossdataStreamingHelperProjectIT extends BaseSparkStreamingXDTest with CommonValues { - private val sparkConf = new SparkConf().setMaster("local[2]").setAppName(this.getClass.getSimpleName) + private val sparkConf = + new SparkConf().setMaster("local[2]").setAppName(this.getClass.getSimpleName) private var sc: SparkContext = _ private var kafkaTestUtils: KafkaTestUtils = _ private var zookeeperConf: Map[String, String] = _ @@ -56,7 +57,9 @@ class CrossdataStreamingHelperProjectIT extends BaseSparkStreamingXDTest with Co zookeeperConf = Map("connectionString" -> kafkaTestUtils.zkAddress) catalogConf = parseZookeeperCatalogConfig(zookeeperConf) xDContext = XDContext.getOrCreate(sc, parseCatalogConfig(catalogConf)) - zookeeperStreamingCatalog = new ZookeeperStreamingCatalog(new SimpleCatalystConf(true), XDContext.xdConfig) //TODO Replace XDContext.xdConfig when refactoring CoreConfig + zookeeperStreamingCatalog = new ZookeeperStreamingCatalog( + new SimpleCatalystConf(true), + XDContext.xdConfig) //TODO Replace XDContext.xdConfig when refactoring CoreConfig } if (consumer == null) { @@ -97,19 +100,19 @@ class CrossdataStreamingHelperProjectIT extends BaseSparkStreamingXDTest with Co val producerPortKafka = kafkaTestUtils.brokerAddress.split(":").last val kafkaStreamModelZk = kafkaStreamModelProject.copy( - connection = connectionHostModel.copy( - zkConnection = Seq(ConnectionModel(consumerHostZK, consumerPortZK)), - kafkaConnection = Seq(ConnectionModel(producerHostKafka, producerPortKafka.toInt)))) + connection = connectionHostModel.copy( + zkConnection = Seq(ConnectionModel(consumerHostZK, consumerPortZK)), + kafkaConnection = Seq(ConnectionModel(producerHostKafka, producerPortKafka.toInt)))) val ephemeralTableKafka = ephemeralTableModelStreamKafkaOptionsProject.copy( - options = ephemeralOptionsStreamKafkaProject.copy(kafkaOptions = kafkaStreamModelZk - )) + options = ephemeralOptionsStreamKafkaProject.copy(kafkaOptions = kafkaStreamModelZk)) zookeeperStreamingCatalog.createEphemeralQuery(queryProjectedModel) zookeeperStreamingCatalog.createEphemeralTable(ephemeralTableKafka) zookeeperStreamingCatalog.getEphemeralTable(TableNameProject) match { case Some(ephemeralTable) => - ssc = CrossdataStreamingHelper.createContext(ephemeralTable, sparkConf, zookeeperConf, catalogConf) + ssc = CrossdataStreamingHelper + .createContext(ephemeralTable, sparkConf, zookeeperConf, catalogConf) val valuesToSent = Array("""{"name": "a"}""", """{"name": "c"}""") kafkaTestUtils.createTopic(TopicTestProject) kafkaTestUtils.sendMessages(TopicTestProject, valuesToSent) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/kafka/CrossdataStreamingHelperSelectIT.scala b/streaming/src/test/scala/org/apache/spark/streaming/kafka/CrossdataStreamingHelperSelectIT.scala index d2ae70453..ae740a268 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/kafka/CrossdataStreamingHelperSelectIT.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/kafka/CrossdataStreamingHelperSelectIT.scala @@ -33,29 +33,31 @@ import scala.language.postfixOps @RunWith(classOf[JUnitRunner]) class CrossdataStreamingHelperSelectIT extends BaseSparkStreamingXDTest with CommonValues { - private val sparkConf = new SparkConf().setMaster("local[2]").setAppName(this.getClass.getSimpleName) - private var sc : SparkContext = _ + private val sparkConf = + new SparkConf().setMaster("local[2]").setAppName(this.getClass.getSimpleName) + private var sc: SparkContext = _ private var kafkaTestUtils: KafkaTestUtils = _ private var zookeeperConf: Map[String, String] = _ private var catalogConf: Map[String, String] = _ - private var xDContext : XDContext = _ - private var zookeeperStreamingCatalog : ZookeeperStreamingCatalog = _ + private var xDContext: XDContext = _ + private var zookeeperStreamingCatalog: ZookeeperStreamingCatalog = _ private var consumer: ConsumerConnector = _ private var ssc: StreamingContext = _ override def beforeAll { sc = SparkContext.getOrCreate(sparkConf) - if (kafkaTestUtils == null){ + if (kafkaTestUtils == null) { kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() zookeeperConf = Map("connectionString" -> kafkaTestUtils.zkAddress) catalogConf = parseZookeeperCatalogConfig(zookeeperConf) xDContext = XDContext.getOrCreate(sc, parseCatalogConfig(catalogConf)) - zookeeperStreamingCatalog = new ZookeeperStreamingCatalog(new SimpleCatalystConf(true), XDContext.xdConfig) + zookeeperStreamingCatalog = + new ZookeeperStreamingCatalog(new SimpleCatalystConf(true), XDContext.xdConfig) } - if (consumer == null){ + if (consumer == null) { val props = new Properties() props.put("zookeeper.connect", kafkaTestUtils.zkAddress) props.put("group.id", GroupId) @@ -71,7 +73,7 @@ class CrossdataStreamingHelperSelectIT extends BaseSparkStreamingXDTest with Com consumer = null } - if(ssc != null){ + if (ssc != null) { ssc.stop(stopSparkContext = true, stopGracefully = false) ssc.awaitTerminationOrTimeout(6000) ssc = null @@ -83,7 +85,7 @@ class CrossdataStreamingHelperSelectIT extends BaseSparkStreamingXDTest with Com } } -/* + /* test("Crossdata streaming must save into the kafka output the sql results") { deletePath(checkpointDirectorySelect) @@ -131,5 +133,5 @@ class CrossdataStreamingHelperSelectIT extends BaseSparkStreamingXDTest with Com case None => throw new Exception("Ephemeral table not created") } } -*/ + */ } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamIT.scala b/streaming/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamIT.scala index 6eb9daeae..b073a1a4b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamIT.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamIT.scala @@ -33,7 +33,7 @@ class KafkaStreamIT extends BaseSparkStreamingXDTest with CommonValues { val sparkConf = new SparkConf().setMaster("local[2]").setAppName(this.getClass.getSimpleName) val sc = SparkContext.getOrCreate(sparkConf) var ssc: StreamingContext = _ - val kafkaTestUtils: KafkaTestUtils = new KafkaTestUtils + val kafkaTestUtils: KafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() after { @@ -44,7 +44,7 @@ class KafkaStreamIT extends BaseSparkStreamingXDTest with CommonValues { } } - override def afterAll : Unit = { + override def afterAll: Unit = { kafkaTestUtils.teardown() } @@ -61,9 +61,9 @@ class KafkaStreamIT extends BaseSparkStreamingXDTest with CommonValues { val producerPortKafka = kafkaTestUtils.brokerAddress.split(":").last val kafkaStreamModelZk = kafkaStreamModel.copy( - connection = connectionHostModel.copy( - zkConnection = Seq(ConnectionModel(consumerHostZK, consumerPortZK)), - kafkaConnection = Seq(ConnectionModel(producerHostKafka, producerPortKafka.toInt)))) + connection = connectionHostModel.copy( + zkConnection = Seq(ConnectionModel(consumerHostZK, consumerPortZK)), + kafkaConnection = Seq(ConnectionModel(producerHostKafka, producerPortKafka.toInt)))) val input = new KafkaInput(kafkaStreamModelZk) val stream = input.createStream(ssc) @@ -71,9 +71,10 @@ class KafkaStreamIT extends BaseSparkStreamingXDTest with CommonValues { stream.map(_._2).countByValue().foreachRDD { rdd => val ret = rdd.collect() - ret.toMap.foreach { case (key, value) => - val count = result.getOrElseUpdate(key, 0) + value - result.put(key, count) + ret.toMap.foreach { + case (key, value) => + val count = result.getOrElseUpdate(key, 0) + value + result.put(key, count) } } @@ -97,9 +98,9 @@ class KafkaStreamIT extends BaseSparkStreamingXDTest with CommonValues { val producerPortKafka = kafkaTestUtils.brokerAddress.split(":").last val kafkaStreamModelZk = kafkaStreamModelProject.copy( - connection = connectionHostModel.copy( - zkConnection = Seq(ConnectionModel(consumerHostZK, consumerPortZK)), - kafkaConnection = Seq(ConnectionModel(producerHostKafka, producerPortKafka.toInt)))) + connection = connectionHostModel.copy( + zkConnection = Seq(ConnectionModel(consumerHostZK, consumerPortZK)), + kafkaConnection = Seq(ConnectionModel(producerHostKafka, producerPortKafka.toInt)))) val input = new KafkaInput(kafkaStreamModelZk) val stream = input.createStream(ssc) diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverDdlIT.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverDdlIT.scala index 09228eab3..64c665e16 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverDdlIT.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverDdlIT.scala @@ -33,24 +33,28 @@ class DriverDdlIT extends MongoWithSharedContext { withDriverDo { driver => val mongoImportOptions = Map( - MongodbConfig.Host -> s"$MongoHost:$MongoPort", - MongodbConfig.Database -> Database, - MongodbConfig.Collection -> Collection + MongodbConfig.Host -> s"$MongoHost:$MongoPort", + MongodbConfig.Database -> Database, + MongodbConfig.Collection -> Collection ) - driver.importTables("mongodb", mongoImportOptions).resultSet.head.getSeq(0) shouldBe Seq(Database, Collection) + driver.importTables("mongodb", mongoImportOptions).resultSet.head.getSeq(0) shouldBe Seq( + Database, + Collection) } } it should "allow to create tables" in { withDriverDo { driver => - - val crtTableResult = driver.createTable( - name = "crtTable", - dataSourceProvider = "org.apache.spark.sql.json", - schema = None, - options = Map("path" -> Paths.get(getClass.getResource("/tabletest.json").toURI).toString), - isTemporary = true).resultSet + val crtTableResult = driver + .createTable( + name = "crtTable", + dataSourceProvider = "org.apache.spark.sql.json", + schema = None, + options = + Map("path" -> Paths.get(getClass.getResource("/tabletest.json").toURI).toString), + isTemporary = true) + .resultSet driver.listTables() should contain("crtTable", None) } @@ -58,28 +62,34 @@ class DriverDdlIT extends MongoWithSharedContext { it should "allow to drop tables" in { withDriverDo { driver => - - driver.sql( - s"CREATE TEMPORARY TABLE jsonTable3 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI).toString}')" - ).waitForResult() + driver + .sql( + s"CREATE TEMPORARY TABLE jsonTable3 USING org.apache.spark.sql.json OPTIONS (path '${Paths + .get(getClass.getResource("/tabletest.json").toURI) + .toString}')" + ) + .waitForResult() driver.dropTable("jsonTable3").waitForResult() - driver.listTables() should not contain("jsonTable3", None) + driver.listTables() should not contain ("jsonTable3", None) } } it should "allow to drop all tables" in { withDriverDo { driver => - - driver.sql( - s"CREATE TEMPORARY TABLE jsonTable3 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI).toString}')" - ).waitForResult() + driver + .sql( + s"CREATE TEMPORARY TABLE jsonTable3 USING org.apache.spark.sql.json OPTIONS (path '${Paths + .get(getClass.getResource("/tabletest.json").toURI) + .toString}')" + ) + .waitForResult() driver.dropAllTables().waitForResult() - driver.listTables() should not contain("jsonTable3", None) + driver.listTables() should not contain ("jsonTable3", None) } } -} \ No newline at end of file +} diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverIT.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverIT.scala index 03c630eb9..0a1692567 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverIT.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverIT.scala @@ -34,20 +34,26 @@ class DriverIT extends EndToEndTest { assumeCrossdataUpAndRunning() withDriverDo { driver => - val result = driver.sql("select select").waitForResult(10 seconds) result shouldBe an[ErrorSQLResult] result.asInstanceOf[ErrorSQLResult].cause.isDefined shouldBe (true) result.asInstanceOf[ErrorSQLResult].cause.get shouldBe a[Exception] - result.asInstanceOf[ErrorSQLResult].cause.get.getMessage should include regex "cannot resolve .*" + result + .asInstanceOf[ErrorSQLResult] + .cause + .get + .getMessage should include regex "cannot resolve .*" } } it should "return a SuccessfulQueryResult when executing a select *" in { assumeCrossdataUpAndRunning() withDriverDo { driver => - - driver.sql(s"CREATE TEMPORARY TABLE jsonTable USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI).toString}')").waitForResult() + driver + .sql(s"CREATE TEMPORARY TABLE jsonTable USING org.apache.spark.sql.json OPTIONS (path '${Paths + .get(getClass.getResource("/tabletest.json").toURI) + .toString}')") + .waitForResult() val result = driver.sql("SELECT * FROM jsonTable").waitForResult() result shouldBe an[SuccessfulSQLResult] @@ -61,31 +67,30 @@ class DriverIT extends EndToEndTest { it should "get a list of tables" in { assumeCrossdataUpAndRunning withDriverDo { driver => - - driver.sql( - s"CREATE TABLE db.jsonTable2 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI).toString}')" - ).waitForResult() - - driver.sql( - s"CREATE TABLE jsonTable2 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI).toString}')" - ).waitForResult() - - driver.listTables() should contain allOf(("jsonTable2", Some("db")), ("jsonTable2", None)) + driver + .sql( + s"CREATE TABLE db.jsonTable2 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI).toString}')" + ) + .waitForResult() + + driver + .sql( + s"CREATE TABLE jsonTable2 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI).toString}')" + ) + .waitForResult() + + driver.listTables() should contain allOf (("jsonTable2", Some("db")), ("jsonTable2", None)) } } - - it should "indicates that the cluster is alive when there is a server up" in { withDriverDo { driver => - driver.isClusterAlive(6 seconds) shouldBe true } } it should "return the addresses of servers up and running" in { withDriverDo { driver => - val addresses = Await.result(driver.serversUp(), 6 seconds) addresses should have length 1 @@ -95,7 +100,6 @@ class DriverIT extends EndToEndTest { it should "return the current cluster state" in { withDriverDo { driver => - val clusterState = Await.result(driver.clusterState(), 6 seconds) clusterState.getLeader.host shouldBe Some("127.0.0.1") @@ -128,7 +132,9 @@ class DriverIT extends EndToEndTest { val filePath = getClass.getResource("/TestAddApp.jar").getPath withDriverDo { driver => - val result = driver.addAppCommand(filePath, "com.stratio.addApp.AddAppTest.main", Some("testApp")).waitForResult() + val result = driver + .addAppCommand(filePath, "com.stratio.addApp.AddAppTest.main", Some("testApp")) + .waitForResult() driver.sql("EXECUTE testApp(rain,bow)").waitForResult() result.hasError should equal(false) } @@ -141,10 +147,14 @@ class DriverIT extends EndToEndTest { val filePath = getClass.getResource("/TestAddApp.jar").getPath withDriverDo { driver => - val addAppResult = driver.addAppCommand(filePath, "com.stratio.addApp.AddAppTest.main", Some("testApp")).waitForResult() + val addAppResult = driver + .addAppCommand(filePath, "com.stratio.addApp.AddAppTest.main", Some("testApp")) + .waitForResult() addAppResult.hasError should equal(false) - val executeResult = driver.sql("""EXECUTE testApp(rain,bow2) OPTIONS (executor.memory '20G')""").waitForResult() + val executeResult = driver + .sql("""EXECUTE testApp(rain,bow2) OPTIONS (executor.memory '20G')""") + .waitForResult() executeResult.hasError should equal(false) executeResult.resultSet.length should equal(1) @@ -153,7 +163,6 @@ class DriverIT extends EndToEndTest { } } - it should "allow running multiple drivers per JVM" in { val driverTable = "drvtable" @@ -164,12 +173,24 @@ class DriverIT extends EndToEndTest { driver shouldNot be theSameInstanceAs anotherDriver driver.listTables().size shouldBe anotherDriver.listTables().size - driver.sql(s"CREATE TEMPORARY TABLE $driverTable USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI).toString}')").waitForResult() + driver + .sql( + s"CREATE TEMPORARY TABLE $driverTable USING org.apache.spark.sql.json OPTIONS (path '${Paths + .get(getClass.getResource("/tabletest.json").toURI) + .toString}')") + .waitForResult() driver.sql(s"SELECT * FROM $driverTable").waitForResult().resultSet should not be empty anotherDriver.sql(s"SELECT * FROM $driverTable").waitForResult().hasError shouldBe true - anotherDriver.sql(s"CREATE TEMPORARY TABLE $anotherDriverTable USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI).toString}')").waitForResult() - anotherDriver.sql(s"SELECT * FROM $anotherDriverTable").waitForResult().resultSet should not be empty + anotherDriver + .sql(s"CREATE TEMPORARY TABLE $anotherDriverTable USING org.apache.spark.sql.json OPTIONS (path '${Paths + .get(getClass.getResource("/tabletest.json").toURI) + .toString}')") + .waitForResult() + anotherDriver + .sql(s"SELECT * FROM $anotherDriverTable") + .waitForResult() + .resultSet should not be empty driver.sql(s"SELECT * FROM $anotherDriverTable").waitForResult().hasError shouldBe true } @@ -179,4 +200,4 @@ class DriverIT extends EndToEndTest { } } -} \ No newline at end of file +} diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverStandaloneIT.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverStandaloneIT.scala index 938b77b0c..8949fd903 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverStandaloneIT.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverStandaloneIT.scala @@ -21,7 +21,7 @@ class DriverStandaloneIT extends BaseXDTest { "Crossdata driver" should "fail with a timeout when there is no server" in { - the [RuntimeException] thrownBy { + the[RuntimeException] thrownBy { Driver.newSession() } should have message s"Cannot establish connection to XDServer: timed out after ${Driver.InitializationTimeout}" diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/EndToEndTest.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/EndToEndTest.scala index 7de74e9e1..a51313ced 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/EndToEndTest.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/EndToEndTest.scala @@ -44,7 +44,6 @@ trait EndToEndTest extends BaseXDTest with BeforeAndAfterAll { crossdataServer.foreach(_.destroy()) } - override protected def beforeAll(): Unit = { init() } @@ -56,4 +55,4 @@ trait EndToEndTest extends BaseXDTest with BeforeAndAfterAll { def assumeCrossdataUpAndRunning() = { assume(crossdataServer.isDefined, "Crossdata server is not up and running") } -} \ No newline at end of file +} diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/FlattenedTablesIT.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/FlattenedTablesIT.scala index 7248c4c4d..4a18bf545 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/FlattenedTablesIT.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/FlattenedTablesIT.scala @@ -25,13 +25,13 @@ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class FlattenedTablesIT extends MongoWithSharedContext { - implicit val configWithFlattening: Option[DriverConf] = Some(new DriverConf().setFlattenTables(true)) + implicit val configWithFlattening: Option[DriverConf] = Some( + new DriverConf().setFlattenTables(true)) "The Driver" should " List table's description with nested and array fields flattened" in { assumeCrossdataUpAndRunning withDriverDo { flattenedDriver => - //Experimentation val result: Seq[FieldMetadata] = flattenedDriver.describeTable(Some(Database), Collection) @@ -48,29 +48,32 @@ class FlattenedTablesIT extends MongoWithSharedContext { assumeCrossdataUpAndRunning withDriverDo { flattenedDriver => - //Experimentation val result: Seq[FieldMetadata] = flattenedDriver.describeTable(Some(Database), Collection) //Expectations - val addressType = StructType(Seq(StructField("street", StringType), StructField("city", StringType), StructField("zip", IntegerType))) - val detailAccount = StructType(Seq(StructField("bank", StringType), StructField("office", IntegerType))) - val accountType = StructType(Seq(StructField("number", IntegerType), StructField("details", detailAccount))) + val addressType = StructType( + Seq(StructField("street", StringType), + StructField("city", StringType), + StructField("zip", IntegerType))) + val detailAccount = + StructType(Seq(StructField("bank", StringType), StructField("office", IntegerType))) + val accountType = + StructType(Seq(StructField("number", IntegerType), StructField("details", detailAccount))) result should contain(new FieldMetadata("address", addressType)) result should contain(new FieldMetadata("account", accountType)) - } (Some(new DriverConf().setFlattenTables(false))) + }(Some(new DriverConf().setFlattenTables(false))) } - it should " Query with Flattened Fields" in { assumeCrossdataUpAndRunning withDriverDo { flattenedDriver => - //Experimentation - val result = flattenedDriver.sql(s"SELECT address.street from $Database.$Collection").resultSet + val result = + flattenedDriver.sql(s"SELECT address.street from $Database.$Collection").resultSet //Expectations result.head.toSeq(0).toString should fullyMatch regex "[0-9]+th Avenue" @@ -82,9 +85,10 @@ class FlattenedTablesIT extends MongoWithSharedContext { assumeCrossdataUpAndRunning withDriverDo { flattenedDriver => - //Experimentation - val result = flattenedDriver.sql(s"SELECT description FROM $Database.$Collection WHERE address.street = '5th Avenue'").resultSet + val result = flattenedDriver + .sql(s"SELECT description FROM $Database.$Collection WHERE address.street = '5th Avenue'") + .resultSet //Expectations result.head.toSeq(0).toString should be equals "description5" @@ -93,5 +97,3 @@ class FlattenedTablesIT extends MongoWithSharedContext { } } - - diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/JavaDriverIT.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/JavaDriverIT.scala index 3e4c579be..09f851baa 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/JavaDriverIT.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/JavaDriverIT.scala @@ -24,22 +24,23 @@ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) -class JavaDriverIT extends EndToEndTest{ - +class JavaDriverIT extends EndToEndTest { "JavaDriver (with default options)" should "get a list of tables" in { assumeCrossdataUpAndRunning() withJavaDriverDo { javaDriver => - javaDriver.sql( - s"CREATE TABLE db.jsonTable3 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI()).toString}')" + s"CREATE TABLE db.jsonTable3 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI()).toString}')" ) javaDriver.sql( - s"CREATE TABLE jsonTable3 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI()).toString}')" + s"CREATE TABLE jsonTable3 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI()).toString}')" ) - javaDriver.listTables() should contain allOf(new JavaTableName("jsonTable3", "db"), new JavaTableName("jsonTable3", "")) + javaDriver + .listTables() should contain allOf (new JavaTableName("jsonTable3", "db"), new JavaTableName( + "jsonTable3", + "")) } } @@ -48,17 +49,18 @@ class JavaDriverIT extends EndToEndTest{ assumeCrossdataUpAndRunning() withJavaDriverDo { javaDriver => - javaDriver.sql( - s"CREATE TABLE db.jsonTable3 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI()).toString}')" + s"CREATE TABLE db.jsonTable3 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI()).toString}')" ) javaDriver.sql( - s"CREATE TABLE jsonTable3 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI()).toString}')" + s"CREATE TABLE jsonTable3 USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI()).toString}')" ) - javaDriver.listTables() should contain allOf(new JavaTableName("jsonTable3", "db"), new JavaTableName("jsonTable3", "")) - } (Some(new DriverConf().setFlattenTables(true))) - + javaDriver + .listTables() should contain allOf (new JavaTableName("jsonTable3", "db"), new JavaTableName( + "jsonTable3", + "")) + }(Some(new DriverConf().setFlattenTables(true))) } } diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/MongoWithSharedContext.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/MongoWithSharedContext.scala index ca0c0db75..9fd2c81e7 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/MongoWithSharedContext.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/MongoWithSharedContext.scala @@ -41,20 +41,24 @@ class MongoWithSharedContext extends BaseXDTest with MongoConstants with BeforeA val collection = client(Database)(Collection) for (a <- 1 to 10) { collection.insert { - MongoDBObject("id" -> a, - "age" -> (10 + a), - "description" -> s"description$a", - "enrolled" -> (a % 2 == 0), - "name" -> s"Name $a", - "address" -> MongoDBObject("street" -> s"${a}th Avenue", "city" -> s"City $a", "zip" -> (28000+a)), - "account" -> MongoDBObject("number" -> (11235813*a), "details" -> MongoDBObject("bank" -> "Mercantil", "office" -> (12357+a))), - "grades" -> Seq(MongoDBObject("FP" -> Seq(7.0, 8.0)), MongoDBObject("REACTIVEARCHS" -> Seq(9.0))) - ) + MongoDBObject( + "id" -> a, + "age" -> (10 + a), + "description" -> s"description$a", + "enrolled" -> (a % 2 == 0), + "name" -> s"Name $a", + "address" -> MongoDBObject("street" -> s"${a}th Avenue", + "city" -> s"City $a", + "zip" -> (28000 + a)), + "account" -> MongoDBObject("number" -> (11235813 * a), + "details" -> MongoDBObject("bank" -> "Mercantil", + "office" -> (12357 + a))), + "grades" -> Seq(MongoDBObject("FP" -> Seq(7.0, 8.0)), + MongoDBObject("REACTIVEARCHS" -> Seq(9.0)))) } } } - protected def cleanTestData: Unit = { val client = this.client @@ -63,8 +67,6 @@ class MongoWithSharedContext extends BaseXDTest with MongoConstants with BeforeA } - - def init() = { crossdataServer = Some(new CrossdataServer) crossdataServer.foreach(_.init(null)) @@ -80,13 +82,11 @@ class MongoWithSharedContext extends BaseXDTest with MongoConstants with BeforeA crossdataServer.foreach(_.destroy()) } - override protected def beforeAll(): Unit = { init() saveTestData - val importQuery = - s""" + val importQuery = s""" |IMPORT TABLES |USING $SourceProvider |OPTIONS ( @@ -97,7 +97,8 @@ class MongoWithSharedContext extends BaseXDTest with MongoConstants with BeforeA |) """.stripMargin - crossdataServer.foreach(_.sessionProviderOpt.foreach(_.session(SessionID).get.sql(importQuery))) + crossdataServer.foreach( + _.sessionProviderOpt.foreach(_.session(SessionID).get.sql(importQuery))) } override protected def afterAll(): Unit = { @@ -122,4 +123,4 @@ sealed trait MongoConstants { val MongoPort = 27017 val SourceProvider = "com.stratio.crossdata.connector.mongodb" -} \ No newline at end of file +} diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/config/DriverConfSpec.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/config/DriverConfSpec.scala index 09799b285..5be46573b 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/config/DriverConfSpec.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/config/DriverConfSpec.scala @@ -21,8 +21,7 @@ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) -class DriverConfSpec extends BaseXDTest{ - +class DriverConfSpec extends BaseXDTest { "DriverConf" should "load default config" in { val conf = new DriverConf() @@ -39,7 +38,10 @@ class DriverConfSpec extends BaseXDTest{ } it should "allow to set common properties" in { - val conf = new DriverConf().setTunnelTimeout(10).setClusterContactPoint("1.1.1.1:1000", "2.2.2.2:2000").setFlattenTables(true) + val conf = new DriverConf() + .setTunnelTimeout(10) + .setClusterContactPoint("1.1.1.1:1000", "2.2.2.2:2000") + .setFlattenTables(true) conf.getFlattenTables shouldBe true conf.getClusterContactPoint should have length 2 diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/CreateGlobalIndexIT.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/CreateGlobalIndexIT.scala index cd779c98e..7d4fa137d 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/CreateGlobalIndexIT.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/CreateGlobalIndexIT.scala @@ -50,20 +50,16 @@ class CreateGlobalIndexIT extends MongoAndElasticWithSharedContext { mongoClient(mongoTestDatabase).dropDatabase() - elasticClient.execute{ + elasticClient.execute { deleteIndex(defaultIndexES) }.await - super.afterAll() } - "Create global index" should "create an associated index in Elasticsearch" in { - - val sentence = - s"""|CREATE GLOBAL INDEX $indexName + val sentence = s"""|CREATE GLOBAL INDEX $indexName |ON globalIndexDb.proofGlobalIndex (other) |WITH PK id |USING com.stratio.crossdata.connector.elasticsearch @@ -76,7 +72,7 @@ class CreateGlobalIndexIT extends MongoAndElasticWithSharedContext { sql(sentence) - val typeExistResponse = elasticClient.execute{ + val typeExistResponse = elasticClient.execute { typesExist(indexName).in(defaultIndexES) }.await @@ -86,7 +82,6 @@ class CreateGlobalIndexIT extends MongoAndElasticWithSharedContext { } - it should "fail if the target table is temporary" in { val tempTableId = "tempTable" @@ -103,8 +98,7 @@ class CreateGlobalIndexIT extends MongoAndElasticWithSharedContext { sql(createTable1) - val sentence = - s"""|CREATE GLOBAL INDEX fail_index + val sentence = s"""|CREATE GLOBAL INDEX fail_index |ON $tempTableId (other) |WITH PK id |USING com.stratio.crossdata.connector.elasticsearch @@ -115,9 +109,9 @@ class CreateGlobalIndexIT extends MongoAndElasticWithSharedContext { | es.cluster '$ElasticClusterName' |)""".stripMargin - the [RuntimeException] thrownBy { + the[RuntimeException] thrownBy { sql(sentence) } should have message s"Cannot create the index. Table `$tempTableId` doesn't exist or is temporary" } -} \ No newline at end of file +} diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/InsertGlobalIndexIT.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/InsertGlobalIndexIT.scala index e8ea80c8e..9d0d3d17a 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/InsertGlobalIndexIT.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/InsertGlobalIndexIT.scala @@ -51,19 +51,16 @@ class InsertGlobalIndexIT extends MongoAndElasticWithSharedContext { mongoClient(mongoDatabase).dropDatabase() - elasticClient.execute{ + elasticClient.execute { deleteIndex(defaultIndexES) }.await - super.afterAll() } - "Insertion" should "insert indexed columns into elasticsearch index" in { - val sentence = - s"""|CREATE GLOBAL INDEX $indexName + val sentence = s"""|CREATE GLOBAL INDEX $indexName |ON $mongoTableIdentifier (other) |WITH PK id |USING com.stratio.crossdata.connector.elasticsearch @@ -77,23 +74,26 @@ class InsertGlobalIndexIT extends MongoAndElasticWithSharedContext { sql(sentence) - val typeExistResponse = elasticClient.execute{ + val typeExistResponse = elasticClient.execute { typesExist(indexName).in(defaultIndexES) }.await typeExistResponse.isExists shouldBe true - sql(s"INSERT INTO $mongoTableIdentifier VALUES ( 50, 'Samantha', 'Fox', 4),( 1, 'Charlie', 'Green', 5)") - elasticClient.execute{ + elasticClient.execute { flushIndex(defaultIndexES) }.await mongoClient(mongoDatabase)(mongoCollection).count() shouldBe 2 - elasticClient.execute(search in defaultIndexES / indexName).await.getHits.totalHits() shouldBe 2 + elasticClient + .execute(search in defaultIndexES / indexName) + .await + .getHits + .totalHits() shouldBe 2 } -} \ No newline at end of file +} diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/MongoAndElasticWithSharedContext.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/MongoAndElasticWithSharedContext.scala index 762699963..df49248ff 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/MongoAndElasticWithSharedContext.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/MongoAndElasticWithSharedContext.scala @@ -25,14 +25,20 @@ import org.scalatest.BeforeAndAfterAll import scala.util.Try -class MongoAndElasticWithSharedContext extends SharedXDContextTest with Constants with BeforeAndAfterAll with SparkLoggerComponent { +class MongoAndElasticWithSharedContext + extends SharedXDContextTest + with Constants + with BeforeAndAfterAll + with SparkLoggerComponent { lazy val mongoClient: MongoClient = MongoClient(MongoHost, MongoPort) lazy val elasticClient: ElasticClient = Try { - logInfo(s"Connection to elastic search, ElasticHost: $ElasticHost, ElasticNativePort:$ElasticNativePort, ElasticClusterName $ElasticClusterName") + logInfo( + s"Connection to elastic search, ElasticHost: $ElasticHost, ElasticNativePort:$ElasticNativePort, ElasticClusterName $ElasticClusterName") val settings = Settings.settingsBuilder().put("cluster.name", ElasticClusterName).build() - val elasticClient = ElasticClient.transport(settings, ElasticsearchClientUri(ElasticHost, ElasticNativePort)) + val elasticClient = + ElasticClient.transport(settings, ElasticsearchClientUri(ElasticHost, ElasticNativePort)) elasticClient } get @@ -42,7 +48,6 @@ class MongoAndElasticWithSharedContext extends SharedXDContextTest with Constant super.afterAll() } - } sealed trait Constants { @@ -58,11 +63,12 @@ sealed trait Constants { val MongoSourceProvider = "com.stratio.crossdata.connector.mongodb" //Elastic - val ElasticHost: String = Try(config.getStringList("elasticsearch.hosts")).map(_.get(0)).getOrElse("127.0.0.1") + val ElasticHost: String = + Try(config.getStringList("elasticsearch.hosts")).map(_.get(0)).getOrElse("127.0.0.1") val ElasticRestPort = 9200 val ElasticNativePort = 9300 val ElasticSourceProvider = "com.stratio.crossdata.connector.elasticsearch" - val ElasticClusterName: String = Try(config.getString("elasticsearch.cluster")).getOrElse("esCluster") - + val ElasticClusterName: String = + Try(config.getString("elasticsearch.cluster")).getOrElse("esCluster") -} \ No newline at end of file +} diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/MongoCreateGlobalIndexIT.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/MongoCreateGlobalIndexIT.scala index 230fc38e0..89d5f893d 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/MongoCreateGlobalIndexIT.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/globalindex/MongoCreateGlobalIndexIT.scala @@ -21,14 +21,12 @@ import org.apache.spark.sql.Row import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner - @RunWith(classOf[JUnitRunner]) class MongoCreateGlobalIndexIT extends MongoAndElasticWithSharedContext { val mongoTestDatabase = "globalIndexDb" val defaultIndexES = "gidx" - protected override def beforeAll(): Unit = { super.beforeAll() @@ -45,15 +43,22 @@ class MongoCreateGlobalIndexIT extends MongoAndElasticWithSharedContext { sql(createTable1) mongoClient(mongoTestDatabase)("proofGlobalIndex").insert( - MongoDBObject("id" -> 11, "name" -> "prueba", "comments" -> "one comment", "other" -> 12, "another" -> 12) + MongoDBObject("id" -> 11, + "name" -> "prueba", + "comments" -> "one comment", + "other" -> 12, + "another" -> 12) ) mongoClient(mongoTestDatabase)("proofGlobalIndex").insert( - MongoDBObject("id" -> 13, "name" -> "prueba2", "comments" -> "one comment fail", "other" -> 5, "another" -> 12) + MongoDBObject("id" -> 13, + "name" -> "prueba2", + "comments" -> "one comment fail", + "other" -> 5, + "another" -> 12) ) - val sentence = - s"""|CREATE GLOBAL INDEX myIndex + val sentence = s"""|CREATE GLOBAL INDEX myIndex |ON globalIndexDb.proofGlobalIndex (other, another) |WITH PK id |USING com.stratio.crossdata.connector.elasticsearch @@ -68,17 +73,15 @@ class MongoCreateGlobalIndexIT extends MongoAndElasticWithSharedContext { sql(sentence) elasticClient.execute { - index into "gidx" / "myIndex" fields( - "id" -> 11, - "another" -> 12, - "other"-> 12) + index into "gidx" / "myIndex" fields ("id" -> 11, + "another" -> 12, + "other" -> 12) }.await elasticClient.execute { - index into "gidx" / "myIndex" fields( - "id" -> 13, - "another" -> 12, - "other"-> 5) + index into "gidx" / "myIndex" fields ("id" -> 13, + "another" -> 12, + "other" -> 5) }.await elasticClient.execute { @@ -91,7 +94,7 @@ class MongoCreateGlobalIndexIT extends MongoAndElasticWithSharedContext { mongoClient(mongoTestDatabase).dropDatabase() - elasticClient.execute{ + elasticClient.execute { deleteIndex(defaultIndexES) }.await @@ -128,7 +131,12 @@ class MongoCreateGlobalIndexIT extends MongoAndElasticWithSharedContext { } it should "execute a select col where indexedFilter equals to using multiple projects via DDL" in { - val result = xdContext.table("globalIndexDb.proofGlobalIndex").select("name", "other").where($"other" equalTo 5).select("name").collect() + val result = xdContext + .table("globalIndexDb.proofGlobalIndex") + .select("name", "other") + .where($"other" equalTo 5) + .select("name") + .collect() result should have length 1 result shouldBe Array(Row("prueba2")) } @@ -141,17 +149,20 @@ class MongoCreateGlobalIndexIT extends MongoAndElasticWithSharedContext { } it should "support filters mixed with indexedCols" in { - val result = sql(s"select name from globalIndexDb.proofGlobalIndex WHERE other > 10 AND name LIKE '%prueba%'").collect() + val result = sql( + s"select name from globalIndexDb.proofGlobalIndex WHERE other > 10 AND name LIKE '%prueba%'") + .collect() result should have length 1 result shouldBe Array(Row("prueba")) } it should "support filters using equals in two indexed columns" in { - val result = sql(s"select name from globalIndexDb.proofGlobalIndex WHERE other = another").collect() + val result = + sql(s"select name from globalIndexDb.proofGlobalIndex WHERE other = another").collect() result should have length 1 result shouldBe Array(Row("prueba")) } -} \ No newline at end of file +} diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/ignore/DriverIT.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/ignore/DriverIT.scala index fe40c6654..10fbd68a2 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/ignore/DriverIT.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/ignore/DriverIT.scala @@ -29,11 +29,14 @@ import scala.util.Random @RunWith(classOf[JUnitRunner]) class DriverIT extends BaseXDTest { - it should "be able to execute a query involving a temporary table in any server" ignore {// TODO it is ignored until a crossdata-server container can be launched + it should "be able to execute a query involving a temporary table in any server" ignore { // TODO it is ignored until a crossdata-server container can be launched withDriverDo { driver => - - driver.sql(s"CREATE TEMPORARY TABLE jsonTable USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI).toString}')").waitForResult() + driver + .sql(s"CREATE TEMPORARY TABLE jsonTable USING org.apache.spark.sql.json OPTIONS (path '${Paths + .get(getClass.getResource("/tabletest.json").toURI) + .toString}')") + .waitForResult() for (_ <- 1 to 3) { // It assumes that the driver has a round robin policy @@ -50,9 +53,18 @@ class DriverIT extends BaseXDTest { withDriverDo { driver => withDriverDo { anotherDriver => - - driver.sql(s"CREATE TEMPORARY TABLE $randomTable USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI).toString}')").waitForResult() - anotherDriver.sql(s"CREATE TEMPORARY TABLE $randomTable USING org.apache.spark.sql.json OPTIONS (path '${Paths.get(getClass.getResource("/tabletest.json").toURI).toString}')").waitForResult() + driver + .sql( + s"CREATE TEMPORARY TABLE $randomTable USING org.apache.spark.sql.json OPTIONS (path '${Paths + .get(getClass.getResource("/tabletest.json").toURI) + .toString}')") + .waitForResult() + anotherDriver + .sql( + s"CREATE TEMPORARY TABLE $randomTable USING org.apache.spark.sql.json OPTIONS (path '${Paths + .get(getClass.getResource("/tabletest.json").toURI) + .toString}')") + .waitForResult() driver.sql(s"SET spark.sql.shuffle.partitions=400").waitForResult() anotherDriver.sql(s"SET spark.sql.shuffle.partitions=400").waitForResult() @@ -61,10 +73,13 @@ class DriverIT extends BaseXDTest { for (_ <- 1 to 3) { // It assumes that the driver has a round robin policy - val result = driver.sql(s"SELECT title, count(*) FROM $randomTable GROUP BY title").waitForResult() + val result = + driver.sql(s"SELECT title, count(*) FROM $randomTable GROUP BY title").waitForResult() validateResult(result) - val result2 = anotherDriver.sql(s"SELECT title, count(*) FROM $randomTable GROUP BY title").waitForResult() + val result2 = anotherDriver + .sql(s"SELECT title, count(*) FROM $randomTable GROUP BY title") + .waitForResult() validateResult(result2) } @@ -79,4 +94,4 @@ class DriverIT extends BaseXDTest { rows should have length 2 rows(0) should have length 2 } -} \ No newline at end of file +} diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/querybuilder/QueryBuilderSpec.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/querybuilder/QueryBuilderSpec.scala index a1162fb98..3cf1df18b 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/querybuilder/QueryBuilderSpec.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/querybuilder/QueryBuilderSpec.scala @@ -26,7 +26,6 @@ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class QueryBuilderSpec extends BaseXDTest { - "The Query Builder" should "be able to build a completed query using strings" in { val query = select("col, '1', max(col)") from "table inner join table2 on a = b" where "a = b" groupBy "col" having "a = b" orderBy "col ASC" limit 5 @@ -44,7 +43,6 @@ class QueryBuilderSpec extends BaseXDTest { compareAfterFormatting(query, expected) } - it should "be able to add a where clause on a limited query" in { val query = selectAll from 'table limit 1 where 'a < 5 @@ -71,7 +69,6 @@ class QueryBuilderSpec extends BaseXDTest { compareAfterFormatting(query, expected) } - it should "be able to join several queries" in { val query = (selectAll from 'table) unionAll (selectAll from 'table2) unionAll (selectAll from 'table3) @@ -154,7 +151,6 @@ class QueryBuilderSpec extends BaseXDTest { val query_1 = select('c).from('table) val query = selectAll from query_1 - val expected = """ | SELECT * FROM | ( SELECT c FROM table ) @@ -286,7 +282,7 @@ class QueryBuilderSpec extends BaseXDTest { it should "be able to maintain user associations" in { - val query = select (('a + 13) * ('hola + 2) + 5) from 'test + val query = select(('a + 13) * ('hola + 2) + 5) from 'test val expected = """ | SELECT ((a + 13) * (hola + 2)) + 5 @@ -296,10 +292,9 @@ class QueryBuilderSpec extends BaseXDTest { compareAfterFormatting(query, expected) } - it should "be able to support aliases" in { - val query = select ('a as 'alias) from ('test as 'talias, (selectAll from 'table) as 'qalias) + val query = select('a as 'alias) from ('test as 'talias, (selectAll from 'table) as 'qalias) val expected = """ | SELECT a AS alias @@ -309,16 +304,15 @@ class QueryBuilderSpec extends BaseXDTest { compareAfterFormatting(query, expected) } - /* This test is here as documentation. Actually, its testing Scala since a mathematical precedence order is guaranteed by Scala's method names precedence table. Check "Programming in Scala: A comprehensive step-by-step guide", M.Ordersky, Section "5.8 - Operator precedence and associativity". - */ + */ it should "make use of Scala's method names precedence rules" in { - val query = select ('a, 'c - 'd * 'a) from 'test + val query = select('a, 'c - 'd * 'a) from 'test val expected = "SELECT a, c - (d * a) FROM test" @@ -326,7 +320,7 @@ class QueryBuilderSpec extends BaseXDTest { } it should "keep operator precedence provided by the user through the use of parenthesis" in { - val query = select ('a, 'b * ( 'c - 'd )) from 'test + val query = select('a, 'b * ('c - 'd)) from 'test val expected = "SELECT a, b * (c - d) FROM test" @@ -335,14 +329,14 @@ class QueryBuilderSpec extends BaseXDTest { it should "generate correct queries using arithmetic operators" in { - val arithmeticExpressions = ('a + 'b)::('c - 'd)::('e * 'f)::('g / 'h)::('i % 'j)::Nil - val baseQuery = select (arithmeticExpressions:_*) from 'test + val arithmeticExpressions = ('a + 'b) :: ('c - 'd) :: ('e * 'f) :: ('g / 'h) :: ('i % 'j) :: Nil + val baseQuery = select(arithmeticExpressions: _*) from 'test - val query = (baseQuery /: arithmeticExpressions) { - (q, op) => q.where(op === 'ref) + val query = (baseQuery /: arithmeticExpressions) { (q, op) => + q.where(op === 'ref) } - val expectedExpressions = "a + b"::"c - d"::"e * f"::"g / h"::"i % j"::Nil + val expectedExpressions = "a + b" :: "c - d" :: "e * f" :: "g / h" :: "i % j" :: Nil val expected = s""" |SELECT ${expectedExpressions mkString ", "} |FROM test @@ -358,20 +352,28 @@ class QueryBuilderSpec extends BaseXDTest { val selQueryStr = "SELECT a FROM sourceTable" Seq( - (insert into 'test select 'a from 'sourceTable, s"INSERT INTO test $selQueryStr"), - (insert overwrite 'test select 'a from 'sourceTable, s"INSERT OVERWRITE test $selQueryStr") - ) foreach { case (query, expected) => - compareAfterFormatting(query, expected) + (insert into 'test select 'a from 'sourceTable, s"INSERT INTO test $selQueryStr"), + (insert overwrite 'test select 'a from 'sourceTable, s"INSERT OVERWRITE test $selQueryStr") + ) foreach { + case (query, expected) => + compareAfterFormatting(query, expected) } } it should "be able to support common functions in the select expression" in { val query = select( - distinct('col), countDistinct('col), sumDistinct('col), - count(querybuilder.all), approxCountDistinct('col, 0.95), - avg('col), min('col), max('col), sum('col), abs('col) - ) from 'table + distinct('col), + countDistinct('col), + sumDistinct('col), + count(querybuilder.all), + approxCountDistinct('col, 0.95), + avg('col), + min('col), + max('col), + sum('col), + abs('col) + ) from 'table val expected = """ | SELECT DISTINCT col, count( DISTINCT col), sum( DISTINCT col), @@ -383,21 +385,18 @@ class QueryBuilderSpec extends BaseXDTest { compareAfterFormatting(query, expected) } - it should "be able to allow different order selections" in { - val queryAsc = selectAll from 'table orderBy('col asc) - val queryDesc = selectAll from 'table sortBy('col desc) + val queryAsc = selectAll from 'table orderBy ('col asc) + val queryDesc = selectAll from 'table sortBy ('col desc) - val expectedAsc = - """ + val expectedAsc = """ | SELECT * | FROM table | ORDER BY col ASC """ - val expectedDesc = - """ + val expectedDesc = """ | SELECT * | FROM table | SORT BY col DESC @@ -408,10 +407,9 @@ class QueryBuilderSpec extends BaseXDTest { } - it should "be able to support comparison predicates" in { - val query = selectAll from 'table where( !('a < 5 && 'a <= 5 && 'a > 5 && 'a >=5 && 'a === 5 && 'a <> 5 || false)) + val query = selectAll from 'table where (!('a < 5 && 'a <= 5 && 'a > 5 && 'a >= 5 && 'a === 5 && 'a <> 5 || false)) val expected = """ @@ -426,10 +424,9 @@ class QueryBuilderSpec extends BaseXDTest { it should "be able to support common predicates" in { - val query = selectAll from 'table where ( ('a in (2,3,4)) && ('b like "%R") && ('b isNull) && ('b isNotNull)) + val query = selectAll from 'table where (('a in (2, 3, 4)) && ('b like "%R") && ('b isNull) && ('b isNotNull)) - val expected = - """ + val expected = """ | SELECT * | FROM table | WHERE ( a IN (2,3,4)) AND (b LIKE '%R') AND ( b IS NULL) AND ( b IS NOT NULL) @@ -439,12 +436,12 @@ class QueryBuilderSpec extends BaseXDTest { } - it should "be able to support SparkSQL types" in { - val timestampVal = new Timestamp(new GregorianCalendar(1970,0,1,0,0,0).getTimeInMillis) + val timestampVal = new Timestamp(new GregorianCalendar(1970, 0, 1, 0, 0, 0).getTimeInMillis) - val query = selectAll from 'table where ( ('a <> "string") && ('a <> 5f) && ('a <> true) && ('a <> timestampVal) && ('a <> new java.math.BigDecimal(1))) + val query = selectAll from 'table where (('a <> "string") && ('a <> 5f) && ('a <> true) && ('a <> timestampVal) && ('a <> new java.math.BigDecimal( + 1))) val expected = """ @@ -461,13 +458,7 @@ class QueryBuilderSpec extends BaseXDTest { formatOutput(query.build) should be(formatOutput(expected)) } - def formatOutput(query: String): String = query.stripMargin.replaceAll(System.lineSeparator(), " ").trim.replaceAll(" +", " ") - - - - - } diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/test/Utils.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/test/Utils.scala index 26dd75053..8f46e4ec5 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/test/Utils.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/test/Utils.scala @@ -19,7 +19,7 @@ import com.stratio.crossdata.driver.{Driver, JavaDriver} import com.stratio.crossdata.driver.config.DriverConf import com.stratio.crossdata.test.BaseXDTest -object Utils extends BaseXDTest{ +object Utils extends BaseXDTest { def withDriverDo(block: Driver => Unit)(implicit optConfig: Option[DriverConf] = None): Unit = { @@ -31,9 +31,11 @@ object Utils extends BaseXDTest{ } } - def withJavaDriverDo(block: JavaDriver => Unit)(implicit optConfig: Option[DriverConf] = None): Unit = { + def withJavaDriverDo(block: JavaDriver => Unit)( + implicit optConfig: Option[DriverConf] = None): Unit = { - val driver = optConfig.map(driverConf => new JavaDriver(driverConf)).getOrElse(new JavaDriver()) + val driver = + optConfig.map(driverConf => new JavaDriver(driverConf)).getOrElse(new JavaDriver()) try { block(driver) } finally {