Skip to content

Commit 067b273

Browse files
author
Yuxin Cui
authored
Support Int4OpaqueTensor for AWQ (#2997)
* Support Int4OpaqueTensor for AWQ Add act_pre_scale into Int4OpaqueTensor for AWQ. Signed-off-by: Cui, Yuxin <[email protected]> * Format codes Signed-off-by: Cui, Yuxin <[email protected]> * Add detailed tests for act_pre_scale Signed-off-by: Cui, Yuxin <[email protected]> * remove debug codes Signed-off-by: Cui, Yuxin <[email protected]> * update codes Signed-off-by: Cui, Yuxin <[email protected]> * Change to int4_packing_format Signed-off-by: Cui, Yuxin <[email protected]> * Precise variable naming Signed-off-by: Cui, Yuxin <[email protected]> * Update tests for act_pre_scale Signed-off-by: Cui, Yuxin <[email protected]> --------- Signed-off-by: Cui, Yuxin <[email protected]>
1 parent 58c3064 commit 067b273

File tree

4 files changed

+169
-51
lines changed

4 files changed

+169
-51
lines changed

test/prototype/test_awq.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
# LICENSE file in the root directory of this source tree.
66
import copy
77
import tempfile
8-
import unittest
98

109
import torch
10+
from parameterized import parameterized
1111
from torch.testing._internal.common_utils import (
1212
TestCase,
1313
run_tests,
1414
)
1515

1616
from torchao.prototype.awq import AWQConfig, AWQStep
1717
from torchao.quantization import Int4WeightOnlyConfig, quantize_
18-
from torchao.utils import _is_fbgemm_genai_gpu_available
18+
from torchao.utils import _is_fbgemm_genai_gpu_available, torch_version_at_least
1919

2020

2121
class ToyLinearModel(torch.nn.Module):
@@ -42,11 +42,15 @@ def forward(self, x):
4242
return x
4343

4444

45-
@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
46-
@unittest.skipIf(
47-
not _is_fbgemm_genai_gpu_available(),
48-
reason="need to install fbgemm_gpu_genai package",
49-
)
45+
devices = ["cpu"]
46+
if (
47+
torch.cuda.is_available()
48+
and _is_fbgemm_genai_gpu_available()
49+
and torch_version_at_least("2.6.0")
50+
):
51+
devices.append("cuda")
52+
53+
5054
class TestAWQ(TestCase):
5155
def test_awq_config(self):
5256
base_config = Int4WeightOnlyConfig()
@@ -61,8 +65,8 @@ def test_awq_config(self):
6165
with self.assertRaisesRegex(ValueError, "is not one of"):
6266
AWQConfig(base_config, step="not_supported")
6367

64-
def test_awq_functionality(self):
65-
device = "cuda"
68+
@parameterized.expand([(device,) for device in devices])
69+
def test_awq_functionality(self, device):
6670
dataset_size = 100
6771
l1, l2, l3 = 512, 256, 128
6872
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
@@ -73,7 +77,15 @@ def test_awq_functionality(self):
7377
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
7478

7579
# baseline quantization
76-
base_config = Int4WeightOnlyConfig(group_size=group_size)
80+
if device == "cuda":
81+
base_config = Int4WeightOnlyConfig(group_size=group_size)
82+
elif device == "cpu":
83+
base_config = Int4WeightOnlyConfig(
84+
group_size=group_size, int4_packing_format="opaque"
85+
)
86+
torch.manual_seed(1234)
87+
else:
88+
assert False, "Unsupported device: {}".format(device)
7789
m_baseline = copy.deepcopy(m)
7890
quantize_(m_baseline, base_config)
7991

@@ -104,8 +116,8 @@ def test_awq_functionality(self):
104116
loss_base = (ref_out - baseline_out).pow(2).mean().item()
105117
assert loss_awq < loss_base
106118

107-
def test_awq_loading(self):
108-
device = "cuda"
119+
@parameterized.expand([(device,) for device in devices])
120+
def test_awq_loading(self, device):
109121
dataset_size = 100
110122
l1, l2, l3 = 512, 256, 128
111123
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
@@ -123,7 +135,14 @@ def test_awq_loading(self):
123135
calibration_data = dataset[:n_calibration_examples]
124136

125137
# calibrate
126-
base_config = Int4WeightOnlyConfig(group_size=group_size)
138+
if device == "cuda":
139+
base_config = Int4WeightOnlyConfig(group_size=group_size)
140+
elif device == "cpu":
141+
base_config = Int4WeightOnlyConfig(
142+
group_size=group_size, int4_packing_format="opaque"
143+
)
144+
else:
145+
assert False, "Unsupported device: {}".format(device)
127146
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
128147
quantize_(m, quant_config)
129148

@@ -152,14 +171,14 @@ def test_awq_loading(self):
152171
assert awq_save_load_out is not None
153172
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
154173

155-
def test_awq_loading_vllm(self):
174+
@parameterized.expand([(device,) for device in devices])
175+
def test_awq_loading_vllm(self, device):
156176
"""Simulate weight loading in vllm:
157177
* prepare model weight to the same format (awq weight)
158178
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
159179
160180
There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
161181
"""
162-
device = "cuda"
163182
dataset_size = 100
164183
l1, l2, l3 = 512, 256, 128
165184
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
@@ -177,7 +196,14 @@ def test_awq_loading_vllm(self):
177196
calibration_data = dataset[:n_calibration_examples]
178197

179198
# calibrate
180-
base_config = Int4WeightOnlyConfig(group_size=group_size)
199+
if device == "cuda":
200+
base_config = Int4WeightOnlyConfig(group_size=group_size)
201+
elif device == "cpu":
202+
base_config = Int4WeightOnlyConfig(
203+
group_size=group_size, int4_packing_format="opaque"
204+
)
205+
else:
206+
assert False, "Unsupported device: {}".format(device)
181207
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
182208
quantize_(m, quant_config)
183209

test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Int4WeightOnlyConfig,
2020
quantize_,
2121
)
22+
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
2223
from torchao.quantization.utils import compute_error
2324
from torchao.utils import (
2425
torch_version_at_least,
@@ -76,6 +77,31 @@ def test_module_path(self, dtype):
7677
"<class 'torchao.quantization.Int4OpaqueTensor'>",
7778
)
7879

80+
def test_activation_prescaling(self):
81+
dtype = torch.bfloat16
82+
input = torch.randn(1, 128, dtype=dtype)
83+
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype)
84+
original_output = linear(input)
85+
quantize_(linear, get_config(group_size=128))
86+
qw = linear.weight
87+
assert isinstance(qw, SupportsActivationPreScaling), (
88+
"Expected int4 tensor supports activation prescaling"
89+
)
90+
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
91+
_ACT_PRE_SCALE = 2
92+
manual_scaled_quantized = linear(input * _ACT_PRE_SCALE)
93+
qw.act_pre_scale = _ACT_PRE_SCALE
94+
auto_scaled_quantized = linear(input)
95+
96+
# Making sure activation pre scaling is successfully applied to the activation.
97+
self.assertEqual(manual_scaled_quantized, auto_scaled_quantized)
98+
99+
# If pre-scaling is auto-applied, the quantization error should be low,
100+
# i.e., compute_error (SQNR) is high
101+
self.assertTrue(
102+
compute_error(original_output * _ACT_PRE_SCALE, auto_scaled_quantized) > 20
103+
)
104+
79105

80106
instantiate_parametrized_tests(TestInt4OpaqueTensor)
81107

0 commit comments

Comments
 (0)