Skip to content

Commit c27f8e4

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0fd074b commit c27f8e4

File tree

3 files changed

+164
-81
lines changed

3 files changed

+164
-81
lines changed

auto_round/autoround.py

Lines changed: 126 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,11 +2279,11 @@ def get_act_max_hook(module, input, output):
22792279
# # Apply the configuration to the corresponding layer in the model
22802280
# for key in keys:
22812281
# setattr(m, key, low_config[key])
2282-
2282+
22832283
# return layer_config
22842284

22852285
@torch.inference_mode()
2286-
def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others, outputs, device,cache_device):
2286+
def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others, outputs, device, cache_device):
22872287
## TODO Q4_K_M does not support iters==0
22882288
## TODO for moe model, expert use default bits
22892289
mse_reduction = "mean"
@@ -2293,35 +2293,35 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
22932293

22942294
target_gguf_format = None
22952295
for format in formats:
2296-
if format.startswith("gguf") and 'm' in format:
2296+
if format.startswith("gguf") and "m" in format:
22972297
target_gguf_format = format
22982298
if target_gguf_format is None:
22992299
return
23002300

23012301
## simple verification, if the layer_config has any mixed-bits setting, we don't apply auto mix precision
2302-
bits = []
2303-
count=0
2302+
bits = []
2303+
count = 0
23042304
quant_bits = {}
2305-
for n, m in block.named_modules():#[4 4 6 4 4 6 8]
2305+
for n, m in block.named_modules(): # [4 4 6 4 4 6 8]
23062306
if hasattr(m, "bits"):
23072307
bits.append(m.bits)
2308-
quant_bits[m.bits]=0
2308+
quant_bits[m.bits] = 0
23092309
ori_bit = min(bits)
23102310
for b in bits:
23112311
if b != ori_bit:
2312-
quant_bits[b]+=1
2313-
bits = set(bits) #{4,6}
2312+
quant_bits[b] += 1
2313+
bits = set(bits) # {4,6}
23142314
if len(bits) <= 1:
23152315
return
23162316
del quant_bits[min(bits)]
2317-
2317+
23182318
layer_names = []
2319-
2319+
23202320
for n, m in block.named_modules():
23212321
if check_to_quantized(m):
23222322
layer_names.append(n)
2323-
count+=1
2324-
2323+
count += 1
2324+
23252325
if count > 10:
23262326
logger.info("不进行选择")
23272327
return
@@ -2334,17 +2334,17 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
23342334
# current_output = to_device(current_output, device)
23352335
current_input_ids = [input_ids[i] for i in whole_indices]
23362336
default_config = GGUF_CONFIG[target_gguf_format]
2337-
split_list = re.split(':|_',target_gguf_format)
2337+
split_list = re.split(":|_", target_gguf_format)
23382338
mix_configs = {}
2339-
2340-
for k,_ in quant_bits.items():
2339+
2340+
for k, _ in quant_bits.items():
23412341
mix_configs[k] = GGUF_CONFIG[f"gguf:q{k}_{split_list[2]}"]
2342-
2342+
23432343
d_format = [f"gguf:q{min(bits)}_{split_list[2]}"]
23442344
low_config = GGUF_CONFIG[f"gguf:q{min(bits)}_{split_list[2]}"]
23452345

23462346
default_layer_config = low_config
2347-
2347+
23482348
# for k in self.layer_config.keys():
23492349
# s = re.split('\.',k)
23502350
# if len(s) <2:
@@ -2355,73 +2355,135 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
23552355

