Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion nemo_curator/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _parse_slurm_nodelist(nodelist: str) -> list[str]:
nodes.append(f"{prefix}{str(n).zfill(width)}")
else:
nodes.append(f"{prefix}{part}")
return nodes if nodes else [nodelist]
return nodes or [nodelist]


# --------------------------------------------------------------------------- #
Expand Down Expand Up @@ -519,6 +519,11 @@ def _run_as_worker(self, head_ip: str) -> int:
if self.num_cpus is not None:
cmd.extend(["--num-cpus", str(self.num_cpus)])

# Xenna manages GPU assignment itself. This must be present before the
# Ray worker process starts, otherwise Ray may mask CUDA_VISIBLE_DEVICES
# before Xenna can assign GPUs on worker nodes.
os.environ.setdefault("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1")

logger.info(f"Ray worker starting: {' '.join(cmd)}")
result = subprocess.run(cmd, check=False) # noqa: S603
logger.info(f"Ray worker exited with code {result.returncode}")
Expand Down
4 changes: 4 additions & 0 deletions nemo_curator/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ def init_cluster( # noqa: PLR0913

# We set some env vars for Xenna here. This is only used for Xenna clusters.
os.environ["XENNA_RAY_METRICS_PORT"] = str(ray_metrics_port)
# Xenna manages GPU assignment itself. This must be present before the Ray
# cluster is started, otherwise Ray workers may mask CUDA_VISIBLE_DEVICES
# before Xenna can assign GPUs.
os.environ.setdefault("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1")

# Opt into Ray Serve's HAProxy ingress when both binaries resolve. Ray Serve
# uses socat to drive HAProxy's admin socket — without it, the controller's
Expand Down
108 changes: 108 additions & 0 deletions tests/core/test_ray_cluster_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

from nemo_curator.core import client as core_client
from nemo_curator.core import utils as core_utils

if TYPE_CHECKING:
from pathlib import Path

import pytest


def _mock_init_cluster_dependencies(monkeypatch: pytest.MonkeyPatch) -> list[list[str]]:
popen_calls: list[list[str]] = []

class FakePopen:
def __init__(self, cmd: list[str], **_kwargs: object) -> None:
popen_calls.append(cmd)

monkeypatch.setattr(core_utils.ray.util, "register_serializer", lambda *_args, **_kwargs: None)
monkeypatch.setattr(core_utils, "get_free_port", lambda port, **_kwargs: port)
monkeypatch.setattr(core_utils.shutil, "which", lambda _cmd: None)
monkeypatch.setattr(core_utils.subprocess, "Popen", FakePopen)
return popen_calls


def test_init_cluster_sets_xenna_gpu_env_before_ray_start(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
popen_calls = _mock_init_cluster_dependencies(monkeypatch)
monkeypatch.delenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", raising=False)

core_utils.init_cluster(
ray_port=6379,
ray_temp_dir=str(tmp_path),
ray_dashboard_port=8265,
ray_metrics_port=8080,
ray_client_server_port=10001,
ray_dashboard_host="127.0.0.1",
)

Comment on lines +45 to +54

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing required ray_dashboard_host argument

init_cluster has ray_dashboard_host: str as a required positional parameter (no default value), but both test invocations omit it entirely. This causes an immediate TypeError: init_cluster() missing 1 required positional argument: 'ray_dashboard_host', meaning neither test can pass even when ray is installed. The same omission appears in the second test at line 62–69.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks — addressed in 421e8c5. I added the missing ray_dashboard_host argument to both init_cluster test calls, and also covered/fixed the worker-node path by setting RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES before SlurmRayClient._run_as_worker launches ray start.

Validation:

  • uv run ruff check nemo_curator/core/client.py tests/core/test_ray_cluster_utils.py — passed
  • uv run pytest tests/core/test_ray_cluster_utils.py — blocked locally because NeMo Curator raises on non-Linux hosts (darwin); attempting to spoof Linux then fails on Ray/psutil Linux native extension import on macOS.

assert popen_calls
assert "--head" in popen_calls[0]
assert "--block" in popen_calls[0]
assert core_utils.os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] == "1"


def test_init_cluster_preserves_existing_xenna_gpu_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
_mock_init_cluster_dependencies(monkeypatch)
monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "0")

core_utils.init_cluster(
ray_port=6379,
ray_temp_dir=str(tmp_path),
ray_dashboard_port=8265,
ray_metrics_port=8080,
ray_client_server_port=10001,
ray_dashboard_host="127.0.0.1",
)

assert core_utils.os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] == "0"


def test_slurm_worker_sets_xenna_gpu_env_before_ray_start(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
run_calls: list[list[str]] = []

def fake_run(cmd: list[str], **_kwargs: object) -> object:
run_calls.append(cmd)
assert core_client.os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] == "1"
return type("CompletedProcess", (), {"returncode": 0})()

monkeypatch.delenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", raising=False)
monkeypatch.setattr(core_client, "_find_ray_binary", lambda: "ray")
monkeypatch.setattr(core_client.subprocess, "run", fake_run)

client = core_client.SlurmRayClient(ray_port=6379, ray_temp_dir=str(tmp_path))

assert client._run_as_worker("10.0.0.1") == 0
assert run_calls
assert run_calls[0][:4] == ["ray", "start", "--address", "10.0.0.1:6379"]


def test_slurm_worker_preserves_existing_xenna_gpu_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "0")
monkeypatch.setattr(core_client, "_find_ray_binary", lambda: "ray")
monkeypatch.setattr(
core_client.subprocess,
"run",
lambda *_args, **_kwargs: type("CompletedProcess", (), {"returncode": 0})(),
)

client = core_client.SlurmRayClient(ray_port=6379, ray_temp_dir=str(tmp_path))

assert client._run_as_worker("10.0.0.1") == 0
assert core_client.os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] == "0"
4 changes: 2 additions & 2 deletions tutorials/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from nemo_curator.backends.base import NodeInfo, WorkerMetadata
from nemo_curator.backends.xenna import XennaExecutor
from nemo_curator.core.client import RayClient
from nemo_curator.core.client import SlurmRayClient
from nemo_curator.pipeline import Pipeline
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.stages.resources import Resources
Expand Down Expand Up @@ -217,7 +217,7 @@ def process(self, task: SampleTask) -> SampleTask:
def main() -> None:
"""Main function to run the pipeline."""
# Create pipeline
ray_client = RayClient()
ray_client = SlurmRayClient()
ray_client.start()
pipeline = Pipeline(name="sentiment_analysis", description="Analyze sentiment of sample sentences")

Expand Down