|
23 | 23 | DEFAULT_QUANTIZATION_METHOD,
|
24 | 24 | QuantizationConfig,
|
25 | 25 | QuantizationStatus,
|
| 26 | + QuantizationScheme, |
| 27 | + QuantizationArgs, |
| 28 | + QuantizationStrategy, |
| 29 | + QuantizationType, |
26 | 30 | )
|
27 | 31 | from compressed_tensors.quantization.lifecycle import apply_quantization_config
|
28 |
| -from compressed_tensors.utils import match_named_modules |
| 32 | +from compressed_tensors.utils import match_named_modules, is_match |
29 | 33 | from tests.testing_utils import requires_accelerate
|
30 | 34 | from transformers import AutoModelForCausalLM
|
31 | 35 |
|
@@ -265,3 +269,100 @@ def test_apply_quantization_config(caplog, target, should_raise_warning):
|
265 | 269 | assert len(caplog.text) > 0
|
266 | 270 | else:
|
267 | 271 | assert len(caplog.text) == 0
|
| 272 | + |
| 273 | + |
| 274 | +def test_multi_apply_quantization_config(): |
| 275 | + """ |
| 276 | + Ensure that multiple quantization configs are applied correctly |
| 277 | + If quantization config was previously applied to a module, |
| 278 | + those changes should be reset for newly applied quantization config |
| 279 | + """ |
| 280 | + model = get_tinyllama_model() |
| 281 | + |
| 282 | + # FP8 applied to mlp and self_attn.o_proj to validate overwriting |
| 283 | + qconfig1 = QuantizationConfig( |
| 284 | + config_groups={ |
| 285 | + "group_0": QuantizationScheme( |
| 286 | + targets=[ |
| 287 | + r"re:.*model\.layers\.\d+\.mlp\.(down|gate|up)_proj$", |
| 288 | + r"re:.*model\.layers\.\d+\.self_attn\.o_proj$", |
| 289 | + ], |
| 290 | + weights=QuantizationArgs( |
| 291 | + num_bits=8, |
| 292 | + type=QuantizationType.FLOAT, |
| 293 | + strategy=QuantizationStrategy.TENSOR, |
| 294 | + symmetric=True, |
| 295 | + dynamic=False, |
| 296 | + ), |
| 297 | + input_activations=QuantizationArgs( |
| 298 | + num_bits=8, |
| 299 | + type=QuantizationType.FLOAT, |
| 300 | + strategy=QuantizationStrategy.TENSOR, |
| 301 | + symmetric=True, |
| 302 | + dynamic=False, |
| 303 | + ), |
| 304 | + ) |
| 305 | + }, |
| 306 | + ignore=["lm_head"], |
| 307 | + ) |
| 308 | + # W4A16_ASYM applied to self_attn |
| 309 | + qconfig2 = QuantizationConfig( |
| 310 | + config_groups={ |
| 311 | + "group_0": QuantizationScheme( |
| 312 | + targets=[ |
| 313 | + r"re:.*model\.layers\.\d+\.self_attn\.(k|q|o|v)_proj$", |
| 314 | + ], |
| 315 | + weights=QuantizationArgs( |
| 316 | + num_bits=4, |
| 317 | + type=QuantizationType.INT, |
| 318 | + strategy=QuantizationStrategy.GROUP, |
| 319 | + group_size=128, |
| 320 | + symmetric=False, |
| 321 | + dynamic=False, |
| 322 | + ), |
| 323 | + ) |
| 324 | + }, |
| 325 | + ignore=["lm_head"], |
| 326 | + ) |
| 327 | + |
| 328 | + apply_quantization_config(model, qconfig1) |
| 329 | + apply_quantization_config(model, qconfig2) |
| 330 | + for name, module in model.named_modules(): |
| 331 | + if is_match( |
| 332 | + name, module, qconfig2.config_groups["group_0"].targets, qconfig2.ignore |
| 333 | + ): |
| 334 | + # assert W4A16_ASYM parameters are present with correct shape |
| 335 | + # and FP8 parameters have been removed |
| 336 | + assert not hasattr(module, "input_scale") |
| 337 | + assert not hasattr(module, "input_zero_point") |
| 338 | + weight_scale = getattr(module, "weight_scale", None) |
| 339 | + assert ( |
| 340 | + weight_scale is not None |
| 341 | + and weight_scale.shape[:-1] == module.weight.shape[:-1] |
| 342 | + and weight_scale.shape[-1] == module.weight.shape[-1] / 128 |
| 343 | + ) |
| 344 | + weight_zero_point = getattr(module, "weight_zero_point", None) |
| 345 | + assert ( |
| 346 | + weight_zero_point is not None |
| 347 | + and weight_zero_point.shape[:-1] == module.weight.shape[:-1] |
| 348 | + and weight_zero_point.shape[-1] == module.weight.shape[-1] / 128 |
| 349 | + ) |
| 350 | + |
| 351 | + elif is_match( |
| 352 | + name, module, qconfig1.config_groups["group_0"].targets, qconfig1.ignore |
| 353 | + ): |
| 354 | + # assert FP8 scheme parameters are present with correct shape |
| 355 | + input_scale = getattr(module, "input_scale", None) |
| 356 | + assert input_scale is not None and input_scale.shape == torch.Size([1]) |
| 357 | + input_zero_point = getattr(module, "input_zero_point", None) |
| 358 | + assert ( |
| 359 | + input_zero_point is not None |
| 360 | + and input_zero_point.shape == torch.Size([1]) |
| 361 | + ) |
| 362 | + weight_scale = getattr(module, "weight_scale", None) |
| 363 | + assert weight_scale is not None and weight_scale.shape == torch.Size([1]) |
| 364 | + weight_zero_point = getattr(module, "weight_zero_point", None) |
| 365 | + assert ( |
| 366 | + weight_zero_point is not None |
| 367 | + and weight_zero_point.shape == torch.Size([1]) |
| 368 | + ) |
0 commit comments