Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
bde2217
basic slurm array file partitioning
sarahyurick Jun 9, 2026
a0595f6
add slurm array params to composite stages using filepartitioningstage
sarahyurick Jun 9, 2026
43ee179
add tutorial and tests
sarahyurick Jun 11, 2026
cae17b3
Merge branch 'main' into slurm_array
sarahyurick Jun 11, 2026
acfeceb
ruff
sarahyurick Jun 11, 2026
6eaf95e
address greptile reviews
sarahyurick Jun 11, 2026
bb1e30a
ruff
sarahyurick Jun 11, 2026
2ccbd3f
more greptile comments
sarahyurick Jun 11, 2026
1b659ea
add nonetask and failedtask sentinels
sarahyurick Jun 11, 2026
3522809
add failedtask detection and repeat
sarahyurick Jun 11, 2026
717edac
ruff
sarahyurick Jun 11, 2026
8f2345b
Merge branch 'main' into slurm_array
sarahyurick Jun 11, 2026
ebba73e
greptile comments
sarahyurick Jun 11, 2026
5e58793
Merge branch 'main' into slurm_array
sarahyurick Jun 15, 2026
ad8f68a
TextSemanticDeduplicationWorkflow revert
sarahyurick Jun 16, 2026
437270f
Merge branch 'main' into slurm_array
sarahyurick Jun 16, 2026
b55ec47
Merge branch 'main' into slurm_array
sarahyurick Jun 16, 2026
672f3d2
use SlurmArrayConfig dataclass
sarahyurick Jun 16, 2026
2814eec
use base stage adapter and source stage instead of file partitioning …
sarahyurick Jun 22, 2026
4dacd31
formatting
sarahyurick Jun 22, 2026
d54d3a3
ruff
sarahyurick Jun 22, 2026
5e18093
greptile feedback
sarahyurick Jun 22, 2026
94312c4
update tutorial
sarahyurick Jun 22, 2026
298a632
Merge branch 'main' into slurm_array
sarahyurick Jun 22, 2026
4d5d20a
Merge branch 'main' into slurm_array
sarahyurick Jun 23, 2026
e0a8af5
save bulk updates
sarahyurick Jun 24, 2026
62a4493
ruff
sarahyurick Jun 24, 2026
11c74cc
add greptile suggestion
sarahyurick Jun 24, 2026
67e0eef
ruff
sarahyurick Jun 24, 2026
3ee13ff
Merge branch 'main' into slurm_array
sarahyurick Jun 24, 2026
14ddc3c
Merge branch 'main' into slurm_array
sarahyurick Jun 25, 2026
f024975
use login node to submit retry jobs
sarahyurick Jun 25, 2026
ebc93dc
Merge branch 'main' into slurm_array
sarahyurick Jun 25, 2026
3fc2dbc
Merge branch 'main' into slurm_array
sarahyurick Jun 25, 2026
be338ea
Merge branch 'main' into slurm_array
sarahyurick Jun 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions nemo_curator/stages/audio/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ class ManifestReader(CompositeStage[_EmptyTask, AudioTask]):
blocksize: Target size per partition (e.g., "100MB"). Ignored if files_per_partition is set.
file_extensions: File extensions to filter. Defaults to [".jsonl", ".json"].
storage_options: Storage options for cloud paths (S3, GCS credentials, endpoints).
enable_array_partitioning: Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs).
shard_index: The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable.
total_shards: The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable.
minimum_shard_index: The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to 0.
"""

manifest_path: str | list[str]
Expand All @@ -200,6 +207,10 @@ class ManifestReader(CompositeStage[_EmptyTask, AudioTask]):
blocksize: int | str | None = None
file_extensions: list[str] = field(default_factory=lambda: [".jsonl", ".json"])
storage_options: dict[str, Any] | None = None
enable_array_partitioning: bool = False
shard_index: int | str | None = None
total_shards: int | str | None = None
minimum_shard_index: int | str = 0

def __post_init__(self) -> None:
super().__init__()
Expand All @@ -215,6 +226,10 @@ def decompose(self) -> list[ProcessingStage]:
blocksize=self.blocksize,
file_extensions=self.file_extensions,
storage_options=self.storage_options,
enable_array_partitioning=self.enable_array_partitioning,
shard_index=self.shard_index,
total_shards=self.total_shards,
minimum_shard_index=self.minimum_shard_index,
),
ManifestReaderStage(),
]
Expand Down
69 changes: 68 additions & 1 deletion nemo_curator/stages/file_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import os
from dataclasses import dataclass
from typing import Any

Expand All @@ -29,6 +31,24 @@
)


def _get_int_or_env_var(input_value: int | str | None, default_name: str | None = None) -> int:
if type(input_value) is int:
return input_value
elif type(input_value) is str:
if os.environ.get(input_value) is None:
msg = f"Environment variable {input_value} is not set"
raise ValueError(msg)
return int(os.environ.get(input_value))
elif default_name is not None:
if os.environ.get(default_name) is None:
msg = f"Environment variable {default_name} is not set"
raise ValueError(msg)
return int(os.environ.get(default_name))
else:
msg = f"Invalid input value: {input_value}, must be an integer or a string"
raise ValueError(msg)
Comment thread
sarahyurick marked this conversation as resolved.
Outdated


@dataclass
class FilePartitioningStage(ProcessingStage[_EmptyTask, FileGroupTask]):
"""Stage that partitions input file paths into FileGroupTasks.
Expand All @@ -55,6 +75,18 @@ class FilePartitioningStage(ProcessingStage[_EmptyTask, FileGroupTask]):
Storage options to pass to the file system.
limit: int | None = None
Maximum number of partitions to create.
enable_array_partitioning: bool = False
Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs).
Intended for use with Slurm job arrays via the `sbatch --array` option.
shard_index: int | str | None = None
The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable.
total_shards: int | str | None = None
The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable.
minimum_shard_index: int = 0
The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to 0.
"""

file_paths: str | list[str]
Expand All @@ -63,6 +95,10 @@ class FilePartitioningStage(ProcessingStage[_EmptyTask, FileGroupTask]):
file_extensions: list[str] | None = None
storage_options: dict[str, Any] | None = None
limit: int | None = None
enable_array_partitioning: bool = False
shard_index: int | str | None = None
total_shards: int | str | None = None
minimum_shard_index: int | str = 0
name: str = "file_partitioning"

def __post_init__(self):
Expand Down Expand Up @@ -91,6 +127,12 @@ def __post_init__(self):

self.resources = Resources(cpus=0.5)

if self.enable_array_partitioning:
self.shard_index = _get_int_or_env_var(self.shard_index, "SLURM_ARRAY_TASK_ID")
self.total_shards = _get_int_or_env_var(self.total_shards, "SLURM_ARRAY_TASK_COUNT")
self.minimum_shard_index = _get_int_or_env_var(self.minimum_shard_index)
self.name = "array_file_partitioning"

def inputs(self) -> tuple[list[str], list[str]]:
return [], []

Expand All @@ -106,7 +148,7 @@ def ray_stage_spec(self) -> dict[str, Any]:
def xenna_stage_spec(self) -> dict[str, Any]:
return {"num_workers_per_node": 1}

def process(self, _: _EmptyTask) -> list[FileGroupTask]:
def _process(self, _: _EmptyTask) -> list[FileGroupTask]:
"""Process the initial task to create file group tasks.

