Skip to content

Commit

Permalink
[SPARK-42657][CONNECT] Support to find and transfer client-side REPL …
Browse files Browse the repository at this point in the history
…classfiles to server as artifacts

### What changes were proposed in this pull request?

This PR introduces the concept of a `ClassFinder` that is able to scrape the REPL output (either file-based or in-memory based) for generated class files.  The `ClassFinder` is registered during initialization of the REPL and aids in uploading the generated class files as artifacts to the Spark Connect server.

### Why are the changes needed?

To run UDFs which are defined on the client side REPL, we require a mechanism that can find the local REPL classfiles and then utilise the mechanism from https://issues.apache.org/jira/browse/SPARK-42653 to transfer them to the server as artifacts.

### Does this PR introduce _any_ user-facing change?

Yes, users can now run UDFs on the default (ammonite) REPL with spark connect.

Input (in REPL):
```
class A(x: Int) { def get = x * 5 + 19 }
def dummyUdf(x: Int): Int = new A(x).get
val myUdf = udf(dummyUdf _)
spark.range(5).select(myUdf(col("id"))).as[Int].collect()
```

Output:
```
Array[Int] = Array(19, 24, 29, 34, 39)
```

### How was this patch tested?

Unit tests + E2E tests.

Closes apache#40675 from vicennial/SPARK-42657.

Lead-authored-by: vicennial <[email protected]>
Co-authored-by: Venkata Sai Akhil Gudesa <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
  • Loading branch information
2 people authored and hvanhovell committed Apr 17, 2023
1 parent 7a5b6c8 commit 3941369
Show file tree
Hide file tree
Showing 11 changed files with 349 additions and 27 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ jobs:
# Run the tests.
- name: Run tests
env: ${{ fromJSON(inputs.envs) }}
shell: 'script -q -e -c "bash {0}"'
run: |
# Fix for TTY related issues when launching the Ammonite REPL in tests.
export TERM=vt100 && script -qfc 'echo exit | amm -s' && rm typescript
# Hive "other tests" test needs larger metaspace size based on experiment.
if [[ "$MODULES_TO_TEST" == "hive" ]] && [[ "$EXCLUDED_TAGS" == "org.apache.spark.tags.SlowHiveTest" ]]; then export METASPACE_SIZE=2g; fi
export SERIAL_SBT_TESTS=1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.util.{Cleaner, ConvertToArrow}
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.internal.CatalogImpl
Expand Down Expand Up @@ -495,6 +495,14 @@ class SparkSession private[sql] (
@scala.annotation.varargs
def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri)

/**
* Register a [[ClassFinder]] for dynamically generated classes.
*
* @since 3.4.0
*/
@Experimental
def registerClassFinder(finder: ClassFinder): Unit = client.registerClassFinder(finder)

/**
* This resets the plan id generator so we can produce plans that are comparable.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
*/
package org.apache.spark.sql.application

import java.io.{InputStream, OutputStream}
import java.util.concurrent.Semaphore

import scala.util.control.NonFatal

import ammonite.compiler.CodeClassWrapper
import ammonite.util.Bind

