Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nemo_curator/stages/text/io/reader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from nemo_curator.stages.text.io.reader.jsonl import JsonlReader
from nemo_curator.stages.text.io.reader.lance import LanceReader
from nemo_curator.stages.text.io.reader.parquet import ParquetReader

__all__ = ["JsonlReader", "ParquetReader"]
__all__ = ["JsonlReader", "LanceReader", "ParquetReader"]
88 changes: 63 additions & 25 deletions nemo_curator/stages/text/io/reader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar

import numpy as np
import pandas as pd
import pyarrow as pa
import ray
from loguru import logger

Expand All @@ -28,20 +29,31 @@
from nemo_curator.backends.utils import RayStageSpecKeys
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.tasks import DocumentBatch, FileGroupTask
from nemo_curator.tasks.tasks import Task

ReaderTask = TypeVar("ReaderTask", bound=Task)
ReaderData: TypeAlias = pd.DataFrame | pa.Table


@dataclass(frozen=True)
class ReaderOutput:
data: ReaderData
metadata: dict[str, Any] | None = None


@dataclass
class BaseReader(ProcessingStage[FileGroupTask, DocumentBatch]):
"""Common base for tabular file readers.
class BaseReader(ProcessingStage[ReaderTask, DocumentBatch]):
"""Common base for tabular readers.

Subclasses must implement the read_data method.
Subclasses must implement read_task for their input task type.
"""

fields: list[str] | None = None
read_kwargs: dict[str, Any] = field(default_factory=dict)
name: str = ""
_generate_ids: bool = False
_assign_ids: bool = False
allow_empty: bool = False

def __post_init__(self) -> None:
if self._generate_ids and self._assign_ids:
Expand All @@ -52,7 +64,7 @@ def inputs(self) -> tuple[list[str], list[str]]:
return [], []

def outputs(self) -> tuple[list[str], list[str]]:
output_fields = self.fields or []
output_fields = list(self.fields or [])
if self._generate_ids or self._assign_ids:
from nemo_curator.stages.deduplication.id_generator import CURATOR_DEDUP_ID_STR

Expand All @@ -72,24 +84,16 @@ def setup(self, _: WorkerMetadata | None = None) -> None:
)
raise RuntimeError(msg) from None

def process(self, task: FileGroupTask) -> DocumentBatch:
# Merge read kwargs with storage options precedence: task.storage_options > self.read_kwargs
effective_read_kwargs: dict[str, Any] = {}
if self.read_kwargs:
effective_read_kwargs.update(self.read_kwargs)

# Read the files
result = self.read_data(task.data, effective_read_kwargs, self.fields)
def process(self, task: ReaderTask) -> DocumentBatch:
output = self.read_task(task, self._effective_read_kwargs(), self.fields)
self._validate_result(task, output.data)
return self._document_batch(task, output)

# Validate the result
if (
(result is None)
or (hasattr(result, "empty") and result.empty)
or (hasattr(result, "num_rows") and result.num_rows == 0)
):
msg = f"No data read from files in task {task.task_id}"
raise ValueError(msg)
def _effective_read_kwargs(self) -> dict[str, Any]:
return dict(self.read_kwargs or {})

def _document_batch(self, task: ReaderTask, output: ReaderOutput) -> DocumentBatch:
result = output.data
# Apply IDs only for Pandas DataFrames
if isinstance(result, pd.DataFrame):
if self._generate_ids:
Expand All @@ -100,16 +104,29 @@ def process(self, task: FileGroupTask) -> DocumentBatch:
return DocumentBatch(
dataset_name=task.dataset_name,
data=result,
_metadata=task._metadata,
_metadata=self._output_metadata(task, output),
_stage_perf=task._stage_perf,
)

def _output_metadata(self, task: ReaderTask, _output: ReaderOutput) -> dict[str, Any]:
return task._metadata

def _validate_result(self, task: ReaderTask, result: ReaderData) -> None:
if self.allow_empty:
return
if (result is None) or (isinstance(result, pd.DataFrame) and result.empty) or (
isinstance(result, pa.Table) and result.num_rows == 0
):
msg = f"No data read from files in task {task.task_id}"
raise ValueError(msg)

# Subclass responsibilities -------------------------------------------------
def read_data(
def read_task(
self,
file_paths: list[str],
task: ReaderTask,
read_kwargs: dict[str, Any] | None,
fields: list[str] | None,
) -> pd.DataFrame | None: # pragma: no cover - abstract
) -> ReaderOutput: # pragma: no cover - abstract
raise NotImplementedError

# ID helpers ----------------------------------------------------------------
Expand All @@ -136,3 +153,24 @@ def _generate_ids_func(self, filepath: str | list[str], df: pd.DataFrame) -> pd.

def ray_stage_spec(self) -> dict[str, Any]:
return {RayStageSpecKeys.IS_ACTOR_STAGE: self._generate_ids or self._assign_ids}


@dataclass
class BaseFileReader(BaseReader[FileGroupTask]):
"""Base reader for file-group readers that consume lists of paths."""

def read_task(
self,
task: FileGroupTask,
read_kwargs: dict[str, Any] | None,
fields: list[str] | None,
) -> ReaderOutput:
return ReaderOutput(self.read_data(task.data, read_kwargs, fields))

def read_data(
self,
file_paths: list[str],
read_kwargs: dict[str, Any] | None,
fields: list[str] | None,
) -> ReaderData: # pragma: no cover - abstract
raise NotImplementedError
6 changes: 3 additions & 3 deletions nemo_curator/stages/text/io/reader/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from nemo_curator.tasks import DocumentBatch, EmptyTask
from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS, pandas_select_columns

from .base import BaseReader
from .base import BaseFileReader


@dataclass
class JsonlReaderStage(BaseReader):
class JsonlReaderStage(BaseFileReader):
"""
Stage that processes a group of JSONL files into a DocumentBatch.
This stage accepts FileGroupTasks created by FilePartitioningStage
Expand All @@ -53,7 +53,7 @@ def read_data(
paths: list[str],
read_kwargs: dict[str, Any] | None = None,
fields: list[str] | None = None,
) -> pd.DataFrame | None:
) -> pd.DataFrame:
"""Read JSONL files using Pandas."""

# Normalize read_kwargs to a dict to avoid TypeError when None
Expand Down
Loading
Loading