Skip to content

Commit 176a17b

Browse files
authored
[Offloading] [Bugfix] Fix disk offloading of models with explicit tensor dtypes (#46849)
* use meta tensor dtype Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add test Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 4ab0479 commit 176a17b

3 files changed

Lines changed: 35 additions & 4 deletions

File tree

src/transformers/integrations/accelerate.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,6 @@ def accelerate_disk_offload(
437437
checkpoint_files: list[str] | None,
438438
device_map: dict,
439439
sharded_metadata: dict | None,
440-
dtype: torch.dtype | None,
441440
weight_mapping=None,
442441
):
443442
"""
@@ -460,7 +459,6 @@ def accelerate_disk_offload(
460459
if is_offloaded_safetensors:
461460
meta_state_dict = model.state_dict()
462461
param_device_map = expand_device_map(device_map, meta_state_dict.keys())
463-
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
464462
if sharded_metadata is None:
465463
weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0])
466464
else:
@@ -480,7 +478,7 @@ def accelerate_disk_offload(
480478
target_name: {
481479
"safetensors_file": weight_map[source_name],
482480
"weight_name": source_name,
483-
"dtype": str_dtype,
481+
"dtype": str(meta_state_dict[target_name].dtype).removeprefix("torch."),
484482
}
485483
for target_name, source_name in weight_renaming_map.items()
486484
# Need to check if it's in the mapping in case of unexpected keys that would result in KeyError (we skip them)

src/transformers/modeling_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4419,7 +4419,6 @@ def _load_pretrained_model(
44194419
checkpoint_files,
44204420
load_config.device_map,
44214421
load_config.sharded_metadata,
4422-
load_config.dtype,
44234422
load_config.weight_mapping,
44244423
)
44254424

tests/utils/test_modeling_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,40 @@ def test_from_pretrained_disk_offload_derived_to_base_model(self):
11601160
outputs2 = new_model_with_offload(inputs)
11611161
torch.testing.assert_close(outputs1[0].cpu(), outputs2[0].cpu())
11621162

1163+
@slow
1164+
@require_accelerate
1165+
@mark.accelerate_tests
1166+
@require_torch_accelerator
1167+
def test_disk_onload_dtype(self):
1168+
from accelerate.utils import align_module_device
1169+
1170+
with tempfile.TemporaryDirectory() as tmp_dir:
1171+
# load model with full disk offloading
1172+
model = AutoModelForCausalLM.from_pretrained(
1173+
"inference-optimization/DSV4-tiny-empty",
1174+
device_map="auto",
1175+
max_memory={},
1176+
offload_folder=tmp_dir,
1177+
offload_buffers=True,
1178+
)
1179+
1180+
# note that `model.dtype` and the dtype of `tid2eid` differ
1181+
offloaded = model.model.layers[0].mlp.gate.tid2eid
1182+
self.assertEqual(offloaded.dtype, torch.int64)
1183+
self.assertEqual(model.dtype, torch.bfloat16)
1184+
1185+
# the dtype used to load from disk should match
1186+
index = model.model.layers[0].mlp.gate._hf_hook.weights_map.dataset.index
1187+
weights_info = index["model.layers.0.mlp.gate.tid2eid"]
1188+
weights_info_type = getattr(torch, weights_info["dtype"])
1189+
self.assertEqual(weights_info_type, offloaded.dtype)
1190+
self.assertEqual(weights_info_type, torch.int64)
1191+
1192+
# the onloaded dtype should be the weight dtype, not the model dtype
1193+
with align_module_device(model.model.layers[0].mlp.gate):
1194+
self.assertEqual(model.model.layers[0].mlp.gate.tid2eid.dtype, offloaded.dtype)
1195+
self.assertEqual(model.model.layers[0].mlp.gate.tid2eid.dtype, torch.int64)
1196+
11631197
@slow
11641198
@require_torch
11651199
def test_from_pretrained_non_contiguous_checkpoint(self):

0 commit comments

Comments
 (0)