Skip to content

Commit

Permalink
neuron parallel compile
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 15, 2024
1 parent 17c6f38 commit 11b2797
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
74 changes: 59 additions & 15 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,14 @@
is_torch_xla_available,
patch_within_function,
)
from .utils.cache_utils import get_hf_hub_cache_repos, has_write_access_to_repo
from .utils.hub_neuronx_cache import patch_neuron_cc_wrapper, synchronize_hub_cache
from .utils.cache_utils import (
get_hf_hub_cache_repos,
get_model_name_or_path,
get_neuronxcc_version,
get_num_neuron_cores_used,
has_write_access_to_repo,
)
from .utils.hub_neuronx_cache import ModelCacheEntry, hub_neuronx_cache, patch_neuron_cc_wrapper, synchronize_hub_cache
from .utils.require_utils import requires_neuronx_distributed
from .utils.training_utils import (
TRANSFORMERS_MIN_VERSION_USE_ACCELERATE,
Expand Down Expand Up @@ -172,6 +178,24 @@ def __init__(self, *args, **kwargs):

set_neuron_cc_optlevel_for_model(self.model, optlevel=self.args.neuron_cc_optlevel)

# Model cache entry management.
model_name_or_path_for_cache_entry = get_model_name_or_path(self.model.config)
model_config_for_cache_entry = copy.deepcopy(self.model.config)
use_bf16 = os.environ.get("XLA_USE_BF16", False) or os.environ.get("XLA_DOWNCAST_BF16", False)
precision = "bfloat16" if use_bf16 else "float32"
neuron_config_for_cache_entry = {
"model_class": self.model.__class__.__name__,
"precision": precision,
"num_neuron_cores_per_node": get_num_neuron_cores_used(),
"compiler_version": get_neuronxcc_version(),
"tensor_parallel_size": self.args.tensor_parallel_size,
"pipeline_parallel_size": self.args.pipeline_parallel_size,
}
self.model_cache_entry: Optional[ModelCacheEntry] = None
if model_name_or_path_for_cache_entry is not None:
model_config_for_cache_entry.neuron = neuron_config_for_cache_entry
self.model_cache_entry = ModelCacheEntry(model_name_or_path_for_cache_entry, model_config_for_cache_entry)

@property
def mp_enabled(self):
return self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM
Expand Down Expand Up @@ -265,6 +289,18 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor,
return data
return super()._prepare_input(data)

def _update_input_specs_in_model_cache_entry(self, input_specs: Dict[str, Any]):
if self.model_cache_entry is None:
return
self.model_cache_entry.config["neuron"]["training"] = self.model.training
self.model_cache_entry.config["neuron"]["input_specs"] = input_specs

def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
inputs = super()._prepare_inputs(inputs)
input_specs_for_cache_entry = {k: v.shape if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
self._update_input_specs_in_model_cache_entry(input_specs_for_cache_entry)
return inputs

def compute_loss(self, model, inputs, return_outputs: bool = False):
self.state.last_inputs = inputs
from neuronx_distributed.pipeline import NxDPPModel
Expand Down Expand Up @@ -452,10 +488,15 @@ def _save_xla(self, output_dir: Optional[str] = None):
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
if not os.environ.get("NEURON_PARALLEL_COMPILE"): # Avoid unnecessary model saving during precompilation
with patch_neuron_cc_wrapper():
if output_dir is None:
output_dir = self.args.output_dir
if self.model_cache_entry is not None and "input_specs" not in self.model_cache_entry.config["neuron"]:
model_cache_entry = None
else:
model_cache_entry = self.model_cache_entry
with hub_neuronx_cache("training", entry=model_cache_entry):
if output_dir is None:
output_dir = self.args.output_dir

self._save_xla(output_dir)
self._save_xla(output_dir)

self.synchronize_hub_cache()

Expand Down Expand Up @@ -1277,12 +1318,13 @@ def train(
**kwargs,
):
with patch_neuron_cc_wrapper():
result = super().train(
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
**kwargs,
)
with hub_neuronx_cache("training", entry=self.model_cache_entry):
result = super().train(
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
**kwargs,
)
self.synchronize_hub_cache()
return result

Expand All @@ -1293,17 +1335,19 @@ def evaluate(
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
with patch_neuron_cc_wrapper():
result = super().evaluate(
eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
with hub_neuronx_cache("training", entry=self.model_cache_entry):
result = super().evaluate(
eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
self.synchronize_hub_cache()
return result

def predict(
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
) -> PredictionOutput:
with patch_neuron_cc_wrapper():
result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
with hub_neuronx_cache("training", entry=self.model_cache_entry):
result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
self.synchronize_hub_cache()
return result

Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/utils/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def remove_ip_adress_from_path(path: Path) -> Path:
return Path().joinpath(*(re.sub(_IP_PATTERN, "", part) for part in path.parts))


def _get_model_name_or_path(config: "PretrainedConfig") -> Optional[str]:
def get_model_name_or_path(config: "PretrainedConfig") -> Optional[str]:
attribute_names_to_try = ["_model_name_or_path", "_name_or_path"]
model_name_or_path = None
for name in attribute_names_to_try:
Expand Down Expand Up @@ -664,7 +664,7 @@ def __post_init__(self, model: "PreTrainedModel"):

# Checking whether the model is private or not.
is_private = None
model_name_or_path = _get_model_name_or_path(model.config)
model_name_or_path = get_model_name_or_path(model.config)
if model_name_or_path is None:
is_private = True
elif Path(model_name_or_path).exists():
Expand Down
3 changes: 3 additions & 0 deletions optimum/neuron/utils/hub_neuronx_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ def synchronize_hub_cache(cache_repo_id: Optional[str] = None):
repo_id (`Optional[str]`, default to None):
The id of the HuggingFace cache repository, in the form 'org|user/name'.
"""
# Not pushing anything if neuron parallel compile.
if os.environ.get("NEURON_PARALLEL_COMPILE") == "1":
return
hub_cache_proxy = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id)
hub_cache_proxy.synchronize()

Expand Down

0 comments on commit 11b2797

Please sign in to comment.