Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions nemo_curator/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,24 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from loguru import logger

from nemo_curator.core.utils import ignore_ray_head_node
from nemo_curator.tasks import Task
from nemo_curator.tasks.sentinels import FailedTask, NoneTask
from nemo_curator.utils.performance_utils import StageTimer
from nemo_curator.utils.resumability_client import _flush_deltas, _is_active, _skip_completed_sources

if TYPE_CHECKING:
from nemo_curator.stages.base import ProcessingStage


def _is_sentinel(task: Task) -> bool:
"""A payload-less marker (NoneTask/FailedTask) that is stripped before the
next stage rather than propagated."""
return isinstance(task, (NoneTask, FailedTask))


@dataclass
class NodeInfo:
"""Generic node information for setup_on_node calls across backends.
Expand Down Expand Up @@ -85,9 +95,23 @@ def process_batch(self, tasks: list[Task]) -> list[Task]:
# Use the batch processing logic
results = self.stage.process_batch(tasks)

# A returned ``None`` ("filter this slot") becomes a NoneTask so every
# output is a real Task that gets a task_id. Sentinels (NoneTask /
# FailedTask) carry no identity and are stripped again before this
# method returns.
results = [NoneTask() if r is None else r for r in results]

# Guarantee every emitted task has a task_id (derived id, or uuid fallback).
results = self._post_process_task_ids(tasks, results)

# Opt-in resumability: fire per-source counter deltas. A no-op (the
# client helpers self-disable) when no resumability actor is registered.
if _is_active():
results = self._apply_resumability_counters(tasks, results)

# Sentinels never propagate to the next stage.
results = [r for r in results if not _is_sentinel(r)]

# Log performance stats and add to result tasks
_, stage_perf_stats = self._timer.log_stats()
# Consume and attach any custom metrics recorded by the stage during this call
Expand Down Expand Up @@ -168,6 +192,108 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas
task.task_id = "r" + uuid.uuid4().hex
return out

# ------------------------------------------------------------------ #
# Resumability (opt-in). Runs only when a resumability actor is
# registered. task_ids are already assigned by _post_process_task_ids;
# this layer only stamps _source_id, fires per-source counter deltas, and
# drops already-completed sources. Sentinels are stripped by the caller.
# ------------------------------------------------------------------ #
def _apply_resumability_counters(self, input_tasks: list[Task], output_tasks: list[Task]) -> list[Task]: # noqa: C901
# Every delta's dedup key is an OUTPUT task_id, never an input's
# (``parent.task_id``). The source fires ``+1`` keyed on its output
# partition's id; that id is the *input* id of the next stage, so keying
# a downstream delta on the input would reuse the source's key and the
# actor would treat the two as one conflicting event. An output id is
# always one level deeper, so it's unique to the (task, stage) that
# produced it.
stage = self.stage
if getattr(stage, "is_source_stage", False):
return self._source_counters(output_tasks)

# No outputs at all. Filtering is expressed as None -> NoneTask (a kept
# slot), so a stage that emits nothing is degenerate; there is no output

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do we mean by "is degenerate".. can this ever happen? because we have a NoneTask so not output_tasks shouldn't be valid right?

# to key a delta on, so skip (like the ambiguous-cardinality case).
if not output_tasks:
return output_tasks
Comment on lines +183 to +184

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Why don't we just error here ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, output_tasks can be empty, right? At the end of the pipeline or something?


# Pre-source stages: inputs carry no _source_id, so there's nothing to
# track yet. Leave outputs untouched.
if all(not t._source_id for t in input_tasks):
return output_tasks
Comment on lines +187 to +188

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which tasks are these? What's a pre-source stage? Is it the initial_tasks?


is_sink = stage.is_sink_stage

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question: why would the user ever want something other than source stage being the first stage and sink stage being the last stage of the pipeline? Like if the last stage failed but the second to last stage was the sink stage, they just don't want to rerun the last stage?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sync stage not being the last stage, I think PDF pipelines are a good example. I have this metadata stage called stagePerfLogging at the end, which is needed because I cannot do stuff similar to benchmarking, since the pipeline.run never returns.

As for the source stage, I don't have an example in mind, but we don't need to force this assumption hence, I would prefer for us to keep it relaxed. From a user's perspective, if they don't specify, the default is that the first stage is source and the last stage is sync.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for source we have a case where user might provide initial task and no source stage is defined.

per_task: list[tuple[str, str, int]] = []
real = [t for t in output_tasks if not _is_sentinel(t)]

if len(input_tasks) == 1 and len(output_tasks) != 1:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question:

Suggested change
if len(input_tasks) == 1 and len(output_tasks) != 1:
if len(input_tasks) == 1 and len(output_tasks) > 1:

# Genuine fan-out (1 -> N, N != 1). One net delta for the parent it
# is consumed (-1); each real child continues (+1) unless this is a
# sink, where children leave the pipeline (0); each FailedTask keeps
# the source open (+1); NoneTask contributes nothing. sink and
# fan-out are independent, so the sink test applies here too.
parent = input_tasks[0]
n_failed = sum(1 for t in output_tasks if isinstance(t, FailedTask))
continuing = 0 if is_sink else len(real)
delta = continuing + n_failed - 1
# Key on output_tasks[0].task_id (NOT parent.task_id, which collides
# with the source's +1). It always ends in "_0": get_deterministic_id()
# is consulted only for source stages (which return via
# _source_counters and never reach here), so non-source children are
# indexed positionally (suffix 0, 1, ...) -> output[0] is "<parent>_0".
per_task.append((output_tasks[0].task_id, parent._source_id, delta))
for c in real:
if not c._source_id:
c._source_id = parent._source_id
Comment on lines +194 to +208

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Fan-out branch silently completes sources that have FailedTask outputs

The fan-out delta len(real) - 1 excludes FailedTask items from real, but does not compensate for the missing "pending debt" those failed slots represent. Concretely: for 1 input → [real_A, failed_B] at a non-sink stage, len(real) = 1, delta = 0, counter stays at 1; when real_A reaches the sink it fires -1 and the counter hits 0 — the source is marked complete, but failed_B was supposed to keep it pending for retry. For sink fan-out the bug is even more immediate: is_sink forces delta to -1 regardless of real, so the counter zeros out the moment the batch emits even if every output is a FailedTask.

The 1:1 positional branch handles FailedTask correctly with an explicit continue (no delta). The fan-out branch has no analogous guard. The simplest conservative fix is to bail out (no delta fired) whenever any fan-out output is a FailedTask — the source stays pending and will reprocess on resume, consistent with FailedTask's documented semantics.

elif len(output_tasks) == len(input_tasks):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're using the number of inputs and outputs as a proxy for whether a stage is 1:1 or fan-in/fan-out type stages. However in some cases (like shuffle) it might be the case that by chance the number of inputs and output tasks are identical. However since the inputs are completely shuffled I don't think we can make assumptions about reusambility.

# Positional 1:1, including filtered (NoneTask) / failed slots. Each
# delta keys on the OUTPUT id (r.task_id).
for parent, r in zip(input_tasks, output_tasks, strict=True):
sid = parent._source_id
if isinstance(r, NoneTask):
# Filtered: this slot is consumed.
per_task.append((r.task_id, sid, -1))
continue
if isinstance(r, FailedTask):
# Failed: leave the source open so it reruns (no sink test).
per_task.append((r.task_id, sid, 0))
continue
# Real: a sink consumes it (-1); otherwise it passes through (0).
per_task.append((r.task_id, sid, -1 if is_sink else 0))
if not r._source_id:
r._source_id = sid
else:
# M inputs -> K outputs (K != M): the parent of each output can't be
# determined, so the counter can't be updated correctly. Skip
# (the source counter stays pending -> reprocessed on resume).
logger.warning(
f"resumability: {type(stage).__name__} produced {len(output_tasks)} outputs "
f"for {len(input_tasks)} inputs; can't attribute sources, skipping counter "
f"update for this batch."
)
return output_tasks

_flush_deltas(per_task)
return output_tasks

def _source_counters(self, output_tasks: list[Task]) -> list[Task]:
"""Source stage: each output is a source partition. Its ``_source_id``
is its own (last) id segment — the content id or index assigned by
``_post_process_task_ids``. Already-completed sources are dropped; each
surviving source fires a ``+1``."""
sources = [t for t in output_tasks if not _is_sentinel(t)]
for t in sources:
t._source_id = t.task_id.rsplit("_", 1)[-1]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 rsplit strips only the last segment, not the full suffix

t._source_id = t.task_id.rsplit("_", 1)[-1] takes only the last underscore-delimited token. For a source with task_id = "0_abc_def" (produced when get_deterministic_id() returns "abc_def"), this yields "def" instead of "abc_def". Two unrelated sources whose deterministic IDs differ only in a prefix (e.g. "shard1_42" and "shard2_42") would get the same _source_id = "42", and the first to complete would cause the second to be silently skipped on the next run. Since get_deterministic_id() is explicitly documented as overridable, this is a real footgun.

The correct extraction for sources is to strip just the parent prefix ("0_") using split("_", 1)[1] so that the entire content-based suffix is preserved as the source identity.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we even have multiple delimited cases here, @abhinavg4 ? Or will it be like 0_1_2 ?

@VibhuJawa VibhuJawa Jun 23, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m a little worried that this is relying on task_id’s string encoding too directly.

task_id is effectively an id path, but here we’re parsing it with rsplit("_", 1) and treating the last path segment as the source identity. That works for the current common source shape, but it makes the resumability logic depend on delimiter
details that are not really owned by this code. It also maybe fragile for cases like N→N source stages where outputs without get_deterministic_id() can all get suffix 0, which would make multiple source partitions share _source_id == "0". (Which is an adverse case i guess ??)

Minor ask: Could we centralize this behind a small task-id helper/API instead of open-coding string parsing? For example, something like:

task.source_id = TaskId.parse(task.task_id).leaf()
 # or
task_id.get_last_segment()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answering both of the above.

We will always have _ delimited cases. The task ID is controlled by base.py only. Slightly above this function. I'm not sure if there's a good way to make resumability independent of this delimiter. But I'm comfortable with it since we assign task ID and we always use '_'

For the case when get_deterministic_id() all give the same value. This is an adverse case. Like this will break a ton of other stuff too (like writers and stuff). In the curator, we do not (Cannot?) ensure that these id's are unique across tasks. We just rely on the user to ensure this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor ask: Could we centralize this behind a small task-id helper/API instead of open-coding string parsing?

Great call out yes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Praateek also asked this somewhere. But yeah, I can make TaskId as a class or something. Or actually I can put this inside task. So we can do task.get_source_id()

completed = _skip_completed_sources([t._source_id for t in sources])
per_task: list[tuple[str, str, int]] = []
survivors: list[Task] = []
for t in sources:
if t._source_id in completed:
continue
per_task.append((t.task_id, t._source_id, +1))
survivors.append(t)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little confused why the source ID here is the last part of the X_Y_Z chain? I guess maybe I am confused about how task_id versus _source_id are constructed/formatted.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is since source stage was the 3rd stage in your example. So each task_id is of one of the following format:

0_1_2_{sid}_2_1
0_1_2_0_1_1
r{uuid}_1_2

For the {sid}, the index purely depends on what index does source stage happens. It always starts with zero since an empty task. For most cases, the task_id will have this format:

0_{sid}_0_0_0_0: Single fanout at source and then filters.

_flush_deltas(per_task)
return survivors

def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None:
"""Setup the stage on a node.

