Skip to content

Commit 711df81

Browse files
Add HuggingFace dataset upload/download (#310)
* Add HuggingFace dataset upload/download Fixes #309 * Changelog * Add dep
1 parent 9fbe198 commit 711df81

File tree

3 files changed

+84
-2
lines changed

3 files changed

+84
-2
lines changed

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: minor
2+
changes:
3+
added:
4+
- HuggingFace upload/download functionality.

policyengine_core/data/dataset.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import requests
88
import os
99
import tempfile
10+
from huggingface_hub import HfApi, login, hf_hub_download
11+
import pkg_resources
1012

1113

1214
def atomic_write(file: Path, content: bytes) -> None:
@@ -53,6 +55,8 @@ class Dataset:
5355
"""The time period of the dataset. This is used to automatically enter the values in the correct time period if the data type is `Dataset.ARRAYS`."""
5456
url: str = None
5557
"""The URL to download the dataset from. This is used to download the dataset if it does not exist."""
58+
huggingface_url: str = None
59+
"""The HuggingFace URL to download the dataset from. This is used to download the dataset if it does not exist."""
5660

5761
# Data formats
5862
TABLES = "tables"
@@ -306,15 +310,15 @@ def store_file(self, file_path: str):
306310
raise FileNotFoundError(f"File {file_path} does not exist.")
307311
shutil.move(file_path, self.file_path)
308312

309-
def download(self, url: str = None) -> None:
313+
def download(self, url: str = None, version: str = None) -> None:
310314
"""Downloads a file to the dataset's file path.
311315
312316
Args:
313317
url (str): The url to download.
314318
"""
315319

316320
if url is None:
317-
url = self.url
321+
url = self.huggingface_url or self.url
318322

319323
if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ:
320324
auth_headers = {}
@@ -345,6 +349,10 @@ def download(self, url: str = None) -> None:
345349
raise ValueError(
346350
f"File {file_path} not found in release {release_tag} of {org}/{repo}."
347351
)
352+
elif url.startswith("hf://"):
353+
owner_name, model_name = url.split("/")[2:]
354+
self.download_from_huggingface(owner_name, model_name, version)
355+
return
348356
else:
349357
url = url
350358

@@ -363,6 +371,19 @@ def download(self, url: str = None) -> None:
363371

364372
atomic_write(self.file_path, response.content)
365373

374+
def upload(self, url: str = None):
375+
"""Uploads the dataset to a URL.
376+
377+
Args:
378+
url (str): The url to upload.
379+
"""
380+
if url is None:
381+
url = self.huggingface_url or self.url
382+
383+
if url.startswith("hf://"):
384+
owner_name, model_name = url.split("/")[2:]
385+
self.upload_to_huggingface(owner_name, model_name)
386+
366387
def remove(self):
367388
"""Removes the dataset from disk."""
368389
if self.exists:
@@ -414,3 +435,59 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None):
414435
)()
415436

416437
return dataset
438+
439+
def upload_to_huggingface(self, owner_name: str, model_name: str):
440+
"""Uploads the dataset to Hugging Face.
441+
442+
Args:
443+
owner_name (str): The owner name.
444+
model_name (str): The model name.
445+
"""
446+
token = os.environ.get(
447+
"HUGGING_FACE_TOKEN", "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"
448+
)
449+
login(token=token)
450+
api = HfApi()
451+
452+
# Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata.
453+
uk_data_version = get_package_version("policyengine-uk-data")
454+
uk_version = get_package_version("policyengine-uk")
455+
with h5py.File(self.file_path, "a") as f:
456+
f.attrs["policyengine-uk-data"] = uk_data_version
457+
f.attrs["policyengine-uk"] = uk_version
458+
459+
api.upload_file(
460+
path_or_fileobj=self.file_path,
461+
path_in_repo=self.file_path.name,
462+
repo_id=f"{owner_name}/{model_name}",
463+
repo_type="model",
464+
)
465+
466+
def download_from_huggingface(
467+
self, owner_name: str, model_name: str, version: str = None
468+
):
469+
"""Downloads the dataset from Hugging Face.
470+
471+
Args:
472+
owner_name (str): The owner name.
473+
model_name (str): The model name.
474+
"""
475+
token = os.environ.get(
476+
"HUGGING_FACE_TOKEN", "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"
477+
)
478+
login(token=token)
479+
480+
hf_hub_download(
481+
repo_id=f"{owner_name}/{model_name}",
482+
repo_type="model",
483+
path=self.file_path,
484+
revision=version,
485+
)
486+
487+
488+
def get_package_version(package_name: str) -> str:
489+
"""Get the installed version of a package."""
490+
try:
491+
return pkg_resources.get_distribution(package_name).version
492+
except pkg_resources.DistributionNotFound:
493+
return "not installed"

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"ipython>=8,<9",
2525
"pyvis>=0.3.2",
2626
"microdf_python>=0.4.3",
27+
"huggingface_hub>=0.25.1",
2728
]
2829

2930
dev_requirements = [

0 commit comments

Comments
 (0)