Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,10 @@ def __init__(

self.quantization_compressor = {}
for format in self.compression_formats:
self.quantization_compressor[
format
] = BaseCompressor.load_from_registry(
format, config=quantization_config
self.quantization_compressor[format] = (
BaseCompressor.load_from_registry(
format, config=quantization_config
)
)

def get_missing_module_keys(self, model: Module) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,14 @@ def compress_weight(
compressed_dict["weight_shape"] = weight_shape
compressed_dict["weight_packed"] = packed_weight

# We typically don't compress zp; apart from when using the packed_compressor
# and when storing group/channel zp
if not quantization_args.symmetric and quantization_args.strategy in [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]:
packed_zp = pack_to_int32(
zero_point, quantization_args.num_bits, packed_dim=0
)
compressed_dict["weight_zero_point"] = packed_zp
compressed_dict["weight_zero_point"] = packed_zp.contiguous()
return compressed_dict

def decompress_weight(
Expand All @@ -166,16 +164,13 @@ def decompress_weight(
num_bits = quantization_args.num_bits
unpacked = unpack_from_int32(weight, num_bits, original_shape)

# NOTE: this will fail decompression as we don't currently handle packed zp on
# decompression
if not quantization_args.symmetric and quantization_args.strategy in [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]:
raise ValueError(
"Decompression of packed zero points is currently not supported"
)
assert zero_point is not None
assert (
zero_point is not None
), "Asymmetric quantization requires zero-point values"
original_zp_shape = (original_shape[0], scale.shape[-1])
zero_point = unpack_from_int32(
zero_point, num_bits, original_zp_shape, packed_dim=0
Expand Down
102 changes: 97 additions & 5 deletions tests/test_compressors/quantized_compressors/test_pack_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import math
import shutil
import tempfile
from collections import OrderedDict

import pytest
Expand Down Expand Up @@ -170,12 +171,13 @@ def test_reload_match(tmp_path, num_bits):
)
save_file(compressed_state_dict, tmp_path / "model.safetensors")

reconstructed_dense_gen = compressor.decompress(
tmp_path, names_to_scheme=quantized_modules_to_scheme
)
reconstructed_dense = {}
for name, value in reconstructed_dense_gen:
reconstructed_dense[name] = value
with tempfile.TemporaryDirectory() as _tmp:
reconstructed_dense_gen = compressor.decompress(
tmp_path, names_to_scheme=quantized_modules_to_scheme
)
for name, value in reconstructed_dense_gen:
reconstructed_dense[name] = value

fake_quant_dummy = fake_quantize(
dense_state_dict["dummy.weight"],
Expand Down Expand Up @@ -473,3 +475,93 @@ def test_unpack_from_int32(num_bits, values, expected_tensor):
unpacked_tensor = unpack_from_int32(values, num_bits, expected_tensor.shape)
assert torch.equal(unpacked_tensor, unpacked_tensor)
assert unpacked_tensor.dtype == unpacked_tensor.dtype


@pytest.mark.parametrize(
"strategy,group_size",
[
(QuantizationStrategy.GROUP, 128),
(QuantizationStrategy.CHANNEL, None),
],
)
def test_asymmetric_zero_point_decompression(strategy, group_size, tmp_path):
"""
Test that zero-point packing and unpacking works correctly for asymmetric quantization
with GROUP and CHANNEL strategies.
"""
shape = (512, 1024)

if strategy == QuantizationStrategy.CHANNEL:
expected_zp_shape = (shape[0], 1)
elif strategy == QuantizationStrategy.GROUP:
num_groups = shape[1] // group_size
expected_zp_shape = (shape[0], max(num_groups, 1))

dense_state_dict = {
"dummy.weight": torch.randn(shape),
"dummy.weight_scale": torch.rand(expected_zp_shape).to(torch.float32),
"dummy.weight_zero_point": torch.randint(-8, 8, expected_zp_shape).to(
torch.int8
),
}

quant_config = get_dummy_quant_config(
num_bits=4, strategy=strategy.value, symmetric=False, group_size=group_size
)

compressor = PackedQuantizationCompressor(config=quant_config)
quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]}
compressed_state_dict = compressor.compress(
dense_state_dict.copy(), names_to_scheme=quantized_modules_to_scheme
)

assert "dummy.weight_zero_point" in compressed_state_dict
assert compressed_state_dict["dummy.weight_zero_point"].dtype == torch.int32

save_file(compressed_state_dict, tmp_path / "model.safetensors")

reconstructed_dense_gen = compressor.decompress(
tmp_path, names_to_scheme=quantized_modules_to_scheme
)
reconstructed_dense = {}
for name, value in reconstructed_dense_gen:
reconstructed_dense[name] = value

assert "dummy" in reconstructed_dense
assert "weight" in reconstructed_dense["dummy"]

assert reconstructed_dense["dummy"]["weight"].shape == shape

shutil.rmtree(tmp_path)


@pytest.mark.parametrize(
"num_bits,strategy",
[
(4, QuantizationStrategy.GROUP),
(4, QuantizationStrategy.CHANNEL),
(8, QuantizationStrategy.GROUP),
(8, QuantizationStrategy.CHANNEL),
],
)
def test_zero_point_pack_unpack_consistency(num_bits, strategy):
"""
Test that packing and unpacking zero-points preserves values correctly.
"""
if strategy == QuantizationStrategy.GROUP:
shape = (512, 8)
group_size = 128
else:
shape = (512, 1)
group_size = None

max_val = (1 << (num_bits - 1)) - 1
min_val = -(1 << (num_bits - 1))
original_zp = torch.randint(min_val, max_val + 1, shape).to(torch.int8)

packed_zp = pack_to_int32(original_zp, num_bits, packed_dim=0)

unpacked_zp = unpack_from_int32(packed_zp, num_bits, shape, packed_dim=0)

assert torch.equal(original_zp, unpacked_zp)
assert unpacked_zp.dtype == torch.int8
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
End-to-end tests for asymmetric quantization with zero-point decompression.
"""

import pytest
import torch
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationConfig,
QuantizationScheme,
QuantizationStrategy,
apply_quantization_config,
)
from compressed_tensors.config import CompressionFormat
from compressed_tensors.compressors.model_compressors.model_compressor import (
ModelCompressor,
)
from torch.nn import Linear, Module


class SimpleModel(Module):
"""Simple model for testing"""

def __init__(self, input_dim=512, hidden_dim=256, output_dim=128):
super().__init__()
self.layer1 = Linear(input_dim, hidden_dim, bias=False)
self.layer2 = Linear(hidden_dim, output_dim, bias=False)

def forward(self, x):
x = self.layer1(x)
x = torch.relu(x)
x = self.layer2(x)
return x


def create_asymmetric_quant_config(
num_bits=4, strategy=QuantizationStrategy.GROUP, group_size=128
) -> QuantizationConfig:
"""Create an asymmetric quantization config"""
config_groups = {
"group_1": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=num_bits,
strategy=strategy.value,
group_size=(
group_size if strategy == QuantizationStrategy.GROUP else None
),
symmetric=False,
),
),
}
return QuantizationConfig(config_groups=config_groups)


@pytest.mark.parametrize(
"strategy,group_size",
[
(QuantizationStrategy.GROUP, 128),
(QuantizationStrategy.CHANNEL, None),
],
)
def test_end_to_end_asymmetric_quantization(
strategy,
group_size,
mock_per_group_calibration,
mock_per_channel_calibration,
):
"""
Test end-to-end workflow: quantize -> compress -> decompress in memory
"""
model = SimpleModel()
original_weights = {
"layer1": model.layer1.weight.detach().clone(),
"layer2": model.layer2.weight.detach().clone(),
}

quant_config = create_asymmetric_quant_config(
num_bits=4, strategy=strategy, group_size=group_size
)
# Set pack-quantized format for ModelCompressor usage
quant_config.format = CompressionFormat.pack_quantized.value
apply_quantization_config(model, quant_config)

if strategy == QuantizationStrategy.GROUP:
mock_per_group_calibration(
model.layer1, "weight", model.layer1.weight, group_size
)
mock_per_group_calibration(
model.layer2, "weight", model.layer2.weight, group_size
)
else:
mock_per_channel_calibration(model.layer1, "weight", model.layer1.weight)
mock_per_channel_calibration(model.layer2, "weight", model.layer2.weight)

# Compress and decompress in memory using ModelCompressor
mc = ModelCompressor(quantization_config=quant_config)
mc.compress_model(model)

# Verify compression created zero-point parameters
assert hasattr(model.layer1, "weight_zero_point")
assert hasattr(model.layer2, "weight_zero_point")
assert model.layer1.weight_zero_point.dtype == torch.int32
assert model.layer2.weight_zero_point.dtype == torch.int32

# Decompress in memory
mc.decompress_model(model)

# Verify decompression restored weights correctly
assert model.layer1.weight.shape == original_weights["layer1"].shape
assert model.layer2.weight.shape == original_weights["layer2"].shape
assert model.layer1.weight.dtype.is_floating_point
assert model.layer2.weight.dtype.is_floating_point
assert not torch.isnan(model.layer1.weight).any()
assert not torch.isnan(model.layer2.weight).any()
assert not torch.isinf(model.layer1.weight).any()
assert not torch.isinf(model.layer2.weight).any()


@pytest.mark.parametrize("num_bits", [4, 8])
def test_asymmetric_quantization_accuracy(num_bits, mock_per_group_calibration):
"""
Test that asymmetric quantization with zero-point preserves accuracy better
than symmetric quantization for biased weight distributions.
"""
shape = (256, 512)
biased_weights = torch.randn(shape) + 2.0

quant_config = create_asymmetric_quant_config(
num_bits=num_bits,
strategy=QuantizationStrategy.GROUP,
group_size=128,
)
quant_config.format = CompressionFormat.pack_quantized.value

class SingleLayer(Module):
def __init__(self):
super().__init__()
self.layer = Linear(shape[1], shape[0], bias=False)

model = SingleLayer()
apply_quantization_config(model, quant_config)

with torch.no_grad():
model.layer.weight.copy_(biased_weights)
mock_per_group_calibration(model.layer, "weight", model.layer.weight, 128)

# Compress and decompress in memory using ModelCompressor
mc = ModelCompressor(quantization_config=quant_config)
mc.compress_model(model)
mc.decompress_model(model)

decompressed_weights = model.layer.weight
assert decompressed_weights.shape == shape
assert not torch.isnan(decompressed_weights).any()
assert not torch.isinf(decompressed_weights).any()
threshold = torch.std(torch.rand(shape) - torch.rand(shape))
assert torch.std(biased_weights - decompressed_weights) < threshold
Loading