Skip to content

Commit

Permalink
Add npu support to big model inference
Browse files Browse the repository at this point in the history
  • Loading branch information
ji-huazhong committed Dec 6, 2023
1 parent 0482548 commit 86e7de9
Showing 1 changed file with 93 additions and 86 deletions.
179 changes: 93 additions & 86 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,14 @@
from .offload import load_offloaded_weight, offload_weight, save_offload_index
from .tqdm import is_tqdm_available, tqdm


if is_npu_available(check_device=False):
import torch_npu # noqa: F401


from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file


WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"


logger = logging.getLogger(__name__)


Expand All @@ -71,19 +67,19 @@ def convert_file_size_to_int(size: Union[int, str]):
if isinstance(size, int):
mem_size = size
elif size.upper().endswith("GIB"):
mem_size = int(float(size[:-3]) * (2**30))
mem_size = int(float(size[:-3]) * (2 ** 30))
elif size.upper().endswith("MIB"):
mem_size = int(float(size[:-3]) * (2**20))
mem_size = int(float(size[:-3]) * (2 ** 20))
elif size.upper().endswith("KIB"):
mem_size = int(float(size[:-3]) * (2**10))
mem_size = int(float(size[:-3]) * (2 ** 10))
elif size.upper().endswith("GB"):
int_size = int(float(size[:-2]) * (10**9))
int_size = int(float(size[:-2]) * (10 ** 9))
mem_size = int_size // 8 if size.endswith("b") else int_size
elif size.upper().endswith("MB"):
int_size = int(float(size[:-2]) * (10**6))
int_size = int(float(size[:-2]) * (10 ** 6))
mem_size = int_size // 8 if size.endswith("b") else int_size
elif size.upper().endswith("KB"):
int_size = int(float(size[:-2]) * (10**3))
int_size = int(float(size[:-2]) * (10 ** 3))
mem_size = int_size // 8 if size.endswith("b") else int_size
except ValueError:
raise ValueError(err_msg)
Expand Down Expand Up @@ -154,7 +150,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:


