|
1 |
| -import sys |
2 | 1 | from contextlib import contextmanager
|
| 2 | +import sys |
3 | 3 | from unittest.mock import MagicMock, patch
|
4 | 4 |
|
5 | 5 | import pytest
|
@@ -167,5 +167,101 @@ def test_symmetric_run_arg_validation(monkeypatch, cleanup_ray):
|
167 | 167 | assert "--num-cpus=4" in ray_start_calls[0][0][0]
|
168 | 168 |
|
169 | 169 |
|
| 170 | +def test_symmetric_run_multi_node(monkeypatch, cleanup_ray): |
| 171 | + """ |
| 172 | + Test symmetric_run with a simulated 3-node (1 head + 2 workers) cluster. |
| 173 | + """ |
| 174 | + from ray.scripts.symmetric_run import symmetric_run |
| 175 | + |
| 176 | + runner = CliRunner() |
| 177 | + # Non-loopback IP because of multi-node. |
| 178 | + head_ip = "10.0.0.1" |
| 179 | + head_port = "6379" |
| 180 | + address = f"{head_ip}:{head_port}" |
| 181 | + |
| 182 | + common_args = ["--address", address, "--min-nodes", "3", "--", "echo", "ok"] |
| 183 | + |
| 184 | + with patch("subprocess.run") as mock_run, patch( |
| 185 | + "ray.scripts.symmetric_run.check_ray_already_started", return_value=False |
| 186 | + ): |
| 187 | + |
| 188 | + # Make subprocess.run succeed by default. |
| 189 | + mock_run.return_value.returncode = 0 |
| 190 | + |
| 191 | + # ---- Head node ---- |
| 192 | + # If IP == resolved_gcs_host, then is_head == True. |
| 193 | + with _setup_mock_network_utils(curr_ip=head_ip, head_ip=head_ip): |
| 194 | + # The head waits for --min-nodes, so mock success. |
| 195 | + with patch( |
| 196 | + "ray.scripts.symmetric_run.check_cluster_ready", return_value=True |
| 197 | + ) as mock_ready: |
| 198 | + with patch("sys.argv", ["ray.scripts.symmetric_run", *common_args]): |
| 199 | + result_head = runner.invoke(symmetric_run, common_args) |
| 200 | + assert result_head.exit_code == 0 |
| 201 | + # Ensure the head path waited for 3 nodes. |
| 202 | + mock_ready.assert_called_once() |
| 203 | + args_called, _kwargs_called = mock_ready.call_args |
| 204 | + assert args_called[0] == 3 # nnodes |
| 205 | + |
| 206 | + # ---- Worker node 1 ---- |
| 207 | + with _setup_mock_network_utils(curr_ip=head_ip, head_ip="10.0.0.2"): |
| 208 | + with patch( |
| 209 | + "ray.scripts.symmetric_run.check_head_node_ready", return_value=True |
| 210 | + ): |
| 211 | + with patch("sys.argv", ["ray.scripts.symmetric_run", *common_args]): |
| 212 | + result_w1 = runner.invoke(symmetric_run, common_args) |
| 213 | + assert result_w1.exit_code == 0 |
| 214 | + |
| 215 | + # ---- Worker node 2 ---- |
| 216 | + with _setup_mock_network_utils(curr_ip=head_ip, head_ip="10.0.0.3"): |
| 217 | + with patch( |
| 218 | + "ray.scripts.symmetric_run.check_head_node_ready", return_value=True |
| 219 | + ): |
| 220 | + with patch("sys.argv", ["ray.scripts.symmetric_run", *common_args]): |
| 221 | + result_w2 = runner.invoke(symmetric_run, common_args) |
| 222 | + assert result_w2.exit_code == 0 |
| 223 | + |
| 224 | + calls = mock_run.call_args_list |
| 225 | + |
| 226 | + calls_str = [str(c) for c in calls] |
| 227 | + start_calls = [s for s in calls_str if "ray" in s and "start" in s] |
| 228 | + stop_calls = [s for s in calls_str if "ray" in s and "stop" in s] |
| 229 | + |
| 230 | + assert len(start_calls) == 3, f"Expected 3 ray start calls, got: {start_calls}" |
| 231 | + assert len(stop_calls) == 3, f"Expected 3 ray stop calls, got: {stop_calls}" |
| 232 | + |
| 233 | + head_starts = [s for s in start_calls if "--head" in s] |
| 234 | + worker_starts = [s for s in start_calls if "--address" in s and "--block" in s] |
| 235 | + |
| 236 | + assert ( |
| 237 | + len(head_starts) == 1 |
| 238 | + ), f"Expected exactly 1 head start, got: {head_starts}" |
| 239 | + assert ( |
| 240 | + len(worker_starts) == 2 |
| 241 | + ), f"Expected exactly 2 worker starts, got: {worker_starts}" |
| 242 | + |
| 243 | + # Validate head flags |
| 244 | + head_call = head_starts[0] |
| 245 | + assert f"--node-ip-address={head_ip}" in head_call |
| 246 | + assert f"--port={head_port}" in head_call |
| 247 | + |
| 248 | + # Validate worker flags |
| 249 | + for s in worker_starts: |
| 250 | + # Must connect to the same head address we passed on the CLI. |
| 251 | + # "ray start --address <address> --block ..." |
| 252 | + assert "--address" in s |
| 253 | + assert address in s |
| 254 | + assert "--block" in s |
| 255 | + |
| 256 | + # Validate that the entrypoint was invoked once on the head (the |
| 257 | + # `echo ok` command). |
| 258 | + non_ray_calls = [ |
| 259 | + s for s in calls_str if not ("ray" in s and ("start" in s or "stop" in s)) |
| 260 | + ] |
| 261 | + assert any( |
| 262 | + "['echo', 'ok']" in s for s in non_ray_calls |
| 263 | + ), f"Entrypoint command was not found in: {non_ray_calls}" |
| 264 | + |
| 265 | + |
170 | 266 | if __name__ == "__main__":
|
171 | 267 | sys.exit(pytest.main(["-sv", __file__]))
|
0 commit comments