Skip to content

Commit

Permalink
[SPARK-42945][CONNECT] Support PYSPARK_JVM_STACKTRACE_ENABLED in Spar…
Browse files Browse the repository at this point in the history
…k Connect

### What changes were proposed in this pull request?

This PR supports `spark.sql.pyspark.jvmStacktrace.enabled` in Spark Connect to optionally show the JVM stack trace.

It also adds a new Spark Connect config ,`spark.connect.jvmStacktrace.maxSize` (default: 4096), to adjust the stack trace size. This is to prevent the HTTP header size from exceeding the maximum allowed size.

### Why are the changes needed?

To support an existing config that works with legacy PySpark.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

New unit test.

Closes apache#40575 from allisonwang-db/spark-42945-stack-trace.

Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
allisonwang-db authored and HyukjinKwon committed Apr 19, 2023
1 parent c291564 commit b9400c7
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,14 @@ object Connect {
.stringConf
.toSequence
.createWithDefault(Nil)

val CONNECT_JVM_STACK_TRACE_MAX_SIZE =
ConfigBuilder("spark.connect.jvmStacktrace.maxSize")
.doc("""
|Sets the maximum stack trace size to display when
|`spark.sql.pyspark.jvmStacktrace.enabled` is true.
|""".stripMargin)
.version("3.5.0")
.intConf
.createWithDefault(4096)
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import io.grpc.protobuf.StatusProto
import io.grpc.protobuf.services.ProtoReflectionService
import io.grpc.stub.StreamObserver
import org.apache.commons.lang3.StringUtils
import org.apache.commons.lang3.exception.ExceptionUtils
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

Expand All @@ -42,7 +43,8 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE}
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_JVM_STACK_TRACE_MAX_SIZE}
import org.apache.spark.sql.internal.SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED

/**
* The SparkConnectService implementation.
Expand Down Expand Up @@ -73,19 +75,26 @@ class SparkConnectService(debug: Boolean)
classes.toSeq
}

private def buildStatusFromThrowable(st: Throwable): RPCStatus = {
private def buildStatusFromThrowable(st: Throwable, stackTraceEnabled: Boolean): RPCStatus = {
val message = StringUtils.abbreviate(st.getMessage, 2048)
val errorInfo = ErrorInfo
.newBuilder()
.setReason(st.getClass.getName)
.setDomain("org.apache.spark")
.putMetadata("classes", compact(render(allClasses(st.getClass).map(_.getName))))

lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st))
val withStackTrace = if (stackTraceEnabled && stackTrace.nonEmpty) {
val maxSize = SparkEnv.get.conf.get(CONNECT_JVM_STACK_TRACE_MAX_SIZE)
errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize))
} else {
errorInfo
}

RPCStatus
.newBuilder()
.setCode(RPCCode.INTERNAL_VALUE)
.addDetails(
ProtoAny.pack(
ErrorInfo
.newBuilder()
.setReason(st.getClass.getName)
.setDomain("org.apache.spark")
.putMetadata("classes", compact(render(allClasses(st.getClass).map(_.getName))))
.build()))
.addDetails(ProtoAny.pack(withStackTrace.build()))
.setMessage(if (message != null) message else "")
.build()
}
Expand All @@ -110,23 +119,42 @@ class SparkConnectService(debug: Boolean)
*/
private def handleError[V](
opType: String,
observer: StreamObserver[V]): PartialFunction[Throwable, Unit] = {
case se: SparkException if isPythonExecutionException(se) =>
logError(s"Error during: $opType", se)
observer.onError(
StatusProto.toStatusRuntimeException(buildStatusFromThrowable(se.getCause)))

case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) =>
logError(s"Error during: $opType", e)
observer.onError(StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e)))

case e: Throwable =>
logError(s"Error during: $opType", e)
observer.onError(
Status.UNKNOWN
.withCause(e)
.withDescription(StringUtils.abbreviate(e.getMessage, 2048))
.asRuntimeException())
observer: StreamObserver[V],
userId: String,
sessionId: String): PartialFunction[Throwable, Unit] = {
val session =
SparkConnectService
.getOrCreateIsolatedSession(userId, sessionId)
.session
val stackTraceEnabled =
try {
session.conf.get(PYSPARK_JVM_STACKTRACE_ENABLED.key).toBoolean
} catch {
case NonFatal(e) =>
logWarning(s"Failed to get Spark conf `PYSPARK_JVM_STACKTRACE_ENABLED`: $e")
true
}

{
case se: SparkException if isPythonExecutionException(se) =>
logError(s"Error during: $opType", se)
observer.onError(
StatusProto.toStatusRuntimeException(
buildStatusFromThrowable(se.getCause, stackTraceEnabled)))

case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) =>
logError(s"Error during: $opType", e)
observer.onError(
StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, stackTraceEnabled)))

case e: Throwable =>
logError(s"Error during: $opType", e)
observer.onError(
Status.UNKNOWN
.withCause(e)
.withDescription(StringUtils.abbreviate(e.getMessage, 2048))
.asRuntimeException())
}
}

