Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 9, 2024
1 parent 6687070 commit aafeecc
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
12 changes: 9 additions & 3 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,16 @@ def _save_xla(self, output_dir: Optional[str] = None):

def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
if not os.environ.get("NEURON_PARALLEL_COMPILE"): # Avoid unnecessary model saving during precompilation
if output_dir is None:
output_dir = self.args.output_dir
with patch_neuron_cc_wrapper():
with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]):
if output_dir is None:
output_dir = self.args.output_dir

self._save_xla(output_dir)
self._save_xla(output_dir)

if xm.get_ordinal() == 0:
synchronize_hub_cache(get_hf_hub_cache_repos()[0])
xm.rendezvous("Hub cache synchronization done")

# Push to the Hub when `save_model` is called by the user.
if self.args.push_to_hub and not _internal_call:
Expand Down
14 changes: 11 additions & 3 deletions optimum/neuron/utils/hub_neuronx_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..version import __version__
from .import_utils import is_neuronx_available
from .patching import patch_everywhere
from .require_utils import requires_torch_neuronx
from .require_utils import requires_torch_neuronx, requires_torch_xla


if is_neuronx_available():
Expand Down Expand Up @@ -277,13 +277,17 @@ def hf_create_compile_cache(cache_url):
patch_everywhere("create_compile_cache", create_compile_cache, "libneuronxla")


@requires_torch_neuronx
@requires_torch_xla
@contextmanager
def patch_neuron_cc_wrapper():
"""
Patches the `neuron_cc_wrapper` file to force it use our own version of it which essentially makes sure that it
uses our caching system.
"""

import torch_xla.core.xla_model as xm

def patch(restore: bool = False):
path = os.environ["PATH"]
main_dir = Path(path.split(":")[0])
Expand All @@ -301,10 +305,14 @@ def patch(restore: bool = False):
shutil.copy(src, dst)

try:
patch()
if xm.get_ordinal() == 0:
patch()
xm.rendezvous("Patch neuron_cc_wrapper")
yield
finally:
patch(restore=True)
if xm.get_ordinal() == 0:
patch(restore=True)
xm.rendezvous("Restore neuron_cc_wrapper")


@requires_torch_neuronx
Expand Down

0 comments on commit aafeecc

Please sign in to comment.