Skip to content

Commit

Permalink
make style
Browse files Browse the repository at this point in the history
  • Loading branch information
ji-huazhong committed Dec 6, 2023
1 parent 86e7de9 commit fb4c893
Showing 1 changed file with 68 additions and 66 deletions.
134 changes: 68 additions & 66 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@
from .offload import load_offloaded_weight, offload_weight, save_offload_index
from .tqdm import is_tqdm_available, tqdm


if is_npu_available(check_device=False):
import torch_npu # noqa: F401

from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file


WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,19 +69,19 @@ def convert_file_size_to_int(size: Union[int, str]):
if isinstance(size, int):
mem_size = size
elif size.upper().endswith("GIB"):
mem_size = int(float(size[:-3]) * (2 ** 30))
mem_size = int(float(size[:-3]) * (2**30))
elif size.upper().endswith("MIB"):
mem_size = int(float(size[:-3]) * (2 ** 20))
mem_size = int(float(size[:-3]) * (2**20))
elif size.upper().endswith("KIB"):
mem_size = int(float(size[:-3]) * (2 ** 10))
mem_size = int(float(size[:-3]) * (2**10))
elif size.upper().endswith("GB"):
int_size = int(float(size[:-2]) * (10 ** 9))
int_size = int(float(size[:-2]) * (10**9))
mem_size = int_size // 8 if size.endswith("b") else int_size
elif size.upper().endswith("MB"):
int_size = int(float(size[:-2]) * (10 ** 6))
int_size = int(float(size[:-2]) * (10**6))
mem_size = int_size // 8 if size.endswith("b") else int_size
elif size.upper().endswith("KB"):
int_size = int(float(size[:-2]) * (10 ** 3))
int_size = int(float(size[:-2]) * (10**3))
mem_size = int_size // 8 if size.endswith("b") else int_size
except ValueError:
raise ValueError(err_msg)
Expand Down Expand Up @@ -150,7 +152,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:


def shard_checkpoint(
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
Expand Down Expand Up @@ -232,12 +234,12 @@ def shard_checkpoint(


def set_module_tensor_to_device(
module: nn.Module,
tensor_name: str,
device: Union[int, str, torch.device],
value: Optional[torch.Tensor] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
fp16_statistics: Optional[torch.HalfTensor] = None,
module: nn.Module,
tensor_name: str,
device: Union[int, str, torch.device],
value: Optional[torch.Tensor] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
fp16_statistics: Optional[torch.HalfTensor] = None,
):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
Expand Down Expand Up @@ -296,10 +298,10 @@ def set_module_tensor_to_device(
# leave it on cpu first before moving them to cuda
# # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0
if (
param is not None
and param.device.type != "cuda"
and torch.device(device).type == "cuda"
and param_cls.__name__ in ["Int8Params", "FP4Params"]
param is not None
and param.device.type != "cuda"
and torch.device(device).type == "cuda"
and param_cls.__name__ in ["Int8Params", "FP4Params"]
):
device_quantization = device
device = "cpu"
Expand Down Expand Up @@ -339,9 +341,9 @@ def set_module_tensor_to_device(
del fp16_statistics
# as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight
if (
module.__class__.__name__ == "Linear8bitLt"
and getattr(module.weight, "SCB", None) is None
and str(module.weight.device) != "meta"
module.__class__.__name__ == "Linear8bitLt"
and getattr(module.weight, "SCB", None) is None
and str(module.weight.device) != "meta"
):
# quantize only if necessary
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
Expand All @@ -362,7 +364,7 @@ def set_module_tensor_to_device(


def named_module_tensors(
module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False
module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False
):
"""
A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True`
Expand Down Expand Up @@ -443,14 +445,14 @@ def check_tied_parameters_in_config(model: nn.Module):

if "PreTrainedModel" in [c.__name__ for c in inspect.getmro(model.__class__)]:
has_tied_word_embedding = (
hasattr(model, "config")
and getattr(model.config, "tie_word_embeddings", False)
and model.get_output_embeddings()
hasattr(model, "config")
and getattr(model.config, "tie_word_embeddings", False)
and model.get_output_embeddings()
)
has_tied_encoder_decoder = (
hasattr(model, "config")
and getattr(model.config, "is_encoder_decoder", False)
and getattr(model.config, "tie_encoder_decoder", False)
hasattr(model, "config")
and getattr(model.config, "is_encoder_decoder", False)
and getattr(model.config, "tie_encoder_decoder", False)
)
has_tied_module = any(hasattr(module, "_tie_weights") for module in model.modules())

Expand Down Expand Up @@ -591,9 +593,9 @@ def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype:


def compute_module_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the size of each submodule of a given model.
Expand All @@ -620,7 +622,7 @@ def compute_module_sizes(


def get_max_layer_size(
modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str]
modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str]
):
"""
Utility function that will scan a list of named modules and return the maximum size used by one full layer. The
Expand Down Expand Up @@ -773,12 +775,12 @@ def load_offloaded_weights(model, index, offload_folder):


def get_balanced_memory(
model: nn.Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[List[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
low_zero: bool = False,
model: nn.Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[List[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
low_zero: bool = False,
):
"""
Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU.
Expand Down Expand Up @@ -819,10 +821,10 @@ def get_balanced_memory(
d
for d in max_memory
if (
d != "cpu"
and (torch.device(d).type == "xpu" or torch.xpu.get_device_properties(d).dev_type == "gpu")
)
and max_memory[d] > 0
d != "cpu"
and (torch.device(d).type == "xpu" or torch.xpu.get_device_properties(d).dev_type == "gpu")
)
and max_memory[d] > 0
]
)
else:
Expand Down Expand Up @@ -912,22 +914,22 @@ def calculate_maximum_sizes(model: torch.nn.Module):
no_split_modules = []

modules_to_treat = (
list(model.named_parameters(recurse=False))
+ list(model.named_children())
+ list(model.named_buffers(recurse=False))
list(model.named_parameters(recurse=False))
+ list(model.named_children())
+ list(model.named_buffers(recurse=False))
)
largest_layer = get_max_layer_size(modules_to_treat, sizes, no_split_modules)
total_size = sizes[""]
return total_size, largest_layer


def infer_auto_device_map(
model: nn.Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[List[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
verbose: bool = False,
model: nn.Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[List[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
verbose: bool = False,
):
"""
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
Expand Down Expand Up @@ -996,9 +998,9 @@ def infer_auto_device_map(

# Direct submodules and parameters
modules_to_treat = (
list(model.named_parameters(recurse=False))
+ list(model.named_children())
+ list(model.named_buffers(recurse=False))
list(model.named_parameters(recurse=False))
+ list(model.named_children())
+ list(model.named_buffers(recurse=False))
)
# Initialize maximum largest layer, to know which space to keep in memory
max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)
Expand Down Expand Up @@ -1124,10 +1126,10 @@ def infer_auto_device_map(
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]

modules_to_treat = (
[(name, module)]
+ modules_to_treat[:tied_module_index]
+ tied_module_children
+ modules_to_treat[tied_module_index + 1:]
[(name, module)]
+ modules_to_treat[:tied_module_index]
+ tied_module_children
+ modules_to_treat[tied_module_index + 1 :]
)
# Update the max layer size.
max_layer_size, max_layer_names = get_max_layer_size(
Expand Down Expand Up @@ -1316,15 +1318,15 @@ def get_state_dict_offloaded_model(model: nn.Module):


def load_checkpoint_in_model(
model: nn.Module,
checkpoint: Union[str, os.PathLike],
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
offload_state_dict: bool = False,
offload_buffers: bool = False,
keep_in_fp32_modules: List[str] = None,
offload_8bit_bnb: bool = False,
model: nn.Module,
checkpoint: Union[str, os.PathLike],
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
offload_state_dict: bool = False,
offload_buffers: bool = False,
keep_in_fp32_modules: List[str] = None,
offload_8bit_bnb: bool = False,
):
"""
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
Expand Down

0 comments on commit fb4c893

Please sign in to comment.