Skip to content

Commit

Permalink
[SPARK-43153][CONNECT] Skip Spark execution when the dataframe is local
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Skips Spark execution when the dataframe is local.

### Why are the changes needed?

When the built DataFrame in Spark Connect is local, we can skip Spark execution.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes apache#40806 from ueshin/issues/SPARK-43153/handler.

Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
ueshin authored and zhengruifeng committed Apr 18, 2023
1 parent fc75dab commit 66392c4
Showing 1 changed file with 92 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralP
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.SparkConnectStreamHandler.processAsArrowBatches
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{LocalTableScanExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType, UserDefinedType}
Expand Down Expand Up @@ -161,100 +161,104 @@ object SparkConnectStreamHandler {
// Conservatively sets it 70% because the size is not accurate but estimated.
val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong

SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
val rows = dataframe.queryExecution.executedPlan.execute()
val numPartitions = rows.getNumPartitions
var numSent = 0

if (numPartitions > 0) {
type Batch = (Array[Byte], Long)

val batches = rows.mapPartitionsInternal(
SparkConnectStreamHandler
.rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))

val signal = new Object
val partitions = new Array[Array[Batch]](numPartitions)
var error: Option[Throwable] = None

// This callback is executed by the DAGScheduler thread.
// After fetching a partition, it inserts the partition into the Map, and then
// wakes up the main thread.
val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
signal.synchronized {
partitions(partitionId) = partition
signal.notify()
}
()
}
val rowToArrowConverter = SparkConnectStreamHandler
.rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId)

val future = spark.sparkContext.submitJob(
rdd = batches,
processPartition = (iter: Iterator[Batch]) => iter.toArray,
partitions = Seq.range(0, numPartitions),
resultHandler = resultHandler,
resultFunc = () => ())

// Collect errors and propagate them to the main thread.
future.onComplete { result =>
result.failed.foreach { throwable =>
signal.synchronized {
error = Some(throwable)
signal.notify()
}
}
}(ThreadUtils.sameThread)

// The main thread will wait until 0-th partition is available,
// then send it to client and wait for the next partition.
// Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
// the arrow batches in main thread to avoid DAGScheduler thread been blocked for
// tasks not related to scheduling. This is particularly important if there are
// multiple users or clients running code at the same time.
var currentPartitionId = 0
while (currentPartitionId < numPartitions) {
val partition = signal.synchronized {
var part = partitions(currentPartitionId)
while (part == null && error.isEmpty) {
signal.wait()
part = partitions(currentPartitionId)
}
partitions(currentPartitionId) = null
var numSent = 0
def sendBatch(bytes: Array[Byte], count: Long): Unit = {
val response = proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId)
val batch = proto.ExecutePlanResponse.ArrowBatch
.newBuilder()
.setRowCount(count)
.setData(ByteString.copyFrom(bytes))
.build()
response.setArrowBatch(batch)
responseObserver.onNext(response.build())
numSent += 1
}

error.foreach { case other =>
throw other
dataframe.queryExecution.executedPlan match {
case LocalTableScanExec(_, rows) =>
rowToArrowConverter(rows.iterator).foreach { case (bytes, count) =>
sendBatch(bytes, count)
}
case _ =>
SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
val rows = dataframe.queryExecution.executedPlan.execute()
val numPartitions = rows.getNumPartitions

if (numPartitions > 0) {
type Batch = (Array[Byte], Long)

val batches = rows.mapPartitionsInternal(rowToArrowConverter)

val signal = new Object
val partitions = new Array[Array[Batch]](numPartitions)
var error: Option[Throwable] = None

// This callback is executed by the DAGScheduler thread.
// After fetching a partition, it inserts the partition into the Map, and then
// wakes up the main thread.
val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
signal.synchronized {
partitions(partitionId) = partition
signal.notify()
}
()
}
part
}

partition.foreach { case (bytes, count) =>
val response = proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId)
val batch = proto.ExecutePlanResponse.ArrowBatch
.newBuilder()
.setRowCount(count)
.setData(ByteString.copyFrom(bytes))
.build()
response.setArrowBatch(batch)
responseObserver.onNext(response.build())
numSent += 1
val future = spark.sparkContext.submitJob(
rdd = batches,
processPartition = (iter: Iterator[Batch]) => iter.toArray,
partitions = Seq.range(0, numPartitions),
resultHandler = resultHandler,
resultFunc = () => ())

// Collect errors and propagate them to the main thread.
future.onComplete { result =>
result.failed.foreach { throwable =>
signal.synchronized {
error = Some(throwable)
signal.notify()
}
}
}(ThreadUtils.sameThread)

// The main thread will wait until 0-th partition is available,
// then send it to client and wait for the next partition.
// Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
// the arrow batches in main thread to avoid DAGScheduler thread been blocked for
// tasks not related to scheduling. This is particularly important if there are
// multiple users or clients running code at the same time.
var currentPartitionId = 0
while (currentPartitionId < numPartitions) {
val partition = signal.synchronized {
var part = partitions(currentPartitionId)
while (part == null && error.isEmpty) {
signal.wait()
part = partitions(currentPartitionId)
}
partitions(currentPartitionId) = null

error.foreach { other =>
throw other
}
part
}

partition.foreach { case (bytes, count) =>
sendBatch(bytes, count)
}

currentPartitionId += 1
}
}

currentPartitionId += 1
}
}
}

// Make sure at least 1 batch will be sent.
if (numSent == 0) {
val bytes = ArrowConverters.createEmptyArrowBatch(schema, timeZoneId)
val response = proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId)
val batch = proto.ExecutePlanResponse.ArrowBatch
.newBuilder()
.setRowCount(0L)
.setData(ByteString.copyFrom(bytes))
.build()
response.setArrowBatch(batch)
responseObserver.onNext(response.build())
}
// Make sure at least 1 batch will be sent.
if (numSent == 0) {
sendBatch(ArrowConverters.createEmptyArrowBatch(schema, timeZoneId), 0L)
}
}

Expand Down

0 comments on commit 66392c4

Please sign in to comment.