@@ -213,6 +213,9 @@ def forward(self, x):
213213
214214class TestQAT (TestCase ):
215215 SEED = 123
216+ DEVICE = torch .accelerator .current_accelerator () if \
217+ torch .accelerator .is_available () else \
218+ None
216219
217220 def test_fake_quantize_per_channel_group (self ):
218221 n_bit = 4
@@ -347,7 +350,7 @@ def _set_ptq_weight(
347350 group_size ,
348351 )
349352 q_weight = torch .ops .aten ._convert_weight_to_int4pack (
350- q_weight .to ("cuda" ),
353+ q_weight .to (self . DEVICE ),
351354 qat_linear .inner_k_tiles ,
352355 )
353356 ptq_linear .weight = q_weight
@@ -600,13 +603,13 @@ def _assert_close_4w(self, val, ref):
600603 print (mean_err )
601604 self .assertTrue (mean_err < 0.05 )
602605
603- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
606+ @unittest .skipIf (DEVICE == None , "skipping when GPU is not available" )
604607 def test_qat_4w_primitives (self ):
605608 n_bit = 4
606609 group_size = 32
607610 inner_k_tiles = 8
608611 scales_precision = torch .bfloat16
609- device = torch . device ( "cuda" )
612+ device = self . DEVICE
610613 dtype = torch .bfloat16
611614 torch .manual_seed (self .SEED )
612615 x = torch .randn (100 , 256 , dtype = dtype , device = device )
@@ -651,13 +654,13 @@ def test_qat_4w_primitives(self):
651654
652655 self ._assert_close_4w (qat_out , ptq_out )
653656
654- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
657+ @unittest .skipIf (DEVICE == None , "skipping when GPU is not available" )
655658 def test_qat_4w_linear (self ):
656659 from torchao .quantization .GPTQ import WeightOnlyInt4Linear
657660 from torchao .quantization .qat .linear import Int4WeightOnlyQATLinear
658661
659662 group_size = 128
660- device = torch . device ( "cuda" )
663+ device = self . DEVICE
661664 dtype = torch .bfloat16
662665 torch .manual_seed (self .SEED )
663666 qat_linear = Int4WeightOnlyQATLinear (
@@ -692,15 +695,15 @@ def test_qat_4w_quantizer_gradients(self):
692695 quantizer = Int4WeightOnlyQATQuantizer (groupsize = 32 , inner_k_tiles = 8 )
693696 self ._test_qat_quantized_gradients (quantizer )
694697
695- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
698+ @unittest .skipIf (DEVICE == None , "skipping when GPU is not available" )
699+ @unittest .skipIf (DEVICE == torch .device ("xpu" ), "skipped due to https://github.com/intel/torch-xpu-ops/issues/1770" )
696700 def test_qat_4w_quantizer (self ):
697701 from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
698702 from torchao .quantization .qat import Int4WeightOnlyQATQuantizer
699-
700703 group_size = 32
701704 inner_k_tiles = 8
702- device = torch .device ("cuda" )
703705 dtype = torch .bfloat16
706+ device = self .DEVICE
704707 torch .manual_seed (self .SEED )
705708 m = M ().to (device ).to (dtype )
706709 m2 = copy .deepcopy (m )
@@ -711,6 +714,7 @@ def test_qat_4w_quantizer(self):
711714 ptq_quantizer = Int4WeightOnlyQuantizer (
712715 groupsize = group_size ,
713716 inner_k_tiles = inner_k_tiles ,
717+ device = device
714718 )
715719 qat_model = qat_quantizer .prepare (m )
716720 ptq_model = ptq_quantizer .quantize (m2 )
@@ -1891,12 +1895,12 @@ def _test_quantize_api_against_ptq(
18911895 torch .manual_seed (self .SEED )
18921896
18931897 if module_type == "linear" :
1894- m = M ().to (dtype ).cuda ( )
1895- example_inputs = (m .example_inputs ()[0 ].to (dtype ).cuda ( ),)
1898+ m = M ().to (dtype ).to ( self . DEVICE )
1899+ example_inputs = (m .example_inputs ()[0 ].to (dtype ).to ( self . DEVICE ),)
18961900 filter_fn = lambda m , fqn : isinstance (m , torch .nn .Linear )
18971901 elif module_type == "embedding" :
1898- m = M3 ().to (dtype ).cuda ( )
1899- example_inputs = (m .example_inputs ()[0 ].cuda ( ),)
1902+ m = M3 ().to (dtype ).to ( self . DEVICE )
1903+ example_inputs = (m .example_inputs ()[0 ].to ( self . DEVICE ),)
19001904 filter_fn = lambda m , fqn : isinstance (m , torch .nn .Embedding )
19011905 else :
19021906 raise ValueError (f"Unknown module type { module_type } " )
@@ -1971,7 +1975,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat
19711975 target_convert_sqnr = float ("inf" ),
19721976 )
19731977
1974- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1978+ @unittest .skipIf (DEVICE == None , "skipping when GPU is not available" )
19751979 def test_quantize_api_int8_int4 (self ):
19761980 """
19771981 Test the following:
@@ -1984,7 +1988,7 @@ def test_quantize_api_int8_int4(self):
19841988 target_convert_sqnr = float ("inf" ),
19851989 )
19861990
1987- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1991+ @unittest .skipIf (DEVICE == None , "skipping when GPU is not available" )
19881992 @parametrize (
19891993 "weight_dtype, weight_granularity, dtype" ,
19901994 [
@@ -2009,7 +2013,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype):
20092013 dtype = dtype ,
20102014 )
20112015
2012- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
2016+ @unittest .skipIf (DEVICE == None , "skipping when GPU is not available" )
20132017 @parametrize (
20142018 "weight_dtype, granularity, dtype, module_type" ,
20152019 [
0 commit comments