From 4f610f765034976efdba8f47e502b53b74355f8a Mon Sep 17 00:00:00 2001 From: Etelis <92247226+Etelis@users.noreply.github.com> Date: Tue, 9 Sep 2025 19:12:58 +0300 Subject: [PATCH 01/11] feat: add zero-point decompression support for asymmetric quantization - Fix decompress_weight method in PackedQuantizationCompressor to support unpacking zero-points - Add comprehensive tests for zero-point packing/unpacking with GROUP and CHANNEL strategies - Add end-to-end integration tests for asymmetric quantization workflow - Ensure packed tensors are contiguous for safetensors compatibility Resolves issue referenced in vllm-project/llm-compressor#1704 --- .../quantized_compressors/pack_quantized.py | 19 +- .../test_asymmetric_decompression.py | 228 ++++++++++++++++++ .../quantized_compressors/test_pack_quant.py | 91 +++++++ 3 files changed, 325 insertions(+), 13 deletions(-) create mode 100644 tests/test_compressors/quantized_compressors/test_asymmetric_decompression.py diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index e2ce3d24b..9e3e2aa69 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -134,8 +134,6 @@ def compress_weight( compressed_dict["weight_shape"] = weight_shape compressed_dict["weight_packed"] = packed_weight - # We typically don't compress zp; apart from when using the packed_compressor - # and when storing group/channel zp if not quantization_args.symmetric and quantization_args.strategy in [ QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, @@ -143,7 +141,7 @@ def compress_weight( packed_zp = pack_to_int32( zero_point, quantization_args.num_bits, packed_dim=0 ) - compressed_dict["weight_zero_point"] = packed_zp + compressed_dict["weight_zero_point"] = packed_zp.contiguous() return compressed_dict def decompress_weight( @@ -166,20 +164,15 @@ def decompress_weight( num_bits = quantization_args.num_bits unpacked = unpack_from_int32(weight, num_bits, original_shape) - # NOTE: this will fail decompression as we don't currently handle packed zp on - # decompression if not quantization_args.symmetric and quantization_args.strategy in [ QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, ]: - raise ValueError( - "Decompression of packed zero points is currently not supported" - ) - assert zero_point is not None - original_zp_shape = (original_shape[0], scale.shape[-1]) - zero_point = unpack_from_int32( - zero_point, num_bits, original_zp_shape, packed_dim=0 - ) + if zero_point is not None: + original_zp_shape = (original_shape[0], scale.shape[-1]) + zero_point = unpack_from_int32( + zero_point, num_bits, original_zp_shape, packed_dim=0 + ) decompressed_weight = dequantize( x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx diff --git a/tests/test_compressors/quantized_compressors/test_asymmetric_decompression.py b/tests/test_compressors/quantized_compressors/test_asymmetric_decompression.py new file mode 100644 index 000000000..681523270 --- /dev/null +++ b/tests/test_compressors/quantized_compressors/test_asymmetric_decompression.py @@ -0,0 +1,228 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +End-to-end tests for asymmetric quantization with zero-point decompression. +""" + +import shutil +import tempfile +from pathlib import Path + +import pytest +import torch +from compressed_tensors import PackedQuantizationCompressor +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationConfig, + QuantizationScheme, + QuantizationStrategy, + apply_quantization_config, +) +from compressed_tensors.quantization.lifecycle.forward import fake_quantize +from safetensors.torch import save_file +from torch.nn import Linear, Module, Sequential + + +class SimpleModel(Module): + """Simple model for testing""" + def __init__(self, input_dim=512, hidden_dim=256, output_dim=128): + super().__init__() + self.layer1 = Linear(input_dim, hidden_dim, bias=False) + self.layer2 = Linear(hidden_dim, output_dim, bias=False) + + def forward(self, x): + x = self.layer1(x) + x = torch.relu(x) + x = self.layer2(x) + return x + + +def create_asymmetric_quant_config( + num_bits=4, + strategy=QuantizationStrategy.GROUP, + group_size=128 +) -> QuantizationConfig: + """Create an asymmetric quantization config""" + config_groups = { + "group_1": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=num_bits, + strategy=strategy.value, + group_size=group_size if strategy == QuantizationStrategy.GROUP else None, + symmetric=False, + ), + ), + } + return QuantizationConfig(config_groups=config_groups) + + +@pytest.mark.parametrize( + "strategy,group_size", + [ + (QuantizationStrategy.GROUP, 128), + (QuantizationStrategy.CHANNEL, None), + ], +) +def test_end_to_end_asymmetric_quantization(strategy, group_size): + """ + Test end-to-end workflow: quantize -> compress -> save -> load -> decompress -> use + """ + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + model = SimpleModel() + original_weights = { + "layer1": model.layer1.weight.clone(), + "layer2": model.layer2.weight.clone(), + } + + quant_config = create_asymmetric_quant_config( + num_bits=4, + strategy=strategy, + group_size=group_size + ) + apply_quantization_config(model, quant_config) + + for name, module in model.named_modules(): + if isinstance(module, Linear): + weight = module.weight + if strategy == QuantizationStrategy.CHANNEL: + scale_shape = (weight.shape[0], 1) + else: + scale_shape = (weight.shape[0], weight.shape[1] // group_size) + + module.weight_scale = torch.nn.Parameter( + torch.rand(scale_shape) * 0.1, + requires_grad=False + ) + module.weight_zero_point = torch.nn.Parameter( + torch.randint(-8, 8, scale_shape, dtype=torch.int8), + requires_grad=False + ) + + compressor = PackedQuantizationCompressor(config=quant_config) + quantized_modules_to_scheme = { + "layer1": quant_config.config_groups["group_1"], + "layer2": quant_config.config_groups["group_1"], + } + + state_dict = model.state_dict() + compressed_state_dict = compressor.compress( + state_dict, names_to_scheme=quantized_modules_to_scheme + ) + + assert "layer1.weight_zero_point" in compressed_state_dict + assert "layer2.weight_zero_point" in compressed_state_dict + assert compressed_state_dict["layer1.weight_zero_point"].dtype == torch.int32 + assert compressed_state_dict["layer2.weight_zero_point"].dtype == torch.int32 + + save_file(compressed_state_dict, tmp_path / "model.safetensors") + + reconstructed_gen = compressor.decompress( + tmp_path, names_to_scheme=quantized_modules_to_scheme + ) + + reconstructed_weights = {} + for module_name, module_data in reconstructed_gen: + reconstructed_weights[module_name] = module_data + + assert "layer1" in reconstructed_weights + assert "layer2" in reconstructed_weights + assert "weight" in reconstructed_weights["layer1"] + assert "weight" in reconstructed_weights["layer2"] + + assert reconstructed_weights["layer1"]["weight"].shape == original_weights["layer1"].shape + assert reconstructed_weights["layer2"]["weight"].shape == original_weights["layer2"].shape + + new_model = SimpleModel() + new_model.layer1.weight.data = reconstructed_weights["layer1"]["weight"] + new_model.layer2.weight.data = reconstructed_weights["layer2"]["weight"] + + test_input = torch.randn(1, 512) + with torch.no_grad(): + output = new_model(test_input) + + assert output.shape == (1, 128) + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() + + +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_asymmetric_quantization_accuracy(num_bits): + """ + Test that asymmetric quantization with zero-point preserves accuracy better + than symmetric quantization for biased weight distributions. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + shape = (256, 512) + weights = torch.randn(shape) + 2.0 + + quant_config = create_asymmetric_quant_config( + num_bits=num_bits, + strategy=QuantizationStrategy.GROUP, + group_size=128 + ) + + group_size = 128 + num_groups = shape[1] // group_size + scale_shape = (shape[0], num_groups) + + scales = torch.rand(scale_shape) * 0.1 + zero_points = torch.randint(-2**(num_bits-1), 2**(num_bits-1), scale_shape, dtype=torch.int8) + + state_dict = { + "layer.weight": weights, + "layer.weight_scale": scales, + "layer.weight_zero_point": zero_points, + } + + compressor = PackedQuantizationCompressor(config=quant_config) + quantized_modules_to_scheme = {"layer": quant_config.config_groups["group_1"]} + + compressed_state_dict = compressor.compress( + state_dict.copy(), names_to_scheme=quantized_modules_to_scheme + ) + + save_file(compressed_state_dict, tmp_path / "model.safetensors") + + reconstructed_gen = compressor.decompress( + tmp_path, names_to_scheme=quantized_modules_to_scheme + ) + + reconstructed = {} + for module_name, module_data in reconstructed_gen: + reconstructed[module_name] = module_data + + assert "layer" in reconstructed + assert "weight" in reconstructed["layer"] + assert reconstructed["layer"]["weight"].shape == shape + + decompressed_weights = reconstructed["layer"]["weight"] + assert not torch.isnan(decompressed_weights).any() + assert not torch.isinf(decompressed_weights).any() + + assert decompressed_weights.abs().max() < 100 + assert decompressed_weights.abs().max() > 0.01 + + +if __name__ == "__main__": + test_end_to_end_asymmetric_quantization(QuantizationStrategy.GROUP, 128) + test_end_to_end_asymmetric_quantization(QuantizationStrategy.CHANNEL, None) + test_asymmetric_quantization_accuracy(4) + test_asymmetric_quantization_accuracy(8) + print("All tests passed!") diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index 00d612756..5ccd09f83 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -473,3 +473,94 @@ def test_unpack_from_int32(num_bits, values, expected_tensor): unpacked_tensor = unpack_from_int32(values, num_bits, expected_tensor.shape) assert torch.equal(unpacked_tensor, unpacked_tensor) assert unpacked_tensor.dtype == unpacked_tensor.dtype + + +@pytest.mark.parametrize( + "strategy,group_size", + [ + (QuantizationStrategy.GROUP, 128), + (QuantizationStrategy.CHANNEL, None), + ], +) +def test_asymmetric_zero_point_decompression(strategy, group_size, tmp_path): + """ + Test that zero-point packing and unpacking works correctly for asymmetric quantization + with GROUP and CHANNEL strategies. + """ + shape = (512, 1024) + + if strategy == QuantizationStrategy.CHANNEL: + expected_zp_shape = (shape[0], 1) + elif strategy == QuantizationStrategy.GROUP: + num_groups = shape[1] // group_size + expected_zp_shape = (shape[0], max(num_groups, 1)) + + dense_state_dict = { + "dummy.weight": torch.randn(shape), + "dummy.weight_scale": torch.rand(expected_zp_shape).to(torch.float32), + "dummy.weight_zero_point": torch.randint(-8, 8, expected_zp_shape).to(torch.int8), + } + + quant_config = get_dummy_quant_config( + num_bits=4, + strategy=strategy.value, + symmetric=False, + group_size=group_size + ) + + compressor = PackedQuantizationCompressor(config=quant_config) + quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]} + compressed_state_dict = compressor.compress( + dense_state_dict.copy(), names_to_scheme=quantized_modules_to_scheme + ) + + assert "dummy.weight_zero_point" in compressed_state_dict + assert compressed_state_dict["dummy.weight_zero_point"].dtype == torch.int32 + + save_file(compressed_state_dict, tmp_path / "model.safetensors") + + reconstructed_dense_gen = compressor.decompress( + tmp_path, names_to_scheme=quantized_modules_to_scheme + ) + reconstructed_dense = {} + for name, value in reconstructed_dense_gen: + reconstructed_dense[name] = value + + assert "dummy" in reconstructed_dense + assert "weight" in reconstructed_dense["dummy"] + + assert reconstructed_dense["dummy"]["weight"].shape == shape + + shutil.rmtree(tmp_path) + + +@pytest.mark.parametrize( + "num_bits,strategy", + [ + (4, QuantizationStrategy.GROUP), + (4, QuantizationStrategy.CHANNEL), + (8, QuantizationStrategy.GROUP), + (8, QuantizationStrategy.CHANNEL), + ], +) +def test_zero_point_pack_unpack_consistency(num_bits, strategy): + """ + Test that packing and unpacking zero-points preserves values correctly. + """ + if strategy == QuantizationStrategy.GROUP: + shape = (512, 8) + group_size = 128 + else: + shape = (512, 1) + group_size = None + + max_val = (1 << (num_bits - 1)) - 1 + min_val = -(1 << (num_bits - 1)) + original_zp = torch.randint(min_val, max_val + 1, shape).to(torch.int8) + + packed_zp = pack_to_int32(original_zp, num_bits, packed_dim=0) + + unpacked_zp = unpack_from_int32(packed_zp, num_bits, shape, packed_dim=0) + + assert torch.equal(original_zp, unpacked_zp) + assert unpacked_zp.dtype == torch.int8 From bd1d083b2643a5ebbd9ab2e327a7c42ac233e020 Mon Sep 17 00:00:00 2001 From: Etelis <92247226+Etelis@users.noreply.github.com> Date: Wed, 10 Sep 2025 09:35:07 +0300 Subject: [PATCH 02/11] nit: assert zero_point exists for asymmetric strategies before unpacking --- .../quantized_compressors/pack_quantized.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 9e3e2aa69..07da50c7f 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -168,11 +168,13 @@ def decompress_weight( QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, ]: - if zero_point is not None: - original_zp_shape = (original_shape[0], scale.shape[-1]) - zero_point = unpack_from_int32( - zero_point, num_bits, original_zp_shape, packed_dim=0 - ) + assert ( + zero_point is not None + ), "Asymmetric quantization requires zero-point values" + original_zp_shape = (original_shape[0], scale.shape[-1]) + zero_point = unpack_from_int32( + zero_point, num_bits, original_zp_shape, packed_dim=0 + ) decompressed_weight = dequantize( x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx From 281f1c66fe6971ea15ce1108500cbaec9b569722 Mon Sep 17 00:00:00 2001 From: Etelis <92247226+Etelis@users.noreply.github.com> Date: Thu, 11 Sep 2025 17:16:40 +0300 Subject: [PATCH 03/11] tests: rely on apply_quantization_config to init scale/zero-point; remove manual creation --- .../test_asymmetric_decompression.py | 55 +++++++------------ 1 file changed, 19 insertions(+), 36 deletions(-) diff --git a/tests/test_compressors/quantized_compressors/test_asymmetric_decompression.py b/tests/test_compressors/quantized_compressors/test_asymmetric_decompression.py index 681523270..25347a98c 100644 --- a/tests/test_compressors/quantized_compressors/test_asymmetric_decompression.py +++ b/tests/test_compressors/quantized_compressors/test_asymmetric_decompression.py @@ -96,22 +96,7 @@ def test_end_to_end_asymmetric_quantization(strategy, group_size): ) apply_quantization_config(model, quant_config) - for name, module in model.named_modules(): - if isinstance(module, Linear): - weight = module.weight - if strategy == QuantizationStrategy.CHANNEL: - scale_shape = (weight.shape[0], 1) - else: - scale_shape = (weight.shape[0], weight.shape[1] // group_size) - - module.weight_scale = torch.nn.Parameter( - torch.rand(scale_shape) * 0.1, - requires_grad=False - ) - module.weight_zero_point = torch.nn.Parameter( - torch.randint(-8, 8, scale_shape, dtype=torch.int8), - requires_grad=False - ) + compressor = PackedQuantizationCompressor(config=quant_config) quantized_modules_to_scheme = { @@ -168,34 +153,32 @@ def test_asymmetric_quantization_accuracy(num_bits): """ with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) - + shape = (256, 512) - weights = torch.randn(shape) + 2.0 - + biased_weights = torch.randn(shape) + 2.0 + quant_config = create_asymmetric_quant_config( num_bits=num_bits, strategy=QuantizationStrategy.GROUP, - group_size=128 + group_size=128, ) - - group_size = 128 - num_groups = shape[1] // group_size - scale_shape = (shape[0], num_groups) - - scales = torch.rand(scale_shape) * 0.1 - zero_points = torch.randint(-2**(num_bits-1), 2**(num_bits-1), scale_shape, dtype=torch.int8) - - state_dict = { - "layer.weight": weights, - "layer.weight_scale": scales, - "layer.weight_zero_point": zero_points, - } - + + class SingleLayer(Module): + def __init__(self): + super().__init__() + self.layer = Linear(shape[1], shape[0], bias=False) + + model = SingleLayer() + apply_quantization_config(model, quant_config) + + with torch.no_grad(): + model.layer.weight.copy_(biased_weights) + compressor = PackedQuantizationCompressor(config=quant_config) quantized_modules_to_scheme = {"layer": quant_config.config_groups["group_1"]} - + compressed_state_dict = compressor.compress( - state_dict.copy(), names_to_scheme=quantized_modules_to_scheme + model.state_dict().copy(), names_to_scheme=quantized_modules_to_scheme ) save_file(compressed_state_dict, tmp_path / "model.safetensors") From c0cbb70709df977fbcc1c00cc50529652cb1e771 Mon Sep 17 00:00:00 2001 From: Etelis <92247226+Etelis@users.noreply.github.com> Date: Thu, 11 Sep 2025 17:19:48 +0300 Subject: [PATCH 04/11] tests: rename to test_packed_asym_decompression.py --- ...ion.py => test_packed_asym_decompression.py} | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) rename tests/test_compressors/quantized_compressors/{test_asymmetric_decompression.py => test_packed_asym_decompression.py} (90%) diff --git a/tests/test_compressors/quantized_compressors/test_asymmetric_decompression.py b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py similarity index 90% rename from tests/test_compressors/quantized_compressors/test_asymmetric_decompression.py rename to tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py index 25347a98c..5d005b66c 100644 --- a/tests/test_compressors/quantized_compressors/test_asymmetric_decompression.py +++ b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py @@ -76,7 +76,12 @@ def create_asymmetric_quant_config( (QuantizationStrategy.CHANNEL, None), ], ) -def test_end_to_end_asymmetric_quantization(strategy, group_size): +def test_end_to_end_asymmetric_quantization( + strategy, + group_size, + mock_per_group_calibration, + mock_per_channel_calibration, +): """ Test end-to-end workflow: quantize -> compress -> save -> load -> decompress -> use """ @@ -95,6 +100,13 @@ def test_end_to_end_asymmetric_quantization(strategy, group_size): group_size=group_size ) apply_quantization_config(model, quant_config) + + if strategy == QuantizationStrategy.GROUP: + mock_per_group_calibration(model.layer1, "weight", model.layer1.weight, group_size) + mock_per_group_calibration(model.layer2, "weight", model.layer2.weight, group_size) + else: + mock_per_channel_calibration(model.layer1, "weight", model.layer1.weight) + mock_per_channel_calibration(model.layer2, "weight", model.layer2.weight) @@ -146,7 +158,7 @@ def test_end_to_end_asymmetric_quantization(strategy, group_size): @pytest.mark.parametrize("num_bits", [4, 8]) -def test_asymmetric_quantization_accuracy(num_bits): +def test_asymmetric_quantization_accuracy(num_bits, mock_per_group_calibration): """ Test that asymmetric quantization with zero-point preserves accuracy better than symmetric quantization for biased weight distributions. @@ -173,6 +185,7 @@ def __init__(self): with torch.no_grad(): model.layer.weight.copy_(biased_weights) + mock_per_group_calibration(model.layer, "weight", model.layer.weight, 128) compressor = PackedQuantizationCompressor(config=quant_config) quantized_modules_to_scheme = {"layer": quant_config.config_groups["group_1"]} From 126fc89515e8f5fcdc04a01b080f73721ceed6de Mon Sep 17 00:00:00 2001 From: Etelis <92247226+Etelis@users.noreply.github.com> Date: Thu, 11 Sep 2025 17:46:45 +0300 Subject: [PATCH 05/11] tests: use in-memory decompress_model; calibrate via fixtures; std-dev similarity; cleanup temp usage --- .../quantized_compressors/test_pack_quant.py | 12 +- .../test_packed_asym_decompression.py | 112 +++++++++--------- 2 files changed, 63 insertions(+), 61 deletions(-) diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index 5ccd09f83..5cf6da379 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -15,6 +15,7 @@ import math import shutil +import tempfile from collections import OrderedDict import pytest @@ -170,12 +171,13 @@ def test_reload_match(tmp_path, num_bits): ) save_file(compressed_state_dict, tmp_path / "model.safetensors") - reconstructed_dense_gen = compressor.decompress( - tmp_path, names_to_scheme=quantized_modules_to_scheme - ) reconstructed_dense = {} - for name, value in reconstructed_dense_gen: - reconstructed_dense[name] = value + with tempfile.TemporaryDirectory() as _tmp: + reconstructed_dense_gen = compressor.decompress( + tmp_path, names_to_scheme=quantized_modules_to_scheme + ) + for name, value in reconstructed_dense_gen: + reconstructed_dense[name] = value fake_quant_dummy = fake_quantize( dense_state_dict["dummy.weight"], diff --git a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py index 5d005b66c..62ece296f 100644 --- a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py +++ b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py @@ -30,8 +30,12 @@ QuantizationStrategy, apply_quantization_config, ) +from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.lifecycle.forward import fake_quantize from safetensors.torch import save_file +from compressed_tensors.compressors.model_compressors.model_compressor import ( + ModelCompressor, +) from torch.nn import Linear, Module, Sequential @@ -90,8 +94,8 @@ def test_end_to_end_asymmetric_quantization( model = SimpleModel() original_weights = { - "layer1": model.layer1.weight.clone(), - "layer2": model.layer2.weight.clone(), + "layer1": model.layer1.weight.detach().clone(), + "layer2": model.layer2.weight.detach().clone(), } quant_config = create_asymmetric_quant_config( @@ -99,6 +103,8 @@ def test_end_to_end_asymmetric_quantization( strategy=strategy, group_size=group_size ) + # Set pack-quantized format for ModelCompressor usage + quant_config.format = CompressionFormat.pack_quantized.value apply_quantization_config(model, quant_config) if strategy == QuantizationStrategy.GROUP: @@ -126,35 +132,33 @@ def test_end_to_end_asymmetric_quantization( assert compressed_state_dict["layer1.weight_zero_point"].dtype == torch.int32 assert compressed_state_dict["layer2.weight_zero_point"].dtype == torch.int32 - save_file(compressed_state_dict, tmp_path / "model.safetensors") - - reconstructed_gen = compressor.decompress( - tmp_path, names_to_scheme=quantized_modules_to_scheme - ) - - reconstructed_weights = {} - for module_name, module_data in reconstructed_gen: - reconstructed_weights[module_name] = module_data - - assert "layer1" in reconstructed_weights - assert "layer2" in reconstructed_weights - assert "weight" in reconstructed_weights["layer1"] - assert "weight" in reconstructed_weights["layer2"] - - assert reconstructed_weights["layer1"]["weight"].shape == original_weights["layer1"].shape - assert reconstructed_weights["layer2"]["weight"].shape == original_weights["layer2"].shape - new_model = SimpleModel() - new_model.layer1.weight.data = reconstructed_weights["layer1"]["weight"] - new_model.layer2.weight.data = reconstructed_weights["layer2"]["weight"] - - test_input = torch.randn(1, 512) - with torch.no_grad(): - output = new_model(test_input) - - assert output.shape == (1, 128) - assert not torch.isnan(output).any() - assert not torch.isinf(output).any() + apply_quantization_config(new_model, quant_config) + + for module_name in ["layer1", "layer2"]: + module = getattr(new_model, module_name) + prefix = f"{module_name}." + for key, value in compressed_state_dict.items(): + if key.startswith(prefix): + param_name = key[len(prefix):] + if hasattr(module, param_name): + getattr(module, param_name).data = value.clone() + else: + module.register_parameter( + param_name, torch.nn.Parameter(value.clone(), requires_grad=False) + ) + + mc = ModelCompressor(quantization_config=quant_config) + mc.decompress_model(new_model) + + assert new_model.layer1.weight.shape == original_weights["layer1"].shape + assert new_model.layer2.weight.shape == original_weights["layer2"].shape + assert new_model.layer1.weight.dtype.is_floating_point + assert new_model.layer2.weight.dtype.is_floating_point + assert not torch.isnan(new_model.layer1.weight).any() + assert not torch.isnan(new_model.layer2.weight).any() + assert not torch.isinf(new_model.layer1.weight).any() + assert not torch.isinf(new_model.layer2.weight).any() @pytest.mark.parametrize("num_bits", [4, 8]) @@ -174,6 +178,7 @@ def test_asymmetric_quantization_accuracy(num_bits, mock_per_group_calibration): strategy=QuantizationStrategy.GROUP, group_size=128, ) + quant_config.format = CompressionFormat.pack_quantized.value class SingleLayer(Module): def __init__(self): @@ -194,31 +199,26 @@ def __init__(self): model.state_dict().copy(), names_to_scheme=quantized_modules_to_scheme ) - save_file(compressed_state_dict, tmp_path / "model.safetensors") - - reconstructed_gen = compressor.decompress( - tmp_path, names_to_scheme=quantized_modules_to_scheme - ) - - reconstructed = {} - for module_name, module_data in reconstructed_gen: - reconstructed[module_name] = module_data - - assert "layer" in reconstructed - assert "weight" in reconstructed["layer"] - assert reconstructed["layer"]["weight"].shape == shape - - decompressed_weights = reconstructed["layer"]["weight"] + new_model = SingleLayer() + apply_quantization_config(new_model, quant_config) + + module = new_model.layer + for key, value in compressed_state_dict.items(): + if key.startswith("layer."): + param_name = key[len("layer."):] + if hasattr(module, param_name): + getattr(module, param_name).data = value.clone() + else: + module.register_parameter( + param_name, torch.nn.Parameter(value.clone(), requires_grad=False) + ) + + mc = ModelCompressor(quantization_config=quant_config) + mc.decompress_model(new_model) + + decompressed_weights = new_model.layer.weight + assert decompressed_weights.shape == shape assert not torch.isnan(decompressed_weights).any() assert not torch.isinf(decompressed_weights).any() - - assert decompressed_weights.abs().max() < 100 - assert decompressed_weights.abs().max() > 0.01 - - -if __name__ == "__main__": - test_end_to_end_asymmetric_quantization(QuantizationStrategy.GROUP, 128) - test_end_to_end_asymmetric_quantization(QuantizationStrategy.CHANNEL, None) - test_asymmetric_quantization_accuracy(4) - test_asymmetric_quantization_accuracy(8) - print("All tests passed!") + threshold = torch.std(torch.rand(shape) - torch.rand(shape)) + assert torch.std(biased_weights - decompressed_weights) < threshold \ No newline at end of file From 31fe0cffd42803e8f27d6b18304b6e512ab1074f Mon Sep 17 00:00:00 2001 From: Etelis <92247226+Etelis@users.noreply.github.com> Date: Mon, 20 Oct 2025 16:12:38 +0300 Subject: [PATCH 06/11] refactor: use in-memory compress/decompress methods --- .../test_packed_asym_decompression.py | 213 +++++++----------- 1 file changed, 79 insertions(+), 134 deletions(-) diff --git a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py index 62ece296f..b43007450 100644 --- a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py +++ b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py @@ -16,13 +16,8 @@ End-to-end tests for asymmetric quantization with zero-point decompression. """ -import shutil -import tempfile -from pathlib import Path - import pytest import torch -from compressed_tensors import PackedQuantizationCompressor from compressed_tensors.quantization import ( QuantizationArgs, QuantizationConfig, @@ -31,12 +26,10 @@ apply_quantization_config, ) from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization.lifecycle.forward import fake_quantize -from safetensors.torch import save_file from compressed_tensors.compressors.model_compressors.model_compressor import ( ModelCompressor, ) -from torch.nn import Linear, Module, Sequential +from torch.nn import Linear, Module class SimpleModel(Module): @@ -87,78 +80,52 @@ def test_end_to_end_asymmetric_quantization( mock_per_channel_calibration, ): """ - Test end-to-end workflow: quantize -> compress -> save -> load -> decompress -> use + Test end-to-end workflow: quantize -> compress -> decompress in memory """ - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_path = Path(tmp_dir) - - model = SimpleModel() - original_weights = { - "layer1": model.layer1.weight.detach().clone(), - "layer2": model.layer2.weight.detach().clone(), - } - - quant_config = create_asymmetric_quant_config( - num_bits=4, - strategy=strategy, - group_size=group_size - ) - # Set pack-quantized format for ModelCompressor usage - quant_config.format = CompressionFormat.pack_quantized.value - apply_quantization_config(model, quant_config) - - if strategy == QuantizationStrategy.GROUP: - mock_per_group_calibration(model.layer1, "weight", model.layer1.weight, group_size) - mock_per_group_calibration(model.layer2, "weight", model.layer2.weight, group_size) - else: - mock_per_channel_calibration(model.layer1, "weight", model.layer1.weight) - mock_per_channel_calibration(model.layer2, "weight", model.layer2.weight) - - - - compressor = PackedQuantizationCompressor(config=quant_config) - quantized_modules_to_scheme = { - "layer1": quant_config.config_groups["group_1"], - "layer2": quant_config.config_groups["group_1"], - } - - state_dict = model.state_dict() - compressed_state_dict = compressor.compress( - state_dict, names_to_scheme=quantized_modules_to_scheme - ) - - assert "layer1.weight_zero_point" in compressed_state_dict - assert "layer2.weight_zero_point" in compressed_state_dict - assert compressed_state_dict["layer1.weight_zero_point"].dtype == torch.int32 - assert compressed_state_dict["layer2.weight_zero_point"].dtype == torch.int32 - - new_model = SimpleModel() - apply_quantization_config(new_model, quant_config) - - for module_name in ["layer1", "layer2"]: - module = getattr(new_model, module_name) - prefix = f"{module_name}." - for key, value in compressed_state_dict.items(): - if key.startswith(prefix): - param_name = key[len(prefix):] - if hasattr(module, param_name): - getattr(module, param_name).data = value.clone() - else: - module.register_parameter( - param_name, torch.nn.Parameter(value.clone(), requires_grad=False) - ) - - mc = ModelCompressor(quantization_config=quant_config) - mc.decompress_model(new_model) - - assert new_model.layer1.weight.shape == original_weights["layer1"].shape - assert new_model.layer2.weight.shape == original_weights["layer2"].shape - assert new_model.layer1.weight.dtype.is_floating_point - assert new_model.layer2.weight.dtype.is_floating_point - assert not torch.isnan(new_model.layer1.weight).any() - assert not torch.isnan(new_model.layer2.weight).any() - assert not torch.isinf(new_model.layer1.weight).any() - assert not torch.isinf(new_model.layer2.weight).any() + model = SimpleModel() + original_weights = { + "layer1": model.layer1.weight.detach().clone(), + "layer2": model.layer2.weight.detach().clone(), + } + + quant_config = create_asymmetric_quant_config( + num_bits=4, + strategy=strategy, + group_size=group_size + ) + # Set pack-quantized format for ModelCompressor usage + quant_config.format = CompressionFormat.pack_quantized.value + apply_quantization_config(model, quant_config) + + if strategy == QuantizationStrategy.GROUP: + mock_per_group_calibration(model.layer1, "weight", model.layer1.weight, group_size) + mock_per_group_calibration(model.layer2, "weight", model.layer2.weight, group_size) + else: + mock_per_channel_calibration(model.layer1, "weight", model.layer1.weight) + mock_per_channel_calibration(model.layer2, "weight", model.layer2.weight) + + # Compress and decompress in memory using ModelCompressor + mc = ModelCompressor(quantization_config=quant_config) + mc.compress_model(model) + + # Verify compression created zero-point parameters + assert hasattr(model.layer1, "weight_zero_point") + assert hasattr(model.layer2, "weight_zero_point") + assert model.layer1.weight_zero_point.dtype == torch.int32 + assert model.layer2.weight_zero_point.dtype == torch.int32 + + # Decompress in memory + mc.decompress_model(model) + + # Verify decompression restored weights correctly + assert model.layer1.weight.shape == original_weights["layer1"].shape + assert model.layer2.weight.shape == original_weights["layer2"].shape + assert model.layer1.weight.dtype.is_floating_point + assert model.layer2.weight.dtype.is_floating_point + assert not torch.isnan(model.layer1.weight).any() + assert not torch.isnan(model.layer2.weight).any() + assert not torch.isinf(model.layer1.weight).any() + assert not torch.isinf(model.layer2.weight).any() @pytest.mark.parametrize("num_bits", [4, 8]) @@ -167,58 +134,36 @@ def test_asymmetric_quantization_accuracy(num_bits, mock_per_group_calibration): Test that asymmetric quantization with zero-point preserves accuracy better than symmetric quantization for biased weight distributions. """ - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_path = Path(tmp_dir) - - shape = (256, 512) - biased_weights = torch.randn(shape) + 2.0 - - quant_config = create_asymmetric_quant_config( - num_bits=num_bits, - strategy=QuantizationStrategy.GROUP, - group_size=128, - ) - quant_config.format = CompressionFormat.pack_quantized.value - - class SingleLayer(Module): - def __init__(self): - super().__init__() - self.layer = Linear(shape[1], shape[0], bias=False) - - model = SingleLayer() - apply_quantization_config(model, quant_config) - - with torch.no_grad(): - model.layer.weight.copy_(biased_weights) - mock_per_group_calibration(model.layer, "weight", model.layer.weight, 128) - - compressor = PackedQuantizationCompressor(config=quant_config) - quantized_modules_to_scheme = {"layer": quant_config.config_groups["group_1"]} - - compressed_state_dict = compressor.compress( - model.state_dict().copy(), names_to_scheme=quantized_modules_to_scheme - ) - - new_model = SingleLayer() - apply_quantization_config(new_model, quant_config) - - module = new_model.layer - for key, value in compressed_state_dict.items(): - if key.startswith("layer."): - param_name = key[len("layer."):] - if hasattr(module, param_name): - getattr(module, param_name).data = value.clone() - else: - module.register_parameter( - param_name, torch.nn.Parameter(value.clone(), requires_grad=False) - ) - - mc = ModelCompressor(quantization_config=quant_config) - mc.decompress_model(new_model) - - decompressed_weights = new_model.layer.weight - assert decompressed_weights.shape == shape - assert not torch.isnan(decompressed_weights).any() - assert not torch.isinf(decompressed_weights).any() - threshold = torch.std(torch.rand(shape) - torch.rand(shape)) - assert torch.std(biased_weights - decompressed_weights) < threshold \ No newline at end of file + shape = (256, 512) + biased_weights = torch.randn(shape) + 2.0 + + quant_config = create_asymmetric_quant_config( + num_bits=num_bits, + strategy=QuantizationStrategy.GROUP, + group_size=128, + ) + quant_config.format = CompressionFormat.pack_quantized.value + + class SingleLayer(Module): + def __init__(self): + super().__init__() + self.layer = Linear(shape[1], shape[0], bias=False) + + model = SingleLayer() + apply_quantization_config(model, quant_config) + + with torch.no_grad(): + model.layer.weight.copy_(biased_weights) + mock_per_group_calibration(model.layer, "weight", model.layer.weight, 128) + + # Compress and decompress in memory using ModelCompressor + mc = ModelCompressor(quantization_config=quant_config) + mc.compress_model(model) + mc.decompress_model(model) + + decompressed_weights = model.layer.weight + assert decompressed_weights.shape == shape + assert not torch.isnan(decompressed_weights).any() + assert not torch.isinf(decompressed_weights).any() + threshold = torch.std(torch.rand(shape) - torch.rand(shape)) + assert torch.std(biased_weights - decompressed_weights) < threshold \ No newline at end of file From 3ffb213fe3a791d6296e8b66d1a1170f7107c79a Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 23 Oct 2025 12:36:05 -0500 Subject: [PATCH 07/11] stylefix Signed-off-by: Brian Dellabetta --- .../model_compressors/model_compressor.py | 16 ++++---- .../quantized_compressors/test_pack_quant.py | 37 +++++++++---------- .../test_packed_asym_decompression.py | 35 ++++++++++-------- 3 files changed, 46 insertions(+), 42 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 3a0fe4903..74a0d3944 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -201,9 +201,11 @@ def from_pretrained_model( sparsity_config=sparsity_config, quantization_config=quantization_config, transform_config=transform_config, - compression_formats=[quantization_format] - if isinstance(quantization_format, str) - else quantization_format, + compression_formats=( + [quantization_format] + if isinstance(quantization_format, str) + else quantization_format + ), ) @staticmethod @@ -314,10 +316,10 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: - self.quantization_compressor[ - format - ] = BaseCompressor.load_from_registry( - format, config=quantization_config + self.quantization_compressor[format] = ( + BaseCompressor.load_from_registry( + format, config=quantization_config + ) ) # ----- used by hf quantizer ----- # diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index 5cf6da379..a5a6c8792 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -490,49 +490,48 @@ def test_asymmetric_zero_point_decompression(strategy, group_size, tmp_path): with GROUP and CHANNEL strategies. """ shape = (512, 1024) - + if strategy == QuantizationStrategy.CHANNEL: expected_zp_shape = (shape[0], 1) elif strategy == QuantizationStrategy.GROUP: num_groups = shape[1] // group_size expected_zp_shape = (shape[0], max(num_groups, 1)) - + dense_state_dict = { "dummy.weight": torch.randn(shape), "dummy.weight_scale": torch.rand(expected_zp_shape).to(torch.float32), - "dummy.weight_zero_point": torch.randint(-8, 8, expected_zp_shape).to(torch.int8), + "dummy.weight_zero_point": torch.randint(-8, 8, expected_zp_shape).to( + torch.int8 + ), } - + quant_config = get_dummy_quant_config( - num_bits=4, - strategy=strategy.value, - symmetric=False, - group_size=group_size + num_bits=4, strategy=strategy.value, symmetric=False, group_size=group_size ) - + compressor = PackedQuantizationCompressor(config=quant_config) quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]} compressed_state_dict = compressor.compress( dense_state_dict.copy(), names_to_scheme=quantized_modules_to_scheme ) - + assert "dummy.weight_zero_point" in compressed_state_dict assert compressed_state_dict["dummy.weight_zero_point"].dtype == torch.int32 - + save_file(compressed_state_dict, tmp_path / "model.safetensors") - + reconstructed_dense_gen = compressor.decompress( tmp_path, names_to_scheme=quantized_modules_to_scheme ) reconstructed_dense = {} for name, value in reconstructed_dense_gen: reconstructed_dense[name] = value - + assert "dummy" in reconstructed_dense assert "weight" in reconstructed_dense["dummy"] - + assert reconstructed_dense["dummy"]["weight"].shape == shape - + shutil.rmtree(tmp_path) @@ -555,14 +554,14 @@ def test_zero_point_pack_unpack_consistency(num_bits, strategy): else: shape = (512, 1) group_size = None - + max_val = (1 << (num_bits - 1)) - 1 min_val = -(1 << (num_bits - 1)) original_zp = torch.randint(min_val, max_val + 1, shape).to(torch.int8) - + packed_zp = pack_to_int32(original_zp, num_bits, packed_dim=0) - + unpacked_zp = unpack_from_int32(packed_zp, num_bits, shape, packed_dim=0) - + assert torch.equal(original_zp, unpacked_zp) assert unpacked_zp.dtype == torch.int8 diff --git a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py index b43007450..b6ffcc92d 100644 --- a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py +++ b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py @@ -34,11 +34,12 @@ class SimpleModel(Module): """Simple model for testing""" + def __init__(self, input_dim=512, hidden_dim=256, output_dim=128): super().__init__() self.layer1 = Linear(input_dim, hidden_dim, bias=False) self.layer2 = Linear(hidden_dim, output_dim, bias=False) - + def forward(self, x): x = self.layer1(x) x = torch.relu(x) @@ -47,9 +48,7 @@ def forward(self, x): def create_asymmetric_quant_config( - num_bits=4, - strategy=QuantizationStrategy.GROUP, - group_size=128 + num_bits=4, strategy=QuantizationStrategy.GROUP, group_size=128 ) -> QuantizationConfig: """Create an asymmetric quantization config""" config_groups = { @@ -58,7 +57,9 @@ def create_asymmetric_quant_config( weights=QuantizationArgs( num_bits=num_bits, strategy=strategy.value, - group_size=group_size if strategy == QuantizationStrategy.GROUP else None, + group_size=( + group_size if strategy == QuantizationStrategy.GROUP else None + ), symmetric=False, ), ), @@ -87,36 +88,38 @@ def test_end_to_end_asymmetric_quantization( "layer1": model.layer1.weight.detach().clone(), "layer2": model.layer2.weight.detach().clone(), } - + quant_config = create_asymmetric_quant_config( - num_bits=4, - strategy=strategy, - group_size=group_size + num_bits=4, strategy=strategy, group_size=group_size ) # Set pack-quantized format for ModelCompressor usage quant_config.format = CompressionFormat.pack_quantized.value apply_quantization_config(model, quant_config) if strategy == QuantizationStrategy.GROUP: - mock_per_group_calibration(model.layer1, "weight", model.layer1.weight, group_size) - mock_per_group_calibration(model.layer2, "weight", model.layer2.weight, group_size) + mock_per_group_calibration( + model.layer1, "weight", model.layer1.weight, group_size + ) + mock_per_group_calibration( + model.layer2, "weight", model.layer2.weight, group_size + ) else: mock_per_channel_calibration(model.layer1, "weight", model.layer1.weight) mock_per_channel_calibration(model.layer2, "weight", model.layer2.weight) - + # Compress and decompress in memory using ModelCompressor mc = ModelCompressor(quantization_config=quant_config) mc.compress_model(model) - + # Verify compression created zero-point parameters assert hasattr(model.layer1, "weight_zero_point") assert hasattr(model.layer2, "weight_zero_point") assert model.layer1.weight_zero_point.dtype == torch.int32 assert model.layer2.weight_zero_point.dtype == torch.int32 - + # Decompress in memory mc.decompress_model(model) - + # Verify decompression restored weights correctly assert model.layer1.weight.shape == original_weights["layer1"].shape assert model.layer2.weight.shape == original_weights["layer2"].shape @@ -166,4 +169,4 @@ def __init__(self): assert not torch.isnan(decompressed_weights).any() assert not torch.isinf(decompressed_weights).any() threshold = torch.std(torch.rand(shape) - torch.rand(shape)) - assert torch.std(biased_weights - decompressed_weights) < threshold \ No newline at end of file + assert torch.std(biased_weights - decompressed_weights) < threshold From 2b70136bbf402d930e7332c0ff604b4dfa979030 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 23 Oct 2025 13:00:48 -0500 Subject: [PATCH 08/11] style fixes Signed-off-by: Brian Dellabetta --- .../quantized_compressors/test_fp4_quant.py | 4 ++++ .../quantized_compressors/test_pack_quant.py | 8 +++----- .../test_packed_asym_decompression.py | 8 ++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/test_compressors/quantized_compressors/test_fp4_quant.py b/tests/test_compressors/quantized_compressors/test_fp4_quant.py index bb405f79d..3169691be 100644 --- a/tests/test_compressors/quantized_compressors/test_fp4_quant.py +++ b/tests/test_compressors/quantized_compressors/test_fp4_quant.py @@ -54,4 +54,8 @@ def test_pack_unpack_odd_dims(): ) with pytest.raises((ValueError, torch._dynamo.exc.Unsupported)): +<<<<<<< Updated upstream:tests/test_compressors/quantized_compressors/test_fp4_quant.py _ = pack_fp4_to_uint8(x) +======= + pack_fp4_to_uint8(x) +>>>>>>> Stashed changes:tests/test_compressors/quantized_compressors/test_nvfp4_quant.py diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index a5a6c8792..05a8ea647 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -172,7 +172,7 @@ def test_reload_match(tmp_path, num_bits): save_file(compressed_state_dict, tmp_path / "model.safetensors") reconstructed_dense = {} - with tempfile.TemporaryDirectory() as _tmp: + with tempfile.TemporaryDirectory(): reconstructed_dense_gen = compressor.decompress( tmp_path, names_to_scheme=quantized_modules_to_scheme ) @@ -486,8 +486,8 @@ def test_unpack_from_int32(num_bits, values, expected_tensor): ) def test_asymmetric_zero_point_decompression(strategy, group_size, tmp_path): """ - Test that zero-point packing and unpacking works correctly for asymmetric quantization - with GROUP and CHANNEL strategies. + Test that zero-point packing and unpacking works correctly for asymmetric + quantization with GROUP and CHANNEL strategies. """ shape = (512, 1024) @@ -550,10 +550,8 @@ def test_zero_point_pack_unpack_consistency(num_bits, strategy): """ if strategy == QuantizationStrategy.GROUP: shape = (512, 8) - group_size = 128 else: shape = (512, 1) - group_size = None max_val = (1 << (num_bits - 1)) - 1 min_val = -(1 << (num_bits - 1)) diff --git a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py index b6ffcc92d..fb85bedbd 100644 --- a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py +++ b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py @@ -18,6 +18,10 @@ import pytest import torch +from compressed_tensors.compressors.model_compressors.model_compressor import ( + ModelCompressor, +) +from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import ( QuantizationArgs, QuantizationConfig, @@ -25,10 +29,6 @@ QuantizationStrategy, apply_quantization_config, ) -from compressed_tensors.config import CompressionFormat -from compressed_tensors.compressors.model_compressors.model_compressor import ( - ModelCompressor, -) from torch.nn import Linear, Module From e157827f533bddf9f4989184aa7eceb840a7a2ec Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 23 Oct 2025 13:04:52 -0500 Subject: [PATCH 09/11] style fixes Signed-off-by: Brian Dellabetta --- .../test_compressors/quantized_compressors/test_fp4_quant.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_compressors/quantized_compressors/test_fp4_quant.py b/tests/test_compressors/quantized_compressors/test_fp4_quant.py index 3169691be..8bd4203ac 100644 --- a/tests/test_compressors/quantized_compressors/test_fp4_quant.py +++ b/tests/test_compressors/quantized_compressors/test_fp4_quant.py @@ -54,8 +54,4 @@ def test_pack_unpack_odd_dims(): ) with pytest.raises((ValueError, torch._dynamo.exc.Unsupported)): -<<<<<<< Updated upstream:tests/test_compressors/quantized_compressors/test_fp4_quant.py - _ = pack_fp4_to_uint8(x) -======= pack_fp4_to_uint8(x) ->>>>>>> Stashed changes:tests/test_compressors/quantized_compressors/test_nvfp4_quant.py From 76daf284d673e62cd59abf24f9c93d21c86a64a6 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 23 Oct 2025 13:12:05 -0500 Subject: [PATCH 10/11] style fixes Signed-off-by: Brian Dellabetta --- .../compressors/model_compressors/model_compressor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 27cfec7fa..fde3e1954 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -338,10 +338,10 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: - self.quantization_compressor[format] = ( - BaseCompressor.load_from_registry( - format, config=quantization_config - ) + self.quantization_compressor[ + format + ] = BaseCompressor.load_from_registry( + format, config=quantization_config ) def get_missing_module_keys(self, model: Module) -> List[str]: From fc4afa1641d206fb9e7ee2475802c53161bb4ca4 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 23 Oct 2025 13:12:37 -0500 Subject: [PATCH 11/11] style fixes Signed-off-by: Brian Dellabetta --- tests/test_compressors/quantized_compressors/test_fp4_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_compressors/quantized_compressors/test_fp4_quant.py b/tests/test_compressors/quantized_compressors/test_fp4_quant.py index 8bd4203ac..bb405f79d 100644 --- a/tests/test_compressors/quantized_compressors/test_fp4_quant.py +++ b/tests/test_compressors/quantized_compressors/test_fp4_quant.py @@ -54,4 +54,4 @@ def test_pack_unpack_odd_dims(): ) with pytest.raises((ValueError, torch._dynamo.exc.Unsupported)): - pack_fp4_to_uint8(x) + _ = pack_fp4_to_uint8(x)