Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
bde2217
basic slurm array file partitioning
sarahyurick Jun 9, 2026
a0595f6
add slurm array params to composite stages using filepartitioningstage
sarahyurick Jun 9, 2026
43ee179
add tutorial and tests
sarahyurick Jun 11, 2026
cae17b3
Merge branch 'main' into slurm_array
sarahyurick Jun 11, 2026
acfeceb
ruff
sarahyurick Jun 11, 2026
6eaf95e
address greptile reviews
sarahyurick Jun 11, 2026
bb1e30a
ruff
sarahyurick Jun 11, 2026
2ccbd3f
more greptile comments
sarahyurick Jun 11, 2026
1b659ea
add nonetask and failedtask sentinels
sarahyurick Jun 11, 2026
3522809
add failedtask detection and repeat
sarahyurick Jun 11, 2026
717edac
ruff
sarahyurick Jun 11, 2026
8f2345b
Merge branch 'main' into slurm_array
sarahyurick Jun 11, 2026
ebba73e
greptile comments
sarahyurick Jun 11, 2026
5e58793
Merge branch 'main' into slurm_array
sarahyurick Jun 15, 2026
ad8f68a
TextSemanticDeduplicationWorkflow revert
sarahyurick Jun 16, 2026
437270f
Merge branch 'main' into slurm_array
sarahyurick Jun 16, 2026
b55ec47
Merge branch 'main' into slurm_array
sarahyurick Jun 16, 2026
672f3d2
use SlurmArrayConfig dataclass
sarahyurick Jun 16, 2026
2814eec
use base stage adapter and source stage instead of file partitioning …
sarahyurick Jun 22, 2026
4dacd31
formatting
sarahyurick Jun 22, 2026
d54d3a3
ruff
sarahyurick Jun 22, 2026
5e18093
greptile feedback
sarahyurick Jun 22, 2026
94312c4
update tutorial
sarahyurick Jun 22, 2026
298a632
Merge branch 'main' into slurm_array
sarahyurick Jun 22, 2026
4d5d20a
Merge branch 'main' into slurm_array
sarahyurick Jun 23, 2026
e0a8af5
save bulk updates
sarahyurick Jun 24, 2026
62a4493
ruff
sarahyurick Jun 24, 2026
11c74cc
add greptile suggestion
sarahyurick Jun 24, 2026
67e0eef
ruff
sarahyurick Jun 24, 2026
3ee13ff
Merge branch 'main' into slurm_array
sarahyurick Jun 24, 2026
14ddc3c
Merge branch 'main' into slurm_array
sarahyurick Jun 25, 2026
f024975
use login node to submit retry jobs
sarahyurick Jun 25, 2026
ebc93dc
Merge branch 'main' into slurm_array
sarahyurick Jun 25, 2026
3fc2dbc
Merge branch 'main' into slurm_array
sarahyurick Jun 25, 2026
be338ea
Merge branch 'main' into slurm_array
sarahyurick Jun 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions nemo_curator/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,8 +17,15 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from nemo_curator.backends.failed_task_markers import record_failed_tasks
from nemo_curator.backends.slurm_array import (
filter_slurm_array_source_tasks,
raise_for_failed_source_tasks_with_slurm_array,
resolve_slurm_array_config,
)
from nemo_curator.core.utils import ignore_ray_head_node
from nemo_curator.tasks import Task
from nemo_curator.tasks.sentinels import FailedTask, NoneTask
from nemo_curator.utils.performance_utils import StageTimer

if TYPE_CHECKING:
Expand Down Expand Up @@ -85,9 +92,29 @@ def process_batch(self, tasks: list[Task]) -> list[Task]:
# Use the batch processing logic
results = self.stage.process_batch(tasks)

# A returned ``None`` ("filter this slot") becomes a NoneTask so every
# output is a real Task that gets a task_id. Sentinels (NoneTask /
# FailedTask) carry no identity and are stripped again before this
# method returns.
results = [NoneTask() if r is None else r for r in results]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might have a merge conflict on lin 99, 113. But I think the intention is the same. So one of can resolve it whoever merges later.


# Guarantee every emitted task has a task_id (derived id, or uuid fallback).
results = self._post_process_task_ids(tasks, results)

# Failed tasks on the source stage are not supported when Slurm array filtering is enabled.
slurm_array = resolve_slurm_array_config(is_source_stage=getattr(self.stage, "is_source_stage", False))
failed_tasks = [r for r in results if isinstance(r, FailedTask)]
raise_for_failed_source_tasks_with_slurm_array(self.stage.name, failed_tasks, slurm_array)

