Skip to content

Commit

Permalink
Adapt commands
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 15, 2024
1 parent 230458f commit fcc7180
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 69 deletions.
20 changes: 17 additions & 3 deletions optimum/commands/neuron/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,18 +226,32 @@ def parse_args(parser: "ArgumentParser"):
type=str,
help="The model_id to lookup cached versions for.",
)
parser.add_argument(
"--mode",
type=str,
choices=["training", "inference", "all"],
default="all",
help='The mode you wish to lookup compilation files for. Can be either "training", "inference" or "all"',
)
parser.add_argument("--repo_id", type=str, default=None, help="The name of the repo to use as remote cache.")

def run(self):
entries = get_hub_cached_entries(self.args.model_id, cache_repo_id=self.args.repo_id)
def _list_entries(self, mode: str):
entries = get_hub_cached_entries(mode, self.args.model_id, cache_repo_id=self.args.repo_id)
n_entries = len(entries)
output = f"\n*** {n_entries} entrie(s) found in cache for {self.args.model_id} ***\n\n"
output = f"\n*** {n_entries} entrie(s) found in cache for {self.args.model_id} for {mode}.***\n\n"
for entry in entries:
for key, value in entry.items():
output += f"\n{key}: {value}"
output += "\n"
print(output)

def run(self):
if self.args.mode == "all":
self._list_entries("training")
self._list_entries("inference")
else:
self._list_entries(self.args.mode)


class CustomCacheRepoCommand(BaseOptimumCLICommand):
SUBCOMMANDS = (
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(
cache_entry = None if checkpoint_id is None else ModelCacheEntry(checkpoint_id, config)

# Export the model using the Optimum Neuron Cache
with hub_neuronx_cache(entry=cache_entry):
with hub_neuronx_cache("inference", entry=cache_entry):
available_cores = get_available_cores()
if num_cores > available_cores:
raise ValueError(
Expand Down
76 changes: 27 additions & 49 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,13 @@
from .accelerate import NeuronAccelerator, NeuronDistributedType
from .distributed import Parallelizer, ParallelizersManager
from .distributed.utils import make_optimizer_constructor_lazy
from .trainer_callback import NeuronCacheCallback
from .utils import (
Patcher,
is_torch_xla_available,
patch_within_function,
)
from .utils.cache_utils import get_hf_hub_cache_repos
from .utils.hub_neuronx_cache import hub_neuronx_cache, patch_neuron_cc_wrapper, synchronize_hub_cache
from .utils.cache_utils import get_hf_hub_cache_repos, has_write_access_to_repo
from .utils.hub_neuronx_cache import patch_neuron_cc_wrapper, synchronize_hub_cache
from .utils.require_utils import requires_neuronx_distributed
from .utils.training_utils import (
TRANSFORMERS_MIN_VERSION_USE_ACCELERATE,
Expand Down Expand Up @@ -220,26 +219,19 @@ def create_accelerator_and_postprocess(self):
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args)

def synchronize_hub_cache(self):
repo_id = get_hf_hub_cache_repos()[0]
if xm.get_ordinal() == 0:
has_write_access = has_write_access_to_repo(repo_id)
if has_write_access:
synchronize_hub_cache(repo_id)
xm.rendezvous("Hub cache synchronization done")

def _wrap_model(self, model, training=True, dataloader=None):
return super()._wrap_model(
self.accelerator.patch_model_for_neuron(model), training=training, dataloader=dataloader
)

# TODO: make this cleaner.
def trigger_on_step_middle_for_neuron_cache_callback(self, model: "PreTrainedModel"):
for callback in self.callback_handler.callbacks:
if isinstance(callback, NeuronCacheCallback):
# kwargs might not have everything expected (like metrics) but all we need is here.
kwargs = {
"model": model,
"tokenizer": self.tokenizer,
"optimizer": self.optimizer,
"lr_scheduler": self.lr_scheduler,
"train_dataloader": self.callback_handler.train_dataloader,
"eval_dataloader": self.callback_handler.eval_dataloader,
}
callback.on_step_middle(self.args, self.state, self.control, **kwargs)

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.mp_enabled:
if self.train_dataset is None or not has_length(self.train_dataset):
Expand Down Expand Up @@ -275,7 +267,6 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor,

def compute_loss(self, model, inputs, return_outputs: bool = False):
self.state.last_inputs = inputs
self.trigger_on_step_middle_for_neuron_cache_callback(model)
from neuronx_distributed.pipeline import NxDPPModel

if isinstance(model, NxDPPModel):
Expand Down Expand Up @@ -317,7 +308,6 @@ def prediction_step(
from neuronx_distributed.pipeline import NxDPPModel

self.state.last_inputs = inputs
self.trigger_on_step_middle_for_neuron_cache_callback(model)

if isinstance(model, NxDPPModel):
if not prediction_loss_only:
Expand Down Expand Up @@ -462,15 +452,12 @@ def _save_xla(self, output_dir: Optional[str] = None):
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
if not os.environ.get("NEURON_PARALLEL_COMPILE"): # Avoid unnecessary model saving during precompilation
with patch_neuron_cc_wrapper():
with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]):
if output_dir is None:
output_dir = self.args.output_dir
if output_dir is None:
output_dir = self.args.output_dir

self._save_xla(output_dir)
self._save_xla(output_dir)

if xm.get_ordinal() == 0:
synchronize_hub_cache(get_hf_hub_cache_repos()[0])
xm.rendezvous("Hub cache synchronization done")
self.synchronize_hub_cache()

# Push to the Hub when `save_model` is called by the user.
if self.args.push_to_hub and not _internal_call:
Expand Down Expand Up @@ -1290,16 +1277,13 @@ def train(
**kwargs,
):
with patch_neuron_cc_wrapper():
with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]):
result = super().train(
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
**kwargs,
)
if xm.get_ordinal() == 0:
synchronize_hub_cache(get_hf_hub_cache_repos()[0])
xm.rendezvous("Hub cache synchronization done")
result = super().train(
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
**kwargs,
)
self.synchronize_hub_cache()
return result