import org.apache.spark.annotation.DeveloperApi
Expand All @@ -43,7 +45,14 @@ object ConnectRepl {
| /_/
|""".stripMargin

def main(args: Array[String]): Unit = {
def main(args: Array[String]): Unit = doMain(args)

private[application] def doMain(
args: Array[String],
semaphore: Option[Semaphore] = None,
inputStream: InputStream = System.in,
outputStream: OutputStream = System.out,
errorStream: OutputStream = System.err): Unit = {
// Build the client.
val client =
try {
Expand All @@ -67,22 +76,30 @@ object ConnectRepl {

// Build the session.
val spark = SparkSession.builder().client(client).build()
val sparkBind = new Bind("spark", spark)

// Add the proper imports.
val imports =
// Add the proper imports and register a [[ClassFinder]].
val predefCode =
"""
|import org.apache.spark.sql.functions._
|import spark.implicits._
|import spark.sql
|import org.apache.spark.sql.connect.client.AmmoniteClassFinder
|
|spark.registerClassFinder(new AmmoniteClassFinder(repl.sess))
|""".stripMargin

// Please note that we make ammonite generate classes instead of objects.
// Classes tend to have superior serialization behavior when using UDFs.
val main = ammonite.Main(
welcomeBanner = Option(splash),
predefCode = imports,
replCodeWrapper = CodeClassWrapper,
scriptCodeWrapper = CodeClassWrapper)
main.run(new Bind("spark", spark))
predefCode = predefCode,
inputStream = inputStream,
outputStream = outputStream,
errorStream = errorStream)
if (semaphore.nonEmpty) {
// Used for testing.
main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get))
} else {
main.run(sparkBind)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
*/
package org.apache.spark.sql.connect.client

import java.io.InputStream
import java.io.{ByteArrayInputStream, InputStream}
import java.net.URI
import java.nio.file.{Files, Path, Paths}
import java.util.concurrent.CopyOnWriteArrayList
import java.util.zip.{CheckedInputStream, CRC32}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.Promise
import scala.concurrent.duration.Duration
Expand Down Expand Up @@ -48,6 +50,12 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
private val CHUNK_SIZE: Int = 32 * 1024

private[this] val stub = proto.SparkConnectServiceGrpc.newStub(channel)
private[this] val classFinders = new CopyOnWriteArrayList[ClassFinder]

/**
* Register a [[ClassFinder]] for dynamically generated classes.
*/
def registerClassFinder(finder: ClassFinder): Unit = classFinders.add(finder)

/**
* Add a single artifact to the session.
Expand Down Expand Up @@ -92,10 +100,23 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
*/
def addArtifacts(uris: Seq[URI]): Unit = addArtifacts(uris.flatMap(parseArtifacts))

/**
* Upload all class file artifacts from the local REPL(s) to the server.
*
* The registered [[ClassFinder]]s are traversed to retrieve the class file artifacts.
*/
private[client] def uploadAllClassFileArtifacts(): Unit = {
addArtifacts(classFinders.asScala.flatMap(_.findClasses()))
}

/**
* Add a number of artifacts to the session.
*/
private def addArtifacts(artifacts: Iterable[Artifact]): Unit = {
if (artifacts.isEmpty) {
return
}

val promise = Promise[Seq[ArtifactSummary]]
val responseHandler = new StreamObserver[proto.AddArtifactsResponse] {
private val summaries = mutable.Buffer.empty[ArtifactSummary]
Expand Down Expand Up @@ -302,4 +323,13 @@ object Artifact {
override def size: Long = Files.size(path)
override def stream: InputStream = Files.newInputStream(path)
}

/**
* Payload stored in memory.
*/
class InMemory(bytes: Array[Byte]) extends LocalData {
override def size: Long = bytes.length
override def stream: InputStream = new ByteArrayInputStream(bytes)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.client

import java.net.URL
import java.nio.file.{Files, LinkOption, Path, Paths}

import scala.collection.JavaConverters._

import ammonite.repl.api.Session
import ammonite.runtime.SpecialClassLoader

import org.apache.spark.sql.connect.client.Artifact.{InMemory, LocalFile}

trait ClassFinder {
def findClasses(): Iterator[Artifact]
}

/**
* A generic [[ClassFinder]] implementation that traverses a specific REPL output directory.
* @param _rootDir
*/
class REPLClassDirMonitor(_rootDir: String) extends ClassFinder {
private val rootDir = Paths.get(_rootDir)
require(rootDir.isAbsolute)
require(Files.isDirectory(rootDir))

override def findClasses(): Iterator[Artifact] = {
Files
.walk(rootDir)
// Ignore symbolic links
.filter(path => Files.isRegularFile(path, LinkOption.NOFOLLOW_LINKS) && isClass(path))
.map[Artifact](path => toArtifact(path))
.iterator()
.asScala
}

private def toArtifact(path: Path): Artifact = {
// Persist the relative path of the classfile
Artifact.newClassArtifact(rootDir.relativize(path), new LocalFile(path))
}

private def isClass(path: Path): Boolean = path.toString.endsWith(".class")
}

/**
* A special [[ClassFinder]] for the Ammonite REPL to handle in-memory class files.
* @param session
*/
class AmmoniteClassFinder(session: Session) extends ClassFinder {

override def findClasses(): Iterator[Artifact] = {
session.frames.iterator.flatMap { frame =>
val classloader = frame.classloader.asInstanceOf[SpecialClassLoader]
val signatures: Seq[(Either[String, URL], Long)] = classloader.classpathSignature
signatures.iterator.collect { case (Left(name), _) =>
val parts = name.split('.')
parts(parts.length - 1) += ".class"
val path = Paths.get(parts.head, parts.tail: _*)
val bytes = classloader.newFileDict(name)
Artifact.newClassArtifact(path, new InMemory(bytes))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ private[sql] class SparkConnectClient(
* @return
* A [[proto.AnalyzePlanResponse]] from the Spark Connect server.
*/
def analyze(request: proto.AnalyzePlanRequest): proto.AnalyzePlanResponse =
def analyze(request: proto.AnalyzePlanRequest): proto.AnalyzePlanResponse = {
artifactManager.uploadAllClassFileArtifacts()
stub.analyzePlan(request)
}

def execute(plan: proto.Plan): java.util.Iterator[proto.ExecutePlanResponse] = {
artifactManager.uploadAllClassFileArtifacts()
val request = proto.ExecutePlanRequest
.newBuilder()
.setPlan(plan)
Expand Down Expand Up @@ -201,6 +204,11 @@ private[sql] class SparkConnectClient(
*/
def addArtifacts(uri: Seq[URI]): Unit = artifactManager.addArtifacts(uri)

/**
* Register a [[ClassFinder]] for dynamically generated classes.
*/
def registerClassFinder(finder: ClassFinder): Unit = artifactManager.registerClassFinder(finder)

/**
* Shutdown the client's connection to the server.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,6 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
assert(result(2) == 2)
}

ignore("SPARK-42665: Ignore simple udf test until the udf is fully implemented.") {
def dummyUdf(x: Int): Int = x + 5
val myUdf = udf(dummyUdf _)
val df = spark.range(5).select(myUdf(Column("id")))
val result = df.collect()
assert(result.length == 5)
result.zipWithIndex.foreach { case (v, idx) =>
assert(v.getInt(0) == idx + 5)
}
}

test("read and write") {
val testDataPath = java.nio.file.Paths
.get(
Expand Down
Loading

0 comments on commit 3941369

Please sign in to comment.