Skip to content

Commit

Permalink
[Big-Modeling] Harmonize device check to handle corner cases (#2254)
Browse files Browse the repository at this point in the history
* harmonize device check

* make style

* oops

* oops again
  • Loading branch information
younesbelkada authored Dec 14, 2023
1 parent ad3a5bc commit 6b2d968
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,33 @@
logger = logging.getLogger(__name__)


def check_device_same(first_device, second_device):
"""
Utility method to check if two `torch` devices are similar. When dealing with CUDA devices, torch throws `False`
for `torch.device("cuda") == torch.device("cuda:0")` whereas they should be the same
Args:
first_device (`torch.device`):
First device to check
second_device (`torch.device`):
Second device to check
"""
if first_device.type != second_device.type:
return False

if first_device.type == "cuda" and first_device.index is None:
# In case the first_device is a cuda device and have
# the index attribute set to `None`, default it to `0`
first_device = torch.device("cuda", index=0)

if second_device.type == "cuda" and second_device.index is None:
# In case the second_device is a cuda device and have
# the index attribute set to `None`, default it to `0`
second_device = torch.device("cuda", index=0)

return first_device == second_device


def convert_file_size_to_int(size: Union[int, str]):
"""
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
Expand Down Expand Up @@ -324,7 +351,7 @@ def set_module_tensor_to_device(
device = device_quantization
if is_buffer:
module._buffers[tensor_name] = new_value
elif value is not None or torch.device(device) != module._parameters[tensor_name].device:
elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device):
param_cls = type(module._parameters[tensor_name])
kwargs = module._parameters[tensor_name].__dict__
if param_cls.__name__ in ["Int8Params", "FP4Params"]:
Expand Down

0 comments on commit 6b2d968

Please sign in to comment.