Skip to content

Commit d89ee5a

Browse files
RHOAIENG-8098 - ClusterConfiguration should support tolerations
1 parent 051ee3c commit d89ee5a

File tree

5 files changed

+62
-6
lines changed

5 files changed

+62
-6
lines changed

src/codeflare_sdk/common/utils/unit_test_support.py

+11
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import yaml
2323
from pathlib import Path
2424
from kubernetes import client
25+
from kubernetes.client import V1Toleration
2526
from unittest.mock import patch
2627

2728
parent = Path(__file__).resolve().parents[4] # project directory
@@ -427,8 +428,18 @@ def create_cluster_all_config_params(mocker, cluster_name, is_appwrapper) -> Clu
427428
head_memory_requests=12,
428429
head_memory_limits=16,
429430
head_extended_resource_requests={"nvidia.com/gpu": 1, "intel.com/gpu": 2},
431+
head_tolerations=[
432+
V1Toleration(
433+
key="key1", operator="Equal", value="value1", effect="NoSchedule"
434+
)
435+
],
430436
worker_cpu_requests=4,
431437
worker_cpu_limits=8,
438+
worker_tolerations=[
439+
V1Toleration(
440+
key="key2", operator="Equal", value="value2", effect="NoSchedule"
441+
)
442+
],
432443
num_workers=10,
433444
worker_memory_requests=12,
434445
worker_memory_limits=16,

