Skip to content
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
bde2217
basic slurm array file partitioning
sarahyurick Jun 9, 2026
a0595f6
add slurm array params to composite stages using filepartitioningstage
sarahyurick Jun 9, 2026
43ee179
add tutorial and tests
sarahyurick Jun 11, 2026
cae17b3
Merge branch 'main' into slurm_array
sarahyurick Jun 11, 2026
acfeceb
ruff
sarahyurick Jun 11, 2026
6eaf95e
address greptile reviews
sarahyurick Jun 11, 2026
bb1e30a
ruff
sarahyurick Jun 11, 2026
2ccbd3f
more greptile comments
sarahyurick Jun 11, 2026
1b659ea
add nonetask and failedtask sentinels
sarahyurick Jun 11, 2026
3522809
add failedtask detection and repeat
sarahyurick Jun 11, 2026
717edac
ruff
sarahyurick Jun 11, 2026
8f2345b
Merge branch 'main' into slurm_array
sarahyurick Jun 11, 2026
ebba73e
greptile comments
sarahyurick Jun 11, 2026
5e58793
Merge branch 'main' into slurm_array
sarahyurick Jun 15, 2026
ad8f68a
TextSemanticDeduplicationWorkflow revert
sarahyurick Jun 16, 2026
437270f
Merge branch 'main' into slurm_array
sarahyurick Jun 16, 2026
b55ec47
Merge branch 'main' into slurm_array
sarahyurick Jun 16, 2026
672f3d2
use SlurmArrayConfig dataclass
sarahyurick Jun 16, 2026
2814eec
use base stage adapter and source stage instead of file partitioning …
sarahyurick Jun 22, 2026
4dacd31
formatting
sarahyurick Jun 22, 2026
d54d3a3
ruff
sarahyurick Jun 22, 2026
5e18093
greptile feedback
sarahyurick Jun 22, 2026
94312c4
update tutorial
sarahyurick Jun 22, 2026
298a632
Merge branch 'main' into slurm_array
sarahyurick Jun 22, 2026
4d5d20a
Merge branch 'main' into slurm_array
sarahyurick Jun 23, 2026
e0a8af5
save bulk updates
sarahyurick Jun 24, 2026
62a4493
ruff
sarahyurick Jun 24, 2026
11c74cc
add greptile suggestion
sarahyurick Jun 24, 2026
67e0eef
ruff
sarahyurick Jun 24, 2026
3ee13ff
Merge branch 'main' into slurm_array
sarahyurick Jun 24, 2026
14ddc3c
Merge branch 'main' into slurm_array
sarahyurick Jun 25, 2026
f024975
use login node to submit retry jobs
sarahyurick Jun 25, 2026
ebc93dc
Merge branch 'main' into slurm_array
sarahyurick Jun 25, 2026
3fc2dbc
Merge branch 'main' into slurm_array
sarahyurick Jun 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 217 additions & 7 deletions nemo_curator/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,19 +12,144 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import hashlib
import json
import os
import socket
import tempfile
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any

from loguru import logger

from nemo_curator.core.utils import ignore_ray_head_node
from nemo_curator.tasks import Task
from nemo_curator.tasks.sentinels import FailedTask, NoneTask
from nemo_curator.utils.performance_utils import StageTimer

if TYPE_CHECKING:
from nemo_curator.stages.base import ProcessingStage


FAILED_TASKS_DIR_ENV_VAR = "NEMO_CURATOR_FAILED_TASKS_DIR"
SLURM_ARRAY_ENABLED_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_ENABLED"
SLURM_ARRAY_SHARD_INDEX_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_SHARD_INDEX"
SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_TOTAL_SHARDS"
SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_MINIMUM_SHARD_INDEX"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need all these variables, or can they be self-inferred? I think the initial design that @praateekmahajan had in mind was we just add --array=1-100 to the slurm submit command, and everything else works OOTB. Currently, it seems like the effort from the user side is a bit more than that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the idea is we want to give the user full control if they need to override total shards (needed for reruns) or minimum shard index (needed to get around any Slurm array size limits). Really to enable Slurm array partitioning, the only thing explicitly needed is:

