diff --git a/.github/workflows/github-actions-basic.yml b/.github/workflows/github-actions-basic.yml index eff0b922..5a3228f7 100644 --- a/.github/workflows/github-actions-basic.yml +++ b/.github/workflows/github-actions-basic.yml @@ -38,9 +38,9 @@ jobs: - spark: "3.4.0" java-version: "17" distribution: "temurin" -# - spark: "4.0.0-PREVIEW2" -# java-version: "17" -# distribution: "temurin" + - spark: "4.0.0-preview2" + java-version: "17" + distribution: "temurin" env: SPARK_VERSION: ${{ matrix.spark }} steps: @@ -50,11 +50,11 @@ jobs: with: fetch-depth: 0 repository: holdenk/spark-testing-base - ref: main - - uses: actions/setup-java@v3 + - uses: actions/setup-java@v4 with: java-version: ${{ matrix.java-version }} distribution: ${{ matrix.distribution }} + cache: sbt - name: Cache maven modules id: cache-maven uses: actions/cache@v4.0.0 diff --git a/README.md b/README.md index bf5342cb..58f3db8f 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ While we hope you choose our library, https://github.com/juanrh/sscheck , https: ## [Release Notes](RELEASE_NOTES.md) -JDK17 support exists only for Spark 3.4.0 +JDK17 support exists only for Spark 3.4.0 & Spark 4 (previews) ## Security Disclosure e-mails diff --git a/build.sbt b/build.sbt index 5bd66cc3..0b71cb0c 100644 --- a/build.sbt +++ b/build.sbt @@ -41,11 +41,17 @@ lazy val core = (project in file("core")) "org.apache.spark" %% "spark-sql" % sparkVersion.value, "org.apache.spark" %% "spark-hive" % sparkVersion.value, "org.apache.spark" %% "spark-catalyst" % sparkVersion.value, - "org.apache.spark" %% "spark-yarn" % sparkVersion.value, "org.apache.spark" %% "spark-mllib" % sparkVersion.value ) ++ commonDependencies ++ { - if (sparkVersion.value > "3.0.0") { + if (sparkVersion.value > "4.0.0") { + Seq( + "org.apache.spark" %% "spark-sql-api" % sparkVersion.value, + "io.netty" % "netty-all" % "4.1.96.Final", + "io.netty" % "netty-tcnative-classes" % "2.0.66.Final", + "com.github.luben" % "zstd-jni" % "1.5.5-4" + ) + } else if (sparkVersion.value > "3.0.0") { Seq( "io.netty" % "netty-all" % "4.1.77.Final", "io.netty" % "netty-tcnative-classes" % "2.0.52.Final" @@ -101,14 +107,24 @@ lazy val kafka_0_8 = { val commonSettings = Seq( organization := "com.holdenkarau", publishMavenStyle := true, + libraryDependencySchemes += "com.github.luben" %% "zstd-jni" % "early-semver", // "early-semver", + evictionErrorLevel := Level.Info, sparkVersion := System.getProperty("sparkVersion", "2.4.8"), - sparkTestingVersion := "1.5.3", + sparkTestingVersion := "1.6.0", version := sparkVersion.value + "_" + sparkTestingVersion.value, scalaVersion := { - "2.12.15" + if (sparkVersion.value >= "4.0.0") { + "2.13.13" + } else { + "2.12.15" + } }, crossScalaVersions := { - if (sparkVersion.value >= "3.2.0") { + if (sparkVersion.value >= "4.0.0") { + Seq("2.13.13") + } else if (sparkVersion.value >= "3.5.0") { + Seq("2.12.15", "2.13.13") + } else if (sparkVersion.value >= "3.2.0") { Seq("2.12.15", "2.13.10") } else if (sparkVersion.value >= "3.0.0") { Seq("2.12.15") @@ -118,9 +134,13 @@ val commonSettings = Seq( }, scalacOptions ++= Seq("-deprecation", "-unchecked", "-Yrangepos"), javacOptions ++= { - Seq("-source", "1.8", "-target", "1.8") + if (sparkVersion.value >= "4.0.0") { + Seq("-source", "17", "-target", "17") + } else { + Seq("-source", "1.8", "-target", "1.8") + } }, - javaOptions ++= Seq("-Xms5G", "-Xmx5G"), + javaOptions ++= Seq("-Xms8G", "-Xmx8G"), coverageHighlighting := true, @@ -142,16 +162,31 @@ val commonSettings = Seq( "Typesafe repository" at "https://repo.typesafe.com/typesafe/releases/", "Second Typesafe repo" at "https://repo.typesafe.com/typesafe/maven-releases/", "Mesosphere Public Repository" at "https://downloads.mesosphere.io/maven", - Resolver.sonatypeRepo("public") + Resolver.sonatypeRepo("public"), + Resolver.mavenLocal ) ) // Allow kafka (and other) utils to have version specific files val coreSources = unmanagedSourceDirectories in Compile := { - if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq( + if (sparkVersion.value >= "4.0.0") Seq( + (sourceDirectory in Compile)(_ / "4.0/scala"), (sourceDirectory in Compile)(_ / "2.2/scala"), (sourceDirectory in Compile)(_ / "3.0/scala"), - (sourceDirectory in Compile)(_ / "2.0/scala"), (sourceDirectory in Compile)(_ / "2.0/java") + (sourceDirectory in Compile)(_ / "2.0/scala"), + (sourceDirectory in Compile)(_ / "2.0/java") + ).join.value + else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq( + (sourceDirectory in Compile)(_ / "2.2/scala"), + (sourceDirectory in Compile)(_ / "3.0/scala"), + (sourceDirectory in Compile)(_ / "2.0/scala"), + (sourceDirectory in Compile)(_ / "2.0/java") + ).join.value + else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq( + (sourceDirectory in Compile)(_ / "2.2/scala"), + (sourceDirectory in Compile)(_ / "3.0/scala"), + (sourceDirectory in Compile)(_ / "2.0/scala"), + (sourceDirectory in Compile)(_ / "2.0/java") ).join.value else if (sparkVersion.value >= "2.4.0" && scalaVersion.value >= "2.12.0") Seq( (sourceDirectory in Compile)(_ / "2.2/scala"), @@ -164,7 +199,16 @@ val coreSources = unmanagedSourceDirectories in Compile := { } val coreTestSources = unmanagedSourceDirectories in Test := { - if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq( + if (sparkVersion.value >= "4.0.0" && scalaVersion.value >= "2.12.0") Seq( + (sourceDirectory in Test)(_ / "4.0/scala"), + (sourceDirectory in Test)(_ / "3.0/scala"), + (sourceDirectory in Test)(_ / "3.0/java"), + (sourceDirectory in Test)(_ / "2.2/scala"), + (sourceDirectory in Test)(_ / "2.0/scala"), + (sourceDirectory in Test)(_ / "2.0/java") + ).join.value + else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq( + (sourceDirectory in Test)(_ / "pre-4.0/scala"), (sourceDirectory in Test)(_ / "3.0/scala"), (sourceDirectory in Test)(_ / "3.0/java"), (sourceDirectory in Test)(_ / "2.2/scala"), @@ -243,6 +287,6 @@ lazy val publishSettings = Seq( } ) -lazy val noPublishSettings = +lazy val noPublishSettings = { skip in publish := true } diff --git a/core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataframeGenerator.scala b/core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameGenerator.scala similarity index 83% rename from core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataframeGenerator.scala rename to core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameGenerator.scala index 8567c9c6..85b2fe00 100644 --- a/core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataframeGenerator.scala +++ b/core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameGenerator.scala @@ -7,7 +7,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.scalacheck.{Arbitrary, Gen} -object DataframeGenerator { +object DataFrameGenerator { /** * Creates a DataFrame Generator for the given Schema. @@ -48,13 +48,16 @@ object DataframeGenerator { */ def arbitraryDataFrameWithCustomFields( sqlContext: SQLContext, schema: StructType, minPartitions: Int = 1) - (userGenerators: ColumnGenerator*): Arbitrary[DataFrame] = { + (userGenerators: ColumnGeneratorBase*): Arbitrary[DataFrame] = { + import sqlContext._ val arbitraryRDDs = RDDGenerator.genRDD( sqlContext.sparkContext, minPartitions)( getRowGenerator(schema, userGenerators)) Arbitrary { - arbitraryRDDs.map(sqlContext.createDataFrame(_, schema)) + arbitraryRDDs.map { r => + sqlContext.createDataFrame(r, schema) + } } } @@ -80,7 +83,7 @@ object DataframeGenerator { * @return Gen[Row] */ def getRowGenerator( - schema: StructType, customGenerators: Seq[ColumnGenerator]): Gen[Row] = { + schema: StructType, customGenerators: Seq[ColumnGeneratorBase]): Gen[Row] = { val generators: List[Gen[Any]] = createGenerators(schema.fields, customGenerators) val listGen: Gen[List[Any]] = @@ -92,14 +95,14 @@ object DataframeGenerator { private def createGenerators( fields: Array[StructField], - userGenerators: Seq[ColumnGenerator]): + userGenerators: Seq[ColumnGeneratorBase]): List[Gen[Any]] = { val generatorMap = userGenerators.map( generator => (generator.columnName -> generator)).toMap fields.toList.map { field => if (generatorMap.contains(field.name)) { generatorMap.get(field.name) match { - case Some(gen: Column) => gen.gen + case Some(gen: ColumnGenerator) => gen.gen case Some(list: ColumnList) => getGenerator(field.dataType, list.gen, nullable = field.nullable) } } @@ -109,7 +112,7 @@ object DataframeGenerator { private def getGenerator( dataType: DataType, - generators: Seq[ColumnGenerator] = Seq(), + generators: Seq[ColumnGeneratorBase] = Seq(), nullable: Boolean = false): Gen[Any] = { val nonNullGen = dataType match { case ByteType => Arbitrary.arbitrary[Byte] @@ -128,9 +131,21 @@ object DataframeGenerator { l => new Date(l/10000) } case dec: DecimalType => { + // With the new ANSI default we need to make sure were passing in + // valid values. Arbitrary.arbitrary[BigDecimal] - .retryUntil(_.precision <= dec.precision) + .retryUntil { d => + try { + val sd = new Decimal() + // Make sure it can be converted + sd.set(d, dec.precision, dec.scale) + true + } catch { + case e: Exception => false + } + } .map(_.bigDecimal.setScale(dec.scale, RoundingMode.HALF_UP)) + .asInstanceOf[Gen[java.math.BigDecimal]] } case arr: ArrayType => { val elementGenerator = getGenerator(arr.elementType, nullable = arr.containsNull) @@ -165,11 +180,11 @@ object DataframeGenerator { } /** - * Previously ColumnGenerator. Allows the user to specify a generator for a + * Previously Column. Allows the user to specify a generator for a * specific column. */ -class Column(val columnName: String, generator: => Gen[Any]) - extends ColumnGenerator { +class ColumnGenerator(val columnName: String, generator: => Gen[Any]) + extends ColumnGeneratorBase { lazy val gen = generator } @@ -177,8 +192,8 @@ class Column(val columnName: String, generator: => Gen[Any]) * ColumnList allows users to specify custom generators for a list of * columns inside a StructType column. */ -class ColumnList(val columnName: String, generators: => Seq[ColumnGenerator]) - extends ColumnGenerator { +class ColumnList(val columnName: String, generators: => Seq[ColumnGeneratorBase]) + extends ColumnGeneratorBase { lazy val gen = generators } @@ -186,6 +201,6 @@ class ColumnList(val columnName: String, generators: => Seq[ColumnGenerator]) * ColumnGenerator - prevously Column; it is now the base class for all * ColumnGenerators. */ -abstract class ColumnGenerator extends java.io.Serializable { +abstract class ColumnGeneratorBase extends java.io.Serializable { val columnName: String } diff --git a/core/src/main/2.0/scala/com/holdenkarau/spark/testing/Prettify.scala b/core/src/main/2.0/scala/com/holdenkarau/spark/testing/Prettify.scala index f1c1287a..e7e5a90e 100644 --- a/core/src/main/2.0/scala/com/holdenkarau/spark/testing/Prettify.scala +++ b/core/src/main/2.0/scala/com/holdenkarau/spark/testing/Prettify.scala @@ -8,7 +8,7 @@ trait Prettify { val maxNumberOfShownValues = 100 implicit def prettyDataFrame(dataframe: DataFrame): Pretty = - Pretty { _ => describeDataframe(dataframe)} + Pretty { _ => describeDataFrame(dataframe)} implicit def prettyRDD(rdd: RDD[_]): Pretty = Pretty { _ => describeRDD(rdd)} @@ -16,7 +16,7 @@ trait Prettify { implicit def prettyDataset(dataset: Dataset[_]): Pretty = Pretty { _ => describeDataset(dataset)} - private def describeDataframe(dataframe: DataFrame) = + private def describeDataFrame(dataframe: DataFrame) = s"""""". stripMargin.replace("\n", " ") diff --git a/core/src/main/4.0/scala/org/apache/spark/sql/internal/evilTools.scala b/core/src/main/4.0/scala/org/apache/spark/sql/internal/evilTools.scala new file mode 100644 index 00000000..f4a0bb3e --- /dev/null +++ b/core/src/main/4.0/scala/org/apache/spark/sql/internal/evilTools.scala @@ -0,0 +1,14 @@ +package org.apache.spark.sql.internal + +import org.apache.spark.sql._ +import org.apache.spark.sql.internal._ +import org.apache.spark.sql.catalyst.expressions._ + +object EvilExpressionColumnNode { + def getExpr(node: ColumnNode): Expression = { + ColumnNodeToExpressionConverter.apply(node) + } + def toColumnNode(expr: Expression): ColumnNode = { + ExpressionColumnNode(expr) + } +} diff --git a/core/src/test/2.0/scala/com/holdenkarau/spark/testing/MLScalaCheckTest.scala b/core/src/test/2.0/scala/com/holdenkarau/spark/testing/MLScalaCheckTest.scala index e23166b4..7af80690 100644 --- a/core/src/test/2.0/scala/com/holdenkarau/spark/testing/MLScalaCheckTest.scala +++ b/core/src/test/2.0/scala/com/holdenkarau/spark/testing/MLScalaCheckTest.scala @@ -15,12 +15,12 @@ class MLScalaCheckTest extends AnyFunSuite with SharedSparkContext with Checkers test("vector generation") { val schema = StructType(List(StructField("vector", VectorType))) val sqlContext = SparkSession.builder.getOrCreate().sqlContext - val dataframeGen = DataframeGenerator.arbitraryDataFrame(sqlContext, schema) + val dataFrameGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema) val property = - forAll(dataframeGen.arbitrary) { - dataframe => { - dataframe.schema === schema && dataframe.count >= 0 + forAll(dataFrameGen.arbitrary) { + dataFrame => { + dataFrame.schema === schema && dataFrame.count >= 0 } } @@ -30,12 +30,12 @@ class MLScalaCheckTest extends AnyFunSuite with SharedSparkContext with Checkers test("matrix generation") { val schema = StructType(List(StructField("matrix", MatrixType))) val sqlContext = SparkSession.builder.getOrCreate().sqlContext - val dataframeGen = DataframeGenerator.arbitraryDataFrame(sqlContext, schema) + val dataFrameGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema) val property = - forAll(dataframeGen.arbitrary) { - dataframe => { - dataframe.schema === schema && dataframe.count >= 0 + forAll(dataFrameGen.arbitrary) { + dataFrame => { + dataFrame.schema === schema && dataFrame.count >= 0 } } diff --git a/core/src/test/2.0/scala/com/holdenkarau/spark/testing/PrettifyTest.scala b/core/src/test/2.0/scala/com/holdenkarau/spark/testing/PrettifyTest.scala index 505b1e1c..bc968078 100644 --- a/core/src/test/2.0/scala/com/holdenkarau/spark/testing/PrettifyTest.scala +++ b/core/src/test/2.0/scala/com/holdenkarau/spark/testing/PrettifyTest.scala @@ -16,10 +16,10 @@ class PrettifyTest extends AnyFunSuite with SharedSparkContext with Checkers wit test("pretty output of DataFrame's check") { val schema = StructType(List(StructField("name", StringType), StructField("age", IntegerType))) val sqlContext = SparkSession.builder.getOrCreate().sqlContext - val nameGenerator = new Column("name", Gen.const("Holden Hanafy")) - val ageGenerator = new Column("age", Gen.const(20)) + val nameGenerator = new ColumnGenerator("name", Gen.const("Holden Hanafy")) + val ageGenerator = new ColumnGenerator("age", Gen.const(20)) - val dataframeGen = DataframeGenerator.arbitraryDataFrameWithCustomFields(sqlContext, schema)(nameGenerator, ageGenerator) + val dataframeGen = DataFrameGenerator.arbitraryDataFrameWithCustomFields(sqlContext, schema)(nameGenerator, ageGenerator) val actual = runFailingCheck(dataframeGen.arbitrary) val expected = diff --git a/core/src/test/2.0/scala/com/holdenkarau/spark/testing/SampleScalaCheckTest.scala b/core/src/test/2.0/scala/com/holdenkarau/spark/testing/SampleScalaCheckTest.scala index ddbc606e..85301ffc 100644 --- a/core/src/test/2.0/scala/com/holdenkarau/spark/testing/SampleScalaCheckTest.scala +++ b/core/src/test/2.0/scala/com/holdenkarau/spark/testing/SampleScalaCheckTest.scala @@ -143,7 +143,7 @@ class SampleScalaCheckTest extends AnyFunSuite test("assert rows' types like schema type") { val schema = StructType( List(StructField("name", StringType, nullable = false), StructField("age", IntegerType, nullable = false))) - val rowGen: Gen[Row] = DataframeGenerator.getRowGenerator(schema) + val rowGen: Gen[Row] = DataFrameGenerator.getRowGenerator(schema) val property = forAll(rowGen) { row => row.get(0).isInstanceOf[String] && row.get(1).isInstanceOf[Int] @@ -152,11 +152,11 @@ class SampleScalaCheckTest extends AnyFunSuite check(property) } - test("test generating Dataframes") { + test("test generating DataFrames") { val schema = StructType( List(StructField("name", StringType), StructField("age", IntegerType))) val sqlContext = SparkSession.builder.getOrCreate().sqlContext - val dataframeGen = DataframeGenerator.arbitraryDataFrame(sqlContext, schema) + val dataframeGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema) val property = forAll(dataframeGen.arbitrary) { @@ -170,9 +170,9 @@ class SampleScalaCheckTest extends AnyFunSuite val schema = StructType( List(StructField("name", StringType), StructField("age", IntegerType))) val sqlContext = SparkSession.builder.getOrCreate().sqlContext - val ageGenerator = new Column("age", Gen.choose(10, 100)) + val ageGenerator = new ColumnGenerator("age", Gen.choose(10, 100)) val dataframeGen = - DataframeGenerator.arbitraryDataFrameWithCustomFields( + DataFrameGenerator.arbitraryDataFrameWithCustomFields( sqlContext, schema)(ageGenerator) val property = @@ -190,10 +190,10 @@ class SampleScalaCheckTest extends AnyFunSuite List(StructField("name", StringType), StructField("age", IntegerType))) val sqlContext = SparkSession.builder.getOrCreate().sqlContext // name should be on of Holden or Hanafy - val nameGenerator = new Column("name", Gen.oneOf("Holden", "Hanafy")) - val ageGenerator = new Column("age", Gen.choose(10, 100)) + val nameGenerator = new ColumnGenerator("name", Gen.oneOf("Holden", "Hanafy")) + val ageGenerator = new ColumnGenerator("age", Gen.choose(10, 100)) val dataframeGen = - DataframeGenerator.arbitraryDataFrameWithCustomFields( + DataFrameGenerator.arbitraryDataFrameWithCustomFields( sqlContext, schema)(nameGenerator, ageGenerator) val sqlExpr = @@ -221,12 +221,12 @@ class SampleScalaCheckTest extends AnyFunSuite val sqlContext = SparkSession.builder.getOrCreate().sqlContext val userGenerator = new ColumnList("user", Seq( // name should be on of Holden or Hanafy - new Column("name", Gen.oneOf("Holden", "Hanafy")), - new Column("age", Gen.choose(10, 100)), - new ColumnList("address", Seq(new Column("zip_code", Gen.choose(100, 200)))) + new ColumnGenerator("name", Gen.oneOf("Holden", "Hanafy")), + new ColumnGenerator("age", Gen.choose(10, 100)), + new ColumnList("address", Seq(new ColumnGenerator("zip_code", Gen.choose(100, 200)))) )) val dataframeGen = - DataframeGenerator.arbitraryDataFrameWithCustomFields( + DataFrameGenerator.arbitraryDataFrameWithCustomFields( sqlContext, schema)(userGenerator) val sqlExpr = """ @@ -272,7 +272,7 @@ class SampleScalaCheckTest extends AnyFunSuite val sqlContext = SparkSession.builder.getOrCreate().sqlContext val dataframeGen: Arbitrary[DataFrame] = - DataframeGenerator.arbitraryDataFrame(sqlContext, schema) + DataFrameGenerator.arbitraryDataFrame(sqlContext, schema) val property = forAll(dataframeGen.arbitrary) { dataframe => dataframe.schema === schema && @@ -293,7 +293,7 @@ class SampleScalaCheckTest extends AnyFunSuite val sqlContext = SparkSession.builder.getOrCreate().sqlContext val dataframeGen: Arbitrary[DataFrame] = - DataframeGenerator.arbitraryDataFrame(sqlContext, schema) + DataFrameGenerator.arbitraryDataFrame(sqlContext, schema) val property = forAll(dataframeGen.arbitrary) { dataframe => dataframe.schema === schema && @@ -307,7 +307,7 @@ class SampleScalaCheckTest extends AnyFunSuite val schema = StructType( List(StructField("timeStamp", TimestampType), StructField("date", DateType))) val sqlContext = SparkSession.builder.getOrCreate().sqlContext - val dataframeGen = DataframeGenerator.arbitraryDataFrame(sqlContext, schema) + val dataframeGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema) val property = forAll(dataframeGen.arbitrary) { @@ -319,6 +319,24 @@ class SampleScalaCheckTest extends AnyFunSuite check(property) } + test("decimal generation mini") { + val schema = StructType(List( + StructField("bloop", DecimalType(38, 2), nullable=true))) + + val sqlContext = new SQLContext(sc) + val dataframeGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema) + + val property = + forAll(dataframeGen.arbitrary) { + dataframe => { + dataframe.schema === schema && dataframe.count >= 0 + } + } + + check(property) + } + + test("decimal generation") { val schema = StructType(List( StructField("small", DecimalType(3, 1)), @@ -326,7 +344,7 @@ class SampleScalaCheckTest extends AnyFunSuite StructField("large", DecimalType(38, 38)))) val sqlContext = new SQLContext(sc) - val dataframeGen = DataframeGenerator.arbitraryDataFrame(sqlContext, schema) + val dataframeGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema) val property = forAll(dataframeGen.arbitrary) { @@ -342,7 +360,7 @@ class SampleScalaCheckTest extends AnyFunSuite val schema = StructType( List(StructField("map", MapType(LongType, IntegerType, true)))) val sqlContext = SparkSession.builder.getOrCreate().sqlContext - val dataframeGen = DataframeGenerator.arbitraryDataFrame(sqlContext, schema) + val dataframeGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema) val property = forAll(dataframeGen.arbitrary) { @@ -377,7 +395,7 @@ class SampleScalaCheckTest extends AnyFunSuite val sqlContext = SparkSession.builder.getOrCreate().sqlContext val dataframeGen = - DataframeGenerator.arbitraryDataFrame(sqlContext, StructType(fields)) + DataFrameGenerator.arbitraryDataFrame(sqlContext, StructType(fields)) val property = forAll(dataframeGen.arbitrary) { diff --git a/core/src/test/4.0/scala/com/holdenkarau/spark/testing/SampleSparkExpressionTest.scala b/core/src/test/4.0/scala/com/holdenkarau/spark/testing/SampleSparkExpressionTest.scala new file mode 100644 index 00000000..bb8e0805 --- /dev/null +++ b/core/src/test/4.0/scala/com/holdenkarau/spark/testing/SampleSparkExpressionTest.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.holdenkarau.spark.testing + +import org.apache.spark.sql.{Column, SparkSession, DataFrame} +import org.apache.spark.sql.internal._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types._ +import org.scalatest.Suite + +/** + * To run this test yous must set SPARK_TESTING=yes (or other non-null value). + */ +class SampleSparkExpressionTest extends ScalaDataFrameSuiteBase { + val inputList = List( + FakeMagic("panda"), + FakeMagic("coffee")) + + testNonCodegen("non-codegen paths!") { + val session = SparkSession.builder().getOrCreate() + import session.implicits._ + val input = sc.parallelize(inputList).toDF + val result_working = input.select(WorkingCodegenExpression.work(input("name")) + 1) + val result_failing = input.select(FailingCodegenExpression.fail(input("name")) + 1) + assert(result_working.collect()(0)(0) === 2) + assert(result_failing.collect()(0)(0) === 2) + } + + testCodegenOnly("verify codegen tests are run with codegen.") { + import sqlContext.implicits._ + val input = sc.parallelize(inputList).toDF + val result_working = input.select(WorkingCodegenExpression.work(input("name"))) + val result_failing = input.select(FailingCodegenExpression.fail(input("name"))) + assert(result_working.collect()(0)(0) === 1) + assert(result_failing.collect()(0)(0) === 3) + } +} + +object WorkingCodegenExpression { + private def withExpr(expr: Expression): Column = new Column( + EvilExpressionColumnNode.toColumnNode(expr)) + + def work(col: Column): Column = withExpr { + WorkingCodegenExpression(EvilExpressionColumnNode.getExpr(col.node)) + } +} + + +//tag::unary[] +case class WorkingCodegenExpression(child: Expression) extends UnaryExpression { + override def prettyName = "workingCodegen" + + override def nullSafeEval(input: Any): Any = { + if (input == null) { + return null + } + return 1 + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // Code to serialize. + val input = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val code = input.code + code""" + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : 1; + """ + ev.copy(code = code, isNull = input.isNull) + } + + override def dataType: DataType = IntegerType + + // New in 3.2 + def withNewChildInternal(newChild: Expression) = { + copy(child = newChild) + } +} +//end::unary[] + +object FailingCodegenExpression { + private def withExpr(expr: Expression): Column = new Column( + EvilExpressionColumnNode.toColumnNode(expr)) + + def fail(col: Column): Column = withExpr { + FailingCodegenExpression(EvilExpressionColumnNode.getExpr(col.node)) + } +} + +case class FailingCodegenExpression(child: Expression) extends UnaryExpression { + override def prettyName = "failingCodegen" + + override def nullSafeEval(input: Any): Any = { + if (input == null) { + return null + } + return 1 + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // Code to serialize. + val input = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val code = input.code + code""" + final $javaType ${ev.value} = 3; + """ + ev.copy(code = code, isNull = input.isNull) + } + override def dataType: DataType = IntegerType + + // New in 3.2 + def withNewChildInternal(newChild: Expression) = { + copy(child = newChild) + } +} + +case class FakeMagic(name: String) diff --git a/core/src/test/3.0/scala/com/holdenkarau/spark/testing/SampleSparkExpressionTest.scala b/core/src/test/pre-4.0/scala/com/holdenkarau/spark/testing/SampleSparkExpressionTest.scala similarity index 100% rename from core/src/test/3.0/scala/com/holdenkarau/spark/testing/SampleSparkExpressionTest.scala rename to core/src/test/pre-4.0/scala/com/holdenkarau/spark/testing/SampleSparkExpressionTest.scala diff --git a/mini-cross-build.sh b/mini-cross-build.sh index 5ab2e41c..b6eeead8 100755 --- a/mini-cross-build.sh +++ b/mini-cross-build.sh @@ -1,6 +1,6 @@ #!/bin/bash set -ex -for spark_version in 2.4.8 3.0.0 3.0.1 3.0.2 3.1.1 3.1.2 3.1.3 3.2.0 3.2.1 3.2.2 3.2.3 3.2.4 3.3.0 3.3.1 3.3.2 3.4.0 3.4.1 3.4.2 3.5.0 3.5.1 +for spark_version in 2.4.8 3.0.0 3.0.1 3.0.2 3.1.1 3.1.2 3.1.3 3.2.0 3.2.1 3.2.2 3.2.3 3.2.4 3.3.0 3.3.1 3.3.2 3.4.0 3.4.1 3.4.2 3.5.0 3.5.1 3.5.2 3.5.3 4.0.0-preview1 4.0.0-preview2 do build_dir="/tmp/spark-testing-base-$spark_version" mkdir -p "${build_dir}"