Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/transformers/integrations/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,6 @@ def accelerate_disk_offload(
checkpoint_files: list[str] | None,
device_map: dict,
sharded_metadata: dict | None,
dtype: torch.dtype | None,
weight_mapping=None,
):
"""
Expand All @@ -460,7 +459,6 @@ def accelerate_disk_offload(
if is_offloaded_safetensors:
meta_state_dict = model.state_dict()
param_device_map = expand_device_map(device_map, meta_state_dict.keys())
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0])
else:
Expand All @@ -480,7 +478,7 @@ def accelerate_disk_offload(
target_name: {
"safetensors_file": weight_map[source_name],
"weight_name": source_name,
"dtype": str_dtype,
"dtype": str(meta_state_dict[target_name].dtype).removeprefix("torch."),
}
for target_name, source_name in weight_renaming_map.items()
# Need to check if it's in the mapping in case of unexpected keys that would result in KeyError (we skip them)
Expand Down
1 change: 0 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4419,7 +4419,6 @@ def _load_pretrained_model(
checkpoint_files,
load_config.device_map,
load_config.sharded_metadata,
load_config.dtype,
load_config.weight_mapping,
)

Expand Down
34 changes: 34 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,40 @@ def test_from_pretrained_disk_offload_derived_to_base_model(self):
outputs2 = new_model_with_offload(inputs)
torch.testing.assert_close(outputs1[0].cpu(), outputs2[0].cpu())

@slow
@require_accelerate
@mark.accelerate_tests
@require_torch_accelerator
def test_disk_onload_dtype(self):
from accelerate.utils import align_module_device

with tempfile.TemporaryDirectory() as tmp_dir:
# load model with full disk offloading
model = AutoModelForCausalLM.from_pretrained(
"inference-optimization/DSV4-tiny-empty",
device_map="auto",
max_memory={},
offload_folder=tmp_dir,
offload_buffers=True,
)

# note that `model.dtype` and the dtype of `tid2eid` differ
offloaded = model.model.layers[0].mlp.gate.tid2eid
self.assertEqual(offloaded.dtype, torch.int64)
self.assertEqual(model.dtype, torch.bfloat16)

# the dtype used to load from disk should match
index = model.model.layers[0].mlp.gate._hf_hook.weights_map.dataset.index
weights_info = index["model.layers.0.mlp.gate.tid2eid"]
weights_info_type = getattr(torch, weights_info["dtype"])
self.assertEqual(weights_info_type, offloaded.dtype)
self.assertEqual(weights_info_type, torch.int64)

# the onloaded dtype should be the weight dtype, not the model dtype
with align_module_device(model.model.layers[0].mlp.gate):
self.assertEqual(model.model.layers[0].mlp.gate.tid2eid.dtype, offloaded.dtype)
self.assertEqual(model.model.layers[0].mlp.gate.tid2eid.dtype, torch.int64)

@slow
@require_torch
def test_from_pretrained_non_contiguous_checkpoint(self):
Expand Down
Loading