Skip to content
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions nemo_curator/stages/text/io/lance_commit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import base64
import pickle
from typing import Any

from nemo_curator.stages.text.io.lance_utils import (
read_lance_checkpoint,
write_lance_checkpoint_marker,
)


def commit_lance_checkpoint(
path: str,
commit_path: str,
*,
storage_options: dict[str, Any] | None = None,
checkpoint_storage_options: dict[str, Any] | None = None,
) -> int:
"""Commit records written by ``LanceWriter`` and return the Lance version."""
import lance
from lance.schema import json_to_schema
from lance_ray import LanceFragmentCommitter

records, committed_version = read_lance_checkpoint(commit_path, "lance_write", checkpoint_storage_options)
if committed_version is not None:
return committed_version

dataset_paths = {record["dataset_path"] for record in records}
if dataset_paths != {path}:
msg = f"Checkpoint records are for {sorted(dataset_paths)}, not {path}"
raise ValueError(msg)
modes = {record["mode"] for record in records}
if len(modes) != 1:
msg = f"Expected one write mode; got {sorted(modes)}"
raise ValueError(msg)
mode = str(next(iter(modes)))
fragments = [
(pickle.loads(base64.b64decode(record["fragment"])), json_to_schema(record["schema"])) # noqa: S301
for record in records
]
schema = fragments[0][1]

committer = LanceFragmentCommitter(path, schema=schema, mode=mode, storage_options=storage_options)
if mode == "append":
committer.on_write_start(schema)
committer.on_write_complete([[(pickle.dumps(fragment), pickle.dumps(schema)) for fragment, schema in fragments]])
version = lance.dataset(path, storage_options=storage_options).version
write_lance_checkpoint_marker(commit_path, version, checkpoint_storage_options)
return version


def commit_lance_annotation_checkpoint(
path: str,
commit_path: str,
*,
storage_options: dict[str, Any] | None = None,
checkpoint_storage_options: dict[str, Any] | None = None,
) -> int:
"""Commit records written by ``LanceAnnotationWriter`` and return the Lance version."""
import lance

records, committed_version = read_lance_checkpoint(
commit_path, "lance_annotation_update", checkpoint_storage_options
)
if committed_version is not None:
return committed_version

dataset_paths = {record["dataset_path"] for record in records}
if dataset_paths != {path}:
msg = f"Checkpoint records are for {sorted(dataset_paths)}, not {path}"
raise ValueError(msg)
read_versions = {int(record["dataset_version"]) for record in records}
if len(read_versions) != 1:
msg = f"Expected one dataset version; got {sorted(read_versions)}"
raise ValueError(msg)
read_version = next(iter(read_versions))
records_by_fragment = {int(record["fragment_id"]): record for record in records}
if len(records_by_fragment) != len(records):
msg = "Ensure each Lance fragment is updated by at most one writer task."
raise ValueError(msg)
updated_fragments = [
pickle.loads(base64.b64decode(record["updated_fragment"])) # noqa: S301
for record in records_by_fragment.values()
]
fields_modified = sorted({field for record in records_by_fragment.values() for field in record["fields_modified"]})
operation = lance.LanceOperation.Update(updated_fragments=updated_fragments, fields_modified=fields_modified)
version = lance.LanceDataset.commit(
path,
operation,
read_version=read_version,
storage_options=storage_options,
).version
write_lance_checkpoint_marker(commit_path, version, checkpoint_storage_options)
return version
89 changes: 89 additions & 0 deletions nemo_curator/stages/text/io/lance_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import posixpath
from typing import Any

from fsspec.core import url_to_fs

from nemo_curator.utils.hash_utils import get_deterministic_hash

LANCE_ROWADDR_COLUMN = "__lance_rowaddr"
LANCE_FRAGID_COLUMN = "__lance_fragid"
_COMMITTED_MARKER = "_COMMITTED"
_RECORDS_DIR = "records"


def lance_checkpoint_record_id(kind: str, *parts: object) -> str:
values = [str(part) for part in parts if part not in {None, ""}]
return f"{kind}-{get_deterministic_hash(values or [kind])}"


def _checkpoint_fs_path(commit_path: str, storage_options: dict[str, Any] | None = None) -> tuple[object, str]:
return url_to_fs(commit_path, **(storage_options or {}))


def _checkpoint_path(fs_path: str, *parts: str) -> str:
return posixpath.join(fs_path.rstrip("/"), *parts)


def write_lance_checkpoint_record(
commit_path: str,
record: dict[str, Any],
record_id: str,
storage_options: dict[str, Any] | None = None,
) -> str:
fs, fs_path = _checkpoint_fs_path(commit_path, storage_options)
records_dir = _checkpoint_path(fs_path, _RECORDS_DIR)
fs.makedirs(records_dir, exist_ok=True)
record_path = _checkpoint_path(records_dir, f"{record_id}.json")
with fs.open(record_path, "w") as stream:
stream.write(json.dumps(record, sort_keys=True) + "\n")
return fs.unstrip_protocol(record_path)


def read_lance_checkpoint(
commit_path: str,
kind: str,
storage_options: dict[str, Any] | None = None,
) -> tuple[list[dict[str, Any]], int | None]:
fs, fs_path = _checkpoint_fs_path(commit_path, storage_options)
marker_path = _checkpoint_path(fs_path, _COMMITTED_MARKER)
if fs.exists(marker_path):
with fs.open(marker_path) as stream:
return [], int(json.loads(stream.read())["version"])

records = []
for record_path in sorted(fs.glob(_checkpoint_path(fs_path, _RECORDS_DIR, "*.json"))):
with fs.open(record_path) as stream:
record = json.loads(stream.read())
if record.get("kind") == kind:
records.append(record)
if not records:
msg = f"No {kind} checkpoint records found under {commit_path}"
raise ValueError(msg)
return records, None


def write_lance_checkpoint_marker(
commit_path: str,
version: int,
storage_options: dict[str, Any] | None = None,
) -> None:
fs, fs_path = _checkpoint_fs_path(commit_path, storage_options)
marker_path = _checkpoint_path(fs_path, _COMMITTED_MARKER)
fs.makedirs(posixpath.dirname(marker_path), exist_ok=True)
with fs.open(marker_path, "w") as stream:
stream.write(json.dumps({"version": version}, sort_keys=True, indent=2) + "\n")
3 changes: 2 additions & 1 deletion nemo_curator/stages/text/io/reader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from nemo_curator.stages.text.io.reader.jsonl import JsonlReader
from nemo_curator.stages.text.io.reader.lance import LanceReader
from nemo_curator.stages.text.io.reader.parquet import ParquetReader

__all__ = ["JsonlReader", "ParquetReader"]
__all__ = ["JsonlReader", "LanceReader", "ParquetReader"]
Loading
Loading