Skip to content

Commit 550c0ad

Browse files
squashed/rebased
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent d4354b0 commit 550c0ad

File tree

1 file changed

+47
-31
lines changed
  • src/compressed_tensors/quantization/lifecycle

1 file changed

+47
-31
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@
3838
infer_quantization_status,
3939
is_kv_cache_quant_scheme,
4040
)
41-
from compressed_tensors.utils.helpers import deprecated, replace_module
41+
from compressed_tensors.utils.helpers import (
42+
fix_fsdp_module_name,
43+
deprecated,
44+
replace_module,
45+
)
4246
from compressed_tensors.utils.match import match_named_modules, match_targets
4347
from compressed_tensors.utils.offload import update_parameter_data
4448
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
@@ -142,36 +146,48 @@ def apply_quantization_config(
142146
for target in scheme.targets:
143147
target_to_scheme[target] = scheme
144148

145-
if run_compressed:
146-
from compressed_tensors.linear.compressed_linear import CompressedLinear
147-
148-
# mark appropriate layers for quantization by setting their quantization schemes
149-
for name, submodule in match_named_modules(
150-
model, target_to_scheme, config.ignore, warn_on_fail=True
151-
):
152-
# mark modules to be quantized by adding
153-
# quant scheme to the matching layers
154-
matched_targets = match_targets(name, submodule, target_to_scheme)
155-
scheme = _scheme_from_targets(target_to_scheme, matched_targets, name)
156-
if run_compressed:
157-
format = config.format
158-
if format != CompressionFormat.dense.value:
159-
if isinstance(submodule, torch.nn.Linear):
160-
# TODO: expand to more module types
161-
compressed_linear = CompressedLinear.from_linear(
162-
submodule,
163-
quantization_scheme=scheme,
164-
quantization_format=format,
165-
)
166-
replace_module(model, name, compressed_linear)
167-
168-
# target matched - add layer and scheme to target list
169-
submodule.quantization_scheme = scheme
170-
171-
names_to_scheme[name] = submodule.quantization_scheme
172-
173-
# apply current quantization status across all targeted layers
174-
apply_quantization_status(model, config.quantization_status)
149+
# mark appropriate layers for quantization by setting their quantization schemes
150+
for name, submodule in match_named_modules(
151+
model, scheme.targets, config.ignore, warn_on_fail=True
152+
):
153+
# potentially fix module name to remove FSDP wrapper prefix
154+
name = fix_fsdp_module_name(name)
155+
156+
# mark modules to be quantized by adding
157+
# quant scheme to the matching layers
158+
scheme = _scheme_from_targets(target_to_scheme, scheme.targets, name)
159+
if run_compressed:
160+
format = config.format
161+
if format != CompressionFormat.dense.value:
162+
if isinstance(submodule, torch.nn.Linear):
163+
from compressed_tensors.linear.compressed_linear import (
164+
CompressedLinear,
165+
)
166+
167+
compressed_linear = CompressedLinear.from_linear(
168+
submodule,
169+
quantization_scheme=scheme,
170+
quantization_format=format,
171+
)
172+
replace_module(model, name, compressed_linear)
173+
174+
# target matched - add layer and scheme to target list
175+
submodule.quantization_scheme = scheme
176+
177+
names_to_scheme[name] = submodule.quantization_scheme
178+
179+
# apply current quantization status to each targeted submodule
180+
apply_quantization_status(submodule, config.quantization_status)
181+
182+
# TODO warn on ignore not being found, this is useful in debugging
183+
# if config.ignore is not None and ignored_submodules is not None:
184+
# if set(config.ignore) - set(ignored_submodules):
185+
# _LOGGER.warning(
186+
# "Some layers that were to be ignored were "
187+
# "not found in the model: "
188+
# f"{set(config.ignore) - set(ignored_submodules)}"
189+
# )
190+
175191
return names_to_scheme
176192

177193

0 commit comments

Comments
 (0)