Skip to content
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
bde2217
basic slurm array file partitioning
sarahyurick Jun 9, 2026
a0595f6
add slurm array params to composite stages using filepartitioningstage
sarahyurick Jun 9, 2026
43ee179
add tutorial and tests
sarahyurick Jun 11, 2026
cae17b3
Merge branch 'main' into slurm_array
sarahyurick Jun 11, 2026
acfeceb
ruff
sarahyurick Jun 11, 2026
6eaf95e
address greptile reviews
sarahyurick Jun 11, 2026
bb1e30a
ruff
sarahyurick Jun 11, 2026
2ccbd3f
more greptile comments
sarahyurick Jun 11, 2026
1b659ea
add nonetask and failedtask sentinels
sarahyurick Jun 11, 2026
3522809
add failedtask detection and repeat
sarahyurick Jun 11, 2026
717edac
ruff
sarahyurick Jun 11, 2026
8f2345b
Merge branch 'main' into slurm_array
sarahyurick Jun 11, 2026
ebba73e
greptile comments
sarahyurick Jun 11, 2026
5e58793
Merge branch 'main' into slurm_array
sarahyurick Jun 15, 2026
ad8f68a
TextSemanticDeduplicationWorkflow revert
sarahyurick Jun 16, 2026
437270f
Merge branch 'main' into slurm_array
sarahyurick Jun 16, 2026
b55ec47
Merge branch 'main' into slurm_array
sarahyurick Jun 16, 2026
672f3d2
use SlurmArrayConfig dataclass
sarahyurick Jun 16, 2026
2814eec
use base stage adapter and source stage instead of file partitioning …
sarahyurick Jun 22, 2026
4dacd31
formatting
sarahyurick Jun 22, 2026
d54d3a3
ruff
sarahyurick Jun 22, 2026
5e18093
greptile feedback
sarahyurick Jun 22, 2026
94312c4
update tutorial
sarahyurick Jun 22, 2026
298a632
Merge branch 'main' into slurm_array
sarahyurick Jun 22, 2026
4d5d20a
Merge branch 'main' into slurm_array
sarahyurick Jun 23, 2026
e0a8af5
save bulk updates
sarahyurick Jun 24, 2026
62a4493
ruff
sarahyurick Jun 24, 2026
11c74cc
add greptile suggestion
sarahyurick Jun 24, 2026
67e0eef
ruff
sarahyurick Jun 24, 2026
3ee13ff
Merge branch 'main' into slurm_array
sarahyurick Jun 24, 2026
14ddc3c
Merge branch 'main' into slurm_array
sarahyurick Jun 25, 2026
f024975
use login node to submit retry jobs
sarahyurick Jun 25, 2026
ebc93dc
Merge branch 'main' into slurm_array
sarahyurick Jun 25, 2026
3fc2dbc
Merge branch 'main' into slurm_array
sarahyurick Jun 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions nemo_curator/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,8 +17,15 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from nemo_curator.backends.failed_task_markers import record_failed_tasks
from nemo_curator.backends.slurm_array import (
filter_slurm_array_source_tasks,
raise_for_failed_source_tasks_with_slurm_array,
resolve_slurm_array_config,
)
from nemo_curator.core.utils import ignore_ray_head_node
from nemo_curator.tasks import Task
from nemo_curator.tasks.sentinels import FailedTask, NoneTask
from nemo_curator.utils.performance_utils import StageTimer

if TYPE_CHECKING:
Expand Down Expand Up @@ -85,9 +92,29 @@ def process_batch(self, tasks: list[Task]) -> list[Task]:
# Use the batch processing logic
results = self.stage.process_batch(tasks)

# A returned ``None`` ("filter this slot") becomes a NoneTask so every
# output is a real Task that gets a task_id. Sentinels (NoneTask /
# FailedTask) carry no identity and are stripped again before this
# method returns.
results = [NoneTask() if r is None else r for r in results]

# Guarantee every emitted task has a task_id (derived id, or uuid fallback).
results = self._post_process_task_ids(tasks, results)

