Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support spark connect #62

Merged
merged 13 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 10 additions & 23 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,15 @@ jobs:
fail-fast: false
matrix:
PYTHON_VERSION: ["3.10"]
JOBLIB_VERSION: ["1.2.0", "1.3.0"]
PIN_MODE: [false, true]
PYSPARK_VERSION: ["3.0.3", "3.1.3", "3.2.3", "3.3.2", "3.4.0"]
include:
- PYSPARK_VERSION: "3.5.1"
PYTHON_VERSION: "3.11"
JOBLIB_VERSION: "1.3.0"
- PYSPARK_VERSION: "3.5.1"
PYTHON_VERSION: "3.11"
JOBLIB_VERSION: "1.4.2"
- PYSPARK_VERSION: "3.5.1"
PYTHON_VERSION: "3.12"
JOBLIB_VERSION: "1.3.0"
- PYSPARK_VERSION: "3.5.1"
PYTHON_VERSION: "3.12"
JOBLIB_VERSION: "1.4.2"
JOBLIB_VERSION: ["1.3.2", "1.4.2"]
PYSPARK_VERSION: ["3.4.4", "3.5.5", "4.0.0.dev2"]
SPARK_CONNECT_MODE: [false, true]
exclude:
- PYSPARK_VERSION: "3.0.3"
PIN_MODE: true
- PYSPARK_VERSION: "3.1.3"
PIN_MODE: true
name: Run test on pyspark ${{ matrix.PYSPARK_VERSION }}, pin_mode ${{ matrix.PIN_MODE }}, python ${{ matrix.PYTHON_VERSION }}, joblib ${{ matrix.JOBLIB_VERSION }}
- PYSPARK_VERSION: "3.4.4"
SPARK_CONNECT_MODE: true
- PYSPARK_VERSION: "3.5.5"
SPARK_CONNECT_MODE: true
name: Run test on pyspark ${{ matrix.PYSPARK_VERSION }}, Use Spark Connect ${{ matrix.SPARK_CONNECT_MODE }}, joblib ${{ matrix.JOBLIB_VERSION }}
steps:
- uses: actions/checkout@v3
- name: Setup python ${{ matrix.PYTHON_VERSION }}
Expand All @@ -38,10 +25,10 @@ jobs:
architecture: x64
- name: Install python packages
run: |
pip install joblib==${{ matrix.JOBLIB_VERSION }} scikit-learn>=0.23.1 pytest pylint pyspark==${{ matrix.PYSPARK_VERSION }}
pip install joblib==${{ matrix.JOBLIB_VERSION }} scikit-learn>=0.23.1 pytest pylint "pyspark[connect]==${{ matrix.PYSPARK_VERSION }}" pandas
- name: Run pylint
run: |
./run-pylint.sh
- name: Run test suites
run: |
PYSPARK_PIN_THREAD=${{ matrix.PIN_MODE }} ./run-tests.sh
TEST_SPARK_CONNECT=${{ matrix.SPARK_CONNECT_MODE }} PYSPARK_VERSION=${{ matrix.PYSPARK_VERSION }} ./run-tests.sh
227 changes: 179 additions & 48 deletions joblibspark/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
The joblib spark backend implementation.
"""
import atexit
import logging
import warnings
from multiprocessing.pool import ThreadPool
import uuid
Expand Down Expand Up @@ -47,6 +48,9 @@
from .utils import create_resource_profile, get_spark_session


_logger = logging.getLogger("joblibspark.backend")


def register():
"""
Register joblib spark backend.
Expand All @@ -62,6 +66,20 @@ def register():
register_parallel_backend('spark', SparkDistributedBackend)


def is_spark_connect_mode():
"""
Check if running with spark connect mode.
"""
try:
from pyspark.sql.utils import is_remote # pylint: disable=C0415
return is_remote()
except ImportError:
return False


_DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE = 64


