Skip to content

Commit 03fb664

Browse files
squashed/rebased
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent a2bfc03 commit 03fb664

File tree

1 file changed

+30
-33
lines changed
  • src/compressed_tensors/quantization/lifecycle

1 file changed

+30
-33
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
# limitations under the License.
1414

1515
import logging
16-
import re
17-
from collections import OrderedDict, defaultdict
16+
from collections import OrderedDict
1817
from copy import deepcopy
1918
from typing import Dict, Iterable, List, Optional
2019
from typing import OrderedDict as OrderedDictType
21-
from typing import Set, Union
20+
from typing import Union
2221

2322
import torch
2423
from compressed_tensors.config import CompressionFormat
@@ -39,7 +38,12 @@
3938
infer_quantization_status,
4039
is_kv_cache_quant_scheme,
4140
)
42-
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
41+
from compressed_tensors.utils.helpers import (
42+
fix_fsdp_module_name,
43+
deprecated,
44+
replace_module,
45+
)
46+
from compressed_tensors.utils.match import match_named_modules, match_targets
4347
from compressed_tensors.utils.offload import update_parameter_data
4448
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
4549
from safetensors import safe_open
@@ -51,8 +55,6 @@
5155
"apply_quantization_config",
5256
"apply_quantization_status",
5357
"find_name_or_class_matches",
54-
"expand_target_names",
55-
"is_target",
5658
]
5759

5860
from compressed_tensors.quantization.utils.helpers import is_module_quantized
@@ -144,31 +146,24 @@ def apply_quantization_config(
144146
for target in scheme.targets:
145147
target_to_scheme[target] = scheme
146148

147-
if run_compressed:
148-
from compressed_tensors.linear.compressed_linear import CompressedLinear
149-
150-
# list of submodules to ignore
151-
ignored_submodules = defaultdict(list)
152-
# mark appropriate layers for quantization by setting their quantization schemes
153-
for name, submodule in model.named_modules():
154-
# potentially fix module name to remove FSDP wrapper prefix
155-
name = fix_fsdp_module_name(name)
156-
if matches := find_name_or_class_matches(name, submodule, config.ignore):
157-
for match in matches:
158-
ignored_submodules[match].append(name)
159-
continue # layer matches ignore list, continue
160-
161-
targets = find_name_or_class_matches(name, submodule, target_to_scheme)
149+
# mark appropriate layers for quantization by setting their quantization schemes
150+
for name, submodule in match_named_modules(
151+
model, scheme.targets, config.ignore, warn_on_fail=True
152+
):
153+
# potentially fix module name to remove FSDP wrapper prefix
154+
name = fix_fsdp_module_name(name)
162155

163-
if targets:
164156
# mark modules to be quantized by adding
165157
# quant scheme to the matching layers
166-
scheme = _scheme_from_targets(target_to_scheme, targets, name)
158+
scheme = _scheme_from_targets(target_to_scheme, scheme.targets, name)
167159
if run_compressed:
168160
format = config.format
169161
if format != CompressionFormat.dense.value:
170162
if isinstance(submodule, torch.nn.Linear):
171-
# TODO: expand to more module types
163+
from compressed_tensors.linear.compressed_linear import (
164+
CompressedLinear,
165+
)
166+
172167
compressed_linear = CompressedLinear.from_linear(
173168
submodule,
174169
quantization_scheme=scheme,
@@ -181,16 +176,18 @@ def apply_quantization_config(
181176

182177
names_to_scheme[name] = submodule.quantization_scheme
183178

184-
if config.ignore is not None and ignored_submodules is not None:
185-
if set(config.ignore) - set(ignored_submodules):
186-
_LOGGER.warning(
187-
"Some layers that were to be ignored were "
188-
"not found in the model: "
189-
f"{set(config.ignore) - set(ignored_submodules)}"
190-
)
179+
# apply current quantization status to each targeted submodule
180+
apply_quantization_status(submodule, config.quantization_status)
181+
182+
# TODO warn on ignore not being found, this is useful in debugging
183+
# if config.ignore is not None and ignored_submodules is not None:
184+
# if set(config.ignore) - set(ignored_submodules):
185+
# _LOGGER.warning(
186+
# "Some layers that were to be ignored were "
187+
# "not found in the model: "
188+
# f"{set(config.ignore) - set(ignored_submodules)}"
189+
# )
191190

192-
# apply current quantization status across all targeted layers
193-
apply_quantization_status(model, config.quantization_status)
194191
return names_to_scheme
195192

196193

0 commit comments

Comments
 (0)