-
Notifications
You must be signed in to change notification settings - Fork 361
[Intel XPU] Enable test/quantization UTs on XPU #3249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a736c41
7c96ad4
dde6d27
9d1cc1f
e54bda3
83afd19
463cf5a
9c8e66b
7b5d2c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,13 +18,16 @@ | |
| from torchao._models.llama.tokenizer import get_tokenizer | ||
| from torchao.quantization import Int4WeightOnlyConfig, quantize_ | ||
| from torchao.quantization.utils import compute_error | ||
| from torchao.utils import auto_detect_device | ||
|
|
||
| torch.manual_seed(0) | ||
|
|
||
| _DEVICE = auto_detect_device() | ||
|
|
||
|
|
||
| class TestGPTQ(TestCase): | ||
| @unittest.skip("skipping until we get checkpoints for gpt-fast") | ||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just change this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") | ||
| def test_gptq_quantizer_int4_weight_only(self): | ||
| from torchao._models._eval import ( | ||
| LMEvalInputRecorder, | ||
|
|
@@ -33,7 +36,6 @@ def test_gptq_quantizer_int4_weight_only(self): | |
| from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer | ||
|
|
||
| precision = torch.bfloat16 | ||
| device = "cuda" | ||
| checkpoint_path = Path( | ||
| "../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" | ||
| ) | ||
|
|
@@ -80,19 +82,19 @@ def test_gptq_quantizer_int4_weight_only(self): | |
| ) | ||
| model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) | ||
|
|
||
| model = quantizer.quantize(model, *inputs).cuda() | ||
| model = quantizer.quantize(model, *inputs).to(_DEVICE) | ||
|
|
||
| model.reset_caches() | ||
| with torch.device("cuda"): | ||
| with torch.device(_DEVICE): | ||
| model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size) | ||
|
|
||
| limit = 1 | ||
| result = TransformerEvalWrapper( | ||
| model.cuda(), | ||
| model.to(_DEVICE), | ||
| tokenizer, | ||
| model.config.block_size, | ||
| prepare_inputs_for_model, | ||
| device, | ||
| _DEVICE, | ||
| ).run_eval( | ||
| ["wikitext"], | ||
| limit, | ||
|
|
@@ -104,7 +106,7 @@ def test_gptq_quantizer_int4_weight_only(self): | |
|
|
||
|
|
||
| class TestMultiTensorFlow(TestCase): | ||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't want to expand test to cpu I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") | ||
| def test_multitensor_add_tensors(self): | ||
| from torchao.quantization.GPTQ import MultiTensor | ||
|
|
||
|
|
@@ -116,7 +118,7 @@ def test_multitensor_add_tensors(self): | |
| self.assertTrue(torch.equal(mt.values[0], tensor1)) | ||
| self.assertTrue(torch.equal(mt.values[1], tensor2)) | ||
|
|
||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") | ||
| def test_multitensor_pad_unpad(self): | ||
| from torchao.quantization.GPTQ import MultiTensor | ||
|
|
||
|
|
@@ -127,7 +129,7 @@ def test_multitensor_pad_unpad(self): | |
| mt.unpad() | ||
| self.assertEqual(mt.count, 1) | ||
|
|
||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") | ||
| def test_multitensor_inplace_operation(self): | ||
| from torchao.quantization.GPTQ import MultiTensor | ||
|
|
||
|
|
@@ -138,7 +140,7 @@ def test_multitensor_inplace_operation(self): | |
|
|
||
|
|
||
| class TestMultiTensorInputRecorder(TestCase): | ||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") | ||
| def test_multitensor_input_recorder(self): | ||
| from torchao.quantization.GPTQ import MultiTensor, MultiTensorInputRecorder | ||
|
|
||
|
|
@@ -159,7 +161,7 @@ def test_multitensor_input_recorder(self): | |
| self.assertTrue(isinstance(MT_input[2][2], MultiTensor)) | ||
| self.assertEqual(MT_input[3], torch.float) | ||
|
|
||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") | ||
| def test_gptq_with_input_recorder(self): | ||
| from torchao.quantization.GPTQ import ( | ||
| Int4WeightOnlyGPTQQuantizer, | ||
|
|
@@ -170,7 +172,7 @@ def test_gptq_with_input_recorder(self): | |
|
|
||
| config = ModelArgs(n_layer=2) | ||
|
|
||
| with torch.device("cuda"): | ||
| with torch.device(_DEVICE): | ||
| model = Transformer(config) | ||
| model.setup_caches(max_batch_size=2, max_seq_length=100) | ||
| idx = torch.randint(1, 10000, (10, 2, 50)).to(torch.int32) | ||
|
|
@@ -191,7 +193,14 @@ def test_gptq_with_input_recorder(self): | |
|
|
||
| args = input_recorder.get_recorded_inputs() | ||
|
|
||
| quantizer = Int4WeightOnlyGPTQQuantizer() | ||
| if _DEVICE.type == "xpu": | ||
| from torchao.dtypes import Int4XPULayout | ||
|
|
||
| quantizer = Int4WeightOnlyGPTQQuantizer( | ||
| device=torch.device("xpu"), layout=Int4XPULayout() | ||
| ) | ||
| else: | ||
| quantizer = Int4WeightOnlyGPTQQuantizer() | ||
|
|
||
| quantizer.quantize(model, *args) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto_detect_device seems to be changing what we want to test, I think previous we only want to test on CUDA, can you preserve this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have refine the auto_detect_device functions and cpu will not be included.