|
21 | 21 |
|
22 | 22 | import nemo_curator.backends.slurm_array as slurm_array_module |
23 | 23 | from nemo_curator.backends.slurm_array import ( |
| 24 | + SLURM_ARRAY_COMPLETION_MANIFEST_NAMESPACE, |
24 | 25 | SLURM_ARRAY_ENABLED_ENV_VAR, |
25 | 26 | SLURM_ARRAY_MINIMUM_SHARD_INDEX_ENV_VAR, |
26 | 27 | SLURM_ARRAY_SHARD_INDEX_ENV_VAR, |
@@ -205,6 +206,11 @@ def test_filtering_is_disabled_without_config(self) -> None: |
205 | 206 |
|
206 | 207 | assert filter_slurm_array_source_tasks(tasks, None, "source") == tasks |
207 | 208 |
|
| 209 | + def test_filtering_returns_empty_for_empty_input(self) -> None: |
| 210 | + slurm_array = SlurmArrayConfig(shard_index=0, total_shards=3) |
| 211 | + |
| 212 | + assert filter_slurm_array_source_tasks([], slurm_array, "source") == [] |
| 213 | + |
208 | 214 | def test_assigns_each_source_task_to_one_shard(self) -> None: |
209 | 215 | tasks = [_task(f"0_{i}") for i in range(8)] |
210 | 216 | assigned_task_ids = [] |
@@ -256,6 +262,7 @@ def test_is_driver_process_for_local_and_slurm_head(self, monkeypatch: MonkeyPat |
256 | 262 | monkeypatch.setenv("SLURM_NODEID", "1") |
257 | 263 |
|
258 | 264 | assert is_slurm_array_driver_process(use_slurm=True) is False |
| 265 | + assert is_slurm_array_driver_process(use_slurm=False) is True |
259 | 266 |
|
260 | 267 | def test_build_completion_manifest_writes_run_config_and_shard_identity(self, tmp_path: Path) -> None: |
261 | 268 | manifest = build_slurm_array_completion_manifest( |
@@ -335,6 +342,27 @@ def test_find_retries_returns_empty_plan_when_all_shards_completed(self, tmp_pat |
335 | 342 | def test_find_retries_returns_none_without_run_config(self, tmp_path: Path) -> None: |
336 | 343 | assert find_slurm_array_retries(tmp_path) is None |
337 | 344 |
|
| 345 | + def test_find_retries_rejects_zero_total_shards_in_run_config(self, tmp_path: Path) -> None: |
| 346 | + manifest_dir = tmp_path / METADATA_DIRNAME / ".slurm_array_completion" |
| 347 | + manifest_dir.mkdir(parents=True) |
| 348 | + (manifest_dir / "run.json").write_text('{"minimum_shard_index":0,"total_shards":0}\n') |
| 349 | + |
| 350 | + with pytest.raises(ValueError, match="total_shards greater than 0"): |
| 351 | + find_slurm_array_retries(tmp_path) |
| 352 | + |
| 353 | + def test_find_retries_rejects_out_of_range_completed_shard(self, tmp_path: Path) -> None: |
| 354 | + manifest_dir = tmp_path / METADATA_DIRNAME / ".slurm_array_completion" |
| 355 | + manifest_dir.mkdir(parents=True) |
| 356 | + (manifest_dir / "run.json").write_text('{"minimum_shard_index":1,"total_shards":3}\n') |
| 357 | + # shard_index=5 is outside the valid range [1, 3] |
| 358 | + filename = f"completed_{SLURM_ARRAY_COMPLETION_MANIFEST_NAMESPACE}_bad.json" |
| 359 | + (manifest_dir / filename).write_text( |
| 360 | + '{"minimum_shard_index":1,"shard_index":5,"status":"completed","total_shards":3}' |
| 361 | + ) |
| 362 | + |
| 363 | + with pytest.raises(ValueError, match="outside the original shard range"): |
| 364 | + find_slurm_array_retries(tmp_path) |
| 365 | + |
338 | 366 | def test_find_retries_rejects_negative_minimum_in_run_config(self, tmp_path: Path) -> None: |
339 | 367 | manifest_dir = tmp_path / METADATA_DIRNAME / ".slurm_array_completion" |
340 | 368 | manifest_dir.mkdir(parents=True) |
|
0 commit comments