Skip to content

Commit fddc1f8

Browse files
committed
Add multi-node test for symmetric_run.
Signed-off-by: Ricardo Decal <[email protected]>
1 parent e829eaf commit fddc1f8

File tree

2 files changed

+102
-2
lines changed

2 files changed

+102
-2
lines changed

python/ray/scripts/symmetric_run.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
"""Symmetric Run for Ray."""
1+
"""Symmetric Run for Ray.
2+
3+
This script launches a Ray cluster across all nodes and executes a specified entrypoint command.
4+
It is useful in environments where the same command is executed on every node, such as with SLURM.
5+
"""
26

37
import socket
48
import subprocess

python/ray/tests/test_symmetric_run.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import sys
21
from contextlib import contextmanager
2+
import sys
33
from unittest.mock import MagicMock, patch
44

55
import pytest
@@ -167,5 +167,101 @@ def test_symmetric_run_arg_validation(monkeypatch, cleanup_ray):
167167
assert "--num-cpus=4" in ray_start_calls[0][0][0]
168168

169169

170+
def test_symmetric_run_multi_node(monkeypatch, cleanup_ray):
171+
"""
172+
Test symmetric_run with a simulated 3-node (1 head + 2 workers) cluster.
173+
"""
174+
from ray.scripts.symmetric_run import symmetric_run
175+
176+
runner = CliRunner()
177+
# Non-loopback IP because of multi-node.
178+
head_ip = "10.0.0.1"
179+
head_port = "6379"
180+
address = f"{head_ip}:{head_port}"
181+
182+
common_args = ["--address", address, "--min-nodes", "3", "--", "echo", "ok"]
183+
184+
with patch("subprocess.run") as mock_run, patch(
185+
"ray.scripts.symmetric_run.check_ray_already_started", return_value=False
186+
):
187+
188+
# Make subprocess.run succeed by default.
189+
mock_run.return_value.returncode = 0
190+
191+
# ---- Head node ----
192+
# If IP == resolved_gcs_host, then is_head == True.
193+
with _setup_mock_network_utils(curr_ip=head_ip, head_ip=head_ip):
194+
# The head waits for --min-nodes, so mock success.
195+
with patch(
196+
"ray.scripts.symmetric_run.check_cluster_ready", return_value=True
197+
) as mock_ready:
198+
with patch("sys.argv", ["ray.scripts.symmetric_run", *common_args]):
199+
result_head = runner.invoke(symmetric_run, common_args)
200+
assert result_head.exit_code == 0
201+
# Ensure the head path waited for 3 nodes.
202+
mock_ready.assert_called_once()
203+
args_called, _kwargs_called = mock_ready.call_args
204+
assert args_called[0] == 3 # nnodes
205+
206+
# ---- Worker node 1 ----
207+
with _setup_mock_network_utils(curr_ip=head_ip, head_ip="10.0.0.2"):
208+
with patch(
209+
"ray.scripts.symmetric_run.check_head_node_ready", return_value=True
210+
):
211+
with patch("sys.argv", ["ray.scripts.symmetric_run", *common_args]):
212+
result_w1 = runner.invoke(symmetric_run, common_args)
213+
assert result_w1.exit_code == 0
214+
215+
# ---- Worker node 2 ----
216+
with _setup_mock_network_utils(curr_ip=head_ip, head_ip="10.0.0.3"):
217+
with patch(
218+
"ray.scripts.symmetric_run.check_head_node_ready", return_value=True
219+
):
220+
with patch("sys.argv", ["ray.scripts.symmetric_run", *common_args]):
221+
result_w2 = runner.invoke(symmetric_run, common_args)
222+
assert result_w2.exit_code == 0
223+
224+
calls = mock_run.call_args_list
225+
226+
calls_str = [str(c) for c in calls]
227+
start_calls = [s for s in calls_str if "ray" in s and "start" in s]
228+
stop_calls = [s for s in calls_str if "ray" in s and "stop" in s]
229+
230+
assert len(start_calls) == 3, f"Expected 3 ray start calls, got: {start_calls}"
231+
assert len(stop_calls) == 3, f"Expected 3 ray stop calls, got: {stop_calls}"
232+
233+
head_starts = [s for s in start_calls if "--head" in s]
234+
worker_starts = [s for s in start_calls if "--address" in s and "--block" in s]
235+
236+
assert (
237+
len(head_starts) == 1
238+
), f"Expected exactly 1 head start, got: {head_starts}"
239+
assert (
240+
len(worker_starts) == 2
241+
), f"Expected exactly 2 worker starts, got: {worker_starts}"
242+
243+
# Validate head flags
244+
head_call = head_starts[0]
245+
assert f"--node-ip-address={head_ip}" in head_call
246+
assert f"--port={head_port}" in head_call
247+
248+
# Validate worker flags
249+
for s in worker_starts:
250+
# Must connect to the same head address we passed on the CLI.
251+
# "ray start --address <address> --block ..."
252+
assert "--address" in s
253+
assert address in s
254+
assert "--block" in s
255+
256+
# Validate that the entrypoint was invoked once on the head (the
257+
# `echo ok` command).
258+
non_ray_calls = [
259+
s for s in calls_str if not ("ray" in s and ("start" in s or "stop" in s))
260+
]
261+
assert any(
262+
"['echo', 'ok']" in s for s in non_ray_calls
263+
), f"Entrypoint command was not found in: {non_ray_calls}"
264+
265+
170266
if __name__ == "__main__":
171267
sys.exit(pytest.main(["-sv", __file__]))

0 commit comments

Comments
 (0)