Skip to content

Commit 4ecc2af

Browse files
HDCharleskylesayrsfynnsu
authored
[AWQ] use match_modules_set and fix logic (#2070)
Depends on vllm-project/compressed-tensors#524 Summary: - modified AWQ _set_resolved_mappings - get smoothing and balance layers at same time using match_modules_set - (bugfix) correct logic so that if any balance layers are incompatible, that matching is skipped - added warnings - get rid of tqdm and skip counting @kylesayrs - added helper for module_to_name - remove hardcoded handling for single balance layer by updating get_lowest_common_module to handle that - modified SmoothQuant _resolve_mappings - brought into alignment with AWQ - this is largely a horizontal move though there is handling for situations that would have been missed before like - multiple smooth layer matches in a single set - parent contexts further than 1 layer away. - updated mapping definitions to always be tuple(list[str],str) which is always the case but wasn't required unlike in AWQ - removed get_lowest_common_parent - now we can use CT's get_lowest_common_ancestor_name so only need to check for module_list (it has a lot of bugfixes compared to the get_lowest_common_parent implementation in LLMC) - updated test_base for AWQ and smoothquant - added test case for _set_resolved_mappings to check that partially skipped matches are handled correctly - added tests for MoE matching being handled correctly - added test cases for get_lowest_non_module_list_ancestor - imported Linear and used that instead of torch.nn.Linear - reverted test_pytorch.py for logarithmic_equalizations and smoothquant - The test was updated in #2084 by @rahul-tuli to ignore some modules but in general because of the way the new logic works, you need to ignore the whole set. - if you only ignore one element the matching logic would need to determine whether there's a full set or not *somehow* which it doesn't do. In the previous logic, this was possible because it was assumed the whole set had to be siblings of the smooth_layer, but the new util is trying to be more flexible and so relaxes this assumption which prevents the same approach from working. If this is a common need, perhaps we can add a util that checks for a context parent context of size N or something. TEST PLAN: pytest /home/HDCharles/repos/llm-compressor/tests/llmcompressor/modifiers/awq/test_base.py pytest /home/HDCharles/repos/llm-compressor/tests/llmcompressor/modifiers/smoothquant/test_base.py --------- Signed-off-by: HDCharles <[email protected]> Signed-off-by: HDCharles <[email protected]> Co-authored-by: Kyle Sayers <[email protected]> Co-authored-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 0479bdf commit 4ecc2af

File tree

8 files changed

+221
-177
lines changed

8 files changed

+221
-177
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 99 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
import inspect
22
from itertools import product
3-
from typing import Literal
3+
from typing import Iterator, Literal
44

55
import torch
66
from compressed_tensors.quantization import disable_quantization
77
from compressed_tensors.utils import (
88
align_modules,
99
get_execution_device,
10+
get_lowest_common_ancestor_name,
11+
match_modules_set,
1012
match_named_modules,
1113
update_offload_parameter,
1214
)
1315
from loguru import logger
1416
from pydantic import ConfigDict, PrivateAttr, model_validator
1517
from torch.nn import Module
18+
from torch.utils._pytree import tree_leaves
1619
from tqdm import tqdm
1720

1821
from llmcompressor.core import Event, EventType, State
@@ -28,7 +31,9 @@
2831
from llmcompressor.pipelines.cache import IntermediatesCache
2932
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
3033
from llmcompressor.utils.helpers import calibration_forward_context
31-
from llmcompressor.utils.pytorch.module import get_layer_by_name
34+
from llmcompressor.utils.pytorch.module import (
35+
get_module_to_name_dict,
36+
)
3237

3338
__all__ = ["AWQModifier"]
3439

@@ -319,73 +324,57 @@ def _set_resolved_mappings(self, model: Module) -> None:
319324
repeat for model.layer.1 and so on
320325
"""
321326
resolved_mappings: list[ResolvedMapping] = []
322-
for mapping_idx, mapping in enumerate(self.mappings):
323-
num_skipped_mappings = 0
324-
325-
for smooth_name, smooth_layer in (
326-
pbar := tqdm(
327-
match_named_modules(model, [mapping.smooth_layer], self.ignore)
328-
)
327+
module_to_name = get_module_to_name_dict(model)
328+
for mapping in self.mappings:
329+
for smooth_layers, *nested_balance_layers in match_modules_set(
330+
model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore
329331
):
330-
pbar.set_description(
331-
f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
332-
f" ({num_skipped_mappings} skipped)"
332+
if len(smooth_layers) > 1:
333+
raise ValueError(
334+
"AWQ needs to match a single smoothlayer for each mapping but "
335+
f"got {[module_to_name.get(s) for s in smooth_layers]}"
336+
f" for mapping: {mapping}"
337+
)
338+
smooth_layer = smooth_layers[0]
339+
smooth_name = module_to_name.get(smooth_layer)
340+
341+
# [[b00, b01, b02...], [b10, b11, b12,...], ...] ↓
342+
# [b00, b01, b02, ..., b10, b11, b12, ...]
343+
balance_layers = tree_leaves(nested_balance_layers)
344+
balance_names = [
345+
module_to_name.get(balance_layer)
346+
for balance_layer in balance_layers
347+
]
348+
349+
all_compatible = _check_layers_are_compatible(
350+
smooth_layer, smooth_name, balance_layers, balance_names
333351
)
334352

335-
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
336-
smooth_parent = get_layer_by_name(smooth_parent_name, model)
337-
338-
balance_layers, balance_names = [], []
339-
for balance_regex in mapping.balance_layers:
340-
# find the submodules that match the activation layer
341-
for balance_suffix, balance_layer in match_named_modules(
342-
smooth_parent, [balance_regex], self.ignore
343-
):
344-
balance_name = f"{smooth_parent_name}.{balance_suffix}"
345-
346-
# exclude v_proj->o_proj mappings whose shapes are incompatible
347-
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
348-
if (
349-
isinstance(smooth_layer, torch.nn.Linear)
350-
and isinstance(balance_layer, torch.nn.Linear)
351-
and balance_name.endswith(".o_proj")
352-
and (
353-
(
354-
smooth_name.endswith(".v_proj")
355-
and smooth_layer.out_features
356-
!= balance_layer.in_features
357-
)
358-
or (
359-
smooth_name.endswith(".qkv_proj")
360-
and smooth_layer.out_features
361-
!= 3 * balance_layer.in_features
362-
)
363-
)
364-
):
365-
num_skipped_mappings += 1
366-
continue
367-
368-
balance_layers.append(balance_layer)
369-
balance_names.append(balance_name)
353+
# skip mapping if any of the balance layers are incompatible
354+
if not all_compatible or len(balance_layers) == 0:
355+
logger.warning(
356+
f"skipping AWQ for {smooth_name} for mapping {mapping}"
357+
+ (
358+
" because found incompatible balance layers"
359+
if not all_compatible
360+
else " because no balance layers were found"
361+
)
362+
)
370363

371-
if len(balance_layers) == 0:
372364
continue
373365

374-
elif len(balance_layers) == 1:
375-
# for single balance layer, parent is the balance layer
376-
parent_name, parent = balance_name, balance_layer
377-
else:
378-
# for multiple balance layers, find lowest common parent
379-
parent_name, parent = get_lowest_common_parent(balance_names, model)
366+
ancestor_name, ancestor = get_lowest_common_ancestor_with_avoid(
367+
balance_names, model, torch.nn.ModuleList
368+
)
380369

381370
resolved_mappings.append(
382371
ResolvedMapping(
383372
smooth_name,
384373
smooth_layer,
385374
balance_layers,
386375
balance_names=balance_names,
387-
parent=parent,
388-
parent_name=parent_name,
376+
parent=ancestor,
377+
parent_name=ancestor_name,
389378
)
390379
)
391380
self._resolved_mappings = resolved_mappings
@@ -721,6 +710,60 @@ def _assert_all_activations_consumed(self):
721710
raise RuntimeError("Some cached activations were not used")
722711

723712

713+
def _check_layers_are_compatible(
714+
smooth_layer, smooth_name, balance_layers, balance_names
715+
):
716+
"""
717+
returns True if they are all compatible
718+
returns False if any smooth & balance layers are incompatible
719+
"""
720+
for balance_layer, balance_name in zip(balance_layers, balance_names):
721+
# exclude v_proj->o_proj mappings whose shapes are incompatible
722+
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
723+
if (
724+
isinstance(smooth_layer, torch.nn.Linear)
725+
and isinstance(balance_layer, torch.nn.Linear)
726+
and balance_name.endswith(".o_proj")
727+
and (
728+
(
729+
smooth_name.endswith(".v_proj")
730+
and smooth_layer.out_features != balance_layer.in_features
731+
)
732+
or (
733+
smooth_name.endswith(".qkv_proj")
734+
and smooth_layer.out_features != 3 * balance_layer.in_features
735+
)
736+
)
737+
):
738+
return False
739+
return True
740+
741+
742+
def get_lowest_common_ancestor_with_avoid(
743+
balance_names: Iterator[str], model: Module, avoid=torch.nn.ModuleList
744+
):
745+
"""
746+
Get the lowest ancestor that is not the avoided class/type.
747+
see compressed_tensors.utils.get_lowest_common_ancestor_name
748+
for detail on case handling.
749+
750+
NOTE: primarily used to exclude parents of type ModuleList, which don't play
751+
nicely with hooks because their forward method is never directly
752+
called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
753+
are selected based on router output and their forward method is called.
754+
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233
755+
"""
756+
ancestor_name = get_lowest_common_ancestor_name(balance_names)
757+
758+
while True:
759+
if ancestor_name == "":
760+
return "", model
761+
ancestor = model.get_submodule(ancestor_name)
762+
if not isinstance(ancestor, avoid):
763+
return ancestor_name, ancestor
764+
ancestor_name = ".".join(ancestor_name.split(".")[:-1])
765+
766+
724767
def _pseudo_quantize_tensor(
725768
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
726769
):
@@ -779,35 +822,3 @@ def _accumulate_mean(
779822
new_count = prev_count + num_added
780823

781824
return (prev_sum + sum_added) / new_count, new_count
782-
783-
784-
def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]:
785-
"""
786-
Given a list of names, returns the lowest-scope common parent.
787-
788-
NOTE: function excludes parents of type ModuleList, which don't play
789-
nicely with hooks because their forward method is never directly
790-
called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
791-
are selected based on router output and their forward method is called.
792-
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233
793-
794-
Returns name of parent and pointer to parent module
795-
796-
Implementation is a small alteration of os.path.commonprefix
797-
https://docs.python.org/3/library/os.path.html#os.path.commonprefix
798-
"""
799-
s1 = min(names)
800-
s2 = max(names)
801-
parent_name = ""
802-
for i, c in enumerate(s1):
803-
if c != s2[i]:
804-
parent_name = s1[:i].rstrip(".")
805-
break
806-
807-
while True:
808-
if parent_name == "":
809-
return "", module
810-
parent = get_layer_by_name(parent_name, module)
811-
if not isinstance(parent, torch.nn.ModuleList):
812-
return parent_name, parent
813-
parent_name = ".".join(parent_name.split(".")[:-1])

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from typing import Callable, Dict, List, Optional, Tuple, Union
33

44
import torch
5-
from compressed_tensors.utils import align_module_device, match_named_modules
5+
from compressed_tensors.utils import align_module_device, match_modules_set
66
from loguru import logger
77
from pydantic import ConfigDict, Field
88
from torch.nn import Module
9+
from torch.utils._pytree import tree_leaves
910

1011
from llmcompressor.core import Event, EventType, State
1112
from llmcompressor.modifiers import Modifier
@@ -14,7 +15,7 @@
1415
handle_mapping_resolution_errors,
1516
)
1617
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
17-
from llmcompressor.utils.pytorch.module import get_layer_by_name
18+
from llmcompressor.utils.pytorch.module import get_module_to_name_dict
1819

1920
MINIMUM_SMOOTHING_SCALE = 1e-5
2021

@@ -198,27 +199,23 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
198199
be balanced.
199200
"""
200201
resolved_mappings = []
201-
for to_balance, to_smooth in self.mappings:
202-
to_smooth_list = [to_smooth] if isinstance(to_smooth, str) else to_smooth
203-
204-
for smooth_name, smooth_layer in match_named_modules(
205-
model, to_smooth_list, self.ignore
202+
module_to_name = get_module_to_name_dict(model)
203+
for mapping in self.mappings:
204+
for *nested_balance_layers, smooth_layers in match_modules_set(
205+
model, tree_leaves(mapping), self.ignore
206206
):
207-
# Search for balance layers within the parent scope
208-
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
209-
smooth_parent = get_layer_by_name(smooth_parent_name, model)
210-
211-
balance_layers = [
212-
balance_layer
213-
for _, balance_layer in match_named_modules(
214-
smooth_parent, to_balance, self.ignore
215-
)
216-
]
217-
218-
if balance_layers:
219-
resolved_mappings.append(
220-
SmoothQuantMapping(smooth_name, smooth_layer, balance_layers)
207+
if len(smooth_layers) > 1:
208+
raise ValueError(
209+
"SmoothQuant must match a single smooth layer for each mapping"
210+
f" but got {[module_to_name.get(s) for s in smooth_layers]}"
211+
f" for mapping: {mapping}"
221212
)
213+
smooth_layer = smooth_layers[0]
214+
smooth_name = module_to_name.get(smooth_layers[0])
215+
balance_layers = tree_leaves(nested_balance_layers)
216+
resolved_mappings.append(
217+
SmoothQuantMapping(smooth_name, smooth_layer, balance_layers)
218+
)
222219

223220
return resolved_mappings
224221

src/llmcompressor/modifiers/smoothquant/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import functools
22
from collections import namedtuple
3-
from typing import Dict, List, Tuple, Union
3+
from typing import Dict, List, Tuple
44

55
from loguru import logger
66

@@ -10,7 +10,7 @@
1010
"DEFAULT_SMOOTHQUANT_MAPPINGS",
1111
]
1212

13-
LayerMapType = Tuple[Union[List[str], str], Union[List[str], str]]
13+
LayerMapType = Tuple[List[str], str]
1414
LayerMap: LayerMapType = namedtuple("LayerMap", ["balance_layers", "smooth_layers"])
1515

1616
DEFAULT_SMOOTHQUANT_MAPPINGS: List[LayerMap] = [

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from compressed_tensors.utils import TorchDtype, get_head_dim
1313
from pydantic import Field, ValidationInfo, field_validator
14+
from torch.utils._pytree import tree_leaves
1415
from transformers import PreTrainedModel
1516

1617
from llmcompressor.core import Event, EventType, State
@@ -205,7 +206,9 @@ def _fuse_norms(self, model: PreTrainedModel):
205206
for norm, *linears in match_modules_set(
206207
model, (mapping.norm, *mapping.linears)
207208
):
208-
fuse_norm_linears(norm, linears)
209+
# match_modules_set returns a list of lists
210+
assert len(norm) == 1
211+
fuse_norm_linears(norm[0], tree_leaves(linears))
209212

210213
def _create_r1_scheme(self) -> TransformScheme:
211214
return TransformScheme(

src/llmcompressor/utils/pytorch/module.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from compressed_tensors import InternalModule
1212
from compressed_tensors.quantization.utils import is_module_quantized
13+
from loguru import logger
1314
from torch.nn import Linear, Module, Parameter
1415
from torch.nn.modules.conv import _ConvNd
1516
from transformers import PreTrainedModel
@@ -369,3 +370,15 @@ def get_layer_by_name(layer_name: str, module: Module) -> Module:
369370
if not layer_name:
370371
return module
371372
return attrgetter(layer_name)(module)
373+
374+
375+
def get_module_to_name_dict(model: Module) -> dict[Module, str]:
376+
module_to_name = {}
377+
for name, module in model.named_modules():
378+
if module in module_to_name:
379+
logger.warning(
380+
f"Warning, {name} and {module_to_name[module]} both "
381+
"share the same module, which can result in unexpected behavior"
382+
)
383+
module_to_name[module] = name
384+
return module_to_name

0 commit comments

Comments
 (0)