Skip to content

Commit 8ba8a73

Browse files
add dp to auto choose the bit while keepping size not too large
1 parent e8c4259 commit 8ba8a73

File tree

1 file changed

+228
-31
lines changed

1 file changed

+228
-31
lines changed

auto_round/autoround.py

Lines changed: 228 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,6 +2231,22 @@ def auto_mix_rtn(self, model: torch.nn.Module, inputs, block_names, q_input=None
22312231

22322232
if pbar is None:
22332233
pbar = tqdm(range(0, len(block_names), nblocks))
2234+
2235+
#convert all block to default quant block
2236+
for i in range(0, len(block_names), nblocks):
2237+
for format in self.formats:
2238+
if format.startswith("gguf"):
2239+
target_gguf_format = format
2240+
if target_gguf_format is None:
2241+
logger.info("target_gguf_format is None, return")
2242+
return
2243+
2244+
cur_m_n = block_names[i]
2245+
cur_m = get_module(model, cur_m_n)
2246+
cur_m = cur_m.to(device)
2247+
quantized_layer_names, unquantized_layer_names = wrapper_block(
2248+
cur_m, self.enable_minmax_tuning, self.enable_norm_bias_tuning, device=self.device
2249+
)
22342250

22352251
for i in range(0, len(block_names), nblocks):
22362252
if i != 0:
@@ -2256,9 +2272,28 @@ def auto_mix_rtn(self, model: torch.nn.Module, inputs, block_names, q_input=None
22562272
device=device,
22572273
)
22582274

2275+
for i in range(0, len(block_names), nblocks):
2276+
for format in self.formats:
2277+
if format.startswith("gguf"):
2278+
target_gguf_format = format
2279+
if target_gguf_format is None:
2280+
logger.info("target_gguf_format is None, return")
2281+
return
2282+
layer_names = []
2283+
cur_m_n = block_names[i]
2284+
cur_m = get_module(model, cur_m_n)
2285+
cur_m = cur_m.to(device)
2286+
for n,m in cur_m.named_modules():
2287+
if hasattr(m,"orig_layer"):
2288+
set_module(m,n,m.orig_layer)
2289+
logger.info(f"n:{n}")
2290+
2291+
2292+
2293+
22592294

22602295
@torch.inference_mode()
2261-
def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others, outputs, device, cache_device, mode="percent"):
2296+
def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others, outputs, device, cache_device, mode="percent", block_type="wrap"):
22622297
## TODO Q4_K_M does not support iters==0
22632298
## TODO for moe model, expert use default bits
22642299
s = block.tmp_name
@@ -2284,18 +2319,27 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
22842319
layer_names = []
22852320

22862321
for n, m in block.named_modules():
2287-
if check_to_quantized(m):
2322+
if check_to_quantized(m) and not n.endswith("orig_layer"):
22882323
layer_names.append(n)
22892324
count += 1
22902325
if hasattr(m, "bits"):
22912326
bits.append(m.bits)
22922327
quant_bits[m.bits] = 0
2293-
2328+
if block_type == "wrap":
2329+
if hasattr(m.orig_layer, "bits"):
2330+
bits.append(m.orig_layer.bits)
2331+
quant_bits[m.orig_layer.bits] = 0
2332+
22942333
ori_bit = min(bits)
2295-
22962334
for b in bits:
22972335
if b != ori_bit:
22982336
quant_bits[b] += 1
2337+
2338+
# if block_type == "wrap":
2339+
# count = count/2
2340+
# for k in quant_bits.keys():
2341+
# quant_bits[k] /= 2
2342+
22992343
bits = set(bits) # {4,6}
23002344
if len(bits) <= 1:
23012345
logger.info(f"len<=1,bits为:{bits}不进行选择")
@@ -2306,6 +2350,7 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
23062350
return
23072351

23082352
nsamples = min(32, len(outputs))
2353+
# nsamples = 128
23092354
whole_indices = torch.randperm(len(outputs))[:nsamples]
23102355
##we assume the block input and output shape are same
23112356
current_output = [outputs[x] for x in whole_indices]
@@ -2323,25 +2368,25 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
23232368
low_config = GGUF_CONFIG[f"gguf:q{min(bits)}_{split_list[2]}"]
23242369

