Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
# LICENSE file in the root directory of this source tree.
import copy
import tempfile
import unittest

import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao.prototype.awq import AWQConfig, AWQStep
from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.utils import _is_fbgemm_genai_gpu_available
from torchao.utils import _is_fbgemm_genai_gpu_available, torch_version_at_least


class ToyLinearModel(torch.nn.Module):
Expand All @@ -42,11 +42,15 @@ def forward(self, x):
return x


@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(),
reason="need to install fbgemm_gpu_genai package",
)
devices = ["cpu"]
if (
torch.cuda.is_available()
and _is_fbgemm_genai_gpu_available()
and torch_version_at_least("2.6.0")
):
devices.append("cuda")


class TestAWQ(TestCase):
def test_awq_config(self):
base_config = Int4WeightOnlyConfig()
Expand All @@ -61,8 +65,8 @@ def test_awq_config(self):
with self.assertRaisesRegex(ValueError, "is not one of"):
AWQConfig(base_config, step="not_supported")

def test_awq_functionality(self):
device = "cuda"
@parameterized.expand([(device,) for device in devices])
def test_awq_functionality(self, device):
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
Expand All @@ -73,7 +77,15 @@ def test_awq_functionality(self):
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)

# baseline quantization
base_config = Int4WeightOnlyConfig(group_size=group_size)
if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="opaque"
)
torch.manual_seed(1234)
else:
assert False, "Unsupported device: {}".format(device)
m_baseline = copy.deepcopy(m)
quantize_(m_baseline, base_config)

Expand Down Expand Up @@ -104,8 +116,8 @@ def test_awq_functionality(self):
loss_base = (ref_out - baseline_out).pow(2).mean().item()
assert loss_awq < loss_base

def test_awq_loading(self):
device = "cuda"
@parameterized.expand([(device,) for device in devices])
def test_awq_loading(self, device):
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
Expand All @@ -123,7 +135,14 @@ def test_awq_loading(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = Int4WeightOnlyConfig(group_size=group_size)
if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="opaque"
)
else:
assert False, "Unsupported device: {}".format(device)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down Expand Up @@ -152,14 +171,14 @@ def test_awq_loading(self):
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)

def test_awq_loading_vllm(self):
@parameterized.expand([(device,) for device in devices])
def test_awq_loading_vllm(self, device):
"""Simulate weight loading in vllm:
* prepare model weight to the same format (awq weight)
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint

There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
"""
device = "cuda"
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
Expand All @@ -177,7 +196,14 @@ def test_awq_loading_vllm(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = Int4WeightOnlyConfig(group_size=group_size)
if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="opaque"
)
else:
assert False, "Unsupported device: {}".format(device)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
from torchao.quantization.utils import compute_error
from torchao.utils import (
torch_version_at_least,
Expand Down Expand Up @@ -76,6 +77,31 @@ def test_module_path(self, dtype):
"<class 'torchao.quantization.Int4OpaqueTensor'>",
)

def test_activation_prescaling(self):
dtype = torch.bfloat16
input = torch.randn(1, 128, dtype=dtype)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype)
original_output = linear(input)
quantize_(linear, get_config(group_size=128))
qw = linear.weight
assert isinstance(qw, SupportsActivationPreScaling), (
"Expected int4 tensor supports activation prescaling"
)
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
_ACT_PRE_SCALE = 2
manual_scaled_quantized = linear(input * _ACT_PRE_SCALE)
qw.act_pre_scale = _ACT_PRE_SCALE
auto_scaled_quantized = linear(input)

# Making sure activation pre scaling is successfully applied to the activation.
self.assertEqual(manual_scaled_quantized, auto_scaled_quantized)

# If pre-scaling is auto-applied, the quantization error should be low,
# i.e., compute_error (SQNR) is high
self.assertTrue(
compute_error(original_output * _ACT_PRE_SCALE, auto_scaled_quantized) > 20
)


instantiate_parametrized_tests(TestInt4OpaqueTensor)

Expand Down
Loading
Loading