7
7
import requests
8
8
import os
9
9
import tempfile
10
+ from huggingface_hub import HfApi , login , hf_hub_download
11
+ import pkg_resources
10
12
11
13
12
14
def atomic_write (file : Path , content : bytes ) -> None :
@@ -53,6 +55,8 @@ class Dataset:
53
55
"""The time period of the dataset. This is used to automatically enter the values in the correct time period if the data type is `Dataset.ARRAYS`."""
54
56
url : str = None
55
57
"""The URL to download the dataset from. This is used to download the dataset if it does not exist."""
58
+ huggingface_url : str = None
59
+ """The HuggingFace URL to download the dataset from. This is used to download the dataset if it does not exist."""
56
60
57
61
# Data formats
58
62
TABLES = "tables"
@@ -306,15 +310,15 @@ def store_file(self, file_path: str):
306
310
raise FileNotFoundError (f"File { file_path } does not exist." )
307
311
shutil .move (file_path , self .file_path )
308
312
309
- def download (self , url : str = None ) -> None :
313
+ def download (self , url : str = None , version : str = None ) -> None :
310
314
"""Downloads a file to the dataset's file path.
311
315
312
316
Args:
313
317
url (str): The url to download.
314
318
"""
315
319
316
320
if url is None :
317
- url = self .url
321
+ url = self .huggingface_url or self . url
318
322
319
323
if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os .environ :
320
324
auth_headers = {}
@@ -345,6 +349,10 @@ def download(self, url: str = None) -> None:
345
349
raise ValueError (
346
350
f"File { file_path } not found in release { release_tag } of { org } /{ repo } ."
347
351
)
352
+ elif url .startswith ("hf://" ):
353
+ owner_name , model_name = url .split ("/" )[2 :]
354
+ self .download_from_huggingface (owner_name , model_name , version )
355
+ return
348
356
else :
349
357
url = url
350
358
@@ -363,6 +371,19 @@ def download(self, url: str = None) -> None:
363
371
364
372
atomic_write (self .file_path , response .content )
365
373
374
+ def upload (self , url : str = None ):
375
+ """Uploads the dataset to a URL.
376
+
377
+ Args:
378
+ url (str): The url to upload.
379
+ """
380
+ if url is None :
381
+ url = self .huggingface_url or self .url
382
+
383
+ if url .startswith ("hf://" ):
384
+ owner_name , model_name = url .split ("/" )[2 :]
385
+ self .upload_to_huggingface (owner_name , model_name )
386
+
366
387
def remove (self ):
367
388
"""Removes the dataset from disk."""
368
389
if self .exists :
@@ -414,3 +435,59 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None):
414
435
)()
415
436
416
437
return dataset
438
+
439
+ def upload_to_huggingface (self , owner_name : str , model_name : str ):
440
+ """Uploads the dataset to Hugging Face.
441
+
442
+ Args:
443
+ owner_name (str): The owner name.
444
+ model_name (str): The model name.
445
+ """
446
+ token = os .environ .get (
447
+ "HUGGING_FACE_TOKEN" , "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"
448
+ )
449
+ login (token = token )
450
+ api = HfApi ()
451
+
452
+ # Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata.
453
+ uk_data_version = get_package_version ("policyengine-uk-data" )
454
+ uk_version = get_package_version ("policyengine-uk" )
455
+ with h5py .File (self .file_path , "a" ) as f :
456
+ f .attrs ["policyengine-uk-data" ] = uk_data_version
457
+ f .attrs ["policyengine-uk" ] = uk_version
458
+
459
+ api .upload_file (
460
+ path_or_fileobj = self .file_path ,
461
+ path_in_repo = self .file_path .name ,
462
+ repo_id = f"{ owner_name } /{ model_name } " ,
463
+ repo_type = "model" ,
464
+ )
465
+
466
+ def download_from_huggingface (
467
+ self , owner_name : str , model_name : str , version : str = None
468
+ ):
469
+ """Downloads the dataset from Hugging Face.
470
+
471
+ Args:
472
+ owner_name (str): The owner name.
473
+ model_name (str): The model name.
474
+ """
475
+ token = os .environ .get (
476
+ "HUGGING_FACE_TOKEN" , "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"
477
+ )
478
+ login (token = token )
479
+
480
+ hf_hub_download (
481
+ repo_id = f"{ owner_name } /{ model_name } " ,
482
+ repo_type = "model" ,
483
+ path = self .file_path ,
484
+ revision = version ,
485
+ )
486
+
487
+
488
+ def get_package_version (package_name : str ) -> str :
489
+ """Get the installed version of a package."""
490
+ try :
491
+ return pkg_resources .get_distribution (package_name ).version
492
+ except pkg_resources .DistributionNotFound :
493
+ return "not installed"
0 commit comments