# pylint: disable=too-many-instance-attributes
class SparkDistributedBackend(ParallelBackendBase, AutoBatchingMixin):
"""A ParallelBackend which will execute all batches on spark.
Expand All @@ -82,27 +100,40 @@ def __init__(self,
self._pool = None
self._n_jobs = None
self._spark = get_spark_session()
self._spark_context = self._spark.sparkContext
self._job_group = "joblib-spark-job-group-" + str(uuid.uuid4())
self._spark_pinned_threads_enabled = isinstance(
self._spark_context._gateway, ClientServer
)
self._spark_supports_job_cancelling = (
self._spark_pinned_threads_enabled
or hasattr(self._spark_context.parallelize([1]), "collectWithJobGroup")
)
self._is_running = False
try:
from IPython import get_ipython # pylint: disable=import-outside-toplevel
self._ipython = get_ipython()
except ImportError:
self._ipython = None

self._support_stage_scheduling = self._is_support_stage_scheduling()
self._is_spark_connect_mode = is_spark_connect_mode()
if self._is_spark_connect_mode:
if Version(pyspark.__version__).major < 4:
raise RuntimeError(
"Joblib spark does not support Spark Connect with PySpark version < 4."
)
self._support_stage_scheduling = True
self._spark_supports_job_cancelling = True
else:
self._spark_context = self._spark.sparkContext
self._spark_pinned_threads_enabled = isinstance(
self._spark_context._gateway, ClientServer
)
self._spark_supports_job_cancelling = (
self._spark_pinned_threads_enabled
or hasattr(self._spark_context.parallelize([1]), "collectWithJobGroup")
)
self._support_stage_scheduling = self._is_support_stage_scheduling()

self._resource_profile = self._create_resource_profile(num_cpus_per_spark_task,
num_gpus_per_spark_task)

def _is_support_stage_scheduling(self):
if self._is_spark_connect_mode:
return Version(pyspark.__version__).major >= 4

spark_master = self._spark_context.master
is_spark_local_mode = spark_master == "local" or spark_master.startswith("local[")
if is_spark_local_mode:
Expand Down Expand Up @@ -135,26 +166,53 @@ def _create_resource_profile(self,
def _cancel_all_jobs(self):
self._is_running = False
if not self._spark_supports_job_cancelling:
# Note: There's bug existing in `sparkContext.cancelJobGroup`.
# See https://issues.apache.org/jira/browse/SPARK-31549
warnings.warn("For spark version < 3, pyspark cancelling job API has bugs, "
"so we could not terminate running spark jobs correctly. "
"See https://issues.apache.org/jira/browse/SPARK-31549 for reference.")
if self._is_spark_connect_mode:
warnings.warn("Spark connect does not support job cancellation API "
"for Spark version < 3.5")
else:
# Note: There's bug existing in `sparkContext.cancelJobGroup`.
# See https://issues.apache.org/jira/browse/SPARK-31549
warnings.warn("For spark version < 3, pyspark cancelling job API has bugs, "
"so we could not terminate running spark jobs correctly. "
"See https://issues.apache.org/jira/browse/SPARK-31549 for "
"reference.")
else:
self._spark.sparkContext.cancelJobGroup(self._job_group)
if self._is_spark_connect_mode:
self._spark.interruptTag(self._job_group)
else:
self._spark.sparkContext.cancelJobGroup(self._job_group)

def effective_n_jobs(self, n_jobs):
max_num_concurrent_tasks = self._get_max_num_concurrent_tasks()
if n_jobs is None:
n_jobs = 1
elif n_jobs == -1:
# n_jobs=-1 means requesting all available workers
n_jobs = max_num_concurrent_tasks
if n_jobs > max_num_concurrent_tasks:
warnings.warn(f"User-specified n_jobs ({n_jobs}) is greater than the max number of "
f"concurrent tasks ({max_num_concurrent_tasks}) this cluster can run now."
"If dynamic allocation is enabled for the cluster, you might see more "
"executors allocated.")

if self._is_spark_connect_mode:
if n_jobs == 1:
warnings.warn(
"The maximum number of concurrently running jobs is set to 1, "
"to increase concurrency, you need to set joblib spark backend "
"'n_jobs' param to a greater number."
)

if n_jobs == -1:
n_jobs = _DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE
# pylint: disable = logging-fstring-interpolation
_logger.warning(
"Joblib sets `n_jobs` to default value "
f"{_DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE} in Spark Connect mode."
)
else:
max_num_concurrent_tasks = self._get_max_num_concurrent_tasks()
if n_jobs == -1:
# n_jobs=-1 means requesting all available workers
n_jobs = max_num_concurrent_tasks
if n_jobs > max_num_concurrent_tasks:
warnings.warn(
f"User-specified n_jobs ({n_jobs}) is greater than the max number of "
f"concurrent tasks ({max_num_concurrent_tasks}) this cluster can run now."
"If dynamic allocation is enabled for the cluster, you might see more "
"executors allocated."
)
return n_jobs

def _get_max_num_concurrent_tasks(self):
Expand Down Expand Up @@ -213,38 +271,111 @@ def run_on_worker_and_fetch_result():
raise RuntimeError('The task is canceled due to ipython command canceled.')

# TODO: handle possible spark exception here. # pylint: disable=fixme
worker_rdd = self._spark.sparkContext.parallelize([0], 1)
if self._resource_profile:
worker_rdd = worker_rdd.withResources(self._resource_profile)
def mapper_fn(_):
return cloudpickle.dumps(func())
if self._spark_supports_job_cancelling:
if self._spark_pinned_threads_enabled:
self._spark.sparkContext.setLocalProperty(
"spark.jobGroup.id",
self._job_group
)
self._spark.sparkContext.setLocalProperty(
"spark.job.description",
"joblib spark jobs"
)
rdd = worker_rdd.map(mapper_fn)
ser_res = rdd.collect()[0]
if self._is_spark_connect_mode:
spark_df = self._spark.range(1, numPartitions=1)

def mapper_fn(iterator):
import pandas as pd # pylint: disable=import-outside-toplevel
for _ in iterator: # consume input data.
pass

result = cloudpickle.dumps(func())
yield pd.DataFrame({"result": [result]})

if self._spark_supports_job_cancelling:
self._spark.addTag(self._job_group)

if self._support_stage_scheduling:
collected = spark_df.mapInPandas(
mapper_fn,
schema="result binary",
profile=self._resource_profile,
).collect()
else:
rdd = worker_rdd.map(mapper_fn)
ser_res = rdd.collectWithJobGroup(
self._job_group,
"joblib spark jobs"
)[0]
collected = spark_df.mapInPandas(
mapper_fn,
schema="result binary",
).collect()

ser_res = bytes(collected[0].result)
else:
rdd = worker_rdd.map(mapper_fn)
ser_res = rdd.collect()[0]
worker_rdd = self._spark.sparkContext.parallelize([0], 1)
if self._resource_profile:
worker_rdd = worker_rdd.withResources(self._resource_profile)

def mapper_fn(_):
return cloudpickle.dumps(func())

if self._spark_supports_job_cancelling:
if self._spark_pinned_threads_enabled:
self._spark.sparkContext.setLocalProperty(
"spark.jobGroup.id",
self._job_group
)
self._spark.sparkContext.setLocalProperty(
"spark.job.description",
"joblib spark jobs"
)
rdd = worker_rdd.map(mapper_fn)
ser_res = rdd.collect()[0]
else:
rdd = worker_rdd.map(mapper_fn)
ser_res = rdd.collectWithJobGroup(
self._job_group,
"joblib spark jobs"
)[0]
else:
rdd = worker_rdd.map(mapper_fn)
ser_res = rdd.collect()[0]

return cloudpickle.loads(ser_res)

try:
# pylint: disable=no-name-in-module,import-outside-toplevel
from pyspark import inheritable_thread_target

if Version(pyspark.__version__).major >= 4 and is_spark_connect_mode():
# pylint: disable=fixme
# TODO: remove this patch once Spark 4.0.0 is released.
# the patch is for propagating the Spark session to current thread.
def patched_inheritable_thread_target(f): # pylint: disable=invalid-name
import functools
import copy
from typing import Any

session = f
assert session is not None, "Spark Connect session must be provided."

def outer(ff: Any) -> Any: # pylint: disable=invalid-name
session_client_thread_local_attrs = [
# type: ignore[union-attr]
(attr, copy.deepcopy(value))
for (
attr,
value,
) in session.client.thread_local.__dict__.items()
]

@functools.wraps(ff)
def inner(*args: Any, **kwargs: Any) -> Any:
# Propagates the active spark session to the current thread
from pyspark.sql.connect.session import SparkSession as SCS

# pylint: disable=protected-access,no-member
SCS._set_default_and_active_session(session)

# Set thread locals in child thread.
for attr, value in session_client_thread_local_attrs:
# type: ignore[union-attr]
setattr(session.client.thread_local, attr, value)
return ff(*args, **kwargs)

return inner

return outer

inheritable_thread_target = patched_inheritable_thread_target(self._spark)

run_on_worker_and_fetch_result = \
inheritable_thread_target(run_on_worker_and_fetch_result)
except ImportError:
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
joblib>=0.14
packaging
pandas

1 change: 0 additions & 1 deletion run-pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,3 @@ set -e

# Run pylint
python -m pylint joblibspark

Loading