23562356
if len(bits) == 2:
23572357
logger.info("量化单bit")
2358-
self.choose_one_bit(block,mix_configs,quant_bits,default_config,default_layer_config,layer_names,current_input_ids,input_others,current_output,mse_loss,device,cache_device)
2358+
self.choose_one_bit(
2359+
block,
2360+
mix_configs,
2361+
quant_bits,
2362+
default_config,
2363+
default_layer_config,
2364+
layer_names,
2365+
current_input_ids,
2366+
input_others,
2367+
current_output,
2368+
mse_loss,
2369+
device,
2370+
cache_device,
2371+
)
23592372
else:
23602373
logger.info("量化多bit")
2361-
self.choose_various_bit(block,mix_configs,quant_bits,default_config,default_layer_config,layer_names,current_input_ids,input_others,current_output,mse_loss,device,cache_device)
2362-
2374+
self.choose_various_bit(
2375+
block,
2376+
mix_configs,
2377+
quant_bits,
2378+
default_config,
2379+
default_layer_config,
2380+
layer_names,
2381+
current_input_ids,
2382+
input_others,
2383+
current_output,
2384+
mse_loss,
2385+
device,
2386+
cache_device,
2387+
)
23632388

2364-
def choose_one_bit(self,block,mix_configs,quant_bits,default_config,default_layer_config,layer_names,current_input_ids,input_others,current_output,mse_loss,device,cache_device):
2389+
def choose_one_bit(
2390+
self,
2391+
block,
2392+
mix_configs,
2393+
quant_bits,
2394+
default_config,
2395+
default_layer_config,
2396+
layer_names,
2397+
current_input_ids,
2398+
input_others,
2399+
current_output,
2400+
mse_loss,
2401+
device,
2402+
cache_device,
2403+
):
23652404
each_loss = {}
23662405
# bit = mix_configs.keys()[0]
2367-
[(_,cur_config)] = mix_configs.items()
2368-
[(_,num_bit)] = quant_bits.items()
2406+
[(_, cur_config)] = mix_configs.items()
2407+
[(_, num_bit)] = quant_bits.items()
23692408
for layer_name in layer_names:
23702409
module = get_module(block, layer_name)
23712410
self.layer_config[module.tmp_name] = default_config
2372-
for key in cur_config:
2373-
setattr(module,key,cur_config[key])
2374-
2375-
wrapper_layer = WrapperLinear(module,enable_minmax_tuning=False,enable_round_tuning=False,enable_norm_bias_tuning=False,device=device)
2411+
for key in cur_config:
2412+
setattr(module, key, cur_config[key])
2413+
2414+
wrapper_layer = WrapperLinear(
2415+
module,
2416+
enable_minmax_tuning=False,
2417+
enable_round_tuning=False,
2418+
enable_norm_bias_tuning=False,
2419+
device=device,
2420+
)
23762421
set_module(block, layer_name, wrapper_layer)
2377-
q_output = self.get_block_outputs(block, current_input_ids, input_others, self.batch_size * self.infer_bs_coeff,
2378-
device,
2379-
cache_device)
2380-
2381-
set_module(block,layer_name,wrapper_layer.orig_layer)
2422+
q_output = self.get_block_outputs(
2423+
block, current_input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, cache_device
2424+
)
2425+
2426+
set_module(block, layer_name, wrapper_layer.orig_layer)
23822427
module = get_module(block, layer_name)
2383-
for key in default_config:
2384-
setattr(module,key,default_config[key])
2385-
cur_loss=mse_loss(torch.stack(q_output).squeeze(1),current_output)
2386-
each_loss[layer_name] = cur_loss #把每一层的loss记录下来
2387-
2428+
for key in default_config:
2429+
setattr(module, key, default_config[key])
2430+
cur_loss = mse_loss(torch.stack(q_output).squeeze(1), current_output)
2431+
each_loss[layer_name] = cur_loss # 把每一层的loss记录下来
2432+
23882433
top_n_loss = sorted(each_loss.items(), key=lambda x: x[1], reverse=False)[:num_bit]
23892434
# breakpoint()
23902435
# tmp_list.append(max_loss[1])
23912436
flag = {}
2392-
for layer_name,_ in top_n_loss:
2437+
for layer_name, _ in top_n_loss:
23932438
module = get_module(block, layer_name)
2394-
for key in cur_config:
2395-
setattr(module,key,cur_config[key])
2396-
2439+
for key in cur_config:
2440+
setattr(module, key, cur_config[key])
2441+
23972442
self.layer_config[module.tmp_name] = cur_config
2398-
# continue
2443+
# continue
23992444

