diff --git a/optimum/commands/neuron/cache.py b/optimum/commands/neuron/cache.py index c8decf47f..29a945a86 100644 --- a/optimum/commands/neuron/cache.py +++ b/optimum/commands/neuron/cache.py @@ -226,18 +226,32 @@ def parse_args(parser: "ArgumentParser"): type=str, help="The model_id to lookup cached versions for.", ) + parser.add_argument( + "--mode", + type=str, + choices=["training", "inference", "all"], + default="all", + help='The mode you wish to lookup compilation files for. Can be either "training", "inference" or "all"', + ) parser.add_argument("--repo_id", type=str, default=None, help="The name of the repo to use as remote cache.") - def run(self): - entries = get_hub_cached_entries(self.args.model_id, cache_repo_id=self.args.repo_id) + def _list_entries(self, mode: str): + entries = get_hub_cached_entries(mode, self.args.model_id, cache_repo_id=self.args.repo_id) n_entries = len(entries) - output = f"\n*** {n_entries} entrie(s) found in cache for {self.args.model_id} ***\n\n" + output = f"\n*** {n_entries} entrie(s) found in cache for {self.args.model_id} for {mode}.***\n\n" for entry in entries: for key, value in entry.items(): output += f"\n{key}: {value}" output += "\n" print(output) + def run(self): + if self.args.mode == "all": + self._list_entries("training") + self._list_entries("inference") + else: + self._list_entries(self.args.mode) + class CustomCacheRepoCommand(BaseOptimumCLICommand): SUBCOMMANDS = ( diff --git a/optimum/neuron/modeling_decoder.py b/optimum/neuron/modeling_decoder.py index 501f5f321..fdf6fbaa7 100644 --- a/optimum/neuron/modeling_decoder.py +++ b/optimum/neuron/modeling_decoder.py @@ -151,7 +151,7 @@ def __init__( cache_entry = None if checkpoint_id is None else ModelCacheEntry(checkpoint_id, config) # Export the model using the Optimum Neuron Cache - with hub_neuronx_cache(entry=cache_entry): + with hub_neuronx_cache("inference", entry=cache_entry): available_cores = get_available_cores() if num_cores > available_cores: raise ValueError( diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index e99fbe309..a83e54602 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -68,14 +68,13 @@ from .accelerate import NeuronAccelerator, NeuronDistributedType from .distributed import Parallelizer, ParallelizersManager from .distributed.utils import make_optimizer_constructor_lazy -from .trainer_callback import NeuronCacheCallback from .utils import ( Patcher, is_torch_xla_available, patch_within_function, ) -from .utils.cache_utils import get_hf_hub_cache_repos -from .utils.hub_neuronx_cache import hub_neuronx_cache, patch_neuron_cc_wrapper, synchronize_hub_cache +from .utils.cache_utils import get_hf_hub_cache_repos, has_write_access_to_repo +from .utils.hub_neuronx_cache import patch_neuron_cc_wrapper, synchronize_hub_cache from .utils.require_utils import requires_neuronx_distributed from .utils.training_utils import ( TRANSFORMERS_MIN_VERSION_USE_ACCELERATE, @@ -220,26 +219,19 @@ def create_accelerator_and_postprocess(self): ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config ds_plugin.hf_ds_config.trainer_config_process(self.args) + def synchronize_hub_cache(self): + repo_id = get_hf_hub_cache_repos()[0] + if xm.get_ordinal() == 0: + has_write_access = has_write_access_to_repo(repo_id) + if has_write_access: + synchronize_hub_cache(repo_id) + xm.rendezvous("Hub cache synchronization done") + def _wrap_model(self, model, training=True, dataloader=None): return super()._wrap_model( self.accelerator.patch_model_for_neuron(model), training=training, dataloader=dataloader ) - # TODO: make this cleaner. - def trigger_on_step_middle_for_neuron_cache_callback(self, model: "PreTrainedModel"): - for callback in self.callback_handler.callbacks: - if isinstance(callback, NeuronCacheCallback): - # kwargs might not have everything expected (like metrics) but all we need is here. - kwargs = { - "model": model, - "tokenizer": self.tokenizer, - "optimizer": self.optimizer, - "lr_scheduler": self.lr_scheduler, - "train_dataloader": self.callback_handler.train_dataloader, - "eval_dataloader": self.callback_handler.eval_dataloader, - } - callback.on_step_middle(self.args, self.state, self.control, **kwargs) - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.mp_enabled: if self.train_dataset is None or not has_length(self.train_dataset): @@ -275,7 +267,6 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, def compute_loss(self, model, inputs, return_outputs: bool = False): self.state.last_inputs = inputs - self.trigger_on_step_middle_for_neuron_cache_callback(model) from neuronx_distributed.pipeline import NxDPPModel if isinstance(model, NxDPPModel): @@ -317,7 +308,6 @@ def prediction_step( from neuronx_distributed.pipeline import NxDPPModel self.state.last_inputs = inputs - self.trigger_on_step_middle_for_neuron_cache_callback(model) if isinstance(model, NxDPPModel): if not prediction_loss_only: @@ -462,15 +452,12 @@ def _save_xla(self, output_dir: Optional[str] = None): def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): if not os.environ.get("NEURON_PARALLEL_COMPILE"): # Avoid unnecessary model saving during precompilation with patch_neuron_cc_wrapper(): - with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]): - if output_dir is None: - output_dir = self.args.output_dir + if output_dir is None: + output_dir = self.args.output_dir - self._save_xla(output_dir) + self._save_xla(output_dir) - if xm.get_ordinal() == 0: - synchronize_hub_cache(get_hf_hub_cache_repos()[0]) - xm.rendezvous("Hub cache synchronization done") + self.synchronize_hub_cache() # Push to the Hub when `save_model` is called by the user. if self.args.push_to_hub and not _internal_call: @@ -1290,16 +1277,13 @@ def train( **kwargs, ): with patch_neuron_cc_wrapper(): - with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]): - result = super().train( - resume_from_checkpoint=resume_from_checkpoint, - trial=trial, - ignore_keys_for_eval=ignore_keys_for_eval, - **kwargs, - ) - if xm.get_ordinal() == 0: - synchronize_hub_cache(get_hf_hub_cache_repos()[0]) - xm.rendezvous("Hub cache synchronization done") + result = super().train( + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + **kwargs, + ) + self.synchronize_hub_cache() return result def evaluate( @@ -1309,24 +1293,18 @@ def evaluate( metric_key_prefix: str = "eval", ) -> Dict[str, float]: with patch_neuron_cc_wrapper(): - with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]): - result = super().evaluate( - eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix - ) - if xm.get_ordinal() == 0: - synchronize_hub_cache(get_hf_hub_cache_repos()[0]) - xm.rendezvous("Hub cache synchronization done") + result = super().evaluate( + eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix + ) + self.synchronize_hub_cache() return result def predict( self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test" ) -> PredictionOutput: with patch_neuron_cc_wrapper(): - with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]): - result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) - if xm.get_ordinal() == 0: - synchronize_hub_cache(get_hf_hub_cache_repos()[0]) - xm.rendezvous("Hub cache synchronization done") + result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + self.synchronize_hub_cache() return result diff --git a/optimum/neuron/utils/cache_utils.py b/optimum/neuron/utils/cache_utils.py index d68aa4642..7e78877f4 100644 --- a/optimum/neuron/utils/cache_utils.py +++ b/optimum/neuron/utils/cache_utils.py @@ -184,11 +184,15 @@ def has_write_access_to_repo(repo_id: str) -> bool: def get_hf_hub_cache_repos(): + # Default hub repos. hf_hub_repos = HF_HUB_CACHE_REPOS + + # Locally saved hub repo. saved_custom_cache_repo = load_custom_cache_repo_name_from_hf_home() if saved_custom_cache_repo is not None and saved_custom_cache_repo not in hf_hub_repos: hf_hub_repos = [saved_custom_cache_repo] + hf_hub_repos + # Hub repo set via the environment variable CUSTOM_CACHE_REPO. custom_cache_repo = os.environ.get("CUSTOM_CACHE_REPO", None) if custom_cache_repo is not None and custom_cache_repo not in hf_hub_repos: hf_hub_repos = [custom_cache_repo] + hf_hub_repos @@ -202,13 +206,6 @@ def get_hf_hub_cache_repos(): "set -n [name]`.", ) - # TODO: this is a quick fix. - # Cache utils should not be aware of the multiprocessing side of things. - # The issue here is that `has_write_access_to_repo` actually pushes stuff to the HF Hub. - # Pushing stuff to the HF Hub should be limited to the `push_to_cache_on_hub` function, - # making it easier for higher-level abstractions using the cache utils to reason on which - # parts should only run on the master process and which parts should run on everyone. - if is_main_worker() and hf_hub_repos and not has_write_access_to_repo(hf_hub_repos[0]): warn_once( logger, diff --git a/optimum/neuron/utils/hub_neuronx_cache.py b/optimum/neuron/utils/hub_neuronx_cache.py index c50bd0183..213afe903 100644 --- a/optimum/neuron/utils/hub_neuronx_cache.py +++ b/optimum/neuron/utils/hub_neuronx_cache.py @@ -18,9 +18,10 @@ import os import shutil from contextlib import contextmanager +from enum import Enum from pathlib import Path from tempfile import TemporaryDirectory -from typing import Optional +from typing import Literal, Optional, Union from huggingface_hub import HfApi, get_token from transformers import AutoConfig, PretrainedConfig @@ -232,11 +233,30 @@ def hash(self): REGISTRY_FOLDER = f"0_REGISTRY/{__version__}" +TRAINING_REGISTRY_FOLDER = f"0_TRAINING_REGISTRY/{__version__}" + + +class Mode(str, Enum): + TRAINING = "training" + INFERENCE = "inference" + + +def get_registry_folder_for_mode(mode: Union[Literal["training"], Literal["inference"], Mode]) -> str: + if isinstance(mode, str) and not isinstance(mode, Mode): + mode = Mode(mode) + if mode is Mode.TRAINING: + return TRAINING_REGISTRY_FOLDER + else: + return REGISTRY_FOLDER @requires_torch_neuronx @contextmanager -def hub_neuronx_cache(entry: Optional[ModelCacheEntry] = None, cache_repo_id: Optional[str] = None): +def hub_neuronx_cache( + mode: Union[Literal["training"], Literal["inference"], Mode], + entry: Optional[ModelCacheEntry] = None, + cache_repo_id: Optional[str] = None, +): """A context manager to activate the Hugging Face Hub proxy compiler cache. Args: @@ -246,6 +266,7 @@ def hub_neuronx_cache(entry: Optional[ModelCacheEntry] = None, cache_repo_id: Op cache_repo_id (`Optional[str]`, defaults to `None`): The id of the cache repo to use to fetch the precompiled files. """ + registry_folder = get_registry_folder_for_mode(mode) def hf_create_compile_cache(cache_url): try: @@ -264,7 +285,7 @@ def hf_create_compile_cache(cache_url): logger.warning("Skipping cache metadata update on S3 cache.") else: # Create cache entry in local cache: it can be later synchronized with the hub cache - registry_path = default_cache.get_cache_dir_with_cache_key(REGISTRY_FOLDER) + registry_path = default_cache.get_cache_dir_with_cache_key(registry_folder) model_type = entry.config["model_type"] entry_path = f"{registry_path}/{model_type}/{entry.model_id}" config_path = f"{entry_path}/{entry.hash}.json" @@ -329,7 +350,9 @@ def synchronize_hub_cache(cache_repo_id: Optional[str] = None): hub_cache_proxy.synchronize() -def get_hub_cached_entries(model_id: str, cache_repo_id: Optional[str] = None): +def get_hub_cached_entries( + model_id: str, mode: Union[Literal["training"], Literal["inference"], Mode], cache_repo_id: Optional[str] = None +): if cache_repo_id is None: cache_repo_id = get_hub_cache() # Allocate a Hub API with refreshed information (required for tests altering the env) @@ -341,7 +364,8 @@ def get_hub_cached_entries(model_id: str, cache_repo_id: Optional[str] = None): target_entry = ModelCacheEntry(model_id, (AutoConfig.from_pretrained(model_id))) # Extract model type: it will be used as primary key for lookup model_type = target_entry.config["model_type"] - registry_pattern = REGISTRY_FOLDER + "/" + model_type + registry_folder = get_registry_folder_for_mode(mode) + registry_pattern = registry_folder + "/" + model_type model_files = [path for path in repo_files if registry_pattern in path] model_entries = [] with TemporaryDirectory() as tmpdir: diff --git a/optimum/neuron/utils/neuron_cc_wrapper b/optimum/neuron/utils/neuron_cc_wrapper index b59389a9f..ecd05bedc 100755 --- a/optimum/neuron/utils/neuron_cc_wrapper +++ b/optimum/neuron/utils/neuron_cc_wrapper @@ -1,4 +1,4 @@ -##!/usr/bin/env python3 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- import re import sys diff --git a/optimum/neuron/utils/optimum_neuron_cc_wrapper.py b/optimum/neuron/utils/optimum_neuron_cc_wrapper.py index d0acc6ef3..5685396e1 100644 --- a/optimum/neuron/utils/optimum_neuron_cc_wrapper.py +++ b/optimum/neuron/utils/optimum_neuron_cc_wrapper.py @@ -20,7 +20,7 @@ def main(): - with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]): + with hub_neuronx_cache("training", cache_repo_id=get_hf_hub_cache_repos()[0]): return neuron_cc_wrapper_main() diff --git a/tests/cache/test_neuronx_cache.py b/tests/cache/test_neuronx_cache.py index 0d3a5656a..17334c6b2 100644 --- a/tests/cache/test_neuronx_cache.py +++ b/tests/cache/test_neuronx_cache.py @@ -123,7 +123,7 @@ def test_decoder_cache(cache_repos): synchronize_hub_cache(cache_repo_id=cache_repo_id) assert_local_and_hub_cache_sync(cache_path, cache_repo_id) # Verify we are able to fetch the cached entry for the model - model_entries = get_hub_cached_entries(model_id, cache_repo_id=cache_repo_id) + model_entries = get_hub_cached_entries("inference", model_id, cache_repo_id=cache_repo_id) assert len(model_entries) == 1 assert model_entries[0] == model.config.neuron # Clear the local cache diff --git a/text-generation-inference/server/text_generation_server/model.py b/text-generation-inference/server/text_generation_server/model.py index c759fa1b4..5882c3284 100644 --- a/text-generation-inference/server/text_generation_server/model.py +++ b/text-generation-inference/server/text_generation_server/model.py @@ -34,7 +34,7 @@ def get_export_kwargs_from_env(): def is_cached(model_id, neuron_config): # Look for cached entries for the specified model in_cache = False - entries = get_hub_cached_entries(model_id) + entries = get_hub_cached_entries("inference", model_id) # Look for compatible entries for entry in entries: compatible = True