-
Notifications
You must be signed in to change notification settings - Fork 292
Add support for Slurm arrays #2059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 25 commits
bde2217
a0595f6
43ee179
cae17b3
acfeceb
6eaf95e
bb1e30a
2ccbd3f
1b659ea
3522809
717edac
8f2345b
ebba73e
5e58793
ad8f68a
437270f
b55ec47
672f3d2
2814eec
4dacd31
d54d3a3
5e18093
94312c4
298a632
4d5d20a
e0a8af5
62a4493
11c74cc
67e0eef
3ee13ff
14ddc3c
f024975
ebc93dc
3fc2dbc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
|
|
@@ -12,19 +12,144 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import datetime | ||
| import hashlib | ||
| import json | ||
| import os | ||
| import socket | ||
| import tempfile | ||
| import uuid | ||
| from abc import ABC, abstractmethod | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from loguru import logger | ||
|
|
||
| from nemo_curator.core.utils import ignore_ray_head_node | ||
| from nemo_curator.tasks import Task | ||
| from nemo_curator.tasks.sentinels import FailedTask, NoneTask | ||
| from nemo_curator.utils.performance_utils import StageTimer | ||
|
|
||
| if TYPE_CHECKING: | ||
| from nemo_curator.stages.base import ProcessingStage | ||
|
|
||
|
|
||
| FAILED_TASKS_DIR_ENV_VAR = "NEMO_CURATOR_FAILED_TASKS_DIR" | ||
| SLURM_ARRAY_ENABLED_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_ENABLED" | ||
| SLURM_ARRAY_SHARD_INDEX_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_SHARD_INDEX" | ||
| SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_TOTAL_SHARDS" | ||
| SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_MINIMUM_SHARD_INDEX" | ||
|
|
||
| _TRUE_ENV_VALUES = {"1", "true", "yes", "on"} | ||
|
|
||
|
|
||
| def _get_int_env_var(env_var: str, fallback_name: str | None = None, default: int | None = None) -> int: | ||
|
sarahyurick marked this conversation as resolved.
Outdated
|
||
| env_value = os.environ.get(env_var) | ||
| if env_value is None: | ||
| if fallback_name is not None: | ||
| env_var = fallback_name | ||
| env_value = os.environ.get(env_var) | ||
|
|
||
| if env_value is None: | ||
| if default is not None: | ||
| return default | ||
|
|
||
| msg = f"Environment variable {env_var} is not set" | ||
| raise ValueError(msg) | ||
|
|
||
| try: | ||
| return int(env_value) | ||
| except ValueError as e: | ||
| msg = f"Environment variable {env_var} must contain an integer, got {env_value!r}" | ||
| raise ValueError(msg) from e | ||
|
|
||
|
|
||
| @dataclass | ||
| class SlurmArrayConfig: | ||
|
sarahyurick marked this conversation as resolved.
Outdated
|
||
| """Configuration for assigning source tasks to one Slurm array task.""" | ||
|
|
||
| shard_index: int | ||
| total_shards: int | ||
| minimum_shard_index: int = 0 | ||
|
|
||
| @classmethod | ||
| def from_env(cls) -> "SlurmArrayConfig | None": | ||
| """Return Slurm array config when source-task filtering is enabled.""" | ||
| enabled = os.environ.get(SLURM_ARRAY_ENABLED_ENV_VAR, "") | ||
| if enabled.strip().lower() not in _TRUE_ENV_VALUES: | ||
| return None | ||
|
|
||
| return cls( | ||
| shard_index=_get_int_env_var(SLURM_ARRAY_SHARD_INDEX_ENV_VAR, "SLURM_ARRAY_TASK_ID"), | ||
| total_shards=_get_int_env_var(SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR, "SLURM_ARRAY_TASK_COUNT"), | ||
| minimum_shard_index=_get_int_env_var(SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR, default=0), | ||
| ) | ||
|
|
||
|
|
||
| def _safe_filename_token(value: object) -> str: | ||
| return "".join(ch if ch.isalnum() or ch in "._-" else "_" for ch in str(value)) | ||
|
|
||
|
|
||
| def _fsync_directory(path: Path) -> None: | ||
| flags = os.O_RDONLY | ||
| if hasattr(os, "O_DIRECTORY"): | ||
| flags |= os.O_DIRECTORY | ||
|
|
||
| dir_fd = os.open(path, flags) | ||
| try: | ||
| os.fsync(dir_fd) | ||
| finally: | ||
| os.close(dir_fd) | ||
|
|
||
|
|
||
| def _write_failed_task_marker(marker_dir: Path, stage_name: str, task: FailedTask) -> None: | ||
|
sarahyurick marked this conversation as resolved.
Outdated
|
||
| created_at = datetime.datetime.now(datetime.UTC) | ||
| timestamp = created_at.strftime("%Y%m%dT%H%M%S%fZ") | ||
| payload: dict[str, str | int] = { | ||
| "created_at": created_at.isoformat(), | ||
| "stage_name": stage_name, | ||
| "task_id": task.task_id, | ||
| "dataset_name": task.dataset_name, | ||
| "task_type": type(task).__name__, | ||
| "hostname": socket.gethostname(), | ||
| "pid": os.getpid(), | ||
| } | ||
|
|
||
| marker_dir.mkdir(parents=True, exist_ok=True) | ||
| filename = ( | ||
| "failed_task_" | ||
| f"stage-{_safe_filename_token(stage_name)}_" | ||
| f"task-{_safe_filename_token(task.task_id)}_" | ||
| f"pid-{os.getpid()}_" | ||
| f"{timestamp}_{uuid.uuid4().hex}.json" | ||
| ) | ||
| final_path = marker_dir / filename | ||
|
|
||
| tmp_path: Path | None = None | ||
| try: | ||
| with tempfile.NamedTemporaryFile( | ||
| mode="w", | ||
| encoding="utf-8", | ||
| dir=marker_dir, | ||
| prefix=f".{filename}.", | ||
| suffix=".tmp", | ||
| delete=False, | ||
| ) as tmp: | ||
| tmp_path = Path(tmp.name) | ||
| json.dump(payload, tmp, indent=2, sort_keys=True) | ||
| tmp.write("\n") | ||
| tmp.flush() | ||
| os.fsync(tmp.fileno()) | ||
|
|
||
| os.replace(tmp_path, final_path) | ||
| _fsync_directory(marker_dir) | ||
| except Exception: | ||
| if tmp_path is not None: | ||
| tmp_path.unlink(missing_ok=True) | ||
| raise | ||
|
|
||
|
|
||
| @dataclass | ||
| class NodeInfo: | ||
| """Generic node information for setup_on_node calls across backends. | ||
|
|
@@ -85,9 +210,22 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: | |
| # Use the batch processing logic | ||
| results = self.stage.process_batch(tasks) | ||
|
|
||
| # A returned ``None`` ("filter this slot") becomes a NoneTask so every | ||
| # output is a real Task that gets a task_id. Sentinels (NoneTask / | ||
| # FailedTask) carry no identity and are stripped again before this | ||
| # method returns. | ||
| results = [NoneTask() if r is None else r for r in results] | ||
|
|
||
| # Guarantee every emitted task has a task_id (derived id, or uuid fallback). | ||
| results = self._post_process_task_ids(tasks, results) | ||
|
|
||
| self._record_failed_tasks([r for r in results if isinstance(r, FailedTask)]) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed with @abhinavg4 . For now the PR keeps track of I did the environment variable and write approach because it seems more reliable than trying to handle a global Python variable, etc. And the reason it is an environment variable is so that
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok I think a lot of the util functions are coming because of this feature, and there might be an easier way for this. Continuing in DMs.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The functions are really only:
but I can move to util scripts if that makes it easier to read. |
||
|
|
||
| # Sentinels never propagate to the next stage. | ||
| results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))] | ||
|
|
||
| results = self._filter_slurm_array_source_tasks(results) | ||
|
sarahyurick marked this conversation as resolved.
Outdated
|
||
|
|
||
| # Log performance stats and add to result tasks | ||
| _, stage_perf_stats = self._timer.log_stats() | ||
| # Consume and attach any custom metrics recorded by the stage during this call | ||
|
|
@@ -99,6 +237,78 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: | |
|
|
||
| return results | ||
|
|
||
| def _record_failed_tasks(self, failed_tasks: list[FailedTask]) -> None: | ||
| marker_dir = os.environ.get(FAILED_TASKS_DIR_ENV_VAR) | ||
| if not marker_dir or not failed_tasks: | ||
| return | ||
|
|
||
| marker_path = Path(marker_dir) | ||
| for task in failed_tasks: | ||
| try: | ||
| _write_failed_task_marker(marker_path, self.stage.name, task) | ||
| except Exception as e: # noqa: BLE001 | ||
| logger.warning(f"Failed to write FailedTask marker to {marker_path}: {e}") | ||
|
|
||
| def _filter_slurm_array_source_tasks(self, tasks: list[Task]) -> list[Task]: | ||
| """Keep only source tasks assigned to this Slurm array shard.""" | ||
| slurm_array = self._resolve_slurm_array_config() | ||
| if slurm_array is None: | ||
| return tasks | ||
|
|
||
| nondeterministic_task_ids = [task.task_id for task in tasks if task.task_id.startswith("r")] | ||
| if nondeterministic_task_ids: | ||
| msg = ( | ||
| "Slurm array source filtering requires deterministic task IDs, but stage " | ||
| f"{self.stage.name} emitted ambiguous source task IDs: {nondeterministic_task_ids[:5]}" | ||
| ) | ||
| raise ValueError(msg) | ||
|
|
||
| assigned_tasks = [ | ||
| task | ||
| for task in tasks | ||
| if self._slurm_array_shard_for_task(task, slurm_array) == slurm_array.shard_index | ||
| ] | ||
|
|
||
| msg = ( | ||
| f"Slurm array shard {slurm_array.shard_index}/{slurm_array.total_shards}: " | ||
| f"assigned {len(assigned_tasks)} of {len(tasks)} source tasks for stage {self.stage.name}" | ||
| ) | ||
| if len(assigned_tasks) == 0 and len(tasks) > 0: | ||
| logger.warning(msg) | ||
| else: | ||
| logger.info(msg) | ||
|
|
||
| return assigned_tasks | ||
|
|
||
| def _resolve_slurm_array_config(self) -> SlurmArrayConfig | None: | ||
| if not getattr(self.stage, "is_source_stage", False): | ||
| return None | ||
|
|
||
| if not hasattr(self, "_resolved_slurm_array"): | ||
| resolved = SlurmArrayConfig.from_env() | ||
| if resolved is not None: | ||
| if resolved.total_shards <= 0: | ||
| msg = f"total_shards must be greater than 0, got {resolved.total_shards}" | ||
| raise ValueError(msg) | ||
|
|
||
| min_assignable_shard_index = resolved.minimum_shard_index | ||
| max_assignable_shard_index = resolved.minimum_shard_index + resolved.total_shards - 1 | ||
| if not min_assignable_shard_index <= resolved.shard_index <= max_assignable_shard_index: | ||
| logger.warning( | ||
| "shard_index={} is outside the assignable shard range [{}, {}]. " | ||
| "This task will not receive any source tasks.", | ||
| resolved.shard_index, | ||
| min_assignable_shard_index, | ||
| max_assignable_shard_index, | ||
| ) | ||
| self._resolved_slurm_array = resolved | ||
|
|
||
| return self._resolved_slurm_array | ||
|
|
||
| def _slurm_array_shard_for_task(self, task: Task, slurm_array: SlurmArrayConfig) -> int: | ||
| digest = hashlib.sha256(task.task_id.encode("utf-8")).hexdigest() | ||
| return int(digest[:16], 16) % slurm_array.total_shards + slurm_array.minimum_shard_index | ||
|
|
||
| def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Task | None]) -> list[Task]: | ||
| """Assign a deterministic ``task_id`` to every emitted task. | ||
|
|
||
|
|
@@ -109,17 +319,17 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas | |
| re-derived at each stage boundary so the same object passing through | ||
| N stages gets N ids. | ||
|
|
||
| The input→output mapping decides each output's PARENT; whether the | ||
| The input -> output mapping decides each output's PARENT; whether the | ||
| stage is a source decides each output's SEGMENT (content id vs index) | ||
| — the two are independent. ``None`` outputs (Curator's "return None to | ||
| filter") are NOT removed before the length check — keeping them in | ||
| place preserves positional alignment for filter stages — and are then | ||
| dropped from the returned list. | ||
|
|
||
| - single input → every output is its child (fan-out): ``parent_<seg>`` | ||
| - ``len(output) == len(input)`` → positional 1:1: each ``parent_i_<seg>``; | ||
| - single input -> every output is its child (fan-out): ``parent_<seg>`` | ||
| - ``len(output) == len(input)`` -> positional 1:1: each ``parent_i_<seg>``; | ||
| a ``None`` slot just means input ``i`` was filtered. | ||
| - any other (ambiguous) cardinality across a batch → a random ``uuid`` | ||
| - any other (ambiguous) cardinality across a batch -> a random ``uuid`` | ||
| prefixed with ``"r"`` (e.g. ``"r3f9a…"``), so ``task_id`` is never | ||
| empty even when a derived id is not possible. The ``"r"`` prefix flags | ||
| the id as non-deterministic / ancestry-not-tracked (see | ||
|
|
@@ -128,13 +338,13 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas | |
| ``seg`` is the output's content id (``Task.get_deterministic_id()``) | ||
| for a source stage when available, else the positional index — so a | ||
| source partition keeps a stable id across reorderings regardless of | ||
| whether the source is 1→N or N→N. | ||
| whether the source is 1->N or N->N. | ||
|
|
||
| Note: a stage that BOTH filters and fans out within a single batch | ||
| (returning a flat list rather than a per-input slot) cannot be mapped | ||
| positionally; if its length happens to equal the input length the 1:1 | ||
| assumption may misattribute parents. That combination is unsupported | ||
| until per-slot sentinels (NoneTask/FailedTask) land in a later PR. | ||
| unless the stage preserves an unambiguous input -> output mapping. | ||
| """ | ||
| is_source = getattr(self.stage, "is_source_stage", False) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need all these variables, or can they be self-inferred? I think the initial design that @praateekmahajan had in mind was we just add --array=1-100 to the slurm submit command, and everything else works OOTB. Currently, it seems like the effort from the user side is a bit more than that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the idea is we want to give the user full control if they need to override total shards (needed for reruns) or minimum shard index (needed to get around any Slurm array size limits). Really to enable Slurm array partitioning, the only thing explicitly needed is:
and it can automatically grab the environment variables without any issues. And then the user can override with
NEMO_CURATOR_SLURM_ARRAY_SHARD_INDEX, etc. as desired (but not required).