|
38 | 38 | infer_quantization_status,
|
39 | 39 | is_kv_cache_quant_scheme,
|
40 | 40 | )
|
41 |
| -from compressed_tensors.utils.helpers import deprecated, replace_module |
| 41 | +from compressed_tensors.utils.helpers import ( |
| 42 | + fix_fsdp_module_name, |
| 43 | + deprecated, |
| 44 | + replace_module, |
| 45 | +) |
42 | 46 | from compressed_tensors.utils.match import match_named_modules, match_targets
|
43 | 47 | from compressed_tensors.utils.offload import update_parameter_data
|
44 | 48 | from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
@@ -142,36 +146,48 @@ def apply_quantization_config(
|
142 | 146 | for target in scheme.targets:
|
143 | 147 | target_to_scheme[target] = scheme
|
144 | 148 |
|
145 |
| - if run_compressed: |
146 |
| - from compressed_tensors.linear.compressed_linear import CompressedLinear |
147 |
| - |
148 |
| - # mark appropriate layers for quantization by setting their quantization schemes |
149 |
| - for name, submodule in match_named_modules( |
150 |
| - model, target_to_scheme, config.ignore, warn_on_fail=True |
151 |
| - ): |
152 |
| - # mark modules to be quantized by adding |
153 |
| - # quant scheme to the matching layers |
154 |
| - matched_targets = match_targets(name, submodule, target_to_scheme) |
155 |
| - scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) |
156 |
| - if run_compressed: |
157 |
| - format = config.format |
158 |
| - if format != CompressionFormat.dense.value: |
159 |
| - if isinstance(submodule, torch.nn.Linear): |
160 |
| - # TODO: expand to more module types |
161 |
| - compressed_linear = CompressedLinear.from_linear( |
162 |
| - submodule, |
163 |
| - quantization_scheme=scheme, |
164 |
| - quantization_format=format, |
165 |
| - ) |
166 |
| - replace_module(model, name, compressed_linear) |
167 |
| - |
168 |
| - # target matched - add layer and scheme to target list |
169 |
| - submodule.quantization_scheme = scheme |
170 |
| - |
171 |
| - names_to_scheme[name] = submodule.quantization_scheme |
172 |
| - |
173 |
| - # apply current quantization status across all targeted layers |
174 |
| - apply_quantization_status(model, config.quantization_status) |
| 149 | + # mark appropriate layers for quantization by setting their quantization schemes |
| 150 | + for name, submodule in match_named_modules( |
| 151 | + model, scheme.targets, config.ignore, warn_on_fail=True |
| 152 | + ): |
| 153 | + # potentially fix module name to remove FSDP wrapper prefix |
| 154 | + name = fix_fsdp_module_name(name) |
| 155 | + |
| 156 | + # mark modules to be quantized by adding |
| 157 | + # quant scheme to the matching layers |
| 158 | + scheme = _scheme_from_targets(target_to_scheme, scheme.targets, name) |
| 159 | + if run_compressed: |
| 160 | + format = config.format |
| 161 | + if format != CompressionFormat.dense.value: |
| 162 | + if isinstance(submodule, torch.nn.Linear): |
| 163 | + from compressed_tensors.linear.compressed_linear import ( |
| 164 | + CompressedLinear, |
| 165 | + ) |
| 166 | + |
| 167 | + compressed_linear = CompressedLinear.from_linear( |
| 168 | + submodule, |
| 169 | + quantization_scheme=scheme, |
| 170 | + quantization_format=format, |
| 171 | + ) |
| 172 | + replace_module(model, name, compressed_linear) |
| 173 | + |
| 174 | + # target matched - add layer and scheme to target list |
| 175 | + submodule.quantization_scheme = scheme |
| 176 | + |
| 177 | + names_to_scheme[name] = submodule.quantization_scheme |
| 178 | + |
| 179 | + # apply current quantization status to each targeted submodule |
| 180 | + apply_quantization_status(submodule, config.quantization_status) |
| 181 | + |
| 182 | + # TODO warn on ignore not being found, this is useful in debugging |
| 183 | + # if config.ignore is not None and ignored_submodules is not None: |
| 184 | + # if set(config.ignore) - set(ignored_submodules): |
| 185 | + # _LOGGER.warning( |
| 186 | + # "Some layers that were to be ignored were " |
| 187 | + # "not found in the model: " |
| 188 | + # f"{set(config.ignore) - set(ignored_submodules)}" |
| 189 | + # ) |
| 190 | + |
175 | 191 | return names_to_scheme
|
176 | 192 |
|
177 | 193 |
|
|
0 commit comments