Skip to content

Commit 07bdabf

Browse files
committed
incorporate suggestions @andreahlert and @cosmicBboy
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
1 parent 312f14b commit 07bdabf

11 files changed

Lines changed: 333 additions & 98 deletions

File tree

plugins/huggingface/README.md

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# Hugging Face Datasets Plugin
1+
# Hugging Face Plugin
22

3-
Native Flyte support for Hugging Face `datasets.Dataset` and
4-
`datasets.IterableDataset`.
3+
Native Flyte support for Hugging Face integrations in Flyte.
54

6-
This plugin gives you two related capabilities:
5+
This plugin provides dataset support for Hugging Face `datasets.Dataset`
6+
and `datasets.IterableDataset` objects. It gives you two related capabilities:
77

88
1. Use `from_hf(...)` to reference a dataset on the Hugging Face Hub as a task
99
input default.
@@ -26,7 +26,7 @@ pip install flyteplugins-huggingface
2626
```python
2727
import datasets
2828
import flyte
29-
from flyteplugins.huggingface import from_hf
29+
from flyteplugins.huggingface.datasets import from_hf
3030

3131
env = flyte.TaskEnvironment(name="hf-example")
3232

@@ -56,7 +56,7 @@ reference used between Flyte and the plugin.
5656
`from_hf(...)` is the entry point for Hub-backed task defaults:
5757

5858
```python
59-
from flyteplugins.huggingface import from_hf
59+
from flyteplugins.huggingface.datasets import from_hf
6060