# Failed tasks on the source stage are not supported when Slurm array filtering is enabled.
slurm_array = resolve_slurm_array_config(is_source_stage=getattr(self.stage, "is_source_stage", False))
failed_tasks = [r for r in results if isinstance(r, FailedTask)]
raise_for_failed_source_tasks_with_slurm_array(self.stage.name, failed_tasks, slurm_array)

# Record failed tasks for later inspection or retry bookkeeping.
record_failed_tasks(self.stage.name, failed_tasks)

# Sentinels never propagate to the next stage.
results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))]

# Filter tasks based on the Slurm array configuration.
results = filter_slurm_array_source_tasks(results, slurm_array, self.stage.name)

# Log performance stats and add to result tasks
_, stage_perf_stats = self._timer.log_stats()
# Consume and attach any custom metrics recorded by the stage during this call
Expand All @@ -109,17 +136,17 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas
re-derived at each stage boundary so the same object passing through
N stages gets N ids.

The inputoutput mapping decides each output's PARENT; whether the
The input -> output mapping decides each output's PARENT; whether the
stage is a source decides each output's SEGMENT (content id vs index)
— the two are independent. ``None`` outputs (Curator's "return None to
filter") are NOT removed before the length check — keeping them in
place preserves positional alignment for filter stages — and are then
dropped from the returned list.

- single input every output is its child (fan-out): ``parent_<seg>``
- ``len(output) == len(input)`` positional 1:1: each ``parent_i_<seg>``;
- single input -> every output is its child (fan-out): ``parent_<seg>``
- ``len(output) == len(input)`` -> positional 1:1: each ``parent_i_<seg>``;
a ``None`` slot just means input ``i`` was filtered.
- any other (ambiguous) cardinality across a batch a random ``uuid``
- any other (ambiguous) cardinality across a batch -> a random ``uuid``
prefixed with ``"r"`` (e.g. ``"r3f9a…"``), so ``task_id`` is never
empty even when a derived id is not possible. The ``"r"`` prefix flags
the id as non-deterministic / ancestry-not-tracked (see
Expand All @@ -128,13 +155,13 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas
``seg`` is the output's content id (``Task.get_deterministic_id()``)
for a source stage when available, else the positional index — so a
source partition keeps a stable id across reorderings regardless of
whether the source is 1N or NN.
whether the source is 1->N or N->N.

Note: a stage that BOTH filters and fans out within a single batch
(returning a flat list rather than a per-input slot) cannot be mapped
positionally; if its length happens to equal the input length the 1:1
assumption may misattribute parents. That combination is unsupported
until per-slot sentinels (NoneTask/FailedTask) land in a later PR.
unless the stage preserves an unambiguous input -> output mapping.
"""
is_source = getattr(self.stage, "is_source_stage", False)

Expand Down
81 changes: 81 additions & 0 deletions nemo_curator/backends/failed_task_markers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import os
from pathlib import Path

from loguru import logger

from nemo_curator.tasks.sentinels import FailedTask
from nemo_curator.utils.atomic_io import write_json_atomically

Check failure on line 22 in nemo_curator/backends/failed_task_markers.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

nemo_curator/backends/failed_task_markers.py:15:1: I001 Import block is un-sorted or un-formatted

Check failure on line 22 in nemo_curator/backends/failed_task_markers.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

nemo_curator/backends/failed_task_markers.py:15:1: I001 Import block is un-sorted or un-formatted


FAILED_TASKS_DIR_ENV_VAR = "NEMO_CURATOR_FAILED_TASKS_DIR"
FAILED_TASK_MARKER_PATTERN = "failed_task_*.json"


def _write_failed_task_marker(marker_dir: Path, stage_name: str, task: FailedTask) -> None:
"""Write one compact marker for a failed stage/task pair."""
payload = {
"stage_name": stage_name,
"task_id": task.task_id,
}

marker_identity = f"{stage_name}\0{task.task_id}".encode("utf-8")

Check failure on line 36 in nemo_curator/backends/failed_task_markers.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP012)

nemo_curator/backends/failed_task_markers.py:36:23: UP012 Unnecessary UTF-8 `encoding` argument to `encode`

Check failure on line 36 in nemo_curator/backends/failed_task_markers.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP012)

nemo_curator/backends/failed_task_markers.py:36:23: UP012 Unnecessary UTF-8 `encoding` argument to `encode`
marker_digest = hashlib.sha256(marker_identity).hexdigest()[:16]
filename = f"failed_task_{marker_digest}.json"
final_path = marker_dir / filename
write_json_atomically(final_path, payload, separators=(",", ":"), sort_keys=True)


def record_failed_tasks(stage_name: str, failed_tasks: list[FailedTask]) -> None:
"""Record FailedTask markers when ``NEMO_CURATOR_FAILED_TASKS_DIR`` is set."""
marker_dir = os.environ.get(FAILED_TASKS_DIR_ENV_VAR)
if not marker_dir or not failed_tasks:
return

marker_path = Path(marker_dir)
for task in failed_tasks:
try:
_write_failed_task_marker(marker_path, stage_name, task)
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to write FailedTask marker to {marker_path}: {e}")


def summarize_failed_task_markers(
marker_dir: str | Path | None = None,
) -> dict[str, object]:
"""Count FailedTask markers in ``marker_dir`` or the configured env dir."""
resolved_marker_dir = marker_dir if marker_dir is not None else os.environ.get(FAILED_TASKS_DIR_ENV_VAR)
if not resolved_marker_dir:
return {
"failed_task_marker_count": 0,
}

marker_path = Path(resolved_marker_dir).absolute()
if not marker_path.exists():
return {
"failed_task_marker_count": 0,
}

marker_count = 0
for path in marker_path.glob(FAILED_TASK_MARKER_PATTERN):
if not path.is_file():
continue
marker_count += 1

return {
"failed_task_marker_count": marker_count,
}
198 changes: 198 additions & 0 deletions nemo_curator/backends/slurm_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import os
from dataclasses import dataclass

from loguru import logger

from nemo_curator.tasks import Task
from nemo_curator.tasks.sentinels import FailedTask
from nemo_curator.utils.retry_manifest import RetryManifest

Check failure on line 23 in nemo_curator/backends/slurm_array.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

nemo_curator/backends/slurm_array.py:15:1: I001 Import block is un-sorted or un-formatted

Check failure on line 23 in nemo_curator/backends/slurm_array.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

nemo_curator/backends/slurm_array.py:15:1: I001 Import block is un-sorted or un-formatted


SLURM_ARRAY_ENABLED_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_ENABLED"
SLURM_ARRAY_SHARD_INDEX_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_SHARD_INDEX"
SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_TOTAL_SHARDS"
SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_MINIMUM_SHARD_INDEX"
SLURM_ARRAY_RETRY_MANIFEST_NAMESPACE = "slurm_array"
SLURM_ARRAY_RETRY_DIRNAME = ".slurm_array_retry"

_TRUE_ENV_VALUES = {"1", "true", "yes", "on"}


def _get_int_env_var(env_var: str, fallback_name: str | None = None, default: int | None = None) -> int:
"""Read an integer env var, with optional fallback/default."""
env_value = os.environ.get(env_var)
if env_value is None:
if fallback_name is not None:
env_var = fallback_name
env_value = os.environ.get(env_var)

if env_value is None:
if default is not None:
return default

msg = f"Environment variable {env_var} is not set"
raise ValueError(msg)

try:
return int(env_value)
except ValueError as e:
msg = f"Environment variable {env_var} must contain an integer, got {env_value!r}"
raise ValueError(msg) from e


@dataclass
class SlurmArrayConfig:
"""Source-task sharding settings for one Slurm array task."""

shard_index: int
total_shards: int
minimum_shard_index: int = 0

@classmethod
def from_env(cls) -> "SlurmArrayConfig | None":
"""Build config from Curator env vars, falling back to Slurm env vars."""
enabled = os.environ.get(SLURM_ARRAY_ENABLED_ENV_VAR, "")
if enabled.strip().lower() not in _TRUE_ENV_VALUES:
return None

return cls(
shard_index=_get_int_env_var(SLURM_ARRAY_SHARD_INDEX_ENV_VAR, "SLURM_ARRAY_TASK_ID"),
total_shards=_get_int_env_var(SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR, "SLURM_ARRAY_TASK_COUNT"),
minimum_shard_index=_get_int_env_var(SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR, default=0),
)


def configure_slurm_array_source_filtering(
shard_index: int,
total_shards: int,
minimum_shard_index: int,
) -> None:
"""Set env vars consumed by source-stage filtering."""
os.environ[SLURM_ARRAY_ENABLED_ENV_VAR] = "1"
os.environ[SLURM_ARRAY_SHARD_INDEX_ENV_VAR] = str(shard_index)
os.environ[SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR] = str(total_shards)
os.environ[SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR] = str(minimum_shard_index)


def resolve_slurm_array_config(is_source_stage: bool) -> SlurmArrayConfig | None:
"""Resolve filtering config for source stages."""
if not is_source_stage:
return None

resolved = SlurmArrayConfig.from_env()
if resolved is None:
return None

if resolved.total_shards <= 0:
msg = f"total_shards must be greater than 0, got {resolved.total_shards}"
raise ValueError(msg)

min_assignable_shard_index = resolved.minimum_shard_index
max_assignable_shard_index = resolved.minimum_shard_index + resolved.total_shards - 1
if not min_assignable_shard_index <= resolved.shard_index <= max_assignable_shard_index:
logger.warning(
"shard_index={} is outside the assignable shard range [{}, {}]. "
"This task will not receive any source tasks.",
resolved.shard_index,
min_assignable_shard_index,
max_assignable_shard_index,
)

return resolved


def slurm_array_shard_for_task(task: Task, slurm_array: SlurmArrayConfig) -> int:
"""Assign a task to a shard by hashing its deterministic task ID."""
digest = hashlib.sha256(task.task_id.encode("utf-8")).hexdigest()
return int(digest[:16], 16) % slurm_array.total_shards + slurm_array.minimum_shard_index


def filter_slurm_array_source_tasks(
tasks: list[Task],
slurm_array: SlurmArrayConfig | None,
stage_name: str,
) -> list[Task]:
"""Keep only source tasks assigned to the active Slurm array shard."""
if slurm_array is None:
return tasks

nondeterministic_task_ids = [task.task_id for task in tasks if task.task_id.startswith("r")]
if nondeterministic_task_ids:
msg = (
"Slurm array source filtering requires deterministic task IDs, but stage "
f"{stage_name} emitted ambiguous source task IDs: {nondeterministic_task_ids[:5]}"
)
raise ValueError(msg)

assigned_tasks = [
task for task in tasks if slurm_array_shard_for_task(task, slurm_array) == slurm_array.shard_index
]

msg = (
f"Slurm array shard {slurm_array.shard_index}/{slurm_array.total_shards}: "
f"assigned {len(assigned_tasks)} of {len(tasks)} source tasks for stage {stage_name}"
)
if len(assigned_tasks) == 0 and len(tasks) > 0:
logger.warning(msg)
else:
logger.info(msg)

return assigned_tasks


def raise_for_failed_source_tasks_with_slurm_array(
stage_name: str,
failed_tasks: list[FailedTask],
slurm_array: SlurmArrayConfig | None,
) -> None:
"""Reject source-stage FailedTasks, which cannot be retried by shard."""
if failed_tasks and slurm_array is not None:
msg = (
f"Source stage {stage_name} emitted FailedTask while Slurm array filtering is enabled. "
"This is not supported because the failed source task cannot be assigned to a retry shard "
"reliably. Raise an exception from the source stage instead."
)
raise ValueError(msg)


def is_slurm_array_driver_process(use_slurm: bool) -> bool:
"""Return true for the process that owns retry metadata."""
return not use_slurm or os.environ.get("SLURM_NODEID", "0") == "0"


def build_slurm_array_retry_manifest(
checkpoint_path: str | None,
shard_index: int,
total_shards: int,
minimum_shard_index: int,
) -> RetryManifest | None:
"""Create a retry manifest for one Slurm array shard."""
if checkpoint_path is None:
return None

return RetryManifest(
checkpoint_path=checkpoint_path,
namespace=SLURM_ARRAY_RETRY_MANIFEST_NAMESPACE,
retry_dirname=SLURM_ARRAY_RETRY_DIRNAME,
identity={
"minimum_shard_index": minimum_shard_index,
"shard_index": shard_index,
"total_shards": total_shards,
},
flatten_identity=True,
)
Loading
Loading