From f81c3654c358eb2acfc1d294f4e43725ea822044 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Wed, 17 Jan 2024 10:39:05 +0100 Subject: [PATCH] Add Neuronx compile cache proxy and use it for LLM decoder models (#410) * feat: add HF Hub neuronx cache proxy * feat(decoders): always use hub neuronx cache * feat(cli): add cache synchronize command * doc: add reference to new cache for NeuronModelForCausalLM * ci: add neuronx cache tests * feat(cache): only except when synchronizing * feat(cache): catch more errors * feat(cache): add warning on cache miss * review: address comments * fix(cli): avoid undefined symbol * fix(utils): avoid possible circular import * review: use existing require helper * review: address comment --- .github/workflows/test_inf2.yml | 4 + docs/source/guides/cache_system.mdx | 59 +++--- optimum/commands/neuron/cache.py | 15 ++ optimum/neuron/modeling_decoder.py | 5 +- optimum/neuron/utils/__init__.py | 1 + optimum/neuron/utils/hub_neuronx_cache.py | 227 ++++++++++++++++++++++ optimum/neuron/utils/require_utils.py | 2 +- tests/cache/test_neuronx_cache.py | 163 ++++++++++++++++ 8 files changed, 444 insertions(+), 32 deletions(-) create mode 100644 optimum/neuron/utils/hub_neuronx_cache.py create mode 100644 tests/cache/test_neuronx_cache.py diff --git a/.github/workflows/test_inf2.yml b/.github/workflows/test_inf2.yml index e6db238ca..a296128ce 100644 --- a/.github/workflows/test_inf2.yml +++ b/.github/workflows/test_inf2.yml @@ -35,6 +35,10 @@ jobs: python -m pip install -U pip python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com python -m pip install .[neuronx,tests] + - name: Run cache tests + run: | + source aws_neuron_venv_pytorch/bin/activate + HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/cache - name: Run CLI tests run: | source aws_neuron_venv_pytorch/bin/activate diff --git a/docs/source/guides/cache_system.mdx b/docs/source/guides/cache_system.mdx index fe1d244d2..b46edcfea 100644 --- a/docs/source/guides/cache_system.mdx +++ b/docs/source/guides/cache_system.mdx @@ -12,52 +12,52 @@ specific language governing permissions and limitations under the License. # Neuron Model Cache -The Neuron Model Cache is a remote cache for compiled Neuron models in the `neff` format. -It is integrated into the [`NeuronTrainer`] class to enable loading pretrained models from the cache instead of compiling them locally. -This can speed up the training process by about –3x. +The Neuron Model Cache is a remote cache for compiled Neuron models in the `neff` format. +It is integrated into the [`NeuronTrainer` and `NeuronModelForCausalLM] classes to enable loading pretrained models from the cache instead of compiling them locally. -The Neuron Model Cache is hosted on the [Hugging Face Hub](https://huggingface.co/aws-neuron/optimum-neuron-cache) and includes compiled files for all popular and supported pre-trained models `optimum-neuron`. +The Neuron Model Cache is hosted on the [Hugging Face Hub](https://huggingface.co/aws-neuron/optimum-neuron-cache) and includes compiled files for all popular and supported `optimum-neuron` pre-trained models. -When training a Transformers or Diffusion model with vanilla [`torch-neuronx`](https://github.com/aws-neuron/aws-neuron-samples/tree/master/torch-neuronx), the models needs to be first compiled. The compiled version is stored in a local directory, usually `/var/tmp/neuron-compile-cache`. -This means that every time you train a new model in a new environment, you need to recompile it, which takes a lot of time. +When loading a Transformers or Diffusion model, it needs to be compiled to neuron format with [`torch-neuronx`](https://github.com/aws-neuron/aws-neuron-samples/tree/master/torch-neuronx), +in order to run on Neuron platforms. +The compilation produces several compilation files stored in a local directory, usually `/var/tmp/neuron-compile-cache`. +This means that every time you train or export a model on a new host, you need to recompile it, which takes a lot of time. We created the Neuron Model Cache to solve this limitation by providing a public cache of precompiled available models and a private cache to create your private, secured, remote model cache. -The Neuron Model Cache plugs into the local cache directory of the Hugging Face Hub. During training, the [`NeuronTrainer`] will check if compilation files are available on the Hub and download them if they are found, allowing you to save both time and cost by skipping the compilation phase. - ## How the caching system works ### Hash computation -Many factors can trigger compilation among which: +Many factors can trigger compilation among which: + +- The input shapes, +- The precision of the model, full-precision or bf16, +- The version of the Neuron X compiler, +- The number of Neuron cores used. -- The model weights -- The input shapes -- The precision of the model, full-precision or bf16 -- The version of the Neuron X compiler -- The number of Neuron cores used +These parameters are used to compute a hash that uniquely identifies each compilation file. -These parameters are used to compute a hash. This hash is then used to compare local hashes for our training session against hashes stored on the Hugging Face Hub, and act accordingly (download or push). +**It is important to keep in mind that even a small change in the model configuration will trigger a recompilation.** ### How to use the Neuron model cache -The Public model cache will be used when your training script uses the [`NeuronTrainer`]. There are no additional changes needed. +The public model cache will be used when you use the [`NeuronTrainer` or `NeuronModelForCausalLM] classes. There are no additional changes needed. -### How to use a private Neuron model cache +### How to use a private Neuron model cache (trainium only) The repository for the public cache is `aws-neuron/optimum-neuron-cache`. This repository includes all precompiled files for commonly used models so that it is publicly available and free to use for everyone. But there are two limitations: -1. You will not be able to push your own compiled files on this repo +1. You will not be able to push your own compiled files on this repo 2. It is public and you might want to use a private repo for private models To alleviate that you can create your own private cache repository using the `optimum-cli` or set the environment variable `CUSTOM_CACHE_REPO`. #### Using the Optimum CLI -The Optimum CLI offers 2 subcommands for cache creation and setting: +The Optimum CLI offers 2 subcommands for cache creation and setting: -- `create`: To create a new cache repository that you can use as a private Neuron Model cache. -- `set`: To set the name of the Nueron cache repository locally, the repository needs to exists +- `create`: To create a new cache repository that you can use as a private Neuron Model cache. +- `set`: To set the name of the Neuron cache repository locally, the repository needs to exists and will be used by default by `optimum-neuron`. Create a new Neuron cache repository: @@ -115,7 +115,7 @@ The `optimum-cli neuron cache set` command is useful when working on a new insta Using the CLI is not always feasible, and not very practical for small testing. In this case, you can simply set the environment variable `CUSTOM_CACHE_REPO`. -For example, if you cache repo is called `michaelbenayoun/my_custom_cache_repo`, you just need to do: +For example, if your cache repo is called `michaelbenayoun/my_custom_cache_repo`, you just need to do: ```bash CUSTOM_CACHE_REPO="michaelbenayoun/my_custom_cache_repo" torchrun ... @@ -139,11 +139,11 @@ You have to be [logged into the Hugging Face Hub](https://huggingface.co/docs/hu

-At each the beginning of each training step, the [`NeuronTrainer`] computes a `NeuronHash` and checks the cache repo(s) (official and custom) on the Hugging Face Hub to see if there are compiled files associated to this hash. +At each the beginning of each training step, the [`NeuronTrainer`] computes a `NeuronHash` and checks the cache repo(s) (official and custom) on the Hugging Face Hub to see if there are compiled files associated to this hash. If that is the case, the files are downloaded directly to the local cache directory and no compilation is needed. Otherwise compilation is performed. -Just as for downloading compiled files, the [`NeuronTrainer`] will keep track of the newly created compilation files at each training step, and upload them to the Hugging Face Hub at save time or when training ends. This assumes that you have writing access to the cache repo, otherwise nothing will be pushed. +Just as for downloading compiled files, the [`NeuronTrainer`] will keep track of the newly created compilation files at each training step, and upload them to the Hugging Face Hub at save time or when training ends. This assumes that you have writing access to the cache repo, otherwise nothing will be pushed. ## Optimum CLI @@ -156,15 +156,16 @@ usage: optimum-cli neuron cache [-h] {create,set,add,list} ... positional arguments: {create,set,add,list} create Create a model repo on the Hugging Face Hub to store Neuron X compilation files. - set Set the name of the Neuron cache repo to use locally. - add Add a model to the cache of your choice. - list List models in a cache repo. + set Set the name of the Neuron cache repo to use locally (trainium only). + add Add a model to the cache of your choice (trainium only). + list List models in a cache repo (trainium only). + synchronize Synchronize local compiler cache with the hub cache (inferentia only). optional arguments: -h, --help show this help message and exit ``` -### Add a model to the cache +### Add a model to the cache (trainium only) It is possible to add a model compilation files to a cache repo via the `optimum-cli neuron cache add` command: @@ -178,7 +179,7 @@ usage: optimum-cli neuron cache add [-h] -m MODEL --task TASK --train_batch_size When running this command a small training session will be run and the resulting compilation files will be pushed. -Make sure that the Neuron cache repo to use is set up locally, this can be done by running the `optimum-cli neuron cache set` command. +Make sure that the Neuron cache repo to use is set up locally, this can be done by running the `optimum-cli neuron cache set` command. You also need to make sure that you are logged in to the Hugging Face Hub and that you have the writing rights for the specified cache repo, this can be done via the `huggingface-cli login` command. diff --git a/optimum/commands/neuron/cache.py b/optimum/commands/neuron/cache.py index 745ce70bb..f5de193de 100644 --- a/optimum/commands/neuron/cache.py +++ b/optimum/commands/neuron/cache.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING +from ...neuron.utils import synchronize_hub_cache from ...neuron.utils.cache_utils import ( CACHE_REPO_NAME, HF_HOME_CACHE_REPO_FILE, @@ -208,6 +209,15 @@ def run(self): print(f"\n*** Repo id: {self.args.name} ***\n\n{result}") +class SynchronizeRepoCommand(BaseOptimumCLICommand): + @staticmethod + def parse_args(parser: "ArgumentParser"): + parser.add_argument("--repo_id", type=str, default=None, help="The name of the repo to use as remote cache.") + + def run(self): + synchronize_hub_cache(self.args.repo_id) + + class CustomCacheRepoCommand(BaseOptimumCLICommand): SUBCOMMANDS = ( CommandInfo( @@ -230,4 +240,9 @@ class CustomCacheRepoCommand(BaseOptimumCLICommand): help="List models in a cache repo.", subcommand_class=ListRepoCommand, ), + CommandInfo( + name="synchronize", + help="Synchronize the neuronx compiler cache with a hub cache repo.", + subcommand_class=SynchronizeRepoCommand, + ), ) diff --git a/optimum/neuron/modeling_decoder.py b/optimum/neuron/modeling_decoder.py index d54274a39..2af03e456 100644 --- a/optimum/neuron/modeling_decoder.py +++ b/optimum/neuron/modeling_decoder.py @@ -29,7 +29,7 @@ from ..exporters.neuron.model_configs import * # noqa: F403 from ..exporters.tasks import TasksManager from ..modeling_base import OptimizedModel -from .utils import is_transformers_neuronx_available +from .utils import hub_neuronx_cache, is_transformers_neuronx_available from .utils.version_utils import check_compiler_compatibility, get_neuronxcc_version @@ -223,7 +223,8 @@ def _from_pretrained( # Compile the Neuron model (if present compiled artifacts will be reloaded instead of compiled) neuron_cc_flags = os.environ.get("NEURON_CC_FLAGS", "") os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags + " --model-type=transformer" - neuronx_model.to_neuron() + with hub_neuronx_cache(): + neuronx_model.to_neuron() os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags # Try to reload the generation config (if any) diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index c859ba71b..15a51ee0b 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -24,6 +24,7 @@ ENCODER_NAME, NEURON_FILE_NAME, ) +from .hub_neuronx_cache import hub_neuronx_cache, synchronize_hub_cache from .import_utils import ( is_accelerate_available, is_neuron_available, diff --git a/optimum/neuron/utils/hub_neuronx_cache.py b/optimum/neuron/utils/hub_neuronx_cache.py new file mode 100644 index 000000000..fc566880d --- /dev/null +++ b/optimum/neuron/utils/hub_neuronx_cache.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from contextlib import contextmanager +from typing import Optional + +from huggingface_hub import HfApi, get_token + +from ..version import __version__ +from .import_utils import is_neuronx_available +from .patching import patch_everywhere +from .require_utils import requires_torch_neuronx + + +if is_neuronx_available(): + from libneuronxla.neuron_cc_cache import ( + CacheUrl, + CompileCache, + CompileCacheFs, + CompileCacheS3, + create_compile_cache, + ) +else: + + class CacheUrl: + pass + + class CompileCache: + pass + + class CompileCacheFs: + pass + + class CompileCacheS3: + pass + + def create_compile_cache(): + pass + + +logger = logging.getLogger(__name__) + + +class CompileCacheHfProxy(CompileCache): + """A HuggingFace Hub proxy cache implementing the CompileCache API. + + This cache first looks for compilation artifacts into the default cache, then the + specified Hugging Face cache repository. + + Args: + repo_id (`str`): + The id of the Hugging Face cache repository, in the form 'org|user/name'. + default_cache (`CompileCache`): + The default neuron compiler cache (can be either a local file or S3 cache). + endpoint (`Optional[str]`, defaults to None): + The HuggingFaceHub endpoint: only required for unit tests to switch to the staging Hub. + token (`Optional[str]`, defaults to None): + The HuggingFace token to use to fetch/push artifacts. If not specified it will correspond + to the current user token. + """ + + cache_type = "hf" + + def __init__( + self, repo_id: str, default_cache: CompileCache, endpoint: Optional[str] = None, token: Optional[str] = None + ): + # Initialize the proxy cache as expected by the parent class + super().__init__(default_cache.cache_url) + self.cache_path = default_cache.cache_path + # Initialize specific members + self.default_cache = default_cache + self.api = HfApi(endpoint=endpoint, token=token, library_name="optimum-neuron", library_version=__version__) + # Check if the HF cache id is valid + try: + if not self.api.repo_exists(repo_id): + raise ValueError(f"The {repo_id} repository does not exist or you don't have access to it.") + except Exception as e: + raise ValueError(f"Error while accessing the {repo_id} cache repository: {e}") + self.repo_id = repo_id + + def get_cache_dir(self, model_hash: str, compile_flags_str: str): + return self.default_cache.get_cache_dir(model_hash, compile_flags_str) + + def clean(self): + self.default_cache.clean() + + def clear_locks(self): + # Clear locks in the default cache only, as the Hf already deals with concurrency + self.default_cache.clear_locks() + + def get_hlos(self, failed_neff_str: str = ""): + return self.default_cache.get_hlos(failed_neff_str) + + def hlo_acquire_lock(self, h: str): + # Put a lock in the default cache only, as the Hf already deals with concurrency + return self.default_cache.hlo_acquire_lock(h) + + def hlo_release_lock(self, h: str): + # Release lock in the default cache only, as the Hf already deals with concurrency + return self.default_cache.hlo_release_lock(h) + + def remove(self, path: str): + # Only remove in the default cache + return self.default_cache.remove(path) + + def _rel_path(self, path: str): + # Remove the default cache url from the path + if path.startswith(self.default_cache.cache_path): + return path[len(self.default_cache.cache_path) :].lstrip("/") + + def exists(self, path: str): + # Always prioritize the default cache + if self.default_cache.exists(path): + return True + rel_path = self._rel_path(path) + exists = self.api.file_exists(self.repo_id, rel_path) + if not exists: + logger.warning( + f"{rel_path} not found in {self.repo_id}: the corresponding graph will be recompiled." + " This may take up to one hour for large models." + ) + return exists + + def download_file(self, filename: str, dst_path: str): + # Always prioritize the default cache for faster retrieval + if self.default_cache.exists(filename): + self.default_cache.download_file(filename, dst_path) + else: + rel_filename = self._rel_path(filename) + local_path = self.api.hf_hub_download(self.repo_id, rel_filename) + os.symlink(local_path, dst_path) + logger.info(f"Fetched cached {rel_filename} from {self.repo_id}") + + def synchronize(self): + if isinstance(self.default_cache, CompileCacheS3): + raise ValueError("Hugging Face hub compiler cache synchronization is not supported for S3.") + logger.info(f"Synchronizing {self.repo_id} Hub cache with {self.default_cache.cache_path} local cache") + self.api.upload_folder( + repo_id=self.repo_id, + folder_path=self.default_cache.cache_path, + commit_message="Synchronizing local compiler cache.", + ignore_patterns="lock", + ) + logger.info("Synchronization complete.") + + def upload_file(self, cache_path: str, src_path: str): + # Only upload to the default cache: use synchronize to populate the Hub cache + self.default_cache.upload_file(cache_path, src_path) + + def upload_string_to_file(self, cache_path: str, data: str): + # Only upload to the default cache: use synchronize to populate the Hub cache + self.default_cache.upload_string_to_file(cache_path, data) + + def download_file_to_string(self, filename: str, limit: int = None): + # Always prioritize the default cache for faster retrieval + if self.default_cache.exists(filename): + return self.default_cache.download_file_to_string(filename, limit) + rel_filename = self._rel_path(filename) + local_path = self.api.hf_hub_download(self.repo_id, rel_filename) + with open(local_path, "rb") as f: + s = f.read().decode(errors="replace") + logger.info(f"Fetched cached {rel_filename} from {self.repo_id}") + return s + + +def get_hub_cache(): + HUB_CACHE = "aws-neuron/optimum-neuron-cache" + return os.getenv("CUSTOM_CACHE_REPO", HUB_CACHE) + + +def _create_hub_compile_cache_proxy( + cache_url: Optional[CacheUrl] = None, + cache_repo_id: Optional[str] = None, +): + if cache_url is None: + cache_url = CacheUrl.get_cache_url() + if cache_repo_id is None: + cache_repo_id = get_hub_cache() + default_cache = CompileCacheS3(cache_url) if cache_url.is_s3() else CompileCacheFs(cache_url) + # Reevaluate endpoint and token (needed for tests altering the environment) + endpoint = os.getenv("HF_ENDPOINT") + token = get_token() + return CompileCacheHfProxy(cache_repo_id, default_cache, endpoint=endpoint, token=token) + + +@requires_torch_neuronx +@contextmanager +def hub_neuronx_cache(): + """A context manager to trigger the Hugging Face Hub proxy compiler cache""" + + def hf_create_compile_cache(cache_url): + try: + return _create_hub_compile_cache_proxy(cache_url) + except Exception as e: + logger.warning(f"Bypassing Hub cache because of the following error: {e}") + return create_compile_cache(cache_url) + + try: + patch_everywhere("create_compile_cache", hf_create_compile_cache, "libneuronxla") + yield + finally: + patch_everywhere("create_compile_cache", create_compile_cache, "libneuronxla") + + +@requires_torch_neuronx +def synchronize_hub_cache(cache_repo_id: Optional[str] = None): + """Synchronize the neuronx compiler cache with the optimum-neuron hub cache. + + Args: + repo_id (`Optional[str]`, default to None): + The id of the HuggingFace cache repository, in the form 'org|user/name'. + """ + hub_cache_proxy = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id) + hub_cache_proxy.synchronize() diff --git a/optimum/neuron/utils/require_utils.py b/optimum/neuron/utils/require_utils.py index 3ec901078..f9b6eb43c 100644 --- a/optimum/neuron/utils/require_utils.py +++ b/optimum/neuron/utils/require_utils.py @@ -19,7 +19,7 @@ from transformers.utils import is_safetensors_available -from . import is_neuronx_distributed_available, is_torch_neuronx_available, is_torch_xla_available +from .import_utils import is_neuronx_distributed_available, is_torch_neuronx_available, is_torch_xla_available _AVAILABILITIES: Dict[str, Callable[[], bool]] = { diff --git a/tests/cache/test_neuronx_cache.py b/tests/cache/test_neuronx_cache.py new file mode 100644 index 000000000..fd7d2050e --- /dev/null +++ b/tests/cache/test_neuronx_cache.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import glob +import os +import shutil +import socket +import subprocess +from tempfile import TemporaryDirectory + +import pytest +import torch +from huggingface_hub import HfApi +from transformers.testing_utils import ENDPOINT_STAGING + +from optimum.neuron import NeuronModelForCausalLM +from optimum.neuron.utils import synchronize_hub_cache +from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx +from optimum.utils.testing_utils import TOKEN + + +@pytest.fixture +def cache_repos(): + # Setup: create temporary Hub repository and local cache directory + api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) + user = api.whoami()["name"] + hostname = socket.gethostname() + cache_repo_id = f"{user}/{hostname}-optimum-neuron-cache" + if api.repo_exists(cache_repo_id): + api.delete_repo(cache_repo_id) + cache_repo_id = api.create_repo(cache_repo_id, private=True).repo_id + cache_dir = TemporaryDirectory() + cache_path = cache_dir.name + # Modify environment to force neuronx cache to use temporary caches + previous_env = {} + env_vars = ["NEURON_COMPILE_CACHE_URL", "CUSTOM_CACHE_REPO", "HF_ENDPOINT", "HF_TOKEN"] + for var in env_vars: + previous_env[var] = os.environ.get(var) + os.environ["NEURON_COMPILE_CACHE_URL"] = cache_path + os.environ["CUSTOM_CACHE_REPO"] = cache_repo_id + os.environ["HF_ENDPOINT"] = ENDPOINT_STAGING + os.environ["HF_TOKEN"] = TOKEN + yield (cache_path, cache_repo_id) + # Teardown + api.delete_repo(cache_repo_id) + for var in env_vars: + if previous_env[var] is None: + os.environ.pop(var) + else: + os.environ[var] = previous_env[var] + + +def export_decoder_model(model_id): + batch_size = 2 + sequence_length = 512 + num_cores = 1 + auto_cast_type = "fp32" + return NeuronModelForCausalLM.from_pretrained( + model_id, + export=True, + batch_size=batch_size, + sequence_length=sequence_length, + num_cores=num_cores, + auto_cast_type=auto_cast_type, + ) + + +def check_decoder_generation(model): + batch_size = model.config.neuron["batch_size"] + input_ids = torch.ones((batch_size, 20), dtype=torch.int64) + with torch.inference_mode(): + sample_output = model.generate(input_ids) + assert sample_output.shape[0] == batch_size + + +def get_local_cached_files(cache_path): + return glob.glob(f"{cache_path}/**/*/*.*", recursive=True) + + +def assert_local_and_hub_cache_sync(cache_path, cache_repo_id): + api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) + remote_files = api.list_repo_files(cache_repo_id) + local_files = get_local_cached_files(cache_path) + for file in local_files: + assert os.path.isfile(file) + path_in_repo = file[len(cache_path) :].lstrip("/") + assert path_in_repo in remote_files + + +def local_cache_size(cache_path): + return len(get_local_cached_files(cache_path)) + + +@is_inferentia_test +@requires_neuronx +def test_decoder_cache(cache_repos): + cache_path, cache_repo_id = cache_repos + # Export the model a first time to populate the local cache + model = export_decoder_model("hf-internal-testing/tiny-random-gpt2") + check_decoder_generation(model) + # Synchronize the hub cache with the local cache + synchronize_hub_cache(cache_repo_id=cache_repo_id) + assert_local_and_hub_cache_sync(cache_path, cache_repo_id) + # Clear the local cache + for root, dirs, files in os.walk(cache_path): + for f in files: + os.unlink(os.path.join(root, f)) + for d in dirs: + shutil.rmtree(os.path.join(root, d)) + assert local_cache_size(cache_path) == 0 + # Export the model again: the compilation artifacts should be fetched from the Hub + model = export_decoder_model("hf-internal-testing/tiny-random-gpt2") + check_decoder_generation(model) + # Verify the local cache directory has not been populated + assert local_cache_size(cache_path) == 0 + + +@is_inferentia_test +@requires_neuronx +@pytest.mark.parametrize( + "var, value, match", + [ + ("CUSTOM_CACHE_REPO", "foo/bar", "The foo/bar repository does not exist"), + ("HF_ENDPOINT", "https://foo.bar.baz", "Name or service not known"), + ("HF_TOKEN", "foo", "repository does not exist or you don't have access to it."), + ], + ids=["invalid_repo", "invalid_endpoint", "invalid_token"], +) +def test_decoder_cache_unavailable(cache_repos, var, value, match): + # Modify the specified environment variable to trigger an error + os.environ[var] = value + # Just exporting the model will only emit a warning + export_decoder_model("hf-internal-testing/tiny-random-gpt2") + with pytest.raises(ValueError, match=match): + # Trying to synchronize will in the contrary raise an exception + synchronize_hub_cache() + # No need to restore environment as it is already done by the cache_repos fixture + + +@is_inferentia_test +@requires_neuronx +def test_optimum_neuron_cli_cache_synchronize(cache_repos): + cache_path, cache_repo_id = cache_repos + # Export a model to populate the local cache + export_decoder_model("hf-internal-testing/tiny-random-gpt2") + # Synchronize the hub cache with the local cache + command = "optimum-cli neuron cache synchronize".split() + p = subprocess.Popen(command, stdout=subprocess.PIPE) + stdout, _ = p.communicate() + stdout = stdout.decode("utf-8") + assert p.returncode == 0 + assert_local_and_hub_cache_sync(cache_path, cache_repo_id)