5
5
# LICENSE file in the root directory of this source tree.
6
6
import copy
7
7
import tempfile
8
- import unittest
9
8
10
9
import torch
10
+ from parameterized import parameterized
11
11
from torch .testing ._internal .common_utils import (
12
12
TestCase ,
13
13
run_tests ,
14
14
)
15
15
16
16
from torchao .prototype .awq import AWQConfig , AWQStep
17
17
from torchao .quantization import Int4WeightOnlyConfig , quantize_
18
- from torchao .utils import _is_fbgemm_genai_gpu_available
18
+ from torchao .utils import _is_fbgemm_genai_gpu_available , torch_version_at_least
19
19
20
20
21
21
class ToyLinearModel (torch .nn .Module ):
@@ -42,11 +42,15 @@ def forward(self, x):
42
42
return x
43
43
44
44
45
- @unittest .skipIf (not torch .cuda .is_available (), reason = "CUDA not available" )
46
- @unittest .skipIf (
47
- not _is_fbgemm_genai_gpu_available (),
48
- reason = "need to install fbgemm_gpu_genai package" ,
49
- )
45
+ devices = ["cpu" ]
46
+ if (
47
+ torch .cuda .is_available ()
48
+ and _is_fbgemm_genai_gpu_available ()
49
+ and torch_version_at_least ("2.6.0" )
50
+ ):
51
+ devices .append ("cuda" )
52
+
53
+
50
54
class TestAWQ (TestCase ):
51
55
def test_awq_config (self ):
52
56
base_config = Int4WeightOnlyConfig ()
@@ -61,8 +65,8 @@ def test_awq_config(self):
61
65
with self .assertRaisesRegex (ValueError , "is not one of" ):
62
66
AWQConfig (base_config , step = "not_supported" )
63
67
64
- def test_awq_functionality ( self ):
65
- device = "cuda"
68
+ @ parameterized . expand ([( device ,) for device in devices ])
69
+ def test_awq_functionality ( self , device ):
66
70
dataset_size = 100
67
71
l1 , l2 , l3 = 512 , 256 , 128
68
72
original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
@@ -73,7 +77,15 @@ def test_awq_functionality(self):
73
77
m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
74
78
75
79
# baseline quantization
76
- base_config = Int4WeightOnlyConfig (group_size = group_size )
80
+ if device == "cuda" :
81
+ base_config = Int4WeightOnlyConfig (group_size = group_size )
82
+ elif device == "cpu" :
83
+ base_config = Int4WeightOnlyConfig (
84
+ group_size = group_size , int4_packing_format = "opaque"
85
+ )
86
+ torch .manual_seed (1234 )
87
+ else :
88
+ assert False , "Unsupported device: {}" .format (device )
77
89
m_baseline = copy .deepcopy (m )
78
90
quantize_ (m_baseline , base_config )
79
91
@@ -104,8 +116,8 @@ def test_awq_functionality(self):
104
116
loss_base = (ref_out - baseline_out ).pow (2 ).mean ().item ()
105
117
assert loss_awq < loss_base
106
118
107
- def test_awq_loading ( self ):
108
- device = "cuda"
119
+ @ parameterized . expand ([( device ,) for device in devices ])
120
+ def test_awq_loading ( self , device ):
109
121
dataset_size = 100
110
122
l1 , l2 , l3 = 512 , 256 , 128
111
123
original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
@@ -123,7 +135,14 @@ def test_awq_loading(self):
123
135
calibration_data = dataset [:n_calibration_examples ]
124
136
125
137
# calibrate
126
- base_config = Int4WeightOnlyConfig (group_size = group_size )
138
+ if device == "cuda" :
139
+ base_config = Int4WeightOnlyConfig (group_size = group_size )
140
+ elif device == "cpu" :
141
+ base_config = Int4WeightOnlyConfig (
142
+ group_size = group_size , int4_packing_format = "opaque"
143
+ )
144
+ else :
145
+ assert False , "Unsupported device: {}" .format (device )
127
146
quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
128
147
quantize_ (m , quant_config )
129
148
@@ -152,14 +171,14 @@ def test_awq_loading(self):
152
171
assert awq_save_load_out is not None
153
172
assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
154
173
155
- def test_awq_loading_vllm (self ):
174
+ @parameterized .expand ([(device ,) for device in devices ])
175
+ def test_awq_loading_vllm (self , device ):
156
176
"""Simulate weight loading in vllm:
157
177
* prepare model weight to the same format (awq weight)
158
178
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
159
179
160
180
There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
161
181
"""
162
- device = "cuda"
163
182
dataset_size = 100
164
183
l1 , l2 , l3 = 512 , 256 , 128
165
184
original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
@@ -177,7 +196,14 @@ def test_awq_loading_vllm(self):
177
196
calibration_data = dataset [:n_calibration_examples ]
178
197
179
198
# calibrate
180
- base_config = Int4WeightOnlyConfig (group_size = group_size )
199
+ if device == "cuda" :
200
+ base_config = Int4WeightOnlyConfig (group_size = group_size )
201
+ elif device == "cpu" :
202
+ base_config = Int4WeightOnlyConfig (
203
+ group_size = group_size , int4_packing_format = "opaque"
204
+ )
205
+ else :
206
+ assert False , "Unsupported device: {}" .format (device )
181
207
quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
182
208
quantize_ (m , quant_config )
183
209
0 commit comments