Skip to content

Commit daa83fc

Browse files
xi-dbhvanhovell
authored andcommitted
[SPARK-53525][CONNECT][FOLLOWUP] Spark Connect ArrowBatch Result Chunking - Scala Client
### What changes were proposed in this pull request? In the previous PR #52271 of Spark Connect ArrowBatch Result Chunking, both Server-side and PySpark client changes were implemented. In this PR, the corresponding Scala client changes are implemented, so large Arrow rows are now supported on the Scala client as well. To reproduce the existing issue we are solving here, run this code on Spark Connect Scala client: ``` val res = spark.sql("select repeat('a', 1024*1024*300)").collect() println(res(0).getString(0).length) ``` It fails with `RESOURCE_EXHAUSTED` error with message `gRPC message exceeds maximum size 134217728: 314573320`, because the server is trying to send an ExecutePlanResponse of ~300MB to the client. With the improvement introduced by the PR, the above code runs successfully and prints the expected result. ### Why are the changes needed? It improves Spark Connect stability when returning large rows. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52496 from xi-db/arrow-batch-chuking-scala-client. Authored-by: Xi Lyu <[email protected]> Signed-off-by: Herman van Hovell <[email protected]>
1 parent 6b88149 commit daa83fc

File tree

4 files changed

+311
-57
lines changed

4 files changed

+311
-57
lines changed

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ import scala.concurrent.{ExecutionContext, Future}
2626
import scala.concurrent.duration.{DurationInt, FiniteDuration}
2727
import scala.jdk.CollectionConverters._
2828

29+
import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, ForwardingClientCall, ForwardingClientCallListener, MethodDescriptor}
2930
import org.apache.commons.io.output.TeeOutputStream
3031
import org.scalactic.TolerantNumerics
3132
import org.scalatest.PrivateMethodTester
3233