This stage expects a simple Task with file paths information
Expand Down Expand Up @@ -189,6 +231,31 @@ def process(self, _: _EmptyTask) -> list[FileGroupTask]:
logger.info(f"Created {len(tasks)} file groups from {len(files)} files")
return tasks

def _process_array(self, task: _EmptyTask) -> list[FileGroupTask]:
all_tasks = self._process(task)
assigned_tasks = []

for ft in all_tasks:
source_files = list(ft._metadata.get("source_files") or ft.data)
# Hash the source files to get a unique identifier for the partition
digest = hashlib.sha256("|".join(sorted(source_files)).encode("utf-8")).hexdigest()
# Assign the partition to the shard
assigned = int(digest[:16], 16) % self.total_shards
# Add the minimum shard index to the assigned shard index
assigned += self.minimum_shard_index
# Add the partition to the assigned tasks
if assigned == self.shard_index:
assigned_tasks.append(ft)

logger.info(f"Shard {self.shard_index}/{self.total_shards}: assigned {len(assigned_tasks)} of {len(all_tasks)} partitions")
return assigned_tasks

def process(self, task: _EmptyTask) -> list[FileGroupTask]:
if self.enable_array_partitioning:
return self._process_array(task)
else:
return self._process(task)

def _get_file_list_with_sizes(self, sort_by_size: bool = True) -> list[tuple[str, int]]:
"""
Get the list of files to process.
Expand Down
16 changes: 16 additions & 0 deletions nemo_curator/stages/interleaved/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class InterleavedWebdatasetReader(CompositeStage[_EmptyTask, InterleavedBatch]):
blocksize: int | str | None = None
max_batch_bytes: int | None = None
read_kwargs: dict[str, Any] = field(default_factory=dict)
enable_array_partitioning: bool = False
shard_index: int | str | None = None
total_shards: int | str | None = None
minimum_shard_index: int | str = 0
materialize_on_read: bool = False
file_extensions: list[str] = field(default_factory=lambda: list(DEFAULT_WEBDATASET_EXTENSIONS))
json_extensions: list[str] = field(default_factory=lambda: list(DEFAULT_JSON_EXTENSIONS))
Expand All @@ -68,6 +72,10 @@ def decompose(self) -> list:
blocksize=self.blocksize,
file_extensions=self.file_extensions,
storage_options=self.storage_options,
enable_array_partitioning=self.enable_array_partitioning,
shard_index=self.shard_index,
total_shards=self.total_shards,
minimum_shard_index=self.minimum_shard_index,
),
InterleavedWebdatasetReaderStage(
read_kwargs=self.read_kwargs,
Expand Down Expand Up @@ -96,6 +104,10 @@ class InterleavedParquetReader(CompositeStage[_EmptyTask, InterleavedBatch]):
fields: tuple[str, ...] | None = None
max_batch_bytes: int | None = None
read_kwargs: dict[str, Any] = field(default_factory=dict)
enable_array_partitioning: bool = False
shard_index: int | str | None = None
total_shards: int | str | None = None
minimum_shard_index: int | str = 0
schema: pa.Schema | None = None
schema_overrides: dict[str, pa.DataType] | None = None
file_extensions: list[str] = field(default_factory=lambda: [".parquet"])
Expand All @@ -113,6 +125,10 @@ def decompose(self) -> list:
blocksize=self.blocksize,
file_extensions=self.file_extensions,
storage_options=self.storage_options,
enable_array_partitioning=self.enable_array_partitioning,
shard_index=self.shard_index,
total_shards=self.total_shards,
minimum_shard_index=self.minimum_shard_index,
),
InterleavedParquetReaderStage(
read_kwargs=self.read_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/stages/text/deduplication/removal_workflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
19 changes: 19 additions & 0 deletions nemo_curator/stages/text/deduplication/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class TextSemanticDeduplicationWorkflow:
embedding_vllm_init_kwargs: dict[str, Any] | None = None
hf_token: str | None = None
model_cache_dir: str | None = None
enable_array_partitioning: bool = False
Comment thread
sarahyurick marked this conversation as resolved.
Outdated
shard_index: int | str | None = None
total_shards: int | str | None = None
minimum_shard_index: int | str = 0
# Semantic deduplication parameters
n_clusters: int = 100
id_field: str = CURATOR_DEDUP_ID_STR
Expand Down Expand Up @@ -132,6 +136,13 @@ class TextSemanticDeduplicationWorkflow:
embedding_vllm_init_kwargs: Additional kwargs passed to vLLM's LLM initializer
hf_token: HuggingFace token for private models
model_cache_dir: Directory to cache model weights
enable_array_partitioning: Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs).
shard_index: The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable.
total_shards: The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable.
minimum_shard_index: The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to 0.

# Semantic deduplication parameters
n_clusters: Number of clusters for K-means
Expand Down Expand Up @@ -252,6 +263,10 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]:
file_extensions=self.input_file_extensions,
_generate_ids=self.use_id_generator,
read_kwargs=self.read_kwargs,
enable_array_partitioning=self.enable_array_partitioning,
shard_index=self.shard_index,
total_shards=self.total_shards,
minimum_shard_index=self.minimum_shard_index,
)
elif self.input_filetype == "parquet":
reader = ParquetReader(
Expand All @@ -266,6 +281,10 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]:
file_extensions=self.input_file_extensions,
read_kwargs=self.read_kwargs,
_generate_ids=self.use_id_generator,
enable_array_partitioning=self.enable_array_partitioning,
shard_index=self.shard_index,
total_shards=self.total_shards,
minimum_shard_index=self.minimum_shard_index,
)
else:
msg = f"Input filetype {self.input_filetype} not supported yet"
Expand Down
8 changes: 8 additions & 0 deletions nemo_curator/stages/text/io/reader/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ class JsonlReader(CompositeStage[_EmptyTask, DocumentBatch]):
blocksize: int | str | None = None
fields: list[str] | None = None # If specified, only read these columns
read_kwargs: dict[str, Any] | None = None
enable_array_partitioning: bool = False
shard_index: int | str | None = None
total_shards: int | str | None = None
minimum_shard_index: int | str = 0
task_type: Literal["document", "image", "video", "audio"] = "document"
file_extensions: list[str] = field(default_factory=lambda: FILETYPE_TO_DEFAULT_EXTENSIONS["jsonl"])
_generate_ids: bool = False
Expand Down Expand Up @@ -121,6 +125,10 @@ def decompose(self) -> list[JsonlReaderStage]:
storage_options=self.read_kwargs.get("storage_options", None)
if self.read_kwargs is not None
else None,
enable_array_partitioning=self.enable_array_partitioning,
shard_index=self.shard_index,
total_shards=self.total_shards,
minimum_shard_index=self.minimum_shard_index,
),
JsonlReaderStage(
fields=self.fields,
Expand Down
8 changes: 8 additions & 0 deletions nemo_curator/stages/text/io/reader/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class ParquetReader(CompositeStage[_EmptyTask, DocumentBatch]):
blocksize: int | str | None = None
fields: list[str] | None = None # If specified, only read these columns
read_kwargs: dict[str, Any] | None = None
enable_array_partitioning: bool = False
shard_index: int | str | None = None
total_shards: int | str | None = None
minimum_shard_index: int | str = 0
file_extensions: list[str] = field(default_factory=lambda: FILETYPE_TO_DEFAULT_EXTENSIONS["parquet"])
task_type: Literal["document", "image", "video", "audio"] = "document"
_generate_ids: bool = False
Expand All @@ -105,6 +109,10 @@ def decompose(self) -> list[ParquetReaderStage]:
blocksize=self.blocksize,
file_extensions=self.file_extensions,
storage_options=self.read_kwargs.get("storage_options", {}) if self.read_kwargs is not None else None,
enable_array_partitioning=self.enable_array_partitioning,
shard_index=self.shard_index,
total_shards=self.total_shards,
minimum_shard_index=self.minimum_shard_index,
),
# Second stage: process file groups into document batches
ParquetReaderStage(
Expand Down
20 changes: 19 additions & 1 deletion nemo_curator/stages/video/io/video_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -244,11 +244,22 @@ class VideoReader(CompositeStage[_EmptyTask, VideoTask]):
input_video_path: Path to the directory containing video files
video_limit: Maximum number of videos to process (None for unlimited)
verbose: Whether to enable verbose logging during download/processing
enable_array_partitioning: Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs).
shard_index: The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable.
total_shards: The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable.
minimum_shard_index: The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to 0.
"""

