diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index eb3fc5634a9..38185cda4d4 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -21,7 +21,7 @@ import re import shutil import tempfile -from collections import defaultdict +from collections import OrderedDict, defaultdict from typing import Dict, List, Optional, Tuple, Union import torch @@ -923,6 +923,7 @@ def infer_auto_device_map( dtype: Optional[Union[str, torch.dtype]] = None, special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None, verbose: bool = False, + clean_result: bool = True, ): """ Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, @@ -956,6 +957,8 @@ def infer_auto_device_map( all weights). verbose (`bool`, *optional*, defaults to `False`): Whether or not to provide debugging statements as the function builds the device_map. + clean_result (`bool`, *optional*, defaults to `True`): + Clean the resulting device_map by grouping all submodules that go on the same device together. """ # Get default / clean up max_memory max_memory = get_max_memory(max_memory) @@ -985,7 +988,7 @@ def infer_auto_device_map( "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." ) - device_map = {} + device_map = OrderedDict() current_device = 0 current_memory_used = 0 @@ -1153,7 +1156,9 @@ def infer_auto_device_map( current_memory_used += module_size device_map[name] = devices[current_device] - return clean_device_map(device_map) + if clean_result: + device_map = clean_device_map(device_map) + return device_map def check_device_map(model: nn.Module, device_map: Dict[str, Union[int, str, torch.device]]):