From 6b2d968897c91bc3f96274b4679d84e9950ad908 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 14 Dec 2023 15:55:31 +0100 Subject: [PATCH] [`Big-Modeling`] Harmonize device check to handle corner cases (#2254) * harmonize device check * make style * oops * oops again --- src/accelerate/utils/modeling.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 71cfbe166bc..3e9070b9d0e 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -47,6 +47,33 @@ logger = logging.getLogger(__name__) +def check_device_same(first_device, second_device): + """ + Utility method to check if two `torch` devices are similar. When dealing with CUDA devices, torch throws `False` + for `torch.device("cuda") == torch.device("cuda:0")` whereas they should be the same + + Args: + first_device (`torch.device`): + First device to check + second_device (`torch.device`): + Second device to check + """ + if first_device.type != second_device.type: + return False + + if first_device.type == "cuda" and first_device.index is None: + # In case the first_device is a cuda device and have + # the index attribute set to `None`, default it to `0` + first_device = torch.device("cuda", index=0) + + if second_device.type == "cuda" and second_device.index is None: + # In case the second_device is a cuda device and have + # the index attribute set to `None`, default it to `0` + second_device = torch.device("cuda", index=0) + + return first_device == second_device + + def convert_file_size_to_int(size: Union[int, str]): """ Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes). @@ -324,7 +351,7 @@ def set_module_tensor_to_device( device = device_quantization if is_buffer: module._buffers[tensor_name] = new_value - elif value is not None or torch.device(device) != module._parameters[tensor_name].device: + elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device): param_cls = type(module._parameters[tensor_name]) kwargs = module._parameters[tensor_name].__dict__ if param_cls.__name__ in ["Int8Params", "FP4Params"]: