Skip to content

Commit cb513a1

Browse files
committed
Document Lance public APIs
1 parent 82564f7 commit cb513a1

3 files changed

Lines changed: 11 additions & 24 deletions

File tree

nemo_curator/stages/text/io/lance_commit.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def commit_lance_checkpoint(
3131
storage_options: dict[str, Any] | None = None,
3232
checkpoint_storage_options: dict[str, Any] | None = None,
3333
) -> int:
34+
"""Commit records written by ``LanceWriter`` and return the Lance version."""
3435
import lance
3536
from lance.schema import json_to_schema
3637
from lance_ray import LanceFragmentCommitter
@@ -70,6 +71,7 @@ def commit_lance_annotation_checkpoint(
7071
storage_options: dict[str, Any] | None = None,
7172
checkpoint_storage_options: dict[str, Any] | None = None,
7273
) -> int:
74+
"""Commit records written by ``LanceAnnotationWriter`` and return the Lance version."""
7375
import lance
7476

7577
records, committed_version = read_lance_checkpoint(
@@ -103,9 +105,5 @@ def commit_lance_annotation_checkpoint(
103105
read_version=read_version,
104106
storage_options=storage_options,
105107
).version
106-
write_lance_checkpoint_marker(
107-
commit_path,
108-
version,
109-
checkpoint_storage_options,
110-
)
108+
write_lance_checkpoint_marker(commit_path, version, checkpoint_storage_options)
111109
return version

nemo_curator/stages/text/io/reader/lance.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,8 @@
3030

3131

3232
def _read_dataset_kwargs(read_kwargs: dict[str, Any], version: int | None = None) -> dict[str, Any]:
33-
return {
34-
**dict(read_kwargs.get("dataset_options") or {}),
35-
**{
36-
key: value
37-
for key, value in (
38-
("storage_options", read_kwargs.get("storage_options")),
39-
("version", read_kwargs.get("version", version)),
40-
)
41-
if value is not None
42-
},
43-
}
33+
options = {"storage_options": read_kwargs.get("storage_options"), "version": read_kwargs.get("version", version)}
34+
return {**dict(read_kwargs.get("dataset_options") or {}), **{k: v for k, v in options.items() if v is not None}}
4435

4536

4637
def _scanner_kwargs(read_kwargs: dict[str, Any], fields: list[str] | None) -> dict[str, Any]:
@@ -207,6 +198,7 @@ def process(self, task: LanceReadTask) -> DocumentBatch | None:
207198

208199
@dataclass
209200
class LanceReader(CompositeStage[EmptyTask, DocumentBatch]):
201+
"""Read a Lance dataset into Curator ``DocumentBatch`` objects by fragment."""
210202
path: str
211203
fragments_per_partition: int = 32
212204
fields: list[str] | None = None

nemo_curator/stages/text/io/writer/lance.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def _schema_for_table(schema: pa.Schema, table: pa.Table) -> pa.Schema:
5656

5757
@dataclass
5858
class LanceWriter(ProcessingStage[DocumentBatch, FileGroupTask]):
59+
"""Write ``DocumentBatch`` tables to Lance fragments and checkpoint the commit."""
5960
path: str
6061
commit_path: str
6162
schema: pa.Schema | None = None
@@ -125,6 +126,7 @@ def process(self, task: DocumentBatch) -> FileGroupTask:
125126

126127
@dataclass
127128
class LanceAnnotationWriter(ProcessingStage[DocumentBatch, FileGroupTask]):
129+
"""Update existing Lance rows using metadata columns emitted by ``LanceReader``."""
128130
path: str
129131
commit_path: str
130132
schema: pa.Schema
@@ -151,6 +153,7 @@ def outputs(self) -> tuple[list[str], list[str]]:
151153
return ["data"], []
152154

153155
def prepare(self) -> int:
156+
"""Create or validate annotation columns and pin the Lance version for the run."""
154157
import lance
155158

156159
dataset = lance.dataset(self.path, storage_options=self.storage_options)
@@ -208,14 +211,8 @@ def process(self, task: DocumentBatch) -> FileGroupTask:
208211
msg = f"Lance annotation update table is missing required columns: {missing}"
209212
raise ValueError(msg)
210213
version = self._update_version()
211-
dataset = lance.dataset(
212-
self.path,
213-
**{
214-
key: value
215-
for key, value in {"storage_options": self.storage_options, "version": version}.items()
216-
if value is not None
217-
},
218-
)
214+
options = {"storage_options": self.storage_options, "version": version}
215+
dataset = lance.dataset(self.path, **{k: v for k, v in options.items() if v is not None})
219216

220217
record_paths = []
221218
fragment_ids = sorted(int(value) for value in pc.unique(table[LANCE_FRAGID_COLUMN].combine_chunks()).to_pylist())

0 commit comments

Comments
 (0)