Skip to content

[SPARK-51843][PYTHON][ML][TESTS] Avoid per-test classic session start & stop #50640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import numpy as np

from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_torch, torch_requirement_message
from pyspark.testing.sqlutils import ReusedSQLTestCase

if should_test_connect:
from pyspark.ml.connect.classification import (
Expand Down Expand Up @@ -231,12 +231,10 @@ def test_save_load(self):
or torch_requirement_message
or "pyspark-connect cannot test classic Spark",
)
class ClassificationTests(ClassificationTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[2]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
class ClassificationTests(ClassificationTestsMixin, ReusedSQLTestCase):
@classmethod
def master(cls):
return "local[2]"


if __name__ == "__main__":
Expand Down
12 changes: 5 additions & 7 deletions python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import numpy as np

from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_torcheval, torcheval_requirement_message
from pyspark.testing.sqlutils import ReusedSQLTestCase

if should_test_connect:
from pyspark.ml.connect.evaluation import (
Expand Down Expand Up @@ -178,12 +178,10 @@ def test_multiclass_classifier_evaluator(self):
or torcheval_requirement_message
or "pyspark-connect cannot test classic Spark",
)
class EvaluationTests(EvaluationTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[2]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
class EvaluationTests(EvaluationTestsMixin, ReusedSQLTestCase):
@classmethod
def master(cls):
return "local[2]"


if __name__ == "__main__":
Expand Down
12 changes: 5 additions & 7 deletions python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import numpy as np

from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_torch, torch_requirement_message
from pyspark.testing.sqlutils import ReusedSQLTestCase

if should_test_connect:
from pyspark.ml.connect.feature import (
Expand Down Expand Up @@ -201,12 +201,10 @@ def test_array_assembler(self):
or torch_requirement_message
or "pyspark-connect cannot test classic Spark",
)
class FeatureTests(FeatureTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[2]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
class FeatureTests(FeatureTestsMixin, ReusedSQLTestCase):
@classmethod
def master(cls):
return "local[2]"


if __name__ == "__main__":
Expand Down
12 changes: 5 additions & 7 deletions python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import numpy as np

from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_torch, torch_requirement_message
from pyspark.testing.sqlutils import ReusedSQLTestCase

if should_test_connect:
from pyspark.ml.connect.feature import StandardScaler
Expand Down Expand Up @@ -175,12 +175,10 @@ def test_pipeline_copy():
or torch_requirement_message
or "pyspark-connect cannot test classic Spark",
)
class PipelineTests(PipelineTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[2]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase):
@classmethod
def master(cls):
return "local[2]"


if __name__ == "__main__":
Expand Down
12 changes: 5 additions & 7 deletions python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import numpy as np

from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.sqlutils import ReusedSQLTestCase

if should_test_connect:
from pyspark.ml.connect.summarizer import summarize_dataframe
Expand Down Expand Up @@ -67,12 +67,10 @@ def assert_dict_allclose(dict1, dict2):
not should_test_connect or is_remote_only(),
connect_requirement_message or "pyspark-connect cannot test classic Spark",
)
class SummarizerTests(SummarizerTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[2]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
class SummarizerTests(SummarizerTestsMixin, ReusedSQLTestCase):
@classmethod
def master(cls):
return "local[2]"


if __name__ == "__main__":
Expand Down
12 changes: 5 additions & 7 deletions python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from pyspark.util import is_remote_only
from pyspark.ml.param import Param, Params
from pyspark.ml.tuning import ParamGridBuilder
from pyspark.sql import SparkSession
from pyspark.sql.functions import rand
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import (
Expand All @@ -36,6 +35,7 @@
have_torcheval,
torcheval_requirement_message,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase

if should_test_connect:
import pandas as pd
Expand Down Expand Up @@ -294,12 +294,10 @@ def test_crossvalidator_with_fold_col(self):
or torcheval_requirement_message
or "pyspark-connect cannot test classic Spark",
)
class CrossValidatorTests(CrossValidatorTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[2]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
class CrossValidatorTests(CrossValidatorTestsMixin, ReusedSQLTestCase):
@classmethod
def master(cls):
return "local[2]"


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,11 @@ def conf(cls):
def setUpClass(cls):
from pyspark import SparkContext

cls.sc = SparkContext("local[4]", cls.__name__, conf=cls.conf())
cls.sc = SparkContext(cls.master(), cls.__name__, conf=cls.conf())

@classmethod
def master(cls):
return "local[4]"

@classmethod
def tearDownClass(cls):
Expand Down