Skip to content

Commit 4d5a96e

Browse files
authored
fix autocast (#11190)
Signed-off-by: jiqing-feng <[email protected]>
1 parent a7f07c1 commit 4d5a96e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/quantization/bnb/test_mixed_int8.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def test_keep_modules_in_fp32(self):
221221
self.assertTrue(module.weight.dtype == torch.int8)
222222

223223
# test if inference works.
224-
with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
224+
with torch.no_grad() and torch.autocast(model.device.type, dtype=torch.float16):
225225
input_dict_for_transformer = self.get_dummy_inputs()
226226
model_inputs = {
227227
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)

0 commit comments

Comments
 (0)