1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Tuple
15+ from typing import Optional , Tuple
1616from weakref import ref
1717
1818import torch
@@ -42,7 +42,7 @@ def get_min_max(self, observed: torch.Tensor):
4242 return min_vals , max_vals
4343
4444 def forward (self , observed : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
45- observed = flatten_for_quantization (observed , self .base_name , self .args )
45+ observed = flatten_for_calibration (observed , self .base_name , self .args )
4646
4747 self .min_vals , self .max_vals = self .get_min_max (observed )
4848
@@ -57,26 +57,31 @@ def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5757
5858 def get_global_scale (self , observed : torch .Tensor ):
5959 observed = observed .reshape ((1 , 1 , - 1 )) # per tensor reshape
60- min_vals , max_vals = self .get_min_max (observed )
61- global_scale = generate_gparam (min_vals , max_vals )
60+ self . min_vals , self . max_vals = self .get_min_max (observed )
61+ global_scale = generate_gparam (self . min_vals , self . max_vals )
6262
6363 return global_scale
6464
6565
66- def flatten_for_quantization (
67- value : torch .Tensor , base_name : str , args : QuantizationArgs
66+ def flatten_for_calibration (
67+ value : torch .Tensor ,
68+ base_name : str ,
69+ args : QuantizationArgs ,
70+ g_idx : Optional [torch .Tensor ] = None ,
6871) -> torch .Tensor :
6972 if base_name == "weight" :
70- return flatten_weight_for_quantization (value , args )
73+ return _flatten_weight (value , args , g_idx )
7174 elif base_name in ("input" , "output" ):
72- return flatten_activation_for_quantization (value , args )
75+ return _flatten_activation (value , args )
7376 elif base_name in ("q" , "k" , "v" ):
74- return flatten_attention_for_quantization (value , args )
77+ return _flatten_attention (value , args )
7578 else :
7679 raise ValueError (f"Unknown quantization base name: { base_name } " )
7780
7881
79- def flatten_weight_for_quantization (value : torch .Tensor , args : QuantizationArgs ):
82+ def _flatten_weight (
83+ value : torch .Tensor , args : QuantizationArgs , g_idx : Optional [torch .Tensor ] = None
84+ ):
8085 # value.shape = (num_rows, num_cols)
8186
8287 if args .strategy == QuantizationStrategy .TENSOR :
@@ -91,34 +96,32 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs)
9196 return value .unsqueeze (- 2 ).unsqueeze (0 )
9297
9398 if args .strategy in (QuantizationStrategy .GROUP , QuantizationStrategy .TENSOR_GROUP ):
99+ if g_idx is not None :
100+ value = value .index_select (dim = 1 , index = torch .argsort (g_idx ))
101+
94102 # (1, num_rows, num_groups, group_size)
95103 return value .unflatten (- 1 , (- 1 , args .group_size )).unsqueeze (0 )
96104
97105 if args .strategy == QuantizationStrategy .BLOCK :
98106 # (1, num_block_rows, num_block_cols, block_width * block_height)
99107 block_height , block_width = args .block_structure
100- num_rows , num_cols = value .shape
101- num_block_rows = strategy_cdiv (num_rows , block_height , args .strategy )
102- num_block_cols = strategy_cdiv (num_cols , block_width , args .strategy )
108+ rows , cols = value .shape
109+ block_rows = strategy_cdiv (rows , block_height , args .strategy , strict = True )
110+ block_cols = strategy_cdiv (cols , block_width , args .strategy , strict = True )
103111 return (
104- value .reshape (
105- num_block_rows ,
106- block_height ,
107- num_block_cols ,
108- block_width ,
109- )
112+ value .reshape (block_rows , block_height , block_cols , block_width )
110113 .transpose (1 , 2 )
111114 .flatten (- 2 , - 1 )
112115 .unsqueeze (0 )
113116 )
114117
115118 if args .strategy == QuantizationStrategy .ATTN_HEAD :
116- raise ValueError ("attention head quantization cannot be applied to weights" )
119+ raise ValueError ("Attention head quantization cannot be applied to weights" )
117120
118121 assert False , f"Unknown strategy { args .strategy } "
119122
120123
121- def flatten_activation_for_quantization (value : torch .Tensor , args : QuantizationArgs ):
124+ def _flatten_activation (value : torch .Tensor , args : QuantizationArgs ):
122125 # value.shape = (batch_size, seq_len, hidden_dim)
123126
124127 if args .strategy == QuantizationStrategy .TENSOR :
@@ -128,7 +131,7 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
128131 if args .strategy == QuantizationStrategy .TOKEN :
129132 # (batch_size, seq_len, hidden_dim)
130133 # warning: token quantization uses `compute_dynamic_scales_and_zp`
131- return value . flatten ( 2 , - 1 )
134+ return value
132135
133136 if args .strategy == QuantizationStrategy .CHANNEL :
134137 raise ValueError ("Channel quantization cannot be applied to activations" )
@@ -142,12 +145,12 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
142145 raise ValueError ("Block quantization cannot be applied to activations" )
143146
144147 if args .strategy == QuantizationStrategy .ATTN_HEAD :
145- raise ValueError ("attention head quantization cannot be applied to linear acts " )
148+ raise ValueError ("Attention head quantization cannot be applied to activations " )
146149
147150 assert False , f"Unknown strategy { args .strategy } "
148151
149152
150- def flatten_attention_for_quantization (value : torch .Tensor , args : QuantizationArgs ):
153+ def _flatten_attention (value : torch .Tensor , args : QuantizationArgs ):
151154 # value.shape = (batch_size, num_heads, seq_len, head_dim)
152155
153156 if args .strategy == QuantizationStrategy .TENSOR :
@@ -161,7 +164,8 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr
161164 raise ValueError ("Channel quantization cannot be applied to attention" )
162165
163166 if args .strategy in (QuantizationStrategy .GROUP , QuantizationStrategy .TENSOR_GROUP ):
164- raise ValueError ("Group quantization cannot be applied to attention" )
167+ # batch_size * num_heads * seq_len, num_groups, group_size)
168+ return value .flatten (0 , 2 ).unflatten (- 1 , (- 1 , args .group_size ))
165169
166170 if args .strategy == QuantizationStrategy .BLOCK :
167171 raise ValueError ("Block quantization cannot be applied to attention" )
0 commit comments