13
13
# limitations under the License.
14
14
15
15
import logging
16
- import re
17
- from collections import OrderedDict , defaultdict
16
+ from collections import OrderedDict
18
17
from copy import deepcopy
19
18
from typing import Dict , Iterable , List , Optional
20
19
from typing import OrderedDict as OrderedDictType
21
- from typing import Set , Union
20
+ from typing import Union
22
21
23
22
import torch
24
23
from compressed_tensors .config import CompressionFormat
39
38
infer_quantization_status ,
40
39
is_kv_cache_quant_scheme ,
41
40
)
42
- from compressed_tensors .utils .helpers import fix_fsdp_module_name , replace_module
41
+ from compressed_tensors .utils .helpers import (
42
+ fix_fsdp_module_name ,
43
+ deprecated ,
44
+ replace_module ,
45
+ )
46
+ from compressed_tensors .utils .match import match_named_modules , match_targets
43
47
from compressed_tensors .utils .offload import update_parameter_data
44
48
from compressed_tensors .utils .safetensors_load import get_safetensors_folder
45
49
from safetensors import safe_open
51
55
"apply_quantization_config" ,
52
56
"apply_quantization_status" ,
53
57
"find_name_or_class_matches" ,
54
- "expand_target_names" ,
55
- "is_target" ,
56
58
]
57
59
58
60
from compressed_tensors .quantization .utils .helpers import is_module_quantized
@@ -144,31 +146,24 @@ def apply_quantization_config(
144
146
for target in scheme .targets :
145
147
target_to_scheme [target ] = scheme
146
148
147
- if run_compressed :
148
- from compressed_tensors .linear .compressed_linear import CompressedLinear
149
-
150
- # list of submodules to ignore
151
- ignored_submodules = defaultdict (list )
152
- # mark appropriate layers for quantization by setting their quantization schemes
153
- for name , submodule in model .named_modules ():
154
- # potentially fix module name to remove FSDP wrapper prefix
155
- name = fix_fsdp_module_name (name )
156
- if matches := find_name_or_class_matches (name , submodule , config .ignore ):
157
- for match in matches :
158
- ignored_submodules [match ].append (name )
159
- continue # layer matches ignore list, continue
160
-
161
- targets = find_name_or_class_matches (name , submodule , target_to_scheme )
149
+ # mark appropriate layers for quantization by setting their quantization schemes
150
+ for name , submodule in match_named_modules (
151
+ model , scheme .targets , config .ignore , warn_on_fail = True
152
+ ):
153
+ # potentially fix module name to remove FSDP wrapper prefix
154
+ name = fix_fsdp_module_name (name )
162
155
163
- if targets :
164
156
# mark modules to be quantized by adding
165
157
# quant scheme to the matching layers
166
- scheme = _scheme_from_targets (target_to_scheme , targets , name )
158
+ scheme = _scheme_from_targets (target_to_scheme , scheme . targets , name )
167
159
if run_compressed :
168
160
format = config .format
169
161
if format != CompressionFormat .dense .value :
170
162
if isinstance (submodule , torch .nn .Linear ):
171
- # TODO: expand to more module types
163
+ from compressed_tensors .linear .compressed_linear import (
164
+ CompressedLinear ,
165
+ )
166
+
172
167
compressed_linear = CompressedLinear .from_linear (
173
168
submodule ,
174
169
quantization_scheme = scheme ,
@@ -181,16 +176,18 @@ def apply_quantization_config(
181
176
182
177
names_to_scheme [name ] = submodule .quantization_scheme
183
178
184
- if config .ignore is not None and ignored_submodules is not None :
185
- if set (config .ignore ) - set (ignored_submodules ):
186
- _LOGGER .warning (
187
- "Some layers that were to be ignored were "
188
- "not found in the model: "
189
- f"{ set (config .ignore ) - set (ignored_submodules )} "
190
- )
179
+ # apply current quantization status to each targeted submodule
180
+ apply_quantization_status (submodule , config .quantization_status )
181
+
182
+ # TODO warn on ignore not being found, this is useful in debugging
183
+ # if config.ignore is not None and ignored_submodules is not None:
184
+ # if set(config.ignore) - set(ignored_submodules):
185
+ # _LOGGER.warning(
186
+ # "Some layers that were to be ignored were "
187
+ # "not found in the model: "
188
+ # f"{set(config.ignore) - set(ignored_submodules)}"
189
+ # )
191
190
192
- # apply current quantization status across all targeted layers
193
- apply_quantization_status (model , config .quantization_status )
194
191
return names_to_scheme
195
192
196
193
0 commit comments