Skip to content

Commit dea5eab

Browse files
committed
updates to get_lowest_common_x
Summary Signed-off-by: HDCharles <[email protected]>
1 parent 351568d commit dea5eab

File tree

2 files changed

+84
-48
lines changed

2 files changed

+84
-48
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -320,21 +320,26 @@ def _set_resolved_mappings(self, model: Module) -> None:
320320
repeat for model.layer.1 and so on
321321
"""
322322
resolved_mappings: list[ResolvedMapping] = []
323-
module_to_name = {module: name for name, module in model.named_modules()}
324-
for mapping_idx, mapping in enumerate(self.mappings):
325-
num_skipped_mappings = 0
323+
324+
module_to_name = {}
325+
for name, module in model.named_modules():
326+
if module in module_to_name:
327+
logger.info(
328+
f"Warning, {name} and {module_to_name[module]} both "
329+
"share the same module the same module, "
330+
"may have trouble resolving mappings."
331+
)
332+
module_to_name[module] = name
333+
334+
335+
336+
for mapping in self.mappings:
326337

327-
# Use match_modules_set to find coherent sets of modules
328338
target_patterns = (mapping.smooth_layer, *mapping.balance_layers)
329339

330340
for smooth_layer, *balance_layers in (
331-
pbar := tqdm(match_modules_set(model, target_patterns, self.ignore))
341+
match_modules_set(model, target_patterns, self.ignore)
332342
):
333-
pbar.set_description(
334-
f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
335-
f" ({num_skipped_mappings} skipped)"
336-
)
337-
338343
smooth_name = module_to_name.get(smooth_layer)
339344
balance_names = [
340345
module_to_name.get(balance_layer)
@@ -347,14 +352,18 @@ def _set_resolved_mappings(self, model: Module) -> None:
347352

348353
# skip mapping if any of the balance layers are incompatible
349354
if not all_compatible or len(balance_layers) == 0:
350-
num_skipped_mappings += 1
355+
logger.info(
356+
f"skipping AWQ for {smooth_name} for mapping {mapping}" + (
357+
" because found incompatible balance layers"
358+
if not all_compatible else
359+
f" because no balance layers were found"
360+
)
361+
)
362+
351363
continue
352-
elif len(balance_layers) == 1:
353-
# for single balance layer, parent is the balance layer
354-
parent_name, parent = balance_names[0], balance_layers[0]
355364
else:
356365
# for multiple balance layers, find lowest common parent
357-
parent_name, parent = get_lowest_common_parent(balance_names, model)
366+
parent_name, parent = get_lowest_common_module(balance_names, model)
358367

359368
resolved_mappings.append(
360369
ResolvedMapping(
@@ -788,29 +797,41 @@ def _accumulate_mean(
788797
return (prev_sum + sum_added) / new_count, new_count
789798

790799

791-
def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]:
800+
def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Module]:
792801
"""
793-
Given a list of names, returns the lowest-scope common parent.
802+
Given a list of names, returns the lowest-scope common module.
794803
795-
NOTE: function excludes parents of type ModuleList, which don't play
804+
NOTE: function excludes modules of type ModuleList, which don't play
796805
nicely with hooks because their forward method is never directly
797806
called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
798807
are selected based on router output and their forward method is called.
799808
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233
800809
801-
Returns name of parent and pointer to parent module
810+
Returns name of module and pointer to module
802811
803812
Implementation is a small alteration of os.path.commonprefix
804813
https://docs.python.org/3/library/os.path.html#os.path.commonprefix
805814
"""
806-
s1 = min(names)
807-
s2 = max(names)
808-
parent_name = ""
815+
# adding "." before and after allows for handling a lot of corner
816+
# cases which were previously mishandled ([case]->prefix->result)
817+
# case 0: single module: [.abc.] -> .abc. -> abc
818+
# case 1: substring modules: [.abc., .ab.] -> .ab -> ""
819+
# case 2: parent & child: [.ab., .ab.a.] -> .ab. -> ab
820+
s1 = min(names) + "."
821+
s2 = max(names) + "."
822+
823+
# 1) find longest shared prefix
824+
parent_name = "."
809825
for i, c in enumerate(s1):
810826
if c != s2[i]:
811-
parent_name = s1[:i].rstrip(".")
812827
break
828+
parent_name += c
829+
830+
# 2) throw away module name fragment and leading dot
831+
# ".keep.thro" -> "keep"
832+
parent_name = parent_name[1:parent_name.rfind(".")]
813833

834+
# 3) return first parent that is not a module list
814835
while True:
815836
if parent_name == "":
816837
return "", module

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import torch
33
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
44
from pydantic import ValidationError
5-
5+
from torch.nn import Linear
66
from llmcompressor.modifiers.awq import AWQMapping, AWQModifier
7-
from llmcompressor.modifiers.awq.base import get_lowest_common_parent
7+
from llmcompressor.modifiers.awq.base import get_lowest_common_module
88
from llmcompressor.modifiers.factory import ModifierFactory
99

1010

@@ -40,16 +40,16 @@ def test_set_resolved_mappings():
4040
)
4141
self_attn = torch.nn.ModuleDict(
4242
{
43-
"q_proj": torch.nn.Linear(4, 4),
44-
"k_proj": torch.nn.Linear(4, 4),
45-
"v_proj": torch.nn.Linear(4, 4),
46-
"o_proj": torch.nn.Linear(4, 4),
43+
"q_proj": Linear(4, 4),
44+
"k_proj": Linear(4, 4),
45+
"v_proj": Linear(4, 4),
46+
"o_proj": Linear(4, 4),
4747
}
4848
)
4949
mlp = torch.nn.ModuleDict(
5050
{
51-
"up_proj": torch.nn.Linear(4, 10),
52-
"down_proj": torch.nn.Linear(10, 4),
51+
"up_proj": Linear(4, 10),
52+
"down_proj": Linear(10, 4),
5353
}
5454
)
5555
model = torch.nn.ModuleDict(
@@ -100,11 +100,11 @@ def test_set_resolved_mappings():
100100
{
101101
"self_attn": torch.nn.ModuleDict(
102102
{
103-
"q_proj": torch.nn.Linear(4, 2),
104-
"k_proj": torch.nn.Linear(4, 2),
105-
"v_proj": torch.nn.Linear(4, 2),
106-
"z_proj": torch.nn.Linear(2, 4),
107-
"o_proj": torch.nn.Linear(4, 4),
103+
"q_proj": Linear(4, 2),
104+
"k_proj": Linear(4, 2),
105+
"v_proj": Linear(4, 2),
106+
"z_proj": Linear(2, 4),
107+
"o_proj": Linear(4, 4),
108108
}
109109
)
110110
}
@@ -192,15 +192,15 @@ def test_validate():
192192

193193

194194
@pytest.mark.unit
195-
def test_get_lowest_common_parent():
195+
def test_get_lowest_common_module():
196196
mlp = torch.nn.ModuleDict(
197197
{
198198
"experts": torch.nn.ModuleList(
199199
[
200200
torch.nn.ModuleDict(
201201
{
202-
"gate_proj": torch.nn.Linear(4, 2),
203-
"down_proj": torch.nn.Linear(4, 2),
202+
"gate_proj": Linear(4, 2),
203+
"down_proj": Linear(4, 2),
204204
}
205205
)
206206
for _ in range(10)
@@ -210,15 +210,15 @@ def test_get_lowest_common_parent():
210210
)
211211
self_attn = torch.nn.ModuleDict(
212212
{
213-
"q_proj": torch.nn.Linear(4, 2),
214-
"k_proj": torch.nn.Linear(4, 2),
215-
"v_proj": torch.nn.Linear(4, 2),
216-
"o_proj": torch.nn.Linear(4, 4),
213+
"q_proj": Linear(4, 2),
214+
"k_proj": Linear(4, 2),
215+
"v_proj": Linear(4, 2),
216+
"o_proj": Linear(4, 4),
217217
}
218218
)
219219
model = torch.nn.ModuleDict(
220220
{
221-
"embed_tokens": torch.nn.Linear(4, 2),
221+
"embed_tokens": Linear(4, 2),
222222
"decoder": torch.nn.ModuleDict(
223223
{
224224
"self_attn": self_attn,
@@ -228,22 +228,37 @@ def test_get_lowest_common_parent():
228228
}
229229
)
230230

231-
parent_name, parent = get_lowest_common_parent(
231+
parent_name, parent = get_lowest_common_module(
232232
["decoder.mlp.experts.1.gate_proj", "decoder.mlp.experts.4.down_proj"], model
233233
)
234234
assert parent_name == "decoder.mlp" and parent == mlp
235235

236-
parent_name, parent = get_lowest_common_parent(
236+
parent_name, parent = get_lowest_common_module(
237237
["decoder.self_attn.q_proj", "decoder.self_attn.v_proj"], model
238238
)
239239
assert parent_name == "decoder.self_attn" and parent == self_attn
240240

241-
parent_name, parent = get_lowest_common_parent(
241+
parent_name, parent = get_lowest_common_module(
242242
["decoder.mlp.experts.1.gate_proj", "decoder.self_attn.v_proj"], model
243243
)
244244
assert parent_name == "decoder" and parent == model["decoder"]
245245

246-
parent_name, parent = get_lowest_common_parent(
246+
parent_name, parent = get_lowest_common_module(
247247
["embed_tokens", "decoder.self_attn.v_proj"], model
248248
)
249249
assert parent_name == "" and parent == model
250+
251+
m = torch.nn.ModuleDict(
252+
{
253+
"abc": Linear(3,3),
254+
"ab": torch.nn.ModuleDict({"a": Linear(3,3)}),
255+
"z": Linear(3,3)
256+
}
257+
)
258+
parent_name, parent = get_lowest_common_module(["abc", "ab"], m)
259+
assert parent_name == ""
260+
parent_name, parent = get_lowest_common_module(["ab", "ab.a"], m)
261+
assert parent_name == "ab"
262+
parent_name, parent = get_lowest_common_module(["z"], m)
263+
assert parent_name == "z"
264+

0 commit comments

Comments
 (0)