Skip to content

Commit 58b760f

Browse files
authored
Fix _load_repository_from_gcs (#76)
* Fix `_load_repository_from_gcs` The issue was with the `join` call over a `pathlib.Path` object, so the `pathlib.Path` cast was missing on the `file_split[0::-1]` variable * Add `directory` as download path prefix * Add `tests/unit/test_vertex_ai_utils.py` * Fix to use `target_dir` over `directory` (introduced in #b0846e1)
1 parent 3d101a0 commit 58b760f

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

src/huggingface_inference_toolkit/vertex_ai_utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def _load_repository_from_gcs(artifact_uri: str, target_dir: Union[str, Path] =
1515
from google.cloud import storage
1616

1717
logger.info(f"Loading model artifacts from {artifact_uri} to {target_dir}")
18-
target_dir = Path(target_dir)
18+
if isinstance(target_dir, str):
19+
target_dir = Path(target_dir)
1920

2021
if artifact_uri.startswith(GCS_URI_PREFIX):
2122
matches = re.match(f"{GCS_URI_PREFIX}(.*?)/(.*)", artifact_uri)
@@ -31,9 +32,9 @@ def _load_repository_from_gcs(artifact_uri: str, target_dir: Union[str, Path] =
3132
else name_without_prefix
3233
)
3334
file_split = name_without_prefix.split("/")
34-
directory = target_dir.join(file_split[0:-1])
35+
directory = target_dir / Path(*file_split[0:-1])
3536
directory.mkdir(parents=True, exist_ok=True)
3637
if name_without_prefix and not name_without_prefix.endswith("/"):
37-
blob.download_to_filename(name_without_prefix)
38+
blob.download_to_filename(target_dir / name_without_prefix)
3839

3940
return str(target_dir.absolute())

tests/unit/test_vertex_ai_utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from pathlib import Path
2+
3+
from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs
4+
5+
6+
def test__load_repository_from_gcs():
7+
"""Tests the `_load_repository_from_gcs` function against a public artifact URI.
8+
9+
References:
10+
- https://cloud.google.com/storage/docs/public-datasets/era5
11+
- https://console.cloud.google.com/storage/browser/gcp-public-data-arco-era5/raw/date-variable-static/2021/12/31/soil_type?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))
12+
"""
13+
14+
public_artifact_uri = (
15+
"gs://gcp-public-data-arco-era5/raw/date-variable-static/2021/12/31/soil_type"
16+
)
17+
target_dir = Path.cwd() / "target"
18+
target_dir_path = _load_repository_from_gcs(
19+
artifact_uri=public_artifact_uri, target_dir=target_dir
20+
)
21+
22+
assert target_dir == Path(target_dir_path)
23+
assert Path(target_dir_path).exists()
24+
assert (Path(target_dir_path) / "static.nc").exists()

0 commit comments

Comments
 (0)