Skip to content

Commit d35c8aa

Browse files
committed
minor cleaning
Signed-off-by: Sarah Yurick <sarahyurick@gmail.com>
1 parent 03421ea commit d35c8aa

11 files changed

Lines changed: 97 additions & 31 deletions

File tree

nemo_curator/backends/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def process_batch(self, tasks: list[Task]) -> list[Task]:
110110
raise ValueError(msg)
111111

112112
# Record failed tasks for later inspection or retry bookkeeping.
113-
record_failed_tasks(self.stage.name, failed_tasks)
113+
record_failed_tasks(failed_tasks)
114114

115115
# Sentinels never propagate to the next stage.
116116
results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))]

nemo_curator/backends/failed_task_markers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def configure_slurm_array_failed_task_manifest_dir(checkpoint_path: str | Path,
6262
return _configure_failed_task_manifest_dir(manifest_dir)
6363

6464

65-
def record_failed_tasks(_stage_name: str, failed_tasks: list[FailedTask]) -> None:
65+
def record_failed_tasks(failed_tasks: list[FailedTask]) -> None:
6666
"""Write one attempt-scoped manifest after any FailedTask is detected."""
6767
if not failed_tasks:
6868
return

nemo_curator/backends/slurm_array.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,26 @@
3939

4040
def _get_int_env_var(env_var: str, fallback_name: str | None = None, default: int | None = None) -> int:
4141
"""Read an integer env var, with optional fallback/default."""
42+
resolved_var = env_var
4243
env_value = os.environ.get(env_var)
43-
if env_value is None:
44-
if fallback_name is not None:
45-
env_var = fallback_name
46-
env_value = os.environ.get(env_var)
44+
if env_value is None and fallback_name is not None:
45+
resolved_var = fallback_name
46+
env_value = os.environ.get(fallback_name)
4747

48-
if env_value is None:
49-
if default is not None:
50-
return default
48+
if env_value is None:
49+
if default is not None:
50+
return default
5151

52+
if fallback_name is not None:
53+
msg = f"Environment variable {env_var} (or {fallback_name}) is not set"
54+
else:
5255
msg = f"Environment variable {env_var} is not set"
53-
raise ValueError(msg)
56+
raise ValueError(msg)
5457

5558
try:
5659
return int(env_value)
5760
except ValueError as e:
58-
msg = f"Environment variable {env_var} must contain an integer, got {env_value!r}"
61+
msg = f"Environment variable {resolved_var} must contain an integer, got {env_value!r}"
5962
raise ValueError(msg) from e
6063

6164

@@ -195,7 +198,13 @@ def filter_slurm_array_source_tasks(
195198

196199

197200
def is_slurm_array_driver_process(use_slurm: bool) -> bool:
198-
"""Return true for the process that owns retry metadata."""
201+
"""Return true for the process that owns retry metadata.
202+
203+
When ``use_slurm`` is False (local or single-node) every process is the
204+
driver. When ``use_slurm`` is True, only the Slurm head node
205+
(``SLURM_NODEID == 0``) is the driver; if the variable is absent (e.g.
206+
bare ``srun`` without an array) the process is treated as the head.
207+
"""
199208
return not use_slurm or os.environ.get("SLURM_NODEID", "0") == "0"
200209

201210

nemo_curator/utils/atomic_io.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ def fsync_directory(path: Path) -> None:
3232
os.close(dir_fd)
3333

3434

35+
def _unlink_best_effort(path: Path) -> None:
36+
"""Remove a temporary file without masking the primary result."""
37+
try:
38+
path.unlink(missing_ok=True)
39+
except OSError:
40+
pass
41+
42+
3543
def _write_json_temp_file(
3644
path: Path,
3745
payload: Any, # noqa: ANN401
@@ -57,10 +65,10 @@ def _write_json_temp_file(
5765
tmp.write("\n")
5866
tmp.flush()
5967
os.fsync(tmp.fileno())
60-
return Path(tmp.name)
68+
return tmp_path
6169
except Exception:
6270
if tmp_path is not None:
63-
tmp_path.unlink(missing_ok=True)
71+
_unlink_best_effort(tmp_path)
6472
raise
6573

6674

@@ -97,7 +105,7 @@ def write_json_atomically(
97105
os.replace(tmp_path, path)
98106
_fsync_directory_best_effort(path.parent)
99107
except Exception:
100-
tmp_path.unlink(missing_ok=True)
108+
_unlink_best_effort(tmp_path)
101109
raise
102110

103111

@@ -120,9 +128,12 @@ def write_json_atomically_if_absent(
120128
try:
121129
os.link(tmp_path, path)
122130
except FileExistsError:
131+
_unlink_best_effort(tmp_path)
123132
return False
124-
finally:
125-
tmp_path.unlink(missing_ok=True)
133+
except Exception:
134+
_unlink_best_effort(tmp_path)
135+
raise
126136

137+
_unlink_best_effort(tmp_path)
127138
_fsync_directory_best_effort(path.parent)
128139
return True

tests/backends/test_base_stage_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def filter_tasks(
104104
calls["filter_tasks"] += 1
105105
return tasks
106106

107-
def record_failed_tasks(_stage_name: str, _failed_tasks: list[FailedTask]) -> None:
107+
def record_failed_tasks(_failed_tasks: list[FailedTask]) -> None:
108108
calls["record_failed_tasks"] += 1
109109

110110
monkeypatch.setattr(base_module, "resolve_slurm_array_config", resolve_config)

tests/backends/test_failed_task_markers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_record_failed_tasks_writes_single_manifest(
8383
manifest_dir = tmp_path / "failed-tasks"
8484
monkeypatch.setenv(FAILED_TASKS_DIR_ENV_VAR, str(manifest_dir))
8585

86-
record_failed_tasks("failed", [_failed_task("0_7_0"), _failed_task("0_8_0")])
86+
record_failed_tasks([_failed_task("0_7_0"), _failed_task("0_8_0")])
8787

8888
manifest_files = list(manifest_dir.glob("*.json"))
8989
assert manifest_files == [manifest_dir / FAILED_TASK_MANIFEST_FILENAME]
@@ -95,11 +95,11 @@ def test_additional_failed_tasks_leave_existing_manifest_unchanged(
9595
) -> None:
9696
manifest_dir = tmp_path / "failed-tasks"
9797
monkeypatch.setenv(FAILED_TASKS_DIR_ENV_VAR, str(manifest_dir))
98-
record_failed_tasks("stage-a", [_failed_task("0_7_0")])
98+
record_failed_tasks([_failed_task("0_7_0")])
9999
manifest_file = manifest_dir / FAILED_TASK_MANIFEST_FILENAME
100100
original_manifest = manifest_file.read_text()
101101

102-
record_failed_tasks("stage-b", [_failed_task("0_8_0")])
102+
record_failed_tasks([_failed_task("0_8_0")])
103103

104104
assert list(manifest_dir.glob("*.json")) == [manifest_file]
105105
assert manifest_file.read_text() == original_manifest
@@ -110,7 +110,7 @@ def test_record_failed_tasks_without_configured_attempt_is_noop(
110110
monkeypatch.delenv(FAILED_TASKS_DIR_ENV_VAR, raising=False)
111111
monkeypatch.chdir(tmp_path)
112112

113-
record_failed_tasks("failed", [_failed_task()])
113+
record_failed_tasks([_failed_task()])
114114

115115
assert not (tmp_path / ".nemo_curator_metadata").exists()
116116

@@ -120,7 +120,7 @@ def test_record_failed_tasks_does_not_write_manifest_for_empty_list(
120120
manifest_dir = tmp_path / "failed-tasks"
121121
monkeypatch.setenv(FAILED_TASKS_DIR_ENV_VAR, str(manifest_dir))
122122

123-
record_failed_tasks("failed", [])
123+
record_failed_tasks([])
124124

125125
assert not manifest_dir.exists()
126126
assert not failed_task_manifest_exists()
@@ -137,7 +137,7 @@ def fail_write(*_args: object, **_kwargs: object) -> None:
137137
monkeypatch.setattr(failed_task_markers_module, "write_json_atomically", fail_write)
138138

139139
with pytest.raises(OSError, match="storage unavailable"):
140-
record_failed_tasks("failed", [_failed_task()])
140+
record_failed_tasks([_failed_task()])
141141

142142
def test_failed_task_manifest_exists_accepts_explicit_directory(self, tmp_path: Path) -> None:
143143
manifest_dir = tmp_path / "failed-tasks"

tests/backends/test_slurm_array.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import nemo_curator.backends.slurm_array as slurm_array_module
2323
from nemo_curator.backends.slurm_array import (
24+
SLURM_ARRAY_COMPLETION_MANIFEST_NAMESPACE,
2425
SLURM_ARRAY_ENABLED_ENV_VAR,
2526
SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR,
2627
SLURM_ARRAY_SHARD_INDEX_ENV_VAR,
@@ -205,6 +206,11 @@ def test_filtering_is_disabled_without_config(self) -> None:
205206

206207
assert filter_slurm_array_source_tasks(tasks, None, "source") == tasks
207208

209+
def test_filtering_returns_empty_for_empty_input(self) -> None:
210+
slurm_array = SlurmArrayConfig(shard_index=0, total_shards=3)
211+
212+
assert filter_slurm_array_source_tasks([], slurm_array, "source") == []
213+
208214
def test_assigns_each_source_task_to_one_shard(self) -> None:
209215
tasks = [_task(f"0_{i}") for i in range(8)]
210216
assigned_task_ids = []
@@ -256,6 +262,7 @@ def test_is_driver_process_for_local_and_slurm_head(self, monkeypatch: MonkeyPat
256262
monkeypatch.setenv("SLURM_NODEID", "1")
257263

258264
assert is_slurm_array_driver_process(use_slurm=True) is False
265+
assert is_slurm_array_driver_process(use_slurm=False) is True
259266

260267
def test_build_completion_manifest_writes_run_config_and_shard_identity(self, tmp_path: Path) -> None:
261268
manifest = build_slurm_array_completion_manifest(
@@ -335,6 +342,27 @@ def test_find_retries_returns_empty_plan_when_all_shards_completed(self, tmp_pat
335342
def test_find_retries_returns_none_without_run_config(self, tmp_path: Path) -> None:
336343
assert find_slurm_array_retries(tmp_path) is None
337344

345+
def test_find_retries_rejects_zero_total_shards_in_run_config(self, tmp_path: Path) -> None:
346+
manifest_dir = tmp_path / METADATA_DIRNAME / ".slurm_array_completion"
347+
manifest_dir.mkdir(parents=True)
348+
(manifest_dir / "run.json").write_text('{"minimum_shard_index":0,"total_shards":0}\n')
349+
350+
with pytest.raises(ValueError, match="total_shards greater than 0"):
351+
find_slurm_array_retries(tmp_path)
352+
353+
def test_find_retries_rejects_out_of_range_completed_shard(self, tmp_path: Path) -> None:
354+
manifest_dir = tmp_path / METADATA_DIRNAME / ".slurm_array_completion"
355+
manifest_dir.mkdir(parents=True)
356+
(manifest_dir / "run.json").write_text('{"minimum_shard_index":1,"total_shards":3}\n')
357+
# shard_index=5 is outside the valid range [1, 3]
358+
filename = f"completed_{SLURM_ARRAY_COMPLETION_MANIFEST_NAMESPACE}_bad.json"
359+
(manifest_dir / filename).write_text(
360+
'{"minimum_shard_index":1,"shard_index":5,"status":"completed","total_shards":3}'
361+
)
362+
363+
with pytest.raises(ValueError, match="outside the original shard range"):
364+
find_slurm_array_retries(tmp_path)
365+
338366
def test_find_retries_rejects_negative_minimum_in_run_config(self, tmp_path: Path) -> None:
339367
manifest_dir = tmp_path / METADATA_DIRNAME / ".slurm_array_completion"
340368
manifest_dir.mkdir(parents=True)

tests/utils/test_atomic_io.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,19 @@ def test_write_json_atomically_if_absent_does_not_replace_existing_file(self, tm
6161

6262
assert json.loads(output_path.read_text()) == {"writer": 1}
6363
assert not list(output_path.parent.glob(f".{output_path.name}.*.tmp"))
64+
65+
def test_write_json_atomically_if_absent_does_not_fail_after_commit_when_cleanup_fails(
66+
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
67+
) -> None:
68+
output_path = tmp_path / "payload.json"
69+
original_unlink = Path.unlink
70+
71+
def fail_temp_cleanup(path: Path, *, missing_ok: bool = False) -> None:
72+
if path.suffix == ".tmp":
73+
raise OSError("cleanup unavailable")
74+
original_unlink(path, missing_ok=missing_ok)
75+
76+
monkeypatch.setattr(Path, "unlink", fail_temp_cleanup)
77+
78+
assert write_json_atomically_if_absent(output_path, {"writer": 1}) is True
79+
assert json.loads(output_path.read_text()) == {"writer": 1}

tutorials/slurm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ python tutorials/slurm/retry_array.py \
342342
--format fields
343343
```
344344

345-
For example, this output means that shards `1`, `2`, `5` through `10`, and `99` should be retried with offset `0`, using the original logical shard range of `0` through `99`:
345+
Without `--max-array-size`, the output is always at most one line (all missing logical shards are emitted as a single submission). For example, this output means that shards `1`, `2`, `5` through `10`, and `99` should be retried with offset `0`, using the original logical shard range of `0` through `99`:
346346

347347
```text
348348
1-2,5-10,99 0 0 100

tutorials/slurm/array_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ def main() -> None:
157157

158158
pipeline.run()
159159

160+
# Check unconditionally: the user may have set NEMO_CURATOR_FAILED_TASKS_DIR
161+
# manually even without --checkpoint-path, and the function safely returns
162+
# False when no directory is configured.
160163
if failed_task_manifest_exists():
161164
logger.warning(
162165
"Pipeline completed without raising, but a FailedTask manifest exists. "

0 commit comments

Comments
 (0)