diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 2815ef2dc64..557640414c8 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -34,12 +34,14 @@ from .offload import load_offloaded_weight, offload_weight, save_offload_index from .tqdm import is_tqdm_available, tqdm + if is_npu_available(check_device=False): import torch_npu # noqa: F401 from safetensors import safe_open from safetensors.torch import load_file as safe_load_file + WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" logger = logging.getLogger(__name__) @@ -67,19 +69,19 @@ def convert_file_size_to_int(size: Union[int, str]): if isinstance(size, int): mem_size = size elif size.upper().endswith("GIB"): - mem_size = int(float(size[:-3]) * (2 ** 30)) + mem_size = int(float(size[:-3]) * (2**30)) elif size.upper().endswith("MIB"): - mem_size = int(float(size[:-3]) * (2 ** 20)) + mem_size = int(float(size[:-3]) * (2**20)) elif size.upper().endswith("KIB"): - mem_size = int(float(size[:-3]) * (2 ** 10)) + mem_size = int(float(size[:-3]) * (2**10)) elif size.upper().endswith("GB"): - int_size = int(float(size[:-2]) * (10 ** 9)) + int_size = int(float(size[:-2]) * (10**9)) mem_size = int_size // 8 if size.endswith("b") else int_size elif size.upper().endswith("MB"): - int_size = int(float(size[:-2]) * (10 ** 6)) + int_size = int(float(size[:-2]) * (10**6)) mem_size = int_size // 8 if size.endswith("b") else int_size elif size.upper().endswith("KB"): - int_size = int(float(size[:-2]) * (10 ** 3)) + int_size = int(float(size[:-2]) * (10**3)) mem_size = int_size // 8 if size.endswith("b") else int_size except ValueError: raise ValueError(err_msg) @@ -150,7 +152,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: def shard_checkpoint( - state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME + state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME ): """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a @@ -232,12 +234,12 @@ def shard_checkpoint( def set_module_tensor_to_device( - module: nn.Module, - tensor_name: str, - device: Union[int, str, torch.device], - value: Optional[torch.Tensor] = None, - dtype: Optional[Union[str, torch.dtype]] = None, - fp16_statistics: Optional[torch.HalfTensor] = None, + module: nn.Module, + tensor_name: str, + device: Union[int, str, torch.device], + value: Optional[torch.Tensor] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + fp16_statistics: Optional[torch.HalfTensor] = None, ): """ A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing @@ -296,10 +298,10 @@ def set_module_tensor_to_device( # leave it on cpu first before moving them to cuda # # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0 if ( - param is not None - and param.device.type != "cuda" - and torch.device(device).type == "cuda" - and param_cls.__name__ in ["Int8Params", "FP4Params"] + param is not None + and param.device.type != "cuda" + and torch.device(device).type == "cuda" + and param_cls.__name__ in ["Int8Params", "FP4Params"] ): device_quantization = device device = "cpu" @@ -339,9 +341,9 @@ def set_module_tensor_to_device( del fp16_statistics # as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight if ( - module.__class__.__name__ == "Linear8bitLt" - and getattr(module.weight, "SCB", None) is None - and str(module.weight.device) != "meta" + module.__class__.__name__ == "Linear8bitLt" + and getattr(module.weight, "SCB", None) is None + and str(module.weight.device) != "meta" ): # quantize only if necessary device_index = torch.device(device).index if torch.device(device).type == "cuda" else None @@ -362,7 +364,7 @@ def set_module_tensor_to_device( def named_module_tensors( - module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False + module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False ): """ A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True` @@ -443,14 +445,14 @@ def check_tied_parameters_in_config(model: nn.Module): if "PreTrainedModel" in [c.__name__ for c in inspect.getmro(model.__class__)]: has_tied_word_embedding = ( - hasattr(model, "config") - and getattr(model.config, "tie_word_embeddings", False) - and model.get_output_embeddings() + hasattr(model, "config") + and getattr(model.config, "tie_word_embeddings", False) + and model.get_output_embeddings() ) has_tied_encoder_decoder = ( - hasattr(model, "config") - and getattr(model.config, "is_encoder_decoder", False) - and getattr(model.config, "tie_encoder_decoder", False) + hasattr(model, "config") + and getattr(model.config, "is_encoder_decoder", False) + and getattr(model.config, "tie_encoder_decoder", False) ) has_tied_module = any(hasattr(module, "_tie_weights") for module in model.modules()) @@ -591,9 +593,9 @@ def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype: def compute_module_sizes( - model: nn.Module, - dtype: Optional[Union[str, torch.device]] = None, - special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, + model: nn.Module, + dtype: Optional[Union[str, torch.device]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, ): """ Compute the size of each submodule of a given model. @@ -620,7 +622,7 @@ def compute_module_sizes( def get_max_layer_size( - modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str] + modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str] ): """ Utility function that will scan a list of named modules and return the maximum size used by one full layer. The @@ -773,12 +775,12 @@ def load_offloaded_weights(model, index, offload_folder): def get_balanced_memory( - model: nn.Module, - max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, - no_split_module_classes: Optional[List[str]] = None, - dtype: Optional[Union[str, torch.dtype]] = None, - special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, - low_zero: bool = False, + model: nn.Module, + max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, + no_split_module_classes: Optional[List[str]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, + low_zero: bool = False, ): """ Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU. @@ -819,10 +821,10 @@ def get_balanced_memory( d for d in max_memory if ( - d != "cpu" - and (torch.device(d).type == "xpu" or torch.xpu.get_device_properties(d).dev_type == "gpu") - ) - and max_memory[d] > 0 + d != "cpu" + and (torch.device(d).type == "xpu" or torch.xpu.get_device_properties(d).dev_type == "gpu") + ) + and max_memory[d] > 0 ] ) else: @@ -912,9 +914,9 @@ def calculate_maximum_sizes(model: torch.nn.Module): no_split_modules = [] modules_to_treat = ( - list(model.named_parameters(recurse=False)) - + list(model.named_children()) - + list(model.named_buffers(recurse=False)) + list(model.named_parameters(recurse=False)) + + list(model.named_children()) + + list(model.named_buffers(recurse=False)) ) largest_layer = get_max_layer_size(modules_to_treat, sizes, no_split_modules) total_size = sizes[""] @@ -922,12 +924,12 @@ def calculate_maximum_sizes(model: torch.nn.Module): def infer_auto_device_map( - model: nn.Module, - max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, - no_split_module_classes: Optional[List[str]] = None, - dtype: Optional[Union[str, torch.dtype]] = None, - special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None, - verbose: bool = False, + model: nn.Module, + max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, + no_split_module_classes: Optional[List[str]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None, + verbose: bool = False, ): """ Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, @@ -996,9 +998,9 @@ def infer_auto_device_map( # Direct submodules and parameters modules_to_treat = ( - list(model.named_parameters(recurse=False)) - + list(model.named_children()) - + list(model.named_buffers(recurse=False)) + list(model.named_parameters(recurse=False)) + + list(model.named_children()) + + list(model.named_buffers(recurse=False)) ) # Initialize maximum largest layer, to know which space to keep in memory max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes) @@ -1124,10 +1126,10 @@ def infer_auto_device_map( tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] modules_to_treat = ( - [(name, module)] - + modules_to_treat[:tied_module_index] - + tied_module_children - + modules_to_treat[tied_module_index + 1:] + [(name, module)] + + modules_to_treat[:tied_module_index] + + tied_module_children + + modules_to_treat[tied_module_index + 1 :] ) # Update the max layer size. max_layer_size, max_layer_names = get_max_layer_size( @@ -1316,15 +1318,15 @@ def get_state_dict_offloaded_model(model: nn.Module): def load_checkpoint_in_model( - model: nn.Module, - checkpoint: Union[str, os.PathLike], - device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None, - offload_folder: Optional[Union[str, os.PathLike]] = None, - dtype: Optional[Union[str, torch.dtype]] = None, - offload_state_dict: bool = False, - offload_buffers: bool = False, - keep_in_fp32_modules: List[str] = None, - offload_8bit_bnb: bool = False, + model: nn.Module, + checkpoint: Union[str, os.PathLike], + device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + offload_state_dict: bool = False, + offload_buffers: bool = False, + keep_in_fp32_modules: List[str] = None, + offload_8bit_bnb: bool = False, ): """ Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are