-
Notifications
You must be signed in to change notification settings - Fork 293
Add support for Slurm arrays #2059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bde2217
a0595f6
43ee179
cae17b3
acfeceb
6eaf95e
bb1e30a
2ccbd3f
1b659ea
3522809
717edac
8f2345b
ebba73e
5e58793
ad8f68a
437270f
b55ec47
672f3d2
2814eec
4dacd31
d54d3a3
5e18093
94312c4
298a632
4d5d20a
e0a8af5
62a4493
11c74cc
67e0eef
3ee13ff
14ddc3c
f024975
ebc93dc
3fc2dbc
be338ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import hashlib | ||
| import os | ||
| from pathlib import Path | ||
|
|
||
| from loguru import logger | ||
|
|
||
| from nemo_curator.tasks.sentinels import FailedTask | ||
| from nemo_curator.utils.atomic_io import write_json_atomically | ||
|
|
||
| FAILED_TASKS_DIR_ENV_VAR = "NEMO_CURATOR_FAILED_TASKS_DIR" | ||
| FAILED_TASK_MARKER_PATTERN = "failed_task_*.json" | ||
|
|
||
|
|
||
| def _write_failed_task_marker(marker_dir: Path, stage_name: str, task: FailedTask) -> None: | ||
| """Write one compact marker for a failed stage/task pair.""" | ||
| payload = { | ||
| "stage_name": stage_name, | ||
| "task_id": task.task_id, | ||
| } | ||
|
|
||
| marker_identity = f"{stage_name}\0{task.task_id}".encode() | ||
| marker_digest = hashlib.sha256(marker_identity).hexdigest()[:16] | ||
| filename = f"failed_task_{marker_digest}.json" | ||
| final_path = marker_dir / filename | ||
| write_json_atomically(final_path, payload, separators=(",", ":"), sort_keys=True) | ||
|
|
||
|
|
||
| def record_failed_tasks(stage_name: str, failed_tasks: list[FailedTask]) -> None: | ||
| """Record FailedTask markers when ``NEMO_CURATOR_FAILED_TASKS_DIR`` is set.""" | ||
| marker_dir = os.environ.get(FAILED_TASKS_DIR_ENV_VAR) | ||
| if not marker_dir or not failed_tasks: | ||
| return | ||
|
|
||
| marker_path = Path(marker_dir) | ||
| for task in failed_tasks: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still not sure why we need it for every task in failed task. This makes writing difficult and reading diffiult? For each Slurm array idx, would it not help to just have a single marker saying this idx failed and that's all? Also, can we move "if failed task" to the base.py code for readability? Just if condition and call this function
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I can see your point. For me it seems easy enough to check if there is anything in the failed tasks directory to know whether or not to rerun. I thought it could be useful to understand how many failed tasks there were (1 vs 1 million) to maybe inform whether or not you even want to rerun it. (Since the files are tiny I am not worried about disk space.) And if each manifest saves which stage produced the failed task, then you know which stages you need to potentially update or edit before rerunning. But I am open to other approaches too! Let me know what you think. |
||
| try: | ||
| _write_failed_task_marker(marker_path, stage_name, task) | ||
| except Exception as e: # noqa: BLE001 | ||
| logger.warning(f"Failed to write FailedTask marker to {marker_path}: {e}") | ||
|
|
||
|
|
||
| def summarize_failed_task_markers( | ||
| marker_dir: str | Path | None = None, | ||
| ) -> dict[str, object]: | ||
| """Count FailedTask markers in ``marker_dir`` or the configured env dir.""" | ||
| resolved_marker_dir = marker_dir if marker_dir is not None else os.environ.get(FAILED_TASKS_DIR_ENV_VAR) | ||
| if not resolved_marker_dir: | ||
| return { | ||
| "failed_task_marker_count": 0, | ||
| } | ||
|
|
||
| marker_path = Path(resolved_marker_dir).absolute() | ||
| if not marker_path.exists(): | ||
| return { | ||
| "failed_task_marker_count": 0, | ||
| } | ||
|
|
||
| marker_count = 0 | ||
| for path in marker_path.glob(FAILED_TASK_MARKER_PATTERN): | ||
| if not path.is_file(): | ||
| continue | ||
| marker_count += 1 | ||
|
|
||
| return { | ||
| "failed_task_marker_count": marker_count, | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,197 @@ | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import hashlib | ||
| import os | ||
| from dataclasses import dataclass | ||
|
|
||
| from loguru import logger | ||
|
|
||
| from nemo_curator.tasks import Task | ||
| from nemo_curator.tasks.sentinels import FailedTask | ||
| from nemo_curator.utils.retry_manifest import RetryManifest | ||
|
|
||
| SLURM_ARRAY_ENABLED_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_ENABLED" | ||
| SLURM_ARRAY_SHARD_INDEX_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_SHARD_INDEX" | ||
| SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_TOTAL_SHARDS" | ||
| SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR = "NEMO_CURATOR_SLURM_ARRAY_MINIMUM_SHARD_INDEX" | ||
| SLURM_ARRAY_RETRY_MANIFEST_NAMESPACE = "slurm_array" | ||
| SLURM_ARRAY_RETRY_DIRNAME = ".slurm_array_retry" | ||
|
|
||
| _TRUE_ENV_VALUES = {"1", "true", "yes", "on"} | ||
|
|
||
|
|
||
| def _get_int_env_var(env_var: str, fallback_name: str | None = None, default: int | None = None) -> int: | ||
| """Read an integer env var, with optional fallback/default.""" | ||
| env_value = os.environ.get(env_var) | ||
| if env_value is None: | ||
| if fallback_name is not None: | ||
| env_var = fallback_name | ||
| env_value = os.environ.get(env_var) | ||
|
|
||
| if env_value is None: | ||
| if default is not None: | ||
| return default | ||
|
|
||
| msg = f"Environment variable {env_var} is not set" | ||
| raise ValueError(msg) | ||
|
|
||
| try: | ||
| return int(env_value) | ||
| except ValueError as e: | ||
| msg = f"Environment variable {env_var} must contain an integer, got {env_value!r}" | ||
| raise ValueError(msg) from e | ||
|
|
||
|
|
||
| @dataclass | ||
| class SlurmArrayConfig: | ||
| """Source-task sharding settings for one Slurm array task.""" | ||
|
|
||
| shard_index: int | ||
| total_shards: int | ||
| minimum_shard_index: int = 0 | ||
|
|
||
| @classmethod | ||
| def from_env(cls) -> "SlurmArrayConfig | None": | ||
| """Build config from Curator env vars, falling back to Slurm env vars.""" | ||
| enabled = os.environ.get(SLURM_ARRAY_ENABLED_ENV_VAR, "") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we gating on this. This would mean usre needs to enable this variable ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, see here too: #2059 (comment) IMO we need a way to explicitly opt into Slurm array partitioning, and this variable is how users can do it. |
||
| if enabled.strip().lower() not in _TRUE_ENV_VALUES: | ||
| return None | ||
|
|
||
| return cls( | ||
| shard_index=_get_int_env_var(SLURM_ARRAY_SHARD_INDEX_ENV_VAR, "SLURM_ARRAY_TASK_ID"), | ||
| total_shards=_get_int_env_var(SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR, "SLURM_ARRAY_TASK_COUNT"), | ||
| minimum_shard_index=_get_int_env_var(SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR, default=0), | ||
| ) | ||
|
|
||
|
|
||
| def configure_slurm_array_source_filtering( | ||
| shard_index: int, | ||
| total_shards: int, | ||
| minimum_shard_index: int, | ||
| ) -> None: | ||
| """Set env vars consumed by source-stage filtering.""" | ||
| os.environ[SLURM_ARRAY_ENABLED_ENV_VAR] = "1" | ||
| os.environ[SLURM_ARRAY_SHARD_INDEX_ENV_VAR] = str(shard_index) | ||
| os.environ[SLURM_ARRAY_TOTAL_SHARDS_ENV_VAR] = str(total_shards) | ||
| os.environ[SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR] = str(minimum_shard_index) | ||
|
|
||
|
|
||
| def resolve_slurm_array_config(is_source_stage: bool) -> SlurmArrayConfig | None: | ||
| """Resolve filtering config for source stages.""" | ||
| if not is_source_stage: | ||
| return None | ||
|
|
||
| resolved = SlurmArrayConfig.from_env() | ||
| if resolved is None: | ||
| return None | ||
|
|
||
| if resolved.total_shards <= 0: | ||
| msg = f"total_shards must be greater than 0, got {resolved.total_shards}" | ||
| raise ValueError(msg) | ||
|
|
||
| min_assignable_shard_index = resolved.minimum_shard_index | ||
| max_assignable_shard_index = resolved.minimum_shard_index + resolved.total_shards - 1 | ||
| if not min_assignable_shard_index <= resolved.shard_index <= max_assignable_shard_index: | ||
| logger.warning( | ||
| "shard_index={} is outside the assignable shard range [{}, {}]. " | ||
| "This task will not receive any source tasks.", | ||
| resolved.shard_index, | ||
| min_assignable_shard_index, | ||
| max_assignable_shard_index, | ||
| ) | ||
|
|
||
| return resolved | ||
|
|
||
|
|
||
| def slurm_array_shard_for_task(task: Task, slurm_array: SlurmArrayConfig) -> int: | ||
| """Assign a task to a shard by hashing its deterministic task ID.""" | ||
| digest = hashlib.sha256(task.task_id.encode("utf-8")).hexdigest() | ||
| return int(digest[:16], 16) % slurm_array.total_shards + slurm_array.minimum_shard_index | ||
|
|
||
|
|
||
| def filter_slurm_array_source_tasks( | ||
| tasks: list[Task], | ||
| slurm_array: SlurmArrayConfig | None, | ||
| stage_name: str, | ||
| ) -> list[Task]: | ||
| """Keep only source tasks assigned to the active Slurm array shard.""" | ||
| if slurm_array is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move this into a single if in base.py? Currently, I feel the code is very less readable since if I'm not using Slurm arrays. I'm coming to a ton of functions and just returning.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This if statement is just acting as an additional barrier to ensure we are only using this function if Slurm array partitioning has been enabled. I can add more comments to base.py if you think it can help? I mainly want to avoid having base.py reason about Slurm array partitioning at all in favor of delegating the work to these functions here, so that base.py does not get clobbered up with a bunch of Slurm-specific things.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops misread the comment I think. Sure I can move the if statement only to base.py, to make it more readable. |
||
| return tasks | ||
|
|
||
| nondeterministic_task_ids = [task.task_id for task in tasks if task.task_id.startswith("r")] | ||
| if nondeterministic_task_ids: | ||
| msg = ( | ||
| "Slurm array source filtering requires deterministic task IDs, but stage " | ||
| f"{stage_name} emitted ambiguous source task IDs: {nondeterministic_task_ids[:5]}" | ||
| ) | ||
| raise ValueError(msg) | ||
|
|
||
| assigned_tasks = [ | ||
| task for task in tasks if slurm_array_shard_for_task(task, slurm_array) == slurm_array.shard_index | ||
| ] | ||
|
|
||
| msg = ( | ||
| f"Slurm array shard {slurm_array.shard_index}/{slurm_array.total_shards}: " | ||
| f"assigned {len(assigned_tasks)} of {len(tasks)} source tasks for stage {stage_name}" | ||
| ) | ||
| if len(assigned_tasks) == 0 and len(tasks) > 0: | ||
| logger.warning(msg) | ||
| else: | ||
| logger.info(msg) | ||
|
|
||
| return assigned_tasks | ||
|
|
||
|
|
||
| def raise_for_failed_source_tasks_with_slurm_array( | ||
| stage_name: str, | ||
| failed_tasks: list[FailedTask], | ||
| slurm_array: SlurmArrayConfig | None, | ||
| ) -> None: | ||
| """Reject source-stage FailedTasks, which cannot be retried by shard.""" | ||
| if failed_tasks and slurm_array is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this? Can't this just mean we ignore this task? Like, if we have [task_1, task_2, failed, task_4] and sluirm array of 2. Then can't we do task_1 and task_2 first and task_4 for another? Also, why not treat None the same way then? I think for resumability, also we need think about this. For now, if a task is failed at the source stage, I assume it needs to rerun. I'm ok with having a contract that source stages cannot emit failed or none tasks. but it should be global then and not just for Slurm array.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this because I am worried about there being edge cases when a source stage produces failed tasks, specifically with Slurm array partitioning enabled. Like with the FilePartitioningStage, the list of file names is used to generate the deterministic IDs, but if there is a failed task then won't that affect how a deterministic ID is generated and thus how it is bucketed into a Slurm array index? So would this mess up which Slurm index it is assigned to upon retry? I haven't thought very thoroughly about the above example so maybe it is not a big concern. But since Slurm array partitioning is supposed to work with any source stage, I wanted to add this restriction until we could be more confident here. Wdyt?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want me to make it more general and crash if the source stage emits any failed tasks (regardless of if we are doing Slurm array partitioning), then I can do that. |
||
| msg = ( | ||
| f"Source stage {stage_name} emitted FailedTask while Slurm array filtering is enabled. " | ||
| "This is not supported because the failed source task cannot be assigned to a retry shard " | ||
| "reliably. Raise an exception from the source stage instead." | ||
| ) | ||
| raise ValueError(msg) | ||
|
|
||
|
|
||
| def is_slurm_array_driver_process(use_slurm: bool) -> bool: | ||
| """Return true for the process that owns retry metadata.""" | ||
| return not use_slurm or os.environ.get("SLURM_NODEID", "0") == "0" | ||
|
|
||
|
|
||
| def build_slurm_array_retry_manifest( | ||
| checkpoint_path: str | None, | ||
| shard_index: int, | ||
| total_shards: int, | ||
| minimum_shard_index: int, | ||
| ) -> RetryManifest | None: | ||
| """Create a retry manifest for one Slurm array shard.""" | ||
| if checkpoint_path is None: | ||
| return None | ||
|
|
||
| return RetryManifest( | ||
| checkpoint_path=checkpoint_path, | ||
| namespace=SLURM_ARRAY_RETRY_MANIFEST_NAMESPACE, | ||
| retry_dirname=SLURM_ARRAY_RETRY_DIRNAME, | ||
| identity={ | ||
| "minimum_shard_index": minimum_shard_index, | ||
| "shard_index": shard_index, | ||
| "total_shards": total_shards, | ||
| }, | ||
| flatten_identity=True, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might have a merge conflict on lin 99, 113. But I think the intention is the same. So one of can resolve it whoever merges later.