@@ -191,116 +191,94 @@ def find_quantized_linear_patterns(
191
191
##
192
192
193
193
194
- def pack_4bit_weight_tensor (inp : torch .Tensor ) -> torch .Tensor :
194
+ def pack_4bit_weight_tensor (weight_tensor : torch .Tensor ) -> torch .Tensor :
195
195
"""
196
196
Given a 8-bit weight tensor containing values quantized to 4 bits, create a packed
197
- weight tensor by packing 2 4-bit values in one unsigned 8-bit value.
197
+ weight tensor by transposing the weight tensor, then packing 2 4-bit values in one
198
+ 8-bit value.
198
199
199
- An input weight tensor of shape (M, K) will produce a packed weight tensor of shape
200
- (M, K / 2).
201
-
202
- The packing implemented here is the same as the packing produced by
203
- backends/vulkan/_passes/int4_weight_only_quantizer.py
200
+ An input weight tensor of shape (N, K) will produce a packed weight tensor of shape
201
+ (K, N / 2).
204
202
"""
205
203
206
204
# Assert we got a properly quantized tensor.
207
- min , max = inp .min ().item (), inp .max ().item ()
205
+ min_val , max_val = weight_tensor .min ().item (), weight_tensor .max ().item ()
208
206
assert (
209
- max <= 7 and min >= - 8
210
- ), f"pack_4bit_weight_tensor: [min,max ] out of [-8, 7] range, got [{ min } , { max } ]"
207
+ max_val <= 7 and min_val >= - 8
208
+ ), f"pack_4bit_weight_tensor: [min_val,max_val ] out of [-8, 7] range, got [{ min_val } , { max_val } ]"
211
209
212
210
# Assuming we have a 2d tensor
213
- if inp .ndim != 2 :
214
- inp = inp .squeeze ()
211
+ if weight_tensor .ndim != 2 :
212
+ weight_tensor = weight_tensor .squeeze ()
215
213
assert (
216
- inp .ndim == 2
217
- ), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got { inp .ndim } "
214
+ weight_tensor .ndim == 2
215
+ ), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got { weight_tensor .ndim } "
218
216
219
- # pad ic
220
- if inp .shape [- 1 ] % 2 != 0 :
221
- inp = F .pad (input = inp , pad = (0 , 1 , 0 , 0 ), mode = "constant" , value = 0 )
217
+ # Need to pad innermost dim to be a multiple of 8, since the minimum load granularity
218
+ # is int32 (4 bytes), which contains 8 4-bit values.
219
+ if weight_tensor .shape [- 1 ] % 8 != 0 :
220
+ num_pad = 8 - (weight_tensor .shape [- 1 ] % 8 )
221
+ weight_tensor = F .pad (input = weight_tensor , pad = (0 , num_pad ))
222
222
223
223
# Shape after padding
224
- oc , ic = inp .shape
225
- assert ic % 2 == 0 , "convert_to_qc4w: expecting ic to be even "
224
+ _ , in_channels = weight_tensor .shape
225
+ assert in_channels % 8 == 0 , "convert_to_qc4w: expecting ic to be divisible by 8 "
226
226
227
- # Adjust inp tensor for zp
228
- inp = inp .to (dtype = torch .uint8 ) + 8
227
+ # Adjust weight_tensor tensor for zp
228
+ weight_tensor = weight_tensor .to (dtype = torch .uint8 ) + 8
229
229
# Pack each 4-bit value into a single 8-bit value
230
- return inp [::, ::2 ] << 4 | inp [::, 1 ::2 ]
231
-
232
-
233
- def make_combined_scales_and_zeros_tensor (
234
- scales : torch .Tensor , zeros : torch .Tensor
235
- ) -> torch .Tensor :
236
- """
237
- Given a scales and zeros tensor, create a combined tensor by stacking them into a
238
- single tensor.
239
-
240
- The scales and zeros tensors are expected to be 2D tensors of shape
241
- (OUTPUT_CHANNELS, NUM_GROUPS). The combined tensor will have the shape
242
- (NUM_GROUPS, OUTPUT_CHANNELS, 2).
243
-
244
- This is the scales and zeros format produced by
245
- backends/vulkan/_passes/int4_weight_only_quantizer.py, which in turn is the scales
246
- and zeros format expected by the _weight_int4pack_mm op in ATen.
247
- """
248
- scales_reshaped = scales .transpose (0 , 1 ).unsqueeze (2 )
249
- zeros_reshaped = zeros .transpose (0 , 1 ).unsqueeze (2 )
250
-
251
- zeros_scaled = zeros_reshaped * scales_reshaped * - 1
252
- return torch .cat ((scales_reshaped , zeros_scaled ), dim = 2 )
230
+ return weight_tensor [::, 1 ::2 ] << 4 | weight_tensor [::, ::2 ]
253
231
254
232
255
233
##
256
234
## Pattern Replacement
257
235
##
258
236
259
237
260
- def make_linear_q4ga_op (
238
+ def make_linear_q4gsw_op (
261
239
ep : ExportedProgram ,
262
240
graph_module : torch .fx .GraphModule ,
263
241
match : QuantizedLinearMatch ,
264
242
weight_tensor : torch .Tensor ,
265
243
weight_scales_tensor : torch .Tensor ,
266
- weight_zeros_tensor : torch .Tensor ,
267
244
):
268
- packed_quantized_weight_tensor = pack_4bit_weight_tensor (weight_tensor )
269
- utils .update_program_state_dict (
270
- ep , match .weight_node .name , packed_quantized_weight_tensor
271
- )
272
- # Need to make sure corresponding FakeTensor has same size
273
- match .weight_node .meta ["val" ] = match .weight_node .meta ["val" ][:, ::2 ].to (
274
- torch .uint8
275
- )
276
-
277
- group_size = weight_tensor .shape [1 ] // weight_scales_tensor .shape [1 ]
278
-
279
- combined_scales_zeros_tensor = make_combined_scales_and_zeros_tensor (
280
- weight_scales_tensor , weight_zeros_tensor
245
+ num_groups = weight_scales_tensor .shape [- 1 ]
246
+ in_channels = weight_tensor .shape [- 1 ]
247
+ group_size = in_channels // num_groups
248
+
249
+ weight_tensor = pack_4bit_weight_tensor (weight_tensor )
250
+ # Use this function for convenience to update the state dict with the packed
251
+ # weight tensor. Alignment will already have been done in the above function.
252
+ weight_tensor = utils .align_width_and_update_state_dict (
253
+ ep , match .weight_node , weight_tensor , align_to = 1 , force_update = True
281
254
)
282
255
283
- combined_scales_zeros_name = f"{ match .weight_node .name } _scales_zeros"
284
- graph_module .register_parameter (
285
- combined_scales_zeros_name , torch .nn .Parameter (combined_scales_zeros_tensor )
256
+ # Also transpose the weight scales tensor to shape [num_groups, N]
257
+ weight_scales_tensor = weight_scales_tensor .transpose (0 , 1 ).contiguous ()
258
+ # Align to multiple of 8 to ensure that data loads from the weight scales
259
+ # tensor do not go out of bounds. Each thread computes 8 output channels.
260
+ utils .align_width_and_update_state_dict (
261
+ ep ,
262
+ match .weight_scales_node ,
263
+ weight_scales_tensor ,
264
+ align_to = 8 ,
265
+ force_update = True ,
286
266
)
287
267
288
268
with graph_module .graph .inserting_before (match .output_node ):
289
- combined_scales_zeros = graph_module .graph .get_attr (combined_scales_zeros_name )
290
- linear_q4ga_node = graph_module .graph .create_node (
269
+ linear_q4gsw_node = graph_module .graph .create_node (
291
270
"call_function" ,
292
- exir_ops .edge .et_vk .linear_weight_int4 .default ,
271
+ exir_ops .edge .et_vk .linear_q4gsw .default ,
293
272
args = (
294
273
match .fp_input_node ,
295
274
match .weight_node ,
275
+ match .weight_scales_node ,
296
276
group_size ,
297
- combined_scales_zeros ,
298
- 1 ,
299
277
),
300
278
)
301
279
302
- linear_q4ga_node .meta ["val" ] = match .output_node .meta ["val" ]
303
- match .output_node .replace_all_uses_with (linear_q4ga_node )
280
+ linear_q4gsw_node .meta ["val" ] = match .output_node .meta ["val" ]
281
+ match .output_node .replace_all_uses_with (linear_q4gsw_node )
304
282
305
283
306
284
def make_linear_q8ta_q8csw_custom_op (
@@ -373,13 +351,8 @@ def replace_quantized_linear_patterns(
373
351
and match .is_weight_pergroup_quantized ()
374
352
and utils .is_in_4bit_range (weight_tensor )
375
353
):
376
- make_linear_q4ga_op (
377
- ep ,
378
- graph_module ,
379
- match ,
380
- weight_tensor ,
381
- weight_scales_tensor ,
382
- weight_zeros_tensor ,
354
+ make_linear_q4gsw_op (
355
+ ep , graph_module , match , weight_tensor , weight_scales_tensor
383
356
)
384
357
elif (
385
358
match .is_input_static_per_tensor_quantized ()
0 commit comments