Skip to content

Commit ab95e63

Browse files
author
ssjia
committed
[ET-VK] Implement linear_q4gsw
Pull Request resolved: #14020 As title. Extend the quantized linear implementation to be able to handle 4-bit per group symmetrically quantized weights. This is in preparation to support using the int8 dot product extension to be able to handle dynamically quantized inputs. ghstack-source-id: 308092879 @exported-using-ghexport Differential Revision: [D81800023](https://our.internmc.facebook.com/intern/diff/D81800023/)
1 parent 2ac3676 commit ab95e63

25 files changed

+2522
-131
lines changed

backends/vulkan/custom_ops_lib.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ def linear_weight_int4_impl(
187187
lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd")
188188
linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name)
189189

190-
#################
190+
##################
191191
## linear_qcs4w ##
192-
#################
192+
##################
193193

194194

195195
def linear_qcs4w(
@@ -234,6 +234,79 @@ def linear_qcs4w(
234234
lib.impl(name, linear_qcs4w, "CompositeExplicitAutograd")
235235
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)
236236

237+
##################
238+
## linear_q4gsw ##
239+
##################
240+
241+
242+
def unpack_4bit_weight_tensor(
243+
packed_weight_tensor: torch.Tensor, x: torch.Tensor
244+
) -> torch.Tensor:
245+
"""
246+
Reverses the packing performed in quantized_linear.pack_4bit_weight_tensor
247+
"""
248+
# Each packed byte contains two 4-bit values: high nibble and low nibble
249+
K, N_half = packed_weight_tensor.shape
250+
N = N_half * 2
251+
252+
# Unpack high and low nibbles
253+
high_nibble = (packed_weight_tensor >> 4) & 0x0F
254+
low_nibble = packed_weight_tensor & 0x0F
255+
256+
# Stack to shape (K, N)
257+
unpacked = torch.empty(
258+
(K, N), dtype=torch.uint8, device=packed_weight_tensor.device
259+
)
260+
unpacked[:, ::2] = low_nibble
261+
unpacked[:, 1::2] = high_nibble
262+
263+
# Undo the +8 offset and convert to signed 4-bit range [-8, 7]
264+
unpacked = unpacked.to(torch.int8) - 8
265+
266+
in_channels = x.shape[-1]
267+
# Undo any padding that may have been added to input channels
268+
if in_channels != unpacked.shape[-1]:
269+
return unpacked[:, :in_channels]
270+
271+
return unpacked
272+
273+
274+
def linear_q4gsw(
275+
x: torch.Tensor,
276+
weights: torch.Tensor,
277+
weight_scales: torch.Tensor,
278+
group_size: int,
279+
bias: Optional[torch.Tensor] = None,
280+
):
281+
# Unpack the packed weights
282+
weights = unpack_4bit_weight_tensor(weights, x)
283+
284+
# Un-transpose the weight scales
285+
weight_scales = weight_scales.transpose(0, 1)
286+
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
287+
288+
weights = torch.ops.torchao.dequantize_affine(
289+
weights, [1, group_size], weight_scales, weight_zeros, torch.int8, -8, 7
290+
)
291+
292+
out = torch.nn.functional.linear(x, weights)
293+
return out
294+
295+
296+
name = "linear_q4gsw"
297+
lib.define(
298+
f"""
299+
{name}(
300+
Tensor self,
301+
Tensor weights,
302+
Tensor weight_scales,
303+
int group_size,
304+
Tensor? bias = None) -> Tensor
305+
"""
306+
)
307+
lib.impl(name, linear_q4gsw, "CompositeExplicitAutograd")
308+
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)
309+
237310
########################
238311
## linear_qta8a_qga4w ##
239312
########################

