diff --git a/.travis.yml b/.travis.yml index d01a1465..94d5b0f7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,12 +14,12 @@ cache: - "$HOME/.sbt" scala: - - 2.10.6 +# - 2.10.6 - 2.11.7 before_script: - "./travis/start-cluster.sh" - - curl -q -sSL http://d3kbcqa49mib13.cloudfront.net/spark-1.6.2-bin-hadoop2.6.tgz | tar -zxf - -C /opt + - curl -q -sSL http://d3kbcqa49mib13.cloudfront.net/spark-2.1.0-bin-hadoop2.7.tgz | tar -zxf - -C /opt script: - "./travis/build.sh" @@ -32,7 +32,7 @@ env: - RIAK_FLAVOR=riak-ts - RIAK_FLAVOR=riak-kv global: - - SPARK_HOME=/opt/spark-1.6.2-bin-hadoop2.6 + - SPARK_HOME=/opt/spark-2.1.0-bin-hadoop2.7 - secure: r9cf5Jdfxsr0MngcKwyvmlKvA0NQF2GnKCDekbsmnQONnW5WUHJjzD5k1LbqiQHdZqqxb72RnOpIEXfJhasteseofsGbZMU5ROAFiohl8WhxHnFb65pMWrXzKg/iqdKhcKZg5akP2fCIDwMz4xVFz0JrG3CtHotjQQ0/6HQ7uDCG2On7h9QB6pt50IarFGRBKd4wBTCkhR+6QKfE4S13mI8gGvq/R7ly7tfqY7i/vFg3z0KN0Vb/+KLwYv/b2aBTBZUli0L/tfBZUQbt9J5Ty9k6lnw+2vlFV9nhKRXMoQjuFi4VDskO1n59eu9pJZ9bhIYngdpnUlrrTpLzRbnXfqSXDufShA4PFOa0EQX/9meFdd1O54y2dnR2K1Gd80Id+NtC66eMa5FND03vGtqNrjjOW1yHkR9FxUhfCMpPWGWXbyBruk/M2Qou+pWp9F7BEvYOeVAkvA2mZ3AGriuBi1iSbJhpDH0plZsQAk2pnXn4CEbXVtpX759KPiA3YyHaMcSpC4QSdXaDqoYGyqazgZL1/MHCKuG8crCn4xe4/4ZbBNLldILtGC6KINas/YhgBctwkY7Sq7VDfxHaMCrSFBDkB0nHRAfG8H3ap5XfaLQkGArPsh0QsjVS92wcbnulIXK3Huf8HLPQTcy+aNyvgTE4O0lMuoHZUA9Tja84kok= - secure: n3+vykdc/JiFBGXgUCt3O1sfWGexlWStERslgeTCxQJXQOv7UbpH/c7tH9wnsndA6Hd1vWLk6MUXu0cH/BrI97LejoCQ+uGm1CHSlOFJ6hSEuC2wxOKGOjbH9HzFQd5EaNIa5XhFisRRjK+anPXicV1zwsda5NFpspVwGcVs+ai7pnh6ysvPOiXKU19uaFPPvv3z9oO1qo7+99fJAR4ty21VGoJmH+hgK5DJ4QLOWdaKgsdclA49Ek9Fwd1Ni+lZW98imnVhHQdtpUqScfYHZcXQTRs6E/daN34Z6cwZ0v8UrX+Y3yatRhUTBy4zE8Nq+OVi02u7kaIAw7rtb5echFfUTW4nMgBSvGnnV3t0MY66JZoEpBzxWaWvhqMdmwbmPpDPdBG/NtPvshtFE6Dle6eBzzGJdu3lpK/QRgeJrdHkDpKt5aX2lAmnGJo0EwQQ6G5U1dQ7ScIZTJndSObyGB7IJtbsc3HFLA4YU5SAdMjQxHREvg+XzXuaL/mMbyJ3ZWA5M5acH43N5UkVKMpqfggdICb6rKdnRKQZRsVtv+oiBGawD1ue5gnsTziwD2xneJ9MSOJ4PQOPfoUHs8q39n8gvsxHpnkPwxgsF7Ed0vXzN7kt5OF/nR1rJc97ZSjSn6oNMYqWfdVRjZHhEVo6EcgdV0xN++3LTVgR4SR+NSs= - secure: eigM8++nPdZRaIq4AM6VFef/X2R+SlCdvh+HICdCz2A+eBiulDNSMH5AzFwriCbUBljffDaWeqHcdmq8R4st9XcVZEk/mtIGs6j24H76ypjb0C6pH9tvze6yzFoCmc2cXa3x+NUcAwgXnM5rzmJYVk8P4ZwwxGsQ/IF7Gr4WgcFaBpg2DVYgKML4DsS+xHEOnN3WOGHg14qTDd+lu4kv99JCRTVYDtqmw3fZO658qQxzat0c2+pRwvRlGFfp2U7U0zDaAhU39CXMGi7nzBEAaWjw5ObVtCa7MQZd6Fh7l8KPV4goA4ngnKBpnJPTvsSrnm3VoP+coT9F6nU+IQCnJW+7N6R8szY1wGteWqrw8I+XA+ts1lo4A2+fnkykzO8TqhS2K5xCCp4EOu4wGRWfHKRQnvdnXKO85FtOIjxdrdsaaA+EG6v3YHvs8rg3VxAsjo1x/cHNFchac5/AUrz3QfJoKWbjybwPMl/sFm01JQMvQLGjPARSLFuUJG9OQQMCmHwxzUGuQXG11Ls4zZWHiJxUub1dhDKV94mc7jJfcyIQd9jyodcgDAB5/Sin7Q9r4Gye/z3R0i4iwx0SNKYNxG15XLFeAk0YxW+h90U3K3t4SjVYW54rtbn156CnImhVIRmRCAJ8IlJJ55+9ObTtTFsNEKS0OaXlSfdGxZOinXQ= diff --git a/CHANGELOG.md b/CHANGELOG.md index 90dbd56b..c5ac4ed5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ # Change Log +## [2.0](https://github.com/basho/spark-riak-connector/releases/tag/v2.0.0) + +* Spark 2.0.2 support; RiakCatalog, RiakSQLContext were removed due to changes in Spark 2.0 API ## [1.6.2](https://github.com/basho/spark-riak-connector/releases/tag/v1.6.2) Critical fix Python KV: if object values are JSON objects with list fields (empty or not) then exception happens (https://github.com/basho/spark-riak-connector/pull/219). diff --git a/Dockerfile b/Dockerfile index 356b2759..cfa44077 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,8 +6,8 @@ MAINTAINER Alexey Suprun # These options could be changed during starting container using --build-arg property with folliwing syntax: # --build-arg ARGUMENT_NAME=value ARG SBT_VERSION=0.13.12 -ARG SPARK_VERSION=1.6.1 -ARG SPARK_HADOOP_VERSION=hadoop2.6 +ARG SPARK_VERSION=2.1.0 +ARG SPARK_HADOOP_VERSION=hadoop2.7 # Set env vars ENV SBT_HOME /usr/local/sbt diff --git a/README.md b/README.md index df1e6e6f..5b594809 100644 --- a/README.md +++ b/README.md @@ -14,10 +14,13 @@ The Spark-Riak connector enables you to connect Spark applications to Riak KV an * Construct a Spark RDD using Riak KV bucket's enhanced 2i query (a.k.a. full bucket read) * Perform parallel full bucket reads from a Riak KV bucket into multiple partitions -## Compatibility +## Version Compatibility + +| Connector | Spark | Riak TS | Riak KV | +|------------|-------|---------|---------| +| 2.X | 2.X | 1.5 | 2.2.0 | +| 1.6.X | 1.6.X | 1.4 | 2.2.0 | -* Riak TS 1.3.1+ -* Apache Spark 1.6+ * Scala 2.10 and 2.11 * Java 8 diff --git a/build.sbt b/build.sbt index bf2d824a..59481780 100644 --- a/build.sbt +++ b/build.sbt @@ -83,7 +83,7 @@ lazy val sparkRiakConnector = (project in file("connector")) }.filter(_.contains("test-utils")).mkString(":") val uberJar = buildDir.relativize((assemblyOutputPath in assembly).value.toPath) - if(!scalaBinaryVersion.value.equals("2.11")) { + if(scalaBinaryVersion.value.equals("2.11")) { val rtnCode = s"connector/python/test.sh $uberJar:$cp $pyTestMark" ! streams.value.log //val rtnCode = s"docker build -t $namespace ." #&& s"docker run --rm -i -e RIAK_HOSTS=$riakHosts -e SPARK_CLASSPATH=$uberJar:$cp -v ${buildDir.toString}:/usr/src/$namespace -v ${home.toString}/.ivy2:/root/.ivy2 -v /var/run/docker.sock:/var/run/docker.sock -v /usr/bin/docker:/bin/docker -w /usr/src/$namespace $namespace ./connector/python/test.sh" ! streams.value.log if (rtnCode != 0) { @@ -101,7 +101,8 @@ lazy val examples = (project in file("examples")) .settings( name := s"$namespace-examples", libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-streaming-kafka" % Versions.spark, + "org.apache.spark" %% "spark-streaming-kafka" % Versions.sparkStreamingKafka + exclude("org.scalatest", s"scalatest_${scalaBinaryVersion.value}"), "org.apache.kafka" %% "kafka" % Versions.kafka)) .settings(publishSettings) .dependsOn(sparkRiakConnector, sparkRiakConnectorTestUtils) @@ -122,7 +123,7 @@ lazy val sparkRiakConnectorTestUtils = (project in file("test-utils")) lazy val commonSettings = Seq( organization := "com.basho.riak", version := "1.6.3-SNAPSHOT", - scalaVersion := "2.10.6", + scalaVersion := "2.11.8", crossPaths := true, spName := s"basho/$namespace", sparkVersion := Versions.spark, @@ -142,8 +143,10 @@ lazy val commonDependencies = Seq( "com.basho.riak" % "riak-client" % Versions.riakClient exclude("io.netty", "netty-all") exclude("org.slf4j", "slf4j-api") exclude("com.fasterxml.jackson.datatype", "jackson-datatype-joda"), - "org.apache.spark" %% "spark-sql" % Versions.spark % "provided", - "org.apache.spark" %% "spark-streaming" % Versions.spark % "provided", + "org.apache.spark" %% "spark-sql" % Versions.spark % "provided" + exclude("org.scalatest", s"scalatest_${scalaBinaryVersion.value}"), + "org.apache.spark" %% "spark-streaming" % Versions.spark % "provided" + exclude("org.scalatest", s"scalatest_${scalaBinaryVersion.value}"), "com.google.guava" % "guava" % Versions.guava, "com.fasterxml.jackson.module" %% "jackson-module-scala" % Versions.jacksonModule exclude("com.google.guava", "guava") exclude("com.google.code.findbugs", "jsr305") @@ -155,7 +158,11 @@ lazy val commonDependencies = Seq( "org.powermock" % "powermock-module-junit4" % Versions.powermokc % "test", "org.powermock" % "powermock-api-mockito" % Versions.powermokc % "test", "com.novocode" % "junit-interface" % Versions.junitInterface % "test", - "com.basho.riak.test" % "riak-test-docker" % Versions.riakTestDocker % "test" + "com.basho.riak.test" % "riak-test-docker" % Versions.riakTestDocker % "test", + "com.spotify" % "docker-client" % "5.0.2" % "test" + exclude("com.fasterxml.jackson.core", "jackson-databind") + exclude("com.fasterxml.jackson.core", "jackson-annotations") + exclude("com.fasterxml.jackson.core", "jackson-core") ), // Connector will use same version of Jackson that Spark uses. No need to incorporate it into uber jar. diff --git a/connector/python/tests/pyspark_tests_fixtures.py b/connector/python/tests/pyspark_tests_fixtures.py index a197d3b1..4239d3cd 100644 --- a/connector/python/tests/pyspark_tests_fixtures.py +++ b/connector/python/tests/pyspark_tests_fixtures.py @@ -2,14 +2,16 @@ import pytest import findspark findspark.init() -from pyspark import SparkContext, SparkConf, SQLContext, Row +from pyspark import SparkContext, SparkConf, Row +from pyspark.sql import SparkSession import riak, pyspark_riak @pytest.fixture(scope="session") def docker_cli(request): # Start spark context to get access to py4j gateway conf = SparkConf().setMaster("local[*]").setAppName("pytest-pyspark-py4j") - sc = SparkContext(conf=conf) + sparkSession = SparkSession.builder.config(conf).getOrCreate() + sc = sparkSession.sparkContext docker_cli = sc._gateway.jvm.com.basho.riak.test.cluster.DockerRiakCluster(1, 2) docker_cli.start() sc.stop() @@ -17,6 +19,22 @@ def docker_cli(request): request.addfinalizer(lambda: docker_cli.stop()) return docker_cli +@pytest.fixture(scope="session") +def spark_session(request): + if not os.environ.has_key('RIAK_HOSTS'): + docker_cli = request.getfuncargvalue('docker_cli') + host_and_port = get_host_and_port(docker_cli) + os.environ['RIAK_HOSTS'] = host_and_port + os.environ['USE_DOCKER'] = 'true' + # Start new spark context + conf = SparkConf().setMaster('local[*]').setAppName('pytest-pyspark-local-testing') + conf.set('spark.riak.connection.host', os.environ['RIAK_HOSTS']) + conf.set('spark.driver.memory', '4g') + conf.set('spark.executor.memory', '4g') + spark_session = SparkSession.builder.config(conf=conf).getOrCreate() + return spark_session + + @pytest.fixture(scope="session") def spark_context(request): # If RIAK_HOSTS is not set, use Docker to start a Riak node @@ -30,16 +48,12 @@ def spark_context(request): conf.set('spark.riak.connection.host', os.environ['RIAK_HOSTS']) conf.set('spark.driver.memory', '4g') conf.set('spark.executor.memory', '4g') - spark_context = SparkContext(conf=conf) + spark_context = SparkSession.builder.config(conf=conf).getOrCreate().sparkContext spark_context.setLogLevel('INFO') pyspark_riak.riak_context(spark_context) request.addfinalizer(lambda: spark_context.stop()) return spark_context -@pytest.fixture(scope="session") -def sql_context(request, spark_context): - sqlContext = SQLContext(spark_context) - return sqlContext @pytest.fixture(scope="session") def riak_client(request): diff --git a/connector/python/tests/test_pyspark_riak.py b/connector/python/tests/test_pyspark_riak.py index ea5822f5..170a195d 100644 --- a/connector/python/tests/test_pyspark_riak.py +++ b/connector/python/tests/test_pyspark_riak.py @@ -4,7 +4,8 @@ from operator import add import findspark findspark.init() -from pyspark import SparkContext, SparkConf, SQLContext, Row +from pyspark import SparkContext, SparkConf, Row +from pyspark.sql import SparkSession import os, subprocess, json, riak, time import pyspark_riak import timeout_decorator @@ -262,7 +263,7 @@ def make_ts_query(riak_ts_table_name, start, end): ###### Riak TS Test ####### -def _test_spark_df_ts_write_use_long(N, M, spark_context, riak_client, sql_context): +def _test_spark_df_ts_write_use_long(N, M, spark_context, riak_client): useLong=True start, end, riak_ts_table_name, test_df, test_rdd, test_data, riak_ts_table = make_table_with_data(N, M, useLong, spark_context, riak_client) @@ -272,7 +273,7 @@ def _test_spark_df_ts_write_use_long(N, M, spark_context, riak_client, sql_conte assert sorted(result.rows, key=lambda x: x[2]) == sorted(test_rdd.collect(), key=lambda x: x[2]) -def _test_spark_df_ts_write_use_timestamp(N, M, spark_context, riak_client, sql_context): +def _test_spark_df_ts_write_use_timestamp(N, M, spark_context, riak_client): useLong=False start_timestamp, end_timestamp, riak_ts_table_name, test_df, test_rdd, timestamp_data, long_data, start_long, end_long, riak_ts_table = make_table_with_data(N, M, useLong, spark_context, riak_client) @@ -282,59 +283,59 @@ def _test_spark_df_ts_write_use_timestamp(N, M, spark_context, riak_client, sql_ result = riak_ts_table.query(query) assert sorted(result.rows, key=lambda x: x[2]) == sorted(spark_context.parallelize(long_data).collect(), key=lambda x: x[2]) -def _test_spark_df_ts_read_use_long(N, M, spark_context, riak_client, sql_context): +def _test_spark_df_ts_read_use_long(N, M, spark_sesion, spark_context, riak_client): useLong=True start, end, riak_ts_table_name, test_df, test_rdd, test_data, riak_ts_table = make_table_with_data(N, M, useLong, spark_context, riak_client) temp_filter = make_filter(useLong, start, end) - result = sql_context.read.format("org.apache.spark.sql.riak").option("spark.riakts.bindings.timestamp", "useLong").load(riak_ts_table_name).filter(temp_filter) + result = spark_sesion.read.format("org.apache.spark.sql.riak").option("spark.riakts.bindings.timestamp", "useLong").load(riak_ts_table_name).filter(temp_filter) assert sorted(result.collect(), key=lambda x: x[2]) == sorted(test_df.collect(), key=lambda x: x[2]) -def _test_spark_df_ts_read_use_long_ts_quantum(N, M, spark_context, riak_client, sql_context): +def _test_spark_df_ts_read_use_long_ts_quantum(N, M, spark_session, spark_context, riak_client): useLong=True start, end, riak_ts_table_name, test_df, test_rdd, test_data, riak_ts_table = make_table_with_data(N, M, useLong, spark_context, riak_client) temp_filter = make_filter(useLong, start, end) - result = sql_context.read.format("org.apache.spark.sql.riak") \ + result = spark_session.read.format("org.apache.spark.sql.riak") \ .option("spark.riakts.bindings.timestamp", "useLong") \ .option("spark.riak.partitioning.ts-quantum", "24h") \ .load(riak_ts_table_name).filter(temp_filter) assert sorted(result.collect(), key=lambda x: x[2]) == sorted(test_df.collect(), key=lambda x: x[2]) -def _test_spark_df_ts_read_use_timestamp(N, M, spark_context, riak_client, sql_context): +def _test_spark_df_ts_read_use_timestamp(N, M, spark_session, spark_context, riak_client): useLong=False start_timestamp, end_timestamp, riak_ts_table_name, test_df, test_rdd, timestamp_data, long_data, start_long, end_long, riak_ts_table = make_table_with_data(N, M, useLong, spark_context, riak_client) temp_filter = make_filter(useLong, unix_time_seconds(start_timestamp), unix_time_seconds(end_timestamp)) - result = sql_context.read.format("org.apache.spark.sql.riak").option("spark.riakts.bindings.timestamp", "useTimestamp").load(riak_ts_table_name).filter(temp_filter) + result = spark_session.read.format("org.apache.spark.sql.riak").option("spark.riakts.bindings.timestamp", "useTimestamp").load(riak_ts_table_name).filter(temp_filter) assert sorted(result.collect(), key=lambda x: x[2]) == sorted(test_df.collect(), key=lambda x: x[2]) -def _test_spark_df_ts_read_use_timestamp_ts_quantum(N, M, spark_context, riak_client, sql_context): +def _test_spark_df_ts_read_use_timestamp_ts_quantum(N, M, spark_session, spark_context, riak_client): useLong=False start_timestamp, end_timestamp, riak_ts_table_name, test_df, test_rdd, timestamp_data, long_data, start_long, end_long, riak_ts_table = make_table_with_data(N, M, useLong, spark_context, riak_client) temp_filter = make_filter(useLong, unix_time_seconds(start_timestamp), unix_time_seconds(end_timestamp)) - result = sql_context.read.format("org.apache.spark.sql.riak").option("spark.riakts.bindings.timestamp", "useTimestamp").option("spark.riak.partitioning.ts-quantum", "24h").load(riak_ts_table_name).filter(temp_filter) + result = spark_session.read.format("org.apache.spark.sql.riak").option("spark.riakts.bindings.timestamp", "useTimestamp").option("spark.riak.partitioning.ts-quantum", "24h").load(riak_ts_table_name).filter(temp_filter) assert sorted(result.collect(), key=lambda x: x[2]) == sorted(test_df.collect(), key=lambda x: x[2]) -def _test_spark_df_ts_range_query_input_split_count_use_long(N, M, S,spark_context, riak_client, sql_context): +def _test_spark_df_ts_range_query_input_split_count_use_long(N, M, S, spark_session, spark_context, riak_client): useLong=True start, end, riak_ts_table_name, test_df, test_rdd, test_data, riak_ts_table = make_table_with_data(N, M, useLong, spark_context, riak_client) time.sleep(1) temp_filter = make_filter(useLong, start, end) - result = sql_context.read.format("org.apache.spark.sql.riak") \ + result = spark_session.read.format("org.apache.spark.sql.riak") \ .option("spark.riakts.bindings.timestamp", "useLong") \ .option("spark.riak.input.split.count", str(S)) \ .option("spark.riak.partitioning.ts-range-field-name", "datetime") \ @@ -344,13 +345,13 @@ def _test_spark_df_ts_range_query_input_split_count_use_long(N, M, S,spark_conte assert sorted(result.collect(), key=lambda x: x[2]) == sorted(test_df.collect(), key=lambda x: x[2]) assert result.rdd.getNumPartitions() == S -def _test_spark_df_ts_range_query_input_split_count_use_long_ts_quantum(N, M, S,spark_context, riak_client, sql_context): +def _test_spark_df_ts_range_query_input_split_count_use_long_ts_quantum(N, M, S, spark_session, spark_context, riak_client): useLong=True start, end, riak_ts_table_name, test_df, test_rdd, test_data, riak_ts_table = make_table_with_data(N, M, useLong, spark_context, riak_client) temp_filter = make_filter(useLong, start, end) - result = sql_context.read.format("org.apache.spark.sql.riak") \ + result = spark_session.read.format("org.apache.spark.sql.riak") \ .option("spark.riakts.bindings.timestamp", "useLong") \ .option("spark.riak.partitioning.ts-quantum", "24h") \ .option("spark.riak.input.split.count", str(S)) \ @@ -360,13 +361,13 @@ def _test_spark_df_ts_range_query_input_split_count_use_long_ts_quantum(N, M, S, assert sorted(result.collect(), key=lambda x: x[2]) == sorted(test_df.collect(), key=lambda x: x[2]) assert result.rdd.getNumPartitions() == S -def _test_spark_df_ts_range_query_input_split_count_use_timestamp(N, M, S,spark_context, riak_client, sql_context): +def _test_spark_df_ts_range_query_input_split_count_use_timestamp(N, M, S, spark_session, spark_context, riak_client): useLong=False start_timestamp, end_timestamp, riak_ts_table_name, test_df, test_rdd, timestamp_data, long_data, start_long, end_long, riak_ts_table = make_table_with_data(N, M, useLong, spark_context, riak_client) temp_filter = make_filter(useLong, unix_time_seconds(start_timestamp), unix_time_seconds(end_timestamp)) - result = sql_context.read.format("org.apache.spark.sql.riak") \ + result = spark_session.read.format("org.apache.spark.sql.riak") \ .option("spark.riakts.bindings.timestamp", "useTimestamp") \ .option("spark.riak.input.split.count", str(S)) \ .option("spark.riak.partitioning.ts-range-field-name", "datetime") \ @@ -375,12 +376,12 @@ def _test_spark_df_ts_range_query_input_split_count_use_timestamp(N, M, S,spark_ assert sorted(result.collect(), key=lambda x: x[2]) == sorted(test_df.collect(), key=lambda x: x[2]) assert result.rdd.getNumPartitions() == S -def _test_spark_df_ts_range_query_input_split_count_use_timestamp_ts_quantum(N, M, S,spark_context, riak_client, sql_context): +def _test_spark_df_ts_range_query_input_split_count_use_timestamp_ts_quantum(N, M, S, spark_session, spark_context, riak_client): useLong=False start_timestamp, end_timestamp, riak_ts_table_name, test_df, test_rdd, timestamp_data, long_data, start_long, end_long, riak_ts_table = make_table_with_data(N, M, useLong, spark_context, riak_client) temp_filter = make_filter(useLong, unix_time_seconds(start_timestamp), unix_time_seconds(end_timestamp)) - result = sql_context.read.format("org.apache.spark.sql.riak") \ + result = spark_session.read.format("org.apache.spark.sql.riak") \ .option("spark.riakts.bindings.timestamp", "useTimestamp") \ .option("spark.riak.partitioning.ts-quantum", "24h") \ .option("spark.riak.input.split.count", str(S)) \ @@ -392,7 +393,7 @@ def _test_spark_df_ts_range_query_input_split_count_use_timestamp_ts_quantum(N, ###### Riak KV Tests ###### -def _test_spark_rdd_write_kv(N, spark_context, riak_client, sql_context): +def _test_spark_rdd_write_kv(N, spark_context, riak_client): test_bucket_name = "test-bucket-"+str(randint(0,100000)) @@ -404,7 +405,7 @@ def _test_spark_rdd_write_kv(N, spark_context, riak_client, sql_context): assert sorted(source_data) == sorted(test_data) -def _test_spark_rdd_kv_read_query_all(N, spark_context, riak_client, sql_context): +def _test_spark_rdd_kv_read_query_all(N, spark_context, riak_client): test_bucket_name = "test-bucket-"+str(randint(0,100000)) @@ -416,7 +417,7 @@ def _test_spark_rdd_kv_read_query_all(N, spark_context, riak_client, sql_context assert sorted(result.collect(), key=lambda x: x[0]) == sorted(test_data, key=lambda x: x[0]) -def _test_spark_rdd_kv_read_query_bucket_keys(N, spark_context, riak_client, sql_context): +def _test_spark_rdd_kv_read_query_bucket_keys(N, spark_context, riak_client): test_bucket_name = "test-bucket-"+str(randint(0,100000)) @@ -432,7 +433,7 @@ def _test_spark_rdd_kv_read_query_bucket_keys(N, spark_context, riak_client, sql assert sorted(result.collect(), key=lambda x: x[0]) == sorted([], key=lambda x: x[0]) -def _test_spark_rdd_kv_read_query_2i_keys(N, spark_context, riak_client, sql_context): +def _test_spark_rdd_kv_read_query_2i_keys(N, spark_context, riak_client): test_bucket_name = "test-bucket-"+str(randint(0,100000)) @@ -446,7 +447,7 @@ def _test_spark_rdd_kv_read_query_2i_keys(N, spark_context, riak_client, sql_con assert sorted(result.collect(), key=lambda x: x[0]) == sorted(test_data, key=lambda x: x[0]) -def _test_spark_rdd_kv_read_query2iRange(N, spark_context, riak_client, sql_context): +def _test_spark_rdd_kv_read_query2iRange(N, spark_context, riak_client): test_bucket_name = "test-bucket-"+str(randint(0,100000)) @@ -460,7 +461,7 @@ def _test_spark_rdd_kv_read_query2iRange(N, spark_context, riak_client, sql_cont assert sorted(result.collect(), key=lambda x: x[0]) == sorted([], key=lambda x: x[0]) -def _test_spark_rdd_kv_read_partition_by_2i_range(N, spark_context, riak_client, sql_context): +def _test_spark_rdd_kv_read_partition_by_2i_range(N, spark_context, riak_client): test_bucket_name = "test-bucket-"+str(randint(0,100000)) @@ -478,7 +479,7 @@ def _test_spark_rdd_kv_read_partition_by_2i_range(N, spark_context, riak_client, assert result.getNumPartitions() == N -def _test_spark_rdd_kv_read_partition_by_2i_keys(N, spark_context, riak_client, sql_context): +def _test_spark_rdd_kv_read_partition_by_2i_keys(N, spark_context, riak_client): test_bucket_name = "test-bucket-"+str(randint(0,100000)) @@ -506,32 +507,32 @@ def _test_spark_rdd_kv_read_partition_by_2i_keys(N, spark_context, riak_client, ###### KV Tests ####### @pytest.mark.riakkv -def test_kv_write(spark_context, riak_client, sql_context): - _test_spark_rdd_write_kv(10, spark_context, riak_client, sql_context) +def test_kv_write(spark_context, riak_client): + _test_spark_rdd_write_kv(10, spark_context, riak_client) @pytest.mark.riakkv -def test_kv_query_all(spark_context, riak_client, sql_context): - _test_spark_rdd_kv_read_query_all(10, spark_context, riak_client, sql_context) +def test_kv_query_all(spark_context, riak_client): + _test_spark_rdd_kv_read_query_all(10, spark_context, riak_client) @pytest.mark.riakkv -def test_kv_query_bucket_keys(spark_context, riak_client, sql_context): - _test_spark_rdd_kv_read_query_bucket_keys(10, spark_context, riak_client, sql_context) +def test_kv_query_bucket_keys(spark_context, riak_client): + _test_spark_rdd_kv_read_query_bucket_keys(10, spark_context, riak_client) @pytest.mark.riakkv -def test_kv_query_2i_keys(spark_context, riak_client, sql_context): - _test_spark_rdd_kv_read_query_2i_keys(10, spark_context, riak_client, sql_context) +def test_kv_query_2i_keys(spark_context, riak_client): + _test_spark_rdd_kv_read_query_2i_keys(10, spark_context, riak_client) @pytest.mark.riakkv -def test_kv_query_2i_range(spark_context, riak_client, sql_context): - _test_spark_rdd_kv_read_query2iRange(10, spark_context, riak_client, sql_context) +def test_kv_query_2i_range(spark_context, riak_client): + _test_spark_rdd_kv_read_query2iRange(10, spark_context, riak_client) @pytest.mark.riakkv -def test_kv_query_partition_by_2i_range(spark_context, riak_client, sql_context): - _test_spark_rdd_kv_read_partition_by_2i_range(10, spark_context, riak_client, sql_context) +def test_kv_query_partition_by_2i_range(spark_context, riak_client): + _test_spark_rdd_kv_read_partition_by_2i_range(10, spark_context, riak_client) @pytest.mark.riakkv -def test_kv_query_partition_by_2i_keys(spark_context, riak_client, sql_context): - _test_spark_rdd_kv_read_partition_by_2i_keys(10, spark_context, riak_client, sql_context) +def test_kv_query_partition_by_2i_keys(spark_context, riak_client): + _test_spark_rdd_kv_read_partition_by_2i_keys(10, spark_context, riak_client) # # if object values are JSON objects with more than 4 keys exception happens @@ -607,41 +608,41 @@ def test_read_JSON_value_with_an_empty_map (spark_context, riak_client): ###### TS Tests ####### @pytest.mark.riakts -def test_ts_df_write_use_timestamp(spark_context, riak_client, sql_context): - _test_spark_df_ts_write_use_timestamp(10, 5, spark_context, riak_client, sql_context) +def test_ts_df_write_use_timestamp(spark_session, spark_context, riak_client): + _test_spark_df_ts_write_use_timestamp(10, 5, spark_context, riak_client) @pytest.mark.riakts -def test_ts_df_write_use_long(spark_context, riak_client, sql_context): - _test_spark_df_ts_write_use_long(10, 5, spark_context, riak_client, sql_context) +def test_ts_df_write_use_long(spark_context, riak_client): + _test_spark_df_ts_write_use_long(10, 5, spark_context, riak_client) @pytest.mark.riakts -def test_ts_df_read_use_timestamp(spark_context, riak_client, sql_context): - _test_spark_df_ts_read_use_timestamp(10, 5, spark_context, riak_client, sql_context) +def test_ts_df_read_use_timestamp(spark_session, spark_context, riak_client): + _test_spark_df_ts_read_use_timestamp(10, 5, spark_session, spark_context, riak_client) @pytest.mark.riakts -def test_ts_df_read_use_long(spark_context, riak_client, sql_context): - _test_spark_df_ts_read_use_long(10, 5, spark_context, riak_client, sql_context) +def test_ts_df_read_use_long(spark_session, spark_context, riak_client): + _test_spark_df_ts_read_use_long(10, 5, spark_session, spark_context, riak_client) @pytest.mark.riakts -def test_ts_df_read_use_timestamp_ts_quantum(spark_context, riak_client, sql_context): - _test_spark_df_ts_read_use_timestamp_ts_quantum(10, 5, spark_context, riak_client, sql_context) +def test_ts_df_read_use_timestamp_ts_quantum(spark_session, spark_context, riak_client): + _test_spark_df_ts_read_use_timestamp_ts_quantum(10, 5, spark_session, spark_context, riak_client) @pytest.mark.riakts -def test_ts_df_read_use_long_ts_quantum(spark_context, riak_client, sql_context): - _test_spark_df_ts_read_use_long_ts_quantum(10, 5, spark_context, riak_client, sql_context) +def test_ts_df_read_use_long_ts_quantum(spark_session, spark_context, riak_client): + _test_spark_df_ts_read_use_long_ts_quantum(10, 5, spark_session, spark_context, riak_client) @pytest.mark.riakts -def test_ts_df_range_query_input_split_count_use_timestamp(spark_context, riak_client, sql_context): - _test_spark_df_ts_range_query_input_split_count_use_timestamp(10, 5, 3, spark_context, riak_client, sql_context) +def test_ts_df_range_query_input_split_count_use_timestamp(spark_session, spark_context, riak_client): + _test_spark_df_ts_range_query_input_split_count_use_timestamp(10, 5, 3, spark_session, spark_context, riak_client) @pytest.mark.riakts -def test_ts_df_range_query_input_split_count_use_long(spark_context, riak_client, sql_context): - _test_spark_df_ts_range_query_input_split_count_use_long(10, 5, 3, spark_context, riak_client, sql_context) +def test_ts_df_range_query_input_split_count_use_long(spark_session, spark_context, riak_client): + _test_spark_df_ts_range_query_input_split_count_use_long(10, 5, 3, spark_session, spark_context, riak_client) @pytest.mark.riakts -def test_ts_df_range_query_input_split_count_use_timestamp_ts_quantum(spark_context, riak_client, sql_context): - _test_spark_df_ts_range_query_input_split_count_use_timestamp_ts_quantum(10, 5, 3, spark_context, riak_client, sql_context) +def test_ts_df_range_query_input_split_count_use_timestamp_ts_quantum(spark_session, spark_context, riak_client): + _test_spark_df_ts_range_query_input_split_count_use_timestamp_ts_quantum(10, 5, 3, spark_session, spark_context, riak_client) @pytest.mark.riakts -def test_ts_df_range_query_input_split_count_use_long_ts_quantum(spark_context, riak_client, sql_context): - _test_spark_df_ts_range_query_input_split_count_use_long_ts_quantum(10, 5, 3, spark_context, riak_client, sql_context) +def test_ts_df_range_query_input_split_count_use_long_ts_quantum(spark_session, spark_context, riak_client): + _test_spark_df_ts_range_query_input_split_count_use_long_ts_quantum(10, 5, 3, spark_session, spark_context, riak_client) diff --git a/connector/src/main/scala/com/basho/riak/spark/query/KVDataQueryingIterator.scala b/connector/src/main/scala/com/basho/riak/spark/query/KVDataQueryingIterator.scala index 4592f0f4..50b0ffb9 100644 --- a/connector/src/main/scala/com/basho/riak/spark/query/KVDataQueryingIterator.scala +++ b/connector/src/main/scala/com/basho/riak/spark/query/KVDataQueryingIterator.scala @@ -18,7 +18,7 @@ package com.basho.riak.spark.query import com.basho.riak.client.core.query.{Location, RiakObject} -import org.apache.spark.Logging +import org.apache.spark.riak.Logging class KVDataQueryingIterator[T](query: Query[T]) extends Iterator[(Location, RiakObject)] with Logging { diff --git a/connector/src/main/scala/com/basho/riak/spark/query/Query.scala b/connector/src/main/scala/com/basho/riak/spark/query/Query.scala index 190404a4..bf1d23f8 100644 --- a/connector/src/main/scala/com/basho/riak/spark/query/Query.scala +++ b/connector/src/main/scala/com/basho/riak/spark/query/Query.scala @@ -24,7 +24,7 @@ import com.basho.riak.client.core.query.{Location, RiakObject} import com.basho.riak.client.core.util.HostAndPort import com.basho.riak.spark.rdd.connector.RiakConnector import com.basho.riak.spark.rdd.{BucketDef, ReadConf} -import org.apache.spark.Logging +import org.apache.spark.riak.Logging import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer diff --git a/connector/src/main/scala/com/basho/riak/spark/query/TSDataQueryingIterator.scala b/connector/src/main/scala/com/basho/riak/spark/query/TSDataQueryingIterator.scala index f9150704..893dcdea 100644 --- a/connector/src/main/scala/com/basho/riak/spark/query/TSDataQueryingIterator.scala +++ b/connector/src/main/scala/com/basho/riak/spark/query/TSDataQueryingIterator.scala @@ -18,8 +18,8 @@ package com.basho.riak.spark.query import com.basho.riak.client.core.query.timeseries.Row -import org.apache.spark.Logging import com.basho.riak.client.core.query.timeseries.ColumnDescription +import org.apache.spark.riak.Logging class TSDataQueryingIterator(query: QueryTS) extends Iterator[Row] with Logging { diff --git a/connector/src/main/scala/com/basho/riak/spark/rdd/ReadConf.scala b/connector/src/main/scala/com/basho/riak/spark/rdd/ReadConf.scala index 982338b1..7136c435 100644 --- a/connector/src/main/scala/com/basho/riak/spark/rdd/ReadConf.scala +++ b/connector/src/main/scala/com/basho/riak/spark/rdd/ReadConf.scala @@ -36,7 +36,7 @@ case class ReadConf ( /** * Used only in ranged partitioner to identify quantized field. * Usage example: - * sqlContext.read + * sparkSession.read * .option("spark.riak.partitioning.ts-range-field-name", "time") * Providing this property automatically turns on RangedRiakTSPartitioner */ diff --git a/connector/src/main/scala/com/basho/riak/spark/rdd/RiakRDD.scala b/connector/src/main/scala/com/basho/riak/spark/rdd/RiakRDD.scala index dafdd29b..78fcb129 100644 --- a/connector/src/main/scala/com/basho/riak/spark/rdd/RiakRDD.scala +++ b/connector/src/main/scala/com/basho/riak/spark/rdd/RiakRDD.scala @@ -25,7 +25,8 @@ import com.basho.riak.spark.rdd.partitioner._ import com.basho.riak.spark.util.{CountingIterator, DataConvertingIterator} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} +import org.apache.spark.riak.Logging +import org.apache.spark.{Partition, SparkContext, TaskContext} import scala.language.existentials import scala.reflect.ClassTag diff --git a/connector/src/main/scala/com/basho/riak/spark/rdd/RiakTSRDD.scala b/connector/src/main/scala/com/basho/riak/spark/rdd/RiakTSRDD.scala index 32d713f0..ed1d9e82 100644 --- a/connector/src/main/scala/com/basho/riak/spark/rdd/RiakTSRDD.scala +++ b/connector/src/main/scala/com/basho/riak/spark/rdd/RiakTSRDD.scala @@ -17,18 +17,20 @@ */ package com.basho.riak.spark.rdd -import com.basho.riak.client.core.query.timeseries.{ Row, ColumnDescription } -import com.basho.riak.spark.query.{ TSQueryData, QueryTS } +import com.basho.riak.client.core.query.timeseries.{ColumnDescription, Row} +import com.basho.riak.spark.query.{QueryTS, TSQueryData} import com.basho.riak.spark.rdd.connector.RiakConnector -import com.basho.riak.spark.rdd.partitioner.{ RiakTSPartition, RiakTSPartitioner } -import com.basho.riak.spark.util.{ TSConversionUtil, CountingIterator, DataConvertingIterator } +import com.basho.riak.spark.rdd.partitioner.{RiakTSPartition, RiakTSPartitioner} +import com.basho.riak.spark.util.{CountingIterator, DataConvertingIterator, TSConversionUtil} import org.apache.spark.sql.types.StructType -import org.apache.spark.{ TaskContext, Partition, Logging, SparkContext } +import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD + import scala.reflect.ClassTag import org.apache.spark.sql.sources.Filter -import com.basho.riak.spark.rdd.partitioner.{ SinglePartitionRiakTSPartitioner, RangedRiakTSPartitioner } +import com.basho.riak.spark.rdd.partitioner.{RangedRiakTSPartitioner, SinglePartitionRiakTSPartitioner} import com.basho.riak.spark.query.TSDataQueryingIterator +import org.apache.spark.riak.Logging /** * @author Sergey Galkin diff --git a/connector/src/main/scala/com/basho/riak/spark/rdd/connector/RiakConnector.scala b/connector/src/main/scala/com/basho/riak/spark/rdd/connector/RiakConnector.scala index 68f3374f..eeea6431 100644 --- a/connector/src/main/scala/com/basho/riak/spark/rdd/connector/RiakConnector.scala +++ b/connector/src/main/scala/com/basho/riak/spark/rdd/connector/RiakConnector.scala @@ -19,7 +19,8 @@ package com.basho.riak.spark.rdd.connector import com.basho.riak.client.core.util.HostAndPort -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.riak.Logging +import org.apache.spark.SparkConf /** * Provides and manages [[RiakSession]]. diff --git a/connector/src/main/scala/com/basho/riak/spark/rdd/connector/RiakConnectorConf.scala b/connector/src/main/scala/com/basho/riak/spark/rdd/connector/RiakConnectorConf.scala index 75379f74..deb3f858 100644 --- a/connector/src/main/scala/com/basho/riak/spark/rdd/connector/RiakConnectorConf.scala +++ b/connector/src/main/scala/com/basho/riak/spark/rdd/connector/RiakConnectorConf.scala @@ -19,10 +19,11 @@ package com.basho.riak.spark.rdd.connector import java.net.InetAddress import com.basho.riak.client.core.util.HostAndPort -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf import scala.collection.JavaConversions._ import scala.util.control.NonFatal import com.basho.riak.client.core.RiakNode +import org.apache.spark.riak.Logging /** Stores configuration of a connection to Riak. diff --git a/connector/src/main/scala/com/basho/riak/spark/rdd/connector/SessionCache.scala b/connector/src/main/scala/com/basho/riak/spark/rdd/connector/SessionCache.scala index 8e2c3c25..f117601a 100644 --- a/connector/src/main/scala/com/basho/riak/spark/rdd/connector/SessionCache.scala +++ b/connector/src/main/scala/com/basho/riak/spark/rdd/connector/SessionCache.scala @@ -25,7 +25,6 @@ import com.basho.riak.client.api.{RiakClient, RiakCommand} import com.basho.riak.client.core.{FutureOperation, RiakCluster, RiakFuture, RiakNode} import com.basho.riak.client.core.util.HostAndPort import com.google.common.cache._ -import org.apache.spark.Logging import scala.collection.JavaConverters._ import java.util.concurrent.{Executors, ScheduledThreadPoolExecutor, ThreadFactory, TimeUnit} @@ -34,6 +33,7 @@ import io.netty.bootstrap.Bootstrap import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.nio.NioSocketChannel import io.netty.util.concurrent.DefaultThreadFactory +import org.apache.spark.riak.Logging import scala.collection.concurrent.TrieMap diff --git a/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/RiakCoveragePlanBasedPartitioner.scala b/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/RiakCoveragePlanBasedPartitioner.scala index d4da344e..f7324ce2 100644 --- a/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/RiakCoveragePlanBasedPartitioner.scala +++ b/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/RiakCoveragePlanBasedPartitioner.scala @@ -24,7 +24,8 @@ import com.basho.riak.spark.query.QueryData import com.basho.riak.spark.rdd.connector.RiakConnector import com.basho.riak.spark.rdd.partitioner.PartitioningUtils._ import com.basho.riak.spark.rdd.{BucketDef, ReadConf, RiakPartition} -import org.apache.spark.{Logging, Partition} +import org.apache.spark.Partition +import org.apache.spark.riak.Logging import scala.collection.JavaConversions._ import scala.util.control.Exception._ @@ -35,6 +36,9 @@ case class RiakLocalCoveragePartition[K]( primaryHost: HostAndPort, queryData: QueryData[K]) extends RiakPartition +/** + * Obtains Coverage Plan and creates a separate partition for each Coverage Entry + */ object RiakCoveragePlanBasedPartitioner extends Logging { def partitions[K](connector: RiakConnector, bucket: BucketDef, readConf: ReadConf, queryData: QueryData[K]): Array[Partition] = { diff --git a/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/RiakTSPartitioner.scala b/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/RiakTSPartitioner.scala index 98617c48..bec5c8e9 100644 --- a/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/RiakTSPartitioner.scala +++ b/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/RiakTSPartitioner.scala @@ -19,7 +19,7 @@ package com.basho.riak.spark.rdd.partitioner import java.sql.Timestamp -import org.apache.spark.{Logging, Partition} +import org.apache.spark.Partition import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import com.basho.riak.client.core.util.HostAndPort @@ -34,6 +34,7 @@ import scala.collection.JavaConversions._ import scala.util.control.Exception._ import com.basho.riak.client.core.query.timeseries.CoverageEntry import com.basho.riak.spark.util.DumpUtils +import org.apache.spark.riak.Logging /** * @author Sergey Galkin @@ -84,7 +85,7 @@ trait RiakTSPartitioner { } recursiveInterpolateFirst(sql, values.iterator) } - + //scalastyle:off protected def toSql(columnNames: Option[Seq[String]], tableName: String, schema: Option[StructType], whereConstraints: (String, Seq[Any])): (String, Seq[Any]) = { var values: Seq[Any] = Seq.empty[Nothing] val sql = "SELECT " + @@ -111,26 +112,32 @@ trait RiakTSPartitioner { } /** Construct Sql clause */ - protected def filterToSqlAndValue(filter: Any): (String, Any) = { + protected def filterToSqlAndValue(filter: Any): (String, Option[Any]) = { val (attribute, sqlOperator, value) = filter match { case EqualTo(a, v) => (a, "=", v) case LessThan(a, v) => (a, "<", v) case LessThanOrEqual(a, v) => (a, "<=", v) case GreaterThan(a, v) => (a, ">", v) case GreaterThanOrEqual(a, v) => (a, ">=", v) + case IsNotNull(a) => (a, "IS NOT NULL", None) + case IsNull(a) => (a, "IS NULL", None) case _ => throw new UnsupportedOperationException( - s"It's not a valid filter $filter to be pushed down, only >, <, >=, <= and = are allowed.") + s"It's not a valid filter $filter to be pushed down, only is not null, is null, >, <, >=, <= and = are allowed.") } // TODO: need to add pattern matching for values, to be sure that they are used correctly - (s"$attribute $sqlOperator ?", value) - } + value match { + case None => (s"$attribute $sqlOperator", None) + case _ => (s"$attribute $sqlOperator ?", Some(value)) + } + } + //scalastyle:on protected def whereClause(filters: Array[Filter]): (String, Seq[Any]) = { val sqlValue = filters.map(filterToSqlAndValue) val sql = sqlValue.map(_._1).mkString(" AND ") - val args = sqlValue.map(_._2) + val args = sqlValue.flatMap(_._2) // Changed to use flatMap to remove None arguments (sql, args.seq) } } diff --git a/connector/src/main/scala/com/basho/riak/spark/util/DataMapper.scala b/connector/src/main/scala/com/basho/riak/spark/util/DataMapper.scala index 8023be89..daf863d0 100644 --- a/connector/src/main/scala/com/basho/riak/spark/util/DataMapper.scala +++ b/connector/src/main/scala/com/basho/riak/spark/util/DataMapper.scala @@ -19,7 +19,7 @@ package com.basho.riak.spark.util import com.basho.riak.client.api.convert.JSONConverter import com.fasterxml.jackson.module.scala.DefaultScalaModule -import org.apache.spark.Logging +import org.apache.spark.riak.Logging trait DataMapper extends Serializable { DataMapper.ensureInitialized() diff --git a/connector/src/main/scala/com/basho/riak/spark/writer/RiakWriter.scala b/connector/src/main/scala/com/basho/riak/spark/writer/RiakWriter.scala index 12697d6e..03d95d63 100644 --- a/connector/src/main/scala/com/basho/riak/spark/writer/RiakWriter.scala +++ b/connector/src/main/scala/com/basho/riak/spark/writer/RiakWriter.scala @@ -31,8 +31,8 @@ import com.basho.riak.spark.rdd.connector.{RiakConnector, RiakSession} import com.basho.riak.spark.util.{CountingIterator, DataMapper} import com.basho.riak.spark.writer.ts.RowDef import com.fasterxml.jackson.module.scala.DefaultScalaModule -import org.apache.spark.riak.RiakWriterTaskCompletionListener -import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.riak.Logging +import org.apache.spark.TaskContext import scala.collection.JavaConversions._ import scala.collection._ @@ -78,7 +78,6 @@ abstract class RiakWriter[T, U]( val endTime = System.currentTimeMillis() val duration = (endTime - startTime) / 1000.0 logDebug(s"Writing FINISHED in $duration seconds") - taskContext.addTaskCompletionListener(RiakWriterTaskCompletionListener(rowIterator.count)) } } diff --git a/connector/src/main/scala/org/apache/spark/riak/Logging.scala b/connector/src/main/scala/org/apache/spark/riak/Logging.scala new file mode 100644 index 00000000..c76101fa --- /dev/null +++ b/connector/src/main/scala/org/apache/spark/riak/Logging.scala @@ -0,0 +1,22 @@ +/** + * Copyright (c) 2015 Basho Technologies, Inc. + * + * This file is provided to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file + * except in compliance with the License. You may obtain + * a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.riak + +trait Logging extends org.apache.spark.internal.Logging { + +} diff --git a/connector/src/main/scala/org/apache/spark/riak/RiakWriterTaskCompletionListener.scala b/connector/src/main/scala/org/apache/spark/riak/RiakWriterTaskCompletionListener.scala deleted file mode 100644 index b217159e..00000000 --- a/connector/src/main/scala/org/apache/spark/riak/RiakWriterTaskCompletionListener.scala +++ /dev/null @@ -1,19 +0,0 @@ -package org.apache.spark.riak - -import org.apache.spark.TaskContext -import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} -import org.apache.spark.util.TaskCompletionListener - -class RiakWriterTaskCompletionListener(recordsWritten: Long) extends TaskCompletionListener{ - - override def onTaskCompletion(context: TaskContext): Unit = { - val metrics = OutputMetrics(DataWriteMethod.Hadoop) - metrics.setRecordsWritten(recordsWritten) - context.taskMetrics().outputMetrics = Some(metrics) - } - -} - -object RiakWriterTaskCompletionListener { - def apply(recordsWritten: Long) = new RiakWriterTaskCompletionListener(recordsWritten) -} \ No newline at end of file diff --git a/connector/src/main/scala/org/apache/spark/sql/riak/RiakCatalog.scala b/connector/src/main/scala/org/apache/spark/sql/riak/RiakCatalog.scala deleted file mode 100644 index fd2c74ea..00000000 --- a/connector/src/main/scala/org/apache/spark/sql/riak/RiakCatalog.scala +++ /dev/null @@ -1,123 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2016 IBM Corp. - * - * Created by Basho Technologies for IBM - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ -package org.apache.spark.sql.riak - -import java.util.concurrent.ExecutionException - -import com.basho.riak.client.core.netty.RiakResponseException -import com.basho.riak.client.core.operations.FetchBucketPropsOperation -import com.basho.riak.client.core.query.Namespace -import com.basho.riak.spark.rdd.ReadConf -import com.basho.riak.spark.rdd.connector.RiakConnector -import com.basho.riak.spark.writer.WriteConf -import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} -import org.apache.spark.sql.catalyst.analysis.Catalog -import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan} -import org.apache.spark.sql.execution.datasources.LogicalRelation - -/** - * @author Sergey Galkin - * @since 1.2.0 - */ -private[sql] class RiakCatalog(rsc: RiakSQLContext, - riakConnector: RiakConnector, - readConf: ReadConf, - writeConf: WriteConf) extends Catalog with Logging { - private val CACHE_SIZE = 1000 - - /** A cache of Spark SQL data source tables that have been accessed. Cache is thread safe. */ - private[riak] val cachedDataSourceTables: LoadingCache[String, LogicalPlan] = { - val cacheLoader = new CacheLoader[String, LogicalPlan]() { - override def load(tableIdent: String): LogicalPlan = { - logDebug(s"Creating new cached data source for ${tableIdent.mkString(".")}") - buildRelation(tableIdent) - } - } - CacheBuilder.newBuilder().maximumSize(CACHE_SIZE).build(cacheLoader) - } - - override def refreshTable(tableIdent: TableIdentifier): Unit = { - val table = bucketIdent(tableIdent) - cachedDataSourceTables.refresh(table) - } - - override val conf: SimpleCatalystConf = SimpleCatalystConf(true) - - - override def unregisterAllTables(): Unit = { - cachedDataSourceTables.invalidateAll() - } - - override def unregisterTable(tableIdentifier: TableIdentifier): Unit = { - val tableIdent = bucketIdent(tableIdentifier) - cachedDataSourceTables.invalidate(tableIdent) - } - - override def lookupRelation(tableIdentifier: TableIdentifier, alias: Option[String]): LogicalPlan = { - val tableIdent = bucketIdent(tableIdentifier) - val tableLogicPlan = cachedDataSourceTables.get(tableIdent) - alias.map(a => Subquery(a, tableLogicPlan)).getOrElse(tableLogicPlan) - } - - override def registerTable(tableIdentifier: TableIdentifier, plan: LogicalPlan): Unit = { - val tableIdent = bucketIdent(tableIdentifier) - cachedDataSourceTables.put(tableIdent, plan) - } - - override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - getTablesFromRiakTS(databaseName) - } - - override def tableExists(tableIdentifier: TableIdentifier): Boolean = { - val tableIdent = bucketIdent(tableIdentifier) - val fetchProps = new FetchBucketPropsOperation.Builder(new Namespace(tableIdent, tableIdent)).build() - - riakConnector.withSessionDo(session => { - session.execute(fetchProps) - }) - - try { - fetchProps.get().getBucketProperties - true - } catch { - case ex: ExecutionException if ex.getCause.isInstanceOf[RiakResponseException] - && ex.getCause.getMessage.startsWith("No bucket-type named") => - false - } - } - - def getTablesFromRiakTS(databaseName: Option[String]): Seq[(String, Boolean)] = Nil - - /** Build logic plan from a RiakRelation */ - private def buildRelation(tableIdent: String): LogicalPlan = { - val relation = RiakRelation(tableIdent, rsc, None, Some(riakConnector), readConf, writeConf) - Subquery(tableIdent, LogicalRelation(relation)) - } - - /** Return a table identifier with table name, keyspace name and cluster name */ - private def bucketIdent(tableIdentifier: Seq[String]): String = { - require(tableIdentifier.size == 1) - tableIdentifier.head - } - - private def bucketIdent(tableIdentifier: TableIdentifier): String = { - tableIdentifier.table - } -} diff --git a/connector/src/main/scala/org/apache/spark/sql/riak/RiakRelation.scala b/connector/src/main/scala/org/apache/spark/sql/riak/RiakRelation.scala index 64dfd5eb..c9a0ad6d 100644 --- a/connector/src/main/scala/org/apache/spark/sql/riak/RiakRelation.scala +++ b/connector/src/main/scala/org/apache/spark/sql/riak/RiakRelation.scala @@ -18,19 +18,19 @@ package org.apache.spark.sql.riak import com.basho.riak.spark._ + import scala.reflect._ -import com.basho.riak.spark.rdd.connector.{RiakConnectorConf, RiakConnector} +import com.basho.riak.spark.rdd.connector.{RiakConnector, RiakConnectorConf} import com.basho.riak.spark.rdd.{ReadConf, RiakTSRDD} -import com.basho.riak.spark.util.TSConversionUtil import com.basho.riak.spark.writer.WriteConf import com.basho.riak.spark.writer.mapper.SqlDataMapper -import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.{InsertableRelation, BaseRelation, Filter, PrunedFilteredScan} +import org.apache.spark.sql.sources.{BaseRelation, Filter, InsertableRelation, PrunedFilteredScan} import org.apache.spark.sql.types._ import org.apache.spark.sql._ -import scala.collection.convert.decorateAsScala._ + import com.basho.riak.spark.query.QueryBucketDef +import org.apache.spark.riak.Logging /** * Implements [[BaseRelation]]]], [[InsertableRelation]]]] and [[PrunedFilteredScan]]]] diff --git a/connector/src/main/scala/org/apache/spark/sql/riak/RiakSQLContext.scala b/connector/src/main/scala/org/apache/spark/sql/riak/RiakSQLContext.scala deleted file mode 100644 index bbd08ed8..00000000 --- a/connector/src/main/scala/org/apache/spark/sql/riak/RiakSQLContext.scala +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2016 IBM Corp. - * - * Created by Basho Technologies for IBM - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ -package org.apache.spark.sql.riak - -import com.basho.riak.spark.rdd.ReadConf -import com.basho.riak.spark.rdd.connector.RiakConnector -import com.basho.riak.spark.writer.WriteConf -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} - -/** - * Allows to execute SQL queries against Riak TS. - * Predicate pushdown is supported. - * - * @author Sergey Galkin - * @since 1.2.0 - */ -class RiakSQLContext(sc: SparkContext) extends SQLContext(sc) { - - /** A catalyst metadata catalog that points to Riak. */ - @transient - override protected[sql] lazy val catalog = new RiakCatalog( - this, RiakConnector(sc.getConf), ReadConf(sc.getConf), WriteConf(sc.getConf)) - - /** Executes SQL query against Riak TS and returns DataFrame representing the result. */ - def riakTsSql(tsQuery: String): DataFrame = new DataFrame(this, super.parseSql(tsQuery)) - - /** Delegates to [[riakTsSql]] */ - override def sql(tsQuery: String): DataFrame = riakTsSql(tsQuery) - -} \ No newline at end of file diff --git a/connector/src/test/java/com/basho/riak/spark/rdd/AbstractJavaSparkTest.java b/connector/src/test/java/com/basho/riak/spark/rdd/AbstractJavaSparkTest.java index b8237690..2fab8a77 100644 --- a/connector/src/test/java/com/basho/riak/spark/rdd/AbstractJavaSparkTest.java +++ b/connector/src/test/java/com/basho/riak/spark/rdd/AbstractJavaSparkTest.java @@ -32,14 +32,7 @@ public abstract class AbstractJavaSparkTest extends AbstractRiakSparkTest { // JavaSparkContext, created per test case - protected JavaSparkContext jsc = null; - - @Override - public SparkContext createSparkContext(SparkConf conf) { - final SparkContext sc = new SparkContext(conf); - jsc = new JavaSparkContext(sc); - return sc; - } + protected JavaSparkContext jsc = new JavaSparkContext(sparkSession().sparkContext()); protected static class FuncReMapWithPartitionIdx implements Function2, Iterator>> { @Override diff --git a/connector/src/test/java/com/basho/riak/spark/rdd/timeseries/AbstractJavaTimeSeriesTest.java b/connector/src/test/java/com/basho/riak/spark/rdd/timeseries/AbstractJavaTimeSeriesTest.java index ec5c426c..b6540211 100644 --- a/connector/src/test/java/com/basho/riak/spark/rdd/timeseries/AbstractJavaTimeSeriesTest.java +++ b/connector/src/test/java/com/basho/riak/spark/rdd/timeseries/AbstractJavaTimeSeriesTest.java @@ -8,19 +8,12 @@ public abstract class AbstractJavaTimeSeriesTest extends AbstractTimeSeriesTest { // JavaSparkContext, created per test case - protected JavaSparkContext jsc = null; + protected JavaSparkContext jsc = new JavaSparkContext(sparkSession().sparkContext()); public AbstractJavaTimeSeriesTest(boolean createTestDate) { super(createTestDate); } - @Override - public SparkContext createSparkContext(SparkConf conf) { - final SparkContext sc = new SparkContext(conf); - jsc = new JavaSparkContext(sc); - return sc; - } - protected String stringify(String[] strings) { return "[" + StringUtils.join(strings, ",") + "]"; } diff --git a/connector/src/test/java/com/basho/riak/spark/rdd/timeseries/TimeSeriesJavaReadTest.java b/connector/src/test/java/com/basho/riak/spark/rdd/timeseries/TimeSeriesJavaReadTest.java index 948f8c57..e41eacfa 100644 --- a/connector/src/test/java/com/basho/riak/spark/rdd/timeseries/TimeSeriesJavaReadTest.java +++ b/connector/src/test/java/com/basho/riak/spark/rdd/timeseries/TimeSeriesJavaReadTest.java @@ -19,9 +19,8 @@ import com.basho.riak.spark.rdd.RiakTSTests; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.functions; import org.apache.spark.sql.types.DataTypes; @@ -64,17 +63,16 @@ public void readDataAsSqlRow() { @Test public void riakTSRDDToDataFrame() { - SQLContext sqlContext = new SQLContext(jsc); JavaRDD rows = javaFunctions(jsc) .riakTSTable(bucketName(), Row.class) .sql(String.format("SELECT time, user_id, temperature_k FROM %s %s", bucketName(), sqlWhereClause())) .map(r -> new TimeSeriesDataBean(r.getTimestamp(0).getTime(), r.getString(1), r.getDouble(2))); - DataFrame df = sqlContext.createDataFrame(rows, TimeSeriesDataBean.class); - df.registerTempTable("test"); + Dataset df = sparkSession().createDataFrame(rows, TimeSeriesDataBean.class); + df.createOrReplaceTempView("test"); // Explicit cast due to compilation error "Object cannot be converted to java.lang.String[]" - String[] data = (String[]) sqlContext.sql("select * from test").toJSON().collect(); + String[] data = (String[]) sparkSession().sql("select * from test").toJSON().collect(); assertEqualsUsingJSONIgnoreOrder("[" + "{time:111111, user_id:'bryce', temperature_k:305.37}," + "{time:111222, user_id:'bryce', temperature_k:300.12}," + @@ -92,17 +90,16 @@ public void riakTSRDDToDataFrameConvertTimestamp() { DataTypes.createStructField("temperature_k", DataTypes.DoubleType, true), }); - SQLContext sqlContext = new SQLContext(jsc); JavaRDD rows = javaFunctions(jsc) .riakTSTable(bucketName(), structType, Row.class) .sql(String.format("SELECT time, user_id, temperature_k FROM %s %s", bucketName(), sqlWhereClause())) .map(r -> new TimeSeriesDataBean(r.getLong(0), r.getString(1), r.getDouble(2))); - DataFrame df = sqlContext.createDataFrame(rows, TimeSeriesDataBean.class); - df.registerTempTable("test"); + Dataset df = sparkSession().createDataFrame(rows, TimeSeriesDataBean.class); + df.createOrReplaceTempView("test"); // Explicit cast due to compilation error "Object cannot be converted to java.lang.String[]" - String[] data = (String[]) sqlContext.sql("select * from test").toJSON().collect(); + String[] data = (String[]) sparkSession().sql("select * from test").toJSON().collect(); assertEqualsUsingJSONIgnoreOrder("[" + "{time:111111, user_id:'bryce', temperature_k:305.37}," + "{time:111222, user_id:'bryce', temperature_k:300.12}," + @@ -114,11 +111,9 @@ public void riakTSRDDToDataFrameConvertTimestamp() { @Test public void dataFrameGenericLoad() { - SQLContext sqlContext = new SQLContext(jsc); + sparkSession().udf().register("getMillis", (UDF1) Timestamp::getTime, DataTypes.LongType); - sqlContext.udf().register("getMillis", (UDF1) Timestamp::getTime, DataTypes.LongType); - - DataFrame df = sqlContext.read() + Dataset df = sparkSession().read() .format("org.apache.spark.sql.riak") .schema(schema()) .load(bucketName()) @@ -138,8 +133,6 @@ public void dataFrameGenericLoad() { @Test public void dataFrameReadShouldConvertTimestampToLong() { - SQLContext sqlContext = new SQLContext(jsc); - StructType structType = new StructType(new StructField[]{ DataTypes.createStructField("surrogate_key", DataTypes.LongType, true), DataTypes.createStructField("family", DataTypes.StringType, true), @@ -148,7 +141,7 @@ public void dataFrameReadShouldConvertTimestampToLong() { DataTypes.createStructField("temperature_k", DataTypes.DoubleType, true), }); - DataFrame df = sqlContext.read() + Dataset df = sparkSession().read() .option("spark.riak.partitioning.ts-range-field-name", "time") .format("org.apache.spark.sql.riak") .schema(structType) @@ -169,9 +162,7 @@ public void dataFrameReadShouldConvertTimestampToLong() { @Test public void dataFrameReadShouldHandleTimestampAsLong() { - SQLContext sqlContext = new SQLContext(jsc); - - DataFrame df = sqlContext.read() + Dataset df = sparkSession().read() .format("org.apache.spark.sql.riak") .option("spark.riakts.bindings.timestamp", "useLong") .option("spark.riak.partitioning.ts-range-field-name", "time") diff --git a/connector/src/test/java/com/basho/riak/spark/rdd/timeseries/TimeSeriesJavaWriteTest.java b/connector/src/test/java/com/basho/riak/spark/rdd/timeseries/TimeSeriesJavaWriteTest.java index f95606e2..08b24e45 100644 --- a/connector/src/test/java/com/basho/riak/spark/rdd/timeseries/TimeSeriesJavaWriteTest.java +++ b/connector/src/test/java/com/basho/riak/spark/rdd/timeseries/TimeSeriesJavaWriteTest.java @@ -20,10 +20,7 @@ import com.basho.riak.spark.japi.rdd.RiakTSJavaRDD; import com.basho.riak.spark.rdd.RiakTSTests; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.Row$; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType$; import org.junit.Test; @@ -74,8 +71,6 @@ public void saveSqlRowsToRiak() { @Test public void saveDataFrameWithSchemaToRiak() { - SQLContext sqlContext = new SQLContext(jsc); - JavaRDD jsonRdd = jsc.parallelize(asList( "{\"surrogate_key\": 1, \"family\": \"f\", \"time\": 111111, \"user_id\": \"bryce\", \"temperature_k\": 305.37}", "{\"surrogate_key\": 1, \"family\": \"f\", \"time\": 111222, \"user_id\": \"bryce\", \"temperature_k\": 300.12}", @@ -84,7 +79,7 @@ public void saveDataFrameWithSchemaToRiak() { "{\"surrogate_key\": 1, \"family\": \"f\", \"time\": 111555, \"user_id\": \"ratman\", \"temperature_k\": 3502.212}" )); - DataFrame df = sqlContext.read().schema(StructType$.MODULE$.apply(asScalaBuffer(asList( + Dataset df = sparkSession().read().schema(StructType$.MODULE$.apply(asScalaBuffer(asList( DataTypes.createStructField("surrogate_key", DataTypes.IntegerType, true), DataTypes.createStructField("family", DataTypes.StringType, true), DataTypes.createStructField("time", DataTypes.LongType, true), diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/AbstractRiakSparkTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/AbstractRiakSparkTest.scala index 01ca392f..d5d6fb86 100644 --- a/connector/src/test/scala/com/basho/riak/spark/rdd/AbstractRiakSparkTest.scala +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/AbstractRiakSparkTest.scala @@ -31,6 +31,7 @@ import scala.reflect.ClassTag import com.basho.riak.spark.rdd.AbstractRiakSparkTest._ import com.basho.riak.spark.rdd.mapper.ReadValueDataMapper import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession import org.junit.ClassRule import scala.collection.JavaConversions._ @@ -38,6 +39,7 @@ import scala.collection.JavaConversions._ abstract class AbstractRiakSparkTest extends AbstractRiakTest { // SparkContext, created per test case + protected val sparkSession: SparkSession = createSparkSession(initSparkConf()) protected var sc: SparkContext = _ protected override def riakHosts: Set[HostAndPort] = HostAndPort.hostsFromString( @@ -55,10 +57,10 @@ abstract class AbstractRiakSparkTest extends AbstractRiakTest { override def initialize(): Unit = { super.initialize() - sc = createSparkContext(initSparkConf()) + sc = sparkSession.sparkContext } - protected def createSparkContext(conf: SparkConf): SparkContext = new SparkContext(conf) + protected def createSparkSession(conf: SparkConf): SparkSession = SparkSession.builder().config(conf).getOrCreate() @After def destroySparkContext(): Unit = Option(sc).foreach(x => x.stop()) diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/Filters2sqlConversionTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/Filters2sqlConversionTest.scala new file mode 100644 index 00000000..6fafc35b --- /dev/null +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/Filters2sqlConversionTest.scala @@ -0,0 +1,136 @@ +package com.basho.riak.spark.rdd + +import com.basho.riak.spark.rdd.connector.RiakConnector +import com.basho.riak.spark.rdd.partitioner.RiakTSCoveragePlanBasedPartitioner +import org.apache.spark.sql.sources._ +import org.junit.{Rule, Test} +import org.junit.runner.RunWith +import org.junit.Assert._ +import org.junit.rules.ExpectedException +import org.mockito.Mock +import org.mockito.runners.MockitoJUnitRunner + +@RunWith(classOf[MockitoJUnitRunner]) +class Filters2sqlConversionTest { + + @Mock + private val rc: RiakConnector = null + + val _expectedException: ExpectedException = ExpectedException.none() + + @Rule + def expectedException: ExpectedException = _expectedException + + private val bucketName = "test" + + private val equalTo = EqualTo("field1", "value1") + private val lessThan = LessThan("field2", 2) + private val lessThanOrEqual = LessThanOrEqual("field3", 3) + private val greaterThan = GreaterThan("field4", 4) + private val greaterThanOrEqual = GreaterThanOrEqual("field5", 5) + private val isNotNull = IsNotNull("field6") + private val isNull = IsNull("field7") + + // Unsupported filters + private val equalNullSafe = EqualNullSafe("field", "value") + private val in = In("field", Array("val0","val1")) + private val and = And(equalTo, lessThan) + private val or = Or(equalTo,lessThan) + private val not = Not(equalTo) + private val stringStartsWith = StringStartsWith("field","value") + private val stringEndsWith = StringEndsWith("field","value") + private val stringContains = StringContains("field","value") + + private def verifyFilters(expectedSql: String, filters: Filter*) = { + val partitioner = new RiakTSCoveragePlanBasedPartitioner(rc, bucketName, None, None, filters.toArray, new ReadConf()) + assertEquals(expectedSql, partitioner.query) + } + + private def verifyUnsupportedFilter(filter: Filter, expectedFilter: String) = { + expectedException.expect(classOf[UnsupportedOperationException]) + expectedException.expectMessage(s"It's not a valid filter $expectedFilter " + + s"to be pushed down, only is not null, is null, >, <, >=, <= and = are allowed") + new RiakTSCoveragePlanBasedPartitioner(rc, bucketName, None, None, Array(filter), new ReadConf()) + } + + @Test + def testEqualToConversion(): Unit = { + verifyFilters("SELECT * FROM test WHERE field1 = 'value1'", equalTo) + } + + @Test + def testLessThanConversion(): Unit = { + verifyFilters("SELECT * FROM test WHERE field2 < 2", lessThan) + } + + @Test + def testLessThanOrEqualConversion(): Unit = { + verifyFilters("SELECT * FROM test WHERE field3 <= 3", lessThanOrEqual) + } + + @Test + def testGreaterThanConversion(): Unit = { + verifyFilters("SELECT * FROM test WHERE field4 > 4", greaterThan) + } + + @Test + def testGreaterThanOrEqualConversion(): Unit = { + verifyFilters("SELECT * FROM test WHERE field5 >= 5", greaterThanOrEqual) + } + + @Test + def testIsNotNullConversion(): Unit = { + verifyFilters("SELECT * FROM test WHERE field6 IS NOT NULL", isNotNull) + } + + @Test + def testIsNullConversion(): Unit = { + verifyFilters("SELECT * FROM test WHERE field7 IS NULL", isNull) + } + + @Test + def testMultipleFiltersConversion(): Unit = { + verifyFilters("SELECT * FROM test WHERE field1 = 'value1' AND field2 < 2 AND field6 IS NOT NULL", + equalTo,lessThan, isNotNull) + } + + @Test + def testUnsuportedFiltersEqualNullSafeConversion(): Unit = { + verifyUnsupportedFilter(equalNullSafe, "EqualNullSafe(field,value)") + } + + @Test + def testUnsuportedFiltersInConversion(): Unit = { + verifyUnsupportedFilter(in, "In(field, [val0,val1]") + } + + @Test + def testUnsuportedFiltersAndConversion(): Unit = { + verifyUnsupportedFilter(and, "And(EqualTo(field1,value1),LessThan(field2,2))") + } + + @Test + def testUnsuportedFiltersOrConversion(): Unit = { + verifyUnsupportedFilter(or, "Or(EqualTo(field1,value1),LessThan(field2,2))") + } + + @Test + def testUnsuportedFiltersNotConversion(): Unit = { + verifyUnsupportedFilter(not, "Not(EqualTo(field1,value1))") + } + + @Test + def testUnsuportedFiltersStringStartsWithConversion(): Unit = { + verifyUnsupportedFilter(stringStartsWith, "StringStartsWith(field,value)") + } + + @Test + def testUnsuportedFiltersStringEndsWithConversion(): Unit = { + verifyUnsupportedFilter(stringEndsWith, "StringEndsWith(field,value)") + } + + @Test + def testUnsuportedFiltersStringContainsConversion(): Unit = { + verifyUnsupportedFilter(stringContains, "StringContains(field,value)") + } +} \ No newline at end of file diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/SparkDataframesTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/SparkDataframesTest.scala index 04fc8084..c706069b 100644 --- a/connector/src/test/scala/com/basho/riak/spark/rdd/SparkDataframesTest.scala +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/SparkDataframesTest.scala @@ -19,7 +19,7 @@ package com.basho.riak.spark.rdd import scala.reflect.runtime.universe import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.junit.Assert._ import org.junit.{ Before, Test } import com.basho.riak.spark.toSparkContextFunctions @@ -44,17 +44,15 @@ class SparkDataframesTest extends AbstractRiakSparkTest { protected override def initSparkConf() = super.initSparkConf().setAppName("Dataframes Test") - var sqlContextHolder: SQLContext = _ var df: DataFrame = _ @Before def initializeDF(): Unit = { - val sqlContext = new org.apache.spark.sql.SQLContext(sc) - import sqlContext.implicits._ - sqlContextHolder = sqlContext + val spark = sparkSession + import spark.implicits._ df = sc.riakBucket[TestData](DEFAULT_NAMESPACE.getBucketNameAsString) .queryAll().toDF - df.registerTempTable("test") + df.createTempView("test") } @Test @@ -67,7 +65,7 @@ class SparkDataframesTest extends AbstractRiakSparkTest { @Test def sqlQueryTest(): Unit = { - val sqlResult = sqlContextHolder.sql("select * from test where category >= 'CategoryC'").toJSON.collect + val sqlResult = sparkSession.sql("select * from test where category >= 'CategoryC'").toJSON.collect val expected = """ [ | {id:'u4',name:'Chris',age:10,category:'CategoryC'}, @@ -78,8 +76,8 @@ class SparkDataframesTest extends AbstractRiakSparkTest { @Test def udfTest(): Unit = { - sqlContextHolder.udf.register("stringLength", (s: String) => s.length) - val udf = sqlContextHolder.sql("select name, stringLength(name) strLgth from test order by strLgth, name").toJSON.collect + sparkSession.udf.register("stringLength", (s: String) => s.length) + val udf = sparkSession.sql("select name, stringLength(name) strLgth from test order by strLgth, name").toJSON.collect val expected = """ [ | {name:'Ben',strLgth:3}, @@ -107,7 +105,7 @@ class SparkDataframesTest extends AbstractRiakSparkTest { @Test def sqlVsFilterTest(): Unit = { - val sql = sqlContextHolder.sql("select id, name from test where age >= 50").toJSON.collect + val sql = sparkSession.sql("select id, name from test where age >= 50").toJSON.collect val filtered = df.where(df("age") >= 50).select("id", "name").toJSON.collect assertEqualsUsingJSONIgnoreOrder(stringify(sql), stringify(filtered)) } diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/SparkJobCompletionTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/SparkJobCompletionTest.scala index 9f5c2975..5f327e63 100644 --- a/connector/src/test/scala/com/basho/riak/spark/rdd/SparkJobCompletionTest.scala +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/SparkJobCompletionTest.scala @@ -23,6 +23,7 @@ import com.basho.riak.client.core.query.Namespace import com.basho.riak.spark._ import com.basho.riak.spark.rdd.SparkJobCompletionTest._ import com.basho.riak.spark.rdd.connector.RiakConnectorConf +import org.apache.spark.sql.SparkSession import org.apache.spark.{SparkConf, SparkContext} import org.junit.Test import org.junit.Assert @@ -127,7 +128,8 @@ object SparkJobCompletionTest extends JsonFunctions { .set("spark.riak.connections.inactivity.timeout", (RiakConnectorConf.defaultInactivityTimeout * 60 * 5).toString) // 5 minutes is enough time to complete Spark job - val data = new SparkContext(sparkConf).riakBucket(ns).queryAll().collect() + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + val data = sparkSession.sparkContext.riakBucket(ns).queryAll().collect() // HACK: Results should be printed for further analysis in the original JVM // to indicate that Spark job was completed successfully diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/failover/AbstractFailoverOfflineTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/failover/AbstractFailoverOfflineTest.scala index 04755027..c048d240 100644 --- a/connector/src/test/scala/com/basho/riak/spark/rdd/failover/AbstractFailoverOfflineTest.scala +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/failover/AbstractFailoverOfflineTest.scala @@ -4,7 +4,9 @@ import com.basho.riak.client.core.query.Namespace import com.basho.riak.client.core.util.HostAndPort import com.basho.riak.stub.{RiakMessageHandler, RiakNodeStub} import org.apache.commons.lang3.exception.ExceptionUtils -import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.riak.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.{SparkConf, SparkContext} import org.hamcrest.{Description, Matchers} import org.junit.internal.matchers.ThrowableCauseMatcher import org.junit.{After, Before} @@ -43,7 +45,8 @@ abstract class AbstractFailoverOfflineTest extends Logging { @Before def setUp(): Unit = { riakNodes = initRiakNodes() - sc = new SparkContext(sparkConf) + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + sc = sparkSession.sparkContext } @After diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/failover/FailoverTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/failover/FailoverTest.scala index 33ce5029..d08ff57b 100644 --- a/connector/src/test/scala/com/basho/riak/spark/rdd/failover/FailoverTest.scala +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/failover/FailoverTest.scala @@ -18,25 +18,28 @@ class FailoverTest extends AbstractRiakSparkTest { private val NUMBER_OF_TEST_VALUES = 1000 private val STUBS_AMOUNT = 1 - private var stubNodes: Seq[RiakNodeStub] = Seq() + private var stubNodes: Seq[RiakNodeStub] = _ protected override val jsonData = Some(asStrictJSON((1 to NUMBER_OF_TEST_VALUES) .map(i => Map("key" -> s"k$i", "value" -> s"v$i", "indexes" -> Map("creationNo" -> i))))) // scalastyle:ignore // Configure Spark using proxied hosts - override protected def initSparkConf(): SparkConf = super.initSparkConf() - .set("spark.riak.connection.host", riakHosts.map { - case hp: HostAndPort if stubNodes.length < STUBS_AMOUNT => - val stub = RiakNodeStub(new ProxyMessageHandler(hp) { - override def onRespond(input: RiakMessage, output: Iterable[RiakMessage]): Unit = input.getCode match { - case MSG_CoverageReq => stubNodes.head.stop() // stop proxy node after coverage plan sent to client - case _ => super.onRespond(input, output) - } - }) - stubNodes = stubNodes :+ stub - stub.start() - case hp: HostAndPort => hp - }.map(hp => s"${hp.getHost}:${hp.getPort}").mkString(",")) + override protected def initSparkConf(): SparkConf = { + stubNodes = Seq() + super.initSparkConf() + .set("spark.riak.connection.host", riakHosts.map { + case hp: HostAndPort if stubNodes.length < STUBS_AMOUNT => + val stub = RiakNodeStub(new ProxyMessageHandler(hp) { + override def onRespond(input: RiakMessage, output: Iterable[RiakMessage]): Unit = input.getCode match { + case MSG_CoverageReq => stubNodes.head.stop() // stop proxy node after coverage plan sent to client + case _ => super.onRespond(input, output) + } + }) + stubNodes = stubNodes :+ stub + stub.start() + case hp: HostAndPort => hp + }.map(hp => s"${hp.getHost}:${hp.getPort}").mkString(",")) + } @After override def destroySparkContext(): Unit = { diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/AbstractTimeSeriesTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/AbstractTimeSeriesTest.scala index de631aff..a1825767 100644 --- a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/AbstractTimeSeriesTest.scala +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/AbstractTimeSeriesTest.scala @@ -1,5 +1,5 @@ /** - * Copyright (c) 2015 Basho Technologies, Inc. + * Copyright (c) 2015-2017 Basho Technologies, Inc. * * This file is provided to you under the Apache License, * Version 2.0 (the "License"); you may not use this file @@ -33,7 +33,7 @@ import com.basho.riak.client.core.query.timeseries.FullColumnDescription import com.basho.riak.client.core.query.timeseries.Row import com.basho.riak.client.core.query.timeseries.TableDefinition import com.basho.riak.spark.rdd.AbstractRiakSparkTest -import org.apache.spark.Logging +import org.apache.spark.riak.Logging import org.apache.spark.sql.types._ import org.junit.Assert._ import org.junit.Rule @@ -101,7 +101,8 @@ abstract class AbstractTimeSeriesTest(val createTestData: Boolean = true) extend new Cell(f.temperature_k)) ) - final val sqlWhereClause = s"WHERE time >= $queryFromMillis AND time <= $queryToMillis AND surrogate_key = 1 AND family = 'f'" + final val filterExpression = s"time >= $queryFromMillis AND time <= $queryToMillis AND surrogate_key = 1 AND family = 'f'" + final val sqlWhereClause = s"WHERE $filterExpression" final val sqlQuery = s"SELECT surrogate_key, family, time, user_id, temperature_k FROM $bucketName $sqlWhereClause" @@ -161,6 +162,7 @@ abstract class AbstractTimeSeriesTest(val createTestData: Boolean = true) extend session.execute(new CreateTable.Builder(tableDefinition) .withQuantum(10, TimeUnit.SECONDS) // scalastyle:ignore .build()) + case Left(ex) => throw ex } }) } diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/SparkDataSetTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/SparkDataSetTest.scala new file mode 100644 index 00000000..6fa00873 --- /dev/null +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/SparkDataSetTest.scala @@ -0,0 +1,55 @@ +/** + * Copyright (c) 2015-2017 Basho Technologies, Inc. + * + * This file is provided to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file + * except in compliance with the License. You may obtain + * a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package com.basho.riak.spark.rdd.timeseries + +import com.basho.riak.spark.rdd.RiakTSTests +import org.junit.Test +import org.junit.experimental.categories.Category + +/** + * @author Sergey Galkin + */ +@Category(Array(classOf[RiakTSTests])) +class SparkDataSetTest extends AbstractTimeSeriesTest { + + @Test + def genericLoadAsDataSet(): Unit = { + import sparkSession.implicits._ + + val ds = sparkSession.read + .format("org.apache.spark.sql.riak") + .option("spark.riakts.bindings.timestamp", "useLong") + .load(bucketName) + .filter(filterExpression) + .as[TimeSeriesData] + + val data: Array[TimeSeriesData] = ds.collect() + + // -- verification + assertEqualsUsingJSONIgnoreOrder( + """ + |[ + | {time:111111, user_id:'bryce', temperature_k:305.37}, + | {time:111222, user_id:'bryce', temperature_k:300.12}, + | {time:111333, user_id:'bryce', temperature_k:295.95}, + | {time:111444, user_id:'ratman', temperature_k:362.121}, + | {time:111555, user_id:'ratman', temperature_k:3502.212} + |] + """.stripMargin, data) + } +} diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TSConversionTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TSConversionTest.scala index 52076a96..446ad04d 100644 --- a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TSConversionTest.scala +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TSConversionTest.scala @@ -19,14 +19,16 @@ package com.basho.riak.spark.rdd.timeseries import java.sql.Timestamp import java.util.{Calendar, Date} + import com.basho.riak.client.core.query.timeseries.ColumnDescription.ColumnType._ -import com.basho.riak.client.core.query.timeseries.{ColumnDescription, Cell, Row} +import com.basho.riak.client.core.query.timeseries.{Cell, ColumnDescription, Row} import com.basho.riak.client.core.util.BinaryValue import com.basho.riak.spark.util.TSConversionUtil -import org.apache.spark.Logging +import org.apache.spark.riak.Logging import org.apache.spark.sql.types._ import org.junit.Assert._ import org.junit.Test + import scala.collection.JavaConversions._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesPartitioningTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesPartitioningTest.scala index b16cb1b9..058573cc 100644 --- a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesPartitioningTest.scala +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesPartitioningTest.scala @@ -43,16 +43,9 @@ class TimeSeriesPartitioningTest extends AbstractTimeSeriesTest(createTestData = LessThan("time", to), EqualTo("user_id", "user1")) - var sqlContext: org.apache.spark.sql.SQLContext = null - - override def initialize(): Unit = { - super.initialize() - sqlContext = new org.apache.spark.sql.SQLContext(sc) - } - @Test def withOptionTest(): Unit = { - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", partitionsCount.toString) .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) .format("org.apache.spark.sql.riak") @@ -67,7 +60,7 @@ class TimeSeriesPartitioningTest extends AbstractTimeSeriesTest(createTestData = @Test def smallRangeShouldBeSinglePartitionTest(): Unit = { val (localFrom, localTo) = (new Timestamp(500L), new Timestamp(504L)) - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", partitionsCount.toString) .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) .format("org.apache.spark.sql.riak") @@ -83,7 +76,7 @@ class TimeSeriesPartitioningTest extends AbstractTimeSeriesTest(createTestData = def invalidRangeTest(): Unit = { expectedException.expect(classOf[IllegalArgumentException]) expectedException.expectMessage("requirement failed: Invalid range query") - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", partitionsCount.toString) .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) .option("spark.riak.partitioning.ts-quantum", "10s") @@ -97,7 +90,7 @@ class TimeSeriesPartitioningTest extends AbstractTimeSeriesTest(createTestData = @Test def withOptionTestFromToTo(): Unit = { - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", partitionsCount.toString) .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) .option("spark.riak.partitioning.ts-quantum", "10s") @@ -121,7 +114,7 @@ class TimeSeriesPartitioningTest extends AbstractTimeSeriesTest(createTestData = @Test def greaterThanToLessThanOrEqualTest(): Unit = { - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", partitionsCount.toString) .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) .format("org.apache.spark.sql.riak") @@ -140,7 +133,7 @@ class TimeSeriesPartitioningTest extends AbstractTimeSeriesTest(createTestData = @Test def lessThanToGreaterThanOrEqualTest(): Unit = { - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", partitionsCount.toString) .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) .format("org.apache.spark.sql.riak") @@ -161,7 +154,7 @@ class TimeSeriesPartitioningTest extends AbstractTimeSeriesTest(createTestData = def noLessThanTest(): Unit = { expectedException.expect(classOf[IllegalArgumentException]) expectedException.expectMessage(s"No LessThanOrEqual or LessThan filers found for tsRangeFieldName $tsRangeFieldName") - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", partitionsCount.toString) .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) .format("org.apache.spark.sql.riak") @@ -175,7 +168,7 @@ class TimeSeriesPartitioningTest extends AbstractTimeSeriesTest(createTestData = def noGreaterThanTest(): Unit = { expectedException.expect(classOf[IllegalArgumentException]) expectedException.expectMessage(s"No GreaterThanOrEqual or GreaterThan filers found for tsRangeFieldName $tsRangeFieldName") - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", partitionsCount.toString) .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) .format("org.apache.spark.sql.riak") @@ -187,7 +180,7 @@ class TimeSeriesPartitioningTest extends AbstractTimeSeriesTest(createTestData = @Test def withLessThanQuantaLimitTest(): Unit = { - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", partitionsCount.toString) .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) .option("spark.riak.partitioning.ts-quantum", "20d") @@ -205,7 +198,7 @@ class TimeSeriesPartitioningTest extends AbstractTimeSeriesTest(createTestData = @Test def withGreaterThanQuantaLimitTest(): Unit = { val (localFrom, localTo) = (new Timestamp(1000000L), new Timestamp(3000000L)) - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", partitionsCount.toString) .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) .option("spark.riak.partitioning.ts-quantum", "10s") @@ -224,7 +217,7 @@ class TimeSeriesPartitioningTest extends AbstractTimeSeriesTest(createTestData = def noFiltersForFieldTest(): Unit = { expectedException.expect(classOf[IllegalArgumentException]) expectedException.expectMessage(s"No filers found for tsRangeFieldName $tsRangeFieldName") - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", partitionsCount.toString) .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) .format("org.apache.spark.sql.riak") diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesReadTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesReadTest.scala index 30610581..36f5bffa 100644 --- a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesReadTest.scala +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesReadTest.scala @@ -20,9 +20,7 @@ package com.basho.riak.spark.rdd.timeseries import com.basho.riak.spark.rdd.RiakTSTests import com.basho.riak.spark.toSparkContextFunctions import org.apache.spark.SparkException -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.riak.RiakSQLContext import org.apache.spark.sql.types._ import org.junit.Assert.assertEquals import org.junit.Test @@ -33,6 +31,7 @@ import org.junit.experimental.categories.Category */ @Category(Array(classOf[RiakTSTests])) class TimeSeriesReadTest extends AbstractTimeSeriesTest { + import sparkSession.implicits._ @Test def readDataAsSqlRow(): Unit = { @@ -57,17 +56,14 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { // TODO: Consider possibility of moving this case to the SparkDataframesTest @Test def riakTSRDDToDataFrame(): Unit = { - val sqlContext = new org.apache.spark.sql.SQLContext(sc) - import sqlContext.implicits._ - val df = sc.riakTSTable[org.apache.spark.sql.Row](bucketName) .sql(s"SELECT time, user_id, temperature_k FROM $bucketName $sqlWhereClause") .map(r => TimeSeriesData(r.getTimestamp(0).getTime, r.getString(1), r.getDouble(2))) .toDF() - df.registerTempTable("test") + df.createTempView("test") - val data = sqlContext.sql("select * from test").toJSON.collect() + val data = sparkSession.sql("select * from test").toJSON.collect() // -- verification assertEqualsUsingJSONIgnoreOrder( @@ -84,9 +80,6 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { @Test def riakTSRDDToDataFrameConvertTimestamp(): Unit = { - val sqlContext = new org.apache.spark.sql.SQLContext(sc) - import sqlContext.implicits._ - val structType = StructType(List( StructField(name = "time", dataType = LongType), StructField(name = "user_id", dataType = StringType), @@ -98,9 +91,9 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { .map(r => TimeSeriesData(r.getLong(0), r.getString(1), r.getDouble(2))) .toDF() - df.registerTempTable("test") + df.createTempView("test") - val data = sqlContext.sql("select * from test").toJSON.collect() + val data = sparkSession.sql("select * from test").toJSON.collect() // -- verification assertEqualsUsingJSONIgnoreOrder( @@ -130,49 +123,13 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { .collect() } - @Test - def sqlRangeQuery(): Unit = { - /* - * This usage scenario requires to use RiakSQLContext, otherwise - * RuntimeException('Table Not Found: time_series_test') will be thrown - */ - val sqlContext = new RiakSQLContext(sc) - sqlContext.udf.register("getMillis", getMillis) // transforms timestamp to not deal with timezones - val df = sqlContext.sql( - s""" - | SELECT getMillis(time) as time, user_id, temperature_k - | FROM $bucketName - | WHERE time >= CAST('$fromStr' AS TIMESTAMP) - | AND time <= CAST('$toStr' AS TIMESTAMP) - | AND surrogate_key = 1 - | AND family = 'f' - """.stripMargin) - - // -- verification - val data = df.toJSON.collect() - - assertEqualsUsingJSONIgnoreOrder( - """ - |[ - | {time: 111111, user_id:'bryce', temperature_k:305.37}, - | {time: 111222, user_id:'bryce', temperature_k:300.12}, - | {time: 111333, user_id:'bryce', temperature_k:295.95}, - | {time: 111444, user_id:'ratman', temperature_k:362.121}, - | {time: 111555, user_id:'ratman', temperature_k:3502.212} - |] - """.stripMargin, stringify(data)) - } - @Test def dataFrameGenericLoad(): Unit = { - val sqlContext = new SQLContext(sc) - sqlContext.udf.register("getMillis", getMillis) // transforms timestamp to not deal with timezones - - import sqlContext.implicits._ + sparkSession.udf.register("getMillis", getMillis) // transforms timestamp to not deal with timezones val udfGetMillis = udf(getMillis) - val df = sqlContext.read + val df = sparkSession.read .format("org.apache.spark.sql.riak") // For real usage no need to provide schema manually .schema(schema) @@ -198,10 +155,6 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { @Test def dataFrameReadShouldConvertTimestampToLong(): Unit = { - val sqlContext = new SQLContext(sc) - - import sqlContext.implicits._ - val newSchema = StructType(List( StructField(name = "surrogate_key", dataType = LongType), StructField(name = "family", dataType = StringType), @@ -210,7 +163,7 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { StructField(name = "temperature_k", dataType = DoubleType)) ) - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.partitioning.ts-range-field-name", "time") .format("org.apache.spark.sql.riak") .schema(newSchema) @@ -249,11 +202,7 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { StructField(name = "unknown_field", dataType = StringType)) ) - val sqlContext = new org.apache.spark.sql.SQLContext(sc) - - import sqlContext.implicits._ - - sqlContext.read + sparkSession.read .option("spark.riak.partitioning.ts-range-field-name", "time") .format("org.apache.spark.sql.riak") .schema(structType) @@ -265,15 +214,13 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { @Test def sqlReadSingleFieldShouldPass(): Unit = { - val sqlContext = new SQLContext(sc) - - sqlContext.read + sparkSession.read .option("spark.riak.partitioning.ts-range-field-name", "time") .format("org.apache.spark.sql.riak") .load(bucketName) - .registerTempTable("test") + .createTempView("test") - val data = sqlContext + val data = sparkSession .sql( s""" | SELECT user_id @@ -297,9 +244,6 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { @Test def readColumnsWithoutSchema(): Unit = { - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ - val rdd = sc.riakTSTable[org.apache.spark.sql.Row](bucketName) .select("time", "user_id", "temperature_k") .where(s"time >= $queryFromMillis AND time <= $queryToMillis AND surrogate_key = 1 AND family = 'f'") @@ -320,9 +264,6 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { @Test def readColumnsWithSchema(): Unit = { - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ - val structType = StructType(List( StructField(name = "time", dataType = LongType), StructField(name = "user_id", dataType = StringType), @@ -401,9 +342,6 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { @Test def readBySchemaWithoutDefinedColumns(): Unit = { - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ - val structType = StructType(List( StructField(name = "time", dataType = LongType), StructField(name = "user_id", dataType = StringType), @@ -448,11 +386,9 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { @Test def dataFrameReadShouldHandleTimestampAsLong(): Unit = { - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ - val df = sqlContext.read + val df = sparkSession.read .format("org.apache.spark.sql.riak") .option("spark.riakts.bindings.timestamp", "useLong") .option("spark.riak.partitioning.ts-range-field-name", "time") @@ -481,21 +417,18 @@ class TimeSeriesReadTest extends AbstractTimeSeriesTest { @Category(Array(classOf[RiakTSTests])) class TimeSeriesReadWithoutSchemaTest extends AbstractTimeSeriesTest { + import sparkSession.implicits._ @Test def riakTSRDDToDataFrame(): Unit = { - val sqlContext = new org.apache.spark.sql.SQLContext(sc) - - import sqlContext.implicits._ - val df = sc.riakTSTable[org.apache.spark.sql.Row](bucketName) .sql(s"SELECT time, user_id, temperature_k FROM $bucketName $sqlWhereClause") .map(r => TimeSeriesData(r.getTimestamp(0).getTime, r.getString(1), r.getDouble(2))) .toDF() - df.registerTempTable("test") + df.createTempView("test") - val data = sqlContext.sql("select * from test").toJSON.collect() + val data = sparkSession.sql("select * from test").toJSON.collect() // -- verification assertEqualsUsingJSONIgnoreOrder( @@ -512,14 +445,11 @@ class TimeSeriesReadWithoutSchemaTest extends AbstractTimeSeriesTest { @Test def dataFrameReadShouldHandleTimestampAsTimestamp(): Unit = { - val sqlContext = new SQLContext(sc) - sqlContext.udf.register("getMillis", getMillis) // transforms timestamp to not deal with timezones - - import sqlContext.implicits._ + sparkSession.udf.register("getMillis", getMillis) // transforms timestamp to not deal with timezones val udfGetMillis = udf(getMillis) - val df = sqlContext.read + val df = sparkSession.read .format("org.apache.spark.sql.riak") .option("spark.riakts.bindings.timestamp", "useTimestamp") .load(bucketName) diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesWriteTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesWriteTest.scala index c8f3049c..c0887869 100644 --- a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesWriteTest.scala +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesWriteTest.scala @@ -23,7 +23,8 @@ import com.basho.riak.spark.writer.WriteDataMapperFactory._ import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions.udf import org.junit.Assert._ import org.junit.Test import org.junit.experimental.categories.Category @@ -34,6 +35,7 @@ import com.basho.riak.spark.util.TSConversionUtil */ @Category(Array(classOf[RiakTSTests])) class TimeSeriesWriteTest extends AbstractTimeSeriesTest(false) { + import sparkSession.implicits._ @Test def saveSqlRowsToRiak(): Unit = { @@ -80,8 +82,7 @@ class TimeSeriesWriteTest extends AbstractTimeSeriesTest(false) { @Test def saveDataFrameWithSchemaToRiak(): Unit = { - val sqlContext = new SQLContext(sc) - val sourceDF = getSourceDF(sqlContext) + val sourceDF = getSourceDF(sparkSession) sourceDF.rdd.saveToRiakTS(DEFAULT_TS_NAMESPACE.getBucketTypeAsString) // -- verification @@ -104,14 +105,9 @@ class TimeSeriesWriteTest extends AbstractTimeSeriesTest(false) { @Test def dataFrameGenericSave(): Unit = { - val sqlContext = new SQLContext(sc) - - import org.apache.spark.sql.functions.udf - import sqlContext.implicits._ - val udfGetMillis = udf(getMillis) - val sourceDF = getSourceDF(sqlContext) + val sourceDF = getSourceDF(sparkSession) sourceDF.write .format("org.apache.spark.sql.riak") @@ -119,7 +115,7 @@ class TimeSeriesWriteTest extends AbstractTimeSeriesTest(false) { .save(bucketName) // -- verification - val df = sqlContext.read + val df = sparkSession.read .format("org.apache.spark.sql.riak") .schema(schema) .load(bucketName) @@ -143,14 +139,10 @@ class TimeSeriesWriteTest extends AbstractTimeSeriesTest(false) { @Test def dataFrameWriteWithTimeFieldAsLongShouldPass(): Unit = { - val sqlContext = new SQLContext(sc) - import org.apache.spark.sql.functions.udf - import sqlContext.implicits._ - val udfGetMillis = udf(getMillis) - val sourceDF = getSourceDF(sqlContext, StructType(List( + val sourceDF = getSourceDF(sparkSession, StructType(List( StructField(name = "surrogate_key", dataType = LongType), StructField(name = "family", dataType = StringType), StructField(name = "time", dataType = LongType), @@ -164,7 +156,7 @@ class TimeSeriesWriteTest extends AbstractTimeSeriesTest(false) { .save(bucketName) // -- verification - val df = sqlContext.read + val df = sparkSession.read .format("org.apache.spark.sql.riak") .schema(schema) .load(bucketName) @@ -188,15 +180,10 @@ class TimeSeriesWriteTest extends AbstractTimeSeriesTest(false) { @Test def dataFrameWriteWithEmptyCells(): Unit = { - val sqlContext = new SQLContext(sc) - - import org.apache.spark.sql.functions.udf - import sqlContext.implicits._ - val udfGetMillis = udf(getMillis) val tsRows = Seq[org.apache.spark.sql.Row] ( - org.apache.spark.sql.Row(2L, "f", 111111L, "test", None), + org.apache.spark.sql.Row(2L, "f", 111111L, "test", null), org.apache.spark.sql.Row(2L, "f", 111222L, "test", 123.123), org.apache.spark.sql.Row(2L, "f", 111333L, "test", 345.34) ) @@ -207,14 +194,14 @@ class TimeSeriesWriteTest extends AbstractTimeSeriesTest(false) { StructField(name = "user_id", dataType = StringType), StructField(name = "temperature_k", dataType = DoubleType))) - val initialDF = getInitialDF(sqlContext, schema, tsRows) + val initialDF = getInitialDF(sparkSession, schema, tsRows) initialDF.write .format("org.apache.spark.sql.riak") .mode(SaveMode.Append) .save(bucketName) - val df = sqlContext.read + val df = sparkSession.read .format("org.apache.spark.sql.riak") .schema(schema) .load(bucketName) @@ -237,12 +224,11 @@ class TimeSeriesWriteTest extends AbstractTimeSeriesTest(false) { expectedException.expect(classOf[SparkException]) expectedException.expectMessage("Invalid data found at row index(es)") - val sqlContext = new SQLContext(sc) val tsRows = Seq[org.apache.spark.sql.Row] ( - org.apache.spark.sql.Row(2L, "f", None, "test", 123.123) + org.apache.spark.sql.Row(2L, "f", null, "test", 123.123) ) - val initialDF = getInitialDF(sqlContext, StructType(List( + val initialDF = getInitialDF(sparkSession, StructType(List( StructField(name = "surrogate_key", dataType = LongType), StructField(name = "family", dataType = StringType), StructField(name = "time", dataType = LongType), @@ -260,12 +246,11 @@ class TimeSeriesWriteTest extends AbstractTimeSeriesTest(false) { expectedException.expect(classOf[SparkException]) expectedException.expectMessage("Invalid data found at row index(es)") - val sqlContext = new SQLContext(sc) val tsRows = Seq[org.apache.spark.sql.Row] ( - org.apache.spark.sql.Row(None, "f", 111222L, "test", 123.123) + org.apache.spark.sql.Row(null, "f", 111222L, "test", 123.123) ) - val initialDF = getInitialDF(sqlContext, StructType(List( + val initialDF = getInitialDF(sparkSession, StructType(List( StructField(name = "surrogate_key", dataType = LongType), StructField(name = "family", dataType = StringType), StructField(name = "time", dataType = LongType), @@ -283,12 +268,11 @@ class TimeSeriesWriteTest extends AbstractTimeSeriesTest(false) { expectedException.expect(classOf[SparkException]) expectedException.expectMessage("Invalid data found at row index(es)") - val sqlContext = new SQLContext(sc) val tsRows = Seq[org.apache.spark.sql.Row] ( - org.apache.spark.sql.Row(2L, None, 111222L, "test", 123.123) + org.apache.spark.sql.Row(2L, null, 111222L, "test", 123.123) ) - val initialDF = getInitialDF(sqlContext, StructType(List( + val initialDF = getInitialDF(sparkSession, StructType(List( StructField(name = "surrogate_key", dataType = LongType), StructField(name = "family", dataType = StringType), StructField(name = "time", dataType = LongType), @@ -301,14 +285,14 @@ class TimeSeriesWriteTest extends AbstractTimeSeriesTest(false) { .save(bucketName) } - private def getSourceDF(sqlContext: SQLContext, structType:StructType = schema): DataFrame = { + private def getSourceDF(sparkSession: SparkSession, structType:StructType = schema): DataFrame = { val sparkRowsWithSchema = riakTSRows.map( r => TSConversionUtil.asSparkRow(structType, r)) - val rdd: RDD[Row] = sqlContext.sparkContext.parallelize(sparkRowsWithSchema) - sqlContext.createDataFrame(rdd, structType) + val rdd: RDD[Row] = sparkSession.sparkContext.parallelize(sparkRowsWithSchema) + sparkSession.createDataFrame(rdd, structType) } - private def getInitialDF(sqlContext: SQLContext, structType:StructType = schema, rows: Seq[Row]): DataFrame = { - val rdd: RDD[Row] = sqlContext.sparkContext.parallelize(rows) - sqlContext.createDataFrame(rdd, structType) + private def getInitialDF(sparkSession: SparkSession, structType:StructType = schema, rows: Seq[Row]): DataFrame = { + val rdd: RDD[Row] = sparkSession.sparkContext.parallelize(rows) + sparkSession.createDataFrame(rdd, structType) } } diff --git a/connector/src/test/scala/com/basho/riak/spark/streaming/SocketStreamingDataSource.scala b/connector/src/test/scala/com/basho/riak/spark/streaming/SocketStreamingDataSource.scala index d8406433..3761d7a3 100644 --- a/connector/src/test/scala/com/basho/riak/spark/streaming/SocketStreamingDataSource.scala +++ b/connector/src/test/scala/com/basho/riak/spark/streaming/SocketStreamingDataSource.scala @@ -4,7 +4,7 @@ import java.net.InetSocketAddress import java.nio.channels.{AsynchronousCloseException, AsynchronousServerSocketChannel, AsynchronousSocketChannel, CompletionHandler} import com.basho.riak.stub.SocketUtils -import org.apache.spark.Logging +import org.apache.spark.riak.Logging class SocketStreamingDataSource extends Logging { diff --git a/connector/src/test/scala/com/basho/riak/spark/streaming/SparkStreamingFixture.scala b/connector/src/test/scala/com/basho/riak/spark/streaming/SparkStreamingFixture.scala index 1aa6838c..3806f9cf 100644 --- a/connector/src/test/scala/com/basho/riak/spark/streaming/SparkStreamingFixture.scala +++ b/connector/src/test/scala/com/basho/riak/spark/streaming/SparkStreamingFixture.scala @@ -1,6 +1,7 @@ package com.basho.riak.spark.streaming -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.riak.Logging +import org.apache.spark.SparkContext import org.apache.spark.streaming.{Seconds, StreamingContext} import org.junit.{After, Before} diff --git a/connector/src/test/scala/org/apache/spark/sql/riak/OptionsTest.scala b/connector/src/test/scala/org/apache/spark/sql/riak/OptionsTest.scala index ef76fcfe..586bba71 100644 --- a/connector/src/test/scala/org/apache/spark/sql/riak/OptionsTest.scala +++ b/connector/src/test/scala/org/apache/spark/sql/riak/OptionsTest.scala @@ -18,22 +18,18 @@ package org.apache.spark.sql.riak import scala.collection.JavaConversions.asScalaBuffer - -import org.apache.spark.{ SparkConf, SparkContext } import org.apache.spark.sql._ -import org.apache.spark.sql.types.{ StringType, StructField, StructType } -import org.junit.{ After, Before, Test } +import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.junit.{After, Before, Test} import org.junit.Assert._ - import com.basho.riak.client.core.RiakNode import com.basho.riak.client.core.util.HostAndPort -import com.basho.riak.spark.rdd.connector.{ RiakConnector, RiakConnectorConf } +import com.basho.riak.spark.rdd.connector.{RiakConnector, RiakConnectorConf} +import org.apache.spark.SparkConf class OptionsTest { - private val source = new DefaultSource - private var sqlContext: SQLContext = _ private val initialHost = "default:1111" private val initialConnectionsMin = 111 private val initialConnectionsMax = 999 @@ -54,26 +50,20 @@ class OptionsTest { .set("spark.riak.input.fetch-size", initialFetchSize.toString) .set("spark.riak.input.split.count", initialSplitCount.toString) - private val dummySchema = StructType(List(StructField("dummy", StringType, nullable = true))) - private var df: DataFrame = _ + protected def createSparkSession(conf: SparkConf): SparkSession = SparkSession.builder().config(conf).getOrCreate() - @Before - def initializeDF(): Unit = { - val conf = initSparkConf - val sc = new SparkContext(conf) - sqlContext = new org.apache.spark.sql.SQLContext(sc) - - df = sqlContext.createDataFrame(sc.emptyRDD[Row], dummySchema) - } + private val dummySchema = StructType(List(StructField("dummy", StringType, nullable = true))) + private val sparkSession: SparkSession = createSparkSession(initSparkConf) + private var df: DataFrame = sparkSession.createDataFrame(sparkSession.sparkContext.emptyRDD[Row], dummySchema) @After def destroySparkContext(): Unit = { - Option(sqlContext).foreach(sqlc => sqlc.sparkContext.stop()) + Option(sparkSession).foreach(sqlc => sqlc.sparkContext.stop()) } @Test def noReadOptionsShouldResultInKeepingInitialProperties(): Unit = { - val rel = source.createRelation(sqlContext, + val rel = source.createRelation(sparkSession.sqlContext, Map("path" -> "path"), dummySchema).asInstanceOf[RiakRelation] val riakConnector = getConnector(rel) val riakConf = getRiakConnectorConf(riakConnector) @@ -87,7 +77,7 @@ class OptionsTest { @Test def noWriteOptionsShouldResultInKeepingInitialProperties(): Unit = { - val rel = source.createRelation(sqlContext, SaveMode.Append, + val rel = source.createRelation(sparkSession.sqlContext, SaveMode.Append, Map("path" -> "path"), df).asInstanceOf[RiakRelation] val riakConnector = getConnector(rel) val writeConf = rel.writeConf @@ -102,7 +92,7 @@ class OptionsTest { @Test def writeOptionsOnReadShouldNotAffectProperties(): Unit = { val newQuorum = 1 - val rel = source.createRelation(sqlContext, + val rel = source.createRelation(sparkSession.sqlContext, Map("path" -> "path", "spark.riak.write.replicas" -> newQuorum.toString), dummySchema).asInstanceOf[RiakRelation] val writeConf = rel.writeConf } @@ -111,7 +101,7 @@ class OptionsTest { def readOptionsOnWriteShouldNotAffectProperties(): Unit = { val newFetchSize = 100 val newSplitCount = 10 - val rel = source.createRelation(sqlContext, SaveMode.Append, + val rel = source.createRelation(sparkSession.sqlContext, SaveMode.Append, Map("path" -> "path", "spark.riak.input.fetch-size" -> newFetchSize.toString, "spark.riak.input.split.count" -> newSplitCount.toString), df).asInstanceOf[RiakRelation] val readConf = rel.readConf @@ -120,7 +110,7 @@ class OptionsTest { @Test def writeOptionsOnWriteShouldAffectProperties(): Unit = { val newQuorum = 1 - val rel = source.createRelation(sqlContext, SaveMode.Append, + val rel = source.createRelation(sparkSession.sqlContext, SaveMode.Append, Map("path" -> "path", "spark.riak.write.replicas" -> newQuorum.toString), df).asInstanceOf[RiakRelation] val writeConf = rel.writeConf assertEquals(newQuorum, writeConf.writeReplicas.toInt) @@ -130,7 +120,7 @@ class OptionsTest { def readOptionsOnReadShouldAffectProperties(): Unit = { val newFetchSize = 100 val newSplitCount = 10 - val rel = source.createRelation(sqlContext, Map("path" -> "path", "spark.riak.input.fetch-size" -> newFetchSize.toString, + val rel = source.createRelation(sparkSession.sqlContext, Map("path" -> "path", "spark.riak.input.fetch-size" -> newFetchSize.toString, "spark.riak.input.split.count" -> newSplitCount.toString), dummySchema).asInstanceOf[RiakRelation] val readConf = rel.readConf assertEquals(newFetchSize, readConf.fetchSize) @@ -142,7 +132,7 @@ class OptionsTest { val newHost = "newHost:9999" val newConnectionsMin = 1 val newConnectionsMax = 9 - val rel = source.createRelation(sqlContext, + val rel = source.createRelation(sparkSession.sqlContext, Map("path" -> "path", "spark.riak.connection.host" -> newHost, "spark.riak.connections.min" -> newConnectionsMin.toString, "spark.riak.connections.max" -> newConnectionsMax.toString), dummySchema).asInstanceOf[RiakRelation] @@ -156,7 +146,7 @@ class OptionsTest { @Test def riakConnectionOptionsShouldChangeOnlySpecifiedProperties(): Unit = { val newHost = "newHost:9999" - val rel = source.createRelation(sqlContext, + val rel = source.createRelation(sparkSession.sqlContext, Map("path" -> "path", "spark.riak.connection.host" -> newHost), dummySchema).asInstanceOf[RiakRelation] val riakConnector = getConnector(rel) val riakConf = getRiakConnectorConf(riakConnector) diff --git a/docs/quick-start.md b/docs/quick-start.md index 63dacbd8..e23577cd 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -185,10 +185,10 @@ import java.sql.Timestamp import com.basho.riak.spark.rdd.connector.RiakConnector ``` -Then, set up implicits for Spark sqlContext: +Then, set up implicits for Spark: ```scala -import sqlContext.implicits._ +import sparkSession.implicits._ ``` ###### Create an RDD with some timeseries data: @@ -243,7 +243,7 @@ And, finally, check that the table was successfully written into the Riak TS tab val test_query = "ts >= CAST('1980-1-1 10:00:00' AS TIMESTAMP) AND ts <= CAST('1980-1-1 10:30:00' AS TIMESTAMP) AND k = 1 AND family = 'f'" -val df2 = sqlContext.read.format("org.apache.spark.sql.riak").load(tableName).filter(test_query) +val df2 = sparkSession.read.format("org.apache.spark.sql.riak").load(tableName).filter(test_query) df2.show() ``` @@ -446,8 +446,7 @@ df.write \ Lets check that the write was successful by reading the TS table into a new DataFrame: ```python -sqlContext = SQLContext(sc) -df2 = sqlContext.read\ +df2 = sparkSession.read\ .format("org.apache.spark.sql.riak")\ .option("spark.riak.connection.host", hostAndPort)\ .option("spark.riakts.bindings.timestamp", "useLong")\ @@ -500,8 +499,8 @@ You should see something like this: ###### Register the DataFrame as a temp sql table and run a sql query to obtain the average of the "value" column: ```python -df2.registerTempTable("pyspark_tmp") -sqlContext.sql("select avg(value) as average_value from pyspark_tmp").show() +df2.createOrReplaceTempView("pyspark_tmp") +sparkSession.sql("select avg(value) as average_value from pyspark_tmp").show() ``` You should see something similar to this: diff --git a/docs/using-connector.md b/docs/using-connector.md index ec735e85..87716458 100644 --- a/docs/using-connector.md +++ b/docs/using-connector.md @@ -12,6 +12,7 @@ Scroll down or click below for the desired information: - [Writing Data To TS Table](./using-connector.md#writing-data-to-ts-table) - [Spark Dataframes With KV Bucket](./using-connector.md#spark-dataframes-with-kv-bucket) - [Spark Dataframes With TS Table](./using-connector.md#spark-dataframes-with-ts-table) +- [Spark DataSets With TS Table](./using-connector.md#spark-datasets-with-ts-table) - [Partitioning for KV Buckets](./using-connector.md#partitioning-for-kv-buckets) - [Working With TS Dates](./using-connector.md#working-with-ts-dates) - [Partitioning for Riak TS Table Queries](./using-connector.md#partitioning-for-riak-ts-table-queries) @@ -29,14 +30,16 @@ The following import statements should be included at the top of your Spark appl **Scala** ```scala import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.sql.SparkSession import com.basho.riak.spark._ ``` **Python** ```python import pyspark import pyspark_riak +import pyspark_riak.sql ``` -You can control how your Spark application interacts with Riak by configuring different options for your `SparkContext` or `SQLContext`. You can set these options within the $SPARK_HOME/conf/spark-default.conf. If you don't set an option, it will be automatically set to the default values listed below. +You can control how your Spark application interacts with Riak by configuring different options for your `SparkSession`. You can set these options within the $SPARK_HOME/conf/spark-default.conf. If you don't set an option, it will be automatically set to the default values listed below. You can set the below options for the `SparkConf` object: @@ -63,8 +66,12 @@ val conf = new SparkConf() .set("spark.riak.connection.host", "127.0.0.1:8087") .set("spark.riak.connections.min", "20") .set("spark.riak.connections.max", "50") - -val sc = new SparkContext("spark://127.0.0.1:7077", "test", conf) +val sparkSession = SparkSession.builder() + .master("spark://127.0.0.1:7077") + .appName("test") + .config(conf) + .getOrCreate() +val sc = sparkSession.sparkContext ``` **Python** @@ -265,16 +272,20 @@ rdd.saveToRiakTS(output_ts_table); ## Spark Dataframes With KV Bucket -You can use Spark DataFrames on top of an RDD that was created from a KV Bucket. First you need to create a SQLContext from SparkContext: +You can use Spark DataFrames on top of an RDD that was created from a KV Bucket. The entry point to programming Spark with the Dataset and DataFrame API is [SparkSession](https://spark.apache.org/docs/2.0.0/api/java/org/apache/spark/sql/SparkSession.html) ```scala -val sqlContext = new org.apache.spark.sql.SQLContext(sc) +val sparkSession = SparkSession.builder() + .master(...) + .appName(...) + .config(...) + .getOrCreate() ``` Then import: ```scala -import sqlContext.implicits._ +import sparkSession.implicits._ ``` Next, you have to specify a user defined type to allow schema inference using reflection: @@ -307,26 +318,26 @@ df.groupBy("category").count Alternatively, you can register a table ```scala -df.registerTempTable("users") +df.createOrReplaceTempView("users") ``` and use Spark SQL queries over it. ```scala -sqlContext.sql("select * from users where age >= 50") +sparkSession.sql("select * from users where age >= 50") ``` Another thing you can use are user defined functions (UDFs). First, you have to register a UDF. ```scala -sqlContext.udf.register("stringLength", (s: String) => s.length) +sparkSession.udf.register("stringLength", (s: String) => s.length) ``` After that you can use it in SQL queries ```scala -sqlContext.sql("select user_id, name, stringLength(name) nameLength from users order by nameLength") +sparkSession.sql("select user_id, name, stringLength(name) nameLength from users order by nameLength") ``` When you already have a DataFrame, you can save it into Riak. To do that, make sure you have imported `com.basho.riak.spark._` so that saveToRiak() method is available. @@ -353,23 +364,22 @@ To enable DataFrames functionality, first steps are **Scala** ```scala -val sc = new SparkContext() -val sqlContext = new org.apache.spark.sql.SQLContext(sc) -import sqlContext.implicits._ +val sparkSession = val sparkSession = SparkSession.builder().getOrCreate() +import sparkSession.implicits._ ts_table_name = "test_table" ``` **Python** ```python -sc = pyspark.SparkContext(conf=conf) -sqlContext = pyspark.SQLContext(sc) +sparkSession = SparkSession.builder.getOrCreate() +sc = sparkSession.sparkContext ts_table_name = "test_table" ``` -To read data from existing TS table `test-table` standard SQLContext means can be used by providing a special `“org.apache.spark.sql.riak”` data format and using a Riak TS range query: +To read data from existing TS table `test-table` standard SparkSession can be used by providing a special `“org.apache.spark.sql.riak”` data format and using a Riak TS range query: **Scala** ```scala -val df = sqlContext.read +val df = sparkSession.read .option("spark.riak.connection.hosts","riak_host_ip:10017") .format("org.apache.spark.sql.riak") .load(ts_table_name) @@ -378,7 +388,7 @@ val df = sqlContext.read ``` **Python** ```python -df = sqlContext.read \ +df = sparkSession.read \ .option("spark.riak.connection.hosts","riak_host_ip:10017") \ .format("org.apache.spark.sql.riak") \ .load(ts_table_name) \ @@ -386,13 +396,7 @@ df = sqlContext.read \ .filter(s"time >= CAST($from AS TIMESTAMP) AND time <= CAST($to AS TIMESTAMP) AND col1= $value1") ``` -Schema may or may not be provided using `.schema()` method. If not provided, it will be inferred. Any of the Spark Connector options can be provided in `.option()` or `.options()`. Alternatively, `org.apache.spark.sql.riak.RiakSQLContext` can be created and then queried with range query using `sql()` method - -**Scala** -```scala -val riakSqlContext = new RiakSQLContext(sc, ts_table_name) -val alternativeDf = riakSqlContext.sql(s"SELECT time, col1 from $ts_table_name WHERE time >= CAST($from AS TIMESTAMP) AND time <= CAST($to AS TIMESTAMP) AND col1= $value1") -``` +Schema may or may not be provided using `.schema()` method. If not provided, it will be inferred. Any of the Spark Connector options can be provided in `.option()` or `.options()`. A DataFrame, `inputDF`, that has the same schema as an existing TS table (column order and types) can be saved to Riak TS as follows: @@ -416,6 +420,31 @@ inputDF.write \ So far SaveMode.Append is the only mode available. Any of the Spark Connector options can be provided in `.option()` or `.options()`. +## Spark Datasets With TS Table +Spark Datasets aka strongly typed Dataframes might be created in a very similar manner to the dataframe, there are only two difference: + +* Datasets requires to have an Encoder; builtin encoders for common Scala types and their product types are already available in implicits object, and you only need to import these implicits as follows: +```scala +import spark.implicits._ +``` + +* the data type should be provided by calling `as()` routine + +Here is an example of a Dataset creation: +```scala +import spark.implicits._ + +case class TimeSeriesData(time: Long, user_id: String, temperature_k: Double) + +val ds = sparkSession.read + .format("org.apache.spark.sql.riak") + .option("spark.riakts.bindings.timestamp", "useLong") + .load(bucketName) + .filter(filterExpression) + .as[TimeSeriesData] +``` + +NOTE: There is no Datasets support for Python since Spark does not support this. ## Partitioning for KV Buckets @@ -436,7 +465,10 @@ val conf = new SparkConf() .setAppName("My Spark Riak App") .set("spark.riak.input.split.count", "10") -val sc = new SparkContext(conf) +val sparkSession = SparkSession.builder() + .config(sparkConf) + .getOrCreate() +val sc = sparkSession.sparkContext ... sc.riakBucket[UserTS](DEFAULT_NAMESPACE) .query2iRange(CREATION_INDEX, 100L, 200L) @@ -482,7 +514,7 @@ val schemaWithLong = StructType(List(       StructField(name = "temperature_k", dataType = DoubleType))     ) -    val df = sqlContext.read +    val df = sparkSession.read       .format("org.apache.spark.sql.riak")       .schema(newSchema)       .load(tableName) @@ -492,7 +524,7 @@ val schemaWithLong = StructType(List( You can use `spark.riakts.bindings.timestamp` and Automatic Schema Discovery with `useLong`: ```scala -val df = sqlContext.read +val df = sparkSession.read       .format("org.apache.spark.sql.riak")       .option("spark.riakts.bindings.timestamp", "useLong")       .load(tableName) @@ -503,7 +535,7 @@ In the previous example, the query times, `queryFromMillis` and `queryToMillis`, Or, you can use `spark.riakts.bindings.timestamp` and Automatic Schema Discovery with `useTimestamp`: ```scala -val df = sqlContext.read +val df = sparkSession.read       .format("org.apache.spark.sql.riak")       .option("spark.riakts.bindings.timestamp", "useTimestamp")       .load(tableName) @@ -536,7 +568,7 @@ For example: **Scala** ```scala - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", "5") .option("spark.riak.partitioning.ts-range-field-name", "time") .format("org.apache.spark.sql.riak") @@ -546,7 +578,7 @@ For example: ``` **Python** ```python -df = sqlContext.read \ +df = sparkSession.read \ .option("spark.riak.input.split.count", "5") \ .option("spark.riak.partitioning.ts-range-field-name", "time") \ .format("org.apache.spark.sql.riak") \ @@ -566,7 +598,7 @@ The initial range query will be split into 5 subqueries (one per each partition) An additional option spark.riak.partitioning.ts-quantum can be passed to notify the Spark-Riak Connector of the quantum size. If the automatically created subranges break the 5 quanta limitation, the initial range will be split into ~4 quantum subranges and the resulting subranges will then be grouped to form the required number of partitions. **Scala** ```scala - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.input.split.count", "5") .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) .option("spark.riak.partitioning.ts-quantum", "5s") @@ -577,7 +609,7 @@ An additional option spark.riak.partitioning.ts-quantum can be passed to notify ``` **Python** ```python -df = sqlContext.read \ +df = sparkSession.read \ .option("spark.riak.input.split.count", "5") \ .option("spark.riak.partitioning.ts-range-field-name", tsRangeFieldName) \ .option("spark.riak.partitioning.ts-quantum", "5s") \ @@ -616,7 +648,7 @@ Or you can set the `spark.riakts.write.bulk-size` property in the DataFrame's `. **Scala** ```scala -val df = sqlContext.write +val df = sparkSession.write .option("spark.riakts.write.bulk-size", "500") .format("org.apache.spark.sql.riak") .mode(SaveMode.Append) @@ -624,7 +656,7 @@ val df = sqlContext.write ``` **Python** ```python -df = sqlContext.write +df = sparkSession.write .option("spark.riakts.write.bulk-size", "500") .format("org.apache.spark.sql.riak") .mode(SaveMode.Append) @@ -637,7 +669,7 @@ Bulks will be written in parallel. The number of parallel writes for each partit ```scala val conf = new SparkConf() .set("spark.riakts.write.bulk-size", "500") - .set("spark.riak.connections.min", "50") + .set("spark.riak.connections.min", "50") ``` **Python** ```python diff --git a/examples/src/main/java/com/basho/riak/spark/examples/SimpleJavaRiakExample.java b/examples/src/main/java/com/basho/riak/spark/examples/SimpleJavaRiakExample.java index 52803983..6674b28b 100644 --- a/examples/src/main/java/com/basho/riak/spark/examples/SimpleJavaRiakExample.java +++ b/examples/src/main/java/com/basho/riak/spark/examples/SimpleJavaRiakExample.java @@ -12,6 +12,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; import java.io.IOException; import java.io.Serializable; diff --git a/examples/src/main/scala/com/basho/riak/spark/examples/SimpleScalaRiakExample.scala b/examples/src/main/scala/com/basho/riak/spark/examples/SimpleScalaRiakExample.scala index 2395b00c..abe0c14c 100644 --- a/examples/src/main/scala/com/basho/riak/spark/examples/SimpleScalaRiakExample.scala +++ b/examples/src/main/scala/com/basho/riak/spark/examples/SimpleScalaRiakExample.scala @@ -19,8 +19,9 @@ package com.basho.riak.spark.examples import com.basho.riak.client.core.query.Namespace import com.basho.riak.spark.rdd.RiakFunctions -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} import com.basho.riak.spark._ +import org.apache.spark.sql.SparkSession /** * Really simple demo program which calculates the number of records loaded @@ -52,7 +53,8 @@ object SimpleScalaRiakExample { println(s"Writing test data to Riak: \n $TEST_DATA") createTestData(sparkConf) - val sc = new SparkContext(sparkConf) + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + val sc = sparkSession.sparkContext val rdd = sc.riakBucket("test-data") .queryAll() diff --git a/examples/src/main/scala/com/basho/riak/spark/examples/SimpleScalaRiakTSExample.scala b/examples/src/main/scala/com/basho/riak/spark/examples/SimpleScalaRiakTSExample.scala index 44c1e054..2dc98410 100644 --- a/examples/src/main/scala/com/basho/riak/spark/examples/SimpleScalaRiakTSExample.scala +++ b/examples/src/main/scala/com/basho/riak/spark/examples/SimpleScalaRiakTSExample.scala @@ -25,12 +25,15 @@ import com.basho.riak.client.core.util.BinaryValue import com.basho.riak.spark.rdd.RiakFunctions import com.basho.riak.spark.toSparkContextFunctions import java.util.Calendar + import com.basho.riak.spark.rdd.RiakObjectData import com.basho.riak.client.core.operations.ts.StoreOperation + import scala.collection.JavaConversions._ import com.basho.riak.client.core.query.Namespace import com.basho.riak.spark.util.RiakObjectConversionUtil import com.basho.riak.client.core.query.indexes.LongIntIndex +import org.apache.spark.sql.SparkSession /** * Really simple demo timeseries-related features @@ -78,7 +81,8 @@ object SimpleScalaRiakTSExample { clearBucket(sparkConf) loadDemoData(sparkConf) - val sc = new SparkContext(sparkConf) + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + val sc = sparkSession.sparkContext val from = beginingOfQuantumMillis(testData.head.time) val to = endOfQuantumMillis(testData.last.time) diff --git a/examples/src/main/scala/com/basho/riak/spark/examples/dataframes/SimpleScalaRiakDataframesExample.scala b/examples/src/main/scala/com/basho/riak/spark/examples/dataframes/SimpleScalaRiakDataframesExample.scala index e8967fd7..3faa3a99 100644 --- a/examples/src/main/scala/com/basho/riak/spark/examples/dataframes/SimpleScalaRiakDataframesExample.scala +++ b/examples/src/main/scala/com/basho/riak/spark/examples/dataframes/SimpleScalaRiakDataframesExample.scala @@ -23,14 +23,16 @@ import com.basho.riak.spark._ import com.basho.riak.spark.util.RiakObjectConversionUtil import org.apache.spark.SparkConf import org.apache.spark.SparkContext + import scala.reflect.runtime.universe import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global -import scala.util.{ Failure, Success } +import scala.util.{Failure, Success} import com.basho.riak.client.core.query.RiakObject import com.basho.riak.client.api.RiakClient import com.basho.riak.client.core.query.Location import com.basho.riak.spark.rdd.RiakFunctions +import org.apache.spark.sql.SparkSession /** * Example shows how Spark DataFrames can be used with Riak @@ -56,14 +58,13 @@ object SimpleScalaRiakDataframesExample { setSparkOpt(sparkConf, "spark.master", "local") setSparkOpt(sparkConf, "spark.riak.connection.host", "127.0.0.1:8087") - val sc = new SparkContext(sparkConf) + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + val sc = sparkSession.sparkContext // Work with clear bucket clearBucket(sparkConf) - val sqlContext = new org.apache.spark.sql.SQLContext(sc) - // To enable toDF() - import sqlContext.implicits._ + import sparkSession.implicits._ println(s" Saving data to Riak: \n ${println(testData)}") @@ -83,18 +84,18 @@ object SimpleScalaRiakDataframesExample { println(s"Dataframe from Riak query: \n ${df.show()}") - df.registerTempTable("users") + df.createTempView("users") println("count by category") df.groupBy("category").count.show println("sort by num of letters") // Register user defined function - sqlContext.udf.register("stringLength", (s: String) => s.length) - sqlContext.sql("select user_id, name, stringLength(name) nameLength from users order by nameLength").show + sparkSession.udf.register("stringLength", (s: String) => s.length) + sparkSession.sql("select user_id, name, stringLength(name) nameLength from users order by nameLength").show println("filter age >= 21") - sqlContext.sql("select * from users where age >= 21").show + sparkSession.sql("select * from users where age >= 21").show } diff --git a/examples/src/main/scala/com/basho/riak/spark/examples/dataframes/SimpleScalaRiakTSDataframesExample.scala b/examples/src/main/scala/com/basho/riak/spark/examples/dataframes/SimpleScalaRiakTSDataframesExample.scala index d5c70bef..0827230b 100644 --- a/examples/src/main/scala/com/basho/riak/spark/examples/dataframes/SimpleScalaRiakTSDataframesExample.scala +++ b/examples/src/main/scala/com/basho/riak/spark/examples/dataframes/SimpleScalaRiakTSDataframesExample.scala @@ -21,9 +21,7 @@ import java.text.SimpleDateFormat import java.util.Date import org.apache.spark.SparkConf -import org.apache.spark.SparkContext -import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.riak.RiakSQLContext +import org.apache.spark.sql.{SaveMode, SparkSession} import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.StringType @@ -78,15 +76,15 @@ object SimpleScalaRiakTSDataframesExample { setSparkOpt(sparkConf, "spark.master", "local") setSparkOpt(sparkConf, "spark.riak.connection.host", "127.0.0.1:8087") - val sc = new SparkContext(sparkConf) + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + val sc = sparkSession.sparkContext - val sqlContext = new org.apache.spark.sql.SQLContext(sc) - import sqlContext.implicits._ + import sparkSession.implicits._ // Load test data from json file println("---------------------------------- input data -----------------------------------") val inputRDD = sc.parallelize(testData.split("\n")) - val inputDF = sqlContext.read.json(inputRDD) + val inputDF = sparkSession.read.json(inputRDD) .withColumn("time", 'time.cast("Timestamp")) // Timestamp types are not inferred when reading from JSON and need to be cast .select("weather", "family", "time", "temperature", "humidity", "pressure") // column ordering should be the same as in schema inputDF.printSchema @@ -108,7 +106,7 @@ object SimpleScalaRiakTSDataframesExample { // Simple Riak range query with schema provided println("---------------------- Range query with provided schema -------------------------") - val withSchemaProvided = sqlContext.read + val withSchemaProvided = sparkSession.read .format("org.apache.spark.sql.riak") .schema(schemaWithTimestamp) .load(tableName) @@ -118,7 +116,7 @@ object SimpleScalaRiakTSDataframesExample { // Simple Riak range query with schema provided and automatic timestamp to long conversion println("---Range query with provided schema and automatic timestamp to long conversion ---") - val withSchemaProvidedLongTime = sqlContext.read + val withSchemaProvidedLongTime = sparkSession.read .option("spark.riak.partitioning.ts-range-field-name", "time") .format("org.apache.spark.sql.riak") .schema(schemaWithLong) @@ -129,7 +127,7 @@ object SimpleScalaRiakTSDataframesExample { // Simple Riak range query without providing schema println("-------------------- Range query with inferred schema ---------------------------") - val df = sqlContext.read + val df = sparkSession.read .option("spark.riak.partitioning.ts-range-field-name", "time") .format("org.apache.spark.sql.riak") .load(tableName) @@ -139,7 +137,7 @@ object SimpleScalaRiakTSDataframesExample { // Simple Riak range query without providing schema and with useLong option for timestamp binding println("------ Range query with inferred schema and treating timestamps as Long (in milliseconds) ---------") - val dfUseLong = sqlContext.read + val dfUseLong = sparkSession.read .option("spark.riak.partitioning.ts-range-field-name", "time") .option("spark.riakts.bindings.timestamp", "useLong") // option to treat timestamps as Longs .format("org.apache.spark.sql.riak") @@ -147,14 +145,6 @@ object SimpleScalaRiakTSDataframesExample { .filter(s"time >= $fromMillis AND time <= $toMillis AND weather = 'sunny' AND family = 'f'") dfUseLong.printSchema dfUseLong.show - - // Alternative way to read data from Riak TS - println("-------------------------- Reading with RiakSQLContext --------------------------") - val riakSqlContext = new RiakSQLContext(sc) - val alternativeDf = riakSqlContext.sql( - s"SELECT * from $tableName WHERE time >= CAST('$from' AS TIMESTAMP) AND time <= CAST('$to' AS TIMESTAMP) AND weather = 'sunny' AND family = 'f'") - alternativeDf.printSchema - alternativeDf.show } private def setSparkOpt(sparkConf: SparkConf, option: String, defaultOptVal: String): SparkConf = { diff --git a/examples/src/main/scala/com/basho/riak/spark/examples/demos/fbl/FootballDemo.scala b/examples/src/main/scala/com/basho/riak/spark/examples/demos/fbl/FootballDemo.scala index f068d2f1..901389f4 100644 --- a/examples/src/main/scala/com/basho/riak/spark/examples/demos/fbl/FootballDemo.scala +++ b/examples/src/main/scala/com/basho/riak/spark/examples/demos/fbl/FootballDemo.scala @@ -1,21 +1,25 @@ package com.basho.riak.spark.examples.demos.fbl import java.util.Calendar -import java.util.concurrent.{TimeUnit, Semaphore} +import java.util.concurrent.{Semaphore, TimeUnit} + import com.basho.riak.client.api.commands.kv.StoreValue import com.basho.riak.client.api.commands.kv.StoreValue.Response import com.basho.riak.client.core.query.indexes.LongIntIndex import com.basho.riak.client.core.{RiakFuture, RiakFutureListener} import com.basho.riak.spark.util.RiakObjectConversionUtil -import org.slf4j.{LoggerFactory, Logger} +import org.slf4j.{Logger, LoggerFactory} import java.util.zip.ZipInputStream -import com.basho.riak.client.core.query.{Namespace, Location} + +import com.basho.riak.client.core.query.Location import com.basho.riak.spark.rdd._ import org.apache.spark.{SparkConf, SparkContext} + import scala.collection.mutable import scala.collection.mutable.ListBuffer import scala.io.Source import com.basho.riak.spark.rdd.RiakFunctions import com.basho.riak.spark._ +import org.apache.spark.sql.SparkSession object FootballDemo { private val logger: Logger = LoggerFactory.getLogger(this.getClass) @@ -92,7 +96,8 @@ object FootballDemo { sys.exit() } - val sc = new SparkContext(sparkConf) + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + val sc = sparkSession.sparkContext // -- Create test data createTestData(sc) diff --git a/examples/src/main/scala/com/basho/riak/spark/examples/demos/ofac/OFACDemo.scala b/examples/src/main/scala/com/basho/riak/spark/examples/demos/ofac/OFACDemo.scala index a0d422b1..71c413df 100644 --- a/examples/src/main/scala/com/basho/riak/spark/examples/demos/ofac/OFACDemo.scala +++ b/examples/src/main/scala/com/basho/riak/spark/examples/demos/ofac/OFACDemo.scala @@ -18,17 +18,17 @@ package com.basho.riak.spark.examples.demos.ofac import com.basho.riak.client.core.query.indexes.LongIntIndex -import com.basho.riak.spark.rdd.{RiakFunctions, BucketDef} +import com.basho.riak.spark.rdd.{BucketDef, RiakFunctions} import com.basho.riak.spark.util.RiakObjectConversionUtil -import com.basho.riak.spark.writer.{WriteDataMapperFactory, WriteDataMapper} -import org.slf4j.{LoggerFactory, Logger} +import com.basho.riak.spark.writer.{WriteDataMapper, WriteDataMapperFactory} +import org.slf4j.{Logger, LoggerFactory} import scala.io.Source import scala.annotation.meta.field - import com.basho.riak.spark._ -import com.basho.riak.client.core.query.{RiakObject, Namespace} -import com.basho.riak.client.api.annotations.{RiakKey, RiakIndex} +import com.basho.riak.client.core.query.{Namespace, RiakObject} +import com.basho.riak.client.api.annotations.{RiakIndex, RiakKey} +import org.apache.spark.sql.SparkSession import org.apache.spark.{SparkConf, SparkContext} object OFACDemo { @@ -70,7 +70,8 @@ object OFACDemo { setSparkOpt(sparkConf,"spark.riak.demo.to", CFG_DEFAULT_TO.toString) // -- Create spark context - val sc = new SparkContext(sparkConf) + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + val sc = sparkSession.sparkContext // -- Cleanup Riak buckets before we start val rf = RiakFunctions(sparkConf) diff --git a/examples/src/main/scala/com/basho/riak/spark/examples/parquet/ScalaRiakParquetExample.scala b/examples/src/main/scala/com/basho/riak/spark/examples/parquet/ScalaRiakParquetExample.scala index 9803e83f..c674eda6 100644 --- a/examples/src/main/scala/com/basho/riak/spark/examples/parquet/ScalaRiakParquetExample.scala +++ b/examples/src/main/scala/com/basho/riak/spark/examples/parquet/ScalaRiakParquetExample.scala @@ -16,8 +16,8 @@ * under the License. */ package com.basho.riak.spark.examples.parquet -import org.apache.spark.sql.{SaveMode, SQLContext} -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.sql.{SaveMode, SparkSession} +import org.apache.spark.SparkConf /** * Simple demo which illustrates how data can be extracted from Riak TS and saved as a parquet file @@ -60,37 +60,37 @@ object ScalaRiakParquetExample { setSparkOpt(sparkConf, "spark.riak.connection.host", "127.0.0.1:8087") println(s"Test data start time: $startDate") - val sc = new SparkContext(sparkConf) - val sqlCtx = SQLContext.getOrCreate(sc) + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + val sc = sparkSession.sparkContext - import sqlCtx.implicits._ + + import sparkSession.implicits._ val rdd = sc.parallelize(testData) rdd.toDF().write.format("org.apache.spark.sql.riak") .mode(SaveMode.Append).save(tableName) - val df = sqlCtx.read.format("org.apache.spark.sql.riak") - .load(tableName).registerTempTable(tableName) + val df = sparkSession.read.format("org.apache.spark.sql.riak") + .load(tableName).createOrReplaceTempView(tableName) val from = (startDate / 1000).toInt val query = s"select * from $tableName where measurementDate >= CAST($from AS TIMESTAMP) " + s"AND measurementDate <= CAST(${from + 1} AS TIMESTAMP) AND site = 'MY7' AND species = 'PM10'" println(s"Query: $query") - val rows = sqlCtx.sql(query) + val rows = sparkSession.sql(query) rows.show() val schema = rows.schema rows.write.mode("overwrite").parquet(parquetFileName) println(s"Data was successfully saved to Parquet file: $parquetFileName") - val parquetFile = sqlCtx.read.parquet(parquetFileName) - parquetFile.registerTempTable("parquetFile") - val data = sqlCtx.sql("SELECT MAX(value) max_value FROM parquetFile ") + val parquetFile = sparkSession.read.parquet(parquetFileName) + parquetFile.createTempView("parquetFile") + val data = sparkSession.sql("SELECT MAX(value) max_value FROM parquetFile ") println("Maximum value retrieved from Parquet file:") data.show() - } private def setSparkOpt(sparkConf: SparkConf, option: String, defaultOptVal: String): SparkConf = { diff --git a/examples/src/main/scala/com/basho/riak/spark/examples/streaming/StreamingKVExample.scala b/examples/src/main/scala/com/basho/riak/spark/examples/streaming/StreamingKVExample.scala index a8e51bb9..965f63b0 100644 --- a/examples/src/main/scala/com/basho/riak/spark/examples/streaming/StreamingKVExample.scala +++ b/examples/src/main/scala/com/basho/riak/spark/examples/streaming/StreamingKVExample.scala @@ -6,6 +6,7 @@ import kafka.serializer.StringDecoder import com.basho.riak.spark._ import com.basho.riak.spark.streaming._ import com.basho.riak.spark.util.RiakObjectConversionUtil +import org.apache.spark.sql.SparkSession import org.apache.spark.streaming.kafka.KafkaUtils import org.apache.spark.streaming.{Durations, StreamingContext} import org.apache.spark.{SparkConf, SparkContext} @@ -27,7 +28,8 @@ object StreamingKVExample { setSparkOpt(sparkConf, "spark.riak.connection.host", "127.0.0.1:8087") setSparkOpt(sparkConf, "kafka.broker", "127.0.0.1:9092") - val sc = new SparkContext(sparkConf) + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + val sc = sparkSession.sparkContext val streamCtx = new StreamingContext(sc, Durations.seconds(15)) val kafkaProps = Map[String, String]( diff --git a/examples/src/main/scala/com/basho/riak/spark/examples/streaming/StreamingTSExample.scala b/examples/src/main/scala/com/basho/riak/spark/examples/streaming/StreamingTSExample.scala index 72ee4d91..91ec0a26 100644 --- a/examples/src/main/scala/com/basho/riak/spark/examples/streaming/StreamingTSExample.scala +++ b/examples/src/main/scala/com/basho/riak/spark/examples/streaming/StreamingTSExample.scala @@ -3,7 +3,7 @@ package com.basho.riak.spark.examples.streaming import java.util.UUID import kafka.serializer.StringDecoder -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.streaming.Durations import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.kafka.KafkaUtils @@ -45,7 +45,8 @@ object StreamingTSExample { setSparkOpt(sparkConf, "spark.riak.connection.host", "127.0.0.1:8087") setSparkOpt(sparkConf, "kafka.broker", "127.0.0.1:9092") - val sc = new SparkContext(sparkConf) + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + val sc = sparkSession.sparkContext val streamCtx = new StreamingContext(sc, Durations.seconds(15)) val kafkaProps = Map[String, String]( diff --git a/project/Versions.scala b/project/Versions.scala index e7dee7e0..7fce5abe 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -1,12 +1,14 @@ object Versions { - val spark = "1.6.1" + val spark = "2.1.0" val guava = "14.0.1" val riakClient = "2.0.7" - val kafka = "0.8.2.2" + val kafka = "0.10.1.0" + val sparkStreamingKafka = "1.6.3" val jfree = "1.0.19" val scalaChart = "0.4.2" val powermokc = "1.6.4" - val jacksonModule = "2.4.4" +// val jacksonModule = "2.4.4" + val jacksonModule = "2.6.5" val junit = "4.11" val jsonUnit = "1.5.1" val hamrest = "1.3" diff --git a/test-utils/src/main/scala/com/basho/riak/spark/run/PiRunLong.scala b/test-utils/src/main/scala/com/basho/riak/spark/run/PiRunLong.scala index 14698dda..153b6c0e 100644 --- a/test-utils/src/main/scala/com/basho/riak/spark/run/PiRunLong.scala +++ b/test-utils/src/main/scala/com/basho/riak/spark/run/PiRunLong.scala @@ -1,6 +1,7 @@ package com.basho.riak.spark.run import org.apache.spark._ +import org.apache.spark.sql.SparkSession import org.apache.spark.rdd._ object LongJobApp { @@ -39,7 +40,8 @@ object LongJobApp { val conf = new SparkConf().setAppName(APP_NAME) .setMaster(options.getOrElse('master, SPARK_URL).asInstanceOf[String]) - val sc = new SparkContext(conf) + val sparkSession = SparkSession.builder().config(conf).getOrCreate() + val sc = sparkSession.sparkContext val count = sc.parallelize(1 to options.getOrElse('samples, NUM_SAMPLES).asInstanceOf[Int], options.getOrElse('partitions, PARTITIONS).asInstanceOf[Int])