diff --git a/nemo_curator/stages/base.py b/nemo_curator/stages/base.py index 81dd00a71c..f3d9bb53fb 100644 --- a/nemo_curator/stages/base.py +++ b/nemo_curator/stages/base.py @@ -19,7 +19,7 @@ import time from abc import ABC, ABCMeta, abstractmethod from inspect import isabstract -from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, final +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, final, get_origin, get_type_hints from loguru import logger @@ -328,8 +328,28 @@ def ray_stage_spec(self) -> dict[str, Any]: Returns (dict[str, Any]): Dictionary containing Ray-specific configuration """ + if self._process_returns_list(): + return {"is_fanout_stage": True} return {} + @classmethod + def _process_returns_list(cls) -> bool: + """Return whether the stage's process annotation can return a list.""" + try: + return_annotation = get_type_hints(cls.process).get("return") + except (NameError, TypeError, AttributeError): + return_annotation = cls.process.__annotations__.get("return") + + return cls._annotation_includes_list(return_annotation) + + @staticmethod + def _annotation_includes_list(annotation: object) -> bool: + """Return whether an annotation is or includes a list type.""" + if isinstance(annotation, str): + return annotation.strip().startswith(("list[", "List[", "typing.List[")) + + return get_origin(annotation) is list + # --- Custom per-stage metrics helpers --- def _log_metrics(self, metrics: dict[str, float]) -> None: """Record custom metrics for this stage (e.g., sub-stage timings).""" diff --git a/nemo_curator/stages/deduplication/semantic/pairwise_io.py b/nemo_curator/stages/deduplication/semantic/pairwise_io.py index d951fd572a..73245e72b3 100644 --- a/nemo_curator/stages/deduplication/semantic/pairwise_io.py +++ b/nemo_curator/stages/deduplication/semantic/pairwise_io.py @@ -17,7 +17,6 @@ from loguru import logger from nemo_curator.backends.base import WorkerMetadata -from nemo_curator.backends.utils import RayStageSpecKeys from nemo_curator.stages.base import ProcessingStage from nemo_curator.stages.resources import Resources from nemo_curator.tasks import EmptyTask, FileGroupTask @@ -65,12 +64,6 @@ def setup(self, _: WorkerMetadata | None = None) -> None: self.fs = get_fs(self.input_path, storage_options=self.storage_options) self.path_normalizer = self.fs.unstrip_protocol if is_remote_url(self.input_path) else (lambda x: x) - def ray_stage_spec(self) -> dict[str, Any]: - """Ray stage specification for this stage.""" - return { - RayStageSpecKeys.IS_FANOUT_STAGE: True, - } - def xenna_stage_spec(self) -> dict[str, Any]: return { "num_workers_per_node": 1, diff --git a/nemo_curator/stages/file_partitioning.py b/nemo_curator/stages/file_partitioning.py index 75d6906501..2efcebb95a 100644 --- a/nemo_curator/stages/file_partitioning.py +++ b/nemo_curator/stages/file_partitioning.py @@ -17,7 +17,6 @@ from loguru import logger -from nemo_curator.backends.utils import RayStageSpecKeys from nemo_curator.stages.base import ProcessingStage from nemo_curator.stages.resources import Resources from nemo_curator.tasks import EmptyTask, FileGroupTask @@ -97,12 +96,6 @@ def inputs(self) -> tuple[list[str], list[str]]: def outputs(self) -> tuple[list[str], list[str]]: return [], [] - def ray_stage_spec(self) -> dict[str, Any]: - """Ray stage specification for this stage.""" - return { - RayStageSpecKeys.IS_FANOUT_STAGE: True, - } - def xenna_stage_spec(self) -> dict[str, Any]: return {"num_workers_per_node": 1} diff --git a/nemo_curator/stages/text/download/base/url_generation.py b/nemo_curator/stages/text/download/base/url_generation.py index 96ab2e27a9..84f3c8b2f7 100644 --- a/nemo_curator/stages/text/download/base/url_generation.py +++ b/nemo_curator/stages/text/download/base/url_generation.py @@ -76,11 +76,6 @@ def process(self, task: EmptyTask) -> list[FileGroupTask]: for i, url in enumerate(urls) ] - def ray_stage_spec(self) -> dict[str, Any]: - return { - "is_fanout_stage": True, - } - def xenna_stage_spec(self) -> dict[str, Any]: return { "num_workers_per_node": 1, diff --git a/nemo_curator/stages/video/clipping/clip_extraction_stages.py b/nemo_curator/stages/video/clipping/clip_extraction_stages.py index 28d20ea9ef..94c9473830 100644 --- a/nemo_curator/stages/video/clipping/clip_extraction_stages.py +++ b/nemo_curator/stages/video/clipping/clip_extraction_stages.py @@ -17,12 +17,10 @@ import subprocess import uuid from dataclasses import dataclass -from typing import Any from loguru import logger from nemo_curator.backends.base import WorkerMetadata -from nemo_curator.backends.utils import RayStageSpecKeys from nemo_curator.stages.base import ProcessingStage from nemo_curator.stages.resources import Resources from nemo_curator.tasks.video import Clip, Video, VideoTask @@ -90,19 +88,13 @@ def inputs(self) -> tuple[list[str], list[str]]: def outputs(self) -> tuple[list[str], list[str]]: return ["data"], [] - def ray_stage_spec(self) -> dict[str, Any]: - """Ray stage specification for this stage.""" - return { - RayStageSpecKeys.IS_FANOUT_STAGE: True, - } - - def process(self, task: VideoTask) -> VideoTask: + def process(self, task: VideoTask) -> list[VideoTask]: video = task.data if not video.clips: logger.warning(f"No clips to transcode for {video.input_video}. Skipping...") video.source_bytes = None - return task + return [task] with make_pipeline_temporary_dir(sub_dir="transcode") as tmp_dir: # write video to file diff --git a/tests/stages/audio/datasets/test_fleurs_create_initial_manifest.py b/tests/stages/audio/datasets/test_fleurs_create_initial_manifest.py index 313b551c1e..704f6b67f0 100644 --- a/tests/stages/audio/datasets/test_fleurs_create_initial_manifest.py +++ b/tests/stages/audio/datasets/test_fleurs_create_initial_manifest.py @@ -244,6 +244,13 @@ def test_prepare_fleurs_stage_dataset_does_not_recopy(tmp_path: Path) -> None: assert (output_path / "hy_am" / "train.tsv").is_file() +def test_create_initial_manifest_stage_is_ray_fanout_stage(tmp_path: Path) -> None: + stage_cls, _ = _import_stage_module() + stage = stage_cls(lang="hy_am", split="dev", raw_data_dir=tmp_path.as_posix()) + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + + def test_process_transcript_parses_tsv(tmp_path: Path) -> None: stage_cls, _ = _import_stage_module() lang = "hy_am" diff --git a/tests/stages/common/test_base.py b/tests/stages/common/test_base.py index dc7c3b0c42..47b672e1ea 100644 --- a/tests/stages/common/test_base.py +++ b/tests/stages/common/test_base.py @@ -51,6 +51,60 @@ def outputs(self) -> tuple[list[str], list[str]]: return [], [] +class FanoutProcessingStage(ProcessingStage[MockTask, MockTask]): + """ProcessingStage that returns multiple tasks.""" + + name = "FanoutProcessingStage" + + def process(self, task: MockTask) -> list[MockTask]: + return [task] + + +class MaybeFanoutProcessingStage(ProcessingStage[MockTask, MockTask]): + """ProcessingStage that may return multiple tasks.""" + + name = "MaybeFanoutProcessingStage" + + def process(self, task: MockTask) -> MockTask | list[MockTask]: + return task + + +class ExplicitMaybeFanoutProcessingStage(MaybeFanoutProcessingStage): + """Maybe-fanout stage that opts into Ray fanout explicitly.""" + + name = "ExplicitMaybeFanoutProcessingStage" + + def ray_stage_spec(self) -> dict[str, bool]: + return {"is_fanout_stage": True} + + +class StringAnnotatedFanoutProcessingStage(ProcessingStage[MockTask, MockTask]): + """ProcessingStage with a string return annotation.""" + + name = "StringAnnotatedFanoutProcessingStage" + + def process(self, task: MockTask) -> list[MockTask]: + return [task] + + +StringAnnotatedFanoutProcessingStage.process.__annotations__["return"] = "list[MissingTask]" + + +class AttributeErrorAnnotatedFanoutProcessingStage(ProcessingStage[MockTask, MockTask]): + """ProcessingStage with a string annotation that raises AttributeError in get_type_hints.""" + + name = "AttributeErrorAnnotatedFanoutProcessingStage" + + def process(self, task: MockTask) -> list[MockTask]: + return [task] + + +missing_annotation_namespace = object() +AttributeErrorAnnotatedFanoutProcessingStage.process.__annotations__["return"] = ( + "list[missing_annotation_namespace.MissingTask]" +) + + class TestProcessingStageWith: """Test the with_ method for ProcessingStage.""" @@ -268,6 +322,40 @@ def process(self, task: MockTask) -> MockTask: assert stage_with_custom2.resources == Resources(cpus=7.0) +class TestProcessingStageRaySpec: + """Test Ray stage spec defaults.""" + + def test_default_ray_stage_spec_empty_for_single_task_stage(self): + stage = ConcreteProcessingStage() + + assert stage.ray_stage_spec() == {} + + def test_ray_stage_spec_detects_fanout_stage(self): + stage = FanoutProcessingStage() + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + + def test_ray_stage_spec_does_not_infer_optional_fanout_stage(self): + stage = MaybeFanoutProcessingStage() + + assert stage.ray_stage_spec() == {} + + def test_ray_stage_spec_allows_optional_fanout_stage_to_opt_in(self): + stage = ExplicitMaybeFanoutProcessingStage() + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + + def test_ray_stage_spec_detects_string_annotated_fanout_stage(self): + stage = StringAnnotatedFanoutProcessingStage() + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + + def test_ray_stage_spec_falls_back_after_attribute_error(self): + stage = AttributeErrorAnnotatedFanoutProcessingStage() + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + + class TestProcessingStageOverriddenProperties: """Test that ProcessingStage raises an error if a derived class overrides the _name, _resources, or _batch_size property.""" diff --git a/tests/stages/deduplication/semantic/test_pairwise_io.py b/tests/stages/deduplication/semantic/test_pairwise_io.py index e5229fe786..d28bd020fc 100644 --- a/tests/stages/deduplication/semantic/test_pairwise_io.py +++ b/tests/stages/deduplication/semantic/test_pairwise_io.py @@ -45,6 +45,11 @@ def test_setup(self): assert stage.path_normalizer is not None assert stage.path_normalizer("/test/path") == "/test/path" + def test_ray_stage_spec_is_fanout_stage(self): + stage = ClusterWiseFilePartitioningStage("/test/path") + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + def test_process_finds_all_centroid_files(self, tmp_path: Path): """Test that process method finds all files in centroid directories.""" diff --git a/tests/stages/video/clipping/test_clip_transcoding_stage.py b/tests/stages/video/clipping/test_clip_transcoding_stage.py index d6b8c23fac..325b60a390 100644 --- a/tests/stages/video/clipping/test_clip_transcoding_stage.py +++ b/tests/stages/video/clipping/test_clip_transcoding_stage.py @@ -130,7 +130,10 @@ def test_process_no_clips(self) -> None: # Should return early and log warning mock_logger.warning.assert_called_once() assert "No clips to transcode" in mock_logger.warning.call_args[0][0] - assert result.data.source_bytes is None + assert isinstance(result, list) + assert len(result) == 1 + assert result[0] is self.mock_task + assert result[0].data.source_bytes is None @patch("nemo_curator.stages.video.clipping.clip_extraction_stages.make_pipeline_temporary_dir") @patch("nemo_curator.stages.video.clipping.clip_extraction_stages.grouping.split_by_chunk_size")