diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index fc0dbb32f63..8022ab2caad 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -36,6 +36,7 @@ find_tied_parameters, get_balanced_memory, infer_auto_device_map, + is_npu_available, is_torch_version, load_checkpoint_in_model, offload_state_dict, @@ -428,10 +429,16 @@ def wrapper(*args, **kwargs): return wrapper model.to = add_warning(model.to, model) - model.cuda = add_warning(model.cuda, model) + if is_npu_available(): + model.npu = add_warning(model.npu, model) + else: + model.cuda = add_warning(model.cuda, model) else: device = list(device_map.values())[0] + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if is_npu_available() and isinstance(device, int): + device = f"npu:{device}" if device != "disk": model.to(device) else: diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 38185cda4d4..71cfbe166bc 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -38,14 +38,12 @@ if is_npu_available(check_device=False): import torch_npu # noqa: F401 - from safetensors import safe_open from safetensors.torch import load_file as safe_load_file WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" - logger = logging.getLogger(__name__) @@ -221,7 +219,7 @@ def shard_checkpoint( weight_map = {} shards = {} for idx, shard in enumerate(sharded_state_dicts): - shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shard_file = weights_name.replace(".bin", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.bin") shard_file = shard_file.replace( ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" ) @@ -307,6 +305,9 @@ def set_module_tensor_to_device( ): device_quantization = device device = "cpu" + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if is_npu_available() and isinstance(device, int): + device = f"npu:{device}" if value is None: new_value = old_value.to(device) if dtype is not None and device in ["meta", torch.device("meta")]: @@ -364,7 +365,10 @@ def set_module_tensor_to_device( if not getattr(module.weight, "quant_state", None) and device_index is not None: module.weight = module.weight.cuda(device_index) # clean pre and post foward hook - torch.cuda.empty_cache() + if is_npu_available(): + torch.npu.empty_cache() + else: + torch.cuda.empty_cache() def named_module_tensors( @@ -671,19 +675,23 @@ def get_max_memory(max_memory: Optional[Dict[Union[int, str], Union[int, str]]] import psutil if max_memory is None: - if not (torch.cuda.is_available() or is_xpu_available()): + if not (torch.cuda.is_available() or is_npu_available() or is_xpu_available()): max_memory = {} else: # Make sure CUDA is initialized on each GPU to have the right memory info. - if not is_xpu_available(): - for i in range(torch.cuda.device_count()): - _ = torch.tensor([0], device=i) - max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())} - else: + if is_npu_available(): + for i in range(torch.npu.device_count()): + _ = torch.tensor(0, device=torch.device("npu", i)) + max_memory = {i: torch.npu.mem_get_info(i)[0] for i in range(torch.npu.device_count())} + elif is_xpu_available(): for i in range(torch.xpu.device_count()): _ = torch.tensor(0, device=torch.device("xpu", i)) max_memory = {i: torch.xpu.max_memory_allocated(i) for i in range(torch.xpu.device_count())} + else: + for i in range(torch.cuda.device_count()): + _ = torch.tensor([0], device=i) + max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())} # allocate everything in the mps device as the RAM is shared if is_mps_available(): max_memory["mps"] = psutil.virtual_memory().available @@ -696,11 +704,16 @@ def get_max_memory(max_memory: Optional[Dict[Union[int, str], Union[int, str]]] max_memory[key] = convert_file_size_to_int(max_memory[key]) # Need to sort the device by type to make sure that we allocate the gpu first. - # As gpu/xpu are represented by int, we need to sort them first. + # As gpu/npu/xpu are represented by int, we need to sort them first. gpu_devices = [k for k in max_memory.keys() if isinstance(k, int)] gpu_devices.sort() - # check if gpu/xgpu devices are available and if not, throw a warning - num_devices = torch.xpu.device_count() if is_xpu_available() else torch.cuda.device_count() + # check if gpu/npu/xpu devices are available and if not, throw a warning + if is_npu_available(): + num_devices = torch.npu.device_count() + elif is_xpu_available(): + num_devices = torch.xpu.device_count() + else: + num_devices = torch.cuda.device_count() for device in gpu_devices: if device >= num_devices or device < 0: logger.warning(f"Device {device} is not available, available devices are {list(range(num_devices))}") @@ -808,9 +821,9 @@ def get_balanced_memory( user_not_set_max_memory = max_memory is None max_memory = get_max_memory(max_memory) - if not is_xpu_available(): - num_devices = len([d for d in max_memory if torch.device(d).type == "cuda" and max_memory[d] > 0]) - else: + if is_npu_available(): + num_devices = len([d for d in max_memory if torch.device(d).type == "npu" and max_memory[d] > 0]) + elif is_xpu_available(): num_devices = len( [ d @@ -822,6 +835,8 @@ def get_balanced_memory( and max_memory[d] > 0 ] ) + else: + num_devices = len([d for d in max_memory if torch.device(d).type == "cuda" and max_memory[d] > 0]) if num_devices == 0: return max_memory @@ -1043,7 +1058,7 @@ def infer_auto_device_map( if verbose: print( f"Not enough space on {devices[current_device]} to put {name} (space available " - f"{current_max_size-current_memory_used}, module size {module_size})." + f"{current_max_size - current_memory_used}, module size {module_size})." ) if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: # -> no split, we go to the next device @@ -1106,7 +1121,7 @@ def infer_auto_device_map( if verbose: print( f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " - f"available {current_max_size-current_memory_used}, needed size {module_size_with_ties})." + f"available {current_max_size - current_memory_used}, needed size {module_size_with_ties})." ) split_happened = False for tied_module_name, tied_module in zip(tied_module_names, tied_modules): @@ -1151,7 +1166,7 @@ def infer_auto_device_map( else: print( f"Putting {name} (size={module_size}) on {devices[current_device]} " - f"(available={current_max_size-current_memory_used})." + f"(available={current_max_size - current_memory_used})." ) current_memory_used += module_size device_map[name] = devices[current_device] diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 0330a9c6078..c8556a63e48 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -26,7 +26,7 @@ from ..state import PartialState from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES from .dataclasses import DistributedType, TensorInformation -from .imports import is_torch_distributed_available, is_torch_version, is_tpu_available +from .imports import is_npu_available, is_torch_distributed_available, is_torch_version, is_tpu_available if is_tpu_available(check_device=False): @@ -164,6 +164,9 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None): } ) elif hasattr(tensor, "to"): + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if is_npu_available() and isinstance(device, int): + device = f"npu:{device}" try: return tensor.to(device, non_blocking=non_blocking) except TypeError: # .to() doesn't accept non_blocking as kwarg