diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3d7dd99..49da6ed 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -7,28 +7,15 @@ jobs: fail-fast: false matrix: PYTHON_VERSION: ["3.10"] - JOBLIB_VERSION: ["1.2.0", "1.3.0"] - PIN_MODE: [false, true] - PYSPARK_VERSION: ["3.0.3", "3.1.3", "3.2.3", "3.3.2", "3.4.0"] - include: - - PYSPARK_VERSION: "3.5.1" - PYTHON_VERSION: "3.11" - JOBLIB_VERSION: "1.3.0" - - PYSPARK_VERSION: "3.5.1" - PYTHON_VERSION: "3.11" - JOBLIB_VERSION: "1.4.2" - - PYSPARK_VERSION: "3.5.1" - PYTHON_VERSION: "3.12" - JOBLIB_VERSION: "1.3.0" - - PYSPARK_VERSION: "3.5.1" - PYTHON_VERSION: "3.12" - JOBLIB_VERSION: "1.4.2" + JOBLIB_VERSION: ["1.3.2", "1.4.2"] + PYSPARK_VERSION: ["3.4.4", "3.5.5", "4.0.0.dev2"] + SPARK_CONNECT_MODE: [false, true] exclude: - - PYSPARK_VERSION: "3.0.3" - PIN_MODE: true - - PYSPARK_VERSION: "3.1.3" - PIN_MODE: true - name: Run test on pyspark ${{ matrix.PYSPARK_VERSION }}, pin_mode ${{ matrix.PIN_MODE }}, python ${{ matrix.PYTHON_VERSION }}, joblib ${{ matrix.JOBLIB_VERSION }} + - PYSPARK_VERSION: "3.4.4" + SPARK_CONNECT_MODE: true + - PYSPARK_VERSION: "3.5.5" + SPARK_CONNECT_MODE: true + name: Run test on pyspark ${{ matrix.PYSPARK_VERSION }}, Use Spark Connect ${{ matrix.SPARK_CONNECT_MODE }}, joblib ${{ matrix.JOBLIB_VERSION }} steps: - uses: actions/checkout@v3 - name: Setup python ${{ matrix.PYTHON_VERSION }} @@ -38,10 +25,10 @@ jobs: architecture: x64 - name: Install python packages run: | - pip install joblib==${{ matrix.JOBLIB_VERSION }} scikit-learn>=0.23.1 pytest pylint pyspark==${{ matrix.PYSPARK_VERSION }} + pip install joblib==${{ matrix.JOBLIB_VERSION }} scikit-learn>=0.23.1 pytest pylint "pyspark[connect]==${{ matrix.PYSPARK_VERSION }}" pandas - name: Run pylint run: | ./run-pylint.sh - name: Run test suites run: | - PYSPARK_PIN_THREAD=${{ matrix.PIN_MODE }} ./run-tests.sh + TEST_SPARK_CONNECT=${{ matrix.SPARK_CONNECT_MODE }} PYSPARK_VERSION=${{ matrix.PYSPARK_VERSION }} ./run-tests.sh diff --git a/joblibspark/backend.py b/joblibspark/backend.py index a2a3728..c3ea2ec 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -18,6 +18,7 @@ The joblib spark backend implementation. """ import atexit +import logging import warnings from multiprocessing.pool import ThreadPool import uuid @@ -47,6 +48,9 @@ from .utils import create_resource_profile, get_spark_session +_logger = logging.getLogger("joblibspark.backend") + + def register(): """ Register joblib spark backend. @@ -62,6 +66,20 @@ def register(): register_parallel_backend('spark', SparkDistributedBackend) +def is_spark_connect_mode(): + """ + Check if running with spark connect mode. + """ + try: + from pyspark.sql.utils import is_remote # pylint: disable=C0415 + return is_remote() + except ImportError: + return False + + +_DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE = 64 + + # pylint: disable=too-many-instance-attributes class SparkDistributedBackend(ParallelBackendBase, AutoBatchingMixin): """A ParallelBackend which will execute all batches on spark. @@ -82,15 +100,7 @@ def __init__(self, self._pool = None self._n_jobs = None self._spark = get_spark_session() - self._spark_context = self._spark.sparkContext self._job_group = "joblib-spark-job-group-" + str(uuid.uuid4()) - self._spark_pinned_threads_enabled = isinstance( - self._spark_context._gateway, ClientServer - ) - self._spark_supports_job_cancelling = ( - self._spark_pinned_threads_enabled - or hasattr(self._spark_context.parallelize([1]), "collectWithJobGroup") - ) self._is_running = False try: from IPython import get_ipython # pylint: disable=import-outside-toplevel @@ -98,11 +108,32 @@ def __init__(self, except ImportError: self._ipython = None - self._support_stage_scheduling = self._is_support_stage_scheduling() + self._is_spark_connect_mode = is_spark_connect_mode() + if self._is_spark_connect_mode: + if Version(pyspark.__version__).major < 4: + raise RuntimeError( + "Joblib spark does not support Spark Connect with PySpark version < 4." + ) + self._support_stage_scheduling = True + self._spark_supports_job_cancelling = True + else: + self._spark_context = self._spark.sparkContext + self._spark_pinned_threads_enabled = isinstance( + self._spark_context._gateway, ClientServer + ) + self._spark_supports_job_cancelling = ( + self._spark_pinned_threads_enabled + or hasattr(self._spark_context.parallelize([1]), "collectWithJobGroup") + ) + self._support_stage_scheduling = self._is_support_stage_scheduling() + self._resource_profile = self._create_resource_profile(num_cpus_per_spark_task, num_gpus_per_spark_task) def _is_support_stage_scheduling(self): + if self._is_spark_connect_mode: + return Version(pyspark.__version__).major >= 4 + spark_master = self._spark_context.master is_spark_local_mode = spark_master == "local" or spark_master.startswith("local[") if is_spark_local_mode: @@ -135,26 +166,53 @@ def _create_resource_profile(self, def _cancel_all_jobs(self): self._is_running = False if not self._spark_supports_job_cancelling: - # Note: There's bug existing in `sparkContext.cancelJobGroup`. - # See https://issues.apache.org/jira/browse/SPARK-31549 - warnings.warn("For spark version < 3, pyspark cancelling job API has bugs, " - "so we could not terminate running spark jobs correctly. " - "See https://issues.apache.org/jira/browse/SPARK-31549 for reference.") + if self._is_spark_connect_mode: + warnings.warn("Spark connect does not support job cancellation API " + "for Spark version < 3.5") + else: + # Note: There's bug existing in `sparkContext.cancelJobGroup`. + # See https://issues.apache.org/jira/browse/SPARK-31549 + warnings.warn("For spark version < 3, pyspark cancelling job API has bugs, " + "so we could not terminate running spark jobs correctly. " + "See https://issues.apache.org/jira/browse/SPARK-31549 for " + "reference.") else: - self._spark.sparkContext.cancelJobGroup(self._job_group) + if self._is_spark_connect_mode: + self._spark.interruptTag(self._job_group) + else: + self._spark.sparkContext.cancelJobGroup(self._job_group) def effective_n_jobs(self, n_jobs): - max_num_concurrent_tasks = self._get_max_num_concurrent_tasks() if n_jobs is None: n_jobs = 1 - elif n_jobs == -1: - # n_jobs=-1 means requesting all available workers - n_jobs = max_num_concurrent_tasks - if n_jobs > max_num_concurrent_tasks: - warnings.warn(f"User-specified n_jobs ({n_jobs}) is greater than the max number of " - f"concurrent tasks ({max_num_concurrent_tasks}) this cluster can run now." - "If dynamic allocation is enabled for the cluster, you might see more " - "executors allocated.") + + if self._is_spark_connect_mode: + if n_jobs == 1: + warnings.warn( + "The maximum number of concurrently running jobs is set to 1, " + "to increase concurrency, you need to set joblib spark backend " + "'n_jobs' param to a greater number." + ) + + if n_jobs == -1: + n_jobs = _DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE + # pylint: disable = logging-fstring-interpolation + _logger.warning( + "Joblib sets `n_jobs` to default value " + f"{_DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE} in Spark Connect mode." + ) + else: + max_num_concurrent_tasks = self._get_max_num_concurrent_tasks() + if n_jobs == -1: + # n_jobs=-1 means requesting all available workers + n_jobs = max_num_concurrent_tasks + if n_jobs > max_num_concurrent_tasks: + warnings.warn( + f"User-specified n_jobs ({n_jobs}) is greater than the max number of " + f"concurrent tasks ({max_num_concurrent_tasks}) this cluster can run now." + "If dynamic allocation is enabled for the cluster, you might see more " + "executors allocated." + ) return n_jobs def _get_max_num_concurrent_tasks(self): @@ -213,38 +271,111 @@ def run_on_worker_and_fetch_result(): raise RuntimeError('The task is canceled due to ipython command canceled.') # TODO: handle possible spark exception here. # pylint: disable=fixme - worker_rdd = self._spark.sparkContext.parallelize([0], 1) - if self._resource_profile: - worker_rdd = worker_rdd.withResources(self._resource_profile) - def mapper_fn(_): - return cloudpickle.dumps(func()) - if self._spark_supports_job_cancelling: - if self._spark_pinned_threads_enabled: - self._spark.sparkContext.setLocalProperty( - "spark.jobGroup.id", - self._job_group - ) - self._spark.sparkContext.setLocalProperty( - "spark.job.description", - "joblib spark jobs" - ) - rdd = worker_rdd.map(mapper_fn) - ser_res = rdd.collect()[0] + if self._is_spark_connect_mode: + spark_df = self._spark.range(1, numPartitions=1) + + def mapper_fn(iterator): + import pandas as pd # pylint: disable=import-outside-toplevel + for _ in iterator: # consume input data. + pass + + result = cloudpickle.dumps(func()) + yield pd.DataFrame({"result": [result]}) + + if self._spark_supports_job_cancelling: + self._spark.addTag(self._job_group) + + if self._support_stage_scheduling: + collected = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + profile=self._resource_profile, + ).collect() else: - rdd = worker_rdd.map(mapper_fn) - ser_res = rdd.collectWithJobGroup( - self._job_group, - "joblib spark jobs" - )[0] + collected = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + ).collect() + + ser_res = bytes(collected[0].result) else: - rdd = worker_rdd.map(mapper_fn) - ser_res = rdd.collect()[0] + worker_rdd = self._spark.sparkContext.parallelize([0], 1) + if self._resource_profile: + worker_rdd = worker_rdd.withResources(self._resource_profile) + + def mapper_fn(_): + return cloudpickle.dumps(func()) + + if self._spark_supports_job_cancelling: + if self._spark_pinned_threads_enabled: + self._spark.sparkContext.setLocalProperty( + "spark.jobGroup.id", + self._job_group + ) + self._spark.sparkContext.setLocalProperty( + "spark.job.description", + "joblib spark jobs" + ) + rdd = worker_rdd.map(mapper_fn) + ser_res = rdd.collect()[0] + else: + rdd = worker_rdd.map(mapper_fn) + ser_res = rdd.collectWithJobGroup( + self._job_group, + "joblib spark jobs" + )[0] + else: + rdd = worker_rdd.map(mapper_fn) + ser_res = rdd.collect()[0] return cloudpickle.loads(ser_res) try: # pylint: disable=no-name-in-module,import-outside-toplevel from pyspark import inheritable_thread_target + + if Version(pyspark.__version__).major >= 4 and is_spark_connect_mode(): + # pylint: disable=fixme + # TODO: remove this patch once Spark 4.0.0 is released. + # the patch is for propagating the Spark session to current thread. + def patched_inheritable_thread_target(f): # pylint: disable=invalid-name + import functools + import copy + from typing import Any + + session = f + assert session is not None, "Spark Connect session must be provided." + + def outer(ff: Any) -> Any: # pylint: disable=invalid-name + session_client_thread_local_attrs = [ + # type: ignore[union-attr] + (attr, copy.deepcopy(value)) + for ( + attr, + value, + ) in session.client.thread_local.__dict__.items() + ] + + @functools.wraps(ff) + def inner(*args: Any, **kwargs: Any) -> Any: + # Propagates the active spark session to the current thread + from pyspark.sql.connect.session import SparkSession as SCS + + # pylint: disable=protected-access,no-member + SCS._set_default_and_active_session(session) + + # Set thread locals in child thread. + for attr, value in session_client_thread_local_attrs: + # type: ignore[union-attr] + setattr(session.client.thread_local, attr, value) + return ff(*args, **kwargs) + + return inner + + return outer + + inheritable_thread_target = patched_inheritable_thread_target(self._spark) + run_on_worker_and_fetch_result = \ inheritable_thread_target(run_on_worker_and_fetch_result) except ImportError: diff --git a/requirements.txt b/requirements.txt index 8a41183..bdc1ad1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ joblib>=0.14 packaging +pandas + diff --git a/run-pylint.sh b/run-pylint.sh index 91276d3..60c6ce5 100755 --- a/run-pylint.sh +++ b/run-pylint.sh @@ -22,4 +22,3 @@ set -e # Run pylint python -m pylint joblibspark - diff --git a/test/test_backend.py b/test/test_backend.py index 1332e40..56fe575 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -8,14 +8,37 @@ from pyspark.sql import SparkSession from joblibspark.backend import SparkDistributedBackend +import joblibspark.backend + +joblibspark.backend._DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE = 8 + + +spark_version = os.environ["PYSPARK_VERSION"] +is_spark_connect_mode = os.environ["TEST_SPARK_CONNECT"].lower() == "true" + + +if Version(spark_version).major >= 4: + spark_connect_jar = "" +else: + spark_connect_jar = f"org.apache.spark:spark-connect_2.12:{spark_version}" class TestLocalSparkCluster(unittest.TestCase): @classmethod def setup_class(cls): - cls.spark = ( - SparkSession.builder.master("local[*]").getOrCreate() - ) + if is_spark_connect_mode: + cls.spark = ( + SparkSession.builder.config( + "spark.jars.packages", spark_connect_jar + ) + .remote("local-cluster[1, 2, 1024]") + .appName("Test") + .getOrCreate() + ) + else: + cls.spark = ( + SparkSession.builder.master("local-cluster[1, 2, 1024]").getOrCreate() + ) @classmethod def teardown_class(cls): @@ -23,22 +46,23 @@ def teardown_class(cls): def test_effective_n_jobs(self): backend = SparkDistributedBackend() - max_num_concurrent_tasks = 8 - backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks) assert backend.effective_n_jobs(n_jobs=None) == 1 - assert backend.effective_n_jobs(n_jobs=-1) == 8 assert backend.effective_n_jobs(n_jobs=4) == 4 - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - assert backend.effective_n_jobs(n_jobs=16) == 16 - assert len(w) == 1 - - def test_resource_profile_supported(self): - backend = SparkDistributedBackend() - # The test fixture uses a local (standalone) Spark instance, which doesn't support stage-level scheduling. - assert not backend._support_stage_scheduling + if is_spark_connect_mode: + assert ( + backend.effective_n_jobs(n_jobs=-1) == + joblibspark.backend._DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE + ) + else: + max_num_concurrent_tasks = 8 + backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks) + assert backend.effective_n_jobs(n_jobs=-1) == 8 + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + assert backend.effective_n_jobs(n_jobs=16) == 16 + assert len(w) == 1 class TestBasicSparkCluster(unittest.TestCase): diff --git a/test/test_spark.py b/test/test_spark.py index 9c78257..e871b45 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging from time import sleep import pytest import os @@ -27,6 +28,7 @@ from joblib.parallel import Parallel, delayed, parallel_backend from joblibspark import register_spark +import joblibspark.backend from sklearn.utils import parallel_backend from sklearn.model_selection import cross_val_score @@ -36,6 +38,21 @@ from pyspark.sql import SparkSession import pyspark +_logger = logging.getLogger("Test") +_logger.setLevel(logging.INFO) + +joblibspark.backend._DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE = 2 + + +spark_version = os.environ["PYSPARK_VERSION"] + +is_spark_connect_mode = os.environ["TEST_SPARK_CONNECT"].lower() == "true" + +if Version(spark_version).major >= 4: + spark_connect_jar = "" +else: + spark_connect_jar = f"org.apache.spark:spark-connect_2.12:{spark_version}" + register_spark() @@ -44,13 +61,28 @@ class TestSparkCluster(unittest.TestCase): @classmethod def setup_class(cls): - cls.spark = ( - SparkSession.builder.master("local-cluster[1, 2, 1024]") + spark_session_builder = ( + SparkSession.builder .config("spark.task.cpus", "1") .config("spark.task.maxFailures", "1") - .getOrCreate() ) + if is_spark_connect_mode: + _logger.info("Test with spark connect mode.") + cls.spark = ( + spark_session_builder.config( + "spark.jars.packages", spark_connect_jar + ) + .remote("local-cluster[1, 2, 1024]") # Adjust the remote address if necessary + .appName("Test") + .getOrCreate() + ) + else: + cls.spark = ( + spark_session_builder.master("local-cluster[1, 2, 1024]") + .getOrCreate() + ) + @classmethod def teardown_class(cls): cls.spark.stop() @@ -65,8 +97,8 @@ def slow_raise_value_error(condition, duration=0.05): raise ValueError("condition evaluated to True") with parallel_backend('spark') as (ba, _): - seq = Parallel(n_jobs=5)(delayed(inc)(i) for i in range(10)) - assert seq == [inc(i) for i in range(10)] + seq = Parallel(n_jobs=2)(delayed(inc)(i) for i in range(2)) + assert seq == [inc(i) for i in range(2)] with pytest.raises(BaseException): Parallel(n_jobs=5)(delayed(slow_raise_value_error)(i == 3) @@ -117,8 +149,12 @@ def test_fn(x): assert len(os.listdir(tmp_dir)) == 0 -@unittest.skipIf(Version(pyspark.__version__).release < (3, 4, 0), - "Resource group is only supported since spark 3.4.0") +@unittest.skipIf( + (not is_spark_connect_mode and Version(pyspark.__version__).release < (3, 4, 0)) or + (is_spark_connect_mode and Version(pyspark.__version__).major < 4), + "Resource group is only supported since Spark 3.4.0 for legacy Spark mode or " + "since Spark 4 for Spark Connect mode." +) class TestGPUSparkCluster(unittest.TestCase): @classmethod def setup_class(cls): @@ -126,20 +162,35 @@ def setup_class(cls): os.path.dirname(os.path.abspath(__file__)), "discover_2_gpu.sh" ) - cls.spark = ( - SparkSession.builder.master("local-cluster[1, 2, 1024]") - .config("spark.task.cpus", "1") - .config("spark.task.resource.gpu.amount", "1") - .config("spark.executor.cores", "2") - .config("spark.worker.resource.gpu.amount", "2") - .config("spark.executor.resource.gpu.amount", "2") - .config("spark.task.maxFailures", "1") - .config( - "spark.worker.resource.gpu.discoveryScript", gpu_discovery_script_path - ) - .getOrCreate() + spark_session_builder = ( + SparkSession.builder + .config("spark.task.cpus", "1") + .config("spark.task.resource.gpu.amount", "1") + .config("spark.executor.cores", "2") + .config("spark.worker.resource.gpu.amount", "2") + .config("spark.executor.resource.gpu.amount", "2") + .config("spark.task.maxFailures", "1") + .config( + "spark.worker.resource.gpu.discoveryScript", gpu_discovery_script_path + ) ) + if is_spark_connect_mode: + _logger.info("Test with spark connect mode.") + cls.spark = ( + spark_session_builder.config( + "spark.jars.packages", spark_connect_jar + ) + .remote("local-cluster[1, 2, 1024]") # Adjust the remote address if necessary + .appName("Test") + .getOrCreate() + ) + else: + cls.spark = ( + spark_session_builder.master("local-cluster[1, 2, 1024]") + .getOrCreate() + ) + @classmethod def teardown_class(cls): cls.spark.stop()