/**
Expand All @@ -144,7 +172,13 @@ class SparkConnectService(debug: Boolean)
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
try {
new SparkConnectStreamHandler(responseObserver).handle(request)
} catch handleError("execute", observer = responseObserver)
} catch {
handleError(
"execute",
observer = responseObserver,
userId = request.getUserContext.getUserId,
sessionId = request.getSessionId)
}
}

/**
Expand All @@ -164,7 +198,13 @@ class SparkConnectService(debug: Boolean)
responseObserver: StreamObserver[proto.AnalyzePlanResponse]): Unit = {
try {
new SparkConnectAnalyzeHandler(responseObserver).handle(request)
} catch handleError("analyze", observer = responseObserver)
} catch {
handleError(
"analyze",
observer = responseObserver,
userId = request.getUserContext.getUserId,
sessionId = request.getSessionId)
}
}

/**
Expand All @@ -179,7 +219,13 @@ class SparkConnectService(debug: Boolean)
responseObserver: StreamObserver[proto.ConfigResponse]): Unit = {
try {
new SparkConnectConfigHandler(responseObserver).handle(request)
} catch handleError("config", observer = responseObserver)
} catch {
handleError(
"config",
observer = responseObserver,
userId = request.getUserContext.getUserId,
sessionId = request.getSessionId)
}
}

/**
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/errors/exceptions/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def convert_exception(info: "ErrorInfo", message: str) -> SparkConnectException:
if "classes" in info.metadata:
classes = json.loads(info.metadata["classes"])

if "stackTrace" in info.metadata:
stackTrace = info.metadata["stackTrace"]
message += f"\n\nJVM stacktrace:\n{stackTrace}"

if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
return ParseException(message)
# Order matters. ParseException inherits AnalysisException.
Expand Down
61 changes: 53 additions & 8 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3124,27 +3124,38 @@ def test_unsupported_jvm_attribute(self):
)


class SparkConnectSessionTests(SparkConnectSQLTestCase):
class SparkConnectSessionTests(ReusedConnectTestCase):
def setUp(self) -> None:
self.spark = (
PySparkSession.builder.config(conf=self.conf())
.appName(self.__class__.__name__)
.remote("local[4]")
.getOrCreate()
)

def tearDown(self):
self.spark.stop()

def _check_no_active_session_error(self, e: PySparkException):
self.check_error(exception=e, error_class="NO_ACTIVE_SESSION", message_parameters=dict())

def test_stop_session(self):
df = self.connect.sql("select 1 as a, 2 as b")
catalog = self.connect.catalog
self.connect.stop()
df = self.spark.sql("select 1 as a, 2 as b")
catalog = self.spark.catalog
self.spark.stop()

# _execute_and_fetch
with self.assertRaises(SparkConnectException) as e:
self.connect.sql("select 1")
self.spark.sql("select 1")
self._check_no_active_session_error(e.exception)

with self.assertRaises(SparkConnectException) as e:
catalog.tableExists(self.tbl_name)
catalog.tableExists("table")
self._check_no_active_session_error(e.exception)

# _execute
with self.assertRaises(SparkConnectException) as e:
self.connect.udf.register("test_func", lambda x: x + 1)
self.spark.udf.register("test_func", lambda x: x + 1)
self._check_no_active_session_error(e.exception)

# _analyze
Expand All @@ -3154,9 +3165,43 @@ def test_stop_session(self):

# Config
with self.assertRaises(SparkConnectException) as e:
self.connect.conf.get("some.conf")
self.spark.conf.get("some.conf")
self._check_no_active_session_error(e.exception)

def test_error_stack_trace(self):
with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}):
with self.assertRaises(AnalysisException) as e:
self.spark.sql("select x").collect()
self.assertTrue("JVM stacktrace" in e.exception.message)
self.assertTrue(
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
)

with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}):
with self.assertRaises(AnalysisException) as e:
self.spark.sql("select x").collect()
self.assertFalse("JVM stacktrace" in e.exception.message)
self.assertFalse(
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
)

# Create a new session with a different stack trace size.
self.spark.stop()
spark = (
PySparkSession.builder.config(conf=self.conf())
.config("spark.connect.jvmStacktrace.maxSize", 128)
.remote("local[4]")
.getOrCreate()
)
spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "true")
with self.assertRaises(AnalysisException) as e:
spark.sql("select x").collect()
self.assertTrue("JVM stacktrace" in e.exception.message)
self.assertFalse(
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
)
spark.stop()


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class ClientTests(unittest.TestCase):
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/testing/connectutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ def conf(cls):
"""
Override this in subclasses to supply a more specific conf
"""
return SparkConf(loadDefaults=False)
conf = SparkConf(loadDefaults=False)
# Disable JVM stack trace in Spark Connect tests to prevent the
# HTTP header size from exceeding the maximum allowed size.
conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "false")
return conf

@classmethod
def setUpClass(cls):
Expand Down

0 comments on commit b9400c7

Please sign in to comment.