@@ -2231,6 +2231,22 @@ def auto_mix_rtn(self, model: torch.nn.Module, inputs, block_names, q_input=None
2231
2231
2232
2232
if pbar is None :
2233
2233
pbar = tqdm (range (0 , len (block_names ), nblocks ))
2234
+
2235
+ #convert all block to default quant block
2236
+ for i in range (0 , len (block_names ), nblocks ):
2237
+ for format in self .formats :
2238
+ if format .startswith ("gguf" ):
2239
+ target_gguf_format = format
2240
+ if target_gguf_format is None :
2241
+ logger .info ("target_gguf_format is None, return" )
2242
+ return
2243
+
2244
+ cur_m_n = block_names [i ]
2245
+ cur_m = get_module (model , cur_m_n )
2246
+ cur_m = cur_m .to (device )
2247
+ quantized_layer_names , unquantized_layer_names = wrapper_block (
2248
+ cur_m , self .enable_minmax_tuning , self .enable_norm_bias_tuning , device = self .device
2249
+ )
2234
2250
2235
2251
for i in range (0 , len (block_names ), nblocks ):
2236
2252
if i != 0 :
@@ -2256,9 +2272,28 @@ def auto_mix_rtn(self, model: torch.nn.Module, inputs, block_names, q_input=None
2256
2272
device = device ,
2257
2273
)
2258
2274
2275
+ for i in range (0 , len (block_names ), nblocks ):
2276
+ for format in self .formats :
2277
+ if format .startswith ("gguf" ):
2278
+ target_gguf_format = format
2279
+ if target_gguf_format is None :
2280
+ logger .info ("target_gguf_format is None, return" )
2281
+ return
2282
+ layer_names = []
2283
+ cur_m_n = block_names [i ]
2284
+ cur_m = get_module (model , cur_m_n )
2285
+ cur_m = cur_m .to (device )
2286
+ for n ,m in cur_m .named_modules ():
2287
+ if hasattr (m ,"orig_layer" ):
2288
+ set_module (m ,n ,m .orig_layer )
2289
+ logger .info (f"n:{ n } " )
2290
+
2291
+
2292
+
2293
+
2259
2294
2260
2295
@torch .inference_mode ()
2261
- def check_needs_auto_gguf_mix_mse (self , block , formats , input_ids , input_others , outputs , device , cache_device , mode = "percent" ):
2296
+ def check_needs_auto_gguf_mix_mse (self , block , formats , input_ids , input_others , outputs , device , cache_device , mode = "percent" , block_type = "wrap" ):
2262
2297
## TODO Q4_K_M does not support iters==0
2263
2298
## TODO for moe model, expert use default bits
2264
2299
s = block .tmp_name
@@ -2284,18 +2319,27 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
2284
2319
layer_names = []
2285
2320
2286
2321
for n , m in block .named_modules ():
2287
- if check_to_quantized (m ):
2322
+ if check_to_quantized (m ) and not n . endswith ( "orig_layer" ) :
2288
2323
layer_names .append (n )
2289
2324
count += 1
2290
2325
if hasattr (m , "bits" ):
2291
2326
bits .append (m .bits )
2292
2327
quant_bits [m .bits ] = 0
2293
-
2328
+ if block_type == "wrap" :
2329
+ if hasattr (m .orig_layer , "bits" ):
2330
+ bits .append (m .orig_layer .bits )
2331
+ quant_bits [m .orig_layer .bits ] = 0
2332
+
2294
2333
ori_bit = min (bits )
2295
-
2296
2334
for b in bits :
2297
2335
if b != ori_bit :
2298
2336
quant_bits [b ] += 1
2337
+
2338
+ # if block_type == "wrap":
2339
+ # count = count/2
2340
+ # for k in quant_bits.keys():
2341
+ # quant_bits[k] /= 2
2342
+
2299
2343
bits = set (bits ) # {4,6}
2300
2344
if len (bits ) <= 1 :
2301
2345
logger .info (f"len<=1,bits为:{ bits } 不进行选择" )
@@ -2306,6 +2350,7 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
2306
2350
return
2307
2351
2308
2352
nsamples = min (32 , len (outputs ))
2353
+ # nsamples = 128
2309
2354
whole_indices = torch .randperm (len (outputs ))[:nsamples ]
2310
2355
##we assume the block input and output shape are same
2311
2356
current_output = [outputs [x ] for x in whole_indices ]
@@ -2323,25 +2368,25 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
2323
2368
low_config = GGUF_CONFIG [f"gguf:q{ min (bits )} _{ split_list [2 ]} " ]
2324
2369
2325
2370
default_layer_config = low_config
2326
- start_num = 10
2327
- end_num = 20
2328
- if (start_num < int (s_l [- 1 ])< end_num ):
2329
- logger .info (f"不混合{ s_l [- 1 ]} 层" )
2330
- for layer_name in layer_names :
2331
- module = get_module (block , layer_name )
2332
- self .layer_config [module .tmp_name ] = default_layer_config
2333
- # logger.info(tmp_ori_layer[module.tmp_name])
2334
- for key in default_layer_config :
2335
- setattr (module , key , default_layer_config [key ])
2336
- return
2337
-
2338
- if ( int ( s_l [ - 1 ]) >= end_num ):
2339
- for kx in quant_bits .keys ():
2340
- quant_bits [kx ]+= 1
2371
+ # start_num = 2
2372
+ # end_num = 29
2373
+ # if(start_num<int(s_l[-1])<end_num):
2374
+ # logger.info(f"不混合{s_l[-1]}层")
2375
+ # for layer_name in layer_names:
2376
+ # module = get_module(block, layer_name)
2377
+ # self.layer_config[module.tmp_name] = default_layer_config
2378
+ # # logger.info(tmp_ori_layer[module.tmp_name])
2379
+ # for key in default_layer_config:
2380
+ # setattr(module, key, default_layer_config[key])
2381
+ # return
2382
+ # if(int(s_l[-1])>=end_num or int(s_l[-1])<=start_num):
2383
+ # lim_size = False
2384
+ # for kx in quant_bits.keys():
2385
+ # quant_bits[kx]+=1
2341
2386
2342
2387
if len (bits ) == 2 :
2343
2388
logger .info (f"量化单bit:{ quant_bits } ,模式为:{ mode } " )
2344
- self .choose_one_bit (
2389
+ self .choose_one_bit_dp (
2345
2390
block ,
2346
2391
mix_configs ,
2347
2392
quant_bits ,
@@ -2354,7 +2399,8 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
2354
2399
mse_loss ,
2355
2400
device ,
2356
2401
cache_device ,
2357
- mode = mode ,
2402
+ mode = mode ,
2403
+ block_type = block_type ,
2358
2404
)
2359
2405
else :
2360
2406
logger .info (f"量化多bit:{ quant_bits } ,模式为:{ mode } " )
@@ -2388,7 +2434,8 @@ def choose_one_bit(
2388
2434
mse_loss ,
2389
2435
device ,
2390
2436
cache_device ,
2391
- mode = "max" ,
2437
+ mode = None ,
2438
+ block_type = None ,
2392
2439
):
2393
2440
each_loss = {}
2394
2441
tmp_ori_layer = {}
@@ -2397,15 +2444,18 @@ def choose_one_bit(
2397
2444
[(_ , num_bit )] = quant_bits .items ()
2398
2445
for layer_name in layer_names :
2399
2446
module = get_module (block , layer_name )
2447
+ if block_type == "wrap" and hasattr (module ,"orig_layer" ):
2448
+ set_module (block ,layer_name ,module .orig_layer )
2449
+ module = get_module (block , layer_name )
2400
2450
tmp_ori_layer [module .tmp_name ] = self .layer_config [module .tmp_name ]
2401
2451
self .layer_config [module .tmp_name ] = default_config
2402
2452
if mode in {"min" , "percent" , "marginal income ratio" }:
2403
2453
for key in cur_config :
2404
2454
setattr (module , key , cur_config [key ])
2405
2455
elif mode == "sensitive" :
2406
2456
for key in default_config :
2407
- setattr (module , key , default_config [key ])
2408
-
2457
+ setattr (module , key , default_config [key ]) #敏感度越大,loss掉的越多,
2458
+
2409
2459
wrapper_layer = WrapperLinear (
2410
2460
module ,
2411
2461
enable_minmax_tuning = False ,
@@ -2418,7 +2468,7 @@ def choose_one_bit(
2418
2468
block , current_input_ids , input_others , self .batch_size * self .infer_bs_coeff , device , cache_device
2419
2469
)
2420
2470
2421
- if mode in {"percent" , "marginal income ratio" , "min" }:
2471
+ if mode in {"percent" , "marginal income ratio" }:
2422
2472
for key in default_config :
2423
2473
setattr (module , key , default_config [key ])
2424
2474
wrapper_layer = WrapperLinear (
@@ -2438,24 +2488,26 @@ def choose_one_bit(
2438
2488
for key in default_config :
2439
2489
setattr (module , key , default_config [key ])
2440
2490
2441
- if mode == "min" or mode == "sensitive" :
2491
+ if mode == "min" :
2442
2492
cur_loss = mse_loss (torch .stack (q_output ).squeeze (1 ), current_output )
2443
2493
# loss_high = mse_loss(torch.stack(q_output).squeeze(1), current_output)
2444
2494
# loss_low = mse_loss(torch.stack(q2_output).squeeze(1), current_output)
2445
2495
# cur_loss = (loss_low - loss_high)/loss_low #改善率越高,表现值loss越小,值为负且越小
2496
+ elif mode == "sensitive" :
2497
+ cur_loss = - mse_loss (torch .stack (q_output ).squeeze (1 ), current_output ) #敏感度越大,loss掉的越多,所以要选择-loss小的
2446
2498
elif mode == "percent" :
2447
2499
loss_high = mse_loss (torch .stack (q_output ).squeeze (1 ), current_output )
2448
2500
loss_low = mse_loss (torch .stack (q2_output ).squeeze (1 ), current_output )
2449
2501
cur_loss = (loss_high - loss_low )/ loss_low #改善率越高,值为负且越小
2450
- logger .info (f"low:{ loss_low } " )
2451
- logger .info (f"high:{ loss_high } " )
2502
+ # logger.info(f"low:{loss_low}")
2503
+ # logger.info(f"high:{loss_high}")
2452
2504
elif mode == "marginal income ratio" :
2453
2505
loss_high = mse_loss (torch .stack (q_output ).squeeze (1 ), current_output )
2454
2506
loss_low = mse_loss (torch .stack (q2_output ).squeeze (1 ), current_output )
2455
2507
income = (loss_high - loss_low )/ loss_low #改善率越高,值为负且越小
2456
2508
marginal = income / ((cur_config ["bits" ]- default_config ["bits" ])* sum (a .numel () for a in module .parameters ())) #边际收益越高值为负且越小
2457
2509
cur_loss = marginal
2458
- each_loss [layer_name ] = cur_loss # 把每一层的loss记录下来
2510
+ each_loss [layer_name ] = ( cur_loss , module . numel * cur_config [ "bits" ]) # 把每一层的loss以及size记录下来
2459
2511
2460
2512
top_n_loss = sorted (each_loss .items (), key = lambda x : x [1 ], reverse = False )[:num_bit ] #reverse=False升序
2461
2513
@@ -2477,8 +2529,8 @@ def choose_one_bit(
2477
2529
for key in cur_config :
2478
2530
setattr (module , key , self .layer_config [module .tmp_name ][key ])
2479
2531
break
2480
- elif mode == "min" :
2481
- logger .info (f"当前计算结果为 :{ loss_item } " )
2532
+ elif mode == "min" or "sensitive" :
2533
+ logger .info (f"层: { layer_name } ,loss :{ loss_item } " )
2482
2534
2483
2535
if loss_item > 0 and mode == "percent" :
2484
2536
logger .info (f"loss = { loss_item } > 0, it seems become worse,so we skip it" )
@@ -2490,6 +2542,151 @@ def choose_one_bit(
2490
2542
self .layer_config [module .tmp_name ] = cur_config
2491
2543
# continue
2492
2544
2545
+ def choose_one_bit_dp (
2546
+ self ,
2547
+ block ,
2548
+ mix_configs ,
2549
+ quant_bits ,
2550
+ default_config ,
2551
+ default_layer_config ,
2552
+ layer_names ,
2553
+ current_input_ids ,
2554
+ input_others ,
2555
+ current_output ,
2556
+ mse_loss ,
2557
+ device ,
2558
+ cache_device ,
2559
+ mode = None ,
2560
+ block_type = None ,
2561
+ ):
2562
+ each_loss = {}
2563
+ tmp_ori_layer = {}
2564
+ # bit = mix_configs.keys()[0]
2565
+ [(_ , cur_config )] = mix_configs .items ()
2566
+ for layer_name in layer_names :
2567
+ module = get_module (block , layer_name )
2568
+
2569
+ layer_size = 0
2570
+ for layer_name in layer_names :
2571
+ module = get_module (block , layer_name )
2572
+ layer_size += (module .orig_layer .bits - default_config ["bits" ])* sum (p .numel () for p in module .parameters ())
2573
+
2574
+ layer_size = int (layer_size * 1.05 / 1e6 )
2575
+ logger .info (f"layer_size为:{ layer_size } " )
2576
+ for layer_name in layer_names :
2577
+ module = get_module (block , layer_name )
2578
+ if block_type == "wrap" and hasattr (module ,"orig_layer" ):
2579
+ set_module (block ,layer_name ,module .orig_layer )
2580
+ module = get_module (block , layer_name )
2581
+ # tmp_ori_layer[module.tmp_name] = self.layer_config[module.tmp_name]
2582
+ # if block_type in {"wrap"}:
2583
+ # for key in cur_config:
2584
+ # setattr(module.orig_layer, key, cur_config[key])
2585
+ elif block_type == "default" :
2586
+ for key in default_config :
2587
+ setattr (module .orig_layer , key , default_config [key ])
2588
+
2589
+ self .layer_config [module .tmp_name ] = default_config
2590
+
2591
+ wrapper_layer = WrapperLinear (
2592
+ module ,
2593
+ enable_minmax_tuning = False ,
2594
+ enable_round_tuning = False ,
2595
+ enable_norm_bias_tuning = False ,
2596
+ device = device ,
2597
+ )
2598
+ set_module (block , layer_name , wrapper_layer )
2599
+ q_output = self .get_block_outputs (
2600
+ block , current_input_ids , input_others , self .batch_size * self .infer_bs_coeff , device , cache_device
2601
+ )
2602
+
2603
+ if mode in {"percent" , "marginal income ratio" }:
2604
+ for key in default_config :
2605
+ setattr (module , key , default_config [key ])
2606
+ wrapper_layer = WrapperLinear (
2607
+ module ,
2608
+ enable_minmax_tuning = False ,
2609
+ enable_round_tuning = False ,
2610
+ enable_norm_bias_tuning = False ,
2611
+ device = device ,
2612
+ )
2613
+ set_module (block , layer_name , wrapper_layer )
2614
+ q2_output = self .get_block_outputs (
2615
+ block , current_input_ids , input_others , self .batch_size * self .infer_bs_coeff , device , cache_device
2616
+ )
2617
+
2618
+ #还原为orig_layer TODO:这里其实有个问题,之前是GGUF的存在高bit而这样操作过后全是低bit了
2619
+ set_module (block , layer_name , wrapper_layer .orig_layer )
2620
+ module = get_module (block , layer_name )
2621
+ for key in default_config :
2622
+ setattr (module , key , default_config [key ])
2623
+
2624
+ if mode == "max" :
2625
+ cur_loss = 1 / mse_loss (torch .stack (q_output ).squeeze (1 ), current_output )
2626
+ elif mode == "sensitive" :
2627
+ cur_loss = - mse_loss (torch .stack (q_output ).squeeze (1 ), current_output ) #敏感度越大,loss掉的越多,所以要选择-loss小的
2628
+ elif mode == "percent" :
2629
+ loss_high = mse_loss (torch .stack (q_output ).squeeze (1 ), current_output )
2630
+ loss_low = mse_loss (torch .stack (q2_output ).squeeze (1 ), current_output )
2631
+ cur_loss = - (loss_high - loss_low )/ loss_low #改善率越高,值为正且越大
2632
+ elif mode == "marginal income ratio" :
2633
+ loss_high = mse_loss (torch .stack (q_output ).squeeze (1 ), current_output )
2634
+ loss_low = mse_loss (torch .stack (q2_output ).squeeze (1 ), current_output )
2635
+ income = (loss_high - loss_low )/ loss_low #改善率越高,值为负且越小
2636
+ marginal = income / ((cur_config ["bits" ]- default_config ["bits" ])* sum (a .numel () for a in module .parameters ())) #边际收益越高值为负且越小
2637
+ cur_loss = marginal
2638
+
2639
+ each_loss [layer_name ] = (cur_loss , (cur_config ["bits" ]- module .bits )* sum (p .numel () for p in module .parameters ())/ 1e6 ) # 把每一层的loss以及size记录下来
2640
+
2641
+ #DP
2642
+ from collections import defaultdict
2643
+ import numpy as np
2644
+ # 创建二维 defaultdict,每个元素默认是一个 list
2645
+ choosed_layer = defaultdict (lambda : defaultdict (list ))
2646
+ viewd_layer = - 1
2647
+ logger .info (each_loss )
2648
+ f = np .zeros ((len (layer_names ),layer_size ))
2649
+
2650
+ #init
2651
+ v , w = each_loss [layer_names [0 ]]
2652
+ w = int (w )
2653
+ if v > 0 :
2654
+ for V in range (w , layer_size - 1 ):
2655
+ f [0 ][V ] = v
2656
+ choosed_layer [0 ][V ].append (layer_name )
2657
+
2658
+ for layer_name in layer_names :
2659
+ viewd_layer += 1
2660
+ loss , size = each_loss [layer_name ]
2661
+ size = int (size )
2662
+ for V in tqdm (range (layer_size )):
2663
+ if size < V :
2664
+ if f [viewd_layer - 1 ][V ]> (f [viewd_layer - 1 ][V - size ]+ loss ):
2665
+ f [viewd_layer ][V ] = f [viewd_layer - 1 ][V ]
2666
+ choosed_layer [viewd_layer ][V ] = choosed_layer [viewd_layer - 1 ][V ]
2667
+ else :
2668
+ f [viewd_layer ][V ] = f [viewd_layer - 1 ][V - size ]+ loss
2669
+ choosed_layer [viewd_layer ][V ] = choosed_layer [viewd_layer - 1 ][V ]
2670
+ choosed_layer [viewd_layer ][V ].append (layer_name )
2671
+
2672
+ else :
2673
+ f [viewd_layer ][V ] = f [viewd_layer - 1 ][V ]
2674
+ choosed_layer [viewd_layer ][V ] = choosed_layer [viewd_layer - 1 ][V ]
2675
+
2676
+
2677
+ # tmp_list.append(max_loss[1])
2678
+ logger .info (f"得到的value:{ f [viewd_layer ][layer_size - 1 ]} \n 选择的layer:{ choosed_layer [viewd_layer ][layer_size - 1 ]} " )
2679
+ for layer_name in choosed_layer [viewd_layer ][layer_size - 1 ]:
2680
+
2681
+ module = get_module (block , layer_name )
2682
+
2683
+ for key in cur_config :
2684
+ setattr (module , key , cur_config [key ])
2685
+
2686
+ self .layer_config [module .tmp_name ] = cur_config
2687
+ # continue
2688
+
2689
+
2493
2690
def choose_various_bit (
2494
2691
self ,
2495
2692
block ,
0 commit comments