src/codeflare_sdk/ray/cluster/build_ray_cluster.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
This sub-module exists primarily to be used internally by the Cluster object
1717
(in the cluster sub-module) for RayCluster/AppWrapper generation.
1818
"""
19-
from typing import Union, Tuple, Dict
19+
from typing import List, Union, Tuple, Dict
2020
from ...common import _kube_api_error_handling
2121
from ...common.kubernetes_cluster import get_api_client, config_check
2222
from kubernetes.client.exceptions import ApiException
@@ -40,6 +40,7 @@
4040
V1PodTemplateSpec,
4141
V1PodSpec,
4242
V1LocalObjectReference,
43+
V1Toleration,
4344
)
4445

4546
import yaml
@@ -139,7 +140,11 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
139140
"resources": head_resources,
140141
},
141142
"template": {
142-
"spec": get_pod_spec(cluster, [get_head_container_spec(cluster)])
143+
"spec": get_pod_spec(
144+
cluster,
145+
[get_head_container_spec(cluster)],
146+
cluster.config.head_tolerations,
147+
)
143148
},
144149
},
145150
"workerGroupSpecs": [
@@ -154,7 +159,11 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
154159
"resources": worker_resources,
155160
},
156161
"template": V1PodTemplateSpec(
157-
spec=get_pod_spec(cluster, [get_worker_container_spec(cluster)])
162+
spec=get_pod_spec(
163+
cluster,
164+
[get_worker_container_spec(cluster)],
165+
cluster.config.worker_tolerations,
166+
)
158167
),
159168
}
160169
],
@@ -243,14 +252,21 @@ def update_image(image) -> str:
243252
return image
244253

245254

246-
def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers):
255+
def get_pod_spec(
256+
cluster: "codeflare_sdk.ray.cluster.Cluster",
257+
containers: List,
258+
tolerations: List[V1Toleration],
259+
) -> V1PodSpec:
247260
"""
248261
The get_pod_spec() function generates a V1PodSpec for the head/worker containers
249262
"""
263+
250264
pod_spec = V1PodSpec(
251265
containers=containers,
252266
volumes=generate_custom_storage(cluster.config.volumes, VOLUMES),
267+
tolerations=tolerations or None,
253268
)
269+
254270
if cluster.config.image_pull_secrets != []:
255271
pod_spec.image_pull_secrets = generate_image_pull_secrets(cluster)
256272

src/codeflare_sdk/ray/cluster/config.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import warnings
2323
from dataclasses import dataclass, field, fields
2424
from typing import Dict, List, Optional, Union, get_args, get_origin
25-
from kubernetes.client import V1Volume, V1VolumeMount
25+
from kubernetes.client import V1Toleration, V1Volume, V1VolumeMount
2626

2727
dir = pathlib.Path(__file__).parent.parent.resolve()
2828

@@ -58,6 +58,8 @@ class ClusterConfiguration:
5858
The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
5959
head_extended_resource_requests:
6060
A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
61+
head_tolerations:
62+
List of tolerations for head nodes.
6163
min_cpus:
6264
The minimum number of CPUs to allocate to each worker.
6365
max_cpus:
@@ -70,6 +72,8 @@ class ClusterConfiguration:
7072
The maximum amount of memory to allocate to each worker.
7173
num_gpus:
7274
The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
75+
worker_tolerations:
76+
List of tolerations for worker nodes.
7377
appwrapper:
7478
A boolean indicating whether to use an AppWrapper.
7579
envs:
@@ -110,6 +114,7 @@ class ClusterConfiguration:
110114
head_extended_resource_requests: Dict[str, Union[str, int]] = field(
111115
default_factory=dict
112116
)
117+
head_tolerations: Optional[List[V1Toleration]] = None
113118
worker_cpu_requests: Union[int, str] = 1
114119
worker_cpu_limits: Union[int, str] = 1
115120
min_cpus: Optional[Union[int, str]] = None # Deprecating
@@ -120,6 +125,7 @@ class ClusterConfiguration:
120125
min_memory: Optional[Union[int, str]] = None # Deprecating
121126
max_memory: Optional[Union[int, str]] = None # Deprecating
122127
num_gpus: Optional[int] = None # Deprecating
128+
worker_tolerations: Optional[List[V1Toleration]] = None
123129
appwrapper: bool = False
124130
envs: Dict[str, str] = field(default_factory=dict)
125131
image: str = ""
@@ -272,7 +278,10 @@ def check_type(value, expected_type):
272278
if origin_type is Union:
273279
return any(check_type(value, union_type) for union_type in args)
274280
if origin_type is list:
275-
return all(check_type(elem, args[0]) for elem in value)
281+
if value is not None:
282+
return all(check_type(elem, args[0]) for elem in (value or []))
283+
else:
284+
return True
276285
if origin_type is dict:
277286
return all(
278287
check_type(k, args[0]) and check_type(v, args[1])

tests/test_cluster_yamls/appwrapper/unit-test-all-params.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ spec:
9999
imagePullSecrets:
100100
- name: secret1
101101
- name: secret2
102+
tolerations:
103+
- effect: NoSchedule
104+
key: key1
105+
operator: Equal
106+
value: value1
102107
volumes:
103108
- emptyDir:
104109
sizeLimit: 500Gi
@@ -185,6 +190,11 @@ spec:
185190
imagePullSecrets:
186191
- name: secret1
187192
- name: secret2
193+
tolerations:
194+
- effect: NoSchedule
195+
key: key2
196+
operator: Equal
197+
value: value2
188198
volumes:
189199
- emptyDir:
190200
sizeLimit: 500Gi

tests/test_cluster_yamls/ray/unit-test-all-params.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ spec:
9090
imagePullSecrets:
9191
- name: secret1
9292
- name: secret2
93+
tolerations:
94+
- effect: NoSchedule
95+
key: key1
96+
operator: Equal
97+
value: value1
9398
volumes:
9499
- emptyDir:
95100
sizeLimit: 500Gi
@@ -176,6 +181,11 @@ spec:
176181
imagePullSecrets:
177182
- name: secret1
178183
- name: secret2
184+
tolerations:
185+
- effect: NoSchedule
186+
key: key2
187+
operator: Equal
188+
value: value2
179189
volumes:
180190
- emptyDir:
181191
sizeLimit: 500Gi

0 commit comments

Comments
 (0)