Skip to content

Commit 3bad56f

Browse files
committed
[Intel GPU] Extend TestQAT module with xpu testcases
Add xpu mode to tests from test_qat.py TestQAT module
1 parent 0f05b40 commit 3bad56f

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

test/quantization/test_qat.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ def forward(self, x):
213213

214214
class TestQAT(TestCase):
215215
SEED = 123
216+
DEVICE = torch.accelerator.current_accelerator() if \
217+
torch.accelerator.is_available() else \
218+
None
216219

217220
def test_fake_quantize_per_channel_group(self):
218221
n_bit = 4
@@ -347,7 +350,7 @@ def _set_ptq_weight(
347350
group_size,
348351
)
349352
q_weight = torch.ops.aten._convert_weight_to_int4pack(
350-
q_weight.to("cuda"),
353+
q_weight.to(self.DEVICE),
351354
qat_linear.inner_k_tiles,
352355
)
353356
ptq_linear.weight = q_weight
@@ -600,13 +603,13 @@ def _assert_close_4w(self, val, ref):
600603
print(mean_err)
601604
self.assertTrue(mean_err < 0.05)
602605

603-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
606+
@unittest.skipIf(DEVICE == None, "skipping when GPU is not available")
604607
def test_qat_4w_primitives(self):
605608
n_bit = 4
606609
group_size = 32
607610
inner_k_tiles = 8
608611
scales_precision = torch.bfloat16
609-
device = torch.device("cuda")
612+
device = self.DEVICE
610613
dtype = torch.bfloat16
611614
torch.manual_seed(self.SEED)
612615
x = torch.randn(100, 256, dtype=dtype, device=device)
@@ -651,13 +654,13 @@ def test_qat_4w_primitives(self):
651654

652655
self._assert_close_4w(qat_out, ptq_out)
653656

654-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
657+
@unittest.skipIf(DEVICE == None, "skipping when GPU is not available")
655658
def test_qat_4w_linear(self):
656659
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
657660
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear
658661

659662
group_size = 128
660-
device = torch.device("cuda")
663+
device = self.DEVICE
661664
dtype = torch.bfloat16
662665
torch.manual_seed(self.SEED)
663666
qat_linear = Int4WeightOnlyQATLinear(
@@ -692,15 +695,15 @@ def test_qat_4w_quantizer_gradients(self):
692695
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
693696
self._test_qat_quantized_gradients(quantizer)
694697

695-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
698+
@unittest.skipIf(DEVICE == None, "skipping when GPU is not available")
699+
@unittest.skipIf(DEVICE == torch.device("xpu"), "skipped due to https://github.com/intel/torch-xpu-ops/issues/1770")
696700
def test_qat_4w_quantizer(self):
697701
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
698702
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
699-
700703
group_size = 32
701704
inner_k_tiles = 8
702-
device = torch.device("cuda")
703705
dtype = torch.bfloat16
706+
device = self.DEVICE
704707
torch.manual_seed(self.SEED)
705708
m = M().to(device).to(dtype)
706709
m2 = copy.deepcopy(m)
@@ -711,6 +714,7 @@ def test_qat_4w_quantizer(self):
711714
ptq_quantizer = Int4WeightOnlyQuantizer(
712715
groupsize=group_size,
713716
inner_k_tiles=inner_k_tiles,
717+
device=device
714718
)
715719
qat_model = qat_quantizer.prepare(m)
716720
ptq_model = ptq_quantizer.quantize(m2)
@@ -1891,12 +1895,12 @@ def _test_quantize_api_against_ptq(
18911895
torch.manual_seed(self.SEED)
18921896

18931897
if module_type == "linear":
1894-
m = M().to(dtype).cuda()
1895-
example_inputs = (m.example_inputs()[0].to(dtype).cuda(),)
1898+
m = M().to(dtype).to(self.DEVICE)
1899+
example_inputs = (m.example_inputs()[0].to(dtype).to(self.DEVICE),)
18961900
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear)
18971901
elif module_type == "embedding":
1898-
m = M3().to(dtype).cuda()
1899-
example_inputs = (m.example_inputs()[0].cuda(),)
1902+
m = M3().to(dtype).to(self.DEVICE)
1903+
example_inputs = (m.example_inputs()[0].to(self.DEVICE),)
19001904
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding)
19011905
else:
19021906
raise ValueError(f"Unknown module type {module_type}")
@@ -1971,7 +1975,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat
19711975
target_convert_sqnr=float("inf"),
19721976
)
19731977

1974-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1978+
@unittest.skipIf(DEVICE == None, "skipping when GPU is not available")
19751979
def test_quantize_api_int8_int4(self):
19761980
"""
19771981
Test the following:
@@ -1984,7 +1988,7 @@ def test_quantize_api_int8_int4(self):
19841988
target_convert_sqnr=float("inf"),
19851989
)
19861990

1987-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1991+
@unittest.skipIf(DEVICE == None, "skipping when GPU is not available")
19881992
@parametrize(
19891993
"weight_dtype, weight_granularity, dtype",
19901994
[
@@ -2009,7 +2013,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype):
20092013
dtype=dtype,
20102014
)
20112015

2012-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
2016+
@unittest.skipIf(DEVICE == None, "skipping when GPU is not available")
20132017
@parametrize(
20142018
"weight_dtype, granularity, dtype, module_type",
20152019
[

0 commit comments

Comments
 (0)