Skip to content

Ray Job work using Kuberay Python Client #858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,362 changes: 701 additions & 661 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ cryptography = "43.0.3"
executing = "1.2.0"
pydantic = "< 2"
ipywidgets = "8.1.2"
odh-kuberay-client = {version = "0.0.0.dev40", source = "testpypi"}

[[tool.poetry.source]]
name = "pypi"

[[tool.poetry.source]]
name = "testpypi"
url = "https://test.pypi.org/simple/"

[tool.poetry.group.docs]
optional = true
Expand Down
1 change: 1 addition & 0 deletions src/codeflare_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AWManager,
AppWrapperStatus,
RayJobClient,
RayJob,
)

from .common.widgets import view_clusters
Expand Down
4 changes: 4 additions & 0 deletions src/codeflare_sdk/ray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
RayJobClient,
)

from .rayjobs import (
RayJob,
)

from .cluster import (
Cluster,
ClusterConfiguration,
Expand Down
2 changes: 2 additions & 0 deletions src/codeflare_sdk/ray/cluster/build_ray_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
"enableIngress": False,
"rayStartParams": {
"dashboard-host": "0.0.0.0",
"dashboard-port": "8265",
"block": "true",
"num-gpus": str(head_gpu_count),
"resources": head_resources,
Expand Down Expand Up @@ -245,6 +246,7 @@ def get_labels(cluster: "codeflare_sdk.ray.cluster.Cluster"):
"""
labels = {
"controller-tools.k8s.io": "1.0",
"ray.io/cluster": cluster.config.name, # Enforced label always present
}
if cluster.config.labels != {}:
labels.update(cluster.config.labels)
Expand Down
310 changes: 307 additions & 3 deletions src/codeflare_sdk/ray/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@

from time import sleep
from typing import List, Optional, Tuple, Dict
import copy

from ray.job_submission import JobSubmissionClient
from ray.job_submission import JobSubmissionClient, JobStatus
import time
import uuid
import warnings

from ..job.job import RayJobSpec

from ...common.kubernetes_cluster.auth import (
config_check,
Expand Down Expand Up @@ -57,7 +63,6 @@
from kubernetes.client.rest import ApiException

from kubernetes.client.rest import ApiException
import warnings

CF_SDK_FIELD_MANAGER = "codeflare-sdk"

Expand Down Expand Up @@ -604,6 +609,298 @@ def _component_resources_down(
yamls = yaml.safe_load_all(self.resource_yaml)
_delete_resources(yamls, namespace, api_instance, cluster_name)

@staticmethod
def run_job_with_managed_cluster(
cluster_config: ClusterConfiguration,
job_config: RayJobSpec,
job_cr_name: Optional[str] = None,
submission_mode: str = "K8sJobMode",
shutdown_after_job_finishes: bool = True,
ttl_seconds_after_finished: Optional[int] = None,
suspend_rayjob_creation: bool = False,
wait_for_completion: bool = True,
job_timeout_seconds: Optional[int] = 3600,
job_polling_interval_seconds: int = 10,
):
"""
Manages the lifecycle of a Ray cluster and a job by creating a RayJob custom resource.
KubeRay operator will then create/delete the RayCluster based on the RayJob definition.

This method will:
1. Generate a RayCluster specification from the provided 'cluster_config'.
2. Construct a RayJob custom resource definition using 'job_config' and embedding the RayCluster spec.
3. Create the RayJob resource in Kubernetes.
4. Optionally, wait for the RayJob to complete or timeout, monitoring its status.
5. The RayCluster lifecycle (creation and deletion) is managed by KubeRay
based on the RayJob's 'shutdownAfterJobFinishes' field.

Args:
cluster_config: Configuration for the Ray cluster to be created by RayJob.
job_config: RayJobSpec object containing job-specific details like entrypoint, runtime_env, etc.
job_cr_name: Name for the RayJob Custom Resource. If None, a unique name is generated.
submission_mode: How the job is submitted ("K8sJobMode" or "RayClientMode").
shutdown_after_job_finishes: If True, RayCluster is deleted after job finishes.
ttl_seconds_after_finished: TTL for RayJob after it's finished.
suspend_rayjob_creation: If True, creates the RayJob in a suspended state.
wait_for_completion: If True, waits for the job to finish.
job_timeout_seconds: Timeout for waiting for job completion.
job_polling_interval_seconds: Interval for polling job status.

Returns:
A dictionary containing details like RayJob CR name, reported job submission ID,
final job status, dashboard URL, and the RayCluster name.

Raises:
TimeoutError: If the job doesn't complete within the specified timeout.
ApiException: For Kubernetes API errors.
ValueError: For configuration issues.
"""
config_check()
k8s_co_api = k8s_client.CustomObjectsApi(get_api_client())
namespace = cluster_config.namespace

if not job_config.entrypoint:
raise ValueError("job_config.entrypoint must be specified.")

# Warn if Pydantic V1/V2 specific fields in RayJobSpec are set, as they are not used for RayJob CR.
if (
job_config.entrypoint_num_cpus is not None
or job_config.entrypoint_num_gpus is not None
or job_config.entrypoint_memory is not None
):
warnings.warn(
"RayJobSpec fields 'entrypoint_num_cpus', 'entrypoint_num_gpus', 'entrypoint_memory' "
"are not directly used when creating a RayJob CR. They are primarily for the Ray Job Submission Client. "
"Resource requests for the job driver pod should be configured in the RayCluster head node spec via ClusterConfiguration.",
UserWarning,
)

# Generate rayClusterSpec from ClusterConfiguration
temp_config_for_spec = copy.deepcopy(cluster_config)
temp_config_for_spec.appwrapper = False

with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
dummy_cluster_for_spec = Cluster(temp_config_for_spec)

ray_cluster_cr_dict = dummy_cluster_for_spec.resource_yaml
if (
not isinstance(ray_cluster_cr_dict, dict)
or "spec" not in ray_cluster_cr_dict
):
raise ValueError(
"Failed to generate RayCluster CR dictionary from ClusterConfiguration. "
f"Got: {type(ray_cluster_cr_dict)}"
)
ray_cluster_spec = ray_cluster_cr_dict["spec"]

# Prepare RayJob CR
actual_job_cr_name = job_cr_name or f"rayjob-{uuid.uuid4().hex[:10]}"

runtime_env_yaml_str = ""
if job_config.runtime_env:
try:
runtime_env_yaml_str = yaml.dump(job_config.runtime_env)
except yaml.YAMLError as e:
raise ValueError(
f"Invalid job_config.runtime_env, failed to dump to YAML: {e}"
)

ray_job_cr_spec = {
"entrypoint": job_config.entrypoint,
"shutdownAfterJobFinishes": shutdown_after_job_finishes,
"rayClusterSpec": ray_cluster_spec,
"submissionMode": submission_mode,
}

if runtime_env_yaml_str:
ray_job_cr_spec["runtimeEnvYAML"] = runtime_env_yaml_str
if job_config.submission_id:
ray_job_cr_spec["jobId"] = job_config.submission_id
if job_config.metadata:
ray_job_cr_spec["metadata"] = job_config.metadata
if ttl_seconds_after_finished is not None:
ray_job_cr_spec["ttlSecondsAfterFinished"] = ttl_seconds_after_finished
if suspend_rayjob_creation:
ray_job_cr_spec["suspend"] = True
if job_config.entrypoint_resources:
ray_job_cr_spec["entrypointResources"] = job_config.entrypoint_resources

ray_job_cr = {
"apiVersion": "ray.io/v1",
"kind": "RayJob",
"metadata": {
"name": actual_job_cr_name,
"namespace": namespace,
},
"spec": ray_job_cr_spec,
}

returned_job_submission_id = None
final_job_status = "UNKNOWN"
dashboard_url = None
ray_cluster_name_actual = None

try:
print(
f"Submitting RayJob '{actual_job_cr_name}' to namespace '{namespace}'..."
)
k8s_co_api.create_namespaced_custom_object(
group="ray.io",
version="v1",
namespace=namespace,
plural="rayjobs",
body=ray_job_cr,
)
print(f"RayJob '{actual_job_cr_name}' created successfully.")

if wait_for_completion:
print(f"Waiting for RayJob '{actual_job_cr_name}' to complete...")
start_time = time.time()
while True:
try:
ray_job_status_cr = (
k8s_co_api.get_namespaced_custom_object_status(
group="ray.io",
version="v1",
namespace=namespace,
plural="rayjobs",
name=actual_job_cr_name,
)
)
except ApiException as e:
if e.status == 404:
print(
f"RayJob '{actual_job_cr_name}' status not found yet, retrying..."
)
time.sleep(job_polling_interval_seconds)
continue
raise

status_field = ray_job_status_cr.get("status", {})
job_deployment_status = status_field.get(
"jobDeploymentStatus", "UNKNOWN"
)
current_job_status = status_field.get("jobStatus", "PENDING")

dashboard_url = status_field.get("dashboardURL", dashboard_url)
ray_cluster_name_actual = status_field.get(
"rayClusterName", ray_cluster_name_actual
)
returned_job_submission_id = status_field.get(
"jobId", job_config.submission_id
)

final_job_status = current_job_status
print(
f"RayJob '{actual_job_cr_name}' status: JobDeployment='{job_deployment_status}', Job='{current_job_status}'"
)

if current_job_status in ["SUCCEEDED", "FAILED", "STOPPED"]:
break

if (
job_timeout_seconds
and (time.time() - start_time) > job_timeout_seconds
):
try:
ray_job_status_cr_final = (
k8s_co_api.get_namespaced_custom_object_status(
group="ray.io",
version="v1",
namespace=namespace,
plural="rayjobs",
name=actual_job_cr_name,
)
)
status_field_final = ray_job_status_cr_final.get(
"status", {}
)
final_job_status = status_field_final.get(
"jobStatus", final_job_status
)
returned_job_submission_id = status_field_final.get(
"jobId", returned_job_submission_id
)
dashboard_url = status_field_final.get(
"dashboardURL", dashboard_url
)
ray_cluster_name_actual = status_field_final.get(
"rayClusterName", ray_cluster_name_actual
)
except Exception:
pass
raise TimeoutError(
f"RayJob '{actual_job_cr_name}' timed out after {job_timeout_seconds} seconds. Last status: {final_job_status}"
)

time.sleep(job_polling_interval_seconds)

print(
f"RayJob '{actual_job_cr_name}' finished with status: {final_job_status}"
)
else:
try:
ray_job_status_cr = k8s_co_api.get_namespaced_custom_object_status(
group="ray.io",
version="v1",
namespace=namespace,
plural="rayjobs",
name=actual_job_cr_name,
)
status_field = ray_job_status_cr.get("status", {})
final_job_status = status_field.get("jobStatus", "SUBMITTED")
returned_job_submission_id = status_field.get(
"jobId", job_config.submission_id
)
dashboard_url = status_field.get("dashboardURL", dashboard_url)
ray_cluster_name_actual = status_field.get(
"rayClusterName", ray_cluster_name_actual
)
except ApiException as e:
if e.status == 404:
final_job_status = "SUBMITTED_NOT_FOUND"
else:
print(
f"Warning: Could not fetch initial status for RayJob '{actual_job_cr_name}': {e}"
)
final_job_status = "UNKNOWN_API_ERROR"

return {
"job_cr_name": actual_job_cr_name,
"job_submission_id": returned_job_submission_id,
"job_status": final_job_status,
"dashboard_url": dashboard_url,
"ray_cluster_name": ray_cluster_name_actual,
}

except ApiException as e:
print(
f"Kubernetes API error during RayJob '{actual_job_cr_name}' management: {e.reason} (status: {e.status})"
)
final_status_on_error = "ERROR_BEFORE_SUBMISSION"
if actual_job_cr_name:
try:
ray_job_status_cr = k8s_co_api.get_namespaced_custom_object_status(
group="ray.io",
version="v1",
namespace=namespace,
plural="rayjobs",
name=actual_job_cr_name,
)
status_field = ray_job_status_cr.get("status", {})
final_status_on_error = status_field.get(
"jobStatus", "UNKNOWN_AFTER_K8S_ERROR"
)
except Exception:
final_status_on_error = "UNKNOWN_FINAL_STATUS_FETCH_FAILED"
raise
except Exception as e:
print(
f"An unexpected error occurred during managed RayJob execution for '{actual_job_cr_name}': {e}"
)
raise


def list_all_clusters(namespace: str, print_to_console: bool = True):
"""
Expand Down Expand Up @@ -760,14 +1057,21 @@ def get_cluster(
head_extended_resource_requests=head_extended_resources,
worker_extended_resource_requests=worker_extended_resources,
)
# 1. Prepare RayClusterSpec from ClusterConfiguration
# Create a temporary config with appwrapper=False to ensure build_ray_cluster returns RayCluster YAML
temp_cluster_config_dict = cluster_config.dict(
exclude_none=True
) # Assuming Pydantic V1 or similar .dict() method
temp_cluster_config_dict["appwrapper"] = False
temp_cluster_config_for_spec = ClusterConfiguration(**temp_cluster_config_dict)
# Ignore the warning here for the lack of a ClusterConfiguration
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Please provide a ClusterConfiguration to initialise the Cluster object",
)
cluster = Cluster(None)
cluster.config = cluster_config
cluster.config = temp_cluster_config_for_spec

# Remove auto-generated fields like creationTimestamp, uid and etc.
remove_autogenerated_fields(resource)
Expand Down
Loading
Loading