NEMO_CURATOR_SLURM_ARRAY_ENABLED=1

and it can automatically grab the environment variables without any issues. And then the user can override with NEMO_CURATOR_SLURM_ARRAY_SHARD_INDEX, etc. as desired (but not required).


_TRUE_ENV_VALUES = {"1", "true", "yes", "on"}


def _get_int_env_var(env_var: str, fallback_name: str | None = None, default: int | None = None) -> int:
Comment thread
sarahyurick marked this conversation as resolved.
Outdated
env_value = os.environ.get(env_var)
if env_value is None:
if fallback_name is not None:
env_var = fallback_name
env_value = os.environ.get(env_var)

if env_value is None:
if default is not None:
return default

msg = f"Environment variable {env_var} is not set"
raise ValueError(msg)

try:
return int(env_value)
except ValueError as e:
msg = f"Environment variable {env_var} must contain an integer, got {env_value!r}"
raise ValueError(msg) from e


@dataclass
class SlurmArrayConfig:
Comment thread
sarahyurick marked this conversation as resolved.
Outdated
"""Configuration for assigning source tasks to one Slurm array task."""

shard_index: int
total_shards: int
minimum_shard_index: int = 0

@classmethod
def from_env(cls) -> "SlurmArrayConfig | None":
"""Return Slurm array config when source-task filtering is enabled."""
enabled = os.environ.get(SLURM_ARRAY_ENABLED_ENV_VAR, "")
if enabled.strip().lower() not in _TRUE_ENV_VALUES:
return None

return cls(
shard_index=_get_int_env_var(SLURM_ARRAY_SHARD_INDEX_ENV_VAR, "SLURM_ARRAY_TASK_ID"),
total_shards=_get_int_env_var(SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR, "SLURM_ARRAY_TASK_COUNT"),
minimum_shard_index=_get_int_env_var(SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR, default=0),
)


def _safe_filename_token(value: object) -> str:
return "".join(ch if ch.isalnum() or ch in "._-" else "_" for ch in str(value))


def _fsync_directory(path: Path) -> None:
flags = os.O_RDONLY
if hasattr(os, "O_DIRECTORY"):
flags |= os.O_DIRECTORY

dir_fd = os.open(path, flags)
try:
os.fsync(dir_fd)
finally:
os.close(dir_fd)


def _write_failed_task_marker(marker_dir: Path, stage_name: str, task: FailedTask) -> None:
Comment thread
sarahyurick marked this conversation as resolved.
Outdated
created_at = datetime.datetime.now(datetime.UTC)
timestamp = created_at.strftime("%Y%m%dT%H%M%S%fZ")
payload: dict[str, str | int] = {
"created_at": created_at.isoformat(),
"stage_name": stage_name,
"task_id": task.task_id,
"dataset_name": task.dataset_name,
"task_type": type(task).__name__,
"hostname": socket.gethostname(),
"pid": os.getpid(),
}

marker_dir.mkdir(parents=True, exist_ok=True)
filename = (
"failed_task_"
f"stage-{_safe_filename_token(stage_name)}_"
f"task-{_safe_filename_token(task.task_id)}_"
f"pid-{os.getpid()}_"
f"{timestamp}_{uuid.uuid4().hex}.json"
)
final_path = marker_dir / filename

tmp_path: Path | None = None
try:
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=marker_dir,
prefix=f".{filename}.",
suffix=".tmp",
delete=False,
) as tmp:
tmp_path = Path(tmp.name)
json.dump(payload, tmp, indent=2, sort_keys=True)
tmp.write("\n")
tmp.flush()
os.fsync(tmp.fileno())

os.replace(tmp_path, final_path)
_fsync_directory(marker_dir)
except Exception:
if tmp_path is not None:
tmp_path.unlink(missing_ok=True)
raise


