From bbbebeba1caceaf1d4466f5de724a6241a5124a9 Mon Sep 17 00:00:00 2001 From: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com> Date: Wed, 17 Apr 2024 16:12:31 +0200 Subject: [PATCH] Allow download subfolder for caching models with subfolder (#566) * allow download subfolder * improve code clarity --- optimum/neuron/utils/hub_neuronx_cache.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/optimum/neuron/utils/hub_neuronx_cache.py b/optimum/neuron/utils/hub_neuronx_cache.py index 4ea89f490..08360c312 100644 --- a/optimum/neuron/utils/hub_neuronx_cache.py +++ b/optimum/neuron/utils/hub_neuronx_cache.py @@ -25,6 +25,7 @@ from typing import Any, Dict, List, Literal, Optional, Union from huggingface_hub import HfApi, get_token +from huggingface_hub.hf_api import RepoFile from transformers import AutoConfig, PretrainedConfig from ..version import __version__ @@ -177,9 +178,10 @@ def download_folder(self, folder_path: str, dst_path: str): # cached locally return True else: + # cached remotely rel_folder_path = self._rel_path(folder_path) try: - folder_info = list(self.api.list_repo_tree(self.repo_id, rel_folder_path)) + folder_info = list(self.api.list_repo_tree(self.repo_id, rel_folder_path, recursive=True)) folder_exists = len(folder_info) > 1 except Exception as e: logger.info(f"{rel_folder_path} not found in {self.repo_id}: {e} \nThe model will be recompiled.") @@ -187,14 +189,13 @@ def download_folder(self, folder_path: str, dst_path: str): if folder_exists: try: - # cached remotely for repo_content in folder_info: - # TODO: this works for `RepoFile` but not `RepoFolder` - local_path = self.api.hf_hub_download(self.repo_id, repo_content.path) - filename = Path(local_path).name - dst_path = Path(dst_path) - dst_path.mkdir(parents=True, exist_ok=True) - os.symlink(local_path, dst_path / filename) + if isinstance(repo_content, RepoFile): + local_path = self.api.hf_hub_download(self.repo_id, repo_content.path) + new_dst_path = Path(dst_path) / repo_content.path.split(Path(dst_path).name + "/")[-1] + new_dst_path.parent.mkdir(parents=True, exist_ok=True) + os.symlink(local_path, new_dst_path) + logger.info(f"Fetched cached {rel_folder_path} from {self.repo_id}") except Exception as e: logger.warning(