# Record failed tasks for later inspection or retry bookkeeping.
record_failed_tasks(self.stage.name, failed_tasks)

# Sentinels never propagate to the next stage.
results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))]

# Filter tasks based on the Slurm array configuration.
results = filter_slurm_array_source_tasks(results, slurm_array, self.stage.name)

# Log performance stats and add to result tasks
_, stage_perf_stats = self._timer.log_stats()
# Consume and attach any custom metrics recorded by the stage during this call
Expand All @@ -109,17 +136,17 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas
re-derived at each stage boundary so the same object passing through
N stages gets N ids.

The inputoutput mapping decides each output's PARENT; whether the
The input -> output mapping decides each output's PARENT; whether the
stage is a source decides each output's SEGMENT (content id vs index)
— the two are independent. ``None`` outputs (Curator's "return None to
filter") are NOT removed before the length check — keeping them in
place preserves positional alignment for filter stages — and are then
dropped from the returned list.

- single input every output is its child (fan-out): ``parent_<seg>``
- ``len(output) == len(input)`` positional 1:1: each ``parent_i_<seg>``;
- single input -> every output is its child (fan-out): ``parent_<seg>``
- ``len(output) == len(input)`` -> positional 1:1: each ``parent_i_<seg>``;
a ``None`` slot just means input ``i`` was filtered.
- any other (ambiguous) cardinality across a batch a random ``uuid``
- any other (ambiguous) cardinality across a batch -> a random ``uuid``
prefixed with ``"r"`` (e.g. ``"r3f9a…"``), so ``task_id`` is never
empty even when a derived id is not possible. The ``"r"`` prefix flags
the id as non-deterministic / ancestry-not-tracked (see
Expand All @@ -128,13 +155,13 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas
``seg`` is the output's content id (``Task.get_deterministic_id()``)
for a source stage when available, else the positional index — so a
source partition keeps a stable id across reorderings regardless of
whether the source is 1N or NN.
whether the source is 1->N or N->N.

Note: a stage that BOTH filters and fans out within a single batch
(returning a flat list rather than a per-input slot) cannot be mapped
positionally; if its length happens to equal the input length the 1:1
assumption may misattribute parents. That combination is unsupported
until per-slot sentinels (NoneTask/FailedTask) land in a later PR.
unless the stage preserves an unambiguous input -> output mapping.
"""
is_source = getattr(self.stage, "is_source_stage", False)

Expand Down
80 changes: 80 additions & 0 deletions nemo_curator/backends/failed_task_markers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import os
from pathlib import Path

from loguru import logger

from nemo_curator.tasks.sentinels import FailedTask
from nemo_curator.utils.atomic_io import write_json_atomically

FAILED_TASKS_DIR_ENV_VAR = "NEMO_CURATOR_FAILED_TASKS_DIR"
FAILED_TASK_MARKER_PATTERN = "failed_task_*.json"


def _write_failed_task_marker(marker_dir: Path, stage_name: str, task: FailedTask) -> None:
"""Write one compact marker for a failed stage/task pair."""
payload = {
"stage_name": stage_name,
"task_id": task.task_id,
}

marker_identity = f"{stage_name}\0{task.task_id}".encode()
marker_digest = hashlib.sha256(marker_identity).hexdigest()[:16]
filename = f"failed_task_{marker_digest}.json"
final_path = marker_dir / filename
write_json_atomically(final_path, payload, separators=(",", ":"), sort_keys=True)


def record_failed_tasks(stage_name: str, failed_tasks: list[FailedTask]) -> None:
"""Record FailedTask markers when ``NEMO_CURATOR_FAILED_TASKS_DIR`` is set."""
marker_dir = os.environ.get(FAILED_TASKS_DIR_ENV_VAR)
if not marker_dir or not failed_tasks:
return

marker_path = Path(marker_dir)
for task in failed_tasks:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still not sure why we need it for every task in failed task. This makes writing difficult and reading diffiult?

For each Slurm array idx, would it not help to just have a single marker saying this idx failed and that's all? Also, can we move "if failed task" to the base.py code for readability?

Just if condition and call this function

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can see your point.

For me it seems easy enough to check if there is anything in the failed tasks directory to know whether or not to rerun.

I thought it could be useful to understand how many failed tasks there were (1 vs 1 million) to maybe inform whether or not you even want to rerun it. (Since the files are tiny I am not worried about disk space.)

And if each manifest saves which stage produced the failed task, then you know which stages you need to potentially update or edit before rerunning.

But I am open to other approaches too! Let me know what you think.

try:
_write_failed_task_marker(marker_path, stage_name, task)
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to write FailedTask marker to {marker_path}: {e}")


def summarize_failed_task_markers(
marker_dir: str | Path | None = None,
) -> dict[str, object]:
"""Count FailedTask markers in ``marker_dir`` or the configured env dir."""
resolved_marker_dir = marker_dir if marker_dir is not None else os.environ.get(FAILED_TASKS_DIR_ENV_VAR)
if not resolved_marker_dir:
return {
"failed_task_marker_count": 0,
}

marker_path = Path(resolved_marker_dir).absolute()
if not marker_path.exists():
return {
"failed_task_marker_count": 0,
}

marker_count = 0
for path in marker_path.glob(FAILED_TASK_MARKER_PATTERN):
if not path.is_file():
continue
marker_count += 1

return {
"failed_task_marker_count": marker_count,
}
197 changes: 197 additions & 0 deletions nemo_curator/backends/slurm_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import os
from dataclasses import dataclass

from loguru import logger

from nemo_curator.tasks import Task
from nemo_curator.tasks.sentinels import FailedTask
from nemo_curator.utils.retry_manifest import RetryManifest

SLURM_ARRAY_ENABLED_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_ENABLED"
SLURM_ARRAY_SHARD_INDEX_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_SHARD_INDEX"
SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_TOTAL_SHARDS"
SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_MINIMUM_SHARD_INDEX"
SLURM_ARRAY_RETRY_MANIFEST_NAMESPACE = "slurm_array"
SLURM_ARRAY_RETRY_DIRNAME = ".slurm_array_retry"

_TRUE_ENV_VALUES = {"1", "true", "yes", "on"}


def _get_int_env_var(env_var: str, fallback_name: str | None = None, default: int | None = None) -> int:
"""Read an integer env var, with optional fallback/default."""
env_value = os.environ.get(env_var)
if env_value is None:
if fallback_name is not None:
env_var = fallback_name
env_value = os.environ.get(env_var)

if env_value is None:
if default is not None:
return default

msg = f"Environment variable {env_var} is not set"
raise ValueError(msg)

try:
return int(env_value)
except ValueError as e:
msg = f"Environment variable {env_var} must contain an integer, got {env_value!r}"
raise ValueError(msg) from e


@dataclass
class SlurmArrayConfig:
"""Source-task sharding settings for one Slurm array task."""

shard_index: int
total_shards: int
minimum_shard_index: int = 0

@classmethod
def from_env(cls) -> "SlurmArrayConfig | None":
"""Build config from Curator env vars, falling back to Slurm env vars."""
enabled = os.environ.get(SLURM_ARRAY_ENABLED_ENV_VAR, "")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we gating on this. This would mean usre needs to enable this variable ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, see here too: #2059 (comment)

IMO we need a way to explicitly opt into Slurm array partitioning, and this variable is how users can do it.

if enabled.strip().lower() not in _TRUE_ENV_VALUES:
return None

return cls(
shard_index=_get_int_env_var(SLURM_ARRAY_SHARD_INDEX_ENV_VAR, "SLURM_ARRAY_TASK_ID"),
total_shards=_get_int_env_var(SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR, "SLURM_ARRAY_TASK_COUNT"),
minimum_shard_index=_get_int_env_var(SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR, default=0),
)


def configure_slurm_array_source_filtering(
shard_index: int,
total_shards: int,
minimum_shard_index: int,
) -> None:
"""Set env vars consumed by source-stage filtering."""
os.environ[SLURM_ARRAY_ENABLED_ENV_VAR] = "1"
os.environ[SLURM_ARRAY_SHARD_INDEX_ENV_VAR] = str(shard_index)
os.environ[SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR] = str(total_shards)
os.environ[SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR] = str(minimum_shard_index)


def resolve_slurm_array_config(is_source_stage: bool) -> SlurmArrayConfig | None:
"""Resolve filtering config for source stages."""
if not is_source_stage:
return None

resolved = SlurmArrayConfig.from_env()
if resolved is None:
return None

if resolved.total_shards <= 0:
msg = f"total_shards must be greater than 0, got {resolved.total_shards}"
raise ValueError(msg)

min_assignable_shard_index = resolved.minimum_shard_index
max_assignable_shard_index = resolved.minimum_shard_index + resolved.total_shards - 1
if not min_assignable_shard_index <= resolved.shard_index <= max_assignable_shard_index:
logger.warning(
"shard_index={} is outside the assignable shard range [{}, {}]. "
"This task will not receive any source tasks.",
resolved.shard_index,
min_assignable_shard_index,
max_assignable_shard_index,
)

return resolved


def slurm_array_shard_for_task(task: Task, slurm_array: SlurmArrayConfig) -> int:
"""Assign a task to a shard by hashing its deterministic task ID."""
digest = hashlib.sha256(task.task_id.encode("utf-8")).hexdigest()
return int(digest[:16], 16) % slurm_array.total_shards + slurm_array.minimum_shard_index


def filter_slurm_array_source_tasks(
tasks: list[Task],
slurm_array: SlurmArrayConfig | None,
stage_name: str,
) -> list[Task]:
"""Keep only source tasks assigned to the active Slurm array shard."""
if slurm_array is None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this into a single if in base.py? Currently, I feel the code is very less readable since if I'm not using Slurm arrays. I'm coming to a ton of functions and just returning.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if statement is just acting as an additional barrier to ensure we are only using this function if Slurm array partitioning has been enabled. I can add more comments to base.py if you think it can help?

I mainly want to avoid having base.py reason about Slurm array partitioning at all in favor of delegating the work to these functions here, so that base.py does not get clobbered up with a bunch of Slurm-specific things.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops misread the comment I think. Sure I can move the if statement only to base.py, to make it more readable.

return tasks

nondeterministic_task_ids = [task.task_id for task in tasks if task.task_id.startswith("r")]
if nondeterministic_task_ids:
msg = (
"Slurm array source filtering requires deterministic task IDs, but stage "
f"{stage_name} emitted ambiguous source task IDs: {nondeterministic_task_ids[:5]}"
)
raise ValueError(msg)

assigned_tasks = [
task for task in tasks if slurm_array_shard_for_task(task, slurm_array) == slurm_array.shard_index
]

msg = (
f"Slurm array shard {slurm_array.shard_index}/{slurm_array.total_shards}: "
f"assigned {len(assigned_tasks)} of {len(tasks)} source tasks for stage {stage_name}"
)
if len(assigned_tasks) == 0 and len(tasks) > 0:
logger.warning(msg)
else:
logger.info(msg)

return assigned_tasks


def raise_for_failed_source_tasks_with_slurm_array(
stage_name: str,
failed_tasks: list[FailedTask],
slurm_array: SlurmArrayConfig | None,
) -> None:
"""Reject source-stage FailedTasks, which cannot be retried by shard."""
if failed_tasks and slurm_array is not None:

@abhinavg4 abhinavg4 Jun 30, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Can't this just mean we ignore this task? Like, if we have

[task_1, task_2, failed, task_4] and sluirm array of 2. Then can't we do task_1 and task_2 first and task_4 for another?

Also, why not treat None the same way then? I think for resumability, also we need think about this. For now, if a task is failed at the source stage, I assume it needs to rerun.

I'm ok with having a contract that source stages cannot emit failed or none tasks. but it should be global then and not just for Slurm array.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this because I am worried about there being edge cases when a source stage produces failed tasks, specifically with Slurm array partitioning enabled.

Like with the FilePartitioningStage, the list of file names is used to generate the deterministic IDs, but if there is a failed task then won't that affect how a deterministic ID is generated and thus how it is bucketed into a Slurm array index? So would this mess up which Slurm index it is assigned to upon retry?

I haven't thought very thoroughly about the above example so maybe it is not a big concern. But since Slurm array partitioning is supposed to work with any source stage, I wanted to add this restriction until we could be more confident here.

Wdyt?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want me to make it more general and crash if the source stage emits any failed tasks (regardless of if we are doing Slurm array partitioning), then I can do that.

msg = (
f"Source stage {stage_name} emitted FailedTask while Slurm array filtering is enabled. "
"This is not supported because the failed source task cannot be assigned to a retry shard "
"reliably. Raise an exception from the source stage instead."
)
raise ValueError(msg)


def is_slurm_array_driver_process(use_slurm: bool) -> bool:
"""Return true for the process that owns retry metadata."""
return not use_slurm or os.environ.get("SLURM_NODEID", "0") == "0"


def build_slurm_array_retry_manifest(
checkpoint_path: str | None,
shard_index: int,
total_shards: int,
minimum_shard_index: int,
) -> RetryManifest | None:
"""Create a retry manifest for one Slurm array shard."""
if checkpoint_path is None:
return None

return RetryManifest(
checkpoint_path=checkpoint_path,
namespace=SLURM_ARRAY_RETRY_MANIFEST_NAMESPACE,
retry_dirname=SLURM_ARRAY_RETRY_DIRNAME,
identity={
"minimum_shard_index": minimum_shard_index,
"shard_index": shard_index,
"total_shards": total_shards,
},
flatten_identity=True,
)
Loading
Loading