6161
from_hf(
6262
repo: str,
@@ -142,7 +142,7 @@ dataset.
142142
Without `cache_root`, a Hub source is materialized into a generated path for the
143143
current execution only.
144144

145-
With `cache_root`, the plugin uses a shared artifact registry so later runs can
145+
With `cache_root`, the plugin uses a shared cache registry so later runs can
146146
skip the Hub download entirely:
147147

148148
```python
@@ -158,7 +158,7 @@ async def train_cached(
158158
return len(ds)
159159
```
160160

161-
The shared registry layout is:
161+
The shared cache layout is:
162162

163163
```text
164164
{cache_root}/huggingface/datasets/
@@ -177,6 +177,10 @@ The cache key is derived from:
177177
This means the cache is stable across runs as long as the underlying converted
178178
Parquet source does not change.
179179

180+
The canonical artifact location is always
181+
`{cache_root}/huggingface/datasets/blobs/{source-cache-key}/...`. The registry
182+
record under `by-key/` is metadata for that cache key.
183+
180184
## What the plugin logs
181185

182186
When `LOG_LEVEL` is `INFO` or lower, the plugin logs whether it is:

plugins/huggingface/examples/hf_dataset_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from flyte._image import PythonWheels
2323
from flyte.io import DataFrame
2424

25-
from flyteplugins.huggingface import from_hf
25+
from flyteplugins.huggingface.datasets import from_hf
2626

2727
env = flyte.TaskEnvironment(
2828
name="hf-dataset-example",

plugins/huggingface/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "flyteplugins-huggingface"
33
dynamic = ["version"]
4-
description = "Hugging Face Datasets plugin for flyte"
4+
description = "Hugging Face Plugin for Flyte"
55
readme = "README.md"
66
authors = [{ name = "André Ahlert", email = "andre@aex.partners" }, { name = "Samhita Alla", email = "samhita@union.ai" }]
77
requires-python = ">=3.10"
@@ -13,7 +13,7 @@ dependencies = [
1313
]
1414

1515
[project.entry-points."flyte.plugins.types"]
16-
huggingface = "flyteplugins.huggingface:register_huggingface_dataset_transformers"
16+
huggingface = "flyteplugins.huggingface.datasets:register_huggingface_dataset_transformers"
1717

1818
[build-system]
1919
requires = ["setuptools", "setuptools_scm"]
Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +0,0 @@
1-
from .dataset import (
2-
HFSource,
3-
from_hf,
4-
register_huggingface_dataset_transformers,
5-
)
6-
7-
__all__ = ["HFSource", "from_hf"]
8-
9-
10-
register_huggingface_dataset_transformers()

plugins/huggingface/src/flyteplugins/huggingface/dataset/__init__.py renamed to plugins/huggingface/src/flyteplugins/huggingface/datasets/__init__.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
import functools
22

33
from ._source import HFSource, from_hf
4-
from ._transformers import (
5-
HFToHuggingFaceDatasetDecodingHandler,
6-
HFToHuggingFaceIterableDatasetDecodingHandler,
7-
HuggingFaceDatasetToParquetEncodingHandler,
8-
HuggingFaceIterableDatasetToParquetEncodingHandler,
9-
ParquetToHuggingFaceDatasetDecodingHandler,
10-
ParquetToHuggingFaceIterableDatasetDecodingHandler,
11-
)
124

135
__all__ = ["HFSource", "from_hf"]
146

@@ -18,6 +10,15 @@ def register_huggingface_dataset_transformers():
1810
"""Register Hugging Face Dataset encoders and decoders."""
1911
from flyte.io.extend import DataFrameTransformerEngine
2012

13+
from ._transformers import (
14+
HFToHuggingFaceDatasetDecodingHandler,
15+
HFToHuggingFaceIterableDatasetDecodingHandler,
16+
HuggingFaceDatasetToParquetEncodingHandler,
17+
HuggingFaceIterableDatasetToParquetEncodingHandler,
18+
ParquetToHuggingFaceDatasetDecodingHandler,
19+
ParquetToHuggingFaceIterableDatasetDecodingHandler,
20+
)
21+
2122
DataFrameTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler(), default_format_for_type=True)
2223
DataFrameTransformerEngine.register(ParquetToHuggingFaceDatasetDecodingHandler(), default_format_for_type=True)
2324
DataFrameTransformerEngine.register(HFToHuggingFaceDatasetDecodingHandler())
@@ -30,6 +31,3 @@ def register_huggingface_dataset_transformers():
3031
default_format_for_type=True,
3132
)
3233
DataFrameTransformerEngine.register(HFToHuggingFaceIterableDatasetDecodingHandler())
33-
34-
35-
register_huggingface_dataset_transformers()

plugins/huggingface/src/flyteplugins/huggingface/dataset/_io.py renamed to plugins/huggingface/src/flyteplugins/huggingface/datasets/_io.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ async def list_parquet_files(uri: str, filesystem) -> list[str]:
146146
raw = [f"{proto}{f}" for f in raw]
147147
return raw
148148
except Exception as e:
149-
logger.debug(f"Unable to list parquet files under {uri}: {e}")
149+
logger.warning(f"Unable to list parquet files under {uri}: {e}")
150150
return [join_uri_path(uri, f"{0:05}.parquet")]
151151

152152

@@ -307,9 +307,7 @@ def hf_cache_manifest(
307307
}
308308

309309

310-
async def read_cache_manifest(
311-
remote_path: str,
312-
) -> dict[str, typing.Any] | None:
310+
async def read_cache_manifest(remote_path: str) -> dict[str, typing.Any] | None:
313311
path = manifest_path(remote_path)
314312
try:
315313
if not await storage_path_exists(path):
@@ -348,13 +346,9 @@ async def read_registry_record(
348346
async def write_registry_record(
349347
source: HFSource,
350348
cache_key: str,
351-
artifact_uri: str,
352349
manifest: dict[str, typing.Any],
353350
) -> None:
354-
record = {
355-
**manifest,
356-
"artifact_uri": artifact_uri,
357-
}
351+
record = dict(manifest)
358352
data = json.dumps(record, sort_keys=True, indent=2).encode("utf-8")
359353
await storage_write_bytes(get_hf_registry_record_path(source, cache_key), data)
360354

@@ -481,32 +475,25 @@ async def ensure_hf_cached(source: HFSource) -> str:
481475
f"under {source.cache_root}"
482476
)
483477
registry_record = await read_registry_record(source, cache_key)
484-
remote_path = (
485-
registry_record.get("artifact_uri", default_remote_path) if registry_record is not None else default_remote_path
486-
)
487-
488-
if await read_cache_manifest(remote_path) == expected_manifest:
478+
if await read_cache_manifest(default_remote_path) == expected_manifest:
489479
if registry_record is None:
490480
await write_registry_record(
491481
source,
492482
cache_key,
493-
remote_path,
494483
expected_manifest,
495484
)
496-
logger.info(f"Using cached Hugging Face dataset at {remote_path}")
497-
return remote_path
485+
logger.info(f"Using cached Hugging Face dataset at {default_remote_path}")
486+
return default_remote_path
498487

499488
logger.info(
500489
f"Materializing Hugging Face dataset {source.repo} "
501490
f"({_source_log_description(source)}) "
502-
f"to remote cache artifact {remote_path}"
491+
f"to remote cache artifact {default_remote_path}"
503492
)
504-
remote_path = default_remote_path
505-
await stream_hf_to_remote(source, remote_path, shards, expected_manifest)
493+
await stream_hf_to_remote(source, default_remote_path, shards, expected_manifest)
506494
await write_registry_record(
507495
source,
508496
cache_key,
509-
remote_path,
510497
expected_manifest,
511498
)
512-
return remote_path
499+
return default_remote_path

plugins/huggingface/src/flyteplugins/huggingface/dataset/_source.py renamed to plugins/huggingface/src/flyteplugins/huggingface/datasets/_source.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,29 @@ class HFSource:
2121
revision: str | None = None
2222
cache_root: str | None = None
2323

24+
def __post_init__(self) -> None:
25+
self.repo = self._normalize_required_field("repo", self.repo)
26+
self.name = self._normalize_optional_field("name", self.name)
27+
self.split = self._normalize_optional_field("split", self.split)
28+
self.revision = self._normalize_optional_field("revision", self.revision)
29+
self.cache_root = self._normalize_optional_field("cache_root", self.cache_root)
30+
31+
@staticmethod
32+
def _normalize_required_field(field_name: str, value: str) -> str:
33+
normalized = value.strip()
34+
if not normalized:
35+
raise ValueError(f"HFSource {field_name} must not be empty")
36+
return normalized
37+
38+
@staticmethod
39+
def _normalize_optional_field(field_name: str, value: str | None) -> str | None:
40+
if value is None:
41+
return None
42+
normalized = value.strip()
43+
if not normalized:
44+
raise ValueError(f"HFSource {field_name} must not be blank")
45+
return normalized
46+
2447
def to_hf_uri(self) -> str:
2548
uri = f"hf://{self.repo}"
2649
params = {}

0 commit comments

Comments
 (0)