input_video_path: str
video_limit: int | None = None
verbose: bool = False
enable_array_partitioning: bool = False
shard_index: int | str | None = None
total_shards: int | str | None = None
minimum_shard_index: int | str = 0

def __post_init__(self):
"""Initialize the parent CompositeStage after dataclass initialization."""
Expand Down Expand Up @@ -276,6 +287,9 @@ def decompose(self) -> list[ProcessingStage]:
List of processing stages: [FilePartitioningStage, VideoReaderStage]
"""
if is_remote_url(self.input_video_path):
if self.enable_array_partitioning:
msg = "enable_array_partitioning is not supported for ClientPartitioningStage"
raise NotImplementedError(msg)
reader_stage = ClientPartitioningStage(
file_paths=self.input_video_path,
files_per_partition=1,
Expand All @@ -288,6 +302,10 @@ def decompose(self) -> list[ProcessingStage]:
files_per_partition=1,
file_extensions=[".mp4", ".mov", ".avi", ".mkv", ".webm"],
limit=self.video_limit,
enable_array_partitioning=self.enable_array_partitioning,
shard_index=self.shard_index,
total_shards=self.total_shards,
minimum_shard_index=self.minimum_shard_index,
)

download_stage = VideoReaderStage(input_path=self.input_video_path, verbose=self.verbose)
Expand Down
Loading