@dataclass
class NodeInfo:
"""Generic node information for setup_on_node calls across backends.
Expand Down Expand Up @@ -85,9 +210,22 @@ def process_batch(self, tasks: list[Task]) -> list[Task]:
# Use the batch processing logic
results = self.stage.process_batch(tasks)

# A returned ``None`` ("filter this slot") becomes a NoneTask so every
# output is a real Task that gets a task_id. Sentinels (NoneTask /
# FailedTask) carry no identity and are stripped again before this
# method returns.
results = [NoneTask() if r is None else r for r in results]

# Guarantee every emitted task has a task_id (derived id, or uuid fallback).
results = self._post_process_task_ids(tasks, results)

self._record_failed_tasks([r for r in results if isinstance(r, FailedTask)])

@sarahyurick sarahyurick Jun 12, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @abhinavg4 . For now the PR keeps track of FailedTask instances by looking for a user-set FAILED_TASKS_DIR_ENV_VAR = "NEMO_CURATOR_FAILED_TASKS_DIR" and writing a JSON file per failed task in the specified directory.

I did the environment variable and write approach because it seems more reliable than trying to handle a global Python variable, etc. And the reason it is an environment variable is so that BaseStageAdapter does not have to propagate an additional parameter for every single stage (which I think would involve having to update the executors as well?). Open to other suggestions.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I think a lot of the util functions are coming because of this feature, and there might be an easier way for this. Continuing in DMs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functions are really only:

  1. Write info about failed tasks
  2. Use task IDs to filter by Slurm array index

but I can move to util scripts if that makes it easier to read.


# Sentinels never propagate to the next stage.
results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))]

results = self._filter_slurm_array_source_tasks(results)
Comment thread
sarahyurick marked this conversation as resolved.
Outdated

# Log performance stats and add to result tasks
_, stage_perf_stats = self._timer.log_stats()
# Consume and attach any custom metrics recorded by the stage during this call
Expand All @@ -99,6 +237,78 @@ def process_batch(self, tasks: list[Task]) -> list[Task]:

return results

def _record_failed_tasks(self, failed_tasks: list[FailedTask]) -> None:
marker_dir = os.environ.get(FAILED_TASKS_DIR_ENV_VAR)
if not marker_dir or not failed_tasks:
return

marker_path = Path(marker_dir)
for task in failed_tasks:
try:
_write_failed_task_marker(marker_path, self.stage.name, task)
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to write FailedTask marker to {marker_path}: {e}")

def _filter_slurm_array_source_tasks(self, tasks: list[Task]) -> list[Task]:
"""Keep only source tasks assigned to this Slurm array shard."""
slurm_array = self._resolve_slurm_array_config()
if slurm_array is None:
return tasks

nondeterministic_task_ids = [task.task_id for task in tasks if task.task_id.startswith("r")]
if nondeterministic_task_ids:
msg = (
"Slurm array source filtering requires deterministic task IDs, but stage "
f"{self.stage.name} emitted ambiguous source task IDs: {nondeterministic_task_ids[:5]}"
)
raise ValueError(msg)

assigned_tasks = [
task
for task in tasks
if self._slurm_array_shard_for_task(task, slurm_array) == slurm_array.shard_index
]

msg = (
f"Slurm array shard {slurm_array.shard_index}/{slurm_array.total_shards}: "
f"assigned {len(assigned_tasks)} of {len(tasks)} source tasks for stage {self.stage.name}"
)
if len(assigned_tasks) == 0 and len(tasks) > 0:
logger.warning(msg)
else:
logger.info(msg)

return assigned_tasks

def _resolve_slurm_array_config(self) -> SlurmArrayConfig | None:
if not getattr(self.stage, "is_source_stage", False):
return None

if not hasattr(self, "_resolved_slurm_array"):
resolved = SlurmArrayConfig.from_env()
if resolved is not None:
if resolved.total_shards <= 0:
msg = f"total_shards must be greater than 0, got {resolved.total_shards}"
raise ValueError(msg)

min_assignable_shard_index = resolved.minimum_shard_index
max_assignable_shard_index = resolved.minimum_shard_index + resolved.total_shards - 1
if not min_assignable_shard_index <= resolved.shard_index <= max_assignable_shard_index:
logger.warning(
"shard_index={} is outside the assignable shard range [{}, {}]. "
"This task will not receive any source tasks.",
resolved.shard_index,
min_assignable_shard_index,
max_assignable_shard_index,
)
self._resolved_slurm_array = resolved

return self._resolved_slurm_array

def _slurm_array_shard_for_task(self, task: Task, slurm_array: SlurmArrayConfig) -> int:
digest = hashlib.sha256(task.task_id.encode("utf-8")).hexdigest()
return int(digest[:16], 16) % slurm_array.total_shards + slurm_array.minimum_shard_index

def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Task | None]) -> list[Task]:
"""Assign a deterministic ``task_id`` to every emitted task.

