Skip to content

Commit df021c0

Browse files
committed
fixing more tests. Can now remain at Kotlin 2.0 if we set -Xlambdas=class, which can be done with gradle plugin
1 parent 7069a9a commit df021c0

File tree

12 files changed

+103
-52
lines changed

12 files changed

+103
-52
lines changed

buildSrc/src/main/kotlin/Versions.kt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ object Versions : Dsl<Versions> {
22
const val project = "2.0.0-SNAPSHOT"
33
const val kotlinSparkApiGradlePlugin = "2.0.0-SNAPSHOT"
44
const val groupID = "org.jetbrains.kotlinx.spark"
5-
// const val kotlin = "2.0.0-Beta5" // todo issues with NonSerializable lambdas
6-
const val kotlin = "1.9.23"
5+
const val kotlin = "2.0.0-Beta5"
76
const val jvmTarget = "8"
87
const val jupyterJvmTarget = "8"
98
inline val spark get() = System.getProperty("spark") as String

gradle-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/gradlePlugin/SparkKotlinCompilerGradlePlugin.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ class SparkKotlinCompilerGradlePlugin : KotlinCompilerPluginSupportPlugin {
2020
compilerOptions {
2121
// Make sure the parameters of data classes are visible to scala
2222
javaParameters.set(true)
23+
24+
// Avoid NotSerializableException by making lambdas serializable
25+
freeCompilerArgs.add("-Xlambdas=class")
2326
}
2427
}
2528
}

