Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/github-actions-basic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ jobs:
- spark: "3.4.0"
java-version: "17"
distribution: "temurin"
# - spark: "4.0.0-PREVIEW2"
# java-version: "17"
# distribution: "temurin"
- spark: "4.0.0-preview2"
java-version: "17"
distribution: "temurin"
env:
SPARK_VERSION: ${{ matrix.spark }}
steps:
Expand All @@ -50,11 +50,11 @@ jobs:
with:
fetch-depth: 0
repository: holdenk/spark-testing-base
ref: main
- uses: actions/setup-java@v3
- uses: actions/setup-java@v4
with:
java-version: ${{ matrix.java-version }}
distribution: ${{ matrix.distribution }}
cache: sbt
- name: Cache maven modules
id: cache-maven
uses: actions/[email protected]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ While we hope you choose our library, https://github.com/juanrh/sscheck , https:

## [Release Notes](RELEASE_NOTES.md)

JDK17 support exists only for Spark 3.4.0
JDK17 support exists only for Spark 3.4.0 & Spark 4 (previews)

## Security Disclosure e-mails

Expand Down
68 changes: 56 additions & 12 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,17 @@ lazy val core = (project in file("core"))
"org.apache.spark" %% "spark-sql" % sparkVersion.value,
"org.apache.spark" %% "spark-hive" % sparkVersion.value,
"org.apache.spark" %% "spark-catalyst" % sparkVersion.value,
"org.apache.spark" %% "spark-yarn" % sparkVersion.value,
"org.apache.spark" %% "spark-mllib" % sparkVersion.value
) ++ commonDependencies ++
{
if (sparkVersion.value > "3.0.0") {
if (sparkVersion.value > "4.0.0") {
Seq(
"org.apache.spark" %% "spark-sql-api" % sparkVersion.value,
"io.netty" % "netty-all" % "4.1.96.Final",
"io.netty" % "netty-tcnative-classes" % "2.0.66.Final",
"com.github.luben" % "zstd-jni" % "1.5.5-4"
)
} else if (sparkVersion.value > "3.0.0") {
Seq(
"io.netty" % "netty-all" % "4.1.77.Final",
"io.netty" % "netty-tcnative-classes" % "2.0.52.Final"
Expand Down Expand Up @@ -101,14 +107,24 @@ lazy val kafka_0_8 = {
val commonSettings = Seq(
organization := "com.holdenkarau",
publishMavenStyle := true,
libraryDependencySchemes += "com.github.luben" %% "zstd-jni" % "early-semver", // "early-semver",
evictionErrorLevel := Level.Info,
sparkVersion := System.getProperty("sparkVersion", "2.4.8"),
sparkTestingVersion := "1.5.3",
sparkTestingVersion := "1.6.0",
version := sparkVersion.value + "_" + sparkTestingVersion.value,
scalaVersion := {
"2.12.15"
if (sparkVersion.value >= "4.0.0") {
"2.13.13"
} else {
"2.12.15"
}
},
crossScalaVersions := {
if (sparkVersion.value >= "3.2.0") {
if (sparkVersion.value >= "4.0.0") {
Seq("2.13.13")
} else if (sparkVersion.value >= "3.5.0") {
Seq("2.12.15", "2.13.13")
} else if (sparkVersion.value >= "3.2.0") {
Seq("2.12.15", "2.13.10")
} else if (sparkVersion.value >= "3.0.0") {
Seq("2.12.15")
Expand All @@ -118,9 +134,13 @@ val commonSettings = Seq(
},
scalacOptions ++= Seq("-deprecation", "-unchecked", "-Yrangepos"),
javacOptions ++= {
Seq("-source", "1.8", "-target", "1.8")
if (sparkVersion.value >= "4.0.0") {
Seq("-source", "17", "-target", "17")
} else {
Seq("-source", "1.8", "-target", "1.8")
}
},
javaOptions ++= Seq("-Xms5G", "-Xmx5G"),
javaOptions ++= Seq("-Xms8G", "-Xmx8G"),

coverageHighlighting := true,

Expand All @@ -142,16 +162,31 @@ val commonSettings = Seq(
"Typesafe repository" at "https://repo.typesafe.com/typesafe/releases/",
"Second Typesafe repo" at "https://repo.typesafe.com/typesafe/maven-releases/",
"Mesosphere Public Repository" at "https://downloads.mesosphere.io/maven",
Resolver.sonatypeRepo("public")
Resolver.sonatypeRepo("public"),
Resolver.mavenLocal
)
)

// Allow kafka (and other) utils to have version specific files
val coreSources = unmanagedSourceDirectories in Compile := {
if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
if (sparkVersion.value >= "4.0.0") Seq(
(sourceDirectory in Compile)(_ / "4.0/scala"),
(sourceDirectory in Compile)(_ / "2.2/scala"),
(sourceDirectory in Compile)(_ / "3.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/scala"), (sourceDirectory in Compile)(_ / "2.0/java")
(sourceDirectory in Compile)(_ / "2.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/java")
).join.value
else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Compile)(_ / "2.2/scala"),
(sourceDirectory in Compile)(_ / "3.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/java")
).join.value
else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Compile)(_ / "2.2/scala"),
(sourceDirectory in Compile)(_ / "3.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/java")
).join.value
else if (sparkVersion.value >= "2.4.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Compile)(_ / "2.2/scala"),
Expand All @@ -164,7 +199,16 @@ val coreSources = unmanagedSourceDirectories in Compile := {
}

val coreTestSources = unmanagedSourceDirectories in Test := {
if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
if (sparkVersion.value >= "4.0.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Test)(_ / "4.0/scala"),
(sourceDirectory in Test)(_ / "3.0/scala"),
(sourceDirectory in Test)(_ / "3.0/java"),
(sourceDirectory in Test)(_ / "2.2/scala"),
(sourceDirectory in Test)(_ / "2.0/scala"),
(sourceDirectory in Test)(_ / "2.0/java")
).join.value
else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Test)(_ / "pre-4.0/scala"),
(sourceDirectory in Test)(_ / "3.0/scala"),
(sourceDirectory in Test)(_ / "3.0/java"),
(sourceDirectory in Test)(_ / "2.2/scala"),
Expand Down Expand Up @@ -243,6 +287,6 @@ lazy val publishSettings = Seq(
}
)

lazy val noPublishSettings =
lazy val noPublishSettings = {
skip in publish := true
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.scalacheck.{Arbitrary, Gen}

object DataframeGenerator {
object DataFrameGenerator {

/**
* Creates a DataFrame Generator for the given Schema.
Expand Down Expand Up @@ -48,13 +48,16 @@ object DataframeGenerator {
*/
def arbitraryDataFrameWithCustomFields(
sqlContext: SQLContext, schema: StructType, minPartitions: Int = 1)
(userGenerators: ColumnGenerator*): Arbitrary[DataFrame] = {
(userGenerators: ColumnGeneratorBase*): Arbitrary[DataFrame] = {
import sqlContext._

val arbitraryRDDs = RDDGenerator.genRDD(
sqlContext.sparkContext, minPartitions)(
getRowGenerator(schema, userGenerators))
Arbitrary {
arbitraryRDDs.map(sqlContext.createDataFrame(_, schema))
arbitraryRDDs.map { r =>
sqlContext.createDataFrame(r, schema)
}
}
}

Expand All @@ -80,7 +83,7 @@ object DataframeGenerator {
* @return Gen[Row]
*/
def getRowGenerator(
schema: StructType, customGenerators: Seq[ColumnGenerator]): Gen[Row] = {
schema: StructType, customGenerators: Seq[ColumnGeneratorBase]): Gen[Row] = {
val generators: List[Gen[Any]] =
createGenerators(schema.fields, customGenerators)
val listGen: Gen[List[Any]] =
Expand All @@ -92,14 +95,14 @@ object DataframeGenerator {

private def createGenerators(
fields: Array[StructField],
userGenerators: Seq[ColumnGenerator]):
userGenerators: Seq[ColumnGeneratorBase]):
List[Gen[Any]] = {
val generatorMap = userGenerators.map(
generator => (generator.columnName -> generator)).toMap
fields.toList.map { field =>
if (generatorMap.contains(field.name)) {
generatorMap.get(field.name) match {
case Some(gen: Column) => gen.gen
case Some(gen: ColumnGenerator) => gen.gen
case Some(list: ColumnList) => getGenerator(field.dataType, list.gen, nullable = field.nullable)
}
}
Expand All @@ -109,7 +112,7 @@ object DataframeGenerator {

private def getGenerator(
dataType: DataType,
generators: Seq[ColumnGenerator] = Seq(),
generators: Seq[ColumnGeneratorBase] = Seq(),
nullable: Boolean = false): Gen[Any] = {
val nonNullGen = dataType match {
case ByteType => Arbitrary.arbitrary[Byte]
Expand All @@ -128,9 +131,21 @@ object DataframeGenerator {
l => new Date(l/10000)
}
case dec: DecimalType => {
// With the new ANSI default we need to make sure were passing in
// valid values.
Arbitrary.arbitrary[BigDecimal]
.retryUntil(_.precision <= dec.precision)
.retryUntil { d =>
try {
val sd = new Decimal()
// Make sure it can be converted
sd.set(d, dec.precision, dec.scale)
true
} catch {
case e: Exception => false
}
}
.map(_.bigDecimal.setScale(dec.scale, RoundingMode.HALF_UP))
.asInstanceOf[Gen[java.math.BigDecimal]]
}
case arr: ArrayType => {
val elementGenerator = getGenerator(arr.elementType, nullable = arr.containsNull)
Expand Down Expand Up @@ -165,27 +180,27 @@ object DataframeGenerator {
}

/**
* Previously ColumnGenerator. Allows the user to specify a generator for a
* Previously Column. Allows the user to specify a generator for a
* specific column.
*/
class Column(val columnName: String, generator: => Gen[Any])
extends ColumnGenerator {
class ColumnGenerator(val columnName: String, generator: => Gen[Any])
extends ColumnGeneratorBase {
lazy val gen = generator
}

/**
* ColumnList allows users to specify custom generators for a list of
* columns inside a StructType column.
*/
class ColumnList(val columnName: String, generators: => Seq[ColumnGenerator])
extends ColumnGenerator {
class ColumnList(val columnName: String, generators: => Seq[ColumnGeneratorBase])
extends ColumnGeneratorBase {
lazy val gen = generators
}

/**
* ColumnGenerator - prevously Column; it is now the base class for all
* ColumnGenerators.
*/
abstract class ColumnGenerator extends java.io.Serializable {
abstract class ColumnGeneratorBase extends java.io.Serializable {
val columnName: String
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ trait Prettify {
val maxNumberOfShownValues = 100

implicit def prettyDataFrame(dataframe: DataFrame): Pretty =
Pretty { _ => describeDataframe(dataframe)}
Pretty { _ => describeDataFrame(dataframe)}

implicit def prettyRDD(rdd: RDD[_]): Pretty =
Pretty { _ => describeRDD(rdd)}

implicit def prettyDataset(dataset: Dataset[_]): Pretty =
Pretty { _ => describeDataset(dataset)}

private def describeDataframe(dataframe: DataFrame) =
private def describeDataFrame(dataframe: DataFrame) =
s"""<DataFrame: schema = ${dataframe.toString}, size = ${dataframe.count()},
|values = (${dataframe.take(maxNumberOfShownValues).mkString(", ")})>""".
stripMargin.replace("\n", " ")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.apache.spark.sql.internal

import org.apache.spark.sql._
import org.apache.spark.sql.internal._
import org.apache.spark.sql.catalyst.expressions._

object EvilExpressionColumnNode {
def getExpr(node: ColumnNode): Expression = {
ColumnNodeToExpressionConverter.apply(node)
}
def toColumnNode(expr: Expression): ColumnNode = {
ExpressionColumnNode(expr)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ class MLScalaCheckTest extends AnyFunSuite with SharedSparkContext with Checkers
test("vector generation") {
val schema = StructType(List(StructField("vector", VectorType)))
val sqlContext = SparkSession.builder.getOrCreate().sqlContext
val dataframeGen = DataframeGenerator.arbitraryDataFrame(sqlContext, schema)
val dataFrameGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema)

val property =
forAll(dataframeGen.arbitrary) {
dataframe => {
dataframe.schema === schema && dataframe.count >= 0
forAll(dataFrameGen.arbitrary) {
dataFrame => {
dataFrame.schema === schema && dataFrame.count >= 0
}
}

Expand All @@ -30,12 +30,12 @@ class MLScalaCheckTest extends AnyFunSuite with SharedSparkContext with Checkers
test("matrix generation") {
val schema = StructType(List(StructField("matrix", MatrixType)))
val sqlContext = SparkSession.builder.getOrCreate().sqlContext
val dataframeGen = DataframeGenerator.arbitraryDataFrame(sqlContext, schema)
val dataFrameGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema)

val property =
forAll(dataframeGen.arbitrary) {
dataframe => {
dataframe.schema === schema && dataframe.count >= 0
forAll(dataFrameGen.arbitrary) {
dataFrame => {
dataFrame.schema === schema && dataFrame.count >= 0
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class PrettifyTest extends AnyFunSuite with SharedSparkContext with Checkers wit
test("pretty output of DataFrame's check") {
val schema = StructType(List(StructField("name", StringType), StructField("age", IntegerType)))
val sqlContext = SparkSession.builder.getOrCreate().sqlContext
val nameGenerator = new Column("name", Gen.const("Holden Hanafy"))
val ageGenerator = new Column("age", Gen.const(20))
val nameGenerator = new ColumnGenerator("name", Gen.const("Holden Hanafy"))
val ageGenerator = new ColumnGenerator("age", Gen.const(20))

val dataframeGen = DataframeGenerator.arbitraryDataFrameWithCustomFields(sqlContext, schema)(nameGenerator, ageGenerator)
val dataframeGen = DataFrameGenerator.arbitraryDataFrameWithCustomFields(sqlContext, schema)(nameGenerator, ageGenerator)

val actual = runFailingCheck(dataframeGen.arbitrary)
val expected =
Expand Down
Loading