From 944dc519284d940860b4e38f3e6fe6346560df62 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 24 Jun 2026 21:02:29 +0000 Subject: [PATCH 1/2] Add Lance reader stage Signed-off-by: Vibhu Jawa --- .../stages/text/io/reader/__init__.py | 3 +- nemo_curator/stages/text/io/reader/base.py | 88 ++++-- nemo_curator/stages/text/io/reader/jsonl.py | 6 +- nemo_curator/stages/text/io/reader/lance.py | 257 ++++++++++++++++++ nemo_curator/stages/text/io/reader/parquet.py | 4 +- nemo_curator/utils/lance.py | 16 ++ pyproject.toml | 3 + tests/L0_Unit_Test_CPU.sh | 2 +- tests/stages/text/io/reader/test_lance.py | 115 ++++++++ tests/stages/text/io/reader/test_parquet.py | 18 ++ uv.lock | 53 +++- 11 files changed, 532 insertions(+), 33 deletions(-) create mode 100644 nemo_curator/stages/text/io/reader/lance.py create mode 100644 nemo_curator/utils/lance.py create mode 100644 tests/stages/text/io/reader/test_lance.py diff --git a/nemo_curator/stages/text/io/reader/__init__.py b/nemo_curator/stages/text/io/reader/__init__.py index 973757b078..4d27af23b1 100644 --- a/nemo_curator/stages/text/io/reader/__init__.py +++ b/nemo_curator/stages/text/io/reader/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from nemo_curator.stages.text.io.reader.jsonl import JsonlReader +from nemo_curator.stages.text.io.reader.lance import LanceReader from nemo_curator.stages.text.io.reader.parquet import ParquetReader -__all__ = ["JsonlReader", "ParquetReader"] +__all__ = ["JsonlReader", "LanceReader", "ParquetReader"] diff --git a/nemo_curator/stages/text/io/reader/base.py b/nemo_curator/stages/text/io/reader/base.py index 551dc463be..ff0042c796 100644 --- a/nemo_curator/stages/text/io/reader/base.py +++ b/nemo_curator/stages/text/io/reader/base.py @@ -15,10 +15,11 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar import numpy as np import pandas as pd +import pyarrow as pa import ray from loguru import logger @@ -28,13 +29,23 @@ from nemo_curator.backends.utils import RayStageSpecKeys from nemo_curator.stages.base import ProcessingStage from nemo_curator.tasks import DocumentBatch, FileGroupTask +from nemo_curator.tasks.tasks import Task + +ReaderTask = TypeVar("ReaderTask", bound=Task) +ReaderData: TypeAlias = pd.DataFrame | pa.Table + + +@dataclass(frozen=True) +class ReaderOutput: + data: ReaderData + metadata: dict[str, Any] | None = None @dataclass -class BaseReader(ProcessingStage[FileGroupTask, DocumentBatch]): - """Common base for tabular file readers. +class BaseReader(ProcessingStage[ReaderTask, DocumentBatch]): + """Common base for tabular readers. - Subclasses must implement the read_data method. + Subclasses must implement read_task for their input task type. """ fields: list[str] | None = None @@ -42,6 +53,7 @@ class BaseReader(ProcessingStage[FileGroupTask, DocumentBatch]): name: str = "" _generate_ids: bool = False _assign_ids: bool = False + allow_empty: bool = False def __post_init__(self) -> None: if self._generate_ids and self._assign_ids: @@ -52,7 +64,7 @@ def inputs(self) -> tuple[list[str], list[str]]: return [], [] def outputs(self) -> tuple[list[str], list[str]]: - output_fields = self.fields or [] + output_fields = list(self.fields or []) if self._generate_ids or self._assign_ids: from nemo_curator.stages.deduplication.id_generator import CURATOR_DEDUP_ID_STR @@ -72,24 +84,16 @@ def setup(self, _: WorkerMetadata | None = None) -> None: ) raise RuntimeError(msg) from None - def process(self, task: FileGroupTask) -> DocumentBatch: - # Merge read kwargs with storage options precedence: task.storage_options > self.read_kwargs - effective_read_kwargs: dict[str, Any] = {} - if self.read_kwargs: - effective_read_kwargs.update(self.read_kwargs) - - # Read the files - result = self.read_data(task.data, effective_read_kwargs, self.fields) + def process(self, task: ReaderTask) -> DocumentBatch: + output = self.read_task(task, self._effective_read_kwargs(), self.fields) + self._validate_result(task, output.data) + return self._document_batch(task, output) - # Validate the result - if ( - (result is None) - or (hasattr(result, "empty") and result.empty) - or (hasattr(result, "num_rows") and result.num_rows == 0) - ): - msg = f"No data read from files in task {task.task_id}" - raise ValueError(msg) + def _effective_read_kwargs(self) -> dict[str, Any]: + return dict(self.read_kwargs or {}) + def _document_batch(self, task: ReaderTask, output: ReaderOutput) -> DocumentBatch: + result = output.data # Apply IDs only for Pandas DataFrames if isinstance(result, pd.DataFrame): if self._generate_ids: @@ -100,16 +104,29 @@ def process(self, task: FileGroupTask) -> DocumentBatch: return DocumentBatch( dataset_name=task.dataset_name, data=result, - _metadata=task._metadata, + _metadata=self._output_metadata(task, output), + _stage_perf=task._stage_perf, ) + def _output_metadata(self, task: ReaderTask, _output: ReaderOutput) -> dict[str, Any]: + return task._metadata + + def _validate_result(self, task: ReaderTask, result: ReaderData) -> None: + if self.allow_empty: + return + if (result is None) or (isinstance(result, pd.DataFrame) and result.empty) or ( + isinstance(result, pa.Table) and result.num_rows == 0 + ): + msg = f"No data read from files in task {task.task_id}" + raise ValueError(msg) + # Subclass responsibilities ------------------------------------------------- - def read_data( + def read_task( self, - file_paths: list[str], + task: ReaderTask, read_kwargs: dict[str, Any] | None, fields: list[str] | None, - ) -> pd.DataFrame | None: # pragma: no cover - abstract + ) -> ReaderOutput: # pragma: no cover - abstract raise NotImplementedError # ID helpers ---------------------------------------------------------------- @@ -136,3 +153,24 @@ def _generate_ids_func(self, filepath: str | list[str], df: pd.DataFrame) -> pd. def ray_stage_spec(self) -> dict[str, Any]: return {RayStageSpecKeys.IS_ACTOR_STAGE: self._generate_ids or self._assign_ids} + + +@dataclass +class BaseFileReader(BaseReader[FileGroupTask]): + """Base reader for file-group readers that consume lists of paths.""" + + def read_task( + self, + task: FileGroupTask, + read_kwargs: dict[str, Any] | None, + fields: list[str] | None, + ) -> ReaderOutput: + return ReaderOutput(self.read_data(task.data, read_kwargs, fields)) + + def read_data( + self, + file_paths: list[str], + read_kwargs: dict[str, Any] | None, + fields: list[str] | None, + ) -> ReaderData: # pragma: no cover - abstract + raise NotImplementedError diff --git a/nemo_curator/stages/text/io/reader/jsonl.py b/nemo_curator/stages/text/io/reader/jsonl.py index 7ae4c81003..94c2c8ab37 100644 --- a/nemo_curator/stages/text/io/reader/jsonl.py +++ b/nemo_curator/stages/text/io/reader/jsonl.py @@ -23,11 +23,11 @@ from nemo_curator.tasks import DocumentBatch, EmptyTask from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS, pandas_select_columns -from .base import BaseReader +from .base import BaseFileReader @dataclass -class JsonlReaderStage(BaseReader): +class JsonlReaderStage(BaseFileReader): """ Stage that processes a group of JSONL files into a DocumentBatch. This stage accepts FileGroupTasks created by FilePartitioningStage @@ -53,7 +53,7 @@ def read_data( paths: list[str], read_kwargs: dict[str, Any] | None = None, fields: list[str] | None = None, - ) -> pd.DataFrame | None: + ) -> pd.DataFrame: """Read JSONL files using Pandas.""" # Normalize read_kwargs to a dict to avoid TypeError when None diff --git a/nemo_curator/stages/text/io/reader/lance.py b/nemo_curator/stages/text/io/reader/lance.py new file mode 100644 index 0000000000..c202a5cd2c --- /dev/null +++ b/nemo_curator/stages/text/io/reader/lance.py @@ -0,0 +1,257 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal + +import pyarrow as pa + +from nemo_curator.backends.utils import RayStageSpecKeys +from nemo_curator.stages.base import CompositeStage, ProcessingStage +from nemo_curator.tasks import DocumentBatch, EmptyTask +from nemo_curator.tasks.tasks import Task +from nemo_curator.utils.hash_utils import get_deterministic_hash +from nemo_curator.utils.lance import ( + LANCE_FRAGID_COLUMN, + LANCE_ROWADDR_COLUMN, +) + +from .base import BaseReader, ReaderOutput + + +def _read_dataset_kwargs(read_kwargs: dict[str, Any], version: int | None = None) -> dict[str, Any]: + resolved_version = version if version is not None else read_kwargs.get("version") + options = {"storage_options": read_kwargs.get("storage_options"), "version": resolved_version} + return {**dict(read_kwargs.get("dataset_options") or {}), **{k: v for k, v in options.items() if v is not None}} + + +def _scanner_kwargs(read_kwargs: dict[str, Any], fields: list[str] | None) -> dict[str, Any]: + scanner_kwargs = dict(read_kwargs.get("scanner_options") or {}) + for key, value in read_kwargs.items(): + if key in {"dataset_options", "scanner_options", "storage_options", "version"}: + continue + scanner_kwargs[key] = value + if fields is not None: + scanner_kwargs["columns"] = fields + return scanner_kwargs + + +def _requested_blob_v2_columns(dataset: object, scanner_kwargs: dict[str, Any]) -> list[str]: + requested_columns = scanner_kwargs.get("columns") + if isinstance(requested_columns, dict | list): + requested_columns = set(requested_columns) + + return [ + field.name + for field in dataset.schema # type: ignore[attr-defined] + if getattr(field.type, "extension_name", None) == "lance.blob.v2" + and (requested_columns is None or field.name in requested_columns) + ] + + +def _restore_lance_blob_v2_columns(dataset: object, table: pa.Table, blob_columns: list[str]) -> pa.Table: + import lance + + rowaddrs = [int(value) for value in table["_rowaddr"].combine_chunks().to_pylist()] + for column in blob_columns: + payloads = [ + payload + for _, payload in dataset.read_blobs(column, addresses=rowaddrs, preserve_order=True) # type: ignore[attr-defined] + ] + table = table.set_column(table.schema.get_field_index(column), column, lance.blob_array(payloads)) + return table + + +def _fragment_ids_from_row_addresses(rowaddr_column: pa.ChunkedArray) -> pa.Array: + rowaddrs = rowaddr_column.combine_chunks().cast(pa.uint64()) + return pa.array([int(value) >> 32 for value in rowaddrs.to_pylist()], type=pa.uint64()) + + +def _add_lance_metadata(table: pa.Table) -> pa.Table: + if "_rowaddr" not in table.column_names: + msg = "Lance scanner did not return _rowaddr; include_lance_metadata requires row addresses" + raise ValueError(msg) + + table = table.rename_columns([LANCE_ROWADDR_COLUMN if name == "_rowaddr" else name for name in table.column_names]) + return table.append_column(LANCE_FRAGID_COLUMN, _fragment_ids_from_row_addresses(table[LANCE_ROWADDR_COLUMN])) + + +@dataclass +class LanceReadTask(Task[list[int]]): + data: list[int] = field(default_factory=list) + + @property + def num_items(self) -> int: + return len(self.data) + + def validate(self) -> bool: + return bool(self.data) + + def get_deterministic_id(self) -> str: + lance_metadata = self._metadata.get("lance") or {} + parts = [ + str(lance_metadata.get("path", self.dataset_name)), + str(lance_metadata.get("version", "")), + *(str(fragment_id) for fragment_id in self.data), + ] + return get_deterministic_hash(parts) + + +@dataclass +class LancePartitioningStage(ProcessingStage[EmptyTask, LanceReadTask]): + path: str + fragments_per_partition: int = 32 + fragment_ids: list[int] | None = None + read_kwargs: dict[str, Any] = field(default_factory=dict) + name: str = "lance_partitioning" + + def __post_init__(self) -> None: + if self.fragments_per_partition <= 0: + msg = "fragments_per_partition must be greater than 0" + raise ValueError(msg) + + def ray_stage_spec(self) -> dict[str, Any]: + return {RayStageSpecKeys.IS_FANOUT_STAGE: True} + + def process(self, _: EmptyTask) -> list[LanceReadTask]: + import lance + + dataset = lance.dataset(self.path, **_read_dataset_kwargs(self.read_kwargs)) + available_fragments = [fragment.fragment_id for fragment in dataset.get_fragments()] + if self.fragment_ids is None: + fragment_ids = available_fragments + else: + available = set(available_fragments) + missing = sorted(set(self.fragment_ids) - available) + if missing: + msg = f"Lance dataset does not contain requested fragment ids: {missing[:10]}" + raise ValueError(msg) + fragment_ids = list(self.fragment_ids) + + tasks = [] + for start in range(0, len(fragment_ids), self.fragments_per_partition): + owned_fragments = fragment_ids[start : start + self.fragments_per_partition] + tasks.append( + LanceReadTask( + dataset_name=self.path, + data=owned_fragments, + _metadata={ + "source_files": [self.path], + "lance": { + "path": self.path, + "version": dataset.version, + "fragment_ids": owned_fragments, + }, + }, + ) + ) + return tasks + + +@dataclass +class LanceReaderStage(BaseReader[LanceReadTask]): + """Read Lance fragment groups into Arrow batches.""" + + path: str = "" + fields: list[str] | None = None + read_kwargs: dict[str, Any] = field(default_factory=dict) + include_lance_metadata: bool = True + allow_empty: bool = True + name: str = "lance_reader" + + def __post_init__(self) -> None: + super().__post_init__() + if not self.path: + msg = "path is required" + raise ValueError(msg) + + def outputs(self) -> tuple[list[str], list[str]]: + output_fields = list(self.fields or self.read_kwargs.get("columns") or []) + if self.include_lance_metadata: + output_fields.extend([LANCE_ROWADDR_COLUMN, LANCE_FRAGID_COLUMN]) + return ["data"], output_fields + + def _output_metadata(self, task: LanceReadTask, output: ReaderOutput) -> dict[str, Any]: + return output.metadata if output.metadata is not None else task._metadata + + def read_task( + self, + task: LanceReadTask, + read_kwargs: dict[str, Any] | None, + fields: list[str] | None, + ) -> ReaderOutput: + import lance + from lance.schema import schema_to_json + + read_kwargs = {} if read_kwargs is None else read_kwargs + version = task._metadata["lance"]["version"] + dataset = lance.dataset(self.path, **_read_dataset_kwargs(read_kwargs, version=version)) + fragments = [dataset.get_fragment(fragment_id) for fragment_id in task.data] + scanner_kwargs = _scanner_kwargs(read_kwargs, fields) + blob_columns = _requested_blob_v2_columns(dataset, scanner_kwargs) + if self.include_lance_metadata or blob_columns: + scanner_kwargs["with_row_address"] = True + scanner_kwargs["fragments"] = fragments + table = dataset.scanner(**scanner_kwargs).to_table() + if blob_columns: + table = _restore_lance_blob_v2_columns(dataset, table, blob_columns) + if self.include_lance_metadata: + table = _add_lance_metadata(table) + elif blob_columns and "_rowaddr" in table.column_names: + table = table.drop_columns(["_rowaddr"]) + + metadata = dict(task._metadata) + lance_metadata = dict(metadata.get("lance") or {}) + lance_metadata["schema"] = schema_to_json(dataset.schema) + metadata["lance"] = lance_metadata + return ReaderOutput(table, metadata) + + +@dataclass +class LanceReader(CompositeStage[EmptyTask, DocumentBatch]): + """Read a Lance dataset into Curator ``DocumentBatch`` objects by fragment.""" + path: str + fragments_per_partition: int = 32 + fields: list[str] | None = None + read_kwargs: dict[str, Any] | None = None + include_lance_metadata: bool = True + fragment_ids: list[int] | None = None + task_type: Literal["document"] = "document" + name: str = "lance_reader" + + def __post_init__(self) -> None: + super().__init__() + self.read_kwargs = {} if self.read_kwargs is None else dict(self.read_kwargs) + + def decompose(self) -> list[ProcessingStage]: + if self.task_type != "document": + msg = f"Converting DocumentBatch to {self.task_type} is not supported yet." + raise NotImplementedError(msg) + + return [ + LancePartitioningStage( + path=self.path, + fragments_per_partition=self.fragments_per_partition, + fragment_ids=self.fragment_ids, + read_kwargs=self.read_kwargs, + ), + LanceReaderStage( + path=self.path, + fields=self.fields, + read_kwargs=self.read_kwargs, + include_lance_metadata=self.include_lance_metadata, + ), + ] diff --git a/nemo_curator/stages/text/io/reader/parquet.py b/nemo_curator/stages/text/io/reader/parquet.py index 0654b0f62c..6b6247060a 100644 --- a/nemo_curator/stages/text/io/reader/parquet.py +++ b/nemo_curator/stages/text/io/reader/parquet.py @@ -22,11 +22,11 @@ from nemo_curator.tasks import DocumentBatch, EmptyTask from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS -from .base import BaseReader +from .base import BaseFileReader @dataclass -class ParquetReaderStage(BaseReader): +class ParquetReaderStage(BaseFileReader): """ Stage that processes a group of Parquet files into a DocumentBatch. This stage accepts FileGroupTasks created by FilePartitioningStage diff --git a/nemo_curator/utils/lance.py b/nemo_curator/utils/lance.py new file mode 100644 index 0000000000..0657341d93 --- /dev/null +++ b/nemo_curator/utils/lance.py @@ -0,0 +1,16 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +LANCE_ROWADDR_COLUMN = "__lance_rowaddr" +LANCE_FRAGID_COLUMN = "__lance_fragid" diff --git a/pyproject.toml b/pyproject.toml index cedc8c48bf..0b3b4adf5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -206,6 +206,8 @@ text_cpu = [ "sentence-transformers", ] +lance = ["pylance>=7"] + text_cuda12 = [ "nemo_curator[cuda12]", "nemo_curator[deduplication_cuda12]", @@ -281,6 +283,7 @@ all = [ "nemo_curator[image_cuda12]", "nemo_curator[inference_server]", "nemo_curator[interleaved_cuda12]", + "nemo_curator[lance]", "nemo_curator[math_cuda12]", "nemo_curator[sdg_cuda12]", "nemo_curator[text_cuda12]", diff --git a/tests/L0_Unit_Test_CPU.sh b/tests/L0_Unit_Test_CPU.sh index 0a498beb78..032838e5b2 100644 --- a/tests/L0_Unit_Test_CPU.sh +++ b/tests/L0_Unit_Test_CPU.sh @@ -23,6 +23,6 @@ export UV_NO_CACHE=1 rm -rf .venv uv venv --seed --python "${PY_VERSION}" -uv sync --no-progress --link-mode copy --locked --extra audio_cpu --extra sdg_cpu --extra text_cpu --extra video_cpu --group test +uv sync --no-progress --link-mode copy --locked --extra audio_cpu --extra sdg_cpu --extra text_cpu --extra video_cpu --extra lance --group test source .venv/bin/activate coverage run -a --branch --source=nemo_curator -m pytest -v "tests/$FOLDER" -m "not gpu" diff --git a/tests/stages/text/io/reader/test_lance.py b/tests/stages/text/io/reader/test_lance.py new file mode 100644 index 0000000000..a0d76f060d --- /dev/null +++ b/tests/stages/text/io/reader/test_lance.py @@ -0,0 +1,115 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import pyarrow as pa +import pytest + +from nemo_curator.stages.text.io.reader.base import BaseReader +from nemo_curator.stages.text.io.reader.lance import ( + LANCE_FRAGID_COLUMN, + LANCE_ROWADDR_COLUMN, + LancePartitioningStage, + LanceReader, + LanceReaderStage, +) +from nemo_curator.tasks import EmptyTask + +pytest.importorskip("lance") + + +def _write_lance_dataset(path: Path) -> None: + import lance + + table = pa.table( + { + "snapshot_id": ["CC-MAIN-2025-26", "CC-MAIN-2025-18", "CC-MAIN-2025-26", "CC-MAIN-2025-26"], + "url": ["https://a.example", "https://b.example", "https://c.example", "https://d.example"], + "text": ["alpha one", "beta two", "gamma three", "delta four"], + "content_zlib": lance.blob_array([b"html-a", b"html-b", b"html-c", b"html-d"]), + }, + schema=pa.schema( + [ + pa.field("snapshot_id", pa.string()), + pa.field("url", pa.string()), + pa.field("text", pa.string()), + lance.blob_field("content_zlib"), + ] + ), + ) + lance.write_dataset(table, str(path), mode="create", max_rows_per_file=2, max_rows_per_group=2, data_storage_version="2.2") + + +def test_lance_reader_partitions_filters_blobs_and_metadata(tmp_path: Path): + dataset_path = tmp_path / "docs.lance" + _write_lance_dataset(dataset_path) + read_tasks = LancePartitioningStage(path=str(dataset_path), fragments_per_partition=1).process(EmptyTask) + + assert issubclass(LanceReaderStage, BaseReader) + assert len(read_tasks) == 2 + assert read_tasks[0].dataset_name == str(dataset_path) + assert {fragment_id for task in read_tasks for fragment_id in task.data} == {0, 1} + assert read_tasks[0].get_deterministic_id() != read_tasks[1].get_deterministic_id() + + reader = LanceReaderStage( + path=str(dataset_path), + fields=["snapshot_id", "url", "content_zlib"], + read_kwargs={"filter": "snapshot_id = 'CC-MAIN-2025-26'", "scanner_options": {"batch_size": 2}}, + ) + batches = [batch for task in read_tasks if (batch := reader.process(task))] + + seen_fragments: set[int] = set() + for batch in batches: + table = batch.to_pyarrow() + assert "schema" in batch._metadata["lance"] + assert LANCE_ROWADDR_COLUMN in table.column_names + assert LANCE_FRAGID_COLUMN in table.column_names + assert table.schema.field("content_zlib").type.extension_name == "lance.blob.v2" + fragids = {int(value) for value in table[LANCE_FRAGID_COLUMN].combine_chunks().to_pylist()} + assert seen_fragments.isdisjoint(fragids) + seen_fragments.update(fragids) + assert seen_fragments == {0, 1} + + +def test_lance_reader_columns_empty_filters_and_fields_override(tmp_path: Path): + dataset_path = tmp_path / "docs.lance" + _write_lance_dataset(dataset_path) + task = LancePartitioningStage(path=str(dataset_path)).process(EmptyTask)[0] + + batch = LanceReaderStage(path=str(dataset_path), read_kwargs={"columns": ["url"]}, include_lance_metadata=False).process(task) + assert batch.to_pyarrow().column_names == ["url"] + + empty_batch = LanceReaderStage(path=str(dataset_path), read_kwargs={"filter": "snapshot_id = 'missing'"}).process(task) + empty_table = empty_batch.to_pyarrow() + assert empty_table.num_rows == 0 + assert LANCE_ROWADDR_COLUMN in empty_table.column_names + assert LANCE_FRAGID_COLUMN in empty_table.column_names + + _, reader_stage = LanceReader(path="example.lance", fields=["a", "b"], read_kwargs={"columns": ["ignored"]}).decompose() + assert reader_stage.fields == ["a", "b"] + assert reader_stage.include_lance_metadata is True + + +def test_lance_reader_uses_partition_version(tmp_path: Path): + import lance + + dataset_path = tmp_path / "docs.lance" + lance.write_dataset(pa.table({"text": ["old"]}), str(dataset_path), mode="create", max_rows_per_file=1) + task = LancePartitioningStage(path=str(dataset_path)).process(EmptyTask)[0] + lance.write_dataset(pa.table({"text": ["new"]}), str(dataset_path), mode="overwrite", max_rows_per_file=1) + + batch = LanceReaderStage(path=str(dataset_path), fields=["text"], include_lance_metadata=False).process(task) + + assert batch.to_pyarrow()["text"].to_pylist() == ["old"] diff --git a/tests/stages/text/io/reader/test_parquet.py b/tests/stages/text/io/reader/test_parquet.py index 4bb93e72bf..6f7cc7be55 100644 --- a/tests/stages/text/io/reader/test_parquet.py +++ b/tests/stages/text/io/reader/test_parquet.py @@ -86,6 +86,7 @@ def test_parquet_reader_stage_pandas_reads_and_concatenates(sample_parquet_files ): out = stage.process(task) assert isinstance(out, DocumentBatch) + assert out._metadata == {"source_files": sample_parquet_files[:2]} df = out.to_pandas() assert isinstance(df, pd.DataFrame) @@ -157,6 +158,23 @@ def test_parquet_reader_stage_pandas_raises_when_all_columns_missing(tmp_path: P _ = stage.process(task) +def test_parquet_reader_stage_empty_file_uses_base_reader_policy(tmp_path: Path): + f = tmp_path / "empty.parquet" + pd.DataFrame({"text": pd.Series(dtype="string"), "score": pd.Series(dtype="float64")}).to_parquet( + f, + index=False, + ) + task = _make_file_group_task([str(f)]) + + with pytest.raises(ValueError, match="No data read from files"): + ParquetReaderStage().process(task) + + out = ParquetReaderStage(allow_empty=True).process(task) + assert isinstance(out, DocumentBatch) + assert out.num_items == 0 + assert out.to_pandas().columns.tolist() == ["text", "score"] + + def test_parquet_reader_stage_pyarrow_reads_and_concatenates(tmp_path: Path): f1 = tmp_path / "a.parquet" f2 = tmp_path / "b.parquet" diff --git a/uv.lock b/uv.lock index 4c0bfd823f..2f48372ef3 100644 --- a/uv.lock +++ b/uv.lock @@ -4009,6 +4009,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/08/10/9f8af3e6f569685ce3af7faab51c8dd9d93b9c38eba339ca31c746119447/kubernetes-32.0.1-py2.py3-none-any.whl", hash = "sha256:35282ab8493b938b08ab5526c7ce66588232df00ef5e1dbe88a419107dc10998", size = 1988070, upload-time = "2025-02-18T21:06:31.391Z" }, ] +[[package]] +name = "lance-namespace" +version = "0.7.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "lance-namespace-urllib3-client" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/5c/9822af615fc1bd3ee1073994696c739aecde377be32435ec3303aed1bc5d/lance_namespace-0.7.7.tar.gz", hash = "sha256:d00b525f2e26993a6c61668e798bca6c808605ab8a79f29f86a1a1af92d91ae2", size = 10754, upload-time = "2026-05-20T17:32:59.45Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/43/186acc1156da20c351db196e2b6241b2453b16dc1b4cc8e0a626667ca471/lance_namespace-0.7.7-py3-none-any.whl", hash = "sha256:477a7ca6b5e1f673a2c9ba52f42d6e8e3ff7c27a601392a21eb90fba98d0309b", size = 12581, upload-time = "2026-05-20T17:32:57.389Z" }, +] + +[[package]] +name = "lance-namespace-urllib3-client" +version = "0.7.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dateutil" }, + { name = "typing-extensions" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/07/95/38ab81ccc1e09beeecd8ddfc61b8bc73831dc5053db1e3f9021f64a4896b/lance_namespace_urllib3_client-0.7.7.tar.gz", hash = "sha256:4d8c066628c17c6a10cf643b51a7f7ae1bfb8a614d9cc54a5af38a4ba2b4b102", size = 202930, upload-time = "2026-05-20T17:32:58.308Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/35/96/5483e48e40433b1d078183c15a92c99e59a156041b0260e7f18ee34e7c08/lance_namespace_urllib3_client-0.7.7-py3-none-any.whl", hash = "sha256:9221c3e00fd89f0c811953d94b32d2ea527765280460a174f5872dc8a74c0ed6", size = 334767, upload-time = "2026-05-20T17:32:55.883Z" }, +] + [[package]] name = "lark" version = "1.2.2" @@ -5108,6 +5135,7 @@ all = [ { name = "pycld2" }, { name = "pycuda" }, { name = "pydub" }, + { name = "pylance" }, { name = "pylibcugraph-cu12" }, { name = "pylibraft-cu12" }, { name = "pynvvideocodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, @@ -5265,6 +5293,9 @@ interleaved-cuda12 = [ { name = "timm" }, { name = "vllm", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, ] +lance = [ + { name = "pylance" }, +] math-cpu = [ { name = "beautifulsoup4" }, { name = "boto3" }, @@ -5516,6 +5547,7 @@ requires-dist = [ { name = "nemo-curator", extras = ["inference-server"], marker = "extra == 'sdg-cuda12'" }, { name = "nemo-curator", extras = ["interleaved-cpu"], marker = "extra == 'interleaved-cuda12'" }, { name = "nemo-curator", extras = ["interleaved-cuda12"], marker = "extra == 'all'" }, + { name = "nemo-curator", extras = ["lance"], marker = "extra == 'all'" }, { name = "nemo-curator", extras = ["math-cpu"], marker = "extra == 'math-cuda12'" }, { name = "nemo-curator", extras = ["math-cuda12"], marker = "extra == 'all'" }, { name = "nemo-curator", extras = ["sdg-cpu"], marker = "extra == 'sdg-cuda12'" }, @@ -5564,6 +5596,7 @@ requires-dist = [ { name = "pycld2", marker = "extra == 'text-cpu'" }, { name = "pycuda", marker = "extra == 'video-cuda12'" }, { name = "pydub", marker = "extra == 'audio-common'", specifier = ">=0.25.1" }, + { name = "pylance", marker = "extra == 'lance'", specifier = ">=7" }, { name = "pylibcugraph-cu12", marker = "extra == 'deduplication-cuda12'", specifier = "==25.10.*" }, { name = "pylibraft-cu12", marker = "extra == 'deduplication-cuda12'", specifier = "==25.10.*" }, { name = "pynvvideocodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'video-cuda12'", specifier = "==2.0.2" }, @@ -5613,7 +5646,7 @@ requires-dist = [ { name = "warcio", marker = "extra == 'text-cpu'" }, { name = "whisperx", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'audio-common'", specifier = ">=3.8.4" }, ] -provides-extras = ["cuda12", "vllm", "inference-server", "deduplication-cuda12", "audio-common", "audio-cpu", "audio-cuda12", "image-cpu", "image-cuda12", "translation-common", "translation-metrics", "translation-segmentation", "translation-aws", "translation-google", "translation-nmt", "translation-all", "text-cpu", "text-cuda12", "video-cpu", "video-cuda12", "math-cpu", "math-cuda12", "interleaved-cpu", "interleaved-cuda12", "sdg-cpu", "sdg-cuda12", "all"] +provides-extras = ["cuda12", "vllm", "inference-server", "deduplication-cuda12", "audio-common", "audio-cpu", "audio-cuda12", "image-cpu", "image-cuda12", "translation-common", "translation-metrics", "translation-segmentation", "translation-aws", "translation-google", "translation-nmt", "translation-all", "text-cpu", "lance", "text-cuda12", "video-cpu", "video-cuda12", "math-cpu", "math-cuda12", "interleaved-cpu", "interleaved-cuda12", "sdg-cpu", "sdg-cuda12", "all"] [package.metadata.requires-dev] build = [ @@ -8213,6 +8246,24 @@ crypto = [ { name = "cryptography" }, ] +[[package]] +name = "pylance" +version = "7.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "lance-namespace" }, + { name = "numpy" }, + { name = "pyarrow" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/ad/2f64921bf346e7075aef24a72595db44821724a3d89a9a92dd24e79632aa/pylance-7.0.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:98422021975be76e72b1572f41b8c9abb3bee5bdc9bfa5e9ce731110a65ed4d1", size = 62134146, upload-time = "2026-05-27T21:59:37.459Z" }, + { url = "https://files.pythonhosted.org/packages/73/1c/c5a01bee0160b55d9a98895cbd33091d038f0a0995b121ab72e629008d02/pylance-7.0.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4bec86ee5b6fbd8bfc493e653f0a1fba0303cfe5492b9b46fc25ab908edc7183", size = 65373684, upload-time = "2026-05-27T22:04:01.584Z" }, + { url = "https://files.pythonhosted.org/packages/eb/da/1fe8b8f7dbfe734d76af76acc994fc360a0d0c79a4874ef69f5a72a58fe3/pylance-7.0.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881491432c53184e52f8d1db8d5f872f39a03f36fb104bec77b33d379519d8b5", size = 69458555, upload-time = "2026-05-27T22:16:50.567Z" }, + { url = "https://files.pythonhosted.org/packages/76/f0/dd505cf3fd0226ab9d94759acd713125af1d3bfacfd80bbd52e3b9f89509/pylance-7.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:18453999e7fff4f76b16d6b7882c9df0628bd142ff95e2461bd7dd5ee3fe0af3", size = 65394430, upload-time = "2026-05-27T22:05:30.923Z" }, + { url = "https://files.pythonhosted.org/packages/17/ba/2357b81034f28eb00790e258ed140289a6a887a7468ca9df6349fd186b27/pylance-7.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:04a58051d408c60fe76d41a220dcaf8fea8fb6d1aa0ca78a709b60bc3cc8d19a", size = 69473470, upload-time = "2026-05-27T22:17:18.935Z" }, + { url = "https://files.pythonhosted.org/packages/1f/ec/5c00b6303a67d787f9475141832cbdc513d674ac3dcaeef8a7b169905e65/pylance-7.0.0-cp39-abi3-win_amd64.whl", hash = "sha256:467d4864af047eaab4e1370e2f1e88e2c6f507c079874421116cb41d78bc3629", size = 74792863, upload-time = "2026-05-27T22:19:23.875Z" }, +] + [[package]] name = "pylibcudf-cu12" version = "25.10.0" From bc207532c8c720558de6c558d2e715a32ad6bbdb Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 24 Jun 2026 23:00:37 +0000 Subject: [PATCH 2/2] Refine Lance reader option handling Signed-off-by: Vibhu Jawa --- nemo_curator/stages/text/io/reader/lance.py | 200 +++++++++++++------- nemo_curator/utils/lance.py | 16 ++ tests/stages/text/io/reader/test_lance.py | 47 ++++- 3 files changed, 186 insertions(+), 77 deletions(-) diff --git a/nemo_curator/stages/text/io/reader/lance.py b/nemo_curator/stages/text/io/reader/lance.py index c202a5cd2c..f6bfe62e6f 100644 --- a/nemo_curator/stages/text/io/reader/lance.py +++ b/nemo_curator/stages/text/io/reader/lance.py @@ -15,9 +15,10 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal -import pyarrow as pa +if TYPE_CHECKING: + import pyarrow as pa from nemo_curator.backends.utils import RayStageSpecKeys from nemo_curator.stages.base import CompositeStage, ProcessingStage @@ -27,70 +28,23 @@ from nemo_curator.utils.lance import ( LANCE_FRAGID_COLUMN, LANCE_ROWADDR_COLUMN, + add_lance_metadata_columns, ) from .base import BaseReader, ReaderOutput -def _read_dataset_kwargs(read_kwargs: dict[str, Any], version: int | None = None) -> dict[str, Any]: - resolved_version = version if version is not None else read_kwargs.get("version") - options = {"storage_options": read_kwargs.get("storage_options"), "version": resolved_version} - return {**dict(read_kwargs.get("dataset_options") or {}), **{k: v for k, v in options.items() if v is not None}} - - -def _scanner_kwargs(read_kwargs: dict[str, Any], fields: list[str] | None) -> dict[str, Any]: - scanner_kwargs = dict(read_kwargs.get("scanner_options") or {}) - for key, value in read_kwargs.items(): - if key in {"dataset_options", "scanner_options", "storage_options", "version"}: - continue - scanner_kwargs[key] = value - if fields is not None: - scanner_kwargs["columns"] = fields - return scanner_kwargs - - -def _requested_blob_v2_columns(dataset: object, scanner_kwargs: dict[str, Any]) -> list[str]: - requested_columns = scanner_kwargs.get("columns") - if isinstance(requested_columns, dict | list): - requested_columns = set(requested_columns) - - return [ - field.name - for field in dataset.schema # type: ignore[attr-defined] - if getattr(field.type, "extension_name", None) == "lance.blob.v2" - and (requested_columns is None or field.name in requested_columns) - ] - - -def _restore_lance_blob_v2_columns(dataset: object, table: pa.Table, blob_columns: list[str]) -> pa.Table: - import lance - - rowaddrs = [int(value) for value in table["_rowaddr"].combine_chunks().to_pylist()] - for column in blob_columns: - payloads = [ - payload - for _, payload in dataset.read_blobs(column, addresses=rowaddrs, preserve_order=True) # type: ignore[attr-defined] - ] - table = table.set_column(table.schema.get_field_index(column), column, lance.blob_array(payloads)) - return table - - -def _fragment_ids_from_row_addresses(rowaddr_column: pa.ChunkedArray) -> pa.Array: - rowaddrs = rowaddr_column.combine_chunks().cast(pa.uint64()) - return pa.array([int(value) >> 32 for value in rowaddrs.to_pylist()], type=pa.uint64()) - - -def _add_lance_metadata(table: pa.Table) -> pa.Table: - if "_rowaddr" not in table.column_names: - msg = "Lance scanner did not return _rowaddr; include_lance_metadata requires row addresses" - raise ValueError(msg) +@dataclass +class LanceReadTask(Task[list[int]]): + """Task containing Lance fragment ids assigned to one read partition. - table = table.rename_columns([LANCE_ROWADDR_COLUMN if name == "_rowaddr" else name for name in table.column_names]) - return table.append_column(LANCE_FRAGID_COLUMN, _fragment_ids_from_row_addresses(table[LANCE_ROWADDR_COLUMN])) + This is created by ``LancePartitioningStage`` and consumed by + ``LanceReaderStage``. + Args: + data: Lance fragment ids to read. + """ -@dataclass -class LanceReadTask(Task[list[int]]): data: list[int] = field(default_factory=list) @property @@ -112,6 +66,18 @@ def get_deterministic_id(self) -> str: @dataclass class LancePartitioningStage(ProcessingStage[EmptyTask, LanceReadTask]): + """Stage that partitions a Lance dataset into fragment-id read tasks. + + The stage opens the dataset once, records the resolved Lance version in + each task, and emits fragment groups for ``LanceReaderStage``. + + Args: + path: Path or URI of the Lance dataset. + fragments_per_partition: Number of Lance fragments assigned to each read task. + fragment_ids: Optional explicit fragment ids to read. Defaults to all fragments. Duplicates are ignored. + read_kwargs: Keyword arguments for opening the Lance dataset. + """ + path: str fragments_per_partition: int = 32 fragment_ids: list[int] | None = None @@ -122,24 +88,36 @@ def __post_init__(self) -> None: if self.fragments_per_partition <= 0: msg = "fragments_per_partition must be greater than 0" raise ValueError(msg) + self.read_kwargs = dict(self.read_kwargs or {}) def ray_stage_spec(self) -> dict[str, Any]: return {RayStageSpecKeys.IS_FANOUT_STAGE: True} + def _dataset_kwargs(self) -> dict[str, Any]: + read_kwargs = dict(self.read_kwargs) + dataset_kwargs = dict(read_kwargs.pop("dataset_options", {}) or {}) + version = dataset_kwargs.pop("version", None) + version = read_kwargs.pop("version", version) + if version is not None: + dataset_kwargs["version"] = version + storage_options = read_kwargs.pop("storage_options", None) + if storage_options is not None: + dataset_kwargs["storage_options"] = storage_options + return dataset_kwargs + def process(self, _: EmptyTask) -> list[LanceReadTask]: import lance - dataset = lance.dataset(self.path, **_read_dataset_kwargs(self.read_kwargs)) - available_fragments = [fragment.fragment_id for fragment in dataset.get_fragments()] + dataset = lance.dataset(self.path, **self._dataset_kwargs()) + available_fragments = sorted(fragment.fragment_id for fragment in dataset.get_fragments()) if self.fragment_ids is None: fragment_ids = available_fragments else: - available = set(available_fragments) - missing = sorted(set(self.fragment_ids) - available) + fragment_ids = sorted(set(self.fragment_ids)) + missing = sorted(set(fragment_ids) - set(available_fragments)) if missing: msg = f"Lance dataset does not contain requested fragment ids: {missing[:10]}" raise ValueError(msg) - fragment_ids = list(self.fragment_ids) tasks = [] for start in range(0, len(fragment_ids), self.fragments_per_partition): @@ -163,7 +141,18 @@ def process(self, _: EmptyTask) -> list[LanceReadTask]: @dataclass class LanceReaderStage(BaseReader[LanceReadTask]): - """Read Lance fragment groups into Arrow batches.""" + """Stage that reads Lance fragment groups into ``DocumentBatch`` objects. + + This stage consumes ``LanceReadTask`` objects from ``LancePartitioningStage`` + and reads the pinned dataset version stored in each task. + + Args: + path: Path or URI of the Lance dataset. + fields: Optional columns to read. Overrides ``columns`` in ``read_kwargs``. + read_kwargs: Keyword arguments for Lance dataset and scanner construction. + include_lance_metadata: Whether to include row-address and fragment-id metadata columns. + allow_empty: Whether filtered reads may return empty tables without raising. + """ path: str = "" fields: list[str] | None = None @@ -177,9 +166,14 @@ def __post_init__(self) -> None: if not self.path: msg = "path is required" raise ValueError(msg) + self.read_kwargs = dict(self.read_kwargs or {}) def outputs(self) -> tuple[list[str], list[str]]: - output_fields = list(self.fields or self.read_kwargs.get("columns") or []) + scanner_options = self.read_kwargs.get("scanner_options") or {} + columns = self.fields if self.fields is not None else self.read_kwargs.get("columns") + if columns is None: + columns = scanner_options.get("columns") + output_fields = list(columns or []) if self.include_lance_metadata: output_fields.extend([LANCE_ROWADDR_COLUMN, LANCE_FRAGID_COLUMN]) return ["data"], output_fields @@ -187,6 +181,45 @@ def outputs(self) -> tuple[list[str], list[str]]: def _output_metadata(self, task: LanceReadTask, output: ReaderOutput) -> dict[str, Any]: return output.metadata if output.metadata is not None else task._metadata + def _restore_blob_v2_columns(self, dataset: object, table: pa.Table, blob_columns: list[str]) -> pa.Table: + import lance + + rowaddrs = [int(value) for value in table["_rowaddr"].combine_chunks().to_pylist()] + for column in blob_columns: + payloads = [ + payload + for _, payload in dataset.read_blobs(column, addresses=rowaddrs, preserve_order=True) # type: ignore[attr-defined] + ] + table = table.set_column(table.schema.get_field_index(column), column, lance.blob_array(payloads)) + return table + + def _task_version(self, task: LanceReadTask) -> int: + version = (task._metadata.get("lance") or {}).get("version") + if version is None: + msg = f"Lance read task {task.task_id} is missing a pinned Lance version" + raise ValueError(msg) + return version + + def _dataset_kwargs(self, read_kwargs: dict[str, Any], version: int) -> dict[str, Any]: + dataset_kwargs = dict(read_kwargs.pop("dataset_options", {}) or {}) + requested_version = dataset_kwargs.pop("version", None) + requested_version = read_kwargs.pop("version", requested_version) + if requested_version is not None and requested_version != version: + msg = f"Lance read version mismatch: task version={version}, requested version={requested_version}" + raise ValueError(msg) + dataset_kwargs["version"] = version + storage_options = read_kwargs.pop("storage_options", None) + if storage_options is not None: + dataset_kwargs["storage_options"] = storage_options + return dataset_kwargs + + def _scanner_kwargs(self, read_kwargs: dict[str, Any], fields: list[str] | None) -> dict[str, Any]: + scanner_kwargs = dict(read_kwargs.pop("scanner_options", {}) or {}) + scanner_kwargs.update(read_kwargs) + if fields is not None: + scanner_kwargs["columns"] = fields + return scanner_kwargs + def read_task( self, task: LanceReadTask, @@ -196,20 +229,26 @@ def read_task( import lance from lance.schema import schema_to_json - read_kwargs = {} if read_kwargs is None else read_kwargs - version = task._metadata["lance"]["version"] - dataset = lance.dataset(self.path, **_read_dataset_kwargs(read_kwargs, version=version)) + read_kwargs = dict(read_kwargs or {}) + dataset_kwargs = self._dataset_kwargs(read_kwargs, self._task_version(task)) + scanner_kwargs = self._scanner_kwargs(read_kwargs, fields) + dataset = lance.dataset(self.path, **dataset_kwargs) fragments = [dataset.get_fragment(fragment_id) for fragment_id in task.data] - scanner_kwargs = _scanner_kwargs(read_kwargs, fields) - blob_columns = _requested_blob_v2_columns(dataset, scanner_kwargs) + requested_columns = scanner_kwargs.get("columns") + blob_columns = [ + field.name + for field in dataset.schema + if getattr(field.type, "extension_name", None) == "lance.blob.v2" + and (requested_columns is None or field.name in requested_columns) + ] if self.include_lance_metadata or blob_columns: scanner_kwargs["with_row_address"] = True scanner_kwargs["fragments"] = fragments table = dataset.scanner(**scanner_kwargs).to_table() if blob_columns: - table = _restore_lance_blob_v2_columns(dataset, table, blob_columns) + table = self._restore_blob_v2_columns(dataset, table, blob_columns) if self.include_lance_metadata: - table = _add_lance_metadata(table) + table = add_lance_metadata_columns(table) elif blob_columns and "_rowaddr" in table.column_names: table = table.drop_columns(["_rowaddr"]) @@ -222,7 +261,22 @@ def read_task( @dataclass class LanceReader(CompositeStage[EmptyTask, DocumentBatch]): - """Read a Lance dataset into Curator ``DocumentBatch`` objects by fragment.""" + """Composite stage for reading Lance datasets. + + This high-level stage decomposes into: + 1. ``LancePartitioningStage`` - partitions Lance fragments into read tasks. + 2. ``LanceReaderStage`` - reads fragment groups into ``DocumentBatch`` objects. + + Args: + path: Path or URI of the Lance dataset. + fragments_per_partition: Number of Lance fragments assigned to each read task. + fields: Optional columns to read. + read_kwargs: Keyword arguments for Lance dataset and scanner construction. + include_lance_metadata: Whether to include row-address and fragment-id metadata columns. + fragment_ids: Optional explicit fragment ids to read. Defaults to all fragments. Duplicates are ignored. + task_type: Output task type. Only ``"document"`` is currently supported. + """ + path: str fragments_per_partition: int = 32 fields: list[str] | None = None diff --git a/nemo_curator/utils/lance.py b/nemo_curator/utils/lance.py index 0657341d93..3cfa248fe7 100644 --- a/nemo_curator/utils/lance.py +++ b/nemo_curator/utils/lance.py @@ -12,5 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pyarrow as pa + LANCE_ROWADDR_COLUMN = "__lance_rowaddr" LANCE_FRAGID_COLUMN = "__lance_fragid" + + +def lance_fragment_ids_from_row_addresses(rowaddr_column: pa.ChunkedArray) -> pa.Array: + rowaddrs = rowaddr_column.combine_chunks().cast(pa.uint64()) + return pa.array([int(value) >> 32 for value in rowaddrs.to_pylist()], type=pa.uint64()) + + +def add_lance_metadata_columns(table: pa.Table) -> pa.Table: + if "_rowaddr" not in table.column_names: + msg = "Lance scanner did not return _rowaddr; include_lance_metadata requires row addresses" + raise ValueError(msg) + + table = table.rename_columns([LANCE_ROWADDR_COLUMN if name == "_rowaddr" else name for name in table.column_names]) + return table.append_column(LANCE_FRAGID_COLUMN, lance_fragment_ids_from_row_addresses(table[LANCE_ROWADDR_COLUMN])) diff --git a/tests/stages/text/io/reader/test_lance.py b/tests/stages/text/io/reader/test_lance.py index a0d76f060d..54407b79b4 100644 --- a/tests/stages/text/io/reader/test_lance.py +++ b/tests/stages/text/io/reader/test_lance.py @@ -49,7 +49,9 @@ def _write_lance_dataset(path: Path) -> None: ] ), ) - lance.write_dataset(table, str(path), mode="create", max_rows_per_file=2, max_rows_per_group=2, data_storage_version="2.2") + lance.write_dataset( + table, str(path), mode="create", max_rows_per_file=2, max_rows_per_group=2, data_storage_version="2.2" + ) def test_lance_reader_partitions_filters_blobs_and_metadata(tmp_path: Path): @@ -83,21 +85,40 @@ def test_lance_reader_partitions_filters_blobs_and_metadata(tmp_path: Path): assert seen_fragments == {0, 1} +def test_lance_reader_validates_requested_fragments(tmp_path: Path): + dataset_path = tmp_path / "docs.lance" + _write_lance_dataset(dataset_path) + + tasks = LancePartitioningStage(path=str(dataset_path), fragments_per_partition=1, fragment_ids=[1, 0, 1]).process( + EmptyTask + ) + assert [task.data for task in tasks] == [[0], [1]] + + with pytest.raises(ValueError, match="requested fragment ids"): + LancePartitioningStage(path=str(dataset_path), fragment_ids=[999]).process(EmptyTask) + + def test_lance_reader_columns_empty_filters_and_fields_override(tmp_path: Path): dataset_path = tmp_path / "docs.lance" _write_lance_dataset(dataset_path) task = LancePartitioningStage(path=str(dataset_path)).process(EmptyTask)[0] - batch = LanceReaderStage(path=str(dataset_path), read_kwargs={"columns": ["url"]}, include_lance_metadata=False).process(task) + batch = LanceReaderStage( + path=str(dataset_path), read_kwargs={"columns": ["url"]}, include_lance_metadata=False + ).process(task) assert batch.to_pyarrow().column_names == ["url"] - empty_batch = LanceReaderStage(path=str(dataset_path), read_kwargs={"filter": "snapshot_id = 'missing'"}).process(task) + empty_batch = LanceReaderStage(path=str(dataset_path), read_kwargs={"filter": "snapshot_id = 'missing'"}).process( + task + ) empty_table = empty_batch.to_pyarrow() assert empty_table.num_rows == 0 assert LANCE_ROWADDR_COLUMN in empty_table.column_names assert LANCE_FRAGID_COLUMN in empty_table.column_names - _, reader_stage = LanceReader(path="example.lance", fields=["a", "b"], read_kwargs={"columns": ["ignored"]}).decompose() + _, reader_stage = LanceReader( + path="example.lance", fields=["a", "b"], read_kwargs={"columns": ["ignored"]} + ).decompose() assert reader_stage.fields == ["a", "b"] assert reader_stage.include_lance_metadata is True @@ -113,3 +134,21 @@ def test_lance_reader_uses_partition_version(tmp_path: Path): batch = LanceReaderStage(path=str(dataset_path), fields=["text"], include_lance_metadata=False).process(task) assert batch.to_pyarrow()["text"].to_pylist() == ["old"] + + +def test_lance_reader_rejects_conflicting_version(tmp_path: Path): + import lance + + dataset_path = tmp_path / "docs.lance" + lance.write_dataset(pa.table({"text": ["old"]}), str(dataset_path), mode="create", max_rows_per_file=1) + task = LancePartitioningStage(path=str(dataset_path)).process(EmptyTask)[0] + lance.write_dataset(pa.table({"text": ["new"]}), str(dataset_path), mode="overwrite", max_rows_per_file=1) + latest_version = lance.dataset(str(dataset_path)).version + + with pytest.raises(ValueError, match="version mismatch"): + LanceReaderStage( + path=str(dataset_path), + fields=["text"], + read_kwargs={"version": latest_version}, + include_lance_metadata=False, + ).process(task)