3334
import org.apache.spark.{SparkArithmeticException, SparkException, SparkUpgradeException}
3435
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
36+
import org.apache.spark.connect.proto
3537
import org.apache.spark.internal.config.ConfigBuilder
3638
import org.apache.spark.sql.{functions, AnalysisException, Observation, Row, SaveMode}
3739
import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, TableAlreadyExistsException, TempTableAlreadyExistsException}
@@ -41,7 +43,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException
4143
import org.apache.spark.sql.connect.ConnectConversions._
4244
import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult}
4345
import org.apache.spark.sql.connect.test.{ConnectFunSuite, IntegrationTestUtils, QueryTest, RemoteSparkSession, SQLHelper}
44-
import org.apache.spark.sql.connect.test.SparkConnectServerUtils.port
46+
import org.apache.spark.sql.connect.test.SparkConnectServerUtils.{createSparkSession, port}
4547
import org.apache.spark.sql.functions._
4648
import org.apache.spark.sql.internal.SqlApiConf
4749
import org.apache.spark.sql.types._
@@ -1848,6 +1850,161 @@ class ClientE2ETestSuite
18481850
checkAnswer(df, Seq.empty)
18491851
}
18501852
}
1853+
1854+
// Helper class to capture Arrow batch chunk information from gRPC responses
1855+
private class ArrowBatchInterceptor extends ClientInterceptor {
1856+
case class BatchInfo(
1857+
batchIndex: Int,
1858+
rowCount: Long,
1859+
startOffset: Long,
1860+
chunks: Seq[ChunkInfo]) {
1861+
def totalChunks: Int = chunks.length
1862+
}
1863+
1864+
case class ChunkInfo(
1865+
batchIndex: Int,
1866+
chunkIndex: Int,
1867+
numChunksInBatch: Int,
1868+
rowCount: Long,
1869+
startOffset: Long,
1870+
dataSize: Int)
1871+
1872+
private val batches: mutable.Buffer[BatchInfo] = mutable.Buffer.empty
1873+
private var currentBatchIndex: Int = 0
1874+
private val currentBatchChunks: mutable.Buffer[ChunkInfo] = mutable.Buffer.empty
1875+
1876+
override def interceptCall[ReqT, RespT](
1877+
method: MethodDescriptor[ReqT, RespT],
1878+
callOptions: CallOptions,
1879+
next: Channel): ClientCall[ReqT, RespT] = {
1880+
new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](
1881+
next.newCall(method, callOptions)) {
1882+
override def start(
1883+
responseListener: ClientCall.Listener[RespT],
1884+
headers: io.grpc.Metadata): Unit = {
1885+
super.start(
1886+
new ForwardingClientCallListener.SimpleForwardingClientCallListener[RespT](
1887+
responseListener) {
1888+
override def onMessage(message: RespT): Unit = {
1889+
message match {
1890+
case response: proto.ExecutePlanResponse if response.hasArrowBatch =>
1891+
val arrowBatch = response.getArrowBatch
1892+
// Track chunk information for every chunk
1893+
currentBatchChunks += ChunkInfo(
1894+
batchIndex = currentBatchIndex,
1895+
chunkIndex = arrowBatch.getChunkIndex.toInt,
1896+
numChunksInBatch = arrowBatch.getNumChunksInBatch.toInt,
1897+
rowCount = arrowBatch.getRowCount,
1898+
startOffset = arrowBatch.getStartOffset,
1899+
dataSize = arrowBatch.getData.size())
1900+
// When we receive the last chunk, create the BatchInfo
1901+
if (currentBatchChunks.length == arrowBatch.getNumChunksInBatch) {
1902+
batches += BatchInfo(
1903+
batchIndex = currentBatchIndex,
1904+
rowCount = arrowBatch.getRowCount,
1905+
startOffset = arrowBatch.getStartOffset,
1906+
chunks = currentBatchChunks.toList)
1907+
currentBatchChunks.clear()
1908+
currentBatchIndex += 1
1909+
}
1910+
case _ => // Not an ExecutePlanResponse with ArrowBatch, ignore
1911+
}
1912+
super.onMessage(message)
1913+
}
1914+
},
1915+
headers)
1916+
}
1917+
}
1918+
}
1919+
1920+
// Get all batch information
1921+
def getBatchInfos: Seq[BatchInfo] = batches.toSeq
1922+
1923+
def clear(): Unit = {
1924+
currentBatchIndex = 0
1925+
currentBatchChunks.clear()
1926+
batches.clear()
1927+
}
1928+
}
1929+
1930+
test("Arrow batch result chunking") {
1931+
// This test validates that the client can correctly reassemble chunked Arrow batches
1932+
// using SequenceInputStream as implemented in SparkResult.processResponses
1933+
1934+
// Two cases are tested here:
1935+
// (a) client preferred chunk size is set: the server should respect it
1936+
// (b) client preferred chunk size is not set: the server should use its own max chunk size
1937+
Seq((Some(1024), None), (None, Some(1024))).foreach {
1938+
case (preferredChunkSizeOpt, maxChunkSizeOpt) =>
1939+
// Create interceptor to capture chunk information
1940+
val arrowBatchInterceptor = new ArrowBatchInterceptor()
1941+
1942+
try {
1943+
// Set preferred chunk size if specified and add interceptor
1944+
preferredChunkSizeOpt match {
1945+
case Some(size) =>
1946+
spark = createSparkSession(
1947+
_.preferredArrowChunkSize(Some(size)).interceptor(arrowBatchInterceptor))
1948+
case None =>
1949+
spark = createSparkSession(_.interceptor(arrowBatchInterceptor))
1950+
}
1951+
// Set server max chunk size if specified
1952+
maxChunkSizeOpt.foreach { size =>
1953+
spark.conf.set("spark.connect.session.resultChunking.maxChunkSize", size.toString)
1954+
}
1955+
1956+
val sqlQuery =
1957+
"select id, CAST(id + 0.5 AS DOUBLE) as double_val from range(0, 2000, 1, 4)"
1958+
1959+
// Execute the query using withResult to access SparkResult object
1960+
spark.sql(sqlQuery).withResult { result =>
1961+
// Verify the results are correct and complete
1962+
assert(result.length == 2000)
1963+
1964+
// Get batch information from interceptor
1965+
val batchInfos = arrowBatchInterceptor.getBatchInfos
1966+
1967+
// Assert there are 4 batches (partitions) in total
1968+
assert(batchInfos.length == 4)
1969+
1970+
// Validate chunk information for each batch
1971+
val maxChunkSize = preferredChunkSizeOpt.orElse(maxChunkSizeOpt).get
1972+
batchInfos.foreach { batch =>
1973+
// In this example, the max chunk size is set to a small value,
1974+
// so each Arrow batch should be split into multiple chunks
1975+
assert(batch.totalChunks > 5)
1976+
assert(batch.chunks.nonEmpty)
1977+
assert(batch.chunks.length == batch.totalChunks)
1978+
batch.chunks.zipWithIndex.foreach { case (chunk, expectedIndex) =>
1979+
assert(chunk.chunkIndex == expectedIndex)
1980+
assert(chunk.numChunksInBatch == batch.totalChunks)
1981+
assert(chunk.rowCount == batch.rowCount)
1982+
assert(chunk.startOffset == batch.startOffset)
1983+
assert(chunk.dataSize > 0)
1984+
assert(chunk.dataSize <= maxChunkSize)
1985+
}
1986+
}
1987+
1988+
// Validate data integrity across the range to ensure chunking didn't corrupt anything
1989+
val rows = result.toArray
1990+
var expectedId = 0L
1991+
rows.foreach { row =>
1992+
assert(row.getLong(0) == expectedId)
1993+
val expectedDouble = expectedId + 0.5
1994+
val actualDouble = row.getDouble(1)
1995+
assert(math.abs(actualDouble - expectedDouble) < 0.001)
1996+
expectedId += 1
1997+
}
1998+
}
1999+
} finally {
2000+
// Clean up configurations
2001+
maxChunkSizeOpt.foreach { _ =>
2002+
spark.conf.unset("spark.connect.session.resultChunking.maxChunkSize")
2003+
}
2004+
arrowBatchInterceptor.clear()
2005+
}
2006+
}
2007+
}
18512008
}
18522009