23252370
default_layer_config = low_config
2326-
start_num = 10
2327-
end_num = 20
2328-
if(start_num<int(s_l[-1])<end_num):
2329-
logger.info(f"不混合{s_l[-1]}层")
2330-
for layer_name in layer_names:
2331-
module = get_module(block, layer_name)
2332-
self.layer_config[module.tmp_name] = default_layer_config
2333-
# logger.info(tmp_ori_layer[module.tmp_name])
2334-
for key in default_layer_config:
2335-
setattr(module, key, default_layer_config[key])
2336-
return
2337-
2338-
if(int(s_l[-1])>=end_num):
2339-
for kx in quant_bits.keys():
2340-
quant_bits[kx]+=1
2371+
# start_num = 2
2372+
# end_num = 29
2373+
# if(start_num<int(s_l[-1])<end_num):
2374+
# logger.info(f"不混合{s_l[-1]}层")
2375+
# for layer_name in layer_names:
2376+
# module = get_module(block, layer_name)
2377+
# self.layer_config[module.tmp_name] = default_layer_config
2378+
# # logger.info(tmp_ori_layer[module.tmp_name])
2379+
# for key in default_layer_config:
2380+
# setattr(module, key, default_layer_config[key])
2381+
# return
2382+
# if(int(s_l[-1])>=end_num or int(s_l[-1])<=start_num):
2383+
# lim_size = False
2384+
# for kx in quant_bits.keys():
2385+
# quant_bits[kx]+=1
23412386

