diff --git a/nemo_curator/backends/base.py b/nemo_curator/backends/base.py index 94236f2cfc..ec1c7e591f 100644 --- a/nemo_curator/backends/base.py +++ b/nemo_curator/backends/base.py @@ -17,14 +17,23 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from loguru import logger + from nemo_curator.core.utils import ignore_ray_head_node from nemo_curator.tasks import Task +from nemo_curator.tasks.sentinels import FailedTask, NoneTask from nemo_curator.utils.performance_utils import StageTimer +from nemo_curator.utils.resumability_client import _flush_deltas, _is_active, _skip_completed_sources if TYPE_CHECKING: from nemo_curator.stages.base import ProcessingStage +def _is_sentinel(task: Task) -> bool: + """A payload-less marker (NoneTask/FailedTask), stripped before the next stage.""" + return isinstance(task, (NoneTask, FailedTask)) + + @dataclass class NodeInfo: """Generic node information for setup_on_node calls across backends. @@ -85,9 +94,20 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: # Use the batch processing logic results = self.stage.process_batch(tasks) - # Guarantee every emitted task has a task_id (derived id, or uuid fallback). + # Replace a returned None ("filter this slot") with a NoneTask so every + # output gets a task_id; sentinels are stripped again below. + results = [NoneTask() if r is None else r for r in results] + + # Assign every emitted task a task_id (derived, or uuid fallback). results = self._post_process_task_ids(tasks, results) + # Opt-in resumability: fire per-source deltas (no-op when no actor registered). + if _is_active(): + results = self._apply_resumability_counters(tasks, results) + + # Sentinels never propagate to the next stage. + results = [r for r in results if not _is_sentinel(r)] + # Log performance stats and add to result tasks _, stage_perf_stats = self._timer.log_stats() # Consume and attach any custom metrics recorded by the stage during this call @@ -100,41 +120,22 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: return results def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Task | None]) -> list[Task]: - """Assign a deterministic ``task_id`` to every emitted task. - - This is the single place task ids are assigned — it runs for every - stage on every backend (all backend adapters subclass this), so it - makes no difference whether a stage defines ``process`` or overrides - ``process_batch``. ``task_id`` is the task's id path (parents + own segment); ids are - re-derived at each stage boundary so the same object passing through - N stages gets N ids. - - The input→output mapping decides each output's PARENT; whether the - stage is a source decides each output's SEGMENT (content id vs index) - — the two are independent. ``None`` outputs (Curator's "return None to - filter") are NOT removed before the length check — keeping them in - place preserves positional alignment for filter stages — and are then - dropped from the returned list. - - - single input → every output is its child (fan-out): ``parent_`` - - ``len(output) == len(input)`` → positional 1:1: each ``parent_i_``; - a ``None`` slot just means input ``i`` was filtered. - - any other (ambiguous) cardinality across a batch → a random ``uuid`` - prefixed with ``"r"`` (e.g. ``"r3f9a…"``), so ``task_id`` is never - empty even when a derived id is not possible. The ``"r"`` prefix flags - the id as non-deterministic / ancestry-not-tracked (see - ``Task.task_id`` docstring). - - ``seg`` is the output's content id (``Task.get_deterministic_id()``) - for a source stage when available, else the positional index — so a - source partition keeps a stable id across reorderings regardless of - whether the source is 1→N or N→N. - - Note: a stage that BOTH filters and fans out within a single batch - (returning a flat list rather than a per-input slot) cannot be mapped - positionally; if its length happens to equal the input length the 1:1 - assumption may misattribute parents. That combination is unsupported - until per-slot sentinels (NoneTask/FailedTask) land in a later PR. + """Assign a deterministic ``task_id`` (parent id + own segment) to every + emitted task. Runs once per stage on every backend, so ``process`` vs + ``process_batch`` makes no difference; ids are re-derived at each stage + boundary, so one object passing through N stages gets N ids. + + - single input → fan-out: each output is ``parent_`` + - ``len(output) == len(input)`` → positional 1:1: ``parent_i_``; a + ``None`` slot means input ``i`` was filtered (kept for alignment, then + dropped from the result) + - any other cardinality → a random ``"r"``-prefixed uuid (non-deterministic, + ancestry-not-tracked; see ``Task.task_id``) + + ``seg`` is the content id (``get_deterministic_id()``) for a source stage, + else the positional index. A stage that both filters and fans out in one + batch can't be mapped positionally and falls to the ``"r"`` case — return + one value (or ``None``) per input to stay positional. """ is_source = getattr(self.stage, "is_source_stage", False) @@ -168,6 +169,85 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas task.task_id = "r" + uuid.uuid4().hex return out + # Resumability (opt-in): stamp _source_id, fire per-source deltas, drop + # completed sources. task_ids are already assigned; sentinels stripped by caller. + def _apply_resumability_counters(self, input_tasks: list[Task], output_tasks: list[Task]) -> list[Task]: # noqa: C901 + # Dedup key is always an OUTPUT task_id, never the input's: the source + # already keyed its +1 on that id, and an output id is one level deeper, + # so it's unique to the (task, stage) that produced it. + stage = self.stage + if getattr(stage, "is_source_stage", False): + return self._source_counters(output_tasks) + + # No outputs to key on (filtering uses None->NoneTask, so this is degenerate): skip. + if not output_tasks: + return output_tasks + + # Pre-source: inputs have no _source_id yet; nothing to track. + if all(not t._source_id for t in input_tasks): + return output_tasks + + is_sink = stage.is_sink_stage + per_task: list[tuple[str, str, int]] = [] + real = [t for t in output_tasks if not _is_sentinel(t)] + + if len(input_tasks) == 1 and len(output_tasks) != 1: + # Fan-out (1->N): parent consumed (-1); each real child continues + # (+1, or 0 at a sink); each FailedTask keeps the source open (+1); + # NoneTask contributes 0. + parent = input_tasks[0] + n_failed = sum(1 for t in output_tasks if isinstance(t, FailedTask)) + continuing = 0 if is_sink else len(real) + delta = continuing + n_failed - 1 + # Key on output[0].task_id (not parent.task_id, which collides with the + # source's +1). Non-source children are indexed positionally, so + # output[0] is always "_0". + per_task.append((output_tasks[0].task_id, parent._source_id, delta)) + for c in real: + if not c._source_id: + c._source_id = parent._source_id + elif len(output_tasks) == len(input_tasks): + # Positional 1:1; each delta keys on the output id (r.task_id). + for parent, r in zip(input_tasks, output_tasks, strict=True): + sid = parent._source_id + if isinstance(r, NoneTask): # filtered -> consumed + per_task.append((r.task_id, sid, -1)) + continue + if isinstance(r, FailedTask): # failed -> source stays open (no sink test) + per_task.append((r.task_id, sid, 0)) + continue + per_task.append((r.task_id, sid, -1 if is_sink else 0)) # real: sink -1, else 0 + if not r._source_id: + r._source_id = sid + else: + # M->K (M!=K): can't attribute parents; skip (source stays pending -> reprocessed). + logger.warning( + f"resumability: {type(stage).__name__} produced {len(output_tasks)} outputs " + f"for {len(input_tasks)} inputs; can't attribute sources, skipping counter " + f"update for this batch." + ) + return output_tasks + + _flush_deltas(per_task) + return output_tasks + + def _source_counters(self, output_tasks: list[Task]) -> list[Task]: + """Source stage: each output is a source partition; its ``_source_id`` is + its own last id segment. Drop already-completed sources; each survivor fires ``+1``.""" + sources = [t for t in output_tasks if not _is_sentinel(t)] + for t in sources: + t._source_id = t.task_id.rsplit("_", 1)[-1] + completed = _skip_completed_sources([t._source_id for t in sources]) + per_task: list[tuple[str, str, int]] = [] + survivors: list[Task] = [] + for t in sources: + if t._source_id in completed: + continue + per_task.append((t.task_id, t._source_id, +1)) + survivors.append(t) + _flush_deltas(per_task) + return survivors + def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None: """Setup the stage on a node. diff --git a/nemo_curator/pipeline/pipeline.py b/nemo_curator/pipeline/pipeline.py index 961ae33c6f..3d36ec7a1c 100644 --- a/nemo_curator/pipeline/pipeline.py +++ b/nemo_curator/pipeline/pipeline.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from typing import Any from loguru import logger @@ -107,8 +108,9 @@ def build(self) -> None: # 3. Source / sink defaults: at most one stage may be explicitly # marked; if none, the first stage is the source and the last is # the sink. The source flag activates content-based ids in the - # default ``process_batch``; the sink flag is used by the - # resumability layer in a follow-up PR. + # default ``process_batch``; the sink flag tells the resumability + # counters that a sink consumes its outputs (see + # ``BaseStageAdapter._apply_resumability_counters``). self._assign_source_sink_roles() def _assign_source_sink_roles(self) -> None: @@ -222,18 +224,32 @@ def describe(self) -> str: return "\n".join(lines) - def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] | None = None) -> list[Task] | None: + def run( + self, + executor: BaseExecutor | None = None, + initial_tasks: list[Task] | None = None, + checkpoint_path: str | Path | None = None, + ) -> list[Task] | None: """Run the pipeline. Args: executor (BaseExecutor): Executor to use initial_tasks (list[Task], optional): Initial tasks to start the pipeline with. Defaults to None. + checkpoint_path (str | Path, optional): Resumability directory. When + set, completed source partitions are tracked (in a + ``.nemo_curator_metadata`` subdir) and skipped on rerun. Multiple + runs (e.g. a SLURM array) may share the directory — each writes + its own LMDB file, so there is no contention. Returns: list[Task] | None: List of tasks """ self.build() + if checkpoint_path is not None: + checkpoint_path = Path(checkpoint_path).absolute() + checkpoint_path.mkdir(parents=True, exist_ok=True) + if executor is None: from nemo_curator.backends.xenna import XennaExecutor @@ -263,4 +279,41 @@ def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] | if initial_tasks: assign_root_task_ids(initial_tasks) - return executor.execute(self.stages, initial_tasks) + if checkpoint_path is None: + return executor.execute(self.stages, initial_tasks) + return self._run_with_resumability(executor, initial_tasks, checkpoint_path) + + def _run_with_resumability( + self, + executor: BaseExecutor, + initial_tasks: list[Task] | None, + checkpoint_path: Path, + ) -> list[Task] | None: + """Own the resumability-actor lifecycle (executors unmodified): spawn it + ``lifetime="detached"`` so it survives executor-local ``ray.shutdown()``, + run, then close. The actor never raises, so there's no error path here.""" + import ray + + from nemo_curator.utils.resumability_actor import ResumabilityActor + from nemo_curator.utils.resumability_client import ACTOR_NAME + + ray.init(ignore_reinit_error=True) + ResumabilityActor.options( # type: ignore[attr-defined] + name=ACTOR_NAME, + lifetime="detached", + get_if_exists=True, + max_pending_calls=100, + ).remote(str(checkpoint_path)) + + try: + return executor.execute(self.stages, initial_tasks) + finally: + # The executor's ray.shutdown() may have run in its own + # finally:; reconnect to clean up the detached actor. + try: + ray.init(ignore_reinit_error=True) + actor_handle = ray.get_actor(ACTOR_NAME) + ray.get(actor_handle.close.remote(), timeout=10) # type: ignore[attr-defined] + ray.kill(actor_handle) + except Exception as e: # noqa: BLE001 + logger.warning(f"resumability actor cleanup failed: {e}") diff --git a/nemo_curator/tasks/__init__.py b/nemo_curator/tasks/__init__.py index bf8d32a150..d45f3b778d 100644 --- a/nemo_curator/tasks/__init__.py +++ b/nemo_curator/tasks/__init__.py @@ -17,17 +17,19 @@ from .file_group import FileGroupTask from .image import ImageBatch, ImageObject from .interleaved import InterleavedBatch -from .sentinels import EmptyTask, SentinelTask +from .sentinels import EmptyTask, FailedTask, NoneTask, SentinelTask from .tasks import Task __all__ = [ "AudioTask", "DocumentBatch", "EmptyTask", + "FailedTask", "FileGroupTask", "ImageBatch", "ImageObject", "InterleavedBatch", + "NoneTask", "SentinelTask", "Task", ] diff --git a/nemo_curator/tasks/sentinels.py b/nemo_curator/tasks/sentinels.py index 84896dd963..bedc5e8cd6 100644 --- a/nemo_curator/tasks/sentinels.py +++ b/nemo_curator/tasks/sentinels.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Payload-less marker tasks. - -``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). All markers -share the :class:`SentinelTask` base and carry no payload (``data is None``). -Construct one with ``EmptyTask()``. +"""Payload-less marker tasks on a shared :class:`SentinelTask` base: +``EmptyTask`` (pipeline seed, root id ``"0"``), ``NoneTask`` (filtered slot; +counter decrements), ``FailedTask`` (failed slot; counter unchanged so its +source stays pending and reruns). All carry no data, get a framework-assigned +``task_id``, and are stripped before the next stage. """ from dataclasses import dataclass, field @@ -25,8 +25,7 @@ @dataclass class SentinelTask(Task[None]): - """Base for payload-less marker tasks. Always carries no data; ``task_id`` - is framework-assigned like any other task.""" + """Base for payload-less marker tasks: no data, framework-assigned ``task_id``.""" data: None = None @@ -44,11 +43,22 @@ def validate(self) -> bool: @dataclass class EmptyTask(SentinelTask): - """Payload-less task that seeds a pipeline. Its ``task_id`` is fixed to - ``"0"`` — the implicit root every task in a run descends from, so all - ``task_id``s share the ``"0"`` prefix (source partitions become - ``"0_"``, user-provided initial tasks become ``"0_0"``, ``"0_1"``, …). - """ + """Seeds a pipeline with ``task_id="0"`` — the implicit root every task + descends from (so all ids share the ``"0"`` prefix).""" dataset_name: str = "empty" task_id: str = field(init=False, default="0") + + +@dataclass +class NoneTask(SentinelTask): + """Marks a slot as intentionally filtered (resumability counter decrements).""" + + dataset_name: str = "none" + + +@dataclass +class FailedTask(SentinelTask): + """Marks a slot as failed → retried on resume (counter does NOT decrement).""" + + dataset_name: str = "failed" diff --git a/nemo_curator/tasks/tasks.py b/nemo_curator/tasks/tasks.py index 04bfb5caf0..e1200614fb 100644 --- a/nemo_curator/tasks/tasks.py +++ b/nemo_curator/tasks/tasks.py @@ -46,6 +46,9 @@ class Task(ABC, Generic[T]): NON-deterministic (differ across runs). dataset_name: Name of the dataset this task belongs to. _stage_perf: List of stages perfs this task has passed through. + _source_id: Source (input partition) this task descends from. Stamped at + the source stage, inherited downstream; used only by the opt-in + resumability layer. Empty for pre-source tasks. """ dataset_name: str @@ -53,6 +56,7 @@ class Task(ABC, Generic[T]): _stage_perf: list[StagePerfStats] = field(default_factory=list) _metadata: dict[str, Any] = field(default_factory=dict) task_id: str = field(init=False, default="") + _source_id: str = field(init=False, default="") def __post_init__(self) -> None: """Post-initialization hook.""" diff --git a/nemo_curator/utils/resumability_actor.py b/nemo_curator/utils/resumability_actor.py new file mode 100644 index 0000000000..32be4bbb9d --- /dev/null +++ b/nemo_curator/utils/resumability_actor.py @@ -0,0 +1,191 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Per-writer LMDB owner tracking per-source completion for resumability. + +LMDB can't be safely shared by writers across hosts (its lock lives in an +mmap'd file not shared on a networked FS), so each actor writes ONLY its own +``/-.mdb`` and on startup reads the UNION of completed sources +across every ``*.mdb`` in the dir. A rerun thus skips everything any prior +writer finished — letting the tasks of a SLURM array share one checkpoint dir. + +``apply_deltas`` is fire-and-forget and never raises; see its docstring for the +dedup/rewrite/anomaly rules. +""" + +from __future__ import annotations + +import os +import socket +from pathlib import Path +from typing import TYPE_CHECKING + +import lmdb +import ray +from loguru import logger + +if TYPE_CHECKING: + from collections.abc import Iterable + + +_COMPLETED_DB = b"completed_sources" +_DEFAULT_MAP_SIZE = 1 << 30 # 1 GiB; sparse on Linux so effectively free +# Subdirectory (under the user-provided checkpoint dir) that holds the +# per-writer LMDB files. Hidden so it sits unobtrusively next to outputs. +METADATA_DIRNAME = ".nemo_curator_metadata" + + +@ray.remote(num_cpus=0, max_concurrency=1) +class ResumabilityActor: + """Per-writer counter + LMDB owner. Spawned by ``Pipeline.run`` with + ``lifetime="detached"`` and closed at end-of-run; ``apply_deltas`` is + fire-and-forget and never raises.""" + + def __init__(self, base_dir: str, map_size: int = _DEFAULT_MAP_SIZE, writer_id: str | None = None): + # Per-writer LMDB files live under /.nemo_curator_metadata/. + self._dir = Path(base_dir).absolute() / METADATA_DIRNAME + self._dir.mkdir(parents=True, exist_ok=True) + # The ONLY file this actor writes, keyed by writer id (default host-pid, + # unique across concurrent writers; a pid-recycled rerun safely reuses it). + wid = writer_id or f"{socket.gethostname()}-{os.getpid()}" + self._path = str(self._dir / f"{wid}.mdb") + self._env = lmdb.open( + self._path, + subdir=False, + lock=False, # sole writer of this file → no inter-process lock needed + max_dbs=1, + map_size=map_size, + metasync=False, + sync=True, + readahead=False, + ) + self._db = self._env.open_db(_COMPLETED_DB) + self._pending: dict[str, int] = {} + # Union of completed sources across ALL writer files in the dir. + self._completed: set[str] = self._load_completed() + # task_id -> last delta applied: same delta = dedup skip; different = rewrite. + self._applied: dict[str, int] = {} + + def _read_completed_from(self, env: lmdb.Environment) -> set[str]: + """Completed-source ids from an open LMDB env (empty if it has no completed-sources db yet).""" + try: + db = env.open_db(_COMPLETED_DB) + except lmdb.Error: + return set() + with env.begin() as txn, txn.cursor(db=db) as cur: + return {k.decode() for k, _ in cur} + + def _load_completed(self) -> set[str]: + """Union of completed sources across all writer files; unreadable files + (mid-write, or open in-process during tests) are skipped with a warning.""" + done = self._read_completed_from(self._env) # our own (possibly reused) file + for mdb in sorted(self._dir.glob("*.mdb")): + if str(mdb) == self._path: + continue + try: + env = lmdb.open(str(mdb), subdir=False, readonly=True, lock=False, max_dbs=1) + except lmdb.Error as e: + logger.warning(f"resumability: skipping unreadable checkpoint {mdb}: {e}") + continue + try: + done |= self._read_completed_from(env) + finally: + env.close() + return done + + # ------------------------------------------------------------ read + + def are_completed(self, source_ids: list[str]) -> list[bool]: + """Parallel bool list: which source_ids are complete (skip on rerun).""" + return [sid in self._completed for sid in source_ids] + + # ------------------------------------------------------------ write + + def apply_deltas(self, per_task: list[tuple[str, str, int]]) -> None: + """Apply per-task counter deltas (fire-and-forget; no ``ray.get``). + + Each tuple is ``(task_id, source_id, delta)``: + - seen ``task_id``, same delta → skip (Ray-retry idempotency). + - seen ``task_id``, different delta → rewrite ``_pending`` by ``-old+new``. + - any delta for an already-completed source → warn and un-complete it + (in-memory + LMDB) so it reprocesses next run (indicates a bug). + - else → apply; persist the source when its counter hits 0. + + Never raises. + """ + newly_done: list[str] = [] + for task_id, sid, d in per_task: + existing = self._applied.get(task_id) + if existing is not None: + if existing == d: + continue # idempotent re-fire + if sid in self._completed: + # Source already finalized but we're getting a different + # delta for one of its tasks — the source wasn't actually + # done. Un-complete it so it reruns next launch. + logger.warning( + f"resumability: task {task_id} delta changed from " + f"{existing} to {d} but source {sid!r} is already " + f"completed. Removing {sid!r} from the completed set " + f"so it will be reprocessed on the next run. Please " + f"file an issue at " + f"https://github.com/NVIDIA-NeMo/Curator if this is " + f"unexpected." + ) + self._remove_from_completed(sid) + continue + # Rewrite-on-conflict: the newest delta wins. + self._applied[task_id] = d + self._pending[sid] = self._pending.get(sid, 0) + (-existing) + d + else: + # New task id. + if sid in self._completed: + logger.warning( + f"resumability: source {sid!r} got update for new " # noqa: S608 + f"task {task_id} (delta={d}) after being completed. " + f"Removing {sid!r} from the completed set so it will " + f"be reprocessed on the next run. Please file an " + f"issue at https://github.com/NVIDIA-NeMo/Curator." + ) + self._remove_from_completed(sid) + continue + self._applied[task_id] = d + self._pending[sid] = self._pending.get(sid, 0) + d + if self._pending[sid] == 0: + newly_done.append(sid) + if newly_done: + self._persist_completed(newly_done) + for sid in newly_done: + self._completed.add(sid) + self._pending.pop(sid, None) + + def _persist_completed(self, sids: Iterable[str]) -> None: + with self._env.begin(write=True) as txn: + for sid in sids: + txn.put(sid.encode(), b"1", db=self._db, overwrite=True) + + def _remove_from_completed(self, sid: str) -> None: + """Un-complete ``sid`` (in-memory + our LMDB file) so it reruns. If a + *different* writer completed it, that entry can't be removed and may + reappear from the union next startup — acceptable for this rare path.""" + self._completed.discard(sid) + with self._env.begin(write=True) as txn: + txn.delete(sid.encode(), db=self._db) + + def close(self) -> None: + if self._env is not None: + try: + self._env.close() + except Exception as e: # noqa: BLE001 + logger.warning(f"failed to close LMDB env: {e}") + self._env = None # type: ignore[assignment] diff --git a/nemo_curator/utils/resumability_client.py b/nemo_curator/utils/resumability_client.py new file mode 100644 index 0000000000..9a4f9b73d5 --- /dev/null +++ b/nemo_curator/utils/resumability_client.py @@ -0,0 +1,57 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Worker-side helpers to talk to the resumability actor; all no-ops when no +actor is registered, so unchecked pipelines pay nothing. +""" + +from __future__ import annotations + +import ray + +# Defined here (not imported from resumability_actor) so the always-imported +# worker path doesn't pull in lmdb until resumability is actually used. +ACTOR_NAME = "nemo_curator_resumability" + + +def _actor() -> ray.actor.ActorHandle | None: + """The resumability actor handle, or None if Ray is down / no actor registered.""" + if not ray.is_initialized(): + return None + try: + return ray.get_actor(ACTOR_NAME) + except ValueError: + return None + + +def _is_active() -> bool: + """True if a resumability actor is registered in this Ray cluster.""" + return _actor() is not None + + +def _flush_deltas(per_task: list[tuple[str, str, int]]) -> None: + """Fire-and-forget per-task deltas ``(task_id, source_id, delta)``. No + ``ray.get`` — the actor never raises, so there's no error path; backpressure + is the actor's ``max_pending_calls`` cap.""" + a = _actor() + if a is not None and per_task: + a.apply_deltas.remote(per_task) # type: ignore[attr-defined] + + +def _skip_completed_sources(source_ids: list[str]) -> set[str]: + """Set of ``source_ids`` already marked complete; the source stage uses it to skip them.""" + a = _actor() + if a is None or not source_ids: + return set() + flags = ray.get(a.are_completed.remote(source_ids)) # type: ignore[attr-defined] + return {sid for sid, done in zip(source_ids, flags, strict=True) if done} diff --git a/pyproject.toml b/pyproject.toml index bd10a5337b..5e823c8f43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ dependencies = [ "fsspec", "hydra-core", "jieba==0.42.1", + "lmdb>=1.4", "loguru", "mecab-python3", "omegaconf", diff --git a/tests/backends/test_resumability_adapter.py b/tests/backends/test_resumability_adapter.py new file mode 100644 index 0000000000..a78db75f1f --- /dev/null +++ b/tests/backends/test_resumability_adapter.py @@ -0,0 +1,242 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the resumability counter step (``_apply_resumability_counters``) +and the ``None``->``NoneTask`` normalization in ``BaseStageAdapter``, with the +actor RPCs mocked out. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from unittest.mock import patch + +from nemo_curator.backends.base import BaseStageAdapter +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import FailedTask, NoneTask, Task + + +@dataclass +class _NoopStage(ProcessingStage[Task, Task]): + name: str = "noop" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: Task) -> Task: + return task + + +@dataclass +class _DropStage(ProcessingStage[Task, Task]): + """A non-source stage that filters every input (returns ``None``).""" + + name: str = "drop" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: Task) -> None: + return None + + +@dataclass +class _SimpleTask(Task[list[int]]): + @property + def num_items(self) -> int: + return 0 + + def validate(self) -> bool: + return True + + +def _task(task_id: str = "", source_id: str = "") -> _SimpleTask: + t = _SimpleTask(dataset_name="d", data=[]) + t.task_id = task_id # pretend _post_process_task_ids already ran + t._source_id = source_id + return t + + +def _counters( + stage: ProcessingStage, + input_tasks: list[Task], + output_tasks: list[Any], + *, + completed: set[str] | None = None, +) -> tuple[list[Task], list[tuple[str, str, int]]]: + """Run ``_apply_resumability_counters`` with the actor RPCs patched. + Returns ``(surviving_outputs, captured_deltas)``.""" + captured: list[tuple[str, str, int]] = [] + + with ( + patch("nemo_curator.backends.base._flush_deltas", side_effect=captured.extend), + patch("nemo_curator.backends.base._skip_completed_sources", return_value=completed or set()), + ): + out = BaseStageAdapter(stage)._apply_resumability_counters(input_tasks, output_tasks) + return out, captured + + +def _process( + stage: ProcessingStage, + tasks: list[Task], + *, + completed: set[str] | None = None, +) -> tuple[list[Task], list[tuple[str, str, int]]]: + """Run the full ``process_batch`` with the resumability actor patched + active. Returns ``(surviving_outputs, captured_deltas)``.""" + captured: list[tuple[str, str, int]] = [] + with ( + patch("nemo_curator.backends.base._is_active", return_value=True), + patch("nemo_curator.backends.base._flush_deltas", side_effect=captured.extend), + patch("nemo_curator.backends.base._skip_completed_sources", return_value=completed or set()), + ): + out = BaseStageAdapter(stage).process_batch(tasks) + return out, captured + + +class TestNoneNormalization: + """A returned ``None`` is normalized to a ``NoneTask`` inside + ``process_batch``: it decrements its slot's source counter and is then + stripped, so it never reaches the next stage.""" + + def test_returned_none_decrements_and_is_stripped(self) -> None: + parent = _task("s_0", source_id="s") + out, captured = _process(_DropStage(), [parent]) + assert out == [] # the NoneTask sentinel is stripped from the output + # Keyed on the NoneTask's assigned OUTPUT id ("s_0_0"), not the parent. + assert captured == [("s_0_0", "s", -1)] # the filtered slot is consumed + + +class TestSourceStage: + def _src_stage(self) -> _NoopStage: + s = _NoopStage() + s.is_source_stage = True + return s + + def test_stamps_source_id_and_fires_plus_one(self) -> None: + empty = _task("0") # EmptyTask-like root + a, b = _task("0_aaa"), _task("0_bbb") + out, captured = _counters(self._src_stage(), [empty], [a, b]) + assert out == [a, b] + # _source_id is the task_id's last segment (its content id / index). + assert a._source_id == "aaa" + assert b._source_id == "bbb" + assert sorted(captured) == [("0_aaa", "aaa", 1), ("0_bbb", "bbb", 1)] + + def test_drops_already_completed_sources(self) -> None: + empty = _task("0") + a, b, c = _task("0_a"), _task("0_b"), _task("0_c") + out, captured = _counters(self._src_stage(), [empty], [a, b, c], completed={"b"}) + assert out == [a, c] + assert {sid for _, sid, _ in captured} == {"a", "c"} + + +class TestNonSourceStage: + def test_pre_source_is_noop(self) -> None: + # Inputs carry no _source_id yet -> nothing tracked, outputs untouched. + a = _task("0_0") + out, captured = _counters(_NoopStage(), [a], [a]) + assert out == [a] + assert captured == [] + + def test_one_to_one_nonsink_zero_delta(self) -> None: + stage = _NoopStage() + stage.is_sink_stage = False + parent = _task("s_0", source_id="s") + child = _task("s_0_0") + _out, captured = _counters(stage, [parent], [child]) + # Keyed on the OUTPUT id (child), not the parent. + assert captured == [("s_0_0", "s", 0)] + assert child._source_id == "s" # inherited + + def test_one_to_one_sink_minus_one(self) -> None: + stage = _NoopStage() + stage.is_sink_stage = True + parent = _task("s_0", source_id="s") + _out, captured = _counters(stage, [parent], [_task("s_0_0")]) + assert captured == [("s_0_0", "s", -1)] + + def test_nonetask_slot_decrements(self) -> None: + # NoneTask keys on its own assigned OUTPUT id (here "s_0_0"), so it can't + # collide with the source's +1 for the same partition ("s_0"). + parent = _task("s_0", source_id="s") + nt = NoneTask() + nt.task_id = "s_0_0" + _out, captured = _counters(_NoopStage(), [parent], [nt]) + assert captured == [("s_0_0", "s", -1)] + + def test_failedtask_slot_zero_delta(self) -> None: + # Failed fires delta 0 (keyed on its output id): the source's +1 stays, + # so the source never completes and reruns. No sink test for Failed. + parent = _task("s_0", source_id="s") + ft = FailedTask() + ft.task_id = "s_0_0" + _out, captured = _counters(_NoopStage(), [parent], [ft]) + assert captured == [("s_0_0", "s", 0)] + + def test_fanout_grows_counter(self) -> None: + stage = _NoopStage() + stage.is_sink_stage = False + parent = _task("s_0", source_id="s") + c0, c1, c2 = _task("s_0_0"), _task("s_0_1"), _task("s_0_2") + _out, captured = _counters(stage, [parent], [c0, c1, c2]) + # 1 input -> 3 real children: net +2, keyed on output[0] ("s_0_0"). + assert captured == [("s_0_0", "s", 2)] + assert all(c._source_id == "s" for c in (c0, c1, c2)) + + def test_fanout_nonsink_mixed_real_none_failed(self) -> None: + # 1 -> [real, None, Failed], non-sink: real continues (+1), Failed keeps + # the source open (+1), None contributes 0, parent consumed (-1). + stage = _NoopStage() + stage.is_sink_stage = False + parent = _task("s_0", source_id="s") + real = _task("s_0_0") + nt, ft = NoneTask(), FailedTask() + nt.task_id, ft.task_id = "s_0_1", "s_0_2" + _out, captured = _counters(stage, [parent], [real, nt, ft]) + # net = continuing(1) + failed(1) - 1 = 1, keyed on output[0]. + assert captured == [("s_0_0", "s", 1)] + + def test_fanout_sink_real_outputs_leave(self) -> None: + # 1 -> [real, real, Failed] at a SINK: real outputs leave (0), Failed + # stays (+1), parent consumed (-1) -> net 0, source stays open. + stage = _NoopStage() + stage.is_sink_stage = True + parent = _task("s_0", source_id="s") + r0, r1 = _task("s_0_0"), _task("s_0_1") + ft = FailedTask() + ft.task_id = "s_0_2" + _out, captured = _counters(stage, [parent], [r0, r1, ft]) + assert captured == [("s_0_0", "s", 0)] + + def test_empty_output_skips(self) -> None: + # A stage that emits nothing (not even a NoneTask) is degenerate: there + # is no output to key a delta on, so nothing is fired. + parent = _task("s_0", source_id="s") + out, captured = _counters(_NoopStage(), [parent], []) + assert captured == [] + assert out == [] + + def test_ambiguous_batch_skips_counters(self) -> None: + # 2 inputs -> 3 outputs: can't attribute, so no deltas are fired. + p0, p1 = _task("s_0", source_id="s"), _task("s_1", source_id="s") + out, captured = _counters(_NoopStage(), [p0, p1], [_task(), _task(), _task()]) + assert captured == [] + assert len(out) == 3 diff --git a/tests/backends/test_resumability_functional.py b/tests/backends/test_resumability_functional.py new file mode 100644 index 0000000000..6242715d2b --- /dev/null +++ b/tests/backends/test_resumability_functional.py @@ -0,0 +1,164 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end resume loop without a Ray cluster: drives the real +``BaseStageAdapter`` against a real LMDB-backed ``ResumabilityActor`` over two +runs sharing a checkpoint dir. Exercises the full counter → actor → LMDB → skip +loop (and would fail under the old parent-id keying bug, where completed sources +never persisted). +""" + +from __future__ import annotations + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING +from unittest.mock import patch + +from nemo_curator.backends.base import BaseStageAdapter +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import EmptyTask, FailedTask, Task +from nemo_curator.utils.resumability_actor import ResumabilityActor + +if TYPE_CHECKING: + from collections.abc import Iterator + from pathlib import Path + + +def _new_actor(base_dir: Path, writer_id: str): # noqa: ANN202 (undecorated Ray actor class instance) + """A real actor instance (undecorated class — no Ray cluster needed), + writing its own ``.mdb`` and reading the union on startup.""" + cls = ResumabilityActor.__ray_metadata__.modified_class # type: ignore[attr-defined] + return cls(str(base_dir), writer_id=writer_id) + + +@dataclass +class _IntTask(Task[int]): + data: int = 0 + + @property + def num_items(self) -> int: + return 1 + + def validate(self) -> bool: + return True + + +@dataclass +class _Source(ProcessingStage[EmptyTask, _IntTask]): + name: str = "source" + n: int = 0 + is_source_stage: bool = True + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return ["data"], [] + + def process(self, _: EmptyTask) -> list[_IntTask]: + return [_IntTask(data=i, dataset_name="d") for i in range(self.n)] + + +@dataclass +class _Sink(ProcessingStage[_IntTask, _IntTask]): + name: str = "sink" + is_sink_stage: bool = True + fail: tuple[int, ...] = () + none: tuple[int, ...] = () + newobj: bool = False + + def inputs(self) -> tuple[list[str], list[str]]: + return ["data"], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return ["data"], [] + + def process(self, task: _IntTask) -> _IntTask | FailedTask | None: + if task.data in self.fail: + return FailedTask() + if task.data in self.none: + return None + return _IntTask(data=task.data, dataset_name="d") if self.newobj else task + + +@contextmanager +def _wired(actor) -> Iterator[None]: # noqa: ANN001 + """Point the worker-side client helpers at ``actor`` directly (no Ray).""" + + def _skip(sids: list[str]) -> set[str]: + return {s for s, done in zip(sids, actor.are_completed(sids), strict=True) if done} + + with ( + patch("nemo_curator.backends.base._is_active", return_value=True), + patch("nemo_curator.backends.base._flush_deltas", side_effect=actor.apply_deltas), + patch("nemo_curator.backends.base._skip_completed_sources", side_effect=_skip), + ): + yield + + +def _run(actor, n: int, **sink_kwargs) -> tuple[list[str], list[str]]: # noqa: ANN001 + """One source→sink pass through the real adapters. Returns + ``(source ids that survived the source stage, sink-output source ids)``.""" + with _wired(actor): + src_out = BaseStageAdapter(_Source(n=n)).process_batch([EmptyTask()]) + src_ids = sorted(t._source_id for t in src_out) + sink_out = BaseStageAdapter(_Sink(**sink_kwargs)).process_batch(src_out) + return src_ids, sorted(t._source_id for t in sink_out) + + +class TestResumeLoop: + def test_completed_skip_and_failed_reruns(self, tmp_path: Path) -> None: + # Run 1: sources 0,1,2; source 1 fails at the sink. + a1 = _new_actor(tmp_path, "w1") + src1, _ = _run(a1, n=3, fail=(1,)) + assert src1 == ["0", "1", "2"] # all emitted on the first run + assert a1.are_completed(["0", "1", "2"]) == [True, False, True] # 1 stays pending + a1.close() + + # Run 2: fresh actor reads w1's completions from the union; source 1 + # now succeeds. Only the previously-failed source reruns. + a2 = _new_actor(tmp_path, "w2") + src2, sink2 = _run(a2, n=3) + assert src2 == ["1"] # 0 and 2 skipped at the source stage + assert sink2 == ["1"] + assert a2.are_completed(["0", "1", "2"]) == [True, True, True] + a2.close() + + def test_new_object_sink_completes(self, tmp_path: Path) -> None: + # Sink returns a NEW object (not the input) — the counter must key on the + # output id, not the input's, or the source's +1 and the sink's -1 would + # collide and the source would never complete. This locks that fix. + a1 = _new_actor(tmp_path, "w1") + _run(a1, n=3, newobj=True) + assert a1.are_completed(["0", "1", "2"]) == [True, True, True] + a1.close() + + a2 = _new_actor(tmp_path, "w2") + src2, _ = _run(a2, n=3, newobj=True) + assert src2 == [] # everything already complete -> nothing reruns + a2.close() + + def test_none_filtered_source_completes(self, tmp_path: Path) -> None: + # A source filtered to None at the sink is consumed -> completes (it must + # NOT behave like Failed). It is skipped, not rerun, on the next run. + a1 = _new_actor(tmp_path, "w1") + _, sink1 = _run(a1, n=3, none=(1,)) + assert sink1 == ["0", "2"] # source 1 produced no sink output (filtered) + assert a1.are_completed(["0", "1", "2"]) == [True, True, True] # but it completed + a1.close() + + a2 = _new_actor(tmp_path, "w2") + src2, _ = _run(a2, n=3, none=(1,)) + assert src2 == [] # source 1 was completed by the filter, not left pending + a2.close() diff --git a/tests/tasks/test_sentinels.py b/tests/tasks/test_sentinels.py new file mode 100644 index 0000000000..aa4785a32f --- /dev/null +++ b/tests/tasks/test_sentinels.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the payload-less sentinel tasks: the ``SentinelTask`` base, +``EmptyTask`` (pipeline seed, ``task_id="0"``), and the ``NoneTask`` / +``FailedTask`` resumability markers (framework-assigned ``task_id``). +""" + +from __future__ import annotations + +import pytest + +from nemo_curator.tasks import EmptyTask, FailedTask, NoneTask, SentinelTask, Task + + +class TestSentinelBase: + def test_subclasses_are_tasks(self) -> None: + for obj in (SentinelTask(dataset_name="s"), EmptyTask(), NoneTask(), FailedTask()): + assert isinstance(obj, Task) + assert isinstance(obj, SentinelTask) + + def test_carry_no_data(self) -> None: + for obj in (SentinelTask(dataset_name="s"), EmptyTask(), NoneTask(), FailedTask()): + assert obj.data is None + + def test_num_items_is_zero(self) -> None: + for obj in (SentinelTask(dataset_name="s"), EmptyTask(), NoneTask(), FailedTask()): + assert obj.num_items == 0 + + def test_validate_is_true(self) -> None: + for obj in (SentinelTask(dataset_name="s"), EmptyTask(), NoneTask(), FailedTask()): + assert obj.validate() is True + + def test_rejects_payload(self) -> None: + # The base asserts ``data is None`` so a sentinel can never carry data. + with pytest.raises(AssertionError): + SentinelTask(dataset_name="s", data="oops") + + +class TestEmptyTask: + def test_is_rooted_at_zero(self) -> None: + # EmptyTask is the implicit root every task descends from. + assert EmptyTask().task_id == "0" + assert EmptyTask().dataset_name == "empty" + + def test_task_id_is_not_user_settable(self) -> None: + # ``task_id`` is init=False, so it cannot be passed positionally/kw. + with pytest.raises(TypeError): + EmptyTask(task_id="5") # type: ignore[call-arg] + + +class TestResumabilityMarkers: + def test_dataset_names(self) -> None: + assert NoneTask().dataset_name == "none" + assert FailedTask().dataset_name == "failed" + + def test_task_id_unset_until_assigned(self) -> None: + # Unlike EmptyTask, these get their id from the adapter; default empty. + assert NoneTask().task_id == "" + assert FailedTask().task_id == "" + + def test_none_and_failed_are_distinct(self) -> None: + assert not isinstance(NoneTask(), FailedTask) + assert not isinstance(FailedTask(), NoneTask) diff --git a/tests/utils/test_resumability_actor.py b/tests/utils/test_resumability_actor.py new file mode 100644 index 0000000000..a0a19158de --- /dev/null +++ b/tests/utils/test_resumability_actor.py @@ -0,0 +1,288 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for :class:`ResumabilityActor` (counter math, dedup, +rewrite-on-conflict, LMDB persistence). Instantiates the actor class directly +(no ``@ray.remote``), so no live Ray cluster is needed. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import patch + +from nemo_curator.utils.resumability_actor import ResumabilityActor + +if TYPE_CHECKING: + from pathlib import Path + + +def _new_actor(tmp_path: Path, writer_id: str | None = None) -> ResumabilityActor: + """Bypass ``@ray.remote`` and instantiate the actor class directly. + + Ray's ``@ray.remote`` decorator stashes the original class on + ``__ray_metadata__.modified_class``. ``tmp_path`` is the checkpoint + directory; the actor keeps its LMDB file under + ``tmp_path/.nemo_curator_metadata/``. ``writer_id`` distinguishes writers + sharing that directory (defaults to host+pid in production); pass distinct + ids to simulate concurrent runs / SLURM-array tasks. + """ + cls = ResumabilityActor.__ray_metadata__.modified_class # type: ignore[attr-defined] + return cls(str(tmp_path), writer_id=writer_id) + + +class TestApplyDeltasCounterMath: + def test_source_emit_increments_pending(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1), ("h1", "1", +1)]) + assert actor._pending == {"0": 1, "1": 1} + assert actor._completed == set() + actor.close() + + def test_counter_reaches_zero_persists_to_lmdb(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + actor.apply_deltas([("h_sink", "0", -1)]) + assert actor._completed == {"0"} + assert "0" not in actor._pending + actor.close() + + # Reopen the actor and confirm "0" survives in LMDB. + actor2 = _new_actor(tmp_path) + assert actor2._completed == {"0"} + actor2.close() + + def test_nonsink_real_task_is_zero_delta(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + actor.apply_deltas([("h_passthrough", "0", 0)]) + assert actor._pending == {"0": 1} + assert actor._completed == set() + actor.close() + + def test_nonetask_decrements(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + actor.apply_deltas([("h_filter", "0", -1)]) + assert actor._completed == {"0"} + actor.close() + + def test_fanout_grows_counter(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + # Fan-out 1→3 emits delta = (3-1) = +2 on the parent's source. + actor.apply_deltas([("h_fanout", "0", +2)]) + assert actor._pending == {"0": 3} + actor.close() + + +class TestDedupAndRewrite: + def test_same_task_same_delta_is_idempotent(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + actor.apply_deltas([("h_t", "0", -1)]) + # Second identical fire — should be a no-op (Ray retry idempotency). + actor.apply_deltas([("h_t", "0", -1)]) + assert actor._completed == {"0"} + actor.close() + + def test_same_task_different_delta_rewrites(self, tmp_path: Path) -> None: + """When a Ray retry fires a different delta for the same task hash, + the actor adjusts pending by (-old + new) so the latest observation + wins. Never raises.""" + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + # First the worker says delta=0 (real Task passed through). + actor.apply_deltas([("h_t", "0", 0)]) + assert actor._pending == {"0": 1} + + # Retry says delta=-1 (NoneTask this time). Rewrite: pending += -0 + -1. + actor.apply_deltas([("h_t", "0", -1)]) + assert actor._completed == {"0"} + # And the recorded delta is updated to the new value. + assert actor._applied["h_t"] == -1 + actor.close() + + def test_rewrite_does_not_raise(self, tmp_path: Path) -> None: + """apply_deltas never raises; rewrite is silent.""" + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + # Multiple conflicting deltas for the same task: should not raise. + actor.apply_deltas([("h_t", "0", 0)]) + actor.apply_deltas([("h_t", "0", +5)]) + actor.apply_deltas([("h_t", "0", -1)]) + # Final state reflects the last delta. + assert actor._applied["h_t"] == -1 + actor.close() + + +class TestUncompleteOnAnomaly: + def test_new_task_after_source_completed_warns_and_uncompletes(self, tmp_path: Path) -> None: + """If a delta arrives for a never-seen task on an already-completed + source, the source wasn't actually done. Un-complete it (in-memory + and in LMDB) so it reruns next launch.""" + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1), ("h_t", "0", -1)]) + assert actor._completed == {"0"} + + with patch("nemo_curator.utils.resumability_actor.logger") as mock_logger: + actor.apply_deltas([("h_late", "0", -1)]) + mock_logger.warning.assert_called_once() + warn_msg = mock_logger.warning.call_args[0][0] + assert "Removing" in warn_msg + assert "completed set" in warn_msg + + # Source has been removed from the in-memory completed set. + assert "0" not in actor._completed + # And from LMDB — reopen and confirm. + actor.close() + actor2 = _new_actor(tmp_path) + assert "0" not in actor2._completed + actor2.close() + + def test_rewrite_attempt_after_source_completed_warns_and_uncompletes(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + actor.apply_deltas([("h_t", "0", -1)]) + assert actor._completed == {"0"} + + # Same task tries to rewrite to a different delta after completion. + with patch("nemo_curator.utils.resumability_actor.logger") as mock_logger: + actor.apply_deltas([("h_t", "0", 0)]) + mock_logger.warning.assert_called_once() + warn_msg = mock_logger.warning.call_args[0][0] + assert "Removing" in warn_msg + + # Source has been uncompleted. + assert "0" not in actor._completed + actor.close() + actor2 = _new_actor(tmp_path) + assert "0" not in actor2._completed + actor2.close() + + def test_apply_deltas_never_raises(self, tmp_path: Path) -> None: + """The whole point of removing the error machinery — no path through + apply_deltas should raise.""" + actor = _new_actor(tmp_path) + # Throw lots of weird stuff at it. + actor.apply_deltas([("h0", "0", +1)]) + actor.apply_deltas([("h_t", "0", -1)]) # completes 0 + actor.apply_deltas([("h_t", "0", +5)]) # rewrite on completed source: warn + uncomplete + actor.apply_deltas([("h_new", "0", -1)]) # new hash, source no longer in completed + actor.apply_deltas([("h1", "1", +1), ("h_t1", "1", -5)]) # negative pending + # Reached this line without raising — pass. + actor.close() + + +class TestAreCompleted: + def test_returns_parallel_bool_list(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1), ("h1", "1", +1)]) + actor.apply_deltas([("h_t", "0", -1)]) + assert actor.are_completed(["0", "1", "unknown"]) == [True, False, False] + actor.close() + + def test_loads_from_lmdb_on_construction(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h_a", "a", +1), ("h_t", "a", -1)]) + assert actor._completed == {"a"} + actor.close() + + actor2 = _new_actor(tmp_path) + assert actor2.are_completed(["a", "b"]) == [True, False] + actor2.close() + + +class TestLifecycle: + def test_close_is_idempotent(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.close() + actor.close() # second close is a no-op + + def test_one_lmdb_write_per_completed_source(self, tmp_path: Path) -> None: + """Sanity-check the 'write only when a counter hits zero' contract: + a still-pending source is never persisted; once it completes it is. + + We verify via close/reopen rather than a concurrent second reader: + lmdb refuses to open the same env file twice in one process, and in + production a single detached actor owns each checkpoint file anyway. + """ + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + # Source 0 is pending (counter != 0) — nothing recorded as completed. + assert actor._completed == set() + # Counter hits zero — now it's recorded. + actor.apply_deltas([("h_t", "0", -1)]) + assert actor._completed == {"0"} + actor.close() + + # A fresh actor loads exactly the one completed source from LMDB. + actor_c = _new_actor(tmp_path) + assert actor_c._completed == {"0"} + actor_c.close() + + +def test_no_lmdb_writes_for_pending_only_deltas(tmp_path: Path) -> None: + """Pending counters change in-memory only; LMDB is touched solely + when a counter hits zero.""" + actor = _new_actor(tmp_path) + # Lots of activity, but no source resolves. + actor.apply_deltas([("h0", "0", +1), ("h1", "1", +1), ("h_fanout_0", "0", +2)]) + actor.close() + + # Fresh actor: nothing persisted. + actor2 = _new_actor(tmp_path) + assert actor2._completed == set() + actor2.close() + + +class TestMultipleWriters: + """Shared metadata dir with one LMDB file per writer (the SLURM-array + model): each writer records ONLY its own completions; later writers read + the union across all writers' files on startup.""" + + def test_union_of_completed_across_writers(self, tmp_path: Path) -> None: + # Writer A finishes source "0". + a = _new_actor(tmp_path, writer_id="hostA-1") + a.apply_deltas([("hA", "0", +1), ("hA_sink", "0", -1)]) + assert a._completed == {"0"} + a.close() + + # Writer B starts later, sees A's completion in the union, finishes "1". + b = _new_actor(tmp_path, writer_id="hostB-2") + assert b.are_completed(["0", "1"]) == [True, False] + b.apply_deltas([("hB", "1", +1), ("hB_sink", "1", -1)]) + b.close() + + # A fresh writer sees the union of everything finished so far. + c = _new_actor(tmp_path, writer_id="hostC-3") + assert c.are_completed(["0", "1", "2"]) == [True, True, False] + c.close() + + # Each writer wrote its OWN file — nothing is shared. + files = sorted(p.name for p in (tmp_path / ".nemo_curator_metadata").glob("*.mdb")) + assert files == ["hostA-1.mdb", "hostB-2.mdb", "hostC-3.mdb"] + + def test_writer_does_not_write_other_writers_files(self, tmp_path: Path) -> None: + # A finishes "s"; B finishes nothing. B must not have touched A's file, + # and A's completion is still readable on its own. + a = _new_actor(tmp_path, writer_id="A") + a.apply_deltas([("h", "s", +1), ("h_sink", "s", -1)]) + a.close() + + b = _new_actor(tmp_path, writer_id="B") # finishes nothing + b.close() + + reader = _new_actor(tmp_path, writer_id="reader") + assert reader.are_completed(["s"]) == [True] + reader.close() diff --git a/tests/utils/test_resumability_client.py b/tests/utils/test_resumability_client.py new file mode 100644 index 0000000000..168a5e0933 --- /dev/null +++ b/tests/utils/test_resumability_client.py @@ -0,0 +1,96 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the worker-side resumability client helpers (actor lookup, +no-op when inactive, delta fire, completed-source lookup). ``ray`` is mocked, +so no live cluster is needed. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from nemo_curator.utils import resumability_client as rc + + +class TestActorLookup: + def test_none_when_ray_not_initialized(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = False + assert rc._actor() is None + assert rc._is_active() is False + ray.get_actor.assert_not_called() + + def test_none_when_no_actor_registered(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = True + ray.get_actor.side_effect = ValueError("no such actor") + assert rc._actor() is None + assert rc._is_active() is False + + def test_returns_handle_when_registered(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = True + handle = MagicMock() + ray.get_actor.return_value = handle + assert rc._actor() is handle + assert rc._is_active() is True + ray.get_actor.assert_called_with(rc.ACTOR_NAME) + + +class TestFlushDeltas: + def test_fires_when_active_and_nonempty(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = True + handle = MagicMock() + ray.get_actor.return_value = handle + deltas = [("t0", "s0", 1), ("t1", "s0", -1)] + rc._flush_deltas(deltas) + handle.apply_deltas.remote.assert_called_once_with(deltas) + + def test_noop_when_no_deltas(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = True + handle = MagicMock() + ray.get_actor.return_value = handle + rc._flush_deltas([]) + handle.apply_deltas.remote.assert_not_called() + + def test_noop_when_inactive(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = False + # Must not raise even though there are deltas to send. + rc._flush_deltas([("t0", "s0", 1)]) + + +class TestSkipCompletedSources: + def test_returns_completed_subset(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = True + handle = MagicMock() + ray.get_actor.return_value = handle + ray.get.return_value = [True, False, True] + assert rc._skip_completed_sources(["a", "b", "c"]) == {"a", "c"} + handle.are_completed.remote.assert_called_once_with(["a", "b", "c"]) + + def test_empty_when_inactive(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = False + assert rc._skip_completed_sources(["a"]) == set() + + def test_empty_when_no_sources(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = True + ray.get_actor.return_value = MagicMock() + assert rc._skip_completed_sources([]) == set() + ray.get.assert_not_called() diff --git a/uv.lock b/uv.lock index 7509d39c76..35200b30ed 100644 --- a/uv.lock +++ b/uv.lock @@ -4357,6 +4357,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/ef/11292bb0b85cf4c93447cab5a29f64576ed14d3ab4280e35ddd23486594a/lm_format_enforcer-0.11.3-py3-none-any.whl", hash = "sha256:cf586350875def1ae7a8fba84fcbbfc8371424b6c9d05c1fcba70aa233fbf06f", size = 45418, upload-time = "2025-08-24T19:37:46.325Z" }, ] +[[package]] +name = "lmdb" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3d/fa/ddef3e433950e23844fd9d82fa045637cbe84140f482120bbdf6abe6be92/lmdb-2.2.1.tar.gz", hash = "sha256:b201b416f7d6cea9bd2f977277a5f51d6e52a434d6ec511a8b34990df2b1a9c5", size = 938665, upload-time = "2026-06-04T04:46:31.461Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/7f/0ed305faf932595d364af9a3046c044f9277273db9e1f033a66fbf2c5b77/lmdb-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:211cad947bc361cbe3c19ef6800d4e1dcb8f2f15e3e5b9bad34cc2818431d268", size = 115968, upload-time = "2026-06-04T04:45:50.068Z" }, + { url = "https://files.pythonhosted.org/packages/30/1e/712864753e331ecf2d93569a6a6d3d1f2a9dcb54feb11a2ace590e32f989/lmdb-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:090c498f57883d69420e4c6a6ec5726471e6ca35e183fe8f032165348c7d49b3", size = 114871, upload-time = "2026-06-04T04:45:51.35Z" }, + { url = "https://files.pythonhosted.org/packages/02/89/7570997080a4e778e6e066c829e722d73ebbc25c269982001b9ce8a26abf/lmdb-2.2.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:aa4115c7fc86ca6ee654f931ceba9e410e83f3296e64cb73125020286be54eb2", size = 326436, upload-time = "2026-06-04T04:45:52.672Z" }, + { url = "https://files.pythonhosted.org/packages/af/97/dc5716d168d652cb2f04bef856a88d51652c42a09c20d23d2e08d4b7704a/lmdb-2.2.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c145f6a67cc10c0c055cf4b9ce16274fb850c4d9690fef5428cb588f0694be1", size = 329516, upload-time = "2026-06-04T04:45:54.233Z" }, + { url = "https://files.pythonhosted.org/packages/63/74/a8701f8e74ced8ec82de63fa0ac098c9fea41e4c57121ca9724790f7ef55/lmdb-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:7d39273c9cd561a7a084090ba33c008b668257c9202c15aa7d9f9c550f44d030", size = 113705, upload-time = "2026-06-04T04:45:55.482Z" }, + { url = "https://files.pythonhosted.org/packages/98/9a/a1304e1cdb991de6f250f5723a90558b17d4f34a0f1a7315cfa6cb301fee/lmdb-2.2.1-cp311-cp311-win_arm64.whl", hash = "sha256:2e5104ae83edf2e04e54ef9b85b07f080e982ea6c3d5c701b4bca2653ee160f1", size = 107498, upload-time = "2026-06-04T04:45:56.806Z" }, + { url = "https://files.pythonhosted.org/packages/1b/93/4796573d885dbc0dd94ed712d070c6919a019acd12754c4708ba8a47732d/lmdb-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e6957c9346ce9e9300ca2b75625e681b9868bbaf4d257626ec96d221e8200fc4", size = 116824, upload-time = "2026-06-04T04:45:58.058Z" }, + { url = "https://files.pythonhosted.org/packages/33/20/d3e48f1af18d67e56c2f42f82a598c2586d7d47dca7c8edda4f479e108b4/lmdb-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bd3f3ab6feed2d4ca87d9d9063d2e371c8cc6d72879d54ae160a1c32758d26c0", size = 115341, upload-time = "2026-06-04T04:45:59.352Z" }, + { url = "https://files.pythonhosted.org/packages/5e/3e/6c3d2aa3b2250220d664a3ebb137519b6c33f94e27bf62e903130fac2cb4/lmdb-2.2.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9129a78af25dd1316784d689fefbd88bda6a756c82847a72b7f423bc1282dbd0", size = 333528, upload-time = "2026-06-04T04:46:00.748Z" }, + { url = "https://files.pythonhosted.org/packages/cf/72/64588fb1359b9a8d2fc6d3bfd98cd6a7f22adcd5fffa4252874529e72794/lmdb-2.2.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:13438ad327f8bca47f1415671335eec500b653459d269556eb2cf2470cecec30", size = 338288, upload-time = "2026-06-04T04:46:02.097Z" }, + { url = "https://files.pythonhosted.org/packages/35/19/bf3466f65c7795d44b6119cd62fa505a1fd3ebb50d71bd20b823e2b1485c/lmdb-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:e54f8705489f8b6668b648333fbd90875c06878b3226a64f3f1af58af01c3d00", size = 113598, upload-time = "2026-06-04T04:46:03.593Z" }, + { url = "https://files.pythonhosted.org/packages/a9/7f/214172bc46f67ec58ee0ec0cda3cf6b27ceeaef614be25c863b7da35f9a8/lmdb-2.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:84468990d6b7f50243a1eb19e7f9fbaead93eb7de0eb854b7dacc7f893c699ea", size = 107614, upload-time = "2026-06-04T04:46:04.834Z" }, + { url = "https://files.pythonhosted.org/packages/55/ea/65df850c0f371856eb495c018b13b16da229cb072a06236021130ce6c2f7/lmdb-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d468fa89da30515979bf35c3e5b4db0ded560f9c39449c11459559c9f85bb820", size = 117352, upload-time = "2026-06-04T04:46:06.103Z" }, + { url = "https://files.pythonhosted.org/packages/1f/88/94a079be5dc482cb9971da32a82046bdcf2124646e4d84c5b4412ccb8d78/lmdb-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:881e8cdde83d9130b9cf75faf3202c16cbdeb54da7ec58a0856e8adfff5d5c25", size = 115703, upload-time = "2026-06-04T04:46:07.42Z" }, + { url = "https://files.pythonhosted.org/packages/a3/73/e360c13279ea523d0caf2d231dd581c9fd0e4c6b49f33acde8613f0b653c/lmdb-2.2.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d54bb7ef49241602599f6fee8547ba14765b896ec459dad9620940235c550ab6", size = 336991, upload-time = "2026-06-04T04:46:08.706Z" }, + { url = "https://files.pythonhosted.org/packages/9f/de/e36baf673fb218b17c0c7a8050d1aad7bd49eb7b8fcf8cf0268ddc06507e/lmdb-2.2.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:12b84c38d091bb283853d8af38951338bf3eb729d8e79f0381291b098c0616f6", size = 340692, upload-time = "2026-06-04T04:46:10.326Z" }, + { url = "https://files.pythonhosted.org/packages/c0/de/9e13991db388343ca59caf684e1572705d9d89bc5cc681cfa912cd3b9106/lmdb-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:f68a203f45d7442527c9cc8cd9a7e10666e38b64a71775870bf5b54c30a15661", size = 113526, upload-time = "2026-06-04T04:46:11.73Z" }, + { url = "https://files.pythonhosted.org/packages/4b/83/2c27f9544034387badbadf577a716cf5681afd79f5fb762c2038b62af70b/lmdb-2.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:6f783cd75835eb7d4676be5b0d38f68a31961f07d74126fd6424377005fb4d04", size = 107682, upload-time = "2026-06-04T04:46:12.981Z" }, +] + [[package]] name = "locket" version = "1.0.0" @@ -5106,6 +5132,7 @@ dependencies = [ { name = "fsspec" }, { name = "hydra-core" }, { name = "jieba" }, + { name = "lmdb" }, { name = "loguru" }, { name = "mecab-python3" }, { name = "omegaconf" }, @@ -5548,6 +5575,7 @@ requires-dist = [ { name = "jieba", specifier = "==0.42.1" }, { name = "justext", marker = "extra == 'text-cpu'" }, { name = "librosa", marker = "extra == 'audio-common'" }, + { name = "lmdb", specifier = ">=1.4" }, { name = "loguru" }, { name = "lxml", marker = "extra == 'text-cpu'" }, { name = "matplotlib", marker = "extra == 'interleaved-cpu'" },