Skip to content

Commit

Permalink
Add npu support to big model inference (#2222)
Browse files Browse the repository at this point in the history
* Add npu support to big model inference

* make style

* add warning when using npu

* fix typo

* replace `.to(<num>)` with `.to("npu:<num>") when using `torch_npu`

* empty_cache

* fix
  • Loading branch information
ji-huazhong authored Dec 8, 2023
1 parent f86876d commit 9964f90
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 21 deletions.
9 changes: 8 additions & 1 deletion src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
find_tied_parameters,
get_balanced_memory,
infer_auto_device_map,
is_npu_available,
is_torch_version,
load_checkpoint_in_model,
offload_state_dict,
Expand Down Expand Up @@ -428,10 +429,16 @@ def wrapper(*args, **kwargs):
return wrapper

model.to = add_warning(model.to, model)
model.cuda = add_warning(model.cuda, model)
if is_npu_available():
model.npu = add_warning(model.npu, model)
else:
model.cuda = add_warning(model.cuda, model)

else:
device = list(device_map.values())[0]
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if is_npu_available() and isinstance(device, int):
device = f"npu:{device}"
if device != "disk":
model.to(device)
else:
Expand Down
53 changes: 34 additions & 19 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@
if is_npu_available(check_device=False):
import torch_npu # noqa: F401


from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file


WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -221,7 +219,7 @@ def shard_checkpoint(
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = weights_name.replace(".bin", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
Expand Down Expand Up @@ -307,6 +305,9 @@ def set_module_tensor_to_device(
):
device_quantization = device
device = "cpu"
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if is_npu_available() and isinstance(device, int):
device = f"npu:{device}"
if value is None:
new_value = old_value.to(device)
if dtype is not None and device in ["meta", torch.device("meta")]:
Expand Down Expand Up @@ -364,7 +365,10 @@ def set_module_tensor_to_device(
if not getattr(module.weight, "quant_state", None) and device_index is not None:
module.weight = module.weight.cuda(device_index)
# clean pre and post foward hook
torch.cuda.empty_cache()
if is_npu_available():
torch.npu.empty_cache()
else:
torch.cuda.empty_cache()


def named_module_tensors(
Expand Down Expand Up @@ -671,19 +675,23 @@ def get_max_memory(max_memory: Optional[Dict[Union[int, str], Union[int, str]]]
import psutil

if max_memory is None:
if not (torch.cuda.is_available() or is_xpu_available()):
if not (torch.cuda.is_available() or is_npu_available() or is_xpu_available()):
max_memory = {}

else:
# Make sure CUDA is initialized on each GPU to have the right memory info.
if not is_xpu_available():
for i in range(torch.cuda.device_count()):
_ = torch.tensor([0], device=i)
max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())}
else:
if is_npu_available():
for i in range(torch.npu.device_count()):
_ = torch.tensor(0, device=torch.device("npu", i))
max_memory = {i: torch.npu.mem_get_info(i)[0] for i in range(torch.npu.device_count())}
elif is_xpu_available():
for i in range(torch.xpu.device_count()):
_ = torch.tensor(0, device=torch.device("xpu", i))
max_memory = {i: torch.xpu.max_memory_allocated(i) for i in range(torch.xpu.device_count())}
else:
for i in range(torch.cuda.device_count()):
_ = torch.tensor([0], device=i)
max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())}
# allocate everything in the mps device as the RAM is shared
if is_mps_available():
max_memory["mps"] = psutil.virtual_memory().available
Expand All @@ -696,11 +704,16 @@ def get_max_memory(max_memory: Optional[Dict[Union[int, str], Union[int, str]]]
max_memory[key] = convert_file_size_to_int(max_memory[key])

# Need to sort the device by type to make sure that we allocate the gpu first.
# As gpu/xpu are represented by int, we need to sort them first.
# As gpu/npu/xpu are represented by int, we need to sort them first.
gpu_devices = [k for k in max_memory.keys() if isinstance(k, int)]
gpu_devices.sort()
# check if gpu/xgpu devices are available and if not, throw a warning
num_devices = torch.xpu.device_count() if is_xpu_available() else torch.cuda.device_count()
# check if gpu/npu/xpu devices are available and if not, throw a warning
if is_npu_available():
num_devices = torch.npu.device_count()
elif is_xpu_available():
num_devices = torch.xpu.device_count()
else:
num_devices = torch.cuda.device_count()
for device in gpu_devices:
if device >= num_devices or device < 0:
logger.warning(f"Device {device} is not available, available devices are {list(range(num_devices))}")
Expand Down Expand Up @@ -808,9 +821,9 @@ def get_balanced_memory(
user_not_set_max_memory = max_memory is None
max_memory = get_max_memory(max_memory)

if not is_xpu_available():
num_devices = len([d for d in max_memory if torch.device(d).type == "cuda" and max_memory[d] > 0])
else:
if is_npu_available():
num_devices = len([d for d in max_memory if torch.device(d).type == "npu" and max_memory[d] > 0])
elif is_xpu_available():
num_devices = len(
[
d
Expand All @@ -822,6 +835,8 @@ def get_balanced_memory(
and max_memory[d] > 0
]
)
else:
num_devices = len([d for d in max_memory if torch.device(d).type == "cuda" and max_memory[d] > 0])

if num_devices == 0:
return max_memory
Expand Down Expand Up @@ -1043,7 +1058,7 @@ def infer_auto_device_map(
if verbose:
print(
f"Not enough space on {devices[current_device]} to put {name} (space available "
f"{current_max_size-current_memory_used}, module size {module_size})."
f"{current_max_size - current_memory_used}, module size {module_size})."
)
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
# -> no split, we go to the next device
Expand Down Expand Up @@ -1106,7 +1121,7 @@ def infer_auto_device_map(
if verbose:
print(
f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space "
f"available {current_max_size-current_memory_used}, needed size {module_size_with_ties})."
f"available {current_max_size - current_memory_used}, needed size {module_size_with_ties})."
)
split_happened = False
for tied_module_name, tied_module in zip(tied_module_names, tied_modules):
Expand Down Expand Up @@ -1151,7 +1166,7 @@ def infer_auto_device_map(
else:
print(
f"Putting {name} (size={module_size}) on {devices[current_device]} "
f"(available={current_max_size-current_memory_used})."
f"(available={current_max_size - current_memory_used})."
)
current_memory_used += module_size
device_map[name] = devices[current_device]
Expand Down
5 changes: 4 additions & 1 deletion src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..state import PartialState
from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
from .dataclasses import DistributedType, TensorInformation
from .imports import is_torch_distributed_available, is_torch_version, is_tpu_available
from .imports import is_npu_available, is_torch_distributed_available, is_torch_version, is_tpu_available


if is_tpu_available(check_device=False):
Expand Down Expand Up @@ -164,6 +164,9 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
}
)
elif hasattr(tensor, "to"):
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if is_npu_available() and isinstance(device, int):
device = f"npu:{device}"
try:
return tensor.to(device, non_blocking=non_blocking)
except TypeError: # .to() doesn't accept non_blocking as kwarg
Expand Down

0 comments on commit 9964f90

Please sign in to comment.