Expand Down
64 changes: 62 additions & 2 deletions nemo_curator/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Any

from loguru import logger
Expand Down Expand Up @@ -222,18 +223,35 @@ def describe(self) -> str:

return "\n".join(lines)

def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] | None = None) -> list[Task] | None:
def run(
self,
executor: BaseExecutor | None = None,
initial_tasks: list[Task] | None = None,
checkpoint_path: str | Path | None = None,
) -> list[Task] | None:
"""Run the pipeline.

Args:
executor (BaseExecutor): Executor to use
initial_tasks (list[Task], optional): Initial tasks to start the pipeline with. Defaults to None.
checkpoint_path (str | Path, optional): Directory used for
resumability. When set, completed source partitions are tracked
across runs and skipped on rerun; the tracking state lives in a
``.nemo_curator_metadata`` subdirectory. Multiple independent
runs (e.g. the tasks of a SLURM array) may point at the same
directory — each writes its own LMDB file, so there is no
shared-file contention. The actor lifecycle is owned by this
method; executors are not modified.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit ai slop too long a substring..


Returns:
list[Task] | None: List of tasks
"""
self.build()

if checkpoint_path is not None:
checkpoint_path = Path(checkpoint_path).absolute()
checkpoint_path.mkdir(parents=True, exist_ok=True)

