diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 19fdad97b5ffb..f9571445e74d0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -90,4 +90,14 @@ object Connect { .stringConf .toSequence .createWithDefault(Nil) + + val CONNECT_JVM_STACK_TRACE_MAX_SIZE = + ConfigBuilder("spark.connect.jvmStacktrace.maxSize") + .doc(""" + |Sets the maximum stack trace size to display when + |`spark.sql.pyspark.jvmStacktrace.enabled` is true. + |""".stripMargin) + .version("3.5.0") + .intConf + .createWithDefault(4096) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 86c36bba7a08a..f8483f963d89c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -33,6 +33,7 @@ import io.grpc.protobuf.StatusProto import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver import org.apache.commons.lang3.StringUtils +import org.apache.commons.lang3.exception.ExceptionUtils import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods.{compact, render} @@ -42,7 +43,8 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE} +import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_JVM_STACK_TRACE_MAX_SIZE} +import org.apache.spark.sql.internal.SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED /** * The SparkConnectService implementation. @@ -73,19 +75,26 @@ class SparkConnectService(debug: Boolean) classes.toSeq } - private def buildStatusFromThrowable(st: Throwable): RPCStatus = { + private def buildStatusFromThrowable(st: Throwable, stackTraceEnabled: Boolean): RPCStatus = { val message = StringUtils.abbreviate(st.getMessage, 2048) + val errorInfo = ErrorInfo + .newBuilder() + .setReason(st.getClass.getName) + .setDomain("org.apache.spark") + .putMetadata("classes", compact(render(allClasses(st.getClass).map(_.getName)))) + + lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st)) + val withStackTrace = if (stackTraceEnabled && stackTrace.nonEmpty) { + val maxSize = SparkEnv.get.conf.get(CONNECT_JVM_STACK_TRACE_MAX_SIZE) + errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize)) + } else { + errorInfo + } + RPCStatus .newBuilder() .setCode(RPCCode.INTERNAL_VALUE) - .addDetails( - ProtoAny.pack( - ErrorInfo - .newBuilder() - .setReason(st.getClass.getName) - .setDomain("org.apache.spark") - .putMetadata("classes", compact(render(allClasses(st.getClass).map(_.getName)))) - .build())) + .addDetails(ProtoAny.pack(withStackTrace.build())) .setMessage(if (message != null) message else "") .build() } @@ -110,23 +119,42 @@ class SparkConnectService(debug: Boolean) */ private def handleError[V]( opType: String, - observer: StreamObserver[V]): PartialFunction[Throwable, Unit] = { - case se: SparkException if isPythonExecutionException(se) => - logError(s"Error during: $opType", se) - observer.onError( - StatusProto.toStatusRuntimeException(buildStatusFromThrowable(se.getCause))) - - case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => - logError(s"Error during: $opType", e) - observer.onError(StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e))) - - case e: Throwable => - logError(s"Error during: $opType", e) - observer.onError( - Status.UNKNOWN - .withCause(e) - .withDescription(StringUtils.abbreviate(e.getMessage, 2048)) - .asRuntimeException()) + observer: StreamObserver[V], + userId: String, + sessionId: String): PartialFunction[Throwable, Unit] = { + val session = + SparkConnectService + .getOrCreateIsolatedSession(userId, sessionId) + .session + val stackTraceEnabled = + try { + session.conf.get(PYSPARK_JVM_STACKTRACE_ENABLED.key).toBoolean + } catch { + case NonFatal(e) => + logWarning(s"Failed to get Spark conf `PYSPARK_JVM_STACKTRACE_ENABLED`: $e") + true + } + + { + case se: SparkException if isPythonExecutionException(se) => + logError(s"Error during: $opType", se) + observer.onError( + StatusProto.toStatusRuntimeException( + buildStatusFromThrowable(se.getCause, stackTraceEnabled))) + + case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => + logError(s"Error during: $opType", e) + observer.onError( + StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, stackTraceEnabled))) + + case e: Throwable => + logError(s"Error during: $opType", e) + observer.onError( + Status.UNKNOWN + .withCause(e) + .withDescription(StringUtils.abbreviate(e.getMessage, 2048)) + .asRuntimeException()) + } } /** @@ -144,7 +172,13 @@ class SparkConnectService(debug: Boolean) responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = { try { new SparkConnectStreamHandler(responseObserver).handle(request) - } catch handleError("execute", observer = responseObserver) + } catch { + handleError( + "execute", + observer = responseObserver, + userId = request.getUserContext.getUserId, + sessionId = request.getSessionId) + } } /** @@ -164,7 +198,13 @@ class SparkConnectService(debug: Boolean) responseObserver: StreamObserver[proto.AnalyzePlanResponse]): Unit = { try { new SparkConnectAnalyzeHandler(responseObserver).handle(request) - } catch handleError("analyze", observer = responseObserver) + } catch { + handleError( + "analyze", + observer = responseObserver, + userId = request.getUserContext.getUserId, + sessionId = request.getSessionId) + } } /** @@ -179,7 +219,13 @@ class SparkConnectService(debug: Boolean) responseObserver: StreamObserver[proto.ConfigResponse]): Unit = { try { new SparkConnectConfigHandler(responseObserver).handle(request) - } catch handleError("config", observer = responseObserver) + } catch { + handleError( + "config", + observer = responseObserver, + userId = request.getUserContext.getUserId, + sessionId = request.getSessionId) + } } /** diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index 43fee1f0af94f..f8f234ed2eeef 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -49,6 +49,10 @@ def convert_exception(info: "ErrorInfo", message: str) -> SparkConnectException: if "classes" in info.metadata: classes = json.loads(info.metadata["classes"]) + if "stackTrace" in info.metadata: + stackTrace = info.metadata["stackTrace"] + message += f"\n\nJVM stacktrace:\n{stackTrace}" + if "org.apache.spark.sql.catalyst.parser.ParseException" in classes: return ParseException(message) # Order matters. ParseException inherits AnalysisException. diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 9d12eb2b26e9e..94292074f788c 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3124,27 +3124,38 @@ def test_unsupported_jvm_attribute(self): ) -class SparkConnectSessionTests(SparkConnectSQLTestCase): +class SparkConnectSessionTests(ReusedConnectTestCase): + def setUp(self) -> None: + self.spark = ( + PySparkSession.builder.config(conf=self.conf()) + .appName(self.__class__.__name__) + .remote("local[4]") + .getOrCreate() + ) + + def tearDown(self): + self.spark.stop() + def _check_no_active_session_error(self, e: PySparkException): self.check_error(exception=e, error_class="NO_ACTIVE_SESSION", message_parameters=dict()) def test_stop_session(self): - df = self.connect.sql("select 1 as a, 2 as b") - catalog = self.connect.catalog - self.connect.stop() + df = self.spark.sql("select 1 as a, 2 as b") + catalog = self.spark.catalog + self.spark.stop() # _execute_and_fetch with self.assertRaises(SparkConnectException) as e: - self.connect.sql("select 1") + self.spark.sql("select 1") self._check_no_active_session_error(e.exception) with self.assertRaises(SparkConnectException) as e: - catalog.tableExists(self.tbl_name) + catalog.tableExists("table") self._check_no_active_session_error(e.exception) # _execute with self.assertRaises(SparkConnectException) as e: - self.connect.udf.register("test_func", lambda x: x + 1) + self.spark.udf.register("test_func", lambda x: x + 1) self._check_no_active_session_error(e.exception) # _analyze @@ -3154,9 +3165,43 @@ def test_stop_session(self): # Config with self.assertRaises(SparkConnectException) as e: - self.connect.conf.get("some.conf") + self.spark.conf.get("some.conf") self._check_no_active_session_error(e.exception) + def test_error_stack_trace(self): + with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}): + with self.assertRaises(AnalysisException) as e: + self.spark.sql("select x").collect() + self.assertTrue("JVM stacktrace" in e.exception.message) + self.assertTrue( + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message + ) + + with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}): + with self.assertRaises(AnalysisException) as e: + self.spark.sql("select x").collect() + self.assertFalse("JVM stacktrace" in e.exception.message) + self.assertFalse( + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message + ) + + # Create a new session with a different stack trace size. + self.spark.stop() + spark = ( + PySparkSession.builder.config(conf=self.conf()) + .config("spark.connect.jvmStacktrace.maxSize", 128) + .remote("local[4]") + .getOrCreate() + ) + spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "true") + with self.assertRaises(AnalysisException) as e: + spark.sql("select x").collect() + self.assertTrue("JVM stacktrace" in e.exception.message) + self.assertFalse( + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message + ) + spark.stop() + @unittest.skipIf(not should_test_connect, connect_requirement_message) class ClientTests(unittest.TestCase): diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 662a7d1446e46..5d57ad803bc53 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -146,7 +146,11 @@ def conf(cls): """ Override this in subclasses to supply a more specific conf """ - return SparkConf(loadDefaults=False) + conf = SparkConf(loadDefaults=False) + # Disable JVM stack trace in Spark Connect tests to prevent the + # HTTP header size from exceeding the maximum allowed size. + conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "false") + return conf @classmethod def setUpClass(cls):