From 11b27973f61ca899baa642742d57f10e6541ab47 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 15 Feb 2024 15:35:07 +0100 Subject: [PATCH] neuron parallel compile --- optimum/neuron/trainers.py | 74 ++++++++++++++++++----- optimum/neuron/utils/cache_utils.py | 4 +- optimum/neuron/utils/hub_neuronx_cache.py | 3 + 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index a83e54602..894473126 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -73,8 +73,14 @@ is_torch_xla_available, patch_within_function, ) -from .utils.cache_utils import get_hf_hub_cache_repos, has_write_access_to_repo -from .utils.hub_neuronx_cache import patch_neuron_cc_wrapper, synchronize_hub_cache +from .utils.cache_utils import ( + get_hf_hub_cache_repos, + get_model_name_or_path, + get_neuronxcc_version, + get_num_neuron_cores_used, + has_write_access_to_repo, +) +from .utils.hub_neuronx_cache import ModelCacheEntry, hub_neuronx_cache, patch_neuron_cc_wrapper, synchronize_hub_cache from .utils.require_utils import requires_neuronx_distributed from .utils.training_utils import ( TRANSFORMERS_MIN_VERSION_USE_ACCELERATE, @@ -172,6 +178,24 @@ def __init__(self, *args, **kwargs): set_neuron_cc_optlevel_for_model(self.model, optlevel=self.args.neuron_cc_optlevel) + # Model cache entry management. + model_name_or_path_for_cache_entry = get_model_name_or_path(self.model.config) + model_config_for_cache_entry = copy.deepcopy(self.model.config) + use_bf16 = os.environ.get("XLA_USE_BF16", False) or os.environ.get("XLA_DOWNCAST_BF16", False) + precision = "bfloat16" if use_bf16 else "float32" + neuron_config_for_cache_entry = { + "model_class": self.model.__class__.__name__, + "precision": precision, + "num_neuron_cores_per_node": get_num_neuron_cores_used(), + "compiler_version": get_neuronxcc_version(), + "tensor_parallel_size": self.args.tensor_parallel_size, + "pipeline_parallel_size": self.args.pipeline_parallel_size, + } + self.model_cache_entry: Optional[ModelCacheEntry] = None + if model_name_or_path_for_cache_entry is not None: + model_config_for_cache_entry.neuron = neuron_config_for_cache_entry + self.model_cache_entry = ModelCacheEntry(model_name_or_path_for_cache_entry, model_config_for_cache_entry) + @property def mp_enabled(self): return self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM @@ -265,6 +289,18 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, return data return super()._prepare_input(data) + def _update_input_specs_in_model_cache_entry(self, input_specs: Dict[str, Any]): + if self.model_cache_entry is None: + return + self.model_cache_entry.config["neuron"]["training"] = self.model.training + self.model_cache_entry.config["neuron"]["input_specs"] = input_specs + + def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: + inputs = super()._prepare_inputs(inputs) + input_specs_for_cache_entry = {k: v.shape if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} + self._update_input_specs_in_model_cache_entry(input_specs_for_cache_entry) + return inputs + def compute_loss(self, model, inputs, return_outputs: bool = False): self.state.last_inputs = inputs from neuronx_distributed.pipeline import NxDPPModel @@ -452,10 +488,15 @@ def _save_xla(self, output_dir: Optional[str] = None): def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): if not os.environ.get("NEURON_PARALLEL_COMPILE"): # Avoid unnecessary model saving during precompilation with patch_neuron_cc_wrapper(): - if output_dir is None: - output_dir = self.args.output_dir + if self.model_cache_entry is not None and "input_specs" not in self.model_cache_entry.config["neuron"]: + model_cache_entry = None + else: + model_cache_entry = self.model_cache_entry + with hub_neuronx_cache("training", entry=model_cache_entry): + if output_dir is None: + output_dir = self.args.output_dir - self._save_xla(output_dir) + self._save_xla(output_dir) self.synchronize_hub_cache() @@ -1277,12 +1318,13 @@ def train( **kwargs, ): with patch_neuron_cc_wrapper(): - result = super().train( - resume_from_checkpoint=resume_from_checkpoint, - trial=trial, - ignore_keys_for_eval=ignore_keys_for_eval, - **kwargs, - ) + with hub_neuronx_cache("training", entry=self.model_cache_entry): + result = super().train( + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + **kwargs, + ) self.synchronize_hub_cache() return result @@ -1293,9 +1335,10 @@ def evaluate( metric_key_prefix: str = "eval", ) -> Dict[str, float]: with patch_neuron_cc_wrapper(): - result = super().evaluate( - eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix - ) + with hub_neuronx_cache("training", entry=self.model_cache_entry): + result = super().evaluate( + eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix + ) self.synchronize_hub_cache() return result @@ -1303,7 +1346,8 @@ def predict( self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test" ) -> PredictionOutput: with patch_neuron_cc_wrapper(): - result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + with hub_neuronx_cache("training", entry=self.model_cache_entry): + result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) self.synchronize_hub_cache() return result diff --git a/optimum/neuron/utils/cache_utils.py b/optimum/neuron/utils/cache_utils.py index 7e78877f4..75dfbddb8 100644 --- a/optimum/neuron/utils/cache_utils.py +++ b/optimum/neuron/utils/cache_utils.py @@ -308,7 +308,7 @@ def remove_ip_adress_from_path(path: Path) -> Path: return Path().joinpath(*(re.sub(_IP_PATTERN, "", part) for part in path.parts)) -def _get_model_name_or_path(config: "PretrainedConfig") -> Optional[str]: +def get_model_name_or_path(config: "PretrainedConfig") -> Optional[str]: attribute_names_to_try = ["_model_name_or_path", "_name_or_path"] model_name_or_path = None for name in attribute_names_to_try: @@ -664,7 +664,7 @@ def __post_init__(self, model: "PreTrainedModel"): # Checking whether the model is private or not. is_private = None - model_name_or_path = _get_model_name_or_path(model.config) + model_name_or_path = get_model_name_or_path(model.config) if model_name_or_path is None: is_private = True elif Path(model_name_or_path).exists(): diff --git a/optimum/neuron/utils/hub_neuronx_cache.py b/optimum/neuron/utils/hub_neuronx_cache.py index 213afe903..468b94544 100644 --- a/optimum/neuron/utils/hub_neuronx_cache.py +++ b/optimum/neuron/utils/hub_neuronx_cache.py @@ -346,6 +346,9 @@ def synchronize_hub_cache(cache_repo_id: Optional[str] = None): repo_id (`Optional[str]`, default to None): The id of the HuggingFace cache repository, in the form 'org|user/name'. """ + # Not pushing anything if neuron parallel compile. + if os.environ.get("NEURON_PARALLEL_COMPILE") == "1": + return hub_cache_proxy = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id) hub_cache_proxy.synchronize()