diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 311274c9203ae..630956a9e73aa 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -249,7 +249,10 @@ jobs: # Run the tests. - name: Run tests env: ${{ fromJSON(inputs.envs) }} + shell: 'script -q -e -c "bash {0}"' run: | + # Fix for TTY related issues when launching the Ammonite REPL in tests. + export TERM=vt100 && script -qfc 'echo exit | amm -s' && rm typescript # Hive "other tests" test needs larger metaspace size based on experiment. if [[ "$MODULES_TO_TEST" == "hive" ]] && [[ "$EXCLUDED_TAGS" == "org.apache.spark.tags.SlowHiveTest" ]]; then export METASPACE_SIZE=2g; fi export SERIAL_SBT_TESTS=1 diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index b1b779f0f08e7..e285db39e8070 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder} -import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} +import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.util.{Cleaner, ConvertToArrow} import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.internal.CatalogImpl @@ -495,6 +495,14 @@ class SparkSession private[sql] ( @scala.annotation.varargs def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri) + /** + * Register a [[ClassFinder]] for dynamically generated classes. + * + * @since 3.4.0 + */ + @Experimental + def registerClassFinder(finder: ClassFinder): Unit = client.registerClassFinder(finder) + /** * This resets the plan id generator so we can produce plans that are comparable. * diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala index ec31697ee59e2..53a31fed489e8 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala @@ -16,9 +16,11 @@ */ package org.apache.spark.sql.application +import java.io.{InputStream, OutputStream} +import java.util.concurrent.Semaphore + import scala.util.control.NonFatal -import ammonite.compiler.CodeClassWrapper import ammonite.util.Bind import org.apache.spark.annotation.DeveloperApi @@ -43,7 +45,14 @@ object ConnectRepl { | /_/ |""".stripMargin - def main(args: Array[String]): Unit = { + def main(args: Array[String]): Unit = doMain(args) + + private[application] def doMain( + args: Array[String], + semaphore: Option[Semaphore] = None, + inputStream: InputStream = System.in, + outputStream: OutputStream = System.out, + errorStream: OutputStream = System.err): Unit = { // Build the client. val client = try { @@ -67,22 +76,30 @@ object ConnectRepl { // Build the session. val spark = SparkSession.builder().client(client).build() + val sparkBind = new Bind("spark", spark) - // Add the proper imports. - val imports = + // Add the proper imports and register a [[ClassFinder]]. + val predefCode = """ |import org.apache.spark.sql.functions._ |import spark.implicits._ |import spark.sql + |import org.apache.spark.sql.connect.client.AmmoniteClassFinder + | + |spark.registerClassFinder(new AmmoniteClassFinder(repl.sess)) |""".stripMargin - // Please note that we make ammonite generate classes instead of objects. - // Classes tend to have superior serialization behavior when using UDFs. val main = ammonite.Main( welcomeBanner = Option(splash), - predefCode = imports, - replCodeWrapper = CodeClassWrapper, - scriptCodeWrapper = CodeClassWrapper) - main.run(new Bind("spark", spark)) + predefCode = predefCode, + inputStream = inputStream, + outputStream = outputStream, + errorStream = errorStream) + if (semaphore.nonEmpty) { + // Used for testing. + main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get)) + } else { + main.run(sparkBind) + } } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala index ead500a53e639..ef3d66c85bc3b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala @@ -16,11 +16,13 @@ */ package org.apache.spark.sql.connect.client -import java.io.InputStream +import java.io.{ByteArrayInputStream, InputStream} import java.net.URI import java.nio.file.{Files, Path, Paths} +import java.util.concurrent.CopyOnWriteArrayList import java.util.zip.{CheckedInputStream, CRC32} +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.Promise import scala.concurrent.duration.Duration @@ -48,6 +50,12 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) { private val CHUNK_SIZE: Int = 32 * 1024 private[this] val stub = proto.SparkConnectServiceGrpc.newStub(channel) + private[this] val classFinders = new CopyOnWriteArrayList[ClassFinder] + + /** + * Register a [[ClassFinder]] for dynamically generated classes. + */ + def registerClassFinder(finder: ClassFinder): Unit = classFinders.add(finder) /** * Add a single artifact to the session. @@ -92,10 +100,23 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) { */ def addArtifacts(uris: Seq[URI]): Unit = addArtifacts(uris.flatMap(parseArtifacts)) + /** + * Upload all class file artifacts from the local REPL(s) to the server. + * + * The registered [[ClassFinder]]s are traversed to retrieve the class file artifacts. + */ + private[client] def uploadAllClassFileArtifacts(): Unit = { + addArtifacts(classFinders.asScala.flatMap(_.findClasses())) + } + /** * Add a number of artifacts to the session. */ private def addArtifacts(artifacts: Iterable[Artifact]): Unit = { + if (artifacts.isEmpty) { + return + } + val promise = Promise[Seq[ArtifactSummary]] val responseHandler = new StreamObserver[proto.AddArtifactsResponse] { private val summaries = mutable.Buffer.empty[ArtifactSummary] @@ -302,4 +323,13 @@ object Artifact { override def size: Long = Files.size(path) override def stream: InputStream = Files.newInputStream(path) } + + /** + * Payload stored in memory. + */ + class InMemory(bytes: Array[Byte]) extends LocalData { + override def size: Long = bytes.length + override def stream: InputStream = new ByteArrayInputStream(bytes) + } + } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ClassFinder.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ClassFinder.scala new file mode 100644 index 0000000000000..0371d42f2d629 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ClassFinder.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client + +import java.net.URL +import java.nio.file.{Files, LinkOption, Path, Paths} + +import scala.collection.JavaConverters._ + +import ammonite.repl.api.Session +import ammonite.runtime.SpecialClassLoader + +import org.apache.spark.sql.connect.client.Artifact.{InMemory, LocalFile} + +trait ClassFinder { + def findClasses(): Iterator[Artifact] +} + +/** + * A generic [[ClassFinder]] implementation that traverses a specific REPL output directory. + * @param _rootDir + */ +class REPLClassDirMonitor(_rootDir: String) extends ClassFinder { + private val rootDir = Paths.get(_rootDir) + require(rootDir.isAbsolute) + require(Files.isDirectory(rootDir)) + + override def findClasses(): Iterator[Artifact] = { + Files + .walk(rootDir) + // Ignore symbolic links + .filter(path => Files.isRegularFile(path, LinkOption.NOFOLLOW_LINKS) && isClass(path)) + .map[Artifact](path => toArtifact(path)) + .iterator() + .asScala + } + + private def toArtifact(path: Path): Artifact = { + // Persist the relative path of the classfile + Artifact.newClassArtifact(rootDir.relativize(path), new LocalFile(path)) + } + + private def isClass(path: Path): Boolean = path.toString.endsWith(".class") +} + +/** + * A special [[ClassFinder]] for the Ammonite REPL to handle in-memory class files. + * @param session + */ +class AmmoniteClassFinder(session: Session) extends ClassFinder { + + override def findClasses(): Iterator[Artifact] = { + session.frames.iterator.flatMap { frame => + val classloader = frame.classloader.asInstanceOf[SpecialClassLoader] + val signatures: Seq[(Either[String, URL], Long)] = classloader.classpathSignature + signatures.iterator.collect { case (Left(name), _) => + val parts = name.split('.') + parts(parts.length - 1) += ".class" + val path = Paths.get(parts.head, parts.tail: _*) + val bytes = classloader.newFileDict(name) + Artifact.newClassArtifact(path, new InMemory(bytes)) + } + } + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index fd9ced6eb62fc..924515166d851 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -58,10 +58,13 @@ private[sql] class SparkConnectClient( * @return * A [[proto.AnalyzePlanResponse]] from the Spark Connect server. */ - def analyze(request: proto.AnalyzePlanRequest): proto.AnalyzePlanResponse = + def analyze(request: proto.AnalyzePlanRequest): proto.AnalyzePlanResponse = { + artifactManager.uploadAllClassFileArtifacts() stub.analyzePlan(request) + } def execute(plan: proto.Plan): java.util.Iterator[proto.ExecutePlanResponse] = { + artifactManager.uploadAllClassFileArtifacts() val request = proto.ExecutePlanRequest .newBuilder() .setPlan(plan) @@ -201,6 +204,11 @@ private[sql] class SparkConnectClient( */ def addArtifacts(uri: Seq[URI]): Unit = artifactManager.addArtifacts(uri) + /** + * Register a [[ClassFinder]] for dynamically generated classes. + */ + def registerClassFinder(finder: ClassFinder): Unit = artifactManager.registerClassFinder(finder) + /** * Shutdown the client's connection to the server. */ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index ff1d721dd58c0..77fe12568476b 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -82,17 +82,6 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper { assert(result(2) == 2) } - ignore("SPARK-42665: Ignore simple udf test until the udf is fully implemented.") { - def dummyUdf(x: Int): Int = x + 5 - val myUdf = udf(dummyUdf _) - val df = spark.range(5).select(myUdf(Column("id"))) - val result = df.collect() - assert(result.length == 5) - result.zipWithIndex.foreach { case (v, idx) => - assert(v.getInt(0) == idx + 5) - } - } - test("read and write") { val testDataPath = java.nio.file.Paths .get( diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala new file mode 100644 index 0000000000000..f0ec28a5a8792 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.application + +import java.io.{PipedInputStream, PipedOutputStream} +import java.util.concurrent.{Executors, Semaphore, TimeUnit} + +import org.apache.commons.io.output.ByteArrayOutputStream + +import org.apache.spark.sql.connect.client.util.RemoteSparkSession + +class ReplE2ESuite extends RemoteSparkSession { + + private val executorService = Executors.newSingleThreadExecutor() + private val TIMEOUT_SECONDS = 10 + + private var testSuiteOut: PipedOutputStream = _ + private var ammoniteOut: ByteArrayOutputStream = _ + private var errorStream: ByteArrayOutputStream = _ + private var ammoniteIn: PipedInputStream = _ + private val semaphore: Semaphore = new Semaphore(0) + + private def getCleanString(out: ByteArrayOutputStream): String = { + // Remove ANSI colour codes + // Regex taken from https://stackoverflow.com/a/25189932 + out.toString("UTF-8").replaceAll("\u001B\\[[\\d;]*[^\\d;]", "") + } + + override def beforeAll(): Unit = { + super.beforeAll() + ammoniteOut = new ByteArrayOutputStream() + testSuiteOut = new PipedOutputStream() + // Connect the `testSuiteOut` and `ammoniteIn` pipes + ammoniteIn = new PipedInputStream(testSuiteOut) + errorStream = new ByteArrayOutputStream() + + val args = Array("--port", serverPort.toString) + val task = new Runnable { + override def run(): Unit = { + ConnectRepl.doMain( + args = args, + semaphore = Some(semaphore), + inputStream = ammoniteIn, + outputStream = ammoniteOut, + errorStream = errorStream) + } + } + + executorService.submit(task) + } + + override def afterAll(): Unit = { + executorService.shutdownNow() + super.afterAll() + } + + def runCommandsInShell(input: String): String = { + require(input.nonEmpty) + // Pad the input with a semaphore release so that we know when the execution of the provided + // input is complete. + val paddedInput = input + '\n' + "semaphore.release()\n" + testSuiteOut.write(paddedInput.getBytes) + testSuiteOut.flush() + if (!semaphore.tryAcquire(TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + val failOut = getCleanString(ammoniteOut) + val errOut = getCleanString(errorStream) + val errorString = + s""" + |REPL Timed out while running command: $input + |Console output: $failOut + |Error output: $errOut + |""".stripMargin + throw new RuntimeException(errorString) + } + getCleanString(ammoniteOut) + } + + def assertContains(message: String, output: String): Unit = { + val isContain = output.contains(message) + assert(isContain, "Ammonite output did not contain '" + message + "':\n" + output) + } + + test("Simple query") { + // Run simple query to test REPL + val input = """ + |spark.sql("select 1").collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[org.apache.spark.sql.Row] = Array([1])", output) + } + + test("UDF containing 'def'") { + val input = """ + |class A(x: Int) { def get = x * 5 + 19 } + |def dummyUdf(x: Int): Int = new A(x).get + |val myUdf = udf(dummyUdf _) + |spark.range(5).select(myUdf(col("id"))).as[Int].collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[Int] = Array(19, 24, 29, 34, 39)", output) + } + + test("UDF containing lambda expression") { + val input = """ + |class A(x: Int) { def get = x * 20 + 5 } + |val dummyUdf = (x: Int) => new A(x).get + |val myUdf = udf(dummyUdf) + |spark.range(5).select(myUdf(col("id"))).as[Int].collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[Int] = Array(5, 25, 45, 65, 85)", output) + } + +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala new file mode 100644 index 0000000000000..c9066615bb572 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import java.nio.file.Paths + +import org.apache.commons.io.FileUtils + +import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.util.Utils + +class ClassFinderSuite extends ConnectFunSuite { + + private val classResourcePath = commonResourcePath.resolve("artifact-tests") + + test("REPLClassDirMonitor functionality test") { + val copyDir = Utils.createTempDir().toPath + FileUtils.copyDirectory(classResourcePath.toFile, copyDir.toFile) + val monitor = new REPLClassDirMonitor(copyDir.toAbsolutePath.toString) + + def checkClasses(monitor: REPLClassDirMonitor, additionalClasses: Seq[String] = Nil): Unit = { + val expectedClassFiles = (Seq( + "Hello.class", + "smallClassFile.class", + "smallClassFileDup.class") ++ additionalClasses).map(name => Paths.get(name)) + + val foundArtifacts = monitor.findClasses().toSeq + assert(expectedClassFiles.forall { classPath => + foundArtifacts.exists(_.path == Paths.get("classes").resolve(classPath)) + }) + } + + checkClasses(monitor) + + // Add new class file into directory + val subDir = Utils.createTempDir(copyDir.toAbsolutePath.toString) + val classToCopy = copyDir.resolve("Hello.class") + val copyLocation = subDir.toPath.resolve("HelloDup.class") + FileUtils.copyFile(classToCopy.toFile, copyLocation.toFile) + + checkClasses(monitor, Seq(s"${subDir.getName}/HelloDup.class")) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala index c27dff51a74d1..2d8cc6d3298e7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala @@ -129,17 +129,19 @@ object SparkConnectServerUtils { trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll { import SparkConnectServerUtils._ var spark: SparkSession = _ + protected lazy val serverPort: Int = port override def beforeAll(): Unit = { super.beforeAll() SparkConnectServerUtils.start() - spark = SparkSession.builder().client(SparkConnectClient.builder().port(port).build()).build() + spark = + SparkSession.builder().client(SparkConnectClient.builder().port(serverPort).build()).build() // Retry and wait for the server to start val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff var success = false - val error = new RuntimeException(s"Failed to start the test server on port $port.") + val error = new RuntimeException(s"Failed to start the test server on port $serverPort.") while (!success && System.nanoTime() < stop) { try { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/SimpleSparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/SimpleSparkConnectService.scala index b1376e5131a72..1b6bdd8cd9393 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/SimpleSparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/SimpleSparkConnectService.scala @@ -39,9 +39,9 @@ private[sql] object SimpleSparkConnectService { def main(args: Array[String]): Unit = { val conf = new SparkConf() + .set("spark.plugins", "org.apache.spark.sql.connect.SparkConnectPlugin") val sparkSession = SparkSession.builder().config(conf).getOrCreate() val sparkContext = sparkSession.sparkContext // init spark context - SparkConnectService.start() // scalastyle:off println println("Ready for client connections.") // scalastyle:on println