@@ -2572,10 +2572,9 @@ def _quantize_layer(
2572
2572
whole_indices = torch .randperm (nsamples )[:pick_samples ]
2573
2573
if gradient_accumulate_steps != 1 :
2574
2574
if q_inputs is not None :
2575
- current_input = [ q_inputs [ i ] for i in whole_indices ]
2575
+ num_elm = self . _get_current_num_elm ( q_input_ids , whole_indices )
2576
2576
else :
2577
- current_input = [inputs [i ] for i in whole_indices ]
2578
- num_elm = sum (id .numel () for id in current_input )
2577
+ num_elm = self ._get_current_num_elm (inputs , whole_indices )
2579
2578
for tmp_step in range (gradient_accumulate_steps ):
2580
2579
indices = whole_indices [tmp_step * batch_size : (tmp_step + 1 ) * batch_size ]
2581
2580
if q_inputs is not None :
@@ -2700,6 +2699,14 @@ def _get_current_q_output(
2700
2699
output_q = block_forward (block , current_input_ids , current_input_others , self .amp , self .amp_dtype , device )
2701
2700
return output_q
2702
2701
2702
+ def _get_current_num_elm (
2703
+ self ,
2704
+ input_ids : list [torch .Tensor ],
2705
+ indices : list [int ],
2706
+ ) -> int :
2707
+ current_input_ids = [input_ids [i ] for i in indices ]
2708
+ return sum (id .numel () for id in current_input_ids )
2709
+
2703
2710
def _quantize_block (
2704
2711
self ,
2705
2712
block : torch .nn .Module ,
@@ -2840,8 +2847,7 @@ def _quantize_block(
2840
2847
whole_indices = torch .randperm (nsamples )[:pick_samples ]
2841
2848
# We assume the block input and output shape is same
2842
2849
if self .gradient_accumulate_steps != 1 :
2843
- current_input_ids = [input_ids [i ] for i in whole_indices ]
2844
- num_elm = sum (id .numel () for id in current_input_ids )
2850
+ num_elm = self ._get_current_num_elm (input_ids , whole_indices )
2845
2851
2846
2852
for tmp_step in range (self .gradient_accumulate_steps ):
2847
2853
indices = whole_indices [tmp_step * self .batch_size : (tmp_step + 1 ) * self .batch_size ]
0 commit comments