if executor is None:
from nemo_curator.backends.xenna import XennaExecutor

Expand Down Expand Up @@ -263,4 +281,46 @@ def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] |
if initial_tasks:
assign_root_task_ids(initial_tasks)

return executor.execute(self.stages, initial_tasks)
if checkpoint_path is None:
return executor.execute(self.stages, initial_tasks)
return self._run_with_resumability(executor, initial_tasks, checkpoint_path)

def _run_with_resumability(
self,
executor: BaseExecutor,
initial_tasks: list[Task] | None,
checkpoint_path: Path,
) -> list[Task] | None:
"""Owns the full resumability-actor lifecycle. Per-backend executors
are not modified — the actor is spawned ``lifetime="detached"`` so
it survives executor-local ``ray.shutdown()`` calls.

The actor never raises (see ``ResumabilityActor.apply_deltas``), so
there's no watchdog and no error propagation path here — just spawn,
run, close.
"""
import ray

from nemo_curator.utils.resumability_actor import ResumabilityActor
from nemo_curator.utils.resumability_client import ACTOR_NAME

ray.init(ignore_reinit_error=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Shouldn't there have already been a RayClient().start() before pipeline.run() was called?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not enforce that right now, right? Like if the user forgets to include that, still our pipelines run. Ideally, I can add a check in the pipeline.run saying please either start a Ray client with RayClient.start() or SlurmClient with SLurmClient.start(). Would you prefer that? I would personally prefer that tbh.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point, it might be nice to enforce it in pipeline.run().

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't do ray.init() here. The ray.init() must live inside the executor.run() since each executor might have their own env-variables that they'd like to propogate.

IOW, the following doesn't work since X won't be propogated.

ray.init()
    ray.init(env_vars=X) 
    ray.shutdown()
ray.shutdown()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you might have to do

with ray.init():
    start_actor()
executor.execute()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this. I don't think we should ray.init here because any subsequent ray.init will be a no-op that doesn't pass env vars and if the client hasn't been started can lead to weird behavior w.r.t the actor getting killed by the shutdowns if any.

ResumabilityActor.options( # type: ignore[attr-defined]
name=ACTOR_NAME,
lifetime="detached",
get_if_exists=True,
max_pending_calls=100,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the rationale behind setting this value?

).remote(str(checkpoint_path))
Comment on lines +300 to +306

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Silent actor init failure when LMDB setup fails

The ActorHandle returned by .remote() is discarded. If ResumabilityActor.__init__ raises — for example because the LMDB file cannot be opened (bad permissions, disk full, path is read-only) — the exception is stored in the returned ObjectRef and never surfaced. The actor is placed in a DEAD state and removed from the Ray name registry, so _actor() subsequently returns None, _is_active() returns False, and all checkpointing silently does nothing for the entire run. The user passed checkpoint_path expecting resumability to be active, but gets no error and no indication it isn't working.

A lightweight fix is to call a trivially-cheap method on the handle and ray.get it immediately after construction; this surfaces any __init__ exception synchronously before the pipeline starts.

Suggested change
ray.init(ignore_reinit_error=True)
ResumabilityActor.options( # type: ignore[attr-defined]
name=ACTOR_NAME,
lifetime="detached",
get_if_exists=True,
max_pending_calls=100,
).remote(str(checkpoint_path))
ray.init(ignore_reinit_error=True)
actor_handle = ResumabilityActor.options( # type: ignore[attr-defined]
name=ACTOR_NAME,
lifetime="detached",
get_if_exists=True,
max_pending_calls=100,
).remote(str(checkpoint_path))
# Verify the actor started successfully; surfaces any __init__ exception
# (e.g. LMDB open failure) before the pipeline begins so the user is not
# left believing checkpointing is active when it silently isn't.
ray.get(actor_handle.are_completed.remote([]), timeout=30) # type: ignore[attr-defined]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed here, in fact we got bit by this (for a different reason)...

If during a retry checkpoint_path contains a lot of data, and the constructor of ResumabilityActor is loading it up, the process is also async.

The ray recommended way is to have a def wait() inside your actor and then do ray.get(actor_handle.wait()) explictly before moving on to next line of code..

Again see

ray.get(actor_handle.wait.remote())


try:
return executor.execute(self.stages, initial_tasks)
finally:
# The executor's ray.shutdown() may have run in its own
# finally:; reconnect to clean up the detached actor.
try:
ray.init(ignore_reinit_error=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, do with ray.init() and then do this...

actor_handle = ray.get_actor(ACTOR_NAME)
ray.get(actor_handle.close.remote(), timeout=10) # type: ignore[attr-defined]
ray.kill(actor_handle)
except Exception as e: # noqa: BLE001
logger.warning(f"resumability actor cleanup failed: {e}")
Comment on lines +310 to +319

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Fire-and-forget deltas from the last batch can be discarded before close() drains them

executor.execute() returns when all task-processing workers are done, but the workers' fire-and-forget apply_deltas.remote() calls are async Ray messages that may still be in transit. Ray does not guarantee cross-actor ordering: the driver's close.remote() message can arrive at the actor before the workers' final apply_deltas messages, because message ordering is only guaranteed per sender-receiver pair.

With max_concurrency=1 the actor processes one call at a time, so once close() runs it sets self._env = None. Any apply_deltas messages that arrive after that and happen to call _persist_completed will hit an AttributeError on None.begin() inside the actor. ray.kill immediately afterward discards all remaining queued messages, so completions from the pipeline's final batch may not be persisted to LMDB.

The safest fix is to enqueue a no-op drain call after executor.execute() returns and before close(), ensuring all prior fire-and-forget messages are ahead of it in the mailbox.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree here!

Comment on lines +313 to +319

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 ray.kill skipped when close() times out — stale actor not cleaned up

ray.get(close.remote(), timeout=10) and ray.kill share the same try block. If close() exceeds the 10-second deadline (e.g., LMDB flushing a large write), GetTimeoutError is caught by except Exception, and ray.kill is never called. The actor remains alive with lifetime="detached".

On a subsequent pipeline.run(checkpoint_path=<different_path>) against the same Ray cluster, get_if_exists=True silently returns the stale actor — which was initialised with the old path. All checkpointing in the new run writes to the wrong LMDB location and reads a stale completed-sources set, so resumability silently does the wrong thing without any error.

The fix is to unconditionally call ray.kill even if close() didn't succeed, by nesting the close attempt inside its own try/except before the kill.

4 changes: 3 additions & 1 deletion nemo_curator/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@
from .file_group import FileGroupTask
from .image import ImageBatch, ImageObject
from .interleaved import InterleavedBatch
from .sentinels import EmptyTask, SentinelTask
from .sentinels import EmptyTask, FailedTask, NoneTask, SentinelTask
from .tasks import Task

__all__ = [
"AudioTask",
"DocumentBatch",
"EmptyTask",
"FailedTask",
"FileGroupTask",
"ImageBatch",
"ImageObject",
"InterleavedBatch",
"NoneTask",
"SentinelTask",
"Task",
]
29 changes: 26 additions & 3 deletions nemo_curator/tasks/sentinels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@
# limitations under the License.
"""Payload-less marker tasks.

``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). All markers
share the :class:`SentinelTask` base and carry no payload (``data is None``).
Construct one with ``EmptyTask()``.
``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). The resumability
layer adds two more markers on the same :class:`SentinelTask` base:

- ``NoneTask`` — this slot was intentionally filtered. The resumability counter

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- ``NoneTask`` this slot was intentionally filtered. The resumability counter
- ``NoneTask`` - this task was intentionally filtered. The resumability counter

treats it as a consumed branch (decrements). The adapter auto-wraps a
returned ``None`` as a ``NoneTask``.
- ``FailedTask`` — this slot failed and should be retried on resume. The counter

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- ``FailedTask``this slot failed and should be retried on resume. The counter
- ``FailedTask``this slot failed and should be retried on resume. The resumability counter

is NOT decremented, so its source stays pending and reruns.

All carry no payload (``data is None``) and get their ``task_id`` assigned by
the executor adapter; sentinels are stripped before the next stage. Construct
with ``EmptyTask()`` / ``NoneTask()`` / ``FailedTask()``.
"""

from dataclasses import dataclass, field
Expand Down Expand Up @@ -52,3 +61,17 @@ class EmptyTask(SentinelTask):

dataset_name: str = "empty"
task_id: str = field(init=False, default="0")


@dataclass
class NoneTask(SentinelTask):
"""Marks a slot as intentionally filtered (resumability counter decrements)."""

dataset_name: str = "none"


@dataclass
class FailedTask(SentinelTask):
Comment thread
sarahyurick marked this conversation as resolved.
"""Marks a slot as failed → retried on resume (counter does NOT decrement)."""

dataset_name: str = "failed"
5 changes: 5 additions & 0 deletions nemo_curator/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ class Task(ABC, Generic[T]):
NON-deterministic (differ across runs).
dataset_name: Name of the dataset this task belongs to.
_stage_perf: List of stages perfs this task has passed through.
_source_id: Identifier of the source (input partition) this task

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit-ish but the docstring for Task is out of order and does not include data or _metadata.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix if there are other updates

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : let's fix this

descends from. Stamped at the source stage and inherited
downstream; used only by the (opt-in) resumability layer to
track which sources have completed. Empty for pre-source tasks.
"""

dataset_name: str
data: T
_stage_perf: list[StagePerfStats] = field(default_factory=list)
_metadata: dict[str, Any] = field(default_factory=dict)
task_id: str = field(init=False, default="")
_source_id: str = field(init=False, default="")

def __post_init__(self) -> None:
"""Post-initialization hook."""
Expand Down
Loading
Loading