kotlin-spark-api/build.gradle.kts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ tasks.compileTestKotlin {
147147
kotlin {
148148
jvmToolchain {
149149
languageVersion = JavaLanguageVersion.of(Versions.jvmTarget)
150-
151150
}
152151
}
153152

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import org.apache.spark.sql.types.UDTRegistration
4646
import org.apache.spark.sql.types.UserDefinedType
4747
import org.apache.spark.unsafe.types.CalendarInterval
4848
import scala.reflect.ClassTag
49+
import java.io.Serializable
4950
import kotlin.reflect.KClass
5051
import kotlin.reflect.KMutableProperty
5152
import kotlin.reflect.KType
@@ -122,7 +123,10 @@ fun schemaFor(kType: KType): DataType = kotlinEncoderFor<Any?>(kType).schema().u
122123
@Deprecated("Use schemaFor instead", ReplaceWith("schemaFor(kType)"))
123124
fun schema(kType: KType) = schemaFor(kType)
124125

125-
object KotlinTypeInference {
126+
object KotlinTypeInference : Serializable {
127+
128+
// https://blog.stylingandroid.com/kotlin-serializable-objects/
129+
private fun readResolve(): Any = KotlinTypeInference
126130

127131
/**
128132
* @param kClass the class for which to infer the encoder.

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/RddDouble.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ inline fun <reified T : Number> JavaRDD<T>.toJavaDoubleRDD(): JavaDoubleRDD =
2020

2121
/** Utility method to convert [JavaDoubleRDD] to [JavaRDD]<[Double]>. */
2222
@Suppress("UNCHECKED_CAST")
23-
fun JavaDoubleRDD.toDoubleRDD(): JavaRDD<Double> =
23+
inline fun JavaDoubleRDD.toDoubleRDD(): JavaRDD<Double> =
2424
JavaDoubleRDD.toRDD(this).toJavaRDD() as JavaRDD<Double>
2525

2626
/** Add up the elements in this RDD. */

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkSession.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import org.apache.spark.streaming.Durations
4444
import org.apache.spark.streaming.api.java.JavaStreamingContext
4545
import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR
4646
import org.jetbrains.kotlinx.spark.api.tuples.*
47+
import scala.reflect.ClassTag
4748
import java.io.Serializable
4849

4950
/**
@@ -406,7 +407,7 @@ private fun getDefaultHadoopConf(): Configuration {
406407
* @return `Broadcast` object, a read-only variable cached on each machine
407408
*/
408409
inline fun <reified T> SparkSession.broadcast(value: T): Broadcast<T> = try {
409-
sparkContext.broadcast(value, kotlinEncoderFor<T>().clsTag())
410+
sparkContext.broadcast(value, ClassTag.apply(T::class.java))
410411
} catch (e: ClassNotFoundException) {
411412
JavaSparkContext(sparkContext).broadcast(value)
412413
}
@@ -426,7 +427,7 @@ inline fun <reified T> SparkSession.broadcast(value: T): Broadcast<T> = try {
426427
DeprecationLevel.WARNING
427428
)
428429
inline fun <reified T> SparkContext.broadcast(value: T): Broadcast<T> = try {
429-
broadcast(value, kotlinEncoderFor<T>().clsTag())
430+
broadcast(value, ClassTag.apply(T::class.java))
430431
} catch (e: ClassNotFoundException) {
431432
JavaSparkContext(this).broadcast(value)
432433
}

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UserDefinedFunction.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ class TypeOfUDFParameterNotSupportedException(kClass: KClass<*>, parameterName:
6969
)
7070

7171
@JvmName("arrayColumnAsSeq")
72-
fun <DsType, T> TypedColumn<DsType, Array<T>>.asSeq(): TypedColumn<DsType, Seq<T>> = typed()
72+
inline fun <DsType, reified T> TypedColumn<DsType, Array<T>>.asSeq(): TypedColumn<DsType, Seq<T>> = typed()
7373
@JvmName("iterableColumnAsSeq")
74-
fun <DsType, T, I : Iterable<T>> TypedColumn<DsType, I>.asSeq(): TypedColumn<DsType, Seq<T>> = typed()
74+
inline fun <DsType, reified T, I : Iterable<T>> TypedColumn<DsType, I>.asSeq(): TypedColumn<DsType, Seq<T>> = typed()
7575
@JvmName("byteArrayColumnAsSeq")
7676
fun <DsType> TypedColumn<DsType, ByteArray>.asSeq(): TypedColumn<DsType, Seq<Byte>> = typed()
7777
@JvmName("charArrayColumnAsSeq")

kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ import java.time.Period
4141

4242
class EncodingTest : ShouldSpec({
4343

44+
@Sparkify
45+
data class SparkifiedPair<T, U>(val first: T, val second: U)
46+
4447
context("encoders") {
4548
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {
4649

@@ -134,8 +137,8 @@ class EncodingTest : ShouldSpec({
134137
}
135138

136139
should("be able to serialize Date") {
137-
val datePair = Date.valueOf("2020-02-10") to 5
138-
val dataset: Dataset<Pair<Date, Int>> = dsOf(datePair)
140+
val datePair = SparkifiedPair(Date.valueOf("2020-02-10"), 5)
141+
val dataset: Dataset<SparkifiedPair<Date, Int>> = dsOf(datePair)
139142
dataset.collectAsList() shouldBe listOf(datePair)
140143
}
141144

@@ -213,6 +216,8 @@ class EncodingTest : ShouldSpec({
213216

214217
context("Give proper names to columns of data classes") {
215218

219+
infix fun <A, B> A.to(other: B) = SparkifiedPair(this, other)
220+
216221
should("Be able to serialize pairs") {
217222
val pairs = listOf(
218223
1 to "1",
@@ -653,25 +658,25 @@ class EncodingTest : ShouldSpec({
653658
}
654659

655660
should("handle arrays of generics") {
656-
data class Test<Z>(val id: Long, val data: Array<Pair<Z, Int>>)
661+
data class Test<Z>(val id: Long, val data: Array<SparkifiedPair<Z, Int>>)
657662

658-
val result = listOf(Test(1, arrayOf(5.1 to 6, 6.1 to 7)))
663+
val result = listOf(Test(1, arrayOf(SparkifiedPair(5.1, 6), SparkifiedPair(6.1, 7))))
659664
.toDS()
660665
.map { it.id to it.data.firstOrNull { liEl -> liEl.first < 6 } }
661666
.map { it.second }
662667
.collectAsList()
663-
expect(result).toContain.inOrder.only.values(5.1 to 6)
668+
expect(result).toContain.inOrder.only.values(SparkifiedPair(5.1, 6))
664669
}
665670

666671
should("handle lists of generics") {
667-
data class Test<Z>(val id: Long, val data: List<Pair<Z, Int>>)
672+
data class Test<Z>(val id: Long, val data: List<SparkifiedPair<Z, Int>>)
668673

669-
val result = listOf(Test(1, listOf(5.1 to 6, 6.1 to 7)))
674+
val result = listOf(Test(1, listOf(SparkifiedPair(5.1, 6), SparkifiedPair(6.1, 7))))
670675
.toDS()
671676
.map { it.id to it.data.firstOrNull { liEl -> liEl.first < 6 } }
672677
.map { it.second }
673678
.collectAsList()
674-
expect(result).toContain.inOrder.only.values(5.1 to 6)
679+
expect(result).toContain.inOrder.only.values(SparkifiedPair(5.1, 6))
675680
}
676681

677682
should("handle boxed arrays") {

kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/RddTest.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@ import io.kotest.matchers.shouldBe
66
import org.apache.spark.api.java.JavaRDD
77
import org.jetbrains.kotlinx.spark.api.tuples.*
88
import scala.Tuple2
9+
import java.io.Serializable
910

10-
class RddTest : ShouldSpec({
11+
class RddTest : Serializable, ShouldSpec({
1112
context("RDD extension functions") {
1213

13-
withSpark(logLevel = SparkLogLevel.DEBUG) {
14+
withSpark(
15+
props = mapOf("spark.sql.codegen.wholeStage" to false),
16+
logLevel = SparkLogLevel.DEBUG,
17+
) {
1418

1519
context("Key/value") {
1620
should("work with spark example") {

kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/StreamingTest.kt

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import scala.Tuple2
3939
import java.io.File
4040
import java.io.Serializable
4141
import java.nio.charset.StandardCharsets
42+
import java.nio.file.Files
4243
import java.util.*
4344
import java.util.concurrent.atomic.AtomicBoolean
4445

@@ -201,10 +202,10 @@ class StreamingTest : ShouldSpec({
201202

202203
private val scalaCompatVersion = SCALA_COMPAT_VERSION
203204
private val sparkVersion = SPARK_VERSION
204-
private fun createTempDir() = File.createTempFile(
205-
System.getProperty("java.io.tmpdir"),
206-
"spark_${scalaCompatVersion}_${sparkVersion}"
207-
).apply { deleteOnExit() }
205+
private fun createTempDir() =
206+
Files.createTempDirectory("spark_${scalaCompatVersion}_${sparkVersion}")
207+
.toFile()
208+
.also { it.deleteOnExit() }
208209

209210
private fun checkpointFile(checkpointDir: String, checkpointTime: Time): Path {
210211
val klass = Class.forName("org.apache.spark.streaming.Checkpoint$")
@@ -215,7 +216,10 @@ private fun checkpointFile(checkpointDir: String, checkpointTime: Time): Path {
215216
return checkpointFileMethod.invoke(module, checkpointDir, checkpointTime) as Path
216217
}
217218

218-
private fun getCheckpointFiles(checkpointDir: String, fs: scala.Option<FileSystem>): scala.collection.immutable.Seq<Path> {
219+
private fun getCheckpointFiles(
220+
checkpointDir: String,
221+
fs: scala.Option<FileSystem>
222+
): scala.collection.immutable.Seq<Path> {
219223
val klass = Class.forName("org.apache.spark.streaming.Checkpoint$")
220224
val moduleField = klass.getField("MODULE$").also { it.isAccessible = true }
221225
val module = moduleField.get(null)
@@ -227,7 +231,11 @@ private fun getCheckpointFiles(checkpointDir: String, fs: scala.Option<FileSyste
227231
private fun createCorruptedCheckpoint(): String {
228232
val checkpointDirectory = createTempDir().absolutePath
229233
val fakeCheckpointFile = checkpointFile(checkpointDirectory, Time(1000))
230-
FileUtils.write(File(fakeCheckpointFile.toString()), "spark_corrupt_${scalaCompatVersion}_${sparkVersion}", StandardCharsets.UTF_8)
234+
FileUtils.write(
235+
File(fakeCheckpointFile.toString()),
236+
"spark_corrupt_${scalaCompatVersion}_${sparkVersion}",
237+
StandardCharsets.UTF_8
238+
)
231239
assert(getCheckpointFiles(checkpointDirectory, (null as FileSystem?).toOption()).nonEmpty())
232240
return checkpointDirectory
233241
}

0 commit comments

Comments
 (0)