2400-
2401-
2402-
def choose_various_bit(self,block,mix_configs,quant_bits,cur_config,default_config,default_layer_config,layer_names,current_input_ids,input_others,current_output,mse_loss,device,cache_device):
2445+
def choose_various_bit(
2446+
self,
2447+
block,
2448+
mix_configs,
2449+
quant_bits,
2450+
cur_config,
2451+
default_config,
2452+
default_layer_config,
2453+
layer_names,
2454+
current_input_ids,
2455+
input_others,
2456+
current_output,
2457+
mse_loss,
2458+
device,
2459+
cache_device,
2460+
):
24032461
each_loss = {}
24042462
for layer_name in layer_names:
24052463
module = get_module(block, layer_name)
2406-
for key in default_config:
2407-
setattr(module,key,cur_config[key])
2408-
2409-
wrapper_layer = WrapperLinear(module,enable_minmax_tuning=False,enable_round_tuning=False,enable_norm_bias_tuning=False,device=device)
2464+
for key in default_config:
2465+
setattr(module, key, cur_config[key])
2466+
2467+
wrapper_layer = WrapperLinear(
2468+
module,
2469+
enable_minmax_tuning=False,
2470+
enable_round_tuning=False,
2471+
enable_norm_bias_tuning=False,
2472+
device=device,
2473+
)
24102474
set_module(block, layer_name, wrapper_layer)
2411-
q_output = self.get_block_outputs(block, current_input_ids, input_others, self.batch_size * self.infer_bs_coeff,
2412-
device,
2413-
cache_device)
2414-
set_module(block,layer_name,wrapper_layer.orig_layer)
2415-
2416-
cur_loss=mse_loss(torch.stack(q_output).squeeze(1),current_output)
2417-
each_loss[layer_name] = cur_loss #把每一层的loss记录下来
2418-
2419-
top_n_loss = sorted(each_loss.items(), key=lambda x: x[1], reverse=True)[:sum(quant_bits.values())]
2475+
q_output = self.get_block_outputs(
2476+
block, current_input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, cache_device
2477+
)
2478+
set_module(block, layer_name, wrapper_layer.orig_layer)
2479+
2480+
cur_loss = mse_loss(torch.stack(q_output).squeeze(1), current_output)
2481+
each_loss[layer_name] = cur_loss # 把每一层的loss记录下来
2482+
2483+
top_n_loss = sorted(each_loss.items(), key=lambda x: x[1], reverse=True)[: sum(quant_bits.values())]
24202484
shift = 0
2421-
for k,_ in top_n_loss.items():
2485+
for k, _ in top_n_loss.items():
24222486
self.layer_config[module.tmp_name] = cur_config
2423-
2424-
24252487