Expand All @@ -109,17 +319,17 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas
re-derived at each stage boundary so the same object passing through
N stages gets N ids.

The inputoutput mapping decides each output's PARENT; whether the
The input -> output mapping decides each output's PARENT; whether the
stage is a source decides each output's SEGMENT (content id vs index)
— the two are independent. ``None`` outputs (Curator's "return None to
filter") are NOT removed before the length check — keeping them in
place preserves positional alignment for filter stages — and are then
dropped from the returned list.

- single input every output is its child (fan-out): ``parent_<seg>``
- ``len(output) == len(input)`` positional 1:1: each ``parent_i_<seg>``;
- single input -> every output is its child (fan-out): ``parent_<seg>``
- ``len(output) == len(input)`` -> positional 1:1: each ``parent_i_<seg>``;
a ``None`` slot just means input ``i`` was filtered.
- any other (ambiguous) cardinality across a batch a random ``uuid``
- any other (ambiguous) cardinality across a batch -> a random ``uuid``
prefixed with ``"r"`` (e.g. ``"r3f9a…"``), so ``task_id`` is never
empty even when a derived id is not possible. The ``"r"`` prefix flags
the id as non-deterministic / ancestry-not-tracked (see
Expand All @@ -128,13 +338,13 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas
``seg`` is the output's content id (``Task.get_deterministic_id()``)
for a source stage when available, else the positional index — so a
source partition keeps a stable id across reorderings regardless of
whether the source is 1N or NN.
whether the source is 1->N or N->N.

Note: a stage that BOTH filters and fans out within a single batch
(returning a flat list rather than a per-input slot) cannot be mapped
positionally; if its length happens to equal the input length the 1:1
assumption may misattribute parents. That combination is unsupported
until per-slot sentinels (NoneTask/FailedTask) land in a later PR.
unless the stage preserves an unambiguous input -> output mapping.
"""
is_source = getattr(self.stage, "is_source_stage", False)

Expand Down
29 changes: 26 additions & 3 deletions nemo_curator/tasks/sentinels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@
# limitations under the License.
"""Payload-less marker tasks.

``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). All markers
share the :class:`SentinelTask` base and carry no payload (``data is None``).
Construct one with ``EmptyTask()``.
``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). The resumability
layer adds two more markers on the same :class:`SentinelTask` base:

- ``NoneTask`` — this slot was intentionally filtered. The resumability counter
treats it as a consumed branch (decrements). The adapter auto-wraps a
returned ``None`` as a ``NoneTask``.
- ``FailedTask`` — this slot failed and should be retried on resume. The counter
is NOT decremented, so its source stays pending and reruns.

All carry no payload (``data is None``) and get their ``task_id`` assigned by
the executor adapter; sentinels are stripped before the next stage. Construct
with ``EmptyTask()`` / ``NoneTask()`` / ``FailedTask()``.
"""

from dataclasses import dataclass, field
Expand Down Expand Up @@ -52,3 +61,17 @@ class EmptyTask(SentinelTask):

dataset_name: str = "empty"
task_id: str = field(init=False, default="0")


@dataclass
class NoneTask(SentinelTask):
"""Marks a slot as intentionally filtered (resumability counter decrements)."""

dataset_name: str = "none"


@dataclass
class FailedTask(SentinelTask):
"""Marks a slot as failed → retried on resume (counter does NOT decrement)."""

dataset_name: str = "failed"
Loading
Loading