Skip to content

Commit 92f8757

Browse files
multi-apply quantization config test
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent fb6aa9a commit 92f8757

File tree

1 file changed

+102
-1
lines changed

1 file changed

+102
-1
lines changed

tests/test_quantization/lifecycle/test_apply.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@
2323
DEFAULT_QUANTIZATION_METHOD,
2424
QuantizationConfig,
2525
QuantizationStatus,
26+
QuantizationScheme,
27+
QuantizationArgs,
28+
QuantizationStrategy,
29+
QuantizationType,
2630
)
2731
from compressed_tensors.quantization.lifecycle import apply_quantization_config
28-
from compressed_tensors.utils import match_named_modules
32+
from compressed_tensors.utils import match_named_modules, is_match
2933
from tests.testing_utils import requires_accelerate
3034
from transformers import AutoModelForCausalLM
3135

@@ -265,3 +269,100 @@ def test_apply_quantization_config(caplog, target, should_raise_warning):
265269
assert len(caplog.text) > 0
266270
else:
267271
assert len(caplog.text) == 0
272+
273+
274+
def test_multi_apply_quantization_config():
275+
"""
276+
Ensure that multiple quantization configs are applied correctly
277+
If quantization config was previously applied to a module,
278+
those changes should be reset for newly applied quantization config
279+
"""
280+
model = get_tinyllama_model()
281+
282+
# FP8 applied to mlp and self_attn.o_proj to validate overwriting
283+
qconfig1 = QuantizationConfig(
284+
config_groups={
285+
"group_0": QuantizationScheme(
286+
targets=[
287+
r"re:.*model\.layers\.\d+\.mlp\.(down|gate|up)_proj$",
288+
r"re:.*model\.layers\.\d+\.self_attn\.o_proj$",
289+
],
290+
weights=QuantizationArgs(
291+
num_bits=8,
292+
type=QuantizationType.FLOAT,
293+
strategy=QuantizationStrategy.TENSOR,
294+
symmetric=True,
295+
dynamic=False,
296+
),
297+
input_activations=QuantizationArgs(
298+
num_bits=8,
299+
type=QuantizationType.FLOAT,
300+
strategy=QuantizationStrategy.TENSOR,
301+
symmetric=True,
302+
dynamic=False,
303+
),
304+
)
305+
},
306+
ignore=["lm_head"],
307+
)
308+
# W4A16_ASYM applied to self_attn
309+
qconfig2 = QuantizationConfig(
310+
config_groups={
311+
"group_0": QuantizationScheme(
312+
targets=[
313+
r"re:.*model\.layers\.\d+\.self_attn\.(k|q|o|v)_proj$",
314+
],
315+
weights=QuantizationArgs(
316+
num_bits=4,
317+
type=QuantizationType.INT,
318+
strategy=QuantizationStrategy.GROUP,
319+
group_size=128,
320+
symmetric=False,
321+
dynamic=False,
322+
),
323+
)
324+
},
325+
ignore=["lm_head"],
326+
)
327+
328+
apply_quantization_config(model, qconfig1)
329+
apply_quantization_config(model, qconfig2)
330+
for name, module in model.named_modules():
331+
if is_match(
332+
name, module, qconfig2.config_groups["group_0"].targets, qconfig2.ignore
333+
):
334+
# assert W4A16_ASYM parameters are present with correct shape
335+
# and FP8 parameters have been removed
336+
assert not hasattr(module, "input_scale")
337+
assert not hasattr(module, "input_zero_point")
338+
weight_scale = getattr(module, "weight_scale", None)
339+
assert (
340+
weight_scale is not None
341+
and weight_scale.shape[:-1] == module.weight.shape[:-1]
342+
and weight_scale.shape[-1] == module.weight.shape[-1] / 128
343+
)
344+
weight_zero_point = getattr(module, "weight_zero_point", None)
345+
assert (
346+
weight_zero_point is not None
347+
and weight_zero_point.shape[:-1] == module.weight.shape[:-1]
348+
and weight_zero_point.shape[-1] == module.weight.shape[-1] / 128
349+
)
350+
351+
elif is_match(
352+
name, module, qconfig1.config_groups["group_0"].targets, qconfig1.ignore
353+
):
354+
# assert FP8 scheme parameters are present with correct shape
355+
input_scale = getattr(module, "input_scale", None)
356+
assert input_scale is not None and input_scale.shape == torch.Size([1])
357+
input_zero_point = getattr(module, "input_zero_point", None)
358+
assert (
359+
input_zero_point is not None
360+
and input_zero_point.shape == torch.Size([1])
361+
)
362+
weight_scale = getattr(module, "weight_scale", None)
363+
assert weight_scale is not None and weight_scale.shape == torch.Size([1])
364+
weight_zero_point = getattr(module, "weight_zero_point", None)
365+
assert (
366+
weight_zero_point is not None
367+
and weight_zero_point.shape == torch.Size([1])
368+
)

0 commit comments

Comments
 (0)