24262488
def quant_block(self, block, input_ids, input_others, q_input=None, device=torch.device("cpu")):
24272489
"""Quantize the weights of a given block of the model.
@@ -2476,7 +2538,8 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
24762538
handle.remove()
24772539

24782540
self.check_needs_auto_gguf_mix_mse(
2479-
block, self.formats, input_ids, input_others, output, device, self.cache_device)
2541+
block, self.formats, input_ids, input_others, output, device, self.cache_device
2542+
)
24802543

24812544
if q_input is not None:
24822545
if input_ids is not q_input:

auto_round/export/export_to_gguf/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,31 +158,31 @@ class ModelType(IntEnum):
158158
GGUF_CONFIG["gguf:q4_0"] = GGUF_INNER_CONFIG["gguf:q4_0"]
159159
GGUF_CONFIG["gguf:q4_0"]["mostly"] = "gguf:q4_0"
160160
GGUF_CONFIG["gguf:q4_1"] = GGUF_INNER_CONFIG["gguf:q4_1"]
161-
GGUF_CONFIG["gguf:q4_1"]["mostly"]= "gguf:q4_1"
161+
GGUF_CONFIG["gguf:q4_1"]["mostly"] = "gguf:q4_1"
162162
GGUF_CONFIG["gguf:q4_k"] = GGUF_INNER_CONFIG["gguf:q4_k"]
163163
GGUF_CONFIG["gguf:q5_0"] = GGUF_INNER_CONFIG["gguf:q5_0"]
164164
GGUF_CONFIG["gguf:q5_0"]["mostly"] = "gguf:q5_0"
165165
GGUF_CONFIG["gguf:q5_1"] = GGUF_INNER_CONFIG["gguf:q5_1"]
166166
GGUF_CONFIG["gguf:q5_1"]["mostly"] = "gguf:q5_1"
167167
GGUF_CONFIG["gguf:q5_k"] = GGUF_INNER_CONFIG["gguf:q5_k"]
168168
GGUF_CONFIG["gguf:q2_k_s"] = GGUF_INNER_CONFIG["gguf:q2_k"]
169-
GGUF_CONFIG["gguf:q2_k_s"]["mostly"]= "gguf:q2_k"
169+
GGUF_CONFIG["gguf:q2_k_s"]["mostly"] = "gguf:q2_k"
170170
GGUF_CONFIG["gguf:q3_k"] = GGUF_INNER_CONFIG["gguf:q3_k"]
171171
GGUF_CONFIG["gguf:q3_k"]["mostly"] = "gguf:q3_k"
172172
GGUF_CONFIG["gguf:q3_k_s"] = GGUF_INNER_CONFIG["gguf:q3_k"]
173173
GGUF_CONFIG["gguf:q3_k_s"]["mostly"] = "gguf:q3_k"
174174
GGUF_CONFIG["gguf:q3_k_m"] = GGUF_INNER_CONFIG["gguf:q3_k"]
175175
GGUF_CONFIG["gguf:q3_k_m"]["mostly"] = "gguf:q3_k"
176176
GGUF_CONFIG["gguf:q3_k_l"] = GGUF_INNER_CONFIG["gguf:q3_k"]
177-
GGUF_CONFIG["gguf:q3_k_l"]["mostly"]= "gguf:q3_k"
177+
GGUF_CONFIG["gguf:q3_k_l"]["mostly"] = "gguf:q3_k"
178178
GGUF_CONFIG["gguf:q4_k"] = GGUF_INNER_CONFIG["gguf:q4_k"]
179-
GGUF_CONFIG["gguf:q4_k"]["mostly"]= "gguf:q4_k"
179+
GGUF_CONFIG["gguf:q4_k"]["mostly"] = "gguf:q4_k"
180180
GGUF_CONFIG["gguf:q4_k_s"] = GGUF_INNER_CONFIG["gguf:q4_k"]
181181
GGUF_CONFIG["gguf:q4_k_s"]["mostly"] = "gguf:q4_k"
182182
GGUF_CONFIG["gguf:q4_k_m"] = GGUF_INNER_CONFIG["gguf:q4_k"]
183183
GGUF_CONFIG["gguf:q4_k_m"]["mostly"] = "gguf:q4_k"
184184
GGUF_CONFIG["gguf:q5_k"] = GGUF_INNER_CONFIG["gguf:q5_k"]
185-
GGUF_CONFIG["gguf:q5_k"]["mostly"]= "gguf:q5_k"
185+
GGUF_CONFIG["gguf:q5_k"]["mostly"] = "gguf:q5_k"
186186
GGUF_CONFIG["gguf:q5_k_s"] = GGUF_INNER_CONFIG["gguf:q5_k"]
187187
GGUF_CONFIG["gguf:q5_k_s"]["mostly"] = "gguf:q5_k"
188188
GGUF_CONFIG["gguf:q5_k_m"] = GGUF_INNER_CONFIG["gguf:q5_k"]

0 commit comments

Comments
 (0)