@@ -2261,6 +2261,9 @@ def auto_mix_rtn(self, model: torch.nn.Module, inputs, block_names, q_input=None
2261
2261
def check_needs_auto_gguf_mix_mse (self , block , formats , input_ids , input_others , outputs , device , cache_device , mode = "percent" ):
2262
2262
## TODO Q4_K_M does not support iters==0
2263
2263
## TODO for moe model, expert use default bits
2264
+ s = block .tmp_name
2265
+ s_l = re .split ("\." ,s )
2266
+
2264
2267
mse_reduction = "mean"
2265
2268
if self .gradient_accumulate_steps != 1 :
2266
2269
mse_reduction = "sum"
@@ -2320,9 +2323,24 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
2320
2323
low_config = GGUF_CONFIG [f"gguf:q{ min (bits )} _{ split_list [2 ]} " ]
2321
2324
2322
2325
default_layer_config = low_config
2326
+ start_num = 10
2327
+ end_num = 20
2328
+ if (start_num < int (s_l [- 1 ])< end_num ):
2329
+ logger .info (f"不混合{ s_l [- 1 ]} 层" )
2330
+ for layer_name in layer_names :
2331
+ module = get_module (block , layer_name )
2332
+ self .layer_config [module .tmp_name ] = default_layer_config
2333
+ # logger.info(tmp_ori_layer[module.tmp_name])
2334
+ for key in default_layer_config :
2335
+ setattr (module , key , default_layer_config [key ])
2336
+ return
2337
+
2338
+ if (int (s_l [- 1 ])>= end_num ):
2339
+ for kx in quant_bits .keys ():
2340
+ quant_bits [kx ]+= 1
2323
2341
2324
2342
if len (bits ) == 2 :
2325
- logger .info (f"量化单bit:{ bits } ,模式为:{ mode } " )
2343
+ logger .info (f"量化单bit:{ quant_bits } ,模式为:{ mode } " )
2326
2344
self .choose_one_bit (
2327
2345
block ,
2328
2346
mix_configs ,
@@ -2339,7 +2357,7 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
2339
2357
mode = mode ,
2340
2358
)
2341
2359
else :
2342
- logger .info (f"量化多bit,模式为:{ mode } " )
2360
+ logger .info (f"量化多bit: { quant_bits } ,模式为:{ mode } " )
2343
2361
self .choose_various_bit (
2344
2362
block ,
2345
2363
mix_configs ,
@@ -2373,13 +2391,15 @@ def choose_one_bit(
2373
2391
mode = "max" ,
2374
2392
):
2375
2393
each_loss = {}
2394
+ tmp_ori_layer = {}
2376
2395
# bit = mix_configs.keys()[0]
2377
2396
[(_ , cur_config )] = mix_configs .items ()
2378
2397
[(_ , num_bit )] = quant_bits .items ()
2379
2398
for layer_name in layer_names :
2380
2399
module = get_module (block , layer_name )
2400
+ tmp_ori_layer [module .tmp_name ] = self .layer_config [module .tmp_name ]
2381
2401
self .layer_config [module .tmp_name ] = default_config
2382
- if mode == "min" or "percent" :
2402
+ if mode in { "min" , "percent" , "marginal income ratio" } :
2383
2403
for key in cur_config :
2384
2404
setattr (module , key , cur_config [key ])
2385
2405
elif mode == "sensitive" :
@@ -2398,7 +2418,7 @@ def choose_one_bit(
2398
2418
block , current_input_ids , input_others , self .batch_size * self .infer_bs_coeff , device , cache_device
2399
2419
)
2400
2420
2401
- if mode == "percent" :
2421
+ if mode in { "percent" , "marginal income ratio" , "min" } :
2402
2422
for key in default_config :
2403
2423
setattr (module , key , default_config [key ])
2404
2424
wrapper_layer = WrapperLinear (
@@ -2420,20 +2440,50 @@ def choose_one_bit(
2420
2440
2421
2441
if mode == "min" or mode == "sensitive" :
2422
2442
cur_loss = mse_loss (torch .stack (q_output ).squeeze (1 ), current_output )
2443
+ # loss_high = mse_loss(torch.stack(q_output).squeeze(1), current_output)
2444
+ # loss_low = mse_loss(torch.stack(q2_output).squeeze(1), current_output)
2445
+ # cur_loss = (loss_low - loss_high)/loss_low #改善率越高,表现值loss越小,值为负且越小
2423
2446
elif mode == "percent" :
2424
2447
loss_high = mse_loss (torch .stack (q_output ).squeeze (1 ), current_output )
2425
2448
loss_low = mse_loss (torch .stack (q2_output ).squeeze (1 ), current_output )
2426
2449
cur_loss = (loss_high - loss_low )/ loss_low #改善率越高,值为负且越小
2450
+ logger .info (f"low:{ loss_low } " )
2451
+ logger .info (f"high:{ loss_high } " )
2452
+ elif mode == "marginal income ratio" :
2453
+ loss_high = mse_loss (torch .stack (q_output ).squeeze (1 ), current_output )
2454
+ loss_low = mse_loss (torch .stack (q2_output ).squeeze (1 ), current_output )
2455
+ income = (loss_high - loss_low )/ loss_low #改善率越高,值为负且越小
2456
+ marginal = income / ((cur_config ["bits" ]- default_config ["bits" ])* sum (a .numel () for a in module .parameters ())) #边际收益越高值为负且越小
2457
+ cur_loss = marginal
2427
2458
each_loss [layer_name ] = cur_loss # 把每一层的loss记录下来
2428
2459
2429
2460
top_n_loss = sorted (each_loss .items (), key = lambda x : x [1 ], reverse = False )[:num_bit ] #reverse=False升序
2461
+
2430
2462
# tmp_list.append(max_loss[1])
2431
2463
flag = {}
2432
2464
for layer_name , loss_item in top_n_loss :
2465
+ module = get_module (block , layer_name )
2466
+
2467
+ if mode == "percent" :
2468
+ logger .info (f"增长率为:{ - loss_item * 100 } %" )
2469
+ elif mode == "marginal income ratio" :
2470
+ logger .info (f"当前量化目标为:{ module .tmp_name } ,边际收益为{ - loss_item } " )
2471
+ if - loss_item < 1.12e-7 :
2472
+ logger .info (f"out of acceptable threshold, disadopt it" )
2473
+ for layer_name in layer_names :
2474
+ module = get_module (block , layer_name )
2475
+ self .layer_config [module .tmp_name ] = tmp_ori_layer [module .tmp_name ]
2476
+ # logger.info(tmp_ori_layer[module.tmp_name])
2477
+ for key in cur_config :
2478
+ setattr (module , key , self .layer_config [module .tmp_name ][key ])
2479
+ break
2480
+ elif mode == "min" :
2481
+ logger .info (f"当前计算结果为:{ loss_item } " )
2482
+
2433
2483
if loss_item > 0 and mode == "percent" :
2434
2484
logger .info (f"loss = { loss_item } > 0, it seems become worse,so we skip it" )
2435
2485
break
2436
- module = get_module ( block , layer_name )
2486
+
2437
2487
for key in cur_config :
2438
2488
setattr (module , key , cur_config [key ])
2439
2489
0 commit comments