Skip to content

Commit

Permalink
Revert "[SPARK-42945][CONNECT] Support PYSPARK_JVM_STACKTRACE_ENABLED…
Browse files Browse the repository at this point in the history
… in Spark Connect"

This reverts commit b9400c7.
  • Loading branch information
HyukjinKwon committed Apr 20, 2023
1 parent 87ccfc2 commit 09a4353
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,4 @@ object Connect {
.stringConf
.toSequence
.createWithDefault(Nil)

val CONNECT_JVM_STACK_TRACE_MAX_SIZE =
ConfigBuilder("spark.connect.jvmStacktrace.maxSize")
.doc("""
|Sets the maximum stack trace size to display when
|`spark.sql.pyspark.jvmStacktrace.enabled` is true.
|""".stripMargin)
.version("3.5.0")
.intConf
.createWithDefault(4096)
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import io.grpc.protobuf.StatusProto
import io.grpc.protobuf.services.ProtoReflectionService
import io.grpc.stub.StreamObserver
import org.apache.commons.lang3.StringUtils
import org.apache.commons.lang3.exception.ExceptionUtils
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

Expand All @@ -43,8 +42,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_JVM_STACK_TRACE_MAX_SIZE}
import org.apache.spark.sql.internal.SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE}

/**
* The SparkConnectService implementation.
Expand Down Expand Up @@ -75,26 +73,19 @@ class SparkConnectService(debug: Boolean)
classes.toSeq
}

private def buildStatusFromThrowable(st: Throwable, stackTraceEnabled: Boolean): RPCStatus = {
private def buildStatusFromThrowable(st: Throwable): RPCStatus = {
val message = StringUtils.abbreviate(st.getMessage, 2048)
val errorInfo = ErrorInfo
.newBuilder()
.setReason(st.getClass.getName)
.setDomain("org.apache.spark")
.putMetadata("classes", compact(render(allClasses(st.getClass).map(_.getName))))

lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st))
val withStackTrace = if (stackTraceEnabled && stackTrace.nonEmpty) {
val maxSize = SparkEnv.get.conf.get(CONNECT_JVM_STACK_TRACE_MAX_SIZE)
errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize))
} else {
errorInfo
}

RPCStatus
.newBuilder()
.setCode(RPCCode.INTERNAL_VALUE)
.addDetails(ProtoAny.pack(withStackTrace.build()))
.addDetails(
ProtoAny.pack(
ErrorInfo
.newBuilder()
.setReason(st.getClass.getName)
.setDomain("org.apache.spark")
.putMetadata("classes", compact(render(allClasses(st.getClass).map(_.getName))))
.build()))
.setMessage(if (message != null) message else "")
.build()
}
Expand All @@ -119,42 +110,23 @@ class SparkConnectService(debug: Boolean)
*/
private def handleError[V](
opType: String,
observer: StreamObserver[V],
userId: String,
sessionId: String): PartialFunction[Throwable, Unit] = {
val session =
SparkConnectService
.getOrCreateIsolatedSession(userId, sessionId)
.session
val stackTraceEnabled =
try {
session.conf.get(PYSPARK_JVM_STACKTRACE_ENABLED.key).toBoolean
} catch {
case NonFatal(e) =>
logWarning(s"Failed to get Spark conf `PYSPARK_JVM_STACKTRACE_ENABLED`: $e")
true
}

{
case se: SparkException if isPythonExecutionException(se) =>
logError(s"Error during: $opType", se)
observer.onError(
StatusProto.toStatusRuntimeException(
buildStatusFromThrowable(se.getCause, stackTraceEnabled)))

case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) =>
logError(s"Error during: $opType", e)
observer.onError(
StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, stackTraceEnabled)))

case e: Throwable =>
logError(s"Error during: $opType", e)
observer.onError(
Status.UNKNOWN
.withCause(e)
.withDescription(StringUtils.abbreviate(e.getMessage, 2048))
.asRuntimeException())
}
observer: StreamObserver[V]): PartialFunction[Throwable, Unit] = {
case se: SparkException if isPythonExecutionException(se) =>
logError(s"Error during: $opType", se)
observer.onError(
StatusProto.toStatusRuntimeException(buildStatusFromThrowable(se.getCause)))

case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) =>
logError(s"Error during: $opType", e)
observer.onError(StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e)))

case e: Throwable =>
logError(s"Error during: $opType", e)
observer.onError(
Status.UNKNOWN
.withCause(e)
.withDescription(StringUtils.abbreviate(e.getMessage, 2048))
.asRuntimeException())
}