23422387
if len(bits) == 2:
23432388
logger.info(f"量化单bit:{quant_bits},模式为:{mode}")
2344-
self.choose_one_bit(
2389+
self.choose_one_bit_dp(
23452390
block,
23462391
mix_configs,
23472392
quant_bits,
@@ -2354,7 +2399,8 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
23542399
mse_loss,
23552400
device,
23562401
cache_device,
2357-
mode = mode,
2402+
mode=mode,
2403+
block_type=block_type,
23582404
)
23592405
else:
23602406
logger.info(f"量化多bit:{quant_bits},模式为:{mode}")
@@ -2388,7 +2434,8 @@ def choose_one_bit(
23882434
mse_loss,
23892435
device,
23902436
cache_device,
2391-
mode="max",
2437+
mode=None,
2438+
block_type=None,
23922439
):
23932440
each_loss = {}
23942441
tmp_ori_layer = {}
@@ -2397,15 +2444,18 @@ def choose_one_bit(
23972444
[(_, num_bit)] = quant_bits.items()
23982445
for layer_name in layer_names:
23992446
module = get_module(block, layer_name)
2447+
if block_type == "wrap" and hasattr(module,"orig_layer"):
2448+
set_module(block,layer_name,module.orig_layer)
2449+
module = get_module(block, layer_name)
24002450
tmp_ori_layer[module.tmp_name] = self.layer_config[module.tmp_name]
24012451
self.layer_config[module.tmp_name] = default_config
24022452
if mode in {"min", "percent", "marginal income ratio"}:
24032453
for key in cur_config:
24042454
setattr(module, key, cur_config[key])
24052455
elif mode == "sensitive":
24062456
for key in default_config:
2407-
setattr(module, key, default_config[key])
2408-
2457+
setattr(module, key, default_config[key]) #敏感度越大,loss掉的越多,
2458+
24092459
wrapper_layer = WrapperLinear(
24102460
module,
24112461
enable_minmax_tuning=False,
@@ -2418,7 +2468,7 @@ def choose_one_bit(
24182468
block, current_input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, cache_device
24192469
)
24202470

2421-
if mode in {"percent", "marginal income ratio", "min"}:
2471+
if mode in {"percent", "marginal income ratio"}:
24222472
for key in default_config:
24232473
setattr(module, key, default_config[key])
24242474
wrapper_layer = WrapperLinear(
@@ -2438,24 +2488,26 @@ def choose_one_bit(
24382488
for key in default_config:
24392489
setattr(module, key, default_config[key])
24402490

2441-
if mode == "min" or mode == "sensitive":
2491+
if mode == "min":
24422492
cur_loss = mse_loss(torch.stack(q_output).squeeze(1), current_output)
24432493
# loss_high = mse_loss(torch.stack(q_output).squeeze(1), current_output)
24442494
# loss_low = mse_loss(torch.stack(q2_output).squeeze(1), current_output)
24452495
# cur_loss = (loss_low - loss_high)/loss_low #改善率越高,表现值loss越小,值为负且越小
2496+
elif mode == "sensitive":
2497+
cur_loss = -mse_loss(torch.stack(q_output).squeeze(1), current_output) #敏感度越大,loss掉的越多,所以要选择-loss小的
24462498
elif mode == "percent":
24472499
loss_high = mse_loss(torch.stack(q_output).squeeze(1), current_output)
24482500
loss_low = mse_loss(torch.stack(q2_output).squeeze(1), current_output)
24492501
cur_loss = (loss_high - loss_low)/loss_low #改善率越高,值为负且越小
2450-
logger.info(f"low:{loss_low}")
2451-
logger.info(f"high:{loss_high}")
2502+
# logger.info(f"low:{loss_low}")
2503+
# logger.info(f"high:{loss_high}")
24522504
elif mode == "marginal income ratio":
24532505
loss_high = mse_loss(torch.stack(q_output).squeeze(1), current_output)
24542506
loss_low = mse_loss(torch.stack(q2_output).squeeze(1), current_output)
24552507
income = (loss_high - loss_low)/loss_low #改善率越高,值为负且越小
24562508
marginal = income/((cur_config["bits"]-default_config["bits"])*sum(a.numel() for a in module.parameters())) #边际收益越高值为负且越小
24572509
cur_loss = marginal
2458-
each_loss[layer_name] = cur_loss # 把每一层的loss记录下来
2510+
each_loss[layer_name] = (cur_loss, module.numel*cur_config["bits"]) # 把每一层的loss以及size记录下来
24592511

24602512
top_n_loss = sorted(each_loss.items(), key=lambda x: x[1], reverse=False)[:num_bit] #reverse=False升序
24612513

@@ -2477,8 +2529,8 @@ def choose_one_bit(
24772529
for key in cur_config:
24782530
setattr(module, key, self.layer_config[module.tmp_name][key])
24792531
break
2480-
elif mode == "min":
2481-
logger.info(f"当前计算结果为{loss_item}")
2532+
elif mode == "min" or "sensitive":
2533+
logger.info(f"层:{layer_name},loss{loss_item}")
24822534

24832535
if loss_item > 0 and mode == "percent":
24842536
logger.info(f"loss = {loss_item} > 0, it seems become worse,so we skip it")
@@ -2490,6 +2542,151 @@ def choose_one_bit(
24902542
self.layer_config[module.tmp_name] = cur_config
24912543
# continue
24922544

2545+
def choose_one_bit_dp(
2546+
self,
2547+
block,
2548+
mix_configs,
2549+
quant_bits,
2550+
default_config,
2551+
default_layer_config,
2552+
layer_names,
2553+
current_input_ids,
2554+
input_others,
2555+
current_output,
2556+
mse_loss,
2557+
device,
2558+
cache_device,
2559+
mode=None,
2560+
block_type=None,
2561+
):
2562+
each_loss = {}
2563+
tmp_ori_layer = {}
2564+
# bit = mix_configs.keys()[0]
2565+
[(_, cur_config)] = mix_configs.items()
2566+
for layer_name in layer_names:
2567+
module = get_module(block, layer_name)
2568+
2569+
layer_size = 0
2570+
for layer_name in layer_names:
2571+
module = get_module(block, layer_name)
2572+
layer_size += (module.orig_layer.bits-default_config["bits"])*sum(p.numel() for p in module.parameters())
2573+
2574+
layer_size = int(layer_size*1.05/1e6)
2575+
logger.info(f"layer_size为:{layer_size}")
2576+
for layer_name in layer_names:
2577+
module = get_module(block, layer_name)
2578+
if block_type == "wrap" and hasattr(module,"orig_layer"):
2579+
set_module(block,layer_name,module.orig_layer)
2580+
module = get_module(block, layer_name)
2581+
# tmp_ori_layer[module.tmp_name] = self.layer_config[module.tmp_name]
2582+
# if block_type in {"wrap"}:
2583+
# for key in cur_config:
2584+
# setattr(module.orig_layer, key, cur_config[key])
2585+
elif block_type == "default":
2586+
for key in default_config:
2587+
setattr(module.orig_layer, key, default_config[key])
2588+
2589+
self.layer_config[module.tmp_name] = default_config
2590+
2591+
wrapper_layer = WrapperLinear(
2592+
module,
2593+
enable_minmax_tuning=False,
2594+
enable_round_tuning=False,
2595+
enable_norm_bias_tuning=False,
2596+
device=device,
2597+
)
2598+
set_module(block, layer_name, wrapper_layer)
2599+
q_output = self.get_block_outputs(
2600+
block, current_input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, cache_device
2601+
)
2602+
2603+
if mode in {"percent", "marginal income ratio"}:
2604+
for key in default_config:
2605+
setattr(module, key, default_config[key])
2606+
wrapper_layer = WrapperLinear(
2607+
module,
2608+
enable_minmax_tuning=False,
2609+
enable_round_tuning=False,
2610+
enable_norm_bias_tuning=False,
2611+
device=device,
2612+
)
2613+
set_module(block, layer_name, wrapper_layer)
2614+
q2_output = self.get_block_outputs(
2615+
block, current_input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, cache_device
2616+
)
2617+
2618+
#还原为orig_layer TODO:这里其实有个问题,之前是GGUF的存在高bit而这样操作过后全是低bit了
2619+
set_module(block, layer_name, wrapper_layer.orig_layer)
2620+
module = get_module(block, layer_name)
2621+
for key in default_config:
2622+
setattr(module, key, default_config[key])
2623+
2624+
if mode == "max":
2625+
cur_loss = 1/mse_loss(torch.stack(q_output).squeeze(1), current_output)
2626+
elif mode == "sensitive":
2627+
cur_loss = -mse_loss(torch.stack(q_output).squeeze(1), current_output) #敏感度越大,loss掉的越多,所以要选择-loss小的
2628+
elif mode == "percent":
2629+
loss_high = mse_loss(torch.stack(q_output).squeeze(1), current_output)
2630+
loss_low = mse_loss(torch.stack(q2_output).squeeze(1), current_output)
2631+
cur_loss = -(loss_high - loss_low)/loss_low #改善率越高,值为正且越大
2632+
elif mode == "marginal income ratio":
2633+
loss_high = mse_loss(torch.stack(q_output).squeeze(1), current_output)
2634+
loss_low = mse_loss(torch.stack(q2_output).squeeze(1), current_output)
2635+
income = (loss_high - loss_low)/loss_low #改善率越高,值为负且越小
2636+
marginal = income/((cur_config["bits"]-default_config["bits"])*sum(a.numel() for a in module.parameters())) #边际收益越高值为负且越小
2637+
cur_loss = marginal
2638+
2639+
each_loss[layer_name] = (cur_loss, (cur_config["bits"]-module.bits)*sum(p.numel() for p in module.parameters())/1e6) # 把每一层的loss以及size记录下来
2640+
2641+
#DP
2642+
from collections import defaultdict
2643+
import numpy as np
2644+
# 创建二维 defaultdict,每个元素默认是一个 list
2645+
choosed_layer = defaultdict(lambda: defaultdict(list))
2646+
viewd_layer = -1
2647+
logger.info(each_loss)
2648+
f = np.zeros((len(layer_names),layer_size))
2649+
2650+
#init
2651+
v , w = each_loss[layer_names[0]]
2652+
w = int(w)
2653+
if v>0:
2654+
for V in range(w, layer_size-1):
2655+
f[0][V] = v
2656+
choosed_layer[0][V].append(layer_name)
2657+
2658+
for layer_name in layer_names:
2659+
viewd_layer+=1
2660+
loss, size = each_loss[layer_name]
2661+
size = int(size)
2662+
for V in tqdm(range(layer_size)):
2663+
if size < V:
2664+
if f[viewd_layer-1][V]>(f[viewd_layer-1][V-size]+loss):
2665+
f[viewd_layer][V] = f[viewd_layer-1][V]
2666+
choosed_layer[viewd_layer][V] = choosed_layer[viewd_layer-1][V]
2667+
else:
2668+
f[viewd_layer][V] = f[viewd_layer-1][V-size]+loss
2669+
choosed_layer[viewd_layer][V] = choosed_layer[viewd_layer-1][V]
2670+
choosed_layer[viewd_layer][V].append(layer_name)
2671+
2672+
else:
2673+
f[viewd_layer][V] = f[viewd_layer-1][V]
2674+
choosed_layer[viewd_layer][V] = choosed_layer[viewd_layer-1][V]
2675+
2676+
2677+
# tmp_list.append(max_loss[1])
2678+
logger.info(f"得到的value:{f[viewd_layer][layer_size-1]}\n选择的layer:{choosed_layer[viewd_layer][layer_size-1]}")
2679+
for layer_name in choosed_layer[viewd_layer][layer_size-1]:
2680+
2681+
module = get_module(block, layer_name)
2682+
2683+
for key in cur_config:
2684+
setattr(module, key, cur_config[key])
2685+
2686+
self.layer_config[module.tmp_name] = cur_config
2687+
# continue
2688+
2689+
24932690
def choose_various_bit(
24942691
self,
24952692
block,

0 commit comments

Comments
 (0)