diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala index 4edeae132b47a..1db56921e2c65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.python.streaming -import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} -import java.nio.channels.{Channels, ServerSocketChannel} +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException, InterruptedIOException} +import java.nio.channels.{Channels, ClosedByInterruptException, ServerSocketChannel} import java.time.Duration import scala.collection.mutable @@ -181,7 +181,7 @@ class TransformWithStateInPySparkStateServer( logWarning(log"No more data to read from the socket") statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) return - case _: InterruptedException => + case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => logInfo(log"Thread interrupted, shutting down state server") Thread.currentThread().interrupt() statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)