/**
Expand All @@ -172,13 +144,7 @@ class SparkConnectService(debug: Boolean)
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
try {
new SparkConnectStreamHandler(responseObserver).handle(request)
} catch {
handleError(
"execute",
observer = responseObserver,
userId = request.getUserContext.getUserId,
sessionId = request.getSessionId)
}
} catch handleError("execute", observer = responseObserver)
}

/**
Expand All @@ -198,13 +164,7 @@ class SparkConnectService(debug: Boolean)
responseObserver: StreamObserver[proto.AnalyzePlanResponse]): Unit = {
try {
new SparkConnectAnalyzeHandler(responseObserver).handle(request)
} catch {
handleError(
"analyze",
observer = responseObserver,
userId = request.getUserContext.getUserId,
sessionId = request.getSessionId)
}
} catch handleError("analyze", observer = responseObserver)
}

/**
Expand All @@ -219,13 +179,7 @@ class SparkConnectService(debug: Boolean)
responseObserver: StreamObserver[proto.ConfigResponse]): Unit = {
try {
new SparkConnectConfigHandler(responseObserver).handle(request)
} catch {
handleError(
"config",
observer = responseObserver,
userId = request.getUserContext.getUserId,
sessionId = request.getSessionId)
}
} catch handleError("config", observer = responseObserver)
}

/**
Expand Down
4 changes: 0 additions & 4 deletions python/pyspark/errors/exceptions/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ def convert_exception(info: "ErrorInfo", message: str) -> SparkConnectException:
if "classes" in info.metadata:
classes = json.loads(info.metadata["classes"])

if "stackTrace" in info.metadata:
stackTrace = info.metadata["stackTrace"]
message += f"\n\nJVM stacktrace:\n{stackTrace}"

if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
return ParseException(message)
# Order matters. ParseException inherits AnalysisException.
Expand Down
61 changes: 8 additions & 53 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3124,38 +3124,27 @@ def test_unsupported_jvm_attribute(self):
)


class SparkConnectSessionTests(ReusedConnectTestCase):
def setUp(self) -> None:
self.spark = (
PySparkSession.builder.config(conf=self.conf())
.appName(self.__class__.__name__)
.remote("local[4]")
.getOrCreate()
)

def tearDown(self):
self.spark.stop()

class SparkConnectSessionTests(SparkConnectSQLTestCase):
def _check_no_active_session_error(self, e: PySparkException):
self.check_error(exception=e, error_class="NO_ACTIVE_SESSION", message_parameters=dict())

def test_stop_session(self):
df = self.spark.sql("select 1 as a, 2 as b")
catalog = self.spark.catalog
self.spark.stop()
df = self.connect.sql("select 1 as a, 2 as b")
catalog = self.connect.catalog
self.connect.stop()

# _execute_and_fetch
with self.assertRaises(SparkConnectException) as e:
self.spark.sql("select 1")
self.connect.sql("select 1")
self._check_no_active_session_error(e.exception)

with self.assertRaises(SparkConnectException) as e:
catalog.tableExists("table")
catalog.tableExists(self.tbl_name)
self._check_no_active_session_error(e.exception)

# _execute
with self.assertRaises(SparkConnectException) as e:
self.spark.udf.register("test_func", lambda x: x + 1)
self.connect.udf.register("test_func", lambda x: x + 1)
self._check_no_active_session_error(e.exception)

# _analyze
Expand All @@ -3165,43 +3154,9 @@ def test_stop_session(self):

# Config
with self.assertRaises(SparkConnectException) as e:
self.spark.conf.get("some.conf")
self.connect.conf.get("some.conf")
self._check_no_active_session_error(e.exception)

def test_error_stack_trace(self):
with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}):
with self.assertRaises(AnalysisException) as e:
self.spark.sql("select x").collect()
self.assertTrue("JVM stacktrace" in e.exception.message)
self.assertTrue(
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
)

with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}):
with self.assertRaises(AnalysisException) as e:
self.spark.sql("select x").collect()
self.assertFalse("JVM stacktrace" in e.exception.message)
self.assertFalse(
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
)

# Create a new session with a different stack trace size.
self.spark.stop()
spark = (
PySparkSession.builder.config(conf=self.conf())
.config("spark.connect.jvmStacktrace.maxSize", 128)
.remote("local[4]")
.getOrCreate()
)
spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "true")
with self.assertRaises(AnalysisException) as e:
spark.sql("select x").collect()
self.assertTrue("JVM stacktrace" in e.exception.message)
self.assertFalse(
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
)
spark.stop()


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class ClientTests(unittest.TestCase):
Expand Down
6 changes: 1 addition & 5 deletions python/pyspark/testing/connectutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,7 @@ def conf(cls):
"""
Override this in subclasses to supply a more specific conf
"""
conf = SparkConf(loadDefaults=False)
# Disable JVM stack trace in Spark Connect tests to prevent the
# HTTP header size from exceeding the maximum allowed size.
conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "false")
return conf
return SparkConf(loadDefaults=False)

@classmethod
def setUpClass(cls):
Expand Down

0 comments on commit 09a4353

Please sign in to comment.