diff --git a/plugins/huggingface/README.md b/plugins/huggingface/README.md new file mode 100644 index 000000000..941e55cc4 --- /dev/null +++ b/plugins/huggingface/README.md @@ -0,0 +1,68 @@ +# Hugging Face Datasets Plugin + +Native support for HuggingFace Datasets in Flyte: prefetch datasets from the Hub to remote storage and pass `datasets.Dataset` between tasks with automatic Parquet serialization. + +## Installation + +```bash +pip install flyteplugins-huggingface +``` + +## Prefetch from HuggingFace Hub + +Stream a dataset from the Hub directly to Flyte's remote storage: + +```python +import flyte +from flyteplugins.huggingface import hf_dataset + +flyte.init(endpoint="my-flyte-endpoint") + +run = hf_dataset(repo="stanfordnlp/imdb", split="train") +run.wait() +data_dir = run.outputs()[0] # flyte.io.Dir with parquet files +``` + +## Type transformer + +Pass `datasets.Dataset` between tasks with automatic serialization: + +```python +import flyte +import datasets + +env = flyte.TaskEnvironment( + name="hf-example", + image=flyte.Image.from_debian_base().with_pip_packages( + "flyteplugins-huggingface", + ), +) + + +@env.task +async def create_dataset() -> datasets.Dataset: + return datasets.Dataset.from_dict({ + "text": ["hello", "world", "foo"], + "label": [0, 1, 0], + }) + + +@env.task +async def filter_positive(ds: datasets.Dataset) -> datasets.Dataset: + return ds.filter(lambda x: x["label"] == 1) +``` + +## Column filtering + +Use type annotations to load only specific columns: + +```python +from typing import Annotated +from collections import OrderedDict + +@env.task +async def load_text_only( + ds: Annotated[datasets.Dataset, OrderedDict(text=str)], +) -> list: + return ds["text"] +``` diff --git a/plugins/huggingface/examples/hf_dataset_workflow.py b/plugins/huggingface/examples/hf_dataset_workflow.py new file mode 100644 index 000000000..7fb30e9b6 --- /dev/null +++ b/plugins/huggingface/examples/hf_dataset_workflow.py @@ -0,0 +1,76 @@ +""" +Example: HuggingFace Datasets with Flyte. + +This example demonstrates: +- Prefetching a dataset from HuggingFace Hub to remote storage +- Loading a prefetched Dir into a datasets.Dataset inside a task +- Passing datasets.Dataset between tasks via the type transformer +- Creating and returning new datasets from tasks +""" + +import datasets +import flyte +import pyarrow.parquet as pq + +from flyteplugins.huggingface import hf_dataset + +env = flyte.TaskEnvironment( + name="hf-dataset-example", + image=flyte.Image.from_debian_base(name="hf-dataset-example").with_pip_packages( + "flyteplugins-huggingface", + ), +) + + +@env.task +async def load_from_dir(data_dir: flyte.io.Dir) -> datasets.Dataset: + """Load parquet files from a prefetched Dir into a datasets.Dataset.""" + tables = [] + async for file in data_dir.walk(): + if file.path.endswith(".parquet"): + local = await file.download() + tables.append(pq.read_table(local)) + import pyarrow as pa + + return datasets.Dataset(pa.concat_tables(tables)) + + +@env.task +async def tokenize(ds: datasets.Dataset) -> datasets.Dataset: + """Simple tokenization: add word count column.""" + word_counts = [len(text.split()) for text in ds["text"]] + return ds.add_column("word_count", word_counts) + + +@env.task +async def filter_long(ds: datasets.Dataset) -> datasets.Dataset: + """Keep only rows with more than 100 words.""" + return ds.filter(lambda row: row["word_count"] > 100) + + +@env.task +async def summary(ds: datasets.Dataset) -> str: + return f"{len(ds)} rows, columns: {ds.column_names}" + + +if __name__ == "__main__": + flyte.init() + + # 1. Prefetch dataset from HuggingFace Hub to remote storage + run = hf_dataset(repo="stanfordnlp/imdb", split="train") + run.wait() + data_dir = run.outputs()[0] + + # 2. Load into datasets.Dataset inside a task + run = flyte.with_runcontext("local").run(load_from_dir, data_dir) + ds = run.outputs()[0] + + # 3. Pass datasets.Dataset between tasks via the type transformer + run = flyte.with_runcontext("local").run(tokenize, ds) + tokenized = run.outputs()[0] + + run = flyte.with_runcontext("local").run(filter_long, tokenized) + filtered = run.outputs()[0] + + run = flyte.with_runcontext("local").run(summary, filtered) + print(run.outputs()[0]) diff --git a/plugins/huggingface/pyproject.toml b/plugins/huggingface/pyproject.toml new file mode 100644 index 000000000..21c1db06b --- /dev/null +++ b/plugins/huggingface/pyproject.toml @@ -0,0 +1,85 @@ +[project] +name = "flyteplugins-huggingface" +dynamic = ["version"] +description = "Hugging Face Datasets plugin for flyte" +readme = "README.md" +authors = [{ name = "Flyte Contributors", email = "admin@flyte.org" }] +requires-python = ">=3.10" +dependencies = [ + "datasets>=2.14.5", + "huggingface-hub>=0.27.0", + "pyarrow", + "flyte" +] + +[project.entry-points."flyte.plugins.types"] +huggingface = "flyteplugins.huggingface.df_transformer:register_huggingface_df_transformers" + +[build-system] +requires = ["setuptools", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[dependency-groups] +dev = [ + "pytest>=8.3.5", + "pytest-asyncio>=0.26.0", + "pandas", +] + +[tool.setuptools] +include-package-data = true +license-files = ["licenses/*.txt", "LICENSE"] + +[tool.setuptools.packages.find] +where = ["src"] +include = ["flyteplugins*"] + +[tool.setuptools_scm] +root = "../../" + +[tool.pytest.ini_options] +norecursedirs = [] +log_cli = true +log_cli_level = 20 +markers = [] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" + +[tool.coverage.run] +branch = true + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = [ + "E", + "W", + "F", + "I", + "PLW", + "YTT", + "ASYNC", + "C4", + "T10", + "EXE", + "ISC", + "LOG", + "PIE", + "Q", + "RSE", + "FLY", + "PGH", + "PLC", + "PLE", + "PLW", + "FURB", + "RUF", +] +ignore = ["PGH003", "PLC0415", "ASYNC240"] + +[tool.ruff.lint.per-file-ignores] +"examples/*" = ["E402"] + +[tool.uv.sources] +flyte = { path = "../../", editable = true } diff --git a/plugins/huggingface/src/flyteplugins/huggingface/__init__.py b/plugins/huggingface/src/flyteplugins/huggingface/__init__.py new file mode 100644 index 000000000..e007ca339 --- /dev/null +++ b/plugins/huggingface/src/flyteplugins/huggingface/__init__.py @@ -0,0 +1,6 @@ +from ._prefetch import HuggingFaceDatasetInfo, hf_dataset + +__all__ = [ + "HuggingFaceDatasetInfo", + "hf_dataset", +] diff --git a/plugins/huggingface/src/flyteplugins/huggingface/_prefetch.py b/plugins/huggingface/src/flyteplugins/huggingface/_prefetch.py new file mode 100644 index 000000000..2dbc60752 --- /dev/null +++ b/plugins/huggingface/src/flyteplugins/huggingface/_prefetch.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import os +import re +import tempfile +import typing +from typing import TYPE_CHECKING + +from flyte._logging import logger +from flyte._resources import Resources +from flyte._task_environment import TaskEnvironment +from flyte.io import Dir +from pydantic import BaseModel + +if TYPE_CHECKING: + from flyte.remote import Run + + +HF_IMAGE_PACKAGES = [ + "huggingface-hub>=0.27.0", + "hf-transfer>=0.1.8", +] + + +class HuggingFaceDatasetInfo(BaseModel): + repo: str + name: str | None = None + split: str | None = None + + +def _validate_input_name(value: str | None) -> None: + if value is not None and not re.match(r"^[a-zA-Z0-9_.-]+$", value): + raise ValueError(f"'{value}' must only contain alphanumeric characters, underscores, hyphens, and dots") + + +def _stream_dataset_to_remote( + repo_id: str, + config_name: str | None, + split: str | None, + token: str | None, + remote_dir_path: str, +) -> str: + import flyte.storage as storage + import huggingface_hub + + hfs = huggingface_hub.HfFileSystem(token=token) + fs = storage.get_underlying_filesystem(path=remote_dir_path) + + # HF Hub auto-converts datasets to parquet under refs/convert/parquet. + # This ref is managed by HF and always contains the latest conversion. + # Structure: datasets/{repo}/{config}/{split}/0000.parquet + config = config_name or "default" + base_path = f"datasets/{repo_id}/{config}" + + if split: + split_paths = [(split, f"{base_path}/{split}")] + else: + try: + entries = hfs.ls(base_path, revision="refs/convert/parquet", detail=True) + split_paths = [(e["name"].split("/")[-1], e["name"]) for e in entries if e["type"] == "directory"] + except FileNotFoundError: + split_paths = [("data", base_path)] + + files_streamed = 0 + chunk_size = 64 * 1024 * 1024 + + for split_name, search_path in split_paths: + try: + entries = hfs.ls(search_path, revision="refs/convert/parquet", detail=True) + except FileNotFoundError: + logger.warning(f"Path not found: {search_path}") + continue + + parquet_files = [e for e in entries if e["type"] == "file" and e["name"].endswith(".parquet")] + + for file_info in parquet_files: + file_name = file_info["name"].split("/")[-1] + if split: + remote_file_path = f"{remote_dir_path}/{file_name}" + else: + remote_file_path = f"{remote_dir_path}/{split_name}/{file_name}" + fs.mkdirs(f"{remote_dir_path}/{split_name}", exist_ok=True) + + logger.info(f" Streaming {split_name}/{file_name}...") + + with hfs.open(file_info["name"], "rb", revision="refs/convert/parquet") as src: + with fs.open(remote_file_path, "wb") as dst: + while True: + chunk = src.read(chunk_size) + if not chunk: + break + dst.write(chunk) + + files_streamed += 1 + + if files_streamed == 0: + raise FileNotFoundError( + f"No parquet files found for {repo_id} (config={config}, split={split}). " + f"The dataset may not have been auto-converted to parquet yet." + ) + + logger.info(f"Streamed {files_streamed} parquet files to {remote_dir_path}") + return remote_dir_path + + +def _download_dataset_to_local( + repo_id: str, + config_name: str | None, + split: str | None, + token: str | None, + local_dir: str, + output_dir: str, +) -> str: + import huggingface_hub + + config = config_name or "default" + base_pattern = f"{config}/" + if split: + base_pattern = f"{config}/{split}/" + + files = huggingface_hub.list_repo_files(repo_id, repo_type="dataset", revision="refs/convert/parquet") + parquet_files = [f for f in files if f.startswith(base_pattern) and f.endswith(".parquet")] + + if not parquet_files: + raise FileNotFoundError(f"No parquet files found for {repo_id} (config={config}, split={split}).") + + for pf in parquet_files: + huggingface_hub.hf_hub_download( + repo_id=repo_id, + filename=pf, + repo_type="dataset", + revision="refs/convert/parquet", + local_dir=local_dir, + token=token, + ) + + # Move parquet files to output_dir, preserving relative structure + for root, _dirs, filenames in os.walk(local_dir): + for fname in filenames: + if fname.endswith(".parquet"): + src = os.path.join(root, fname) + rel = os.path.relpath(src, local_dir) + dst = os.path.join(output_dir, rel) + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.rename(src, dst) + + return output_dir + + +# NOTE: the info argument is a json string instead of a HuggingFaceDatasetInfo +# object because the type engine cannot handle nested pydantic or dataclass +# objects when run in interactive mode. +def store_hf_dataset_task(info: str, raw_data_path: str | None = None) -> Dir: + import flyte + + _info = HuggingFaceDatasetInfo.model_validate_json(info) + token = os.environ.get("HF_TOKEN") + + if token is None: + logger.warning("HF_TOKEN not set, using anonymous access. Private datasets will fail.") + + artifact_name = _info.repo.split("/")[-1].replace(".", "-") + if _info.split: + artifact_name = f"{artifact_name}-{_info.split}" + + try: + logger.info("Attempting direct streaming to remote storage...") + + if raw_data_path is not None: + remote_path = raw_data_path + else: + remote_path = flyte.ctx().raw_data_path.get_random_remote_path(artifact_name) + + _stream_dataset_to_remote(_info.repo, _info.name, _info.split, token, remote_path) + result_dir = Dir.from_existing_remote(remote_path) + logger.info(f"Streaming completed to {remote_path}") + + except (OSError, FileNotFoundError) as e: + logger.error(f"Direct streaming failed: {e}") + logger.info("Falling back to snapshot download...") + + with tempfile.TemporaryDirectory() as local_dir, tempfile.TemporaryDirectory() as output_dir: + _download_dataset_to_local(_info.repo, _info.name, _info.split, token, local_dir, output_dir) + result_dir = Dir.from_local_sync(output_dir, remote_destination=raw_data_path) + + logger.info(f"Dataset stored at {result_dir.path}") + return result_dir + + +def hf_dataset( + repo: str, + *, + name: str | None = None, + split: str | None = None, + raw_data_path: str | None = None, + hf_token_key: str = "HF_TOKEN", + resources: Resources = Resources(cpu="2", memory="8Gi", disk="50Gi"), + force: int = 0, +) -> Run: + """Prefetch a HuggingFace dataset to remote storage. + + Streams parquet files from HuggingFace Hub directly to Flyte's remote storage, + returning a Dir that downstream tasks can consume. Always uses the latest + auto-converted parquet from HuggingFace Hub (refs/convert/parquet). + + :param repo: HuggingFace dataset repo ID (e.g., 'stanfordnlp/imdb'). + :param name: Dataset configuration name (default: 'default'). + :param split: Dataset split (e.g., 'train', 'test'). None fetches all splits. + :param raw_data_path: Override remote storage path. + :param hf_token_key: Secret key for HF token. Default: 'HF_TOKEN'. + :param resources: Resources for the prefetch task. + :param force: Increment to force re-prefetch. + :return: A Run object. Call .wait() then .outputs() to get the Dir. + """ + import flyte + from flyte import Secret + from flyte.remote import Run + + _validate_input_name(name) + _validate_input_name(split) + + info = HuggingFaceDatasetInfo( + repo=repo, + name=name, + split=split, + ) + + image = ( + flyte.Image.from_debian_base(name="prefetch-hf-dataset-image") + .with_pip_packages(*HF_IMAGE_PACKAGES) + .with_env_vars({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) + ) + + env = TaskEnvironment( + name="prefetch-hf-dataset", + image=image, + resources=resources, + secrets=[Secret(key=hf_token_key, as_env_var="HF_TOKEN")], + ) + task = env.task()(store_hf_dataset_task) + run = flyte.with_runcontext(interactive_mode=True, disable_run_cache=force > 0).run( + task, info.model_dump_json(), raw_data_path + ) + return typing.cast(Run, run) diff --git a/plugins/huggingface/src/flyteplugins/huggingface/df_transformer.py b/plugins/huggingface/src/flyteplugins/huggingface/df_transformer.py new file mode 100644 index 000000000..75d806052 --- /dev/null +++ b/plugins/huggingface/src/flyteplugins/huggingface/df_transformer.py @@ -0,0 +1,92 @@ +import functools +import os +import typing +from pathlib import Path + +import flyte.storage as storage +import pyarrow.parquet as pq +from flyte.extend import lazy_module +from flyte.io import PARQUET, DataFrame +from flyte.io.extend import ( + DataFrameDecoder, + DataFrameEncoder, + DataFrameTransformerEngine, +) +from flyteidl2.core import literals_pb2, types_pb2 + +if typing.TYPE_CHECKING: + import datasets +else: + datasets = lazy_module("datasets") + + +class HuggingFaceDatasetToParquetEncodingHandler(DataFrameEncoder): + def __init__(self): + super().__init__(datasets.Dataset, None, PARQUET) + + async def encode( + self, + dataframe: DataFrame, + structured_dataset_type: types_pb2.StructuredDatasetType, + ) -> literals_pb2.StructuredDataset: + if not dataframe.uri: + from flyte._context import internal_ctx + + ctx = internal_ctx() + uri = str(ctx.raw_data.get_random_remote_path()) + else: + uri = typing.cast(str, dataframe.uri) + + if not storage.is_remote(uri): + Path(uri).mkdir(parents=True, exist_ok=True) + + path = os.path.join(uri, f"{0:05}.parquet") + df = typing.cast(datasets.Dataset, dataframe.val) + + filesystem = storage.get_underlying_filesystem(path=path) + table = df.data.table + writer = pq.ParquetWriter(path, table.schema, filesystem=filesystem) + try: + for batch in table.to_batches(max_chunksize=10_000): + writer.write_batch(batch) + finally: + writer.close() + + structured_dataset_type.format = PARQUET + return literals_pb2.StructuredDataset( + uri=uri, metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type) + ) + + +class ParquetToHuggingFaceDatasetDecodingHandler(DataFrameDecoder): + def __init__(self): + super().__init__(datasets.Dataset, None, PARQUET) + + async def decode( + self, + flyte_value: literals_pb2.StructuredDataset, + current_task_metadata: literals_pb2.StructuredDatasetMetadata, + ) -> "datasets.Dataset": + uri = flyte_value.uri + columns = None + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] + + parquet_path = os.path.join(uri, f"{0:05}.parquet") + filesystem = storage.get_underlying_filesystem(path=parquet_path) + table = pq.read_table(parquet_path, columns=columns, filesystem=filesystem) + return datasets.Dataset(table) + + +@functools.lru_cache(maxsize=None) +def register_huggingface_df_transformers(): + """Register Hugging Face Dataset encoders and decoders with the DataFrameTransformerEngine. + + This function is called automatically via the flyte.plugins.types entry point + when flyte.init() is called with load_plugin_type_transformers=True (the default). + """ + DataFrameTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler(), default_format_for_type=True) + DataFrameTransformerEngine.register(ParquetToHuggingFaceDatasetDecodingHandler(), default_format_for_type=True) + + +register_huggingface_df_transformers() diff --git a/plugins/huggingface/tests/__init__.py b/plugins/huggingface/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/huggingface/tests/conftest.py b/plugins/huggingface/tests/conftest.py new file mode 100644 index 000000000..9a7861f0b --- /dev/null +++ b/plugins/huggingface/tests/conftest.py @@ -0,0 +1,53 @@ +import os +from pathlib import Path +from unittest.mock import patch + +import pytest +import pytest_asyncio +from flyte._cache.local_cache import LocalTaskCache +from flyte._context import RawDataPath, internal_ctx +from flyte._persistence._db import LocalDB +from flyte.models import SerializationContext + + +@pytest.fixture +def ctx_with_test_raw_data_path(): + """Pytest fixture to set a RawDataPath in the internal_ctx.""" + raw_data_path = RawDataPath.from_local_folder() + ctx = internal_ctx() + new_context = ctx.new_raw_data_path(raw_data_path=raw_data_path) + with new_context as ctx: + yield ctx + + +@pytest.fixture +def dummy_serialization_context(): + yield SerializationContext( + code_bundle=None, + version="abc123", + input_path="s3://bucket/test/run/inputs.pb", + output_path="s3://bucket/outputs/0/jfkljfa/0", + root_dir=Path.cwd(), + ) + + +@pytest_asyncio.fixture(autouse=True) +async def isolate_local_cache(tmp_path): + """ + Global fixture to isolate LocalTaskCache for each test. + Uses temporary directory to avoid polluting local development cache. + """ + with patch.object(LocalDB, "_get_db_path", staticmethod(lambda: str(tmp_path / "test_cache.db"))): + LocalDB._initialized = False + LocalDB._conn = None + LocalDB._conn_sync = None + yield + await LocalTaskCache.close() + + +@pytest.fixture(autouse=True) +def patch_os_exit(monkeypatch): + def mock_exit(code): + raise SystemExit(code) + + monkeypatch.setattr(os, "_exit", mock_exit) diff --git a/plugins/huggingface/tests/test_df_transformer.py b/plugins/huggingface/tests/test_df_transformer.py new file mode 100644 index 000000000..a82ebfc39 --- /dev/null +++ b/plugins/huggingface/tests/test_df_transformer.py @@ -0,0 +1,237 @@ +import typing +from collections import OrderedDict + +import flyte +import pytest +from flyte.io._dataframe import DataFrame +from flyte.io._dataframe.dataframe import PARQUET, DataFrameTransformerEngine +from flyte.types import TypeEngine + +# Import huggingface handlers to register them +import flyteplugins.huggingface.df_transformer # noqa: F401 +from flyteplugins.huggingface.df_transformer import ( + HuggingFaceDatasetToParquetEncodingHandler, + ParquetToHuggingFaceDatasetDecodingHandler, +) + +datasets = pytest.importorskip("datasets") +pd = pytest.importorskip("pandas") + +# Sample data for testing +TEST_DATA = {"name": ["Alice", "Bob", "Charlie"], "age": [25, 30, 35], "city": ["NYC", "SF", "LA"]} + + +@pytest.fixture +def sample_dataset(): + """Create a sample HuggingFace Dataset for testing.""" + return datasets.Dataset.from_pandas(pd.DataFrame(TEST_DATA)) + + +# ============================================================================ +# Type recognition tests +# ============================================================================ + + +def test_types_huggingface_dataset(): + """Test that HuggingFace Dataset type is recognized.""" + pt = datasets.Dataset + lt = TypeEngine.to_literal_type(pt) + assert lt.structured_dataset_type is not None + assert lt.structured_dataset_type.format == "" + assert lt.structured_dataset_type.columns == [] + + +def test_types_dataset_with_columns(): + """Test that HuggingFace Dataset with column annotations is recognized.""" + my_cols = OrderedDict(name=str, age=int, city=str) + pt = typing.Annotated[datasets.Dataset, my_cols] + lt = TypeEngine.to_literal_type(pt) + assert lt.structured_dataset_type is not None + assert len(lt.structured_dataset_type.columns) == 3 + assert lt.structured_dataset_type.columns[0].name == "name" + assert lt.structured_dataset_type.columns[1].name == "age" + assert lt.structured_dataset_type.columns[2].name == "city" + + +def test_types_dataset_with_format(): + """Test that HuggingFace Dataset with format annotation is recognized.""" + pt = typing.Annotated[datasets.Dataset, PARQUET] + lt = TypeEngine.to_literal_type(pt) + assert lt.structured_dataset_type is not None + assert lt.structured_dataset_type.format == PARQUET + + +def test_types_dataset_with_columns_and_format(): + """Test that HuggingFace Dataset with both columns and format is recognized.""" + my_cols = OrderedDict(name=str, age=int) + pt = typing.Annotated[datasets.Dataset, my_cols, PARQUET] + lt = TypeEngine.to_literal_type(pt) + assert lt.structured_dataset_type is not None + assert len(lt.structured_dataset_type.columns) == 2 + assert lt.structured_dataset_type.format == PARQUET + + +# ============================================================================ +# Handler registration tests +# ============================================================================ + + +def test_retrieving_encoder(): + """Test that encoders can be retrieved for HuggingFace Dataset.""" + assert DataFrameTransformerEngine.get_encoder(datasets.Dataset, "file", PARQUET) is not None + assert DataFrameTransformerEngine.get_encoder( + datasets.Dataset, "file", "" + ) is DataFrameTransformerEngine.get_encoder(datasets.Dataset, "file", PARQUET) + + +def test_decoder_registered(): + """Test that decoder can be retrieved for HuggingFace Dataset.""" + assert DataFrameTransformerEngine.get_decoder(datasets.Dataset, "file", PARQUET) is not None + + +def test_handler_properties(): + """Test that handler properties are correctly set.""" + encoder = HuggingFaceDatasetToParquetEncodingHandler() + assert encoder.python_type is datasets.Dataset + assert encoder.protocol is None + assert encoder.supported_format == PARQUET + + decoder = ParquetToHuggingFaceDatasetDecodingHandler() + assert decoder.python_type is datasets.Dataset + assert decoder.protocol is None + assert decoder.supported_format == PARQUET + + +# ============================================================================ +# Encode/decode roundtrip tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_to_literal_dataset(ctx_with_test_raw_data_path, sample_dataset): + """Test encoding HuggingFace Dataset to literal.""" + lt = TypeEngine.to_literal_type(datasets.Dataset) + fdt = DataFrameTransformerEngine() + + lit = await fdt.to_literal(sample_dataset, python_type=datasets.Dataset, expected=lt) + assert lit.scalar.structured_dataset.metadata.structured_dataset_type.format == PARQUET + assert lit.scalar.structured_dataset.uri is not None + + restored = await fdt.to_python_value(lit, expected_python_type=datasets.Dataset) + assert isinstance(restored, datasets.Dataset) + assert len(restored) == len(sample_dataset) + assert set(restored.column_names) == set(sample_dataset.column_names) + + +@pytest.mark.asyncio +async def test_dataset_roundtrip(ctx_with_test_raw_data_path, sample_dataset): + """Test roundtrip encoding/decoding of HuggingFace Dataset.""" + fdt = DataFrameTransformerEngine() + lt = TypeEngine.to_literal_type(datasets.Dataset) + + lit = await fdt.to_literal(sample_dataset, python_type=datasets.Dataset, expected=lt) + restored = await fdt.to_python_value(lit, expected_python_type=datasets.Dataset) + + assert len(restored) == len(sample_dataset) + assert restored.column_names == sample_dataset.column_names + for col in sample_dataset.column_names: + assert restored[col] == sample_dataset[col] + + +@pytest.mark.asyncio +async def test_dataset_through_flyte_dataframe(ctx_with_test_raw_data_path, sample_dataset): + """Test using HuggingFace Dataset through Flyte DataFrame wrapper.""" + fdt = DataFrameTransformerEngine() + lt = TypeEngine.to_literal_type(datasets.Dataset) + + fdf = DataFrame.from_df(val=sample_dataset) + + lit = await fdt.to_literal(fdf, python_type=datasets.Dataset, expected=lt) + assert lit.scalar.structured_dataset.metadata.structured_dataset_type.format == PARQUET + + restored = await fdt.to_python_value(lit, expected_python_type=datasets.Dataset) + assert isinstance(restored, datasets.Dataset) + assert len(restored) == len(sample_dataset) + + +@pytest.mark.asyncio +async def test_raw_dataset_io(ctx_with_test_raw_data_path, sample_dataset): + """Test using raw HuggingFace Dataset as task input/output.""" + flyte.init() + env = flyte.TaskEnvironment(name="test-hf-dataset") + + @env.task + async def process_dataset(ds: datasets.Dataset) -> datasets.Dataset: + return ds.select(range(2)) + + run = flyte.with_runcontext("local").run(process_dataset, sample_dataset) + result = run.outputs()[0] + assert isinstance(result, datasets.Dataset) + assert len(result) == 2 + + +# ============================================================================ +# Column subsetting tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_dataset_column_subsetting(ctx_with_test_raw_data_path, sample_dataset): + """Test that decoding with column annotations subsets the data.""" + fdt = DataFrameTransformerEngine() + lt = TypeEngine.to_literal_type(datasets.Dataset) + + lit = await fdt.to_literal(sample_dataset, python_type=datasets.Dataset, expected=lt) + + my_cols = OrderedDict(name=str, age=int) + annotated_type = typing.Annotated[datasets.Dataset, my_cols] + restored = await fdt.to_python_value(lit, expected_python_type=annotated_type) + + assert isinstance(restored, datasets.Dataset) + assert set(restored.column_names) == {"name", "age"} + + +# ============================================================================ +# Data type tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_dataset_with_various_types(ctx_with_test_raw_data_path): + """Test roundtrip with various data types.""" + df = pd.DataFrame( + { + "int_col": [1, 2, 3], + "float_col": [1.1, 2.2, 3.3], + "str_col": ["a", "b", "c"], + "bool_col": [True, False, True], + } + ) + ds = datasets.Dataset.from_pandas(df) + + fdt = DataFrameTransformerEngine() + lt = TypeEngine.to_literal_type(datasets.Dataset) + + lit = await fdt.to_literal(ds, python_type=datasets.Dataset, expected=lt) + restored = await fdt.to_python_value(lit, expected_python_type=datasets.Dataset) + + assert len(restored) == len(ds) + assert restored["int_col"] == ds["int_col"] + assert restored["str_col"] == ds["str_col"] + assert restored["bool_col"] == ds["bool_col"] + + +@pytest.mark.asyncio +async def test_empty_dataset(ctx_with_test_raw_data_path): + """Test roundtrip with empty Dataset.""" + empty_ds = datasets.Dataset.from_pandas(pd.DataFrame({"name": [], "age": []})) + + fdt = DataFrameTransformerEngine() + lt = TypeEngine.to_literal_type(datasets.Dataset) + + lit = await fdt.to_literal(empty_ds, python_type=datasets.Dataset, expected=lt) + restored = await fdt.to_python_value(lit, expected_python_type=datasets.Dataset) + + assert isinstance(restored, datasets.Dataset) + assert len(restored) == 0 + assert set(restored.column_names) == {"name", "age"} diff --git a/plugins/huggingface/tests/test_prefetch.py b/plugins/huggingface/tests/test_prefetch.py new file mode 100644 index 000000000..2f88f343e --- /dev/null +++ b/plugins/huggingface/tests/test_prefetch.py @@ -0,0 +1,252 @@ +import os +import tempfile +from io import BytesIO +from unittest.mock import MagicMock, patch + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from flyteplugins.huggingface._prefetch import ( + HuggingFaceDatasetInfo, + _download_dataset_to_local, + _stream_dataset_to_remote, + _validate_input_name, + store_hf_dataset_task, +) + +# ============================================================================ +# Input validation tests +# ============================================================================ + + +def test_validate_input_name_valid(): + _validate_input_name("my-dataset_v1") + _validate_input_name("IMDB") + _validate_input_name("v1.0") + _validate_input_name(None) + + +def test_validate_input_name_invalid(): + with pytest.raises(ValueError, match="must only contain"): + _validate_input_name("my dataset") + with pytest.raises(ValueError, match="must only contain"): + _validate_input_name("data/set") + with pytest.raises(ValueError, match="must only contain"): + _validate_input_name("../etc") + + +# ============================================================================ +# HuggingFaceDatasetInfo tests +# ============================================================================ + + +def test_dataset_info_serialization(): + info = HuggingFaceDatasetInfo(repo="stanfordnlp/imdb", split="train") + dumped = info.model_dump_json() + restored = HuggingFaceDatasetInfo.model_validate_json(dumped) + assert restored.repo == "stanfordnlp/imdb" + assert restored.split == "train" + assert restored.name is None + + +def test_dataset_info_defaults(): + info = HuggingFaceDatasetInfo(repo="squad") + assert info.name is None + assert info.split is None + + +# ============================================================================ +# Streaming tests +# ============================================================================ + + +def _make_mock_hub(parquet_entries, split="train"): + """Helper: create a mock huggingface_hub module with given parquet entries.""" + mock_hfs = MagicMock() + + def ls_side_effect(path, revision=None, detail=True): + if path.endswith(split): + return parquet_entries + return [] + + mock_hfs.ls.side_effect = ls_side_effect + mock_hfs.open.side_effect = lambda name, mode="rb", revision=None: BytesIO(b"fake-parquet-content") + + mock_hub = MagicMock() + mock_hub.HfFileSystem.return_value = mock_hfs + return mock_hub + + +def test_stream_dataset_no_parquet_files(): + mock_hfs = MagicMock() + mock_hfs.ls.return_value = [] + + mock_hub = MagicMock() + mock_hub.HfFileSystem.return_value = mock_hfs + + mock_fs = MagicMock() + + with ( + patch.dict("sys.modules", {"huggingface_hub": mock_hub}), + patch("flyte.storage.get_underlying_filesystem", return_value=mock_fs), + ): + with pytest.raises(FileNotFoundError, match="No parquet files found"): + _stream_dataset_to_remote("fake/dataset", None, "train", None, "s3://bucket/output") + + +def test_stream_dataset_single_split(): + entries = [ + {"type": "file", "name": "datasets/org/ds/default/train/0000.parquet"}, + {"type": "file", "name": "datasets/org/ds/default/train/0001.parquet"}, + ] + mock_hub = _make_mock_hub(entries, split="train") + mock_fs = MagicMock() + mock_fs.open.return_value.__enter__ = MagicMock(return_value=BytesIO()) + mock_fs.open.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch.dict("sys.modules", {"huggingface_hub": mock_hub}), + patch("flyte.storage.get_underlying_filesystem", return_value=mock_fs), + ): + result = _stream_dataset_to_remote("org/ds", None, "train", None, "s3://bucket/out") + assert result == "s3://bucket/out" + write_calls = [str(c) for c in mock_fs.open.call_args_list] + assert any("0000.parquet" in c for c in write_calls) + assert any("0001.parquet" in c for c in write_calls) + + +def test_stream_dataset_multi_split_preserves_split_dirs(): + """When split=None, parquet files from different splits should go into separate subdirs.""" + mock_hfs = MagicMock() + + def ls_side_effect(path, revision=None, detail=True): + if path == "datasets/org/ds/default": + return [ + {"type": "directory", "name": "datasets/org/ds/default/train"}, + {"type": "directory", "name": "datasets/org/ds/default/test"}, + ] + elif path.endswith("/train"): + return [{"type": "file", "name": "datasets/org/ds/default/train/0000.parquet"}] + elif path.endswith("/test"): + return [{"type": "file", "name": "datasets/org/ds/default/test/0000.parquet"}] + return [] + + mock_hfs.ls.side_effect = ls_side_effect + mock_hfs.open.side_effect = lambda name, mode="rb", revision=None: BytesIO(b"data") + + mock_hub = MagicMock() + mock_hub.HfFileSystem.return_value = mock_hfs + + mock_fs = MagicMock() + mock_fs.open.return_value.__enter__ = MagicMock(return_value=BytesIO()) + mock_fs.open.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch.dict("sys.modules", {"huggingface_hub": mock_hub}), + patch("flyte.storage.get_underlying_filesystem", return_value=mock_fs), + ): + result = _stream_dataset_to_remote("org/ds", None, None, None, "s3://bucket/out") + assert result == "s3://bucket/out" + open_paths = [str(c) for c in mock_fs.open.call_args_list] + assert any("train/0000.parquet" in p for p in open_paths) + assert any("test/0000.parquet" in p for p in open_paths) + mkdirs_calls = [str(c) for c in mock_fs.mkdirs.call_args_list] + assert any("train" in c for c in mkdirs_calls) + assert any("test" in c for c in mkdirs_calls) + + +# ============================================================================ +# Download fallback tests +# ============================================================================ + + +def test_download_dataset_to_local_no_files(): + mock_hub = MagicMock() + mock_hub.list_repo_files.return_value = ["other/file.txt"] + + with patch.dict("sys.modules", {"huggingface_hub": mock_hub}): + with tempfile.TemporaryDirectory() as local_dir, tempfile.TemporaryDirectory() as output_dir: + with pytest.raises(FileNotFoundError, match="No parquet files found"): + _download_dataset_to_local("fake/ds", None, "train", None, local_dir, output_dir) + + +def test_download_dataset_preserves_structure(): + mock_hub = MagicMock() + mock_hub.list_repo_files.return_value = [ + "default/train/0000.parquet", + "default/train/0001.parquet", + ] + + def fake_download(repo_id, filename, repo_type, revision, local_dir, token): + dest = os.path.join(local_dir, filename) + os.makedirs(os.path.dirname(dest), exist_ok=True) + table = pa.table({"col": [1]}) + pq.write_table(table, dest) + + mock_hub.hf_hub_download.side_effect = fake_download + + with patch.dict("sys.modules", {"huggingface_hub": mock_hub}): + with tempfile.TemporaryDirectory() as local_dir, tempfile.TemporaryDirectory() as output_dir: + result = _download_dataset_to_local("org/ds", None, "train", None, local_dir, output_dir) + assert result == output_dir + # Files should preserve relative path from local_dir + assert os.path.exists(os.path.join(output_dir, "default", "train", "0000.parquet")) + assert os.path.exists(os.path.join(output_dir, "default", "train", "0001.parquet")) + + +def test_download_dataset_multi_split_no_collision(): + """Multi-split download should not overwrite files with the same basename.""" + mock_hub = MagicMock() + mock_hub.list_repo_files.return_value = [ + "default/train/0000.parquet", + "default/test/0000.parquet", + ] + + def fake_download(repo_id, filename, repo_type, revision, local_dir, token): + dest = os.path.join(local_dir, filename) + os.makedirs(os.path.dirname(dest), exist_ok=True) + table = pa.table({"col": [1]}) + pq.write_table(table, dest) + + mock_hub.hf_hub_download.side_effect = fake_download + + with patch.dict("sys.modules", {"huggingface_hub": mock_hub}): + with tempfile.TemporaryDirectory() as local_dir, tempfile.TemporaryDirectory() as output_dir: + _download_dataset_to_local("org/ds", None, None, None, local_dir, output_dir) + # Both files should exist in separate directories + assert os.path.exists(os.path.join(output_dir, "default", "train", "0000.parquet")) + assert os.path.exists(os.path.join(output_dir, "default", "test", "0000.parquet")) + + +# ============================================================================ +# store_hf_dataset_task tests +# ============================================================================ + + +def test_store_hf_dataset_task_warns_no_token(): + info = HuggingFaceDatasetInfo(repo="org/ds", split="train") + + mock_hub = MagicMock() + mock_hfs = MagicMock() + mock_hfs.ls.return_value = [{"type": "file", "name": "datasets/org/ds/default/train/0000.parquet"}] + mock_hfs.open.return_value = BytesIO(b"data") + mock_hub.HfFileSystem.return_value = mock_hfs + + mock_fs = MagicMock() + mock_fs.open.return_value.__enter__ = MagicMock(return_value=BytesIO()) + mock_fs.open.return_value.__exit__ = MagicMock(return_value=False) + + mock_ctx = MagicMock() + mock_ctx.raw_data_path.get_random_remote_path.return_value = "/tmp/test-output" + + with ( + patch.dict("sys.modules", {"huggingface_hub": mock_hub}), + patch.dict(os.environ, {}, clear=True), + patch("flyte.storage.get_underlying_filesystem", return_value=mock_fs), + patch("flyte.ctx", return_value=mock_ctx), + patch("flyte.io.Dir.from_existing_remote") as mock_dir, + ): + mock_dir.return_value = MagicMock(path="/tmp/test-output") + result = store_hf_dataset_task(info.model_dump_json()) + assert result.path == "/tmp/test-output"