Skip to content

Commit

Permalink
clear memory after offload (#2994)
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc authored Aug 9, 2024
1 parent 79ca85c commit 12a5bef
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/accelerate/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
send_to_device,
set_module_tensor_to_device,
)
from .utils.memory import clear_device_cache
from .utils.modeling import get_non_persistent_buffers
from .utils.other import recursive_getattr

Expand Down Expand Up @@ -695,6 +696,7 @@ def init_hook(self, module):
def pre_forward(self, module, *args, **kwargs):
if self.prev_module_hook is not None:
self.prev_module_hook.offload()
clear_device_cache()
module.to(self.execution_device)
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)

Expand Down
11 changes: 9 additions & 2 deletions src/accelerate/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@

import torch

from .imports import is_mlu_available, is_mps_available, is_musa_available, is_npu_available, is_xpu_available
from .imports import (
is_cuda_available,
is_mlu_available,
is_mps_available,
is_musa_available,
is_npu_available,
is_xpu_available,
)


def clear_device_cache(garbage_collection=False):
Expand All @@ -44,7 +51,7 @@ def clear_device_cache(garbage_collection=False):
torch.npu.empty_cache()
elif is_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else:
elif is_cuda_available():
torch.cuda.empty_cache()


Expand Down

0 comments on commit 12a5bef

Please sign in to comment.