Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Dec 5, 2023
1 parent 274c751 commit c4b0847
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self):
super().__init__()
self.register_buffer("int_param", torch.randint(high=10, size=(15, 30)))
self.register_parameter("float_param", torch.nn.Parameter(torch.rand(10, 5)))

def forward(self, x):
return x + 2

Expand Down Expand Up @@ -434,7 +434,7 @@ def test_load_checkpoint_in_model_two_gpu(self):
self.assertEqual(model.linear1.weight.device, torch.device(0))
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
self.assertEqual(model.linear2.weight.device, torch.device(1))

def test_load_checkpoint_in_model_dtype(self):
with tempfile.NamedTemporaryFile(suffix=".pt") as tmpfile:
print(tmpfile.name)
Expand All @@ -443,7 +443,9 @@ def test_load_checkpoint_in_model_dtype(self):
torch.save(model.state_dict(), "model.pt")

new_model = ModelSeveralDtypes()
load_checkpoint_in_model(new_model, "model.pt", offload_state_dict=True, dtype=torch.float16, device_map={"": "cpu"})
load_checkpoint_in_model(
new_model, "model.pt", offload_state_dict=True, dtype=torch.float16, device_map={"": "cpu"}
)

self.assertEqual(new_model.int_param.dtype, torch.int64)
self.assertEqual(new_model.float_param.dtype, torch.float16)
Expand Down

0 comments on commit c4b0847

Please sign in to comment.