@@ -2279,11 +2279,11 @@ def get_act_max_hook(module, input, output):
2279
2279
# # Apply the configuration to the corresponding layer in the model
2280
2280
# for key in keys:
2281
2281
# setattr(m, key, low_config[key])
2282
-
2282
+
2283
2283
# return layer_config
2284
2284
2285
2285
@torch .inference_mode ()
2286
- def check_needs_auto_gguf_mix_mse (self , block , formats , input_ids , input_others , outputs , device ,cache_device ):
2286
+ def check_needs_auto_gguf_mix_mse (self , block , formats , input_ids , input_others , outputs , device , cache_device ):
2287
2287
## TODO Q4_K_M does not support iters==0
2288
2288
## TODO for moe model, expert use default bits
2289
2289
mse_reduction = "mean"
@@ -2293,35 +2293,35 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
2293
2293
2294
2294
target_gguf_format = None
2295
2295
for format in formats :
2296
- if format .startswith ("gguf" ) and 'm' in format :
2296
+ if format .startswith ("gguf" ) and "m" in format :
2297
2297
target_gguf_format = format
2298
2298
if target_gguf_format is None :
2299
2299
return
2300
2300
2301
2301
## simple verification, if the layer_config has any mixed-bits setting, we don't apply auto mix precision
2302
- bits = []
2303
- count = 0
2302
+ bits = []
2303
+ count = 0
2304
2304
quant_bits = {}
2305
- for n , m in block .named_modules ():# [4 4 6 4 4 6 8]
2305
+ for n , m in block .named_modules (): # [4 4 6 4 4 6 8]
2306
2306
if hasattr (m , "bits" ):
2307
2307
bits .append (m .bits )
2308
- quant_bits [m .bits ]= 0
2308
+ quant_bits [m .bits ] = 0
2309
2309
ori_bit = min (bits )
2310
2310
for b in bits :
2311
2311
if b != ori_bit :
2312
- quant_bits [b ]+= 1
2313
- bits = set (bits ) # {4,6}
2312
+ quant_bits [b ] += 1
2313
+ bits = set (bits ) # {4,6}
2314
2314
if len (bits ) <= 1 :
2315
2315
return
2316
2316
del quant_bits [min (bits )]
2317
-
2317
+
2318
2318
layer_names = []
2319
-
2319
+
2320
2320
for n , m in block .named_modules ():
2321
2321
if check_to_quantized (m ):
2322
2322
layer_names .append (n )
2323
- count += 1
2324
-
2323
+ count += 1
2324
+
2325
2325
if count > 10 :
2326
2326
logger .info ("不进行选择" )
2327
2327
return
@@ -2334,17 +2334,17 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
2334
2334
# current_output = to_device(current_output, device)
2335
2335
current_input_ids = [input_ids [i ] for i in whole_indices ]
2336
2336
default_config = GGUF_CONFIG [target_gguf_format ]
2337
- split_list = re .split (' :|_' , target_gguf_format )
2337
+ split_list = re .split (" :|_" , target_gguf_format )
2338
2338
mix_configs = {}
2339
-
2340
- for k ,_ in quant_bits .items ():
2339
+
2340
+ for k , _ in quant_bits .items ():
2341
2341
mix_configs [k ] = GGUF_CONFIG [f"gguf:q{ k } _{ split_list [2 ]} " ]
2342
-
2342
+
2343
2343
d_format = [f"gguf:q{ min (bits )} _{ split_list [2 ]} " ]
2344
2344
low_config = GGUF_CONFIG [f"gguf:q{ min (bits )} _{ split_list [2 ]} " ]
2345
2345
2346
2346
default_layer_config = low_config
2347
-
2347
+
2348
2348
# for k in self.layer_config.keys():
2349
2349
# s = re.split('\.',k)
2350
2350
# if len(s) <2:
@@ -2355,73 +2355,135 @@ def check_needs_auto_gguf_mix_mse(self, block, formats, input_ids, input_others,
2355
2355
2356
2356
if len (bits ) == 2 :
2357
2357
logger .info ("量化单bit" )
2358
- self .choose_one_bit (block ,mix_configs ,quant_bits ,default_config ,default_layer_config ,layer_names ,current_input_ids ,input_others ,current_output ,mse_loss ,device ,cache_device )
2358
+ self .choose_one_bit (
2359
+ block ,
2360
+ mix_configs ,
2361
+ quant_bits ,
2362
+ default_config ,
2363
+ default_layer_config ,
2364
+ layer_names ,
2365
+ current_input_ids ,
2366
+ input_others ,
2367
+ current_output ,
2368
+ mse_loss ,
2369
+ device ,
2370
+ cache_device ,
2371
+ )
2359
2372
else :
2360
2373
logger .info ("量化多bit" )
2361
- self .choose_various_bit (block ,mix_configs ,quant_bits ,default_config ,default_layer_config ,layer_names ,current_input_ids ,input_others ,current_output ,mse_loss ,device ,cache_device )
2362
-
2374
+ self .choose_various_bit (
2375
+ block ,
2376
+ mix_configs ,
2377
+ quant_bits ,
2378
+ default_config ,
2379
+ default_layer_config ,
2380
+ layer_names ,
2381
+ current_input_ids ,
2382
+ input_others ,
2383
+ current_output ,
2384
+ mse_loss ,
2385
+ device ,
2386
+ cache_device ,
2387
+ )
2363
2388
2364
- def choose_one_bit (self ,block ,mix_configs ,quant_bits ,default_config ,default_layer_config ,layer_names ,current_input_ids ,input_others ,current_output ,mse_loss ,device ,cache_device ):
2389
+ def choose_one_bit (
2390
+ self ,
2391
+ block ,
2392
+ mix_configs ,
2393
+ quant_bits ,
2394
+ default_config ,
2395
+ default_layer_config ,
2396
+ layer_names ,
2397
+ current_input_ids ,
2398
+ input_others ,
2399
+ current_output ,
2400
+ mse_loss ,
2401
+ device ,
2402
+ cache_device ,
2403
+ ):
2365
2404
each_loss = {}
2366
2405
# bit = mix_configs.keys()[0]
2367
- [(_ ,cur_config )] = mix_configs .items ()
2368
- [(_ ,num_bit )] = quant_bits .items ()
2406
+ [(_ , cur_config )] = mix_configs .items ()
2407
+ [(_ , num_bit )] = quant_bits .items ()
2369
2408
for layer_name in layer_names :
2370
2409
module = get_module (block , layer_name )
2371
2410
self .layer_config [module .tmp_name ] = default_config
2372
- for key in cur_config :
2373
- setattr (module ,key ,cur_config [key ])
2374
-
2375
- wrapper_layer = WrapperLinear (module ,enable_minmax_tuning = False ,enable_round_tuning = False ,enable_norm_bias_tuning = False ,device = device )
2411
+ for key in cur_config :
2412
+ setattr (module , key , cur_config [key ])
2413
+
2414
+ wrapper_layer = WrapperLinear (
2415
+ module ,
2416
+ enable_minmax_tuning = False ,
2417
+ enable_round_tuning = False ,
2418
+ enable_norm_bias_tuning = False ,
2419
+ device = device ,
2420
+ )
2376
2421
set_module (block , layer_name , wrapper_layer )
2377
- q_output = self .get_block_outputs (block , current_input_ids , input_others , self . batch_size * self . infer_bs_coeff ,
2378
- device ,
2379
- cache_device )
2380
-
2381
- set_module (block ,layer_name ,wrapper_layer .orig_layer )
2422
+ q_output = self .get_block_outputs (
2423
+ block , current_input_ids , input_others , self . batch_size * self . infer_bs_coeff , device , cache_device
2424
+ )
2425
+
2426
+ set_module (block , layer_name , wrapper_layer .orig_layer )
2382
2427
module = get_module (block , layer_name )
2383
- for key in default_config :
2384
- setattr (module ,key ,default_config [key ])
2385
- cur_loss = mse_loss (torch .stack (q_output ).squeeze (1 ),current_output )
2386
- each_loss [layer_name ] = cur_loss # 把每一层的loss记录下来
2387
-
2428
+ for key in default_config :
2429
+ setattr (module , key , default_config [key ])
2430
+ cur_loss = mse_loss (torch .stack (q_output ).squeeze (1 ), current_output )
2431
+ each_loss [layer_name ] = cur_loss # 把每一层的loss记录下来
2432
+
2388
2433
top_n_loss = sorted (each_loss .items (), key = lambda x : x [1 ], reverse = False )[:num_bit ]
2389
2434
# breakpoint()
2390
2435
# tmp_list.append(max_loss[1])
2391
2436
flag = {}
2392
- for layer_name ,_ in top_n_loss :
2437
+ for layer_name , _ in top_n_loss :
2393
2438
module = get_module (block , layer_name )
2394
- for key in cur_config :
2395
- setattr (module ,key ,cur_config [key ])
2396
-
2439
+ for key in cur_config :
2440
+ setattr (module , key , cur_config [key ])
2441
+
2397
2442
self .layer_config [module .tmp_name ] = cur_config
2398
- # continue
2443
+ # continue
2399
2444
2400
-
2401
-
2402
- def choose_various_bit (self ,block ,mix_configs ,quant_bits ,cur_config ,default_config ,default_layer_config ,layer_names ,current_input_ids ,input_others ,current_output ,mse_loss ,device ,cache_device ):
2445
+ def choose_various_bit (
2446
+ self ,
2447
+ block ,
2448
+ mix_configs ,
2449
+ quant_bits ,
2450
+ cur_config ,
2451
+ default_config ,
2452
+ default_layer_config ,
2453
+ layer_names ,
2454
+ current_input_ids ,
2455
+ input_others ,
2456
+ current_output ,
2457
+ mse_loss ,
2458
+ device ,
2459
+ cache_device ,
2460
+ ):
2403
2461
each_loss = {}
2404
2462
for layer_name in layer_names :
2405
2463
module = get_module (block , layer_name )
2406
- for key in default_config :
2407
- setattr (module ,key ,cur_config [key ])
2408
-
2409
- wrapper_layer = WrapperLinear (module ,enable_minmax_tuning = False ,enable_round_tuning = False ,enable_norm_bias_tuning = False ,device = device )
2464
+ for key in default_config :
2465
+ setattr (module , key , cur_config [key ])
2466
+
2467
+ wrapper_layer = WrapperLinear (
2468
+ module ,
2469
+ enable_minmax_tuning = False ,
2470
+ enable_round_tuning = False ,
2471
+ enable_norm_bias_tuning = False ,
2472
+ device = device ,
2473
+ )
2410
2474
set_module (block , layer_name , wrapper_layer )
2411
- q_output = self .get_block_outputs (block , current_input_ids , input_others , self . batch_size * self . infer_bs_coeff ,
2412
- device ,
2413
- cache_device )
2414
- set_module (block ,layer_name ,wrapper_layer .orig_layer )
2415
-
2416
- cur_loss = mse_loss (torch .stack (q_output ).squeeze (1 ),current_output )
2417
- each_loss [layer_name ] = cur_loss # 把每一层的loss记录下来
2418
-
2419
- top_n_loss = sorted (each_loss .items (), key = lambda x : x [1 ], reverse = True )[:sum (quant_bits .values ())]
2475
+ q_output = self .get_block_outputs (
2476
+ block , current_input_ids , input_others , self . batch_size * self . infer_bs_coeff , device , cache_device
2477
+ )
2478
+ set_module (block , layer_name , wrapper_layer .orig_layer )
2479
+
2480
+ cur_loss = mse_loss (torch .stack (q_output ).squeeze (1 ), current_output )
2481
+ each_loss [layer_name ] = cur_loss # 把每一层的loss记录下来
2482
+
2483
+ top_n_loss = sorted (each_loss .items (), key = lambda x : x [1 ], reverse = True )[: sum (quant_bits .values ())]
2420
2484
shift = 0
2421
- for k ,_ in top_n_loss .items ():
2485
+ for k , _ in top_n_loss .items ():
2422
2486
self .layer_config [module .tmp_name ] = cur_config
2423
-
2424
-
2425
2487
2426
2488
def quant_block (self , block , input_ids , input_others , q_input = None , device = torch .device ("cpu" )):
2427
2489
"""Quantize the weights of a given block of the model.
@@ -2476,7 +2538,8 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
2476
2538
handle .remove ()
2477
2539
2478
2540
self .check_needs_auto_gguf_mix_mse (
2479
- block , self .formats , input_ids , input_others , output , device , self .cache_device )
2541
+ block , self .formats , input_ids , input_others , output , device , self .cache_device
2542
+ )
2480
2543
2481
2544
if q_input is not None :
2482
2545
if input_ids is not q_input :
0 commit comments