def evaluate(
Expand All @@ -1309,24 +1293,18 @@ def evaluate(
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
with patch_neuron_cc_wrapper():
with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]):
result = super().evaluate(
eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
if xm.get_ordinal() == 0:
synchronize_hub_cache(get_hf_hub_cache_repos()[0])
xm.rendezvous("Hub cache synchronization done")
result = super().evaluate(
eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
self.synchronize_hub_cache()
return result

def predict(
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
) -> PredictionOutput:
with patch_neuron_cc_wrapper():
with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]):
result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
if xm.get_ordinal() == 0:
synchronize_hub_cache(get_hf_hub_cache_repos()[0])
xm.rendezvous("Hub cache synchronization done")
result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
self.synchronize_hub_cache()
return result


Expand Down
11 changes: 4 additions & 7 deletions optimum/neuron/utils/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,15 @@ def has_write_access_to_repo(repo_id: str) -> bool:


def get_hf_hub_cache_repos():
# Default hub repos.
hf_hub_repos = HF_HUB_CACHE_REPOS

# Locally saved hub repo.
saved_custom_cache_repo = load_custom_cache_repo_name_from_hf_home()
if saved_custom_cache_repo is not None and saved_custom_cache_repo not in hf_hub_repos:
hf_hub_repos = [saved_custom_cache_repo] + hf_hub_repos

# Hub repo set via the environment variable CUSTOM_CACHE_REPO.
custom_cache_repo = os.environ.get("CUSTOM_CACHE_REPO", None)
if custom_cache_repo is not None and custom_cache_repo not in hf_hub_repos:
hf_hub_repos = [custom_cache_repo] + hf_hub_repos
Expand All @@ -202,13 +206,6 @@ def get_hf_hub_cache_repos():
"set -n [name]`.",
)

# TODO: this is a quick fix.
# Cache utils should not be aware of the multiprocessing side of things.
# The issue here is that `has_write_access_to_repo` actually pushes stuff to the HF Hub.
# Pushing stuff to the HF Hub should be limited to the `push_to_cache_on_hub` function,
# making it easier for higher-level abstractions using the cache utils to reason on which
# parts should only run on the master process and which parts should run on everyone.

if is_main_worker() and hf_hub_repos and not has_write_access_to_repo(hf_hub_repos[0]):
warn_once(
logger,
Expand Down
34 changes: 29 additions & 5 deletions optimum/neuron/utils/hub_neuronx_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import os
import shutil
from contextlib import contextmanager
from enum import Enum
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
from typing import Literal, Optional, Union

from huggingface_hub import HfApi, get_token
from transformers import AutoConfig, PretrainedConfig
Expand Down Expand Up @@ -232,11 +233,30 @@ def hash(self):


REGISTRY_FOLDER = f"0_REGISTRY/{__version__}"
TRAINING_REGISTRY_FOLDER = f"0_TRAINING_REGISTRY/{__version__}"


class Mode(str, Enum):
TRAINING = "training"
INFERENCE = "inference"


def get_registry_folder_for_mode(mode: Union[Literal["training"], Literal["inference"], Mode]) -> str:
if isinstance(mode, str) and not isinstance(mode, Mode):
mode = Mode(mode)
if mode is Mode.TRAINING:
return TRAINING_REGISTRY_FOLDER
else:
return REGISTRY_FOLDER


@requires_torch_neuronx
@contextmanager
def hub_neuronx_cache(entry: Optional[ModelCacheEntry] = None, cache_repo_id: Optional[str] = None):
def hub_neuronx_cache(
mode: Union[Literal["training"], Literal["inference"], Mode],
entry: Optional[ModelCacheEntry] = None,
cache_repo_id: Optional[str] = None,
):
"""A context manager to activate the Hugging Face Hub proxy compiler cache.
Args:
Expand All @@ -246,6 +266,7 @@ def hub_neuronx_cache(entry: Optional[ModelCacheEntry] = None, cache_repo_id: Op
cache_repo_id (`Optional[str]`, defaults to `None`):
The id of the cache repo to use to fetch the precompiled files.
"""
registry_folder = get_registry_folder_for_mode(mode)

def hf_create_compile_cache(cache_url):
try:
Expand All @@ -264,7 +285,7 @@ def hf_create_compile_cache(cache_url):
logger.warning("Skipping cache metadata update on S3 cache.")
else:
# Create cache entry in local cache: it can be later synchronized with the hub cache
registry_path = default_cache.get_cache_dir_with_cache_key(REGISTRY_FOLDER)
registry_path = default_cache.get_cache_dir_with_cache_key(registry_folder)
model_type = entry.config["model_type"]
entry_path = f"{registry_path}/{model_type}/{entry.model_id}"
config_path = f"{entry_path}/{entry.hash}.json"
Expand Down Expand Up @@ -329,7 +350,9 @@ def synchronize_hub_cache(cache_repo_id: Optional[str] = None):
hub_cache_proxy.synchronize()


def get_hub_cached_entries(model_id: str, cache_repo_id: Optional[str] = None):
def get_hub_cached_entries(
model_id: str, mode: Union[Literal["training"], Literal["inference"], Mode], cache_repo_id: Optional[str] = None
):
if cache_repo_id is None:
cache_repo_id = get_hub_cache()
# Allocate a Hub API with refreshed information (required for tests altering the env)
Expand All @@ -341,7 +364,8 @@ def get_hub_cached_entries(model_id: str, cache_repo_id: Optional[str] = None):
target_entry = ModelCacheEntry(model_id, (AutoConfig.from_pretrained(model_id)))
# Extract model type: it will be used as primary key for lookup
model_type = target_entry.config["model_type"]
registry_pattern = REGISTRY_FOLDER + "/" + model_type
registry_folder = get_registry_folder_for_mode(mode)
registry_pattern = registry_folder + "/" + model_type
model_files = [path for path in repo_files if registry_pattern in path]
model_entries = []
with TemporaryDirectory() as tmpdir:
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/utils/neuron_cc_wrapper
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
##!/usr/bin/env python3
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import re
import sys
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/utils/optimum_neuron_cc_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


def main():
with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]):
with hub_neuronx_cache("training", cache_repo_id=get_hf_hub_cache_repos()[0]):
return neuron_cc_wrapper_main()


Expand Down
2 changes: 1 addition & 1 deletion tests/cache/test_neuronx_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_decoder_cache(cache_repos):
synchronize_hub_cache(cache_repo_id=cache_repo_id)
assert_local_and_hub_cache_sync(cache_path, cache_repo_id)
# Verify we are able to fetch the cached entry for the model
model_entries = get_hub_cached_entries(model_id, cache_repo_id=cache_repo_id)
model_entries = get_hub_cached_entries("inference", model_id, cache_repo_id=cache_repo_id)
assert len(model_entries) == 1
assert model_entries[0] == model.config.neuron
# Clear the local cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_export_kwargs_from_env():
def is_cached(model_id, neuron_config):
# Look for cached entries for the specified model
in_cache = False
entries = get_hub_cached_entries(model_id)
entries = get_hub_cached_entries("inference", model_id)
# Look for compatible entries
for entry in entries:
compatible = True
Expand Down

0 comments on commit fcc7180

Please sign in to comment.