From 92404fbf5fe48a01b5e2f33fa33e74b994445f76 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Wed, 3 Jul 2024 18:36:36 +0800 Subject: [PATCH] fix `load_state_dict` for xpu and refine xpu safetensor version check (#2879) * add fix * update warning * no and --- src/accelerate/utils/modeling.py | 22 +++++++++++----------- tests/test_modeling_utils.py | 7 +++++-- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 396377bf4f2..b57b476df41 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -14,7 +14,6 @@ import contextlib import gc -import importlib import inspect import json import logging @@ -26,7 +25,6 @@ from collections import OrderedDict, defaultdict from typing import Dict, List, Optional, Tuple, Union -import packaging import torch import torch.nn as nn @@ -1456,7 +1454,15 @@ def load_state_dict(checkpoint_file, device_map=None): else: # if we only have one device we can load everything directly if len(set(device_map.values())) == 1: - return safe_load_file(checkpoint_file, device=list(device_map.values())[0]) + device = list(device_map.values())[0] + target_device = device + if is_xpu_available(): + if compare_versions("safetensors", "<", "0.4.2"): + raise ImportError("Safetensors version must be >= 0.4.2 for XPU. Please upgrade safetensors.") + if isinstance(device, int): + target_device = f"xpu:{device}" + + return safe_load_file(checkpoint_file, device=target_device) devices = list(set(device_map.values()) - {"disk"}) # cpu device should always exist as fallback option @@ -1486,15 +1492,9 @@ def load_state_dict(checkpoint_file, device_map=None): progress_bar = None for device in devices: target_device = device - if is_xpu_available(): - current_safetensors_version = packaging.version.parse(importlib.metadata.version("safetensors")) - - if compare_versions(current_safetensors_version, "<", "0.4.2"): - raise ModuleNotFoundError( - f"You need at least safetensors 0.4.2 for Intel GPU, while you have {current_safetensors_version}" - ) - + if compare_versions("safetensors", "<", "0.4.2"): + raise ImportError("Safetensors version must be >= 0.4.2 for XPU. Please upgrade safetensors.") if isinstance(device, int): target_device = f"xpu:{device}" diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 27605cb0f50..0cb2b152d9b 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -734,7 +734,7 @@ def test_get_balanced_memory(self): max_memory = get_balanced_memory(model, max_memory={0: 0, "cpu": 100}) assert {0: 0, "cpu": 100} == max_memory - @require_cuda + @require_non_cpu def test_load_state_dict(self): state_dict = {k: torch.randn(4, 5) for k in ["a", "b", "c"]} device_maps = [{"a": "cpu", "b": 0, "c": "disk"}, {"a": 0, "b": 0, "c": "disk"}, {"a": 0, "b": 0, "c": 0}] @@ -748,7 +748,10 @@ def test_load_state_dict(self): for param, device in device_map.items(): device = device if device != "disk" else "cpu" - assert loaded_state_dict[param].device == torch.device(device) + expected_device = ( + torch.device(f"{torch_device}:{device}") if isinstance(device, int) else torch.device(device) + ) + assert loaded_state_dict[param].device == expected_device def test_convert_file_size(self): result = convert_file_size_to_int("0MB")