Skip to content

Commit

Permalink
[tests] fix bug in torch_device (#2909)
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany authored Jul 4, 2024
1 parent 947f64e commit 167cb5e
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,10 +748,7 @@ def test_load_state_dict(self):

for param, device in device_map.items():
device = device if device != "disk" else "cpu"
expected_device = (
torch.device(f"{torch_device}:{device}") if isinstance(device, int) else torch.device(device)
)
assert loaded_state_dict[param].device == expected_device
assert loaded_state_dict[param].device == torch.device(device)

def test_convert_file_size(self):
result = convert_file_size_to_int("0MB")
Expand Down

0 comments on commit 167cb5e

Please sign in to comment.