18532010
private[sql] case class ClassData(a: String, b: Int)

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -187,19 +187,27 @@ object SparkConnectServerUtils {
187187
}
188188

189189
def createSparkSession(): SparkSession = {
190+
createSparkSession(identity)
191+
}
192+
193+
def createSparkSession(
194+
customBuilderFunc: SparkConnectClient.Builder => SparkConnectClient.Builder)
195+
: SparkSession = {
190196
SparkConnectServerUtils.start()
191197

198+
var builder = SparkConnectClient
199+
.builder()
200+
.userId("test")
201+
.port(port)
202+
.retryPolicy(
203+
RetryPolicy
204+
.defaultPolicy()
205+
.copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s"))))
206+
207+
builder = customBuilderFunc(builder)
192208
val spark = SparkSession
193209
.builder()
194-
.client(
195-
SparkConnectClient
196-
.builder()
197-
.userId("test")
198-
.port(port)
199-
.retryPolicy(RetryPolicy
200-
.defaultPolicy()
201-
.copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s"))))
202-
.build())
210+
.client(builder.build())
203211
.create()
204212

205213
// Execute an RPC which will get retried until the server is up.

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,22 @@ private[sql] class SparkConnectClient(
138138
.setSessionId(sessionId)
139139
.setClientType(userAgent)
140140
.addAllTags(tags.get.toSeq.asJava)
141+
142+
// Add request option to allow result chunking.
143+
if (configuration.allowArrowBatchChunking) {
144+
val chunkingOptionsBuilder = proto.ResultChunkingOptions
145+
.newBuilder()
146+
.setAllowArrowBatchChunking(true)
147+
configuration.preferredArrowChunkSize.foreach { size =>
148+
chunkingOptionsBuilder.setPreferredArrowChunkSize(size)
149+
}
150+
request.addRequestOptions(
151+
proto.ExecutePlanRequest.RequestOption
152+
.newBuilder()
153+
.setResultChunkingOptions(chunkingOptionsBuilder.build())
154+
.build())
155+
}
156+
141157
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
142158
operationId.foreach { opId =>
143159
require(
@@ -332,6 +348,16 @@ private[sql] class SparkConnectClient(
332348

333349
def copy(): SparkConnectClient = configuration.toSparkConnectClient
334350

351+
/**
352+
* Returns whether arrow batch chunking is allowed.
353+
*/
354+
def allowArrowBatchChunking: Boolean = configuration.allowArrowBatchChunking
355+
356+
/**
357+
* Returns the preferred arrow chunk size in bytes.
358+
*/
359+
def preferredArrowChunkSize: Option[Int] = configuration.preferredArrowChunkSize
360+
335361
/**
336362
* Add a single artifact to the client session.
337363
*
@@ -757,6 +783,21 @@ object SparkConnectClient {
757783
this
758784
}
759785

786+
def allowArrowBatchChunking(allow: Boolean): Builder = {
787+
_configuration = _configuration.copy(allowArrowBatchChunking = allow)
788+
this
789+
}
790+
791+
def allowArrowBatchChunking: Boolean = _configuration.allowArrowBatchChunking
792+
793+
def preferredArrowChunkSize(size: Option[Int]): Builder = {
794+
size.foreach(s => require(s > 0, "preferredArrowChunkSize must be positive"))
795+
_configuration = _configuration.copy(preferredArrowChunkSize = size)
796+
this
797+
}
798+
799+
def preferredArrowChunkSize: Option[Int] = _configuration.preferredArrowChunkSize
800+
760801
def build(): SparkConnectClient = _configuration.toSparkConnectClient
761802
}
762803

@@ -801,7 +842,9 @@ object SparkConnectClient {
801842
interceptors: List[ClientInterceptor] = List.empty,
802843
sessionId: Option[String] = None,
803844
grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE,
804-
grpcMaxRecursionLimit: Int = ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) {
845+
grpcMaxRecursionLimit: Int = ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT,
846+
allowArrowBatchChunking: Boolean = true,
847+
preferredArrowChunkSize: Option[Int] = None) {
805848

806849
private def isLocal = host.equals("localhost")
807850

0 commit comments

Comments
 (0)