Skip to content

Commit e8c4259

Browse files
update
1 parent 59a583b commit e8c4259

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ dist/
1010
CMakeUserPresets.json
1111
tmp_autoround/
1212
ut_log_dir/
13-
.history/
13+
.history
1414
*.sh

auto_round/autoround.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,6 +2261,9 @@ def auto_mix_rtn(self, model: torch.nn.Module, inputs, block_names, q_input=None
22612261
def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others, outputs, device, cache_device, mode="percent"):
22622262
## TODO Q4_K_M does not support iters==0
22632263
## TODO for moe model, expert use default bits
2264+
s = block.tmp_name
2265+
s_l = re.split("\.",s)
2266+
22642267
mse_reduction = "mean"
22652268
if self.gradient_accumulate_steps != 1:
22662269
mse_reduction = "sum"
@@ -2320,9 +2323,24 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
23202323
low_config = GGUF_CONFIG[f"gguf:q{min(bits)}_{split_list[2]}"]
23212324

23222325
default_layer_config = low_config
2326+
start_num = 10
2327+
end_num = 20
2328+
if(start_num<int(s_l[-1])<end_num):
2329+
logger.info(f"不混合{s_l[-1]}层")
2330+
for layer_name in layer_names:
2331+
module = get_module(block, layer_name)
2332+
self.layer_config[module.tmp_name] = default_layer_config
2333+
# logger.info(tmp_ori_layer[module.tmp_name])
2334+
for key in default_layer_config:
2335+
setattr(module, key, default_layer_config[key])
2336+
return
2337+
2338+
if(int(s_l[-1])>=end_num):
2339+
for kx in quant_bits.keys():
2340+
quant_bits[kx]+=1
23232341

23242342
if len(bits) == 2:
2325-
logger.info(f"量化单bit:{bits},模式为:{mode}")
2343+
logger.info(f"量化单bit:{quant_bits},模式为:{mode}")
23262344
self.choose_one_bit(
23272345
block,
23282346
mix_configs,
@@ -2339,7 +2357,7 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
23392357
mode = mode,
23402358
)
23412359
else:
2342-
logger.info(f"量化多bit,模式为:{mode}")
2360+
logger.info(f"量化多bit:{quant_bits},模式为:{mode}")
23432361
self.choose_various_bit(
23442362
block,
23452363
mix_configs,
@@ -2373,13 +2391,15 @@ def choose_one_bit(
23732391
mode="max",
23742392
):
23752393
each_loss = {}
2394+
tmp_ori_layer = {}
23762395
# bit = mix_configs.keys()[0]
23772396
[(_, cur_config)] = mix_configs.items()
23782397
[(_, num_bit)] = quant_bits.items()
23792398
for layer_name in layer_names:
23802399
module = get_module(block, layer_name)
2400+
tmp_ori_layer[module.tmp_name] = self.layer_config[module.tmp_name]
23812401
self.layer_config[module.tmp_name] = default_config
2382-
if mode == "min" or "percent":
2402+
if mode in {"min", "percent", "marginal income ratio"}:
23832403
for key in cur_config:
23842404
setattr(module, key, cur_config[key])
23852405
elif mode == "sensitive":
@@ -2398,7 +2418,7 @@ def choose_one_bit(
23982418
block, current_input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, cache_device
23992419
)
24002420

2401-
if mode == "percent":
2421+
if mode in {"percent", "marginal income ratio", "min"}:
24022422
for key in default_config:
24032423
setattr(module, key, default_config[key])
24042424
wrapper_layer = WrapperLinear(
@@ -2420,20 +2440,50 @@ def choose_one_bit(
24202440

24212441
if mode == "min" or mode == "sensitive":
24222442
cur_loss = mse_loss(torch.stack(q_output).squeeze(1), current_output)
2443+
# loss_high = mse_loss(torch.stack(q_output).squeeze(1), current_output)
2444+
# loss_low = mse_loss(torch.stack(q2_output).squeeze(1), current_output)
2445+
# cur_loss = (loss_low - loss_high)/loss_low #改善率越高,表现值loss越小,值为负且越小
24232446
elif mode == "percent":
24242447
loss_high = mse_loss(torch.stack(q_output).squeeze(1), current_output)
24252448
loss_low = mse_loss(torch.stack(q2_output).squeeze(1), current_output)
24262449
cur_loss = (loss_high - loss_low)/loss_low #改善率越高,值为负且越小
2450+
logger.info(f"low:{loss_low}")
2451+
logger.info(f"high:{loss_high}")
2452+
elif mode == "marginal income ratio":
2453+
loss_high = mse_loss(torch.stack(q_output).squeeze(1), current_output)
2454+
loss_low = mse_loss(torch.stack(q2_output).squeeze(1), current_output)
2455+
income = (loss_high - loss_low)/loss_low #改善率越高,值为负且越小
2456+
marginal = income/((cur_config["bits"]-default_config["bits"])*sum(a.numel() for a in module.parameters())) #边际收益越高值为负且越小
2457+
cur_loss = marginal
24272458
each_loss[layer_name] = cur_loss # 把每一层的loss记录下来
24282459

24292460
top_n_loss = sorted(each_loss.items(), key=lambda x: x[1], reverse=False)[:num_bit] #reverse=False升序
2461+
24302462
# tmp_list.append(max_loss[1])
24312463
flag = {}
24322464
for layer_name, loss_item in top_n_loss:
2465+
module = get_module(block, layer_name)
2466+
2467+
if mode == "percent":
2468+
logger.info(f"增长率为:{-loss_item*100}%")
2469+
elif mode == "marginal income ratio":
2470+
logger.info(f"当前量化目标为:{module.tmp_name},边际收益为{-loss_item}")
2471+
if -loss_item < 1.12e-7:
2472+
logger.info(f"out of acceptable threshold, disadopt it")
2473+
for layer_name in layer_names:
2474+
module = get_module(block, layer_name)
2475+
self.layer_config[module.tmp_name] = tmp_ori_layer[module.tmp_name]
2476+
# logger.info(tmp_ori_layer[module.tmp_name])
2477+
for key in cur_config:
2478+
setattr(module, key, self.layer_config[module.tmp_name][key])
2479+
break
2480+
elif mode == "min":
2481+
logger.info(f"当前计算结果为:{loss_item}")
2482+
24332483
if loss_item > 0 and mode == "percent":
24342484
logger.info(f"loss = {loss_item} > 0, it seems become worse,so we skip it")
24352485
break
2436-
module = get_module(block, layer_name)
2486+
24372487
for key in cur_config:
24382488
setattr(module, key, cur_config[key])
24392489

q2_k_s_sensitive.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ CUDA_VISIBLE_DEVICES=$device python -m auto_round \
44
--format gguf:q2_k_s,fake \
55
--model /models/${model_name} \
66
--output_dir /data5/shiqi/models \
7-
--eval_bs 32 \
87
--iters 200 \
98
--tasks lambada_openai,hellaswag,piqa,winogrande,truthfulqa_mc1,openbookqa,boolq,arc_easy,arc_challenge,mmlu \
109
2>&1 | tee /data5/shiqi/log/q2_k_s_${model_name}_bs32_sensitive.log

0 commit comments

Comments
 (0)