diff --git a/src/accelerate/utils/memory.py b/src/accelerate/utils/memory.py index baa5377f6a5..b01a97390f6 100644 --- a/src/accelerate/utils/memory.py +++ b/src/accelerate/utils/memory.py @@ -103,7 +103,7 @@ def should_reduce_batch_size(exception: Exception) -> bool: return False -def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128): +def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128, reduce_batch_size_fn: callable = None): """ A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or CUDNN, the batch size is cut in half and passed to `function` @@ -134,6 +134,11 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size) batch_size = starting_batch_size + if reduce_batch_size_fn is None: + def reduce_batch_size_fn(): + nonlocal batch_size + batch_size = batch_size // 2 + return batch_size def decorator(*args, **kwargs): nonlocal batch_size @@ -154,7 +159,7 @@ def decorator(*args, **kwargs): except Exception as e: if should_reduce_batch_size(e): clear_device_cache(garbage_collection=True) - batch_size //= 2 + batch_size = reduce_batch_size_fn() else: raise