backends/vulkan/op_registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,10 @@ def register_int8_mm_op():
334334
@update_features(
335335
[
336336
exir_ops.edge.et_vk.linear_q8ta_q8csw.default,
337+
exir_ops.edge.et_vk.linear_q4gsw.default,
337338
]
338339
)
339-
def register_qa_qw_linear():
340+
def register_quantized_linear_ops():
340341
return OpFeatures(
341342
inputs_storage=utils.CONTIGUOUS_ANY,
342343
supports_prepacking=True,

backends/vulkan/patterns/quantized_linear.py

Lines changed: 49 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -191,116 +191,94 @@ def find_quantized_linear_patterns(
191191
##
192192

193193

194-
def pack_4bit_weight_tensor(inp: torch.Tensor) -> torch.Tensor:
194+
def pack_4bit_weight_tensor(weight_tensor: torch.Tensor) -> torch.Tensor:
195195
"""
196196
Given a 8-bit weight tensor containing values quantized to 4 bits, create a packed
197-
weight tensor by packing 2 4-bit values in one unsigned 8-bit value.
197+
weight tensor by transposing the weight tensor, then packing 2 4-bit values in one
198+
8-bit value.
198199
199-
An input weight tensor of shape (M, K) will produce a packed weight tensor of shape
200-
(M, K / 2).
201-
202-
The packing implemented here is the same as the packing produced by
203-
backends/vulkan/_passes/int4_weight_only_quantizer.py
200+
An input weight tensor of shape (N, K) will produce a packed weight tensor of shape
201+
(K, N / 2).
204202
"""
205203

206204
# Assert we got a properly quantized tensor.
207-
min, max = inp.min().item(), inp.max().item()
205+
min_val, max_val = weight_tensor.min().item(), weight_tensor.max().item()
208206
assert (
209-
max <= 7 and min >= -8
210-
), f"pack_4bit_weight_tensor: [min,max] out of [-8, 7] range, got [{min}, {max}]"
207+
max_val <= 7 and min_val >= -8
208+
), f"pack_4bit_weight_tensor: [min_val,max_val] out of [-8, 7] range, got [{min_val}, {max_val}]"
211209

212210
# Assuming we have a 2d tensor
213-
if inp.ndim != 2:
214-
inp = inp.squeeze()
211+
if weight_tensor.ndim != 2:
212+
weight_tensor = weight_tensor.squeeze()
215213
assert (
216-
inp.ndim == 2
217-
), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got {inp.ndim}"
214+
weight_tensor.ndim == 2
215+
), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got {weight_tensor.ndim}"
218216

219-
# pad ic
220-
if inp.shape[-1] % 2 != 0:
221-
inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0)
217+
# Need to pad innermost dim to be a multiple of 8, since the minimum load granularity
218+
# is int32 (4 bytes), which contains 8 4-bit values.
219+
if weight_tensor.shape[-1] % 8 != 0:
220+
num_pad = 8 - (weight_tensor.shape[-1] % 8)
221+
weight_tensor = F.pad(input=weight_tensor, pad=(0, num_pad))
222222

223223
# Shape after padding
224-
oc, ic = inp.shape
225-
assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even"
224+
_, in_channels = weight_tensor.shape
225+
assert in_channels % 8 == 0, "convert_to_qc4w: expecting ic to be divisible by 8"
226226

227-
# Adjust inp tensor for zp
228-
inp = inp.to(dtype=torch.uint8) + 8
227+
# Adjust weight_tensor tensor for zp
228+
weight_tensor = weight_tensor.to(dtype=torch.uint8) + 8
229229
# Pack each 4-bit value into a single 8-bit value
230-
return inp[::, ::2] << 4 | inp[::, 1::2]
231-
232-
233-
def make_combined_scales_and_zeros_tensor(
234-
scales: torch.Tensor, zeros: torch.Tensor
235-
) -> torch.Tensor:
236-
"""
237-
Given a scales and zeros tensor, create a combined tensor by stacking them into a
238-
single tensor.
239-
240-
The scales and zeros tensors are expected to be 2D tensors of shape
241-
(OUTPUT_CHANNELS, NUM_GROUPS). The combined tensor will have the shape
242-
(NUM_GROUPS, OUTPUT_CHANNELS, 2).
243-
244-
This is the scales and zeros format produced by
245-
backends/vulkan/_passes/int4_weight_only_quantizer.py, which in turn is the scales
246-
and zeros format expected by the _weight_int4pack_mm op in ATen.
247-
"""
248-
scales_reshaped = scales.transpose(0, 1).unsqueeze(2)
249-
zeros_reshaped = zeros.transpose(0, 1).unsqueeze(2)
250-
251-
zeros_scaled = zeros_reshaped * scales_reshaped * -1
252-
return torch.cat((scales_reshaped, zeros_scaled), dim=2)
230+
return weight_tensor[::, 1::2] << 4 | weight_tensor[::, ::2]
253231

254232

255233
##
256234
## Pattern Replacement
257235
##
258236

259237