def shard_checkpoint(
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
Expand Down Expand Up @@ -221,7 +217,7 @@ def shard_checkpoint(
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = weights_name.replace(".bin", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
Expand All @@ -236,12 +232,12 @@ def shard_checkpoint(


def set_module_tensor_to_device(
module: nn.Module,
tensor_name: str,
device: Union[int, str, torch.device],
value: Optional[torch.Tensor] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
fp16_statistics: Optional[torch.HalfTensor] = None,
module: nn.Module,
tensor_name: str,
device: Union[int, str, torch.device],
value: Optional[torch.Tensor] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
fp16_statistics: Optional[torch.HalfTensor] = None,
):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
Expand Down Expand Up @@ -300,10 +296,10 @@ def set_module_tensor_to_device(
# leave it on cpu first before moving them to cuda
# # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0
if (
param is not None
and param.device.type != "cuda"
and torch.device(device).type == "cuda"
and param_cls.__name__ in ["Int8Params", "FP4Params"]
param is not None
and param.device.type != "cuda"
and torch.device(device).type == "cuda"
and param_cls.__name__ in ["Int8Params", "FP4Params"]
):
device_quantization = device
device = "cpu"
Expand Down Expand Up @@ -343,9 +339,9 @@ def set_module_tensor_to_device(
del fp16_statistics
# as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight
if (
module.__class__.__name__ == "Linear8bitLt"
and getattr(module.weight, "SCB", None) is None
and str(module.weight.device) != "meta"
module.__class__.__name__ == "Linear8bitLt"
and getattr(module.weight, "SCB", None) is None
and str(module.weight.device) != "meta"
):
# quantize only if necessary
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
Expand All @@ -366,7 +362,7 @@ def set_module_tensor_to_device(


def named_module_tensors(
module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False
module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False
):
"""
A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True`
Expand Down Expand Up @@ -447,14 +443,14 @@ def check_tied_parameters_in_config(model: nn.Module):

if "PreTrainedModel" in [c.__name__ for c in inspect.getmro(model.__class__)]:
has_tied_word_embedding = (
hasattr(model, "config")
and getattr(model.config, "tie_word_embeddings", False)
and model.get_output_embeddings()
hasattr(model, "config")
and getattr(model.config, "tie_word_embeddings", False)
and model.get_output_embeddings()
)
has_tied_encoder_decoder = (
hasattr(model, "config")
and getattr(model.config, "is_encoder_decoder", False)
and getattr(model.config, "tie_encoder_decoder", False)
hasattr(model, "config")
and getattr(model.config, "is_encoder_decoder", False)
and getattr(model.config, "tie_encoder_decoder", False)
)
has_tied_module = any(hasattr(module, "_tie_weights") for module in model.modules())

Expand Down Expand Up @@ -595,9 +591,9 @@ def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype:


def compute_module_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the size of each submodule of a given model.
Expand All @@ -624,7 +620,7 @@ def compute_module_sizes(


def get_max_layer_size(
modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str]
modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str]
):
"""
Utility function that will scan a list of named modules and return the maximum size used by one full layer. The
Expand Down Expand Up @@ -669,19 +665,23 @@ def get_max_memory(max_memory: Optional[Dict[Union[int, str], Union[int, str]]]
import psutil

if max_memory is None:
if not (torch.cuda.is_available() or is_xpu_available()):
if not (torch.cuda.is_available() or is_npu_available() or is_xpu_available()):
max_memory = {}

else:
# Make sure CUDA is initialized on each GPU to have the right memory info.
if not is_xpu_available():
for i in range(torch.cuda.device_count()):
_ = torch.tensor([0], device=i)
max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())}
else:
if is_npu_available():
for i in range(torch.npu.device_count()):
_ = torch.tensor(0, device=torch.device("npu", i))
max_memory = {i: torch.npu.mem_get_ifo(i)[0] for i in range(torch.npu.device_count())}
elif is_xpu_available():
for i in range(torch.xpu.device_count()):
_ = torch.tensor(0, device=torch.device("xpu", i))
max_memory = {i: torch.xpu.max_memory_allocated(i) for i in range(torch.xpu.device_count())}
else:
for i in range(torch.cuda.device_count()):
_ = torch.tensor([0], device=i)
max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())}
# allocate everything in the mps device as the RAM is shared
if is_mps_available():
max_memory["mps"] = psutil.virtual_memory().available
Expand All @@ -694,11 +694,16 @@ def get_max_memory(max_memory: Optional[Dict[Union[int, str], Union[int, str]]]
max_memory[key] = convert_file_size_to_int(max_memory[key])

# Need to sort the device by type to make sure that we allocate the gpu first.
# As gpu/xpu are represented by int, we need to sort them first.
# As gpu/npu/xpu are represented by int, we need to sort them first.
gpu_devices = [k for k in max_memory.keys() if isinstance(k, int)]
gpu_devices.sort()
# check if gpu/xgpu devices are available and if not, throw a warning
num_devices = torch.xpu.device_count() if is_xpu_available() else torch.cuda.device_count()
# check if gpu/npu/xpu devices are available and if not, throw a warning
if is_npu_available():
num_devices = torch.npu.device_count()
elif is_xpu_available():
num_devices = torch.xpu.device_count()
else:
num_devices = torch.cuda.device_count()
for device in gpu_devices:
if device >= num_devices or device < 0:
logger.warning(f"Device {device} is not available, available devices are {list(range(num_devices))}")
Expand Down Expand Up @@ -768,12 +773,12 @@ def load_offloaded_weights(model, index, offload_folder):


def get_balanced_memory(
model: nn.Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[List[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
low_zero: bool = False,
model: nn.Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[List[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
low_zero: bool = False,
):
"""
Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU.
Expand Down Expand Up @@ -806,20 +811,22 @@ def get_balanced_memory(
user_not_set_max_memory = max_memory is None
max_memory = get_max_memory(max_memory)

if not is_xpu_available():
num_devices = len([d for d in max_memory if torch.device(d).type == "cuda" and max_memory[d] > 0])
else:
if is_npu_available():
num_devices = len([d for d in max_memory if torch.device(d).type == "npu" and max_memory[d] > 0])
elif is_xpu_available():
num_devices = len(
[
d
for d in max_memory
if (
d != "cpu"
and (torch.device(d).type == "xpu" or torch.xpu.get_device_properties(d).dev_type == "gpu")
)
and max_memory[d] > 0
d != "cpu"
and (torch.device(d).type == "xpu" or torch.xpu.get_device_properties(d).dev_type == "gpu")
)
and max_memory[d] > 0
]
)
else:
num_devices = len([d for d in max_memory if torch.device(d).type == "cuda" and max_memory[d] > 0])

if num_devices == 0:
return max_memory
Expand Down Expand Up @@ -905,22 +912,22 @@ def calculate_maximum_sizes(model: torch.nn.Module):
no_split_modules = []

modules_to_treat = (
list(model.named_parameters(recurse=False))
+ list(model.named_children())
+ list(model.named_buffers(recurse=False))
list(model.named_parameters(recurse=False))
+ list(model.named_children())
+ list(model.named_buffers(recurse=False))
)
largest_layer = get_max_layer_size(modules_to_treat, sizes, no_split_modules)
total_size = sizes[""]
return total_size, largest_layer


def infer_auto_device_map(
model: nn.Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[List[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
verbose: bool = False,
model: nn.Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[List[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
verbose: bool = False,
):
"""
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
Expand Down Expand Up @@ -989,9 +996,9 @@ def infer_auto_device_map(

# Direct submodules and parameters
modules_to_treat = (
list(model.named_parameters(recurse=False))
+ list(model.named_children())
+ list(model.named_buffers(recurse=False))
list(model.named_parameters(recurse=False))
+ list(model.named_children())
+ list(model.named_buffers(recurse=False))
)
# Initialize maximum largest layer, to know which space to keep in memory
max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)
Expand Down Expand Up @@ -1038,7 +1045,7 @@ def infer_auto_device_map(
if verbose:
print(
f"Not enough space on {devices[current_device]} to put {name} (space available "
f"{current_max_size-current_memory_used}, module size {module_size})."
f"{current_max_size - current_memory_used}, module size {module_size})."
)
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
# -> no split, we go to the next device
Expand Down Expand Up @@ -1101,7 +1108,7 @@ def infer_auto_device_map(
if verbose:
print(
f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space "
f"available {current_max_size-current_memory_used}, needed size {module_size_with_ties})."
f"available {current_max_size - current_memory_used}, needed size {module_size_with_ties})."
)
split_happened = False
for tied_module_name, tied_module in zip(tied_module_names, tied_modules):
Expand All @@ -1117,10 +1124,10 @@ def infer_auto_device_map(
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]

modules_to_treat = (
[(name, module)]
+ modules_to_treat[:tied_module_index]
+ tied_module_children
+ modules_to_treat[tied_module_index + 1 :]
[(name, module)]
+ modules_to_treat[:tied_module_index]
+ tied_module_children
+ modules_to_treat[tied_module_index + 1:]
)
# Update the max layer size.
max_layer_size, max_layer_names = get_max_layer_size(
Expand All @@ -1146,7 +1153,7 @@ def infer_auto_device_map(
else:
print(
f"Putting {name} (size={module_size}) on {devices[current_device]} "
f"(available={current_max_size-current_memory_used})."
f"(available={current_max_size - current_memory_used})."
)
current_memory_used += module_size
device_map[name] = devices[current_device]
Expand Down Expand Up @@ -1309,15 +1316,15 @@ def get_state_dict_offloaded_model(model: nn.Module):


def load_checkpoint_in_model(
model: nn.Module,
checkpoint: Union[str, os.PathLike],
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
offload_state_dict: bool = False,
offload_buffers: bool = False,
keep_in_fp32_modules: List[str] = None,
offload_8bit_bnb: bool = False,
model: nn.Module,
checkpoint: Union[str, os.PathLike],
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
offload_state_dict: bool = False,
offload_buffers: bool = False,
keep_in_fp32_modules: List[str] = None,
offload_8bit_bnb: bool = False,
):
"""
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
Expand Down

0 comments on commit 86e7de9

Please sign in to comment.