From 12a5befdd661c47e36a94297d9a93caee4ff0c21 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Fri, 9 Aug 2024 09:36:33 +0200 Subject: [PATCH] clear memory after offload (#2994) --- src/accelerate/hooks.py | 2 ++ src/accelerate/utils/memory.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/accelerate/hooks.py b/src/accelerate/hooks.py index 9a5cd912d8c..caadd8729ee 100644 --- a/src/accelerate/hooks.py +++ b/src/accelerate/hooks.py @@ -26,6 +26,7 @@ send_to_device, set_module_tensor_to_device, ) +from .utils.memory import clear_device_cache from .utils.modeling import get_non_persistent_buffers from .utils.other import recursive_getattr @@ -695,6 +696,7 @@ def init_hook(self, module): def pre_forward(self, module, *args, **kwargs): if self.prev_module_hook is not None: self.prev_module_hook.offload() + clear_device_cache() module.to(self.execution_device) return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device) diff --git a/src/accelerate/utils/memory.py b/src/accelerate/utils/memory.py index 974a803273f..baa5377f6a5 100644 --- a/src/accelerate/utils/memory.py +++ b/src/accelerate/utils/memory.py @@ -23,7 +23,14 @@ import torch -from .imports import is_mlu_available, is_mps_available, is_musa_available, is_npu_available, is_xpu_available +from .imports import ( + is_cuda_available, + is_mlu_available, + is_mps_available, + is_musa_available, + is_npu_available, + is_xpu_available, +) def clear_device_cache(garbage_collection=False): @@ -44,7 +51,7 @@ def clear_device_cache(garbage_collection=False): torch.npu.empty_cache() elif is_mps_available(min_version="2.0"): torch.mps.empty_cache() - else: + elif is_cuda_available(): torch.cuda.empty_cache()