Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate new cache system for training #472

Merged
merged 20 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

include README.md
include LICENSE
include optimum/neuron/utils/neuron_cc_wrapper
77 changes: 19 additions & 58 deletions optimum/commands/neuron/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
CACHE_REPO_NAME,
HF_HOME_CACHE_REPO_FILE,
create_custom_cache_repo,
list_in_registry,
load_custom_cache_repo_name_from_hf_home,
set_custom_cache_repo_name_in_hf_home,
)
from ...neuron.utils.runner import ExampleRunner
Expand Down Expand Up @@ -163,52 +161,6 @@ def run(self):
)


class ListRepoCommand(BaseOptimumCLICommand):
@staticmethod
def parse_args(parser: "ArgumentParser"):
parser.add_argument(
"name",
type=str,
nargs="?",
default=None,
help="The name of the repo to list. Will use the locally saved cache repo if left unspecified.",
)
parser.add_argument(
"-m",
"--model",
type=str,
default=None,
help="The model name or path of the model to consider. If left unspecified, will list all available models.",
)
parser.add_argument(
"-v",
"--version",
type=str,
default=None,
help=(
"The version of the Neuron X Compiler to consider. Will list all available versions if left "
"unspecified."
),
)

def run(self):
if self.args.name is None:
custom_cache_repo_name = load_custom_cache_repo_name_from_hf_home()
if custom_cache_repo_name is None:
raise ValueError("No custom cache repo was set locally so you need to specify a cache repo name.")
self.args.name = custom_cache_repo_name

entries = list_in_registry(
self.args.name, model_name_or_path_or_hash=self.args.model, neuron_compiler_version=self.args.version
)
if not entries:
entries = ["Nothing was found."]
line = "\n" + "=" * 50 + "\n"
result = line.join(entries)

print(f"\n*** Repo id: {self.args.name} ***\n\n{result}")


class SynchronizeRepoCommand(BaseOptimumCLICommand):
@staticmethod
def parse_args(parser: "ArgumentParser"):
Expand All @@ -226,18 +178,32 @@ def parse_args(parser: "ArgumentParser"):
type=str,
help="The model_id to lookup cached versions for.",
)
parser.add_argument(
"--mode",
type=str,
choices=["training", "inference", "all"],
default="all",
help='The mode you wish to lookup compilation files for. Can be either "training", "inference" or "all"',
)
parser.add_argument("--repo_id", type=str, default=None, help="The name of the repo to use as remote cache.")

def run(self):
entries = get_hub_cached_entries(self.args.model_id, cache_repo_id=self.args.repo_id)
def _list_entries(self, mode: str):
entries = get_hub_cached_entries(self.args.model_id, mode, cache_repo_id=self.args.repo_id)
n_entries = len(entries)
output = f"\n*** {n_entries} entrie(s) found in cache for {self.args.model_id} ***\n\n"
output = f"\n*** {n_entries} entrie(s) found in cache for {self.args.model_id} for {mode}.***\n\n"
for entry in entries:
for key, value in entry.items():
output += f"\n{key}: {value}"
output += "\n"
print(output)

def run(self):
if self.args.mode == "all":
self._list_entries("training")
self._list_entries("inference")
else:
self._list_entries(self.args.mode)


class CustomCacheRepoCommand(BaseOptimumCLICommand):
SUBCOMMANDS = (
Expand All @@ -256,19 +222,14 @@ class CustomCacheRepoCommand(BaseOptimumCLICommand):
help="Add a model to the cache of your choice (trainium only).",
subcommand_class=AddToCacheRepoCommand,
),
CommandInfo(
name="list",
help="List models in a cache repo (trainium only).",
subcommand_class=ListRepoCommand,
),
CommandInfo(
name="synchronize",
help="Synchronize the neuronx compiler cache with a hub cache repo (inferentia only).",
help="Synchronize the neuronx compiler cache with a hub cache repo.",
subcommand_class=SynchronizeRepoCommand,
),
CommandInfo(
name="lookup",
help="Lookup the neuronx compiler hub cache for the specified model id (inferentia only).",
help="Lookup the neuronx compiler hub cache for the specified model id.",
subcommand_class=LookupRepoCommand,
),
)
2 changes: 1 addition & 1 deletion optimum/neuron/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(
cache_entry = None if checkpoint_id is None else ModelCacheEntry(checkpoint_id, config)

# Export the model using the Optimum Neuron Cache
with hub_neuronx_cache(entry=cache_entry):
with hub_neuronx_cache("inference", entry=cache_entry):
available_cores = get_available_cores()
if num_cores > available_cores:
raise ValueError(
Expand Down
164 changes: 100 additions & 64 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@
import sys
import time
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from accelerate import __version__ as accelerate_version
from packaging import version
from torch.utils.data import Dataset
from transformers import PreTrainedModel, Seq2SeqTrainer, Trainer, TrainingArguments
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.integrations import hp_params
Expand All @@ -56,6 +55,7 @@
EvalLoopOutput,
EvalPrediction,
HPSearchBackend,
PredictionOutput,
TrainOutput,
denumpify_detensorize,
has_length,
Expand All @@ -68,13 +68,19 @@
from .accelerate import NeuronAccelerator, NeuronDistributedType
from .distributed import Parallelizer, ParallelizersManager
from .distributed.utils import make_optimizer_constructor_lazy
from .trainer_callback import NeuronCacheCallback
from .utils import (
Patcher,
is_torch_xla_available,
patch_within_function,
)
from .utils.cache_utils import get_neuron_cache_path, set_neuron_cache_path
from .utils.cache_utils import (
get_hf_hub_cache_repos,
get_model_name_or_path,
get_neuronxcc_version,
get_num_neuron_cores_used,
has_write_access_to_repo,
)
from .utils.hub_neuronx_cache import ModelCacheEntry, hub_neuronx_cache, patch_neuron_cc_wrapper, synchronize_hub_cache
from .utils.require_utils import requires_neuronx_distributed
from .utils.training_utils import (
TRANSFORMERS_MIN_VERSION_USE_ACCELERATE,
Expand Down Expand Up @@ -113,38 +119,11 @@
if KEEP_HF_HUB_PROGRESS_BARS is None:
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"

# Used for torch.distributed.
_ORIGINAL_NEURON_CACHE_PATH: Optional[Path] = None
_TMP_NEURON_CACHE_DIR: Optional[TemporaryDirectory] = None
_TMP_NEURON_CACHE_PATH: Optional[Path] = None
_TCP_STORE_ADDRESS = "127.0.0.1"
_TCP_STORE_PORT = 5000


if os.environ.get("TORCHELASTIC_RUN_ID"):
import torch_xla.distributed.xla_backend as xbn

if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
_ORIGINAL_NEURON_CACHE_PATH = get_neuron_cache_path()

# _ORIGINAL_NEURON_CACHE_PATH is `None` when the `--no-cache` flag is set.
if _ORIGINAL_NEURON_CACHE_PATH is not None:
if is_precompilation():
# During precompilation, we make sure to set the cache path to the defined compile cache path by the
# user. If nothing is specified, it is set to the default compile cache used by the Neuron compiler:
# /var/tmp/neuron-compile-cache
set_neuron_cache_path(_ORIGINAL_NEURON_CACHE_PATH)
else:
if os.environ["RANK"] == "0":
_TMP_NEURON_CACHE_DIR = NeuronCacheCallback.create_temporary_neuron_cache(get_neuron_cache_path())
store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=True)
store.set("tmp_neuron_cache_path", _TMP_NEURON_CACHE_DIR.name)
_TMP_NEURON_CACHE_PATH = Path(_TMP_NEURON_CACHE_DIR.name)
else:
store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=False)
_TMP_NEURON_CACHE_PATH = Path(store.get("tmp_neuron_cache_path").decode("utf-8"))
set_neuron_cache_path(_TMP_NEURON_CACHE_PATH)

torch.distributed.init_process_group(backend="xla")
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
Expand Down Expand Up @@ -194,24 +173,29 @@ def __init__(self, *args, **kwargs):
if self.args.local_rank <= 0:
logger.setLevel(logging.INFO)

push = self.args.local_rank <= 0 and not is_precompilation() and not self.args.skip_cache_push
fetch = self.args.local_rank <= 0 or self.args.mp_plugin.should_parallelize

callback = NeuronCacheCallback(
tmp_neuron_cache=_TMP_NEURON_CACHE_PATH,
original_neuron_cache_path=_ORIGINAL_NEURON_CACHE_PATH,
fetch=fetch,
push=push,
wait_for_everyone_on_fetch=True,
wait_for_everyone_on_push=True,
)
self.add_callback(callback)

# Make the model Neuron-compatible for generation.
patch_generation_mixin_to_neuron_generation_mixin(self.model)

set_neuron_cc_optlevel_for_model(self.model, optlevel=self.args.neuron_cc_optlevel)

# Model cache entry management.
model_name_or_path_for_cache_entry = get_model_name_or_path(self.model.config)
model_config_for_cache_entry = copy.deepcopy(self.model.config)
use_bf16 = os.environ.get("XLA_USE_BF16", False) or os.environ.get("XLA_DOWNCAST_BF16", False)
precision = "bfloat16" if use_bf16 else "float32"
neuron_config_for_cache_entry = {
"model_class": self.model.__class__.__name__,
"precision": precision,
"num_neuron_cores_per_node": get_num_neuron_cores_used(),
"compiler_version": get_neuronxcc_version(),
"tensor_parallel_size": self.args.tensor_parallel_size,
"pipeline_parallel_size": self.args.pipeline_parallel_size,
}
self.model_cache_entry: Optional[ModelCacheEntry] = None
if model_name_or_path_for_cache_entry is not None:
model_config_for_cache_entry.neuron = neuron_config_for_cache_entry
self.model_cache_entry = ModelCacheEntry(model_name_or_path_for_cache_entry, model_config_for_cache_entry)

@property
def mp_enabled(self):
return self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM
Expand Down Expand Up @@ -259,26 +243,19 @@ def create_accelerator_and_postprocess(self):
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args)

def synchronize_hub_cache(self):
repo_id = get_hf_hub_cache_repos()[0]
if xm.get_ordinal() == 0:
has_write_access = has_write_access_to_repo(repo_id)
if has_write_access:
synchronize_hub_cache(repo_id)
xm.rendezvous("Hub cache synchronization done")

def _wrap_model(self, model, training=True, dataloader=None):
return super()._wrap_model(
self.accelerator.patch_model_for_neuron(model), training=training, dataloader=dataloader
)

# TODO: make this cleaner.
def trigger_on_step_middle_for_neuron_cache_callback(self, model: "PreTrainedModel"):
for callback in self.callback_handler.callbacks:
if isinstance(callback, NeuronCacheCallback):
# kwargs might not have everything expected (like metrics) but all we need is here.
kwargs = {
"model": model,
"tokenizer": self.tokenizer,
"optimizer": self.optimizer,
"lr_scheduler": self.lr_scheduler,
"train_dataloader": self.callback_handler.train_dataloader,
"eval_dataloader": self.callback_handler.eval_dataloader,
}
callback.on_step_middle(self.args, self.state, self.control, **kwargs)

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.mp_enabled:
if self.train_dataset is None or not has_length(self.train_dataset):
Expand Down Expand Up @@ -312,9 +289,20 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor,
return data
return super()._prepare_input(data)

def _update_input_specs_in_model_cache_entry(self, input_specs: Dict[str, Any]):
if self.model_cache_entry is None:
return
self.model_cache_entry.config["neuron"]["training"] = self.model.training
self.model_cache_entry.config["neuron"]["input_specs"] = input_specs

def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
inputs = super()._prepare_inputs(inputs)
input_specs_for_cache_entry = {k: v.shape if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
self._update_input_specs_in_model_cache_entry(input_specs_for_cache_entry)
return inputs

def compute_loss(self, model, inputs, return_outputs: bool = False):
self.state.last_inputs = inputs
self.trigger_on_step_middle_for_neuron_cache_callback(model)
from neuronx_distributed.pipeline import NxDPPModel

if isinstance(model, NxDPPModel):
Expand Down Expand Up @@ -356,7 +344,6 @@ def prediction_step(
from neuronx_distributed.pipeline import NxDPPModel

self.state.last_inputs = inputs
self.trigger_on_step_middle_for_neuron_cache_callback(model)

if isinstance(model, NxDPPModel):
if not prediction_loss_only:
Expand Down Expand Up @@ -500,10 +487,18 @@ def _save_xla(self, output_dir: Optional[str] = None):

def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
if not os.environ.get("NEURON_PARALLEL_COMPILE"): # Avoid unnecessary model saving during precompilation
if output_dir is None:
output_dir = self.args.output_dir
with patch_neuron_cc_wrapper():
if self.model_cache_entry is not None and "input_specs" not in self.model_cache_entry.config["neuron"]:
model_cache_entry = None
else:
model_cache_entry = self.model_cache_entry
with hub_neuronx_cache("training", entry=model_cache_entry):
if output_dir is None:
output_dir = self.args.output_dir

self._save_xla(output_dir)

self._save_xla(output_dir)
self.synchronize_hub_cache()

# Push to the Hub when `save_model` is called by the user.
if self.args.push_to_hub and not _internal_call:
Expand Down Expand Up @@ -1315,6 +1310,47 @@ def evaluation_loop(

return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)

def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
trial=None, # No type-annotation for this one because it is related to the optuna package.
ignore_keys_for_eval: Optional[List[str]] = None,
**kwargs,
):
with patch_neuron_cc_wrapper():
with hub_neuronx_cache("training", entry=self.model_cache_entry):
result = super().train(
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
**kwargs,
)
self.synchronize_hub_cache()
return result

def evaluate(
self,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
with patch_neuron_cc_wrapper():
with hub_neuronx_cache("training", entry=self.model_cache_entry):
result = super().evaluate(
eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
self.synchronize_hub_cache()
return result

def predict(
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
) -> PredictionOutput:
with patch_neuron_cc_wrapper():
with hub_neuronx_cache("training", entry=self.model_cache_entry):
result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
self.synchronize_hub_cache()
return result


class NeuronTrainer(AugmentTrainerForNeuronMixin, Trainer):
"""
Expand Down
Loading
Loading