diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 3ef3a6b45ea..4312971f5f1 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -187,9 +187,9 @@ def linear_weight_int4_impl( lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd") linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name) -################# +################## ## linear_qcs4w ## -################# +################## def linear_qcs4w( @@ -234,6 +234,79 @@ def linear_qcs4w( lib.impl(name, linear_qcs4w, "CompositeExplicitAutograd") linear_qc4w_op = getattr(getattr(torch.ops, namespace), name) +################## +## linear_q4gsw ## +################## + + +def unpack_4bit_weight_tensor( + packed_weight_tensor: torch.Tensor, x: torch.Tensor +) -> torch.Tensor: + """ + Reverses the packing performed in quantized_linear.pack_4bit_weight_tensor + """ + # Each packed byte contains two 4-bit values: high nibble and low nibble + K, N_half = packed_weight_tensor.shape + N = N_half * 2 + + # Unpack high and low nibbles + high_nibble = (packed_weight_tensor >> 4) & 0x0F + low_nibble = packed_weight_tensor & 0x0F + + # Stack to shape (K, N) + unpacked = torch.empty( + (K, N), dtype=torch.uint8, device=packed_weight_tensor.device + ) + unpacked[:, ::2] = low_nibble + unpacked[:, 1::2] = high_nibble + + # Undo the +8 offset and convert to signed 4-bit range [-8, 7] + unpacked = unpacked.to(torch.int8) - 8 + + in_channels = x.shape[-1] + # Undo any padding that may have been added to input channels + if in_channels != unpacked.shape[-1]: + return unpacked[:, :in_channels] + + return unpacked + + +def linear_q4gsw( + x: torch.Tensor, + weights: torch.Tensor, + weight_scales: torch.Tensor, + group_size: int, + bias: Optional[torch.Tensor] = None, +): + # Unpack the packed weights + weights = unpack_4bit_weight_tensor(weights, x) + + # Un-transpose the weight scales + weight_scales = weight_scales.transpose(0, 1) + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + + weights = torch.ops.torchao.dequantize_affine( + weights, [1, group_size], weight_scales, weight_zeros, torch.int8, -8, 7 + ) + + out = torch.nn.functional.linear(x, weights) + return out + + +name = "linear_q4gsw" +lib.define( + f""" + {name}( + Tensor self, + Tensor weights, + Tensor weight_scales, + int group_size, + Tensor? bias = None) -> Tensor + """ +) +lib.impl(name, linear_q4gsw, "CompositeExplicitAutograd") +linear_qc4w_op = getattr(getattr(torch.ops, namespace), name) + ######################## ## linear_qta8a_qga4w ## ######################## diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 8fbb41ed046..1b74ef1ac65 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -334,9 +334,10 @@ def register_int8_mm_op(): @update_features( [ exir_ops.edge.et_vk.linear_q8ta_q8csw.default, + exir_ops.edge.et_vk.linear_q4gsw.default, ] ) -def register_qa_qw_linear(): +def register_quantized_linear_ops(): return OpFeatures( inputs_storage=utils.CONTIGUOUS_ANY, supports_prepacking=True, diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 514abd78bf4..ee1c7ee2d2a 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -191,65 +191,43 @@ def find_quantized_linear_patterns( ## -def pack_4bit_weight_tensor(inp: torch.Tensor) -> torch.Tensor: +def pack_4bit_weight_tensor(weight_tensor: torch.Tensor) -> torch.Tensor: """ Given a 8-bit weight tensor containing values quantized to 4 bits, create a packed - weight tensor by packing 2 4-bit values in one unsigned 8-bit value. + weight tensor by transposing the weight tensor, then packing 2 4-bit values in one + 8-bit value. - An input weight tensor of shape (M, K) will produce a packed weight tensor of shape - (M, K / 2). - - The packing implemented here is the same as the packing produced by - backends/vulkan/_passes/int4_weight_only_quantizer.py + An input weight tensor of shape (N, K) will produce a packed weight tensor of shape + (K, N / 2). """ # Assert we got a properly quantized tensor. - min, max = inp.min().item(), inp.max().item() + min_val, max_val = weight_tensor.min().item(), weight_tensor.max().item() assert ( - max <= 7 and min >= -8 - ), f"pack_4bit_weight_tensor: [min,max] out of [-8, 7] range, got [{min}, {max}]" + max_val <= 7 and min_val >= -8 + ), f"pack_4bit_weight_tensor: [min_val,max_val] out of [-8, 7] range, got [{min_val}, {max_val}]" # Assuming we have a 2d tensor - if inp.ndim != 2: - inp = inp.squeeze() + if weight_tensor.ndim != 2: + weight_tensor = weight_tensor.squeeze() assert ( - inp.ndim == 2 - ), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got {inp.ndim}" + weight_tensor.ndim == 2 + ), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got {weight_tensor.ndim}" - # pad ic - if inp.shape[-1] % 2 != 0: - inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0) + # Need to pad innermost dim to be a multiple of 8, since the minimum load granularity + # is int32 (4 bytes), which contains 8 4-bit values. + if weight_tensor.shape[-1] % 8 != 0: + num_pad = 8 - (weight_tensor.shape[-1] % 8) + weight_tensor = F.pad(input=weight_tensor, pad=(0, num_pad)) # Shape after padding - oc, ic = inp.shape - assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even" + _, in_channels = weight_tensor.shape + assert in_channels % 8 == 0, "convert_to_qc4w: expecting ic to be divisible by 8" - # Adjust inp tensor for zp - inp = inp.to(dtype=torch.uint8) + 8 + # Adjust weight_tensor tensor for zp + weight_tensor = weight_tensor.to(dtype=torch.uint8) + 8 # Pack each 4-bit value into a single 8-bit value - return inp[::, ::2] << 4 | inp[::, 1::2] - - -def make_combined_scales_and_zeros_tensor( - scales: torch.Tensor, zeros: torch.Tensor -) -> torch.Tensor: - """ - Given a scales and zeros tensor, create a combined tensor by stacking them into a - single tensor. - - The scales and zeros tensors are expected to be 2D tensors of shape - (OUTPUT_CHANNELS, NUM_GROUPS). The combined tensor will have the shape - (NUM_GROUPS, OUTPUT_CHANNELS, 2). - - This is the scales and zeros format produced by - backends/vulkan/_passes/int4_weight_only_quantizer.py, which in turn is the scales - and zeros format expected by the _weight_int4pack_mm op in ATen. - """ - scales_reshaped = scales.transpose(0, 1).unsqueeze(2) - zeros_reshaped = zeros.transpose(0, 1).unsqueeze(2) - - zeros_scaled = zeros_reshaped * scales_reshaped * -1 - return torch.cat((scales_reshaped, zeros_scaled), dim=2) + return weight_tensor[::, 1::2] << 4 | weight_tensor[::, ::2] ## @@ -257,50 +235,50 @@ def make_combined_scales_and_zeros_tensor( ## -def make_linear_q4ga_op( +def make_linear_q4gsw_op( ep: ExportedProgram, graph_module: torch.fx.GraphModule, match: QuantizedLinearMatch, weight_tensor: torch.Tensor, weight_scales_tensor: torch.Tensor, - weight_zeros_tensor: torch.Tensor, ): - packed_quantized_weight_tensor = pack_4bit_weight_tensor(weight_tensor) - utils.update_program_state_dict( - ep, match.weight_node.name, packed_quantized_weight_tensor - ) - # Need to make sure corresponding FakeTensor has same size - match.weight_node.meta["val"] = match.weight_node.meta["val"][:, ::2].to( - torch.uint8 - ) - - group_size = weight_tensor.shape[1] // weight_scales_tensor.shape[1] - - combined_scales_zeros_tensor = make_combined_scales_and_zeros_tensor( - weight_scales_tensor, weight_zeros_tensor + num_groups = weight_scales_tensor.shape[-1] + in_channels = weight_tensor.shape[-1] + group_size = in_channels // num_groups + + weight_tensor = pack_4bit_weight_tensor(weight_tensor) + # Use this function for convenience to update the state dict with the packed + # weight tensor. Alignment will already have been done in the above function. + weight_tensor = utils.align_width_and_update_state_dict( + ep, match.weight_node, weight_tensor, align_to=1, force_update=True ) - combined_scales_zeros_name = f"{match.weight_node.name}_scales_zeros" - graph_module.register_parameter( - combined_scales_zeros_name, torch.nn.Parameter(combined_scales_zeros_tensor) + # Also transpose the weight scales tensor to shape [num_groups, N] + weight_scales_tensor = weight_scales_tensor.transpose(0, 1).contiguous() + # Align to multiple of 8 to ensure that data loads from the weight scales + # tensor do not go out of bounds. Each thread computes 8 output channels. + utils.align_width_and_update_state_dict( + ep, + match.weight_scales_node, + weight_scales_tensor, + align_to=8, + force_update=True, ) with graph_module.graph.inserting_before(match.output_node): - combined_scales_zeros = graph_module.graph.get_attr(combined_scales_zeros_name) - linear_q4ga_node = graph_module.graph.create_node( + linear_q4gsw_node = graph_module.graph.create_node( "call_function", - exir_ops.edge.et_vk.linear_weight_int4.default, + exir_ops.edge.et_vk.linear_q4gsw.default, args=( match.fp_input_node, match.weight_node, + match.weight_scales_node, group_size, - combined_scales_zeros, - 1, ), ) - linear_q4ga_node.meta["val"] = match.output_node.meta["val"] - match.output_node.replace_all_uses_with(linear_q4ga_node) + linear_q4gsw_node.meta["val"] = match.output_node.meta["val"] + match.output_node.replace_all_uses_with(linear_q4gsw_node) def make_linear_q8ta_q8csw_custom_op( @@ -373,13 +351,8 @@ def replace_quantized_linear_patterns( and match.is_weight_pergroup_quantized() and utils.is_in_4bit_range(weight_tensor) ): - make_linear_q4ga_op( - ep, - graph_module, - match, - weight_tensor, - weight_scales_tensor, - weight_zeros_tensor, + make_linear_q4gsw_op( + ep, graph_module, match, weight_tensor, weight_scales_tensor ) elif ( match.is_input_static_per_tensor_quantized() diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh index 90ede450ae7..da326b26e93 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh @@ -29,4 +29,19 @@ int extract_8bit_from_packed_int_le(const int packed, const int i) { return byte; } +// Extract a 4-bit value from a packed int (little endian) +// It is assumed that the 4-bit value is in the range [0, 15] +int extract_4bit_from_packed_int_le(const int packed, const int col) { + // Extract the 4-bit value from the 8-bit value + int val = packed >> (4 * col) & 0xF; + return val; +} + +// Convenience overload for packed uint +int extract_4bit_from_packed_uint_le(const uint packed, const int col) { + // Extract the 4-bit value from the 8-bit value + int val = int(packed >> (4 * col)) & 0xF; + return val; +} + #endif // LINEAR_COMMON_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh index 049f1d34caf..dd571229a9c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh @@ -39,6 +39,21 @@ void initialize(out FPOutTile out_tile) { #endif } +void add(inout FPOutTile out_tile, const FPOutTile other_out_tile) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + out_tile.data[m][0] += other_out_tile.data[m][0]; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + out_tile.data[m][n4] += other_out_tile.data[m][n4]; + } + } +#endif +} + #ifdef DEBUG_MODE void printFPOutTile(const FPOutTile tile) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh index 7229da32cd3..ee50ad87f74 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh @@ -123,4 +123,14 @@ void apply_scales_and_biases( #endif } +void accumulate_out_tile_with_out_tile( + inout FPOutTile accum, + const FPOutTile other) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + accum.data[m][n4] = accum.data[m][n4] + other.data[m][n4]; + } + } +} + #endif // LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int4_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int4_compute.glslh new file mode 100644 index 00000000000..0606759e393 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int4_compute.glslh @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using fp input and weight tiles. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_FP_INT4_COMPUTE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_FP_INT4_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_common.glslh" +#include "linear_fp_input_tile.glslh" +#include "linear_fp_output_tile.glslh" +#include "linear_fp_per_out_channel_params.glslh" +#include "linear_int4_weight_tile.glslh" + +// Unpacks a int containing 4 packed 8-bit integers into a vec4 containing each +// of the 4 unpacked 8-bit integers. +VEC4_T unpack_packed_4xint4(const int int8x4, const int n4_group) { + return VEC4_T( + extract_4bit_from_packed_int_le(int8x4, n4_group + 0), + extract_4bit_from_packed_int_le(int8x4, n4_group + 2), + extract_4bit_from_packed_int_le(int8x4, n4_group + 4), + extract_4bit_from_packed_int_le(int8x4, n4_group + 6)); +} + +T extract_4bit_from_weight_block( + const ivec4 block, + const int col, + const int row) { + return T(((block[row] >> (4 * col)) & 0xF) - 8); +} + +void fp_accumulate_with_int4_weight( + inout FPOutTile accum, + FPInputTile in_tile, + Int4WeightTile w_tile, + FPPerOutChannelParams scales_tile, + FPPerOutChannelParams zeros_tile) { + // Accum tile is indexed as accum[m][n4][n4i] + // -> gives fp accumulator for output tile element at (x = n, y = m) + // Input tile is indexed as in_tile.data[m][k4] + // -> gives vec4 containing the fp inputs at index + // (k, m), (k + 1, m), (k + 2, m), (k + 3, m) + // Weight tile is indexed as w_tile.data[k4][n8][n4i] + // -> gives packed integer containing the 8x 4-bit quantized values at index + // (n, k), (n, k + 1), (n, k + 2), (n, k + 3), + // (n + 4, k), (n + 4, k + 1), (n + 4, k + 2), (n + 4, k + 3) + VEC4_T weight_texels[2]; +#if TILE_K4 == 1 && TILE_N8 == 1 + [[unroll]] for (int k = 0; k < 4; ++k) { + const int base_col_1 = mul_2(k); + const int base_col_2 = base_col_1 + 1; + weight_texels[0] = VEC4_T( + extract_4bit_from_weight_block(w_tile.data[0][0], base_col_1, 0), + extract_4bit_from_weight_block(w_tile.data[0][0], base_col_1, 1), + extract_4bit_from_weight_block(w_tile.data[0][0], base_col_1, 2), + extract_4bit_from_weight_block(w_tile.data[0][0], base_col_1, 3)); + weight_texels[1] = VEC4_T( + extract_4bit_from_weight_block(w_tile.data[0][0], base_col_2, 0), + extract_4bit_from_weight_block(w_tile.data[0][0], base_col_2, 1), + extract_4bit_from_weight_block(w_tile.data[0][0], base_col_2, 2), + extract_4bit_from_weight_block(w_tile.data[0][0], base_col_2, 3)); + + weight_texels[0] = + fma(weight_texels[0], scales_tile.data[0], zeros_tile.data[0]); + weight_texels[1] = + fma(weight_texels[1], scales_tile.data[1], zeros_tile.data[1]); + + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + accum.data[m][0] = fma( + VEC4_T(in_tile.data[m][0][k]), weight_texels[0], accum.data[m][0]); + accum.data[m][1] = fma( + VEC4_T(in_tile.data[m][0][k]), weight_texels[1], accum.data[m][1]); + } + } + +#else + // TODO(ssjia): Implement generic case + not implemented + +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_FP_INT4_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh index 0cba49e87c7..1286c1d082f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh @@ -15,6 +15,10 @@ VEC4_T load_weight_scale_x4(const int n4) { return t_weight_scales[n4]; } +VEC4_T load_scale_x4(const int n4, const int quant_group_idx, const int N4) { + return t_weight_scales[quant_group_idx * N4 + n4]; +} + void load_weight_scales_tile( out FPPerOutChannelParams scales, const int n4_start) { @@ -29,4 +33,20 @@ void load_weight_scales_tile( #endif } +void load_weight_scales_tile_for_group( + out FPPerOutChannelParams scales, + const int n4_start, + const int quant_group_idx, + const int N4) { +#if TILE_N4 == 1 + scales.data[0] = load_scale_x4(n4_start, quant_group_idx, N4); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + scales.data[n4] = load_scale_x4(n4_start + n4, quant_group_idx, N4); + } + +#endif +} + #endif // LINEAR_FP_WEIGHT_SCALES_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_block.glslh new file mode 100644 index 00000000000..d813224c3aa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_block.glslh @@ -0,0 +1,172 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT4_WEIGHT_BLOCK_GLSLH +#define LINEAR_INT4_WEIGHT_BLOCK_GLSLH + +/* + * This file defines utilties to perform weight prepacking of quantized int4 + * matrix multiplation weights. It also defines utilities to load source + * weight data from inputbuffer, and write out a packed weight block to output + * texture/buffer. + * + * Note: 2 4-bit values are packed into each 8-bit value in the source data. + * + * Requires: + * - t_packed_int4_weight to be defined in shader layout (output texture/buffer) + * - t_int4_weight to be defined in shader layout (input buffer) + * + * Settings: + * - USING_BUFFER to indicate if output resource is a buffer. Otherwise texture + * is assumed. + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_common.glslh" + +// Represents source data for 2 8Kx4N block of the weight matrix read from the +// input buffer. Each int element contains 8 packed 4-bit values along the K +// dimension. Overall the data represents 8Kx8N block. +struct Int4Weight2xBlockSourceData { + uint data[8]; +}; + +// Represents data for a packed 4Kx8N block of the weight matrix to be written +// out to output texture/buffer. An individual block was originally a 4Kx8N +// block in the original weight tensor, and then the top and bottom halves are +// concatenated along the width dim. +struct Int4WeightBlockPacked { + ivec4 data; +}; + +void load_block_source_data_no_checks( + out Int4Weight2xBlockSourceData src_data, + const int k8, + const int n_start, + const int ntexels_K, + const int N) { + [[unroll]] for (int n = 0; n < 8; ++n) { + src_data.data[n] = t_int4_weight[(n_start + n) * ntexels_K + k8]; + } +} + +// To be used if K - k_start < 4 +void load_block_source_data_with_checks( + out Int4Weight2xBlockSourceData src_data, + const int k8, + const int n_start, + const int ntexels_K, + const int N) { + [[unroll]] for (int n = 0; n < 8; ++n) { + if (n_start + n < N) { + src_data.data[n] = t_int4_weight[(n_start + n) * ntexels_K + k8]; + } else { + src_data.data[n] = 0x88888888; + } + } +} + +int pack_8x4bit_signed_into_int( + const int val0, + const int val1, + const int val2, + const int val3, + const int val4, + const int val5, + const int val6, + const int val7) { + return int( + ((val7 & 0xF) << 28) | ((val6 & 0xF) << 24) | ((val5 & 0xF) << 20) | + ((val4 & 0xF) << 16) | ((val3 & 0xF) << 12) | ((val2 & 0xF) << 8) | + ((val1 & 0xF) << 4) | ((val0 & 0xF))); +} + +void create_packed_blocks( + out Int4WeightBlockPacked block1, + out Int4WeightBlockPacked block2, + const Int4Weight2xBlockSourceData src_data) { + [[unroll]] for (int row = 0; row < 4; ++row) { + const int row_idx_1 = row; + const int row_idx_2 = row + 4; + block1.data[row] = pack_8x4bit_signed_into_int( + extract_4bit_from_packed_uint_le(src_data.data[row_idx_1], 0), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_2], 0), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_1], 1), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_2], 1), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_1], 2), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_2], 2), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_1], 3), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_2], 3)); + + block2.data[row] = pack_8x4bit_signed_into_int( + extract_4bit_from_packed_uint_le(src_data.data[row_idx_1], 4), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_2], 4), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_1], 5), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_2], 5), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_1], 6), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_2], 6), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_1], 7), + extract_4bit_from_packed_uint_le(src_data.data[row_idx_2], 7)); + } +} + +#ifdef USING_BUFFER + +void write_packed_block( + const Int4WeightBlockPacked block, + const int k4, + const int n8, + const int nblocks_K) { + t_packed_int4_weight[n8 * nblocks_K + k4] = block.data; +} + +#else // USING_TEXTURE + +void write_packed_block( + const Int4WeightBlockPacked block, + const int k4, + const int n8, + const int nblocks_K) { + imageStore(t_packed_int4_weight, ivec2(k4, n8), block.data); +} + +#endif // USING_BUFFER + +#ifdef DEBUG_MODE + +void printInt4Weight2xBlockSourceData( + const Int4Weight2xBlockSourceData src_data) { + debugPrintfEXT("int4_weight_block_source_data: \\n"); + [[unroll]] for (int row = 0; row < 8; ++row) { + debugPrintfEXT("row %i (raw: %u): ", row, src_data.data[row]); + // Extract and print individual 4-bit values directly from packed int + [[unroll]] for (int col = 0; col < 8; ++col) { + int val_4bit = extract_4bit_from_packed_uint_le(src_data.data[row], col); + debugPrintfEXT("[%i] ", val_4bit); + } + debugPrintfEXT("\\n"); + } +} + +void printInt4WeightBlockPacked(const Int4WeightBlockPacked block) { + debugPrintfEXT("int4_weight_block_packed: \\n"); + // Print unpacked 4-bit values for each int in block.data + [[unroll]] for (int i = 0; i < 4; ++i) { + debugPrintfEXT("block.data[%i] 4-bit values: ", i); + [[unroll]] for (int col = 0; col < 8; ++col) { + int val_4bit = extract_4bit_from_packed_int_le(block.data[i], col); + debugPrintfEXT("[%i] ", val_4bit); + } + debugPrintfEXT("\\n"); + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT4_WEIGHT_BLOCK_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile.glslh new file mode 100644 index 00000000000..559459f14a8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile.glslh @@ -0,0 +1,108 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT4_WEIGHT_TILE_GLSLH +#define LINEAR_INT4_WEIGHT_TILE_GLSLH + +#include "linear_common.glslh" +#include "linear_fp_weight_tile.glslh" + +/* + * Defines the Int4WeightTile struct, which is used to represent a tile of the + * quantized int4 weight matrix of a quantized matrix multiplication operation. + * + * Settings: + * - TILE_K4: number of (groups of 4) rows in the weight tile + * - TILE_N8: number of (groups of 8) columns in the weight tile + */ + +#extension GL_EXT_control_flow_attributes : require + +struct Int4WeightTile { + ivec4 data[TILE_K4][TILE_N8]; +}; + +void unpack_int4_weight_tile( + out FPWeightTile int8_tile, + const Int4WeightTile int4_tile) { +#if TILE_K4 == 1 && TILE_N8 == 1 + for (int k = 0; k < TILE_K; ++k) { + const int col_idx_1 = 2 * k; + const int col_idx_2 = 2 * k + 1; + int8_tile.data[k][0][0] = + T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][0], col_idx_1)); + int8_tile.data[k][0][1] = + T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][1], col_idx_1)); + int8_tile.data[k][0][2] = + T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][2], col_idx_1)); + int8_tile.data[k][0][3] = + T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][3], col_idx_1)); + + // n4 = 1 + int8_tile.data[k][1][0] = + T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][0], col_idx_2)); + int8_tile.data[k][1][1] = + T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][1], col_idx_2)); + int8_tile.data[k][1][2] = + T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][2], col_idx_2)); + int8_tile.data[k][1][3] = + T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][3], col_idx_2)); + } + +#else + for (int k = 0; k < TILE_K; ++k) { + const int k4 = div_4(k); + const int k4i = mod_4(k); + for (int n8 = 0; n8 < TILE_N8; ++n8) { + const int n4 = mul_2(n8); + const int col_idx_1 = 2 * k4i; + const int col_idx_2 = 2 * k4i + 1; + int8_tile.data[k][n4][0] = T(extract_4bit_from_packed_int_le( + int4_tile.data[k4][n8][0], col_idx_1)); + int8_tile.data[k][n4][1] = T(extract_4bit_from_packed_int_le( + int4_tile.data[k4][n8][1], col_idx_1)); + int8_tile.data[k][n4][2] = T(extract_4bit_from_packed_int_le( + int4_tile.data[k4][n8][2], col_idx_1)); + int8_tile.data[k][n4][3] = T(extract_4bit_from_packed_int_le( + int4_tile.data[k4][n8][3], col_idx_1)); + + int8_tile.data[k][n4 + 1][0] = T(extract_4bit_from_packed_int_le( + int4_tile.data[k4][n8][0], col_idx_2)); + int8_tile.data[k][n4 + 1][1] = T(extract_4bit_from_packed_int_le( + int4_tile.data[k4][n8][1], col_idx_2)); + int8_tile.data[k][n4 + 1][2] = T(extract_4bit_from_packed_int_le( + int4_tile.data[k4][n8][2], col_idx_2)); + int8_tile.data[k][n4 + 1][3] = T(extract_4bit_from_packed_int_le( + int4_tile.data[k4][n8][3], col_idx_2)); + } + } + +#endif +} + +#ifdef DEBUG_MODE + +void printInt4WeightTile(const Int4WeightTile block) { + debugPrintfEXT("int4_weight_tile: \\n"); + // Print unpacked 4-bit values for each int in block.data + [[unroll]] for (int i = 0; i < TILE_K; ++i) { + const int k4 = div_4(i); + const int k4i = mod_4(i); + debugPrintfEXT("block.data[%i] 4-bit values: ", i); + [[unroll]] for (int col = 0; col < TILE_N; ++col) { + int val_4bit = + extract_4bit_from_packed_int_le(block.data[k4][0][k4i], col); + debugPrintfEXT("[%i] ", val_4bit); + } + debugPrintfEXT("\\n"); + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT4_WEIGHT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile_load.glslh new file mode 100644 index 00000000000..033e0082436 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile_load.glslh @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT4_WEIGHT_TILE_LOAD_GLSLH +#define LINEAR_INT4_WEIGHT_TILE_LOAD_GLSLH + +/* + * Defines functions to load a Int4WeightTile from input buffer/texture. + * + * Requires: + * - t_packed_int4_weight to be declared in the shader layout (input + * buffer/texture) + * + * Settings: + * - WEIGHT_BUFFER to indicate t_packed_int4_weight is a buffer, otherwise + * texture storage is assumed. + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_int4_weight_tile.glslh" + +#ifdef WEIGHT_BUFFER + +ivec4 load_int4_weight_block( + const int block_x, + const int block_y, + const int nblocks_x) { + return t_packed_int4_weight[(block_y * nblocks_x) + block_x]; +} + +#else // WEIGHT_TEXTURE + +ivec4 load_int4_weight_block( + const int block_x, + const int block_y, + const int nblocks_x) { + return texelFetch(t_packed_int4_weight, ivec2(block_x, block_y), 0); +} + +#endif // WEIGHT_BUFFER + +void load_int4_weight_tile( + out Int4WeightTile weight_tile, + const int block_x, + const int block_y, + const int nblocks_x) { +#if TILE_K4 == 1 && TILE_N8 == 1 + weight_tile.data[0][0] = load_int4_weight_block(block_x, block_y, nblocks_x); + +#elif TILE_K4 == 1 && TILE_N8 > 1 + [[unroll]] for (int x = 0; x < TILE_N8; ++x) { + weight_tile.data[0][x] = + load_int4_weight_block(block_x + x, block_y, nblocks_x); + } + +#elif TILE_K4 > 1 && TILE_N8 == 1 + [[unroll]] for (int y = 0; y < TILE_K4; ++y) { + weight_tile.data[y][0] = + load_int4_weight_block(block_x, block_y + y, nblocks_x); + } + +#else + [[unroll]] for (int y = 0; y < TILE_K4; ++y) { + [[unroll]] for (int x = 0; x < TILE_N8; ++x) { + weight_tile.data[y][x] = + load_int4_weight_block(block_x + x, block_y + y, nblocks_x); + } + } +#endif +} + +#endif // LINEAR_INT4_WEIGHT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl new file mode 100644 index 00000000000..6f0d890a9c4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl @@ -0,0 +1,134 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_N8 ${TILE_N8} + +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N8 * 2} + +#define TILE_M ${TILE_M} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N8 * 8} + +#define WGS ${WGS} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int4_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "0")} +${layout_declare_spec_const(C, "int", "K4_per_group", "0")} + +#include "common.glslh" +#include "linear_fp_input_tile_load.glslh" +#include "linear_int4_weight_tile_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_output_tile_fp_int4_compute.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "linear_fp_output_tile_store.glslh" +#include "linear_fp_bias_load.glslh" + +shared FPOutTile partial_sums[WGS]; + +void main() { + const int lid = int(gl_LocalInvocationID.x); + const int n8 = int(gl_GlobalInvocationID.y); + + // The output tensor will have a shape of [n, 1, 1, 1]. Each thread computes + // 8 output elements, so each thread will write to 8 elements starting at the + // tensor index (gid.x * 8, 0, 0, 0). + const int n = mul_8(n8); + const int n4 = mul_2(n8); + const int K4 = div_up_4(input_sizes.x); + const int N4 = div_up_4(output_sizes.x); + + const int group_size = mul_4(K4_per_group); + + if (n >= output_sizes.x) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + Int4WeightTile int4_weight_tile; + + FPPerOutChannelParams weight_scales_tile; + FPPerOutChannelParams weight_zeros_tile; + weight_zeros_tile.data[0] = VEC4_T(0.0); + weight_zeros_tile.data[1] = VEC4_T(0.0); + + // initialize the group index to a value larger than the largest possible + int cur_group_idx = input_sizes.x; + + for (int k4 = lid; k4 < div_up_4(input_sizes.x); k4 += WGS) { + const int group_idx = k4 / K4_per_group; + + // Only update the scales/zeros if the current iteration is now working on a + // new quantization group. + if (group_idx != cur_group_idx) { + load_weight_scales_tile_for_group(weight_scales_tile, n4, group_idx, N4); + cur_group_idx = group_idx; + } + + load_input_tile_no_checks(in_tile, k4, 0, K4, 1); + load_int4_weight_tile(int4_weight_tile, k4, n8, K4); + + fp_accumulate_with_int4_weight( + out_tile, + in_tile, + int4_weight_tile, + weight_scales_tile, + weight_zeros_tile); + } + + partial_sums[lid] = out_tile; + + memoryBarrierShared(); + barrier(); + + // Tree reduction to compute the overall result. + for (int i = WGS / 2; i > 0; i /= 2) { + if (lid < i) { + accumulate_out_tile_with_out_tile( + partial_sums[lid], partial_sums[lid + i]); + } + memoryBarrierShared(); + barrier(); + } + + // Only the first thread will write out result + if (lid == 0) { + out_tile = partial_sums[0]; + write_output_tile_with_checks(out_tile, n4, 0, N4, 1); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.yaml new file mode 100644 index 00000000000..bb5f44d4086 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.yaml @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_q4gsw_coop: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + TILE_M: 1 + TILE_K4: 1 + TILE_N8: 1 + WGS: 64 + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: linear_q4gsw_coop_texture3d_texture2d + - NAME: linear_q4gsw_coop_texture3d_buffer + WEIGHT_STORAGE: buffer + - NAME: linear_q4gsw_coop_buffer_texture2d + IO_STORAGE: buffer + - NAME: linear_q4gsw_coop_buffer_buffer + IO_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl new file mode 100644 index 00000000000..0ad91643219 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl @@ -0,0 +1,116 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_N8 ${TILE_N8} + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N8 * 2} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N8 * 8} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int4_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "0")} +${layout_declare_spec_const(C, "int", "K4_per_group", "0")} + +#include "linear_fp_input_tile_load.glslh" +#include "linear_int4_weight_tile_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" +#include "linear_fp_output_tile_fp_int4_compute.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "linear_fp_output_tile_store.glslh" + +void main() { + const int out_tile_x = int(gl_GlobalInvocationID.x); + const int out_tile_y = int(gl_GlobalInvocationID.y); + + const int n = out_tile_x * TILE_N; + const int m = out_tile_y * TILE_M; + + const int n8 = div_8(n); + const int n4 = div_4(n); + const int m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const int M = input_sizes.y; + const int K4 = div_up_4(input_sizes.x); + const int N4 = div_up_4(output_sizes.x); // number of texels in each row + const int N8 = div_up_8(output_sizes.x); // number of texels in each row + + bool should_print = (n8 == 0) && (m4 == 0); + should_print = false; + + // VEC4_T out_texels[4][2]; + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + Int4WeightTile int4_weight_tile; + + FPPerOutChannelParams weight_scales_tile; + FPPerOutChannelParams weight_zeros_tile; + weight_zeros_tile.data[0] = VEC4_T(0.0); + weight_zeros_tile.data[1] = VEC4_T(0.0); + + const int num_groups = K4 / K4_per_group; + + for (int group_i = 0; group_i < num_groups; ++group_i) { + // Load quantization scales and zeros for the current group + load_weight_scales_tile_for_group(weight_scales_tile, n4, group_i, N4); + + for (int k4_inner = 0; k4_inner < K4_per_group; k4_inner++) { + const int k4 = group_i * K4_per_group + k4_inner; + + load_input_tile_no_checks(in_tile, k4, m, K4, M); + load_int4_weight_tile(int4_weight_tile, k4, n8, K4); + + fp_accumulate_with_int4_weight( + out_tile, + in_tile, + int4_weight_tile, + weight_scales_tile, + weight_zeros_tile); + } + } + + write_output_tile_with_checks(out_tile, n4, m, N4, M); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.yaml new file mode 100644 index 00000000000..5a6bcb711bb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.yaml @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_q4gsw_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_K4: 1 + TILE_N8: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: linear_q4gsw_tiled_texture3d_texture2d + - NAME: linear_q4gsw_tiled_texture3d_buffer + WEIGHT_STORAGE: buffer + - NAME: linear_q4gsw_tiled_buffer_texture2d + IO_STORAGE: buffer + WEIGHT_STORAGE: texture2d + - NAME: linear_q4gsw_tiled_buffer_buffer + IO_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q4_linear_weight.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_q4_linear_weight.glsl new file mode 100644 index 00000000000..b9f5c994910 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q4_linear_weight.glsl @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type(STORAGE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_packed_int4_weight", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_int4_weight", "uint", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec2 orig_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" +#include "linear_int4_weight_block.glslh" + +void main() { + const int k8 = int(gl_GlobalInvocationID.x); + const int n8 = int(gl_GlobalInvocationID.y); + + const int K = orig_sizes.x; + const int N = orig_sizes.y; + + // Each shader invocation processes a 4x8 block of the input data. + const int K4 = div_up_4(K); + const int K8 = div_up_8(K); + const int N8 = div_up_8(N); + + // Check bounds + if (n8 >= N8 || k8 >= K8) { + return; + } + + Int4Weight2xBlockSourceData src_data; + const int n = mul_8(n8); + if (N - n >= 8) { + load_block_source_data_no_checks(src_data, k8, n, K8, N); + } else { + load_block_source_data_with_checks(src_data, k8, n, K8, N); + } + + // A 8Kx8K block of the weight matrix is loaded into memory. This will be + // split into two blocks each holding 4Kx8N worth of data. + // The first block contains data for k + (0, 1, 2, 3) i.e. the first 4 columns + // of the loaded weight block. + Int4WeightBlockPacked packed_block_1; + // The second block contains data for k + (4, 5, 6, 7) i.e. the second 4 cols + // of the loaded weight block + Int4WeightBlockPacked packed_block_2; + create_packed_blocks(packed_block_1, packed_block_2, src_data); + + const int k4 = mul_2(k8); + write_packed_block(packed_block_1, k4, n8, K4); + write_packed_block(packed_block_2, k4 + 1, n8, K4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q4_linear_weight.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_q4_linear_weight.yaml new file mode 100644 index 00000000000..7a145ec95d7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q4_linear_weight.yaml @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_q4_linear_weight: + parameter_names_with_default_values: + STORAGE: buffer + shader_variants: + - NAME: pack_q4_linear_weight_buffer + STORAGE: buffer + - NAME: pack_q4_linear_weight_texture2d + STORAGE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index d6aeb5e3dce..4831c6f2f85 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -19,6 +19,40 @@ namespace vkcompute { // Shader dispatch utilities // +bool is_gemv(ComputeGraph* graph, const ValueRef& fp_input) { + return graph->size_at(-2, fp_input) == 1; +} + +void resize_linear_qw_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + + ValueRef output = args.at(0).refs.at(0); + ValueRef fp_input = args.at(1).refs.at(0); + ValueRef weight_data = extra_args.at(1); + + std::vector mat1_sizes = graph->sizes_of(fp_input); + std::vector mat2_sizes = graph->sizes_of(weight_data); + + const int64_t out_cols = utils::val_at(-2, mat1_sizes); + const int64_t out_rows = utils::val_at(-2, mat2_sizes); + + std::vector new_out_sizes(3); + if (mat1_sizes.size() == 2) { + new_out_sizes.resize(2); + new_out_sizes.at(0) = out_cols; + new_out_sizes.at(1) = out_rows; + } else { + new_out_sizes.at(0) = mat1_sizes.at(0); + new_out_sizes.at(1) = out_cols; + new_out_sizes.at(2) = out_rows; + } + + graph->virtual_resize(output, new_out_sizes); +} + utils::uvec3 quantized_linear_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -32,10 +66,23 @@ utils::uvec3 quantized_linear_global_wg_size( // width const uint32_t N = utils::val_at(-1, out_sizes); - // 1 output tile is 4x4 elements const uint32_t M4 = utils::div_up(M, 4u); const uint32_t N4 = utils::div_up(N, 4u); + // For 4-bit weights, each output tile contains 8 columns and 4 rows + if (shader.kernel_name.find("q4") != std::string::npos) { + const uint32_t N8 = utils::div_up(N, 8u); + + const bool using_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + // TODO: explain + if (using_coop_algorithm) { + return {64, N8, M}; + } + return {N8, M4, 1}; + } + + // Otherwise, each output tile contains 4 columns and 4 rows return {N4, M4, 1}; } @@ -45,8 +92,15 @@ utils::uvec3 quantized_linear_local_wg_size( const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { - return pick_hw_square_wg_size( - graph, shader, global_workgroup_size, args, resize_args); + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + if (use_coop_algorithm) { + return {64, 1, 1}; + } else { + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); + } } std::tuple get_quantized_input_num_blocks( @@ -80,6 +134,39 @@ utils::uvec3 quant_pack_input_global_wg_size( 1u}; } +vkapi::ShaderInfo pick_linear_qw_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + + const ValueRef output = args.at(0).refs.at(0); + const ValueRef fp_input = args.at(1).refs.at(0); + const ValueRef packed_int_weight = args.at(1).refs.at(1); + + const bool weight_is_4bit = resize_args.at(0) != kDummyValueRef; + const bool is_gemv_case = is_gemv(graph, fp_input); + + std::string kernel_name = "linear_"; + if (weight_is_4bit) { + kernel_name += "q4gsw"; + } else { + kernel_name += "q8csw"; + } + + if (weight_is_4bit && is_gemv_case) { + kernel_name += "_coop"; + } else { + kernel_name += "_tiled"; + } + add_storage_type_suffix(kernel_name, graph->storage_type_of(output)); + add_storage_type_suffix( + kernel_name, graph->storage_type_of(packed_int_weight)); + add_dtype_suffix(kernel_name, graph->dtype_of(output)); + + return VK_KERNEL_FROM_STR(kernel_name); +} + // // Prepacking nodes // @@ -88,35 +175,75 @@ ValueRef prepack_quantized_linear_weight( ComputeGraph& graph, const QuantizationConfig& weight_quant_config, const ValueRef qmat2_data) { - VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND( + weight_quant_config.nbits == 8 || weight_quant_config.nbits == 4); std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); const int64_t ndim = graph.dim_of(qmat2_data); - // Input size is [N, K]. K will be guaranteed to be a multiple of 4. - const int64_t K = qmat2_orig_sizes.at(ndim - 1); - const int64_t N = qmat2_orig_sizes.at(ndim - 2); + int64_t qmat2_width = qmat2_orig_sizes.at(ndim - 1); + int64_t qmat2_height = qmat2_orig_sizes.at(ndim - 2); - // Sanity check that assumption is correct - VK_CHECK_COND(K % 4 == 0); + int64_t K; + int64_t N; + if (weight_quant_config.nbits == 4) { + // For 4-bit quantization, weight source data has shape [N, K/2]. Each byte + // contains 2 * 4-bit values. + K = qmat2_width * 2; + N = qmat2_height; + } else { + // For 8-bit quantization, the weight source data has shape [N, K] + K = qmat2_width; + N = qmat2_height; + } - // The packing format packs the weight tensor into units of 4 wide x 4 high - // blocks. To figure out the size of the output tensor, determine the number - // of blocks along each dimension. - const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); - const int64_t num_blocks_N = utils::div_up(N, int64_t(4)); + // Sanity check that assumptions are correct. Data loads along the innermost + // dimension must be well aligned along texel boundaries. + if (weight_quant_config.nbits == 4) { + VK_CHECK_COND(K % 8 == 0); + } else { + VK_CHECK_COND(K % 4 == 0); + } + + // The packing format packs the weight tensor into blocks of 4 columns (K) and + // 4 rows (N) + int64_t N_per_block = 4; + int64_t K_per_block = 4; + + // For 4 bit, quantization, the amount of information contained in one block + // can be doubled. Each block will contain data for 8 rows (N) instead of the + // usual 4. + if (weight_quant_config.nbits == 4) { + N_per_block = 8; + } + + // To figure out the size of the output tensor, determine the number of blocks + // along each dimension. + const int64_t num_blocks_K = utils::div_up(K, K_per_block); + const int64_t num_blocks_N = utils::div_up(N, N_per_block); // The blocks are arranged in a transposed manner, such that the transposed // weight block is indexed like packed_weights[k4][n4] - this is to allow for // optimal memory coalescing when computing GEMM. - const int64_t output_height = num_blocks_K; + int64_t output_height = num_blocks_K; // The base dtype of the packed tensor is int32 (each int32 contains 4x 8bit // values) and each block is represented as a ivec4. Therefore the width dim // of the packed tensor is multiplied by 4. - const int64_t output_width = num_blocks_N * 4; + int64_t output_width = num_blocks_N * 4; + + // For 4 bit quantization, The blocks are arranged without the transposition, + // such that a weight block is accessed like packed_weights[n8][k4]. This is + // an optimization targeted for LLMs, which need to compute GEMV as well as + // GEMM. This memory layout provides better performance for the co-operative + // algorithm used to compute GEMV, at the cost of slightly reducing GEMM + // performance. + if (weight_quant_config.nbits == 4) { + output_height = num_blocks_N; + output_width = num_blocks_K * 4; + } - // Store the original sizes of the tensor to pass to the shader - utils::ivec2 orig_sizes{ + // Store the original sizes of the weight data to pass to the shader + utils::ivec2 orig_sizes = { utils::safe_downcast(K), utils::safe_downcast(N)}; std::vector qmat2_sizes{output_height, output_width}; @@ -130,13 +257,23 @@ ValueRef prepack_quantized_linear_weight( ValueRef qmat2 = graph.add_tensor( qmat2_sizes, vkcompute::vkapi::kInt, storage_type, utils::kWidthPacked); - // Global workgroup size: each thread writes out two adjacent blocks - utils::uvec3 global_wg_size{ - utils::safe_downcast(num_blocks_N), - utils::safe_downcast(num_blocks_K), - 1u}; + utils::uvec3 global_wg_size; + if (weight_quant_config.nbits == 4) { + // For 4-bit quantization, each thread writes out two adjacent blocks + global_wg_size = { + utils::safe_downcast(utils::div_up(num_blocks_K, int64_t(2))), + utils::safe_downcast(num_blocks_N), + 1u}; + } else { + global_wg_size = { + utils::safe_downcast(num_blocks_N), + utils::safe_downcast(num_blocks_K), + 1u}; + } - std::string kernel_name = "pack_q8_linear_weight"; + std::string kernel_name = weight_quant_config.nbits == 4 + ? "pack_q4_linear_weight" + : "pack_q8_linear_weight"; add_storage_type_suffix(kernel_name, storage_type); graph.prepack_nodes().emplace_back(new PrepackNode( @@ -178,15 +315,12 @@ DynamicDispatchNode make_linear_qw_node( const ValueRef packed_bias, const ValueRef output) { // Only certain quantization types supported at the moment - VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); + VK_CHECK_COND( + weight_quant_config.granularity == kPerChannel || + weight_quant_config.granularity == kPerGroup); VK_CHECK_COND(weight_quant_config.is_symmetric); - VK_CHECK_COND(weight_quant_config.nbits == 8); - - std::string kernel_name = "linear_q8csw_tiled"; - add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + VK_CHECK_COND( + weight_quant_config.nbits == 8 || weight_quant_config.nbits == 4); vkapi::ParamsBindList param_buffers = { graph.sizes_ubo(output), graph.sizes_ubo(fp_input)}; @@ -196,9 +330,18 @@ DynamicDispatchNode make_linear_qw_node( apply_bias = 0; } + int32_t K4_per_group = 0; + if (weight_quant_config.nbits == 4) { + int32_t group_size_val = graph.extract_scalar(group_size); + K4_per_group = utils::div_up(group_size_val, int32_t(4)); + } + + const ValueRef is_4bit_flag = + weight_quant_config.nbits == 4 ? group_size : kDummyValueRef; + return DynamicDispatchNode( graph, - VK_KERNEL_FROM_STR(kernel_name), + pick_linear_qw_shader, quantized_linear_global_wg_size, quantized_linear_local_wg_size, // Inputs and Outputs @@ -210,11 +353,11 @@ DynamicDispatchNode make_linear_qw_node( // Push Constants {}, // Specialization Constants - {apply_bias}, + {apply_bias, K4_per_group}, // Resize args - {}, + {is_4bit_flag, weight_data}, // Resizing Logic - nullptr); + resize_linear_qw_node); } DynamicDispatchNode make_quantize_and_pack_linear_input_node( @@ -546,9 +689,40 @@ void linear_q8csw(ComputeGraph& graph, const std::vector& args) { output); } +void linear_q4gsw(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef group_size = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef output = args.at(idx++); + + const int64_t group_size_val = graph.extract_scalar(group_size); + + QuantizationConfig input_quant_config(32, kNoQuantization, {}); + QuantizationConfig weight_quant_config(4, kPerGroup, {group_size_val}); + + quantized_linear_impl( + graph, + input_quant_config, + weight_quant_config, + fp_input, + kDummyValueRef, // input scale + kDummyValueRef, // input zp + weight_data, + kDummyValueRef, // weight sums + weight_scales_data, + kDummyValueRef, // weight zeros + group_size, // group size + bias_data, + output); +} + REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.linear_q8ta_q8csw.default, linear_q8ta_q8csw); VK_REGISTER_OP(et_vk.linear_q8csw.default, linear_q8csw); + VK_REGISTER_OP(et_vk.linear_q4gsw.default, linear_q4gsw); } } // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index 5ccc83c60e5..fe58055f649 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -92,5 +92,7 @@ if(TARGET vulkan_backend) # Define operator prototypes add_operator_prototype(add) add_operator_prototype(q8csw_linear) + add_operator_prototype(quantized_q4gaw_linear) + add_operator_prototype(quantized_int4_linear) add_operator_prototype(q8csw_conv2d) endif() diff --git a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp new file mode 100644 index 00000000000..805b67c30a2 --- /dev/null +++ b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp @@ -0,0 +1,373 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 300; + +// Linear configuration struct +struct LinearConfig { + int64_t M; // Batch size / number of rows in input + int64_t K; // Input features / columns in input, rows in weight + int64_t N; // Output features / columns in weight + int64_t group_size; // Number of input channels per quantization group + bool has_bias = false; + std::string test_case_name = "placeholder"; + std::string op_name = "linear_q4gsw"; +}; + +// Helper function to unpack 4-bit values from uint8 +std::pair unpack_4bit(uint8_t packed) { + // Extract lower 4 bits and upper 4 bits + int8_t lower = packed & 0x0F; + int8_t upper = (packed >> 4) & 0x0F; + + // Subtract 8 from unpacked 4-bit values + lower -= 8; + upper -= 8; + + return std::make_pair(lower, upper); +} + +// Utility function to create a test case from a LinearConfig +TestCase create_test_case_from_config( + const LinearConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + config.test_case_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "et_vk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Derive sizes from M, K, N + std::vector input_size = {config.M, config.K}; + // Input tensor (float/half) - [M, K] + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + // For 4-bit weights, packed size is [N, K/2] since 2 weights per byte + std::vector weight_size = {config.N, config.K / 2}; + // Quantized weight tensor (uint8, packed 4-bit) - [N, K/2] + ValueSpec quantized_weight( + weight_size, + vkapi::kByte, // uint8 for packed 4-bit quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT4); + quantized_weight.set_constant(true); + quantized_weight.set_int4(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + // Weight quantization scales (float/half, per-group) + // For group symmetric quantization: [K/group_size, N] + // Each group of input features has scales for all output features + std::vector weight_scales_size = { + config.K / config.group_size, config.N}; + ValueSpec weight_scales( + weight_scales_size, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + // Group size parameter + ValueSpec group_size_spec(static_cast(config.group_size)); + + // Bias (optional, float/half) - [N] + ValueSpec bias( + {config.N}, // Per output feature + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + bias.set_constant(true); + if (!config.has_bias) { + bias.set_none(true); + } + + // Output tensor (float/half) - [M, N] + ValueSpec output( + {config.M, config.N}, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + + // Add all specs to test case for linear_q4gsw + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(group_size_spec); + test_case.add_input_spec(bias); + test_case.add_output_spec(output); + + return test_case; +} + +// Generate easy test cases for quantized linear operation (for debugging) +std::vector generate_quantized_linear_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + int M = 4; + int K = 32; + int N = 16; + int group_size = 8; + + LinearConfig config = { + M, // Batch size + K, // Input features + N, // Output features + group_size, // Group size + true, // has_bias + "simple", // test_case_name + }; + + // Test with both storage types and data types for completeness + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized linear operation +std::vector generate_quantized_linear_test_cases() { + std::vector test_cases; + + std::vector configs = { + // Gemv test cases + {1, 128, 64, 32}, + {1, 256, 128, 64}, + // Gemm + {4, 64, 32, 16}, + {4, 128, 64, 32}, + {4, 256, 128, 64}, + {32, 64, 32, 16}, + {32, 128, 64, 32}, + {32, 256, 128, 64}, + // No bias tests + {32, 128, 64, 32, false}, + {32, 256, 128, 64, false}, + // Performance test cases + {1, 2048, 2048, 128}, + {128, 2048, 2048, 128}, + {256, 2048, 2048, 128}, + {1024, 2048, 2048, 128}, + }; + + // Test with different storage types and data types + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + for (auto config : configs) { + std::string prefix = + (config.M < kRefDimSizeLimit && config.K < kRefDimSizeLimit && + config.N < kRefDimSizeLimit) + ? "correctness_" + : "performance_"; + std::string generated_test_case_name = prefix + std::to_string(config.M) + + "_" + std::to_string(config.K) + "_" + std::to_string(config.N) + "_g" + + std::to_string(config.group_size); + if (!config.has_bias) { + generated_test_case_name += "_no_bias"; + } + + config.test_case_name = generated_test_case_name; + + for (const auto& storage_type : storage_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + } + + return test_cases; +} + +// Reference implementation for 4-bit group symmetric weight quantized linear +void linear_q4gsw_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& group_size_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [in_features, out_features/2] + auto output_sizes = + output_spec.get_tensor_sizes(); // [batch_size, out_features] + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = output_sizes[1]; + int64_t group_size = group_size_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + auto& weight_data = weight_spec.get_uint8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate number of output elements + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform quantized linear transformation (matrix multiplication) + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + float sum = 0.0f; + + // Matrix multiplication: output[b][out_f] = sum(input[b][in_f] * + // weight[out_f][in_f]) + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + // Get input value + int64_t input_idx = b * in_features + in_f; + float input_val = input_data[input_idx]; + + // Get weight value and dequantize (4-bit group symmetric quantization) + int64_t group_idx = in_f / group_size; + int64_t scales_idx = group_idx * out_features + out_f; + + // Get packed weight value - weight matrix is [N, K/2] + int64_t weight_idx = (out_f) * (in_features / 2) + (in_f / 2); + uint8_t packed_weight = weight_data[weight_idx]; + + // Unpack 4-bit weight + auto unpacked = unpack_4bit(packed_weight); + int8_t weight_4bit = (in_f % 2 == 0) ? unpacked.first : unpacked.second; + + // Dequantize weight using group symmetric quantization (no zero point) + float weight_scale = weight_scales_data[scales_idx]; + float dequant_weight = static_cast(weight_4bit) * weight_scale; + + sum += input_val * dequant_weight; + } + + // Add bias and store result + if (!bias_spec.is_none()) { + sum += bias_data[out_f]; + } + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = sum; + } + } +} + +void reference_impl(TestCase& test_case) { + linear_q4gsw_reference_impl(test_case); +} + +int64_t quantized_linear_flop_calculator(const TestCase& test_case) { + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = output_sizes[1]; + + // Calculate FLOPs for quantized linear operation + // Each output element requires: + // - in_features multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = batch_size * out_features; + int64_t ops_per_output = in_features; + + // Add quantization overhead (approximate) + // - Unpack 4-bit weight: 1 op per weight element used + // - Dequantize weight: 1 op per weight element used + // - Add bias: 1 op per output element + int64_t quantization_ops = ops_per_output * 2 + 1; // Simplified estimate + + int64_t flop = output_elements * (ops_per_output + quantization_ops); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout + << "4-bit Group Symmetric Weight Quantized Linear Operation Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + // Execute easy test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( + generate_quantized_linear_test_cases, + quantized_linear_flop_calculator, + "QuantizedLinearQ4GSW", + 0, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/quantized_int4_linear.cpp b/backends/vulkan/test/custom_ops/quantized_int4_linear.cpp new file mode 100644 index 00000000000..c125ce2d09c --- /dev/null +++ b/backends/vulkan/test/custom_ops/quantized_int4_linear.cpp @@ -0,0 +1,366 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +// Linear configuration struct +struct LinearConfig { + int64_t M; // Batch size / number of rows in input + int64_t K; // Input features / columns in input, rows in weight + int64_t N; // Output features / columns in weight + int64_t group_size; // Number of input channels per quantization group + std::string name_suffix; + std::string shader_variant_name = "default"; +}; + +// Utility function to create a test case from a LinearConfig +TestCase create_test_case_from_config( + const LinearConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = "QuantizedLinearInt4_" + config.name_suffix + "_" + + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "et_vk.linear_weight_int4.default"; + test_case.set_operator_name(operator_name); + + // Derive sizes from M, K, N + std::vector input_size = {config.M, config.K}; + std::vector weight_size = { + config.N, config.K / 2}; // Packed 4-bit weights + + // Input tensor (float/half) - [M, K] + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ONES); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + // Quantized weight tensor (int8, packed 4-bit) - [N, K/2] + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for packed 4-bit quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::ONES); + quantized_weight.set_constant(true); + quantized_weight.set_int4(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + // Group size parameter + ValueSpec group_size_spec(static_cast(config.group_size)); + + // Weight quantization scales and zeros (float/half, per-group) - + // [K/group_size, N, 2] + std::vector scales_and_zeros_size = { + config.K / config.group_size, config.N, 2}; + ValueSpec scales_and_zeros( + scales_and_zeros_size, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ONES); + scales_and_zeros.set_constant(true); + + if (debugging()) { + print_valuespec_data(scales_and_zeros, "scales_and_zeros"); + } + + // Output tensor (float/half) - [M, N] + ValueSpec output( + {config.M, config.N}, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(group_size_spec); + test_case.add_input_spec(scales_and_zeros); + // Add dummy value for inner_k_tiles (unused but required by operator + // signature) + ValueSpec dummy_inner_k_tiles(static_cast(8)); + test_case.add_input_spec(dummy_inner_k_tiles); + + test_case.add_output_spec(output); + + return test_case; +} + +// Generate easy test cases for quantized linear operation (for debugging) +std::vector generate_quantized_linear_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + int M = 8; + int K = 16; + int N = 16; + int group_size = 8; + + LinearConfig config = { + M, // Batch size + K, // Input features + N, // Output features + group_size, // Group size + "simple", // descriptive name + "default" // shader variant name + }; + + // Test with both storage types and data types for completeness + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized linear operation +std::vector generate_quantized_linear_test_cases() { + std::vector test_cases; + + std::vector configs = { + {8, 64, 32, 8, "correctness_8_64_32_g8"}, + {8, 128, 64, 16, "correctness_8_128_64_g16"}, + {8, 256, 128, 32, "correctness_8_256_128_g32"}, + {32, 64, 32, 8, "correctness_32_64_32_g8"}, + {32, 128, 64, 16, "correctness_32_128_64_g16"}, + {32, 256, 128, 32, "correctness_32_256_128_g32"}, + {1, 256, 128, 32, "correctness_32_256_128_g32"}, + // Performance test cases + {1, 2048, 2048, 128, "performance_128_2048_2048_g128"}, + {128, 2048, 2048, 128, "performance_128_2048_2048_g128"}, + {248, 2048, 2048, 128, "performance_128_2048_2048_g128"}, + {1024, 2048, 2048, 128, "performance_128_2048_2048_g128"}, + // {16384, 576, 128, 32, "performance_16384_576_128_g32"} + }; + + // Test with different storage types and data types + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + // Generate test cases for each combination + for (const auto& config : configs) { + for (const auto& storage_type : storage_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + } + + return test_cases; +} + +// Helper function to unpack 4-bit values from int8 +std::pair unpack_4bit(int8_t packed) { + // Extract lower 4 bits and upper 4 bits + int8_t lower = packed & 0x0F; + int8_t upper = (packed >> 4) & 0x0F; + + // Sign extend from 4-bit to 8-bit + if (lower & 0x08) + lower |= 0xF0; + if (upper & 0x08) + upper |= 0xF0; + + return std::make_pair(lower, upper); +} + +// Reference implementation for quantized linear operation +void quantized_linear_reference_impl(TestCase& test_case) { + static constexpr int64_t kRefDimSizeLimit = 300; + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& group_size_spec = test_case.inputs()[idx++]; + const ValueSpec& scales_and_zeros_spec = test_case.inputs()[idx++]; + // Skip dummy inner_k_tiles + idx++; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [out_features, in_features/2] + auto output_sizes = + output_spec.get_tensor_sizes(); // [batch_size, out_features] + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = output_sizes[1]; + int64_t group_size = group_size_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + auto& weight_data = weight_spec.get_int8_data(); + auto& scales_and_zeros_data = scales_and_zeros_spec.get_float_data(); + + // Calculate number of output elements + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform quantized linear transformation (matrix multiplication) + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + float sum = 0.0f; + + bool should_print = b == 0 && out_f == 0; + should_print = false; + + if (should_print) { + std::cout << "Weights seen: "; + } + + // Matrix multiplication: output[b][out_f] = sum(input[b][in_f] * + // weight[out_f][in_f]) + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + // Get input value + int64_t input_idx = b * in_features + in_f; + float input_val = input_data[input_idx]; + + // Get weight value and dequantize (4-bit group affine quantization) + int64_t group_idx = in_f / group_size; + int64_t scales_and_zeros_idx = group_idx * out_features * 2 + out_f * 2; + + // Get packed weight value + int64_t weight_idx = out_f * (in_features / 2) + (in_f / 2); + int8_t packed_weight = weight_data[weight_idx]; + + // Unpack 4-bit weight + auto unpacked = unpack_4bit(packed_weight); + int8_t weight_4bit = (in_f % 2 == 0) ? unpacked.first : unpacked.second; + + // Dequantize weight using group affine quantization + float weight_scale = scales_and_zeros_data[scales_and_zeros_idx]; + float weight_zero = scales_and_zeros_data[scales_and_zeros_idx + 1]; + float dequant_weight = + (static_cast(weight_4bit) - 8.0f) * weight_scale + + weight_zero; + + if (should_print) { + std::cout << int(weight_4bit) << ", "; + } + + sum += input_val * dequant_weight; + } + + if (should_print) { + std::cout << std::endl; + } + + // Store result + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = sum; + } + } +} + +// Custom FLOP calculator for quantized linear operation +int64_t quantized_linear_flop_calculator(const TestCase& test_case) { + if (test_case.num_inputs() < 4 || test_case.num_outputs() < 1) { + return 0; + } + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = output_sizes[1]; + + // Calculate FLOPs for quantized linear operation + // Each output element requires: + // - in_features multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = batch_size * out_features; + int64_t ops_per_output = in_features; + + // Add quantization overhead (approximate) + // - Dequantize weight: 2 ops per weight element used (unpack + dequantize) + int64_t quantization_ops = ops_per_output * 2; // Simplified estimate + + int64_t flop = output_elements * (ops_per_output + quantization_ops); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Quantized 4-bit Int4 Linear Operation Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = quantized_linear_reference_impl; + + // Execute easy test cases using the new framework with custom FLOP + // calculator + auto results = execute_test_cases( + generate_quantized_linear_test_cases, + quantized_linear_flop_calculator, + "QuantizedLinearInt4", + 0, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/quantized_q4gaw_linear.cpp b/backends/vulkan/test/custom_ops/quantized_q4gaw_linear.cpp new file mode 100644 index 00000000000..084d718b502 --- /dev/null +++ b/backends/vulkan/test/custom_ops/quantized_q4gaw_linear.cpp @@ -0,0 +1,433 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +// Linear configuration struct +struct LinearConfig { + int64_t M; // Batch size / number of rows in input + int64_t K; // Input features / columns in input, rows in weight + int64_t N; // Output features / columns in weight + int64_t group_size; // Number of input channels per quantization group + std::string name_suffix; + std::string shader_variant_name = "default"; +}; + +// Utility function to create a test case from a LinearConfig +TestCase create_test_case_from_config( + const LinearConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = "QuantizedLinear4GAW_" + config.name_suffix + "_" + + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "et_vk.linear_q8ta_q4gaw."; + operator_name += config.shader_variant_name; + test_case.set_operator_name(operator_name); + + // Derive sizes from M, K, N + std::vector input_size = {config.M, config.K}; + std::vector weight_size = { + config.K, config.N / 2}; // Packed 4-bit weights + + // Input tensor (float/half) - [M, K] + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 1.0f; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = 0; + ValueSpec input_zero_point(input_zero_point_val); + + // Group size parameter + ValueSpec group_size_spec(static_cast(config.group_size)); + + // Quantized weight tensor (int8, packed 4-bit) - [K, N/2] + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for packed 4-bit quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT4); + quantized_weight.set_constant(true); + quantized_weight.set_int4(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + // Weight quantization scales (float/half, per-group) - [N, K/group_size] + std::vector weight_scales_size = { + config.N, config.K / config.group_size}; + ValueSpec weight_scales( + weight_scales_size, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + if (debugging()) { + print_valuespec_data(weight_scales, "weight_scales"); + } + + // Weight zeros (int32, per-group) - [N, K/group_size] + ValueSpec weight_zeros( + weight_scales_size, + vkapi::kInt, // int32 for zeros + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_zeros.set_constant(true); + + ValueSpec weight_sums( + {config.N}, // Per output features + vkapi::kFloat, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + int64_t in_features = config.K; + int64_t out_features = config.N; + + ValueSpec orig_OC(static_cast(config.N)); + + // Bias (optional, float/half) - [N] + ValueSpec bias( + {config.N}, // Per output feature + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + bias.set_constant(true); + + // Output tensor (float/half) - [M, N] + ValueSpec output( + {config.M, config.N}, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(weight_zeros); + test_case.add_input_spec(orig_OC); + test_case.add_input_spec(group_size_spec); + test_case.add_input_spec(bias); + + test_case.add_output_spec(output); + + return test_case; +} + +// Generate easy test cases for quantized linear operation (for debugging) +std::vector generate_quantized_linear_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + int M = 4; + int K = 32; + int N = 32; + int group_size = 8; + + LinearConfig config = { + M, // Batch size + K, // Input features + N, // Output features + group_size, // Group size + "simple", // descriptive name + "noint8" // shader variant name + }; + + // Test with both storage types and data types for completeness + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized linear operation +std::vector generate_quantized_linear_test_cases() { + std::vector test_cases; + + std::vector configs = { + {8, 64, 32, 8, "correctness_1_64_32_g8"}, + {8, 128, 64, 16, "correctness_1_128_64_g16"}, + {8, 256, 128, 32, "correctness_1_256_128_g32"}, + {32, 64, 32, 8, "correctness_32_64_32_g8"}, + {32, 128, 64, 16, "correctness_32_128_64_g16"}, + {32, 256, 128, 32, "correctness_32_256_128_g32"}, + {1, 256, 128, 32, "todo"}, + // Performance test cases + {1, 2048, 2048, 128, "todo"}, + {128, 2048, 2048, 128, "performance_128_2048_2048_g64"}, + {248, 2048, 2048, 128, "performance_128_2048_2048_g64"}, + {1024, 2048, 2048, 128, "performance_128_2048_2048_g64"}, + // {16384, 576, 128, 32, "performance_16384_576_128_g32"} + }; + + // Test with different storage types and data types + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + // Generate test cases for each combination + for (const auto& config : configs) { + for (const auto& storage_type : storage_types) { + // Test both with and without shader int8 dot product + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + + // LinearConfig no_int_config = config; + // no_int_config.name_suffix = config.name_suffix + "_noint8"; + // no_int_config.shader_variant_name = "noint8"; + + // test_cases.push_back(create_test_case_from_config( + // no_int_config, storage_type, vkapi::kFloat)); + } + } + + return test_cases; +} + +// Helper function to unpack 4-bit values from int8 +std::pair unpack_4bit(int8_t packed) { + // Extract lower 4 bits and upper 4 bits + int8_t lower = packed & 0x0F; + int8_t upper = (packed >> 4) & 0x0F; + + // Sign extend from 4-bit to 8-bit + if (lower & 0x08) + lower |= 0xF0; + if (upper & 0x08) + upper |= 0xF0; + + return std::make_pair(lower, upper); +} + +// Reference implementation for quantized linear operation +void quantized_linear_reference_impl(TestCase& test_case) { + static constexpr int64_t kRefDimSizeLimit = 300; + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& orig_OC = test_case.inputs()[idx++]; + (void)orig_OC; + const ValueSpec& group_size_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [in_features, out_features/2] + auto output_sizes = + output_spec.get_tensor_sizes(); // [batch_size, out_features] + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = output_sizes[1]; + int64_t group_size = group_size_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& weight_zeros_data = weight_zeros_spec.get_int32_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate number of output elements + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform quantized linear transformation (matrix multiplication) + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + float sum = 0.0f; + + bool should_print = b == 0 && out_f == 0; + should_print = false; + + if (should_print) { + std::cout << "Weights seen: "; + } + + // Matrix multiplication: output[b][out_f] = sum(input[b][in_f] * + // weight[out_f][in_f]) + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + // Get input value and dequantize + int64_t input_idx = b * in_features + in_f; + + float quant_input = + std::round(input_data[input_idx] / input_scale) + input_zero_point; + quant_input = std::min(std::max(quant_input, -128.0f), 127.0f); + float dequant_input = (quant_input - input_zero_point) * input_scale; + + // Get weight value and dequantize (4-bit group affine quantization) + int64_t group_idx = in_f / group_size; + int64_t scales_idx = group_idx * out_features + out_f; + + // Get packed weight value + int64_t weight_idx = in_f * (out_features / 2) + (out_f / 2); + int8_t packed_weight = weight_data[weight_idx]; + + // Unpack 4-bit weight + auto unpacked = unpack_4bit(packed_weight); + int8_t weight_4bit = + (out_f % 2 == 0) ? unpacked.first : unpacked.second; + + // Dequantize weight using group affine quantization + float weight_scale = weight_scales_data[scales_idx]; + int32_t weight_zero = weight_zeros_data[scales_idx]; + float dequant_weight = + (static_cast(weight_4bit) - weight_zero) * weight_scale; + + if (should_print) { + std::cout << int(weight_4bit) << ", "; + } + + sum += dequant_input * dequant_weight; + } + + if (should_print) { + std::cout << std::endl; + } + + // Add bias and store result + sum += bias_data[out_f]; + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = sum; + } + } +} + +// Custom FLOP calculator for quantized linear operation +int64_t quantized_linear_flop_calculator(const TestCase& test_case) { + if (test_case.num_inputs() < 6 || test_case.num_outputs() < 1) { + return 0; + } + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = output_sizes[1]; + + // Calculate FLOPs for quantized linear operation + // Each output element requires: + // - in_features multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = batch_size * out_features; + int64_t ops_per_output = in_features; + + // Add quantization overhead (approximate) + // - Dequantize input: 1 op per input element used + // - Dequantize weight: 2 ops per weight element used (unpack + dequantize) + // - Add bias: 1 op per output element + int64_t quantization_ops = ops_per_output * 2 + 1; // Simplified estimate + + int64_t flop = output_elements * (ops_per_output + quantization_ops); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout + << "Quantized 4-bit Group Affine Weights Linear Operation Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = quantized_linear_reference_impl; + + // Execute easy test cases using the new framework with custom FLOP + // calculator + auto results = execute_test_cases( + generate_quantized_linear_test_cases, + quantized_linear_flop_calculator, + "QuantizedLinear4GAW", + 0, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 5d99f90ec5a..3162857c2d3 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -96,3 +96,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("q8csw_linear") define_custom_op_test_binary("q8csw_conv2d") define_custom_op_test_binary("choose_qparams_per_row") + define_custom_op_test_binary("q4gsw_linear") diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index ee2f6858025..37e0060b3f2 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -46,6 +46,8 @@ void generate_random_uint8_data( std::vector& data, uint8_t min_val = 0, uint8_t max_val = 255); +void generate_random_2xint4_data(std::vector& data); +void generate_random_2xint4_data(std::vector& data); void generate_random_int4_data( std::vector& data, int8_t min_val = -8, @@ -186,9 +188,12 @@ void ValueSpec::generate_tensor_data() { } else if (data_gen_type == DataGenType::RANDINT8) { generate_random_int8_data(int8_data, -128, 127); } else if (data_gen_type == DataGenType::RANDINT4) { - generate_random_int4_data(int8_data); + generate_random_2xint4_data(int8_data); } else if (data_gen_type == DataGenType::ONES) { std::fill(int8_data.begin(), int8_data.end(), 1); + } else if (data_gen_type == DataGenType::ONES_INT4) { + int8_t packed_data = (1 << 4) | 1; + std::fill(int8_data.begin(), int8_data.end(), packed_data); } else if (data_gen_type == DataGenType::ZEROS) { std::fill(int8_data.begin(), int8_data.end(), 0); } else { @@ -205,7 +210,7 @@ void ValueSpec::generate_tensor_data() { } else if (data_gen_type == DataGenType::RANDINT8) { generate_random_uint8_data(uint8_data, 0, 255); } else if (data_gen_type == DataGenType::RANDINT4) { - generate_random_uint8_data(uint8_data, 0, 15); + generate_random_2xint4_data(uint8_data); } else if (data_gen_type == DataGenType::ONES) { std::fill(uint8_data.begin(), uint8_data.end(), 1); } else if (data_gen_type == DataGenType::ZEROS) { @@ -564,6 +569,30 @@ void generate_random_int4_data( } } +void generate_random_2xint4_data(std::vector& data) { + std::mt19937 gen(get_seed()); + std::uniform_int_distribution dis(-8, 7); // Signed 4-bit range + for (auto& val : data) { + // Generate two separate 4-bit values + int8_t lower_4bits = static_cast(dis(gen)) & 0x0F; + int8_t upper_4bits = static_cast(dis(gen)) & 0x0F; + // Pack them into a single 8-bit value + val = (upper_4bits << 4) | lower_4bits; + } +} + +void generate_random_2xint4_data(std::vector& data) { + std::mt19937 gen(get_seed()); + std::uniform_int_distribution dis(0, 15); // Unsigned 4-bit range + for (auto& val : data) { + // Generate two separate 4-bit values + uint8_t lower_4bits = static_cast(dis(gen)) & 0x0F; + uint8_t upper_4bits = static_cast(dis(gen)) & 0x0F; + // Pack them into a single 8-bit value + val = (upper_4bits << 4) | lower_4bits; + } +} + void generate_zeros_data(std::vector& data) { std::fill(data.begin(), data.end(), 0.0f); } @@ -1442,19 +1471,69 @@ void print_valuespec_data( } case vkapi::kChar: { const auto& data = spec.get_int8_data(); - for (size_t i = 0; i < print_count; ++i) { - std::cout << static_cast(data[i]); - if (i < print_count - 1) - std::cout << ", "; + if (spec.is_int4()) { + // Print each 4-bit value individually + size_t element_count = 0; + for (size_t i = 0; i < data.size() && element_count < print_count; + ++i) { + // Extract lower 4 bits (signed) + int8_t lower_4bits = data[i] & 0x0F; + if (lower_4bits > 7) + lower_4bits -= 16; // Convert to signed + std::cout << static_cast(lower_4bits); + element_count++; + + if (element_count < print_count) { + std::cout << ", "; + // Extract upper 4 bits (signed) + int8_t upper_4bits = (data[i] >> 4) & 0x0F; + if (upper_4bits > 7) + upper_4bits -= 16; // Convert to signed + std::cout << static_cast(upper_4bits); + element_count++; + + if (element_count < print_count) + std::cout << ", "; + } + } + } else { + for (size_t i = 0; i < print_count; ++i) { + std::cout << static_cast(data[i]); + if (i < print_count - 1) + std::cout << ", "; + } } break; } case vkapi::kByte: { const auto& data = spec.get_uint8_data(); - for (size_t i = 0; i < print_count; ++i) { - std::cout << static_cast(data[i]); - if (i < print_count - 1) - std::cout << ", "; + if (spec.is_int4()) { + // Print each 4-bit value individually + size_t element_count = 0; + for (size_t i = 0; i < data.size() && element_count < print_count; + ++i) { + // Extract lower 4 bits + uint8_t lower_4bits = data[i] & 0x0F; + std::cout << static_cast(lower_4bits); + element_count++; + + if (element_count < print_count) { + std::cout << ", "; + // Extract upper 4 bits + uint8_t upper_4bits = (data[i] >> 4) & 0x0F; + std::cout << static_cast(upper_4bits); + element_count++; + + if (element_count < print_count) + std::cout << ", "; + } + } + } else { + for (size_t i = 0; i < print_count; ++i) { + std::cout << static_cast(data[i]); + if (i < print_count - 1) + std::cout << ", "; + } } break; } diff --git a/backends/vulkan/test/custom_ops/utils.h b/backends/vulkan/test/custom_ops/utils.h index 6c4e2263fc1..2440e225ef2 100644 --- a/backends/vulkan/test/custom_ops/utils.h +++ b/backends/vulkan/test/custom_ops/utils.h @@ -54,6 +54,7 @@ enum class DataGenType { RANDINT8, RANDINT4, ONES, + ONES_INT4, ZEROS }; @@ -67,6 +68,7 @@ struct ValueSpec { DataGenType data_gen_type; bool is_constant_tensor; bool is_none_flag; + bool is_int4_tensor; std::vector float_data; std::vector int32_data; @@ -92,7 +94,8 @@ struct ValueSpec { spec_type(SpecType::Tensor), data_gen_type(DataGenType::ZEROS), is_constant_tensor(false), - is_none_flag(false) { + is_none_flag(false), + is_int4_tensor(false) { generate_tensor_data(); } @@ -110,7 +113,8 @@ struct ValueSpec { spec_type(SpecType::Tensor), data_gen_type(data_gen_type), is_constant_tensor(false), - is_none_flag(false) { + is_none_flag(false), + is_int4_tensor(false) { generate_tensor_data(); } @@ -123,7 +127,8 @@ struct ValueSpec { spec_type(SpecType::Int), data_gen_type(DataGenType::FIXED), is_constant_tensor(false), - is_none_flag(false) { + is_none_flag(false), + is_int4_tensor(false) { int32_data.push_back(value); } @@ -136,7 +141,8 @@ struct ValueSpec { spec_type(SpecType::Float), data_gen_type(DataGenType::FIXED), is_constant_tensor(false), - is_none_flag(false) { + is_none_flag(false), + is_int4_tensor(false) { float_data.push_back(value); } @@ -149,7 +155,8 @@ struct ValueSpec { spec_type(SpecType::Bool), data_gen_type(DataGenType::FIXED), is_constant_tensor(false), - is_none_flag(false) { + is_none_flag(false), + is_int4_tensor(false) { int32_data.push_back(value ? 1 : 0); } @@ -163,6 +170,7 @@ struct ValueSpec { data_gen_type(DataGenType::FIXED), is_constant_tensor(false), is_none_flag(false), + is_int4_tensor(false), int32_data(values) {} // Default constructor @@ -173,7 +181,8 @@ struct ValueSpec { spec_type(SpecType::Tensor), data_gen_type(DataGenType::ZEROS), is_constant_tensor(false), - is_none_flag(false) {} + is_none_flag(false), + is_int4_tensor(false) {} int64_t numel() const; size_t nbytes() const; @@ -291,10 +300,19 @@ struct ValueSpec { bool is_none() const { return is_none_flag; } + void set_none(bool is_none) { is_none_flag = is_none; } + // Set/get int4 flag + bool is_int4() const { + return is_int4_tensor; + } + void set_int4(bool is_int4) { + is_int4_tensor = is_int4; + } + const void* get_data_ptr() const; // Correctness checking against reference data