260-
def make_linear_q4ga_op(
238+
def make_linear_q4gsw_op(
261239
ep: ExportedProgram,
262240
graph_module: torch.fx.GraphModule,
263241
match: QuantizedLinearMatch,
264242
weight_tensor: torch.Tensor,
265243
weight_scales_tensor: torch.Tensor,
266-
weight_zeros_tensor: torch.Tensor,
267244
):
268-
packed_quantized_weight_tensor = pack_4bit_weight_tensor(weight_tensor)
269-
utils.update_program_state_dict(
270-
ep, match.weight_node.name, packed_quantized_weight_tensor
271-
)
272-
# Need to make sure corresponding FakeTensor has same size
273-
match.weight_node.meta["val"] = match.weight_node.meta["val"][:, ::2].to(
274-
torch.uint8
275-
)
276-
277-
group_size = weight_tensor.shape[1] // weight_scales_tensor.shape[1]
278-
279-
combined_scales_zeros_tensor = make_combined_scales_and_zeros_tensor(
280-
weight_scales_tensor, weight_zeros_tensor
245+
num_groups = weight_scales_tensor.shape[-1]
246+
in_channels = weight_tensor.shape[-1]
247+
group_size = in_channels // num_groups
248+
249+
weight_tensor = pack_4bit_weight_tensor(weight_tensor)
250+
# Use this function for convenience to update the state dict with the packed
251+
# weight tensor. Alignment will already have been done in the above function.
252+
weight_tensor = utils.align_width_and_update_state_dict(
253+
ep, match.weight_node, weight_tensor, align_to=1, force_update=True
281254
)
282255

283-
combined_scales_zeros_name = f"{match.weight_node.name}_scales_zeros"
284-
graph_module.register_parameter(
285-
combined_scales_zeros_name, torch.nn.Parameter(combined_scales_zeros_tensor)
256+
# Also transpose the weight scales tensor to shape [num_groups, N]
257+
weight_scales_tensor = weight_scales_tensor.transpose(0, 1).contiguous()
258+
# Align to multiple of 8 to ensure that data loads from the weight scales
259+
# tensor do not go out of bounds. Each thread computes 8 output channels.
260+
utils.align_width_and_update_state_dict(
261+
ep,
262+
match.weight_scales_node,
263+
weight_scales_tensor,
264+
align_to=8,
265+
force_update=True,
286266
)
287267

288268
with graph_module.graph.inserting_before(match.output_node):
289-
combined_scales_zeros = graph_module.graph.get_attr(combined_scales_zeros_name)
290-
linear_q4ga_node = graph_module.graph.create_node(
269+
linear_q4gsw_node = graph_module.graph.create_node(
291270
"call_function",
292-
exir_ops.edge.et_vk.linear_weight_int4.default,
271+
exir_ops.edge.et_vk.linear_q4gsw.default,
293272
args=(
294273
match.fp_input_node,
295274
match.weight_node,
275+
match.weight_scales_node,
296276
group_size,
297-
combined_scales_zeros,
298-
1,
299277
),
300278
)
301279

302-
linear_q4ga_node.meta["val"] = match.output_node.meta["val"]
303-
match.output_node.replace_all_uses_with(linear_q4ga_node)
280+
linear_q4gsw_node.meta["val"] = match.output_node.meta["val"]
281+
match.output_node.replace_all_uses_with(linear_q4gsw_node)
304282

305283

306284
def make_linear_q8ta_q8csw_custom_op(
@@ -373,13 +351,8 @@ def replace_quantized_linear_patterns(
373351
and match.is_weight_pergroup_quantized()
374352
and utils.is_in_4bit_range(weight_tensor)
375353
):
376-
make_linear_q4ga_op(
377-
ep,
378-
graph_module,
379-
match,
380-
weight_tensor,
381-
weight_scales_tensor,
382-
weight_zeros_tensor,
354+
make_linear_q4gsw_op(
355+
ep, graph_module, match, weight_tensor, weight_scales_tensor
383356
)
384357
elif (
385358
match.is_input_static_per_tensor_quantized()

backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,19 @@ int extract_8bit_from_packed_int_le(const int packed, const int i) {
2929
return byte;
3030
}
3131

32+
// Extract a 4-bit value from a packed int (little endian)
33+
// It is assumed that the 4-bit value is in the range [0, 15]
34+
int extract_4bit_from_packed_int_le(const int packed, const int col) {
35+
// Extract the 4-bit value from the 8-bit value
36+
int val = packed >> (4 * col) & 0xF;
37+
return val;
38+
}
39+
40+
// Convenience overload for packed uint
41+
int extract_4bit_from_packed_uint_le(const uint packed, const int col) {
42+
// Extract the 4-bit value from the 8-bit value
43+
int val = int(packed >> (4 * col)) & 0xF;
44+
return val;
45+
}
46+
3247
#endif // LINEAR_COMMON_GLSLH

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ void initialize(out FPOutTile out_tile) {
3939
#endif
4040
}
4141

42+
void add(inout FPOutTile out_tile, const FPOutTile other_out_tile) {
43+
#if TILE_M > 1 && TILE_N4 == 1
44+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
45+
out_tile.data[m][0] += other_out_tile.data[m][0];
46+
}
47+
48+
#else
49+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
50+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
51+
out_tile.data[m][n4] += other_out_tile.data[m][n4];
52+
}
53+
}
54+
#endif
55+
}
56+
4257
#ifdef DEBUG_MODE
4358

4459
void printFPOutTile(const FPOutTile tile) {

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,14 @@ void apply_scales_and_biases(
123123
#endif
124124
}
125125

126+
void accumulate_out_tile_with_out_tile(
127+
inout FPOutTile accum,
128+
const FPOutTile other) {
129+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
130+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
131+
accum.data[m][n4] = accum.data[m][n4] + other.data[m][n4];
132+
}
133+
}
134+
}
135+
126136
#endif // LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH

0 commit comments

Comments
 (0)