From 22aefab05a1504046989b2229529ee1091161a6b Mon Sep 17 00:00:00 2001 From: ssjia Date: Sun, 7 Sep 2025 11:39:50 -0700 Subject: [PATCH] [ET-VK] Quantized Int8 Linear Pull Request resolved: https://github.com/pytorch/executorch/pull/13816 Title says it all! This PR adds implementations for int8 linear layers. Convolution is implemented in a later step, computing convolution as matrix multiplication via the im2col procedure. For both linear and convolution, two versions are implemented: 1. `q8ta_q8csw` variant which quantized the input tensor and then performs integer accumulation via the int8 dot product extension 2. `q8csw` variant which dequantized the weight tensor in-shader and performs floating point accumulation. The second one is needed to provide an alternative path for executing quantized models if the target GPU does not support int8 dot product extension. These new ops are tested via the custom op testing + benchmarking framework introduced in the previous diff. ghstack-source-id: 308092878 @exported-using-ghexport Differential Revision: [D81323424](https://our.internmc.facebook.com/intern/diff/D81323424/) --- .github/workflows/pull.yml | 2 + backends/vulkan/runtime/graph/ComputeGraph.h | 4 + .../runtime/graph/ops/glsl/common.glslh | 51 ++ .../graph/ops/glsl/linear_common.glslh | 32 + .../graph/ops/glsl/linear_fp_bias_load.glslh | 30 + .../graph/ops/glsl/linear_fp_input_tile.glslh | 45 ++ .../ops/glsl/linear_fp_input_tile_load.glslh | 91 +++ .../ops/glsl/linear_fp_output_tile.glslh | 61 ++ .../linear_fp_output_tile_fp_compute.glslh | 126 ++++ ...inear_fp_output_tile_fp_int8_compute.glslh | 68 +++ ...ear_fp_output_tile_int8_int8_compute.glslh | 179 ++++++ .../glsl/linear_fp_output_tile_store.glslh | 114 ++++ .../linear_fp_per_out_channel_params.glslh | 43 ++ .../glsl/linear_fp_weight_scales_load.glslh | 32 + .../ops/glsl/linear_fp_weight_tile.glslh | 103 ++++ .../ops/glsl/linear_int8_input_block.glslh | 84 +++ .../ops/glsl/linear_int8_input_tile.glslh | 62 ++ .../glsl/linear_int8_input_tile_load.glslh | 75 +++ .../ops/glsl/linear_int8_weight_block.glslh | 99 ++++ .../ops/glsl/linear_int8_weight_tile.glslh | 60 ++ .../glsl/linear_int8_weight_tile_load.glslh | 78 +++ .../linear_int_per_out_channel_params.glslh | 44 ++ .../glsl/linear_int_weight_sums_load.glslh | 32 + .../graph/ops/glsl/linear_q8csw_tiled.glsl | 113 ++++ .../graph/ops/glsl/linear_q8csw_tiled.yaml | 28 + .../ops/glsl/linear_q8ta_q8csw_tiled.glsl | 132 +++++ .../ops/glsl/linear_q8ta_q8csw_tiled.yaml | 24 + .../graph/ops/glsl/pack_q8_linear_weight.glsl | 60 ++ .../graph/ops/glsl/pack_q8_linear_weight.yaml | 14 + .../glsl/quantize_and_pack_linear_input.glsl | 79 +++ .../glsl/quantize_and_pack_linear_input.yaml | 24 + .../graph/ops/impl/QuantizedLinear.cpp | 554 ++++++++++++++++++ .../runtime/graph/ops/impl/QuantizedLinear.h | 29 + .../graph/ops/impl/utils/QuantizationConfig.h | 58 ++ .../vulkan/test/custom_ops/CMakeLists.txt | 1 + .../vulkan/test/custom_ops/q8csw_linear.cpp | 479 +++++++++++++++ backends/vulkan/test/custom_ops/targets.bzl | 5 +- backends/vulkan/test/custom_ops/utils.cpp | 30 +- backends/vulkan/test/custom_ops/utils.h | 42 +- 39 files changed, 3166 insertions(+), 21 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/common.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_bias_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_per_out_channel_params.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int_per_out_channel_params.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int_weight_sums_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h create mode 100644 backends/vulkan/runtime/graph/ops/impl/utils/QuantizationConfig.h create mode 100644 backends/vulkan/test/custom_ops/q8csw_linear.cpp diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 379d47716c9..084cfe17a4d 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -929,7 +929,9 @@ jobs: CMAKE_ARGS="-DEXECUTORCH_BUILD_VULKAN=ON" \ .ci/scripts/setup-linux.sh --build-tool "cmake" + # Custom operator tests PYTHON_EXECUTABLE=python bash backends/vulkan/test/custom_ops/build_and_run.sh add + ./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear nxp-build-test: name: nxp-build-test diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 78fb79e65e8..4e9e2d36e1e 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -308,6 +308,10 @@ class ComputeGraph final { return idx == kDummyValueRef ? true : values_.at(idx).isNone(); } + inline bool val_is_not_none(const ValueRef idx) { + return !val_is_none(idx); + } + inline TypeTag get_val_type(const ValueRef idx) { return values_.at(idx).type(); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh new file mode 100644 index 00000000000..732b7006c2c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef COMMON_GLSLH +#define COMMON_GLSLH + +#define mul_2(x) ((x) << 1) +#define mul_4(x) ((x) << 2) +#define mul_8(x) ((x) << 3) + +#define div_2(x) ((x) >> 1) +#define div_4(x) ((x) >> 2) +#define div_8(x) ((x) >> 3) + +#define div_up_2(x) (((x) + 1) >> 1) +#define div_up_4(x) (((x) + 3) >> 2) +#define div_up_8(x) (((x) + 7) >> 3) + +#define align_up_2(x) ((x + 1) & -2) +#define align_up_4(x) ((x + 3) & -4) +#define align_up_8(x) ((x + 7) & -8) + +#define mod_2(x) ((x) & 1) +#define mod_4(x) ((x) & 3) +#define mod_8(x) ((x) & 7) + +struct TensorIndex4D { + ivec4 data; +}; + +#ifdef DEBUG_MODE + +#extension GL_EXT_debug_printf : require + +void printTensorIndex4D(const TensorIndex4D index) { + debugPrintfEXT( + "tensor_idx: %d, %d, %d, %d\\n", + index.data.x, + index.data.y, + index.data.z, + index.data.w); +} + +#endif // DEBUG_MODE + +#endif // COMMON_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh new file mode 100644 index 00000000000..90ede450ae7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines common functions and structs to be used across matrix multiplication + * operators. + */ + +#ifndef LINEAR_COMMON_GLSLH +#define LINEAR_COMMON_GLSLH + +#include "common.glslh" + +int sign_extend_8bit(const int val) { + if ((val & 0x80) != 0) { + return val | (~0xFF); + } + return val; +} + +int extract_8bit_from_packed_int_le(const int packed, const int i) { + // account for little endian + int byte = sign_extend_8bit(packed >> (8 * i) & 0xFF); + return byte; +} + +#endif // LINEAR_COMMON_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_bias_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_bias_load.glslh new file mode 100644 index 00000000000..f3d32be8b3d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_bias_load.glslh @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_FP_BIAS_LOAD_GLSLH +#define LINEAR_FP_BIAS_LOAD_GLSLH + +#include "linear_fp_per_out_channel_params.glslh" + +VEC4_T load_bias_x4(const int n4) { + return t_bias[n4]; +} + +void load_bias_tile(out FPPerOutChannelParams bias, const int n4_start) { +#if TILE_N4 == 1 + bias.data[0] = load_bias_x4(n4_start); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + bias.data[n4] = load_bias_x4(n4_start + n4); + } + +#endif +} + +#endif // LINEAR_FP_BIAS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh new file mode 100644 index 00000000000..68eee57a132 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_FP_INPUT_TILE_GLSLH +#define LINEAR_FP_INPUT_TILE_GLSLH + +/* + * Defines the FPInputTile struct, which is used to represent a tile of the + * input matrix of a matrix multiplication operation. + * + * Settings: + * - TILE_M: number of rows in the tile + * - TILE_K4: number of (groups of 4) columns in the tile + */ + +#extension GL_EXT_control_flow_attributes : require + +struct FPInputTile { + VEC4_T data[TILE_M][TILE_K4]; +}; + +#ifdef DEBUG_MODE + +void printFPInputTile(const FPInputTile in_tile) { + debugPrintfEXT("input_tile: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + in_tile.data[m][k4].x, + in_tile.data[m][k4].y, + in_tile.data[m][k4].z, + in_tile.data[m][k4].w); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_INPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh new file mode 100644 index 00000000000..6697003935f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to load a FPInputTile from input buffer/texture. + * + * Requires: + * - t_input to be declared in the shader layout (input buffer/texture) + * + * Settings: + * - INPUT_BUFFER to indicate input resource is a buffer, otherwise texture is + * assumed. + */ + +#ifndef LINEAR_FP_INPUT_TILE_LOAD_GLSLH +#define LINEAR_FP_INPUT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_input_tile.glslh" + +#ifdef INPUT_BUFFER + +VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) { + return t_input[(m * ntexels_k) + k4]; +} + +#else + +VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) { + return texelFetch(t_input, ivec3(k4, m, 0), 0); +} + +#endif // INPUT_BUFFER + +// To be used if (M - m_start >= TILE_M) || (K4 - k4_start >= TILE_K4) +void load_input_tile_no_checks( + out FPInputTile in_tile, + const int k4_start, + const int m_start, + const int K4, + const int M) { +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + in_tile.data[m][0] = load_input_x4(k4_start, m_start + m, K4); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4); + } + } +#endif +} + +// To be used if near tensor boundaries +void load_input_tile_with_checks( + out FPInputTile in_tile, + const int k4_start, + const int m_start, + const int K4, + const int M) { +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + if (m_start + m < M) { + in_tile.data[m][0] = load_input_x4(k4_start, m_start + m, K4); + } else { + in_tile.data[m][0] = VEC4_T(0.0); + } + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + if (m_start + m < M && k4_start + k4 < K4) { + in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4); + } else { + in_tile.data[m][k4] = VEC4_T(0.0); + } + } + } +#endif +} + +#endif // LINEAR_FP_INPUT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh new file mode 100644 index 00000000000..049f1d34caf --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines the FPOutTile struct, which is used to represent a tile of the output + * matrix of a matrix multiplication operation. + * + * Settings: + * - TILE_M: number of rows in the output tile + * - TILE_N4: number of (groups of 4) columns in the output tile + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct FPOutTile { + VEC4_T data[TILE_M][TILE_N4]; +}; + +void initialize(out FPOutTile out_tile) { +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + out_tile.data[m][0] = VEC4_T(0); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + out_tile.data[m][n4] = VEC4_T(0); + } + } +#endif +} + +#ifdef DEBUG_MODE + +void printFPOutTile(const FPOutTile tile) { + debugPrintfEXT("output_tile: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %f, %f, %f, %f,", + tile.data[m][n4].x, + tile.data[m][n4].y, + tile.data[m][n4].z, + tile.data[m][n4].w); + } + debugPrintfEXT("\\n"); + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_OUTPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh new file mode 100644 index 00000000000..7229da32cd3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh @@ -0,0 +1,126 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using fp input and weight tiles. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_common.glslh" +#include "linear_fp_input_tile.glslh" +#include "linear_fp_output_tile.glslh" +#include "linear_fp_per_out_channel_params.glslh" +#include "linear_fp_weight_tile.glslh" + +/* + * Accumulates floating point input tile and floating point weight tile into + * floating point output tile. + */ +void fp_accumulate_with_fp_weight( + inout FPOutTile accum, + FPInputTile in_tile, + FPWeightTile w_tile) { +#if TILE_N4 == 1 && TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + accum.data[m][0] = + fma(VEC4_T(in_tile.data[m][0][0]), + w_tile.data[mul_4(0)][0], + accum.data[m][0]); + + accum.data[m][0] = + fma(VEC4_T(in_tile.data[m][0][1]), + w_tile.data[mul_4(0) + 1][0], + accum.data[m][0]); + + accum.data[m][0] = + fma(VEC4_T(in_tile.data[m][0][2]), + w_tile.data[mul_4(0) + 2][0], + accum.data[m][0]); + + accum.data[m][0] = + fma(VEC4_T(in_tile.data[m][0][3]), + w_tile.data[mul_4(0) + 3][0], + accum.data[m][0]); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int n = mul_4(n4); + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][0]), + w_tile.data[mul_4(k4)][n4], + accum.data[m][n4]); + + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][1]), + w_tile.data[mul_4(k4) + 1][n4], + accum.data[m][n4]); + + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][2]), + w_tile.data[mul_4(k4) + 2][n4], + accum.data[m][n4]); + + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][3]), + w_tile.data[mul_4(k4) + 3][n4], + accum.data[m][n4]); + } + } + } + +#endif +} + +/* + * Applies per output channel weight scales to the output tile. + */ +void apply_scales(inout FPOutTile tile, const FPPerOutChannelParams scales) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + tile.data[m][0] = tile.data[m][0] * scales.data[0]; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m][n4] = tile.data[m][n4] * scales.data[n4]; + } + } +#endif +} + +/* + * Applies per output channel weight scales and per output channel biases to the + * output tile. + */ +void apply_scales_and_biases( + inout FPOutTile tile, + const FPPerOutChannelParams scales, + const FPPerOutChannelParams bias) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + tile.data[m][0] = tile.data[m][0] * scales.data[0] + bias.data[0]; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m][n4] = tile.data[m][n4] * scales.data[n4] + bias.data[n4]; + } + } +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh new file mode 100644 index 00000000000..b2ab64a1573 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using fp input and weight tiles. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_FP_INT8_COMPUTE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_FP_INT8_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_common.glslh" +#include "linear_fp_input_tile.glslh" +#include "linear_fp_output_tile.glslh" +#include "linear_int8_weight_tile.glslh" + +// Unpacks a int containing 4 packed 8-bit integers into a vec4 containing each +// of the 4 unpacked 8-bit integers. +VEC4_T unpack_packed_4xint8(int int8x4) { + return VEC4_T( + extract_8bit_from_packed_int_le(int8x4, 0), + extract_8bit_from_packed_int_le(int8x4, 1), + extract_8bit_from_packed_int_le(int8x4, 2), + extract_8bit_from_packed_int_le(int8x4, 3)); +} + +void fp_accumulate_with_int8_weight( + inout FPOutTile accum, + FPInputTile in_tile, + Int8WeightTile w_tile) { + // Accum tile is indexed as accum[m][n4][n4i] + // -> gives fp accumulator for output tile element at (x = n, y = m) + // Input tile is indexed as in_tile.data[m][k4] + // -> gives vec4 containing the fp inputs at index + // (k, m), (k + 1, m), (k + 2, m), (k + 3, m) + // Weight tile is indexed as w_tile.data[k4][n4][n4i] + // -> gives packed integer containing the 4x 8-bit quantized values at index + // (n, k), (n, k + 1), (n, k + 2), (n, k + 3) + VEC4_T weight_texel; +#if TILE_K4 == 1 && TILE_N4 == 1 + [[unroll]] for (int k = 0; k < 4; ++k) { + // Unpack one column of weights + weight_texel = VEC4_T( + extract_8bit_from_packed_int_le(w_tile.data[0][0][0], k), + extract_8bit_from_packed_int_le(w_tile.data[0][0][1], k), + extract_8bit_from_packed_int_le(w_tile.data[0][0][2], k), + extract_8bit_from_packed_int_le(w_tile.data[0][0][3], k)); + + for (int m = 0; m < TILE_M; ++m) { + accum.data[m][0] = + fma(VEC4_T(in_tile.data[m][0][k]), weight_texel, accum.data[m][0]); + } + } + +#else + // TODO(ssjia): implement the general case + not implemented + +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_FP_INT8_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh new file mode 100644 index 00000000000..b04074eba75 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh @@ -0,0 +1,179 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using int8 input and weight tiles. + * + * Settings: + * - TILE_M: The number of rows in the output tile. + * - TILE_N4: The number of (groups of 4) columns in the output tile. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#include "linear_common.glslh" +#include "linear_fp_output_tile.glslh" +#include "linear_fp_per_out_channel_params.glslh" +#include "linear_int8_input_tile.glslh" +#include "linear_int8_weight_tile.glslh" +#include "linear_int_per_out_channel_params.glslh" + +// Stores integer accumulators for an output tile. +struct Int32Accum { + ivec4 data[TILE_M][TILE_N4]; +}; + +// Initialize values to 0 +void initialize(out Int32Accum out_accum) { +#if TILE_N4 == 1 + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + out_accum.data[y][0] = ivec4(0); + } + +#else + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + [[unroll]] for (int x4 = 0; x4 < TILE_K4; ++x4) { + out_accum.data[y][x4] = ivec4(0); + } + } +#endif +} + +// Accumulate int8 input and weight tiles into integer accumulator tile +void int_accumulate_with_int8_weight( + inout Int32Accum accum, + Int8InputTile in_tile, + Int8WeightTile w_tile) { + // Accum tile is indexed as accum[m][n4][n4i] + // -> gives integer accumulator for output tile element at (x = n, y = m) + // Input tile is indexed as in_tile.data[m4][k4][m4i] + // -> gives packed integer containing the 4x 8-bit quantized values at index + // (k, m), (k + 1, m), (k + 2, m), (k + 3, m) + // Weight tile is indexed as w_tile.data[k4][n4][n4i] + // -> gives packed integer containing the 4x 8-bit quantized values at index + // (n, k), (n, k + 1), (n, k + 2), (n, k + 3) +#if TILE_M4 == 1 && TILE_K4 == 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + // n = 0 + accum.data[m][0][0] = dotPacked4x8AccSatEXT( + in_tile.data[0][0][m], w_tile.data[0][0][0], accum.data[m][0][0]); + // n = 1 + accum.data[m][0][1] = dotPacked4x8AccSatEXT( + in_tile.data[0][0][m], w_tile.data[0][0][1], accum.data[m][0][1]); + // n = 2 + accum.data[m][0][2] = dotPacked4x8AccSatEXT( + in_tile.data[0][0][m], w_tile.data[0][0][2], accum.data[m][0][2]); + // n = 3 + accum.data[m][0][3] = dotPacked4x8AccSatEXT( + in_tile.data[0][0][m], w_tile.data[0][0][3], accum.data[m][0][3]); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + const int m4 = div_4(m); + const int m4i = mod_4(m); + [[unroll]] for (int n = 0; n < TILE_N; ++n) { + const int n4 = div_4(n); + const int n4i = mod_4(n); + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + accum.data[m][n4][n4i] = dotPacked4x8AccSatEXT( + in_tile.data[m4][k4][m4i], + w_tile.data[k4][n4][n4i], + accum.data[m][n4][n4i]); + } + } + } + +#endif +} + +/* + * Computes final weight matrix output tile using: + * - int8 accumulator tile + * - per output channel weight sums + * - per output channel scales + */ +void accumulate_out_tile_with_int_accum( + inout FPOutTile out_tile, + const Int32Accum accum, + const float input_q_scale, + const int input_q_zp, + const IntPerOutChannelParams weight_sums, + const FPPerOutChannelParams weight_scales) { + ivec4 input_zp_vec = ivec4(-input_q_zp); +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + // Unfortunately fma doesn't work with ivec4. Prefer to preserve integer + // format for as long as possible to avoid precision loss. + ivec4 accum_adjusted = + input_zp_vec * weight_sums.data[0] + accum.data[m][0]; + out_tile.data[m][0] = + fma(VEC4_T(accum_adjusted), + input_q_scale * weight_scales.data[0], + out_tile.data[m][0]); + } + +#else + // TODO(ssjia): Implement the general case + not implemented + +#endif +} + +void accumulate_out_tile_with_int_accum( + inout FPOutTile out_tile, + const Int32Accum accum, + const float input_q_scale, + const int input_q_zp, + const IntPerOutChannelParams weight_sums, + const FPPerOutChannelParams weight_scales, + const FPPerOutChannelParams bias) { + ivec4 input_zp_vec = ivec4(-input_q_zp); +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + // Apply scale and zero points to the int accumulator + ivec4 accum_adjusted = + input_zp_vec * weight_sums.data[0] + accum.data[m][0]; + out_tile.data[m][0] = + fma(VEC4_T(accum_adjusted), + input_q_scale * weight_scales.data[0], + out_tile.data[m][0]); + out_tile.data[m][0] += bias.data[0]; + } + +#else + // TODO(ssjia): Implement the general case + not implemented + +#endif +} + +#ifdef DEBUG_MODE + +void printInt32Accum(const Int32Accum tile) { + debugPrintfEXT("int accum: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %d, %d, %d, %d,", + tile.data[m][n4].x, + tile.data[m][n4].y, + tile.data[m][n4].z, + tile.data[m][n4].w); + } + debugPrintfEXT("\\n"); + } +} + +#endif + +#endif // LINEAR_FP_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh new file mode 100644 index 00000000000..a4019204cc3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions store a FpOutTile to output buffer/texture. + * + * Requires: + * - t_output to be declared in the shader layout + * + * Settings: + * - OUTPUT_BUFFER to indicate t_output is a vec4 buffer, otherwise texture + * storage is assumed. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_STORE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_STORE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_output_tile.glslh" + +#ifdef OUTPUT_BUFFER + +void write_output_x4( + const VEC4_T out_texel, + const int n4, + const int m, + const int N4) { + t_output[m * N4 + n4] = out_texel; +} + +#else + +void write_output_x4( + const VEC4_T out_texel, + const int n4, + const int m, + const int N4) { + imageStore(t_output, ivec3(n4, m, 0), out_texel); +} + +#endif // OUTPUT_BUFFER + +void write_output_tile( + const FPOutTile out_tile, + const int n4_start, + const int m_start, + const int N4) { +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); + } + } +#endif +} + +// To be used if M - m >= TILE_M && N4 - n4 >= TILE_N4 +void write_output_tile_no_checks( + const FPOutTile out_tile, + const int n4_start, + const int m_start, + const int N4, + const int M) { +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); + } + } +#endif +} + +// To be used if close to tensor boundaries +void write_output_tile_with_checks( + const FPOutTile out_tile, + const int n4_start, + const int m_start, + const int N4, + const int M) { +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + if (m_start + m < M) { + write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); + } + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + if (m_start + m < M && n4_start + n4 < N4) { + write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); + } + } + } +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_per_out_channel_params.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_per_out_channel_params.glslh new file mode 100644 index 00000000000..72b22988414 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_per_out_channel_params.glslh @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines common functions and structs to be used across matrix multiplication + * operators. + */ + +#ifndef LINEAR_FP_PER_OUT_CHANNEL_PARAMS_GLSLH +#define LINEAR_FP_PER_OUT_CHANNEL_PARAMS_GLSLH + +#include "common.glslh" + +#extension GL_EXT_control_flow_attributes : require + +// Represents floating point parameter tensors where each element is associated +// with an output channel, such as weight scales, biases, etc. +struct FPPerOutChannelParams { + VEC4_T data[TILE_N4]; +}; + +#ifdef DEBUG_MODE + +void printFPPerOutChannelParams(const FPPerOutChannelParams params) { + debugPrintfEXT("per_out_channel_params: \\n"); + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + params.data[n4].x, + params.data[n4].y, + params.data[n4].z, + params.data[n4].w); + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_PER_OUT_CHANNEL_PARAMS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh new file mode 100644 index 00000000000..0cba49e87c7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_FP_WEIGHT_SCALES_LOAD_GLSLH +#define LINEAR_FP_WEIGHT_SCALES_LOAD_GLSLH + +#include "linear_fp_per_out_channel_params.glslh" + +VEC4_T load_weight_scale_x4(const int n4) { + return t_weight_scales[n4]; +} + +void load_weight_scales_tile( + out FPPerOutChannelParams scales, + const int n4_start) { +#if TILE_N4 == 1 + scales.data[0] = load_weight_scale_x4(n4_start); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + scales.data[n4] = load_weight_scale_x4(n4_start + n4); + } + +#endif +} + +#endif // LINEAR_FP_WEIGHT_SCALES_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh new file mode 100644 index 00000000000..f44bbbc1565 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh @@ -0,0 +1,103 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines the FPWeightTile struct, which is used to represent a fp tile of a + * weight matrix in matrix multiplication. + * + * Settings: + * - TILE_K: number of rows in the output tile + * - TILE_N4: number of (groups of 4) columns in the output tile + */ + +#ifndef LINEAR_FP_WEIGHT_TILE_GLSLH +#define LINEAR_FP_WEIGHT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "common.glslh" + +struct FPWeightTile { + VEC4_T data[TILE_K][TILE_N4]; +}; + +#ifdef LINEAR_INT8_WEIGHT_TILE_GLSLH + +int sign_extend(const int val) { + if ((val & 0x80) != 0) { + return val | (~0xFF); + } + return val; +} + +T extract_8bit_value(const Int8WeightTile w_tile, const int k, const int n) { +#if TILE_K4 == 1 && TILE_N4 == 1 + const int k4i = k; + const int n4i = n; + ivec4 block = w_tile.data[0][0]; + +#else + const int k4 = div_4(k); + const int k4i = mod_4(k); + + const int n4 = div_4(n); + const int n4i = mod_4(n); + + ivec4 block = w_tile.data[k4][n4]; +#endif + + int col = block[n4i]; + int val = (col >> (k4i * 8)) & 0xFF; + + return T(sign_extend(val)); +} + +void unpack(out FPWeightTile fp_w_tile, const Int8WeightTile w_tile) { +#if TILE_K > 1 && TILE_N4 == 1 + [[unroll]] for (int k = 0; k < TILE_K; ++k) { + fp_w_tile.data[k][0][0] = extract_8bit_value(w_tile, k, 0); + fp_w_tile.data[k][0][1] = extract_8bit_value(w_tile, k, 1); + fp_w_tile.data[k][0][2] = extract_8bit_value(w_tile, k, 2); + fp_w_tile.data[k][0][3] = extract_8bit_value(w_tile, k, 3); + } + +#else + [[unroll]] for (int k = 0; k < TILE_M; ++k) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int n = mul_4(n4); + fp_w_tile.data[k][n4][0] = extract_8bit_value(w_tile, k, n); + fp_w_tile.data[k][n4][1] = extract_8bit_value(w_tile, k, n + 1); + fp_w_tile.data[k][n4][2] = extract_8bit_value(w_tile, k, n + 2); + fp_w_tile.data[k][n4][3] = extract_8bit_value(w_tile, k, n + 3); + } + } +#endif +} + +#endif // LINEAR_INT8_WEIGHT_TILE_GLSLH + +#ifdef DEBUG_MODE + +void printFPWeightTile(const FPWeightTile tile) { + debugPrintfEXT("weight_tile: \\n"); + [[unroll]] for (int k = 0; k < TILE_K; ++k) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %f, %f, %f, %f, ", + tile.data[k][n4].x, + tile.data[k][n4].y, + tile.data[k][n4].z, + tile.data[k][n4].w); + } + debugPrintfEXT("\\n"); + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_WEIGHT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh new file mode 100644 index 00000000000..9535de21f7b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh @@ -0,0 +1,84 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * This file defines utilties to perform int8 quantization and block packing of + * matrix multiplation inputs. It also defines utilities to store packed block + * data to an output buffer or texture. + * + * Requires: + * - t_packed_int8_input to be defined in shader layout (output buffer/texture) + * + * Settings: + * - OUTPUT_BUFFER to indicate if output resource is a buffer. Otherwise texture + * is assumed. + */ + +#ifndef LINEAR_INT8_INPUT_BLOCK_GLSLH +#define LINEAR_INT8_INPUT_BLOCK_GLSLH + +#define TILE_M 4 +#define TILE_K4 1 + +#include "linear_fp_input_tile.glslh" + +struct Int8InputBlock { + ivec4 data; +}; + +ivec4 quantize( + const VEC4_T val, + const float q_inv_scale, + const int q_zero_point) { + vec4 quantized = round(vec4(val) * q_inv_scale) + q_zero_point; + // hard-code 8 bit quantization range + return clamp(ivec4(quantized), -128, 127); +} + +int pack_into_int32(const ivec4 quant_vals) { + int packed = ((quant_vals[0] & 0xFF) << 0) | ((quant_vals[1] & 0xFF) << 8) | + ((quant_vals[2] & 0xFF) << 16) | ((quant_vals[3] & 0xFF) << 24); + + return packed; +} + +void quantize_and_pack( + out Int8InputBlock packed, + const FPInputTile in_block, + const float q_inv_scale, + const int q_zero_point) { + for (int row = 0; row < 4; ++row) { + ivec4 quantized_inputs = + quantize(in_block.data[row][0], q_inv_scale, q_zero_point); + packed.data[row] = pack_into_int32(quantized_inputs); + } +} + +#ifdef OUTPUT_BUFFER + +void write_block( + const Int8InputBlock block, + const int block_x, + const int block_y, + const int nblocks_x) { + t_packed_int8_input[block_y * nblocks_x + block_x] = block.data; +} + +#else // OUTPUT_TEXTURE + +void write_block( + const Int8InputBlock block, + const int block_x, + const int block_y, + const int nblocks_x) { + imageStore(t_packed_int8_input, ivec3(block_x, block_y, 0), block.data); +} + +#endif // OUTPUT_BUFFER + +#endif // LINEAR_INT8_INPUT_BLOCK_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh new file mode 100644 index 00000000000..89a7e1b3f89 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines the Int8InputTile struct, which is used to represent a tile of the + * quantized int8 input matrix of a quantized matrix multiplication operation. + * + * Settings: + * - TILE_M4: number of (groups of 4) rows in the tile + * - TILE_K4: number of (groups of 4) columns in the tile + */ + +#ifndef LINEAR_INT8_INPUT_TILE_GLSLH +#define LINEAR_INT8_INPUT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct Int8InputTile { + ivec4 data[TILE_M4][TILE_K4]; +}; + +#ifdef DEBUG_MODE + +#include "linear_common.glslh" + +void printInt8InputTile(const Int8InputTile tile) { + debugPrintfEXT( + "Int8InputTile [TILE_M4=%d][TILE_K4=%d]:\\n", TILE_M4, TILE_K4); + + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, k4); + + // Each ivec4 contains 4 packed integers, each integer contains 4 8-bit + // values + [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { + int packed_int = tile.data[m4][k4][vec_idx]; + debugPrintfEXT("packed_int[%d]=%d -> [", vec_idx, packed_int); + + // Extract 4 8-bit values from this packed integer + [[unroll]] for (int byte_idx = 0; byte_idx < 4; ++byte_idx) { + int val = extract_8bit_from_packed_int_le(packed_int, byte_idx); + if (byte_idx < 3) { + debugPrintfEXT("%d, ", val); + } else { + debugPrintfEXT("%d] ", val); + } + } + } + debugPrintfEXT("\\n"); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_INPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh new file mode 100644 index 00000000000..c79badab6c6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to load a Int8InputTile from input buffer/texture. + * + * Requires: + * - t_packed_int8_input to be declared in the shader layout + * + * Settings: + * - PACKED_INT8_INPUT_BUFFER to indicate resource is a buffer, otherwise + * texture storage is assumed. + */ + +#ifndef LINEAR_INT8_INPUT_TILE_LOAD_GLSLH +#define LINEAR_INT8_INPUT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_int8_input_tile.glslh" + +#ifdef PACKED_INT8_INPUT_BUFFER + +ivec4 load_int8_input_block( + const int block_x, + const int block_y, + const int nblocks_x) { + return t_packed_int8_input[(block_y * nblocks_x) + block_x]; +} + +#else + +ivec4 load_int8_input_block( + const int block_x, + const int block_y, + const int nblocks_x) { + return texelFetch(t_packed_int8_input, ivec3(block_x, block_y, 0), 0); +} + +#endif // PACKED_INT8_INPUT_BUFFER + +void load_int8_input_tile( + out Int8InputTile in_tile, + const int block_x, + const int block_y, + const int nblocks_x) { +#if TILE_M4 == 1 && TILE_K4 == 1 + in_tile.data[0][0] = load_int8_input_block(block_x, block_y, nblocks_x); + +#elif TILE_M4 == 1 && TILE_K4 > 1 + [[unroll]] for (int x = 0; x < TILE_K4; ++x) { + in_tile.data[0][x] = load_int8_input_block(block_x + x, block_y, nblocks_x); + } + +#elif TILE_M4 > 1 && TILE_K4 == 1 + [[unroll]] for (int y = 0; y < TILE_M4; ++y) { + in_tile.data[y][0] = load_int8_input_block(block_x, block_y + y, nblocks_x); + } + +#else + [[unroll]] for (int y = 0; y < TILE_M4; ++y) { + [[unroll]] for (int x = 0; x < TILE_K4; ++x) { + in_tile.data[y][x] = + load_int8_input_block(block_x + x, block_y + y, nblocks_x); + } + } +#endif +} + +#endif // LINEAR_INT8_INPUT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh new file mode 100644 index 00000000000..6e98caea49e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_WEIGHT_BLOCK_GLSLH +#define LINEAR_INT8_WEIGHT_BLOCK_GLSLH + +/* + * This file defines utilties to perform weight prepacking of quantized int8 + * matrix multiplation weights. It also defines utilities to load source + * weight data from inputbuffer, and write out a packed weight block to output + * texture/buffer. + * + * Requires: + * - t_packed_int8_weight to be defined in shader layout (output texture/buffer) + * - t_int8_weight to be defined in shader layout (input buffer) + * + * Settings: + * - USING_BUFFER to indicate if output resource is a buffer. Otherwise texture + * is assumed. + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_common.glslh" + +// Represents data for a 4x4 block of the weight matrix read from the input +// buffer. +struct Int8WeightBlock { + ivec4 data; +}; + +void load_block_data_no_checks( + out Int8WeightBlock block, + const int k4, + const int n_start, + const int ntexels_K, + const int N) { + [[unroll]] for (int n = 0; n < 4; ++n) { + block.data[n] = t_int8_weight[(n_start + n) * ntexels_K + k4]; + } +} + +void load_block_data_with_checks( + out Int8WeightBlock block, + const int k4, + const int n_start, + const int ntexels_K, + const int N) { + [[unroll]] for (int n = 0; n < 4; ++n) { + if (n_start + n < N) { + block.data[n] = t_int8_weight[(n_start + n) * ntexels_K + k4]; + } else { + block.data[n] = 0; + } + } +} + +#ifdef USING_BUFFER + +void write_weight_block( + const Int8WeightBlock block, + const int n4, + const int k4, + const int ntexels_N) { + t_packed_int8_weight[k4 * ntexels_N + n4] = block.data; +} + +#else // USING_TEXTURE + +void write_weight_block( + const Int8WeightBlock block, + const int n4, + const int k4, + const int ntexels_N) { + imageStore(t_packed_int8_weight, ivec2(n4, k4), block.data); +} + +#endif // USING_BUFFER + +#ifdef DEBUG_MODE + +void printInt8WeightBlock(const Int8WeightBlockPacked block) { + debugPrintfEXT("int8_weight_block_packed: \\n"); + debugPrintfEXT( + "%i %i %i %i \\n", + block.data[0], + block.data[1], + block.data[2], + block.data[3]); +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_WEIGHT_BLOCK_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh new file mode 100644 index 00000000000..f312db543db --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_WEIGHT_TILE_GLSLH +#define LINEAR_INT8_WEIGHT_TILE_GLSLH + +/* + * Defines the Int8WeightTile struct, which is used to represent a tile of the + * quantized int8 weight matrix of a quantized matrix multiplication operation. + * + * Settings: + * - TILE_K4: number of (groups of 4) rows in the weight tile + * - TILE_N4: number of (groups of 4) columns in the weight tile + */ + +#extension GL_EXT_control_flow_attributes : require + +struct Int8WeightTile { + ivec4 data[TILE_K4][TILE_N4]; +}; + +#ifdef DEBUG_MODE + +void printInt8WeightTile(const Int8WeightTile tile) { + debugPrintfEXT( + "Int8WeightTile [TILE_K4=%d][TILE_N4=%d]:\\n", TILE_K4, TILE_N4); + + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, k4); + + // Each ivec4 contains 4 packed integers, each integer contains 4 8-bit + // values + [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { + int packed_int = tile.data[m4][k4][vec_idx]; + debugPrintfEXT("packed_int[%d]=%d -> [", vec_idx, packed_int); + + // Extract 4 8-bit values from this packed integer + [[unroll]] for (int byte_idx = 0; byte_idx < 4; ++byte_idx) { + int val = extract_8bit_from_packed_int_le(packed_int, byte_idx); + if (byte_idx < 3) { + debugPrintfEXT("%d, ", val); + } else { + debugPrintfEXT("%d] ", val); + } + } + } + debugPrintfEXT("\\n"); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_WEIGHT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh new file mode 100644 index 00000000000..fe16d3469b3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH +#define LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH + +/* + * Defines functions to load a Int8WeightTile from input buffer/texture. + * + * Requires: + * - t_packed_int8_weight to be declared in the shader layout (input + * buffer/texture) + * + * Settings: + * - WEIGHT_BUFFER to indicate t_packed_int8_weight is a buffer, otherwise + * texture storage is assumed. + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_int8_weight_tile.glslh" + +#ifdef WEIGHT_BUFFER + +ivec4 load_int8_weight_block( + const int block_x, + const int block_y, + const int nblocks_x) { + return t_packed_int8_weight[(block_y * nblocks_x) + block_x]; +} + +#else // WEIGHT_TEXTURE + +ivec4 load_int8_weight_block( + const int block_x, + const int block_y, + const int nblocks_x) { + return texelFetch(t_packed_int8_weight, ivec2(block_x, block_y), 0); +} + +#endif // WEIGHT_BUFFER + +void load_int8_weight_tile( + out Int8WeightTile weight_tile, + const int block_x, + const int block_y, + const int nblocks_x) { +#if TILE_K4 == 1 && TILE_N4 == 1 + weight_tile.data[0][0] = load_int8_weight_block(block_x, block_y, nblocks_x); + +#elif TILE_K4 == 1 && TILE_N4 > 1 + [[unroll]] for (int x = 0; x < TILE_N4; ++x) { + weight_tile.data[0][x] = + load_int8_weight_block(block_x + x, block_y, nblocks_x); + } + +#elif TILE_K4 > 1 && TILE_N4 == 1 + [[unroll]] for (int y = 0; y < TILE_M4; ++y) { + weight_tile.data[y][0] = + load_int8_weight_block(block_x, block_y + y, nblocks_x); + } + +#else + [[unroll]] for (int y = 0; y < TILE_K4; ++y) { + [[unroll]] for (int x = 0; x < TILE_N4; ++x) { + weight_tile.data[y][x] = + load_int8_weight_block(block_x + x, block_y + y, nblocks_x); + } + } +#endif +} + +#endif // LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_out_channel_params.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_out_channel_params.glslh new file mode 100644 index 00000000000..ca29fd52780 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_out_channel_params.glslh @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines common functions and structs to be used across matrix multiplication + * operators. + */ + +#ifndef LINEAR_INT_PER_OUT_CHANNEL_PARAMS_GLSLH +#define LINEAR_INT_PER_OUT_CHANNEL_PARAMS_GLSLH + +#include "common.glslh" + +#extension GL_EXT_control_flow_attributes : require + +// Represents floating point parameter tensors where each element is associated +// with an output channel, such as weight scales, biases, etc. +struct IntPerOutChannelParams { + ivec4 data[TILE_N4]; +}; + +#ifdef DEBUG_MODE + +void printIntPerOutChannelParams(const IntPerOutChannelParams params) { + debugPrintfEXT("per_out_channel_params: \\n"); + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %d, %d, %d, %d, ", + params.data[n4].x, + params.data[n4].y, + params.data[n4].z, + params.data[n4].w); + } + debugPrintfEXT("\\n"); +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT_PER_OUT_CHANNEL_PARAMS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int_weight_sums_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int_weight_sums_load.glslh new file mode 100644 index 00000000000..1a17f99ea4e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int_weight_sums_load.glslh @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_FP_WEIGHT_SUMS_LOAD_GLSLH +#define LINEAR_FP_WEIGHT_SUMS_LOAD_GLSLH + +#include "linear_int_per_out_channel_params.glslh" + +ivec4 load_weight_sum_x4(const int n4) { + return ivec4(t_weight_sums[n4]); +} + +void load_weight_sums_tile( + out IntPerOutChannelParams sums, + const int n4_start) { +#if TILE_N4 == 1 + sums.data[0] = load_weight_sum_x4(n4_start); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + sums.data[n4] = load_weight_sum_x4(n4_start + n4); + } + +#endif +} + +#endif // LINEAR_FP_WEIGHT_SUMS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl new file mode 100644 index 00000000000..b6d932f0015 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl @@ -0,0 +1,113 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "uint", "apply_bias", "0")} + +#include "linear_fp_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_weight_tile.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "linear_fp_output_tile_fp_int8_compute.glslh" +#include "linear_fp_output_tile_store.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" + +void main() { + // Each thread writes out a 4 wide x 4 high tile of output values + const int out_tile_x = int(gl_GlobalInvocationID.x); + const int out_tile_y = int(gl_GlobalInvocationID.y); + + const int n = out_tile_x * TILE_N; + const int m = out_tile_y * TILE_M; + + const int n4 = div_4(n); + const int m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const int M = input_sizes.y; + const int K4 = div_up_4(input_sizes.x); + const int N4 = div_up_4(output_sizes.x); + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + Int8WeightTile int8_weight_tile; + + const bool dont_check_bounds = (M - m) >= TILE_M; + if (dont_check_bounds) { + for (int k4 = 0; k4 < K4; k4 += TILE_K4) { + load_input_tile_no_checks(in_tile, k4, m, K4, M); + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + fp_accumulate_with_int8_weight(out_tile, in_tile, int8_weight_tile); + } + } else { + for (int k4 = 0; k4 < K4; k4 += TILE_K4) { + load_input_tile_with_checks(in_tile, k4, m, K4, M); + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + fp_accumulate_with_int8_weight(out_tile, in_tile, int8_weight_tile); + } + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, n4); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + apply_scales_and_biases(out_tile, weight_scales_tile, bias_tile); + } + else { + apply_scales(out_tile, weight_scales_tile); + } + + if (dont_check_bounds) { + write_output_tile_no_checks(out_tile, n4, m, N4, M); + } else { + write_output_tile_with_checks(out_tile, n4, m, N4, M); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml new file mode 100644 index 00000000000..242c4471b3d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_q8csw_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 1 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: linear_q8csw_tiled_texture3d_texture2d + - NAME: linear_q8csw_tiled_texture3d_buffer + WEIGHT_STORAGE: buffer + - NAME: linear_q8csw_tiled_buffer_texture2d + IO_STORAGE: buffer + WEIGHT_STORAGE: texture2d + - NAME: linear_q8csw_tiled_buffer_buffer + IO_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl new file mode 100644 index 00000000000..9f7e00e3317 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl @@ -0,0 +1,132 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T int + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if PACKED_INT8_INPUT_STORAGE == "buffer": + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INT8_INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_spec_const(C, "int", "apply_bias", "0")} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "linear_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_fp_output_tile_store.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_bias_load.glslh" + +void main() { + // Each thread writes out a 4 wide x 4 high tile of output values + const int out_tile_x = int(gl_GlobalInvocationID.x); + const int out_tile_y = int(gl_GlobalInvocationID.y); + + const int n = out_tile_x * TILE_N; + const int m = out_tile_y * TILE_M; + + const int n4 = div_4(n); + const int m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const int M = output_sizes.y; + const int K4 = div_up_4(input_sizes.x); + const int N4 = div_up_4(output_sizes.x); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_in_tile; + Int8WeightTile int8_weight_tile; + + // No checks are needed since packed input and weight are structured in units + // of 4x4 blocks. + for (int k4 = 0; k4 < K4; k4 += TILE_K4) { + load_int8_input_tile(int8_in_tile, k4, m4, K4); + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + + int_accumulate_with_int8_weight(out_accum, int8_in_tile, int8_weight_tile); + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, n4); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, n4); + + FPOutTile out_tile; + initialize(out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile); + } + + if (M - m >= TILE_M) { + write_output_tile_no_checks(out_tile, n4, m, N4, M); + } else { + write_output_tile_with_checks(out_tile, n4, m, N4, M); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml new file mode 100644 index 00000000000..aa1de3077fc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_q8ta_q8csw_tiled: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + PACKED_INT8_INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 1 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: linear_q8ta_q8csw_tiled_texture3d_buffer_texture2d + - NAME: linear_q8ta_q8csw_tiled_buffer_buffer_texture2d + OUTPUT_STORAGE: buffer + PACKED_INT8_INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl new file mode 100644 index 00000000000..f2c74b67283 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type(STORAGE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_packed_int8_weight", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_int8_weight", "int", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec2 orig_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" +#include "linear_int8_weight_block.glslh" + +void main() { + // The size of the source weight tensor is [W=K, H=N]. Each shader invocation + // processes a 4x4 block. The thread position corresponds to the block index. + int n4 = int(gl_GlobalInvocationID.x); + int k4 = int(gl_GlobalInvocationID.y); + + const int K = orig_sizes.x; + const int N = orig_sizes.y; + + // Determine the total number of blocks and check bounds + const int N4 = div_up_4(N); + const int K4 = div_up_4(K); + if (n4 >= N4 || k4 >= K4) { + return; + } + + // Each block is represented as an ivec4. Each int corresponds to a row i.e. + // N dim of the weight tensor and contains data for 4 columns i.e. K dim. + Int8WeightBlock block; + const int n = mul_4(n4); + if (N - n >= 4) { + load_block_data_no_checks(block, k4, n, K4, N); + } else { + load_block_data_with_checks(block , k4, n, K4, N); + } + + // The weight blocks are stored in a tranposed manner, such that weight blocks + // are indexed like packed_weight[k4][n4]. This is to optimize memory + // coalescing when computing tiled GEMM. + write_weight_block(block, n4, k4, N4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml new file mode 100644 index 00000000000..13e6d43b2c5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_q8_linear_weight: + parameter_names_with_default_values: + STORAGE: buffer + shader_variants: + - NAME: pack_q8_linear_weight_buffer + STORAGE: buffer + - NAME: pack_q8_linear_weight_texture2d + STORAGE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl new file mode 100644 index 00000000000..6ba9343f10d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl @@ -0,0 +1,79 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +$if GRANULARITY == "per_channel": + ${layout_declare_tensor(B, "r", "t_scale", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float inv_scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "linear_int8_input_block.glslh" +#include "linear_fp_input_tile_load.glslh" + +void main() { + // Each input block contains 4x4 int8 quantized values, which are packed into + // a ivec4. k4 and m4 represent the "block index" of the current block being + // processed. + int k4 = int(gl_GlobalInvocationID.x); + int m4 = int(gl_GlobalInvocationID.y); + + const int K = input_sizes.x; + const int M = input_sizes.y; + + // K4 and M4 represent the number of blocks in each dimension. + const int K4 = div_up_4(K); + const int M4 = div_up_4(M); + + if (k4 >= K4 || m4 >= M4) { + return; + } + + // row of the input tensor to start loading from. Note the input tensor is + // interpreted as a t + const int m = mul_4(m4); + + const bool dont_check_bounds = (M - m) >= 4; + + FPInputTile in_tile; + if (dont_check_bounds) { + load_input_tile_no_checks(in_tile, k4, m, K4, M); + } else { + load_input_tile_with_checks(in_tile, k4, m, K4, M); + } + + Int8InputBlock packed_block; + quantize_and_pack(packed_block, in_tile, inv_scale, zp); + + write_block(packed_block, k4, m4, K4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml new file mode 100644 index 00000000000..37721db1ba8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +quantize_and_pack_linear_input: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + STORAGE: texture3d + GRANULARITY: per_tensor + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: quantize_and_pack_linear_input_per_tensor_texture3d_texture3d + - NAME: quantize_and_pack_linear_input_per_tensor_buffer_texture3d + OUTPUT_STORAGE: buffer + - NAME: quantize_and_pack_linear_input_per_tensor_buffer_buffer + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp new file mode 100644 index 00000000000..d6aeb5e3dce --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -0,0 +1,554 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace vkcompute { + +// +// Shader dispatch utilities +// + +utils::uvec3 quantized_linear_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + + std::vector out_sizes = graph->sizes_of(out); + // height + const uint32_t M = utils::val_at(-2, out_sizes); + // width + const uint32_t N = utils::val_at(-1, out_sizes); + + // 1 output tile is 4x4 elements + const uint32_t M4 = utils::div_up(M, 4u); + const uint32_t N4 = utils::div_up(N, 4u); + + return {N4, M4, 1}; +} + +utils::uvec3 quantized_linear_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); +} + +std::tuple get_quantized_input_num_blocks( + ComputeGraph& graph, + const ValueRef input) { + std::vector input_sizes = graph.sizes_of(input); + const int64_t ndim = graph.dim_of(input); + + const int64_t M = input_sizes.at(ndim - 2); + const int64_t K = input_sizes.at(ndim - 1); + + const int64_t num_blocks_M = utils::div_up(M, int64_t(4)); + const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); + + return std::make_tuple(num_blocks_M, num_blocks_K); +} + +utils::uvec3 quant_pack_input_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef input = args.at(1).refs.at(0); + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(*graph, input); + + return { + utils::safe_downcast(num_blocks_K), + utils::safe_downcast(num_blocks_M), + 1u}; +} + +// +// Prepacking nodes +// + +ValueRef prepack_quantized_linear_weight( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef qmat2_data) { + VK_CHECK_COND(weight_quant_config.nbits == 8); + + std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); + const int64_t ndim = graph.dim_of(qmat2_data); + + // Input size is [N, K]. K will be guaranteed to be a multiple of 4. + const int64_t K = qmat2_orig_sizes.at(ndim - 1); + const int64_t N = qmat2_orig_sizes.at(ndim - 2); + + // Sanity check that assumption is correct + VK_CHECK_COND(K % 4 == 0); + + // The packing format packs the weight tensor into units of 4 wide x 4 high + // blocks. To figure out the size of the output tensor, determine the number + // of blocks along each dimension. + const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); + const int64_t num_blocks_N = utils::div_up(N, int64_t(4)); + + // The blocks are arranged in a transposed manner, such that the transposed + // weight block is indexed like packed_weights[k4][n4] - this is to allow for + // optimal memory coalescing when computing GEMM. + const int64_t output_height = num_blocks_K; + // The base dtype of the packed tensor is int32 (each int32 contains 4x 8bit + // values) and each block is represented as a ivec4. Therefore the width dim + // of the packed tensor is multiplied by 4. + const int64_t output_width = num_blocks_N * 4; + + // Store the original sizes of the tensor to pass to the shader + utils::ivec2 orig_sizes{ + utils::safe_downcast(K), utils::safe_downcast(N)}; + + std::vector qmat2_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef qmat2 = graph.add_tensor( + qmat2_sizes, vkcompute::vkapi::kInt, storage_type, utils::kWidthPacked); + + // Global workgroup size: each thread writes out two adjacent blocks + utils::uvec3 global_wg_size{ + utils::safe_downcast(num_blocks_N), + utils::safe_downcast(num_blocks_K), + 1u}; + + std::string kernel_name = "pack_q8_linear_weight"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + qmat2_data, + qmat2, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(qmat2), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec2))})); + + return qmat2; +} + +// +// Dispatch nodes +// + +/* + * Shader dispatch for linear with quantized weight but fp activations. + */ +DynamicDispatchNode make_linear_qw_node( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef fp_input, + const ValueRef weight_data, + const ValueRef packed_weight, + const ValueRef packed_weight_scales, + const ValueRef packed_weight_zeros, + const ValueRef group_size, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef output) { + // Only certain quantization types supported at the moment + VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); + VK_CHECK_COND(weight_quant_config.is_symmetric); + VK_CHECK_COND(weight_quant_config.nbits == 8); + + std::string kernel_name = "linear_q8csw_tiled"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output), graph.sizes_ubo(fp_input)}; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + return DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + quantized_linear_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {{fp_input, packed_weight, packed_weight_scales, packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {apply_bias}, + // Resize args + {}, + // Resizing Logic + nullptr); +} + +DynamicDispatchNode make_quantize_and_pack_linear_input_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const ValueRef fp_input, + const ValueRef packed_input_scale, + const ValueRef packed_input_zp, + const ValueRef input_scale_data, + const ValueRef input_zp_data, + const ValueRef packed_int_input, + const ValueRef group_size) { + // Only certain quantization types supported at the moment + VK_CHECK_COND(input_quant_config.granularity == kPerTensor); + + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, fp_input); + + float inv_scale = 1.0f / graph.extract_scalar(input_scale_data); + int32_t zp = graph.extract_scalar(input_zp_data); + + std::string shader_name = "quantize_and_pack_linear_input_per_tensor"; + add_storage_type_suffix(shader_name, graph.storage_type_of(packed_int_input)); + add_storage_type_suffix(shader_name, graph.storage_type_of(fp_input)); + add_dtype_suffix(shader_name, graph.dtype_of(fp_input)); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + return DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(shader_name), + quant_pack_input_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{packed_int_input, vkapi::kWrite}, {fp_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}); +} + +DynamicDispatchNode make_linear_qa_qw_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const QuantizationConfig& weight_quant_config, + const ValueRef fp_input, + const ValueRef packed_int_input, + const ValueRef packed_input_scale, + const ValueRef packed_input_zp, + const ValueRef input_scale_data, + const ValueRef input_zp_data, + const ValueRef weight_data, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef group_size, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef output) { + VK_CHECK_COND(input_quant_config.granularity == kPerTensor); + VK_CHECK_COND(input_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); + VK_CHECK_COND(weight_quant_config.is_symmetric); + VK_CHECK_COND(weight_quant_config.nbits == 8); + + float scale = graph.extract_scalar(input_scale_data); + int32_t zp = graph.extract_scalar(input_zp_data); + + // Get shader for quantized linear + std::string kernel_name = "linear_q8ta_q8csw_tiled"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_int_input)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output), graph.sizes_ubo(packed_int_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&scale, sizeof(scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + // Add the compute node + return DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + quantized_linear_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {{packed_int_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {fp_input}, + // Resizing Logic + nullptr); +} + +// +// High level operator impl +// + +void quantized_linear_impl( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const QuantizationConfig& weight_quant_config, + const ValueRef fp_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef weight_data, + const ValueRef weight_sums_data, + const ValueRef weight_scales_data, + const ValueRef weight_zeros_data, + const ValueRef group_size, + const ValueRef bias_data, + const ValueRef output) { + std::vector input_sizes = graph.sizes_of(fp_input); + std::vector weight_sizes = graph.sizes_of(weight_data); + + const int64_t K = utils::val_at(-1, input_sizes); + // K (input channels) must be a multiple of 4 to ensure that reading a group + // of 4 input channels from the input tensor will be aligned on a texel + // boundary. + VK_CHECK_COND(K % 4 == 0); + + // Prepack weight data + + const ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + const ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + // Weight affine quant not supported at the moment + const ValueRef packed_weight_zeros = kDummyValueRef; + + // Prepack bias data + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided. This helps simplify dispatch logic and makes it so that + // fewer shdaer variants need to be generated. + TmpTensor dummy_bias( + &graph, {}, graph.dtype_of(output), utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + // Use weight only quantized linear if at least one is true: + // 1. Device does not support int8 dot product + // 2. Input is not quantized + if (!graph.can_use_int8_dot_product() || + input_quant_config.granularity == kNoQuantization) { + DynamicDispatchNode linear_qw_node(make_linear_qw_node( + graph, + weight_quant_config, + fp_input, + weight_data, + packed_weight, + packed_weight_scales, + packed_weight_zeros, + group_size, + bias_data, + packed_bias, + output)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode(linear_qw_node)); + return; + } else { + // Otherwise, use input and weight quantized linear computed with integer + // accumulation + + // Input scale/zero point only used for activation & weight quantized linear + ValueRef packed_input_scale = input_scale; + ValueRef packed_input_zp = input_zp; + if (graph.val_is_tref(input_scale)) { + VK_CHECK_COND(graph.val_is_tref(packed_input_zp)); + packed_input_scale = prepack_standard( + graph, input_scale, utils::kBuffer, utils::kWidthPacked); + packed_input_zp = prepack_standard( + graph, input_zp, utils::kBuffer, utils::kWidthPacked); + } + + // Pre-computed per quant group weight sums are needed for int accumulation, + // but not for weight only + const ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + // Allocate temporary tensor to store quantized and packed input + + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, fp_input); + + const int64_t int_input_height = num_blocks_M; + const int64_t int_input_width = num_blocks_K * 4; + + TmpTensor packed_int_input( + &graph, + {int_input_height, int_input_width}, + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + DynamicDispatchNode quantize_and_pack_linear_node( + make_quantize_and_pack_linear_input_node( + graph, + input_quant_config, + fp_input, + packed_input_scale, + packed_input_zp, + input_scale, + input_zp, + packed_int_input, + group_size)); + + graph.execute_nodes().emplace_back( + new DynamicDispatchNode(quantize_and_pack_linear_node)); + + DynamicDispatchNode linear_qa_qw_node(make_linear_qa_qw_node( + graph, + input_quant_config, + weight_quant_config, + fp_input, + packed_int_input, + packed_input_scale, + packed_input_zp, + input_scale, + input_zp, + weight_data, + packed_weight, + packed_weight_sums, + packed_weight_scales, + group_size, + bias_data, + packed_bias, + output)); + + graph.execute_nodes().emplace_back( + new DynamicDispatchNode(linear_qa_qw_node)); + } +} + +void linear_q8ta_q8csw(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef output = args.at(idx++); + + const int64_t K = graph.size_at(-1, fp_input); + + QuantizationConfig input_quant_config(8, kPerTensor, {}, false); + QuantizationConfig weight_quant_config(8, kPerChannel, {K}); + + quantized_linear_impl( + graph, + input_quant_config, + weight_quant_config, + fp_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + kDummyValueRef, // weight_zeros_data + kDummyValueRef, // group_size + bias_data, + output); +} + +void linear_q8csw(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef output = args.at(idx++); + + const int64_t K = graph.size_at(-1, fp_input); + + QuantizationConfig input_quant_config(32, kNoQuantization, {}); + QuantizationConfig weight_quant_config(8, kPerChannel, {K}); + + quantized_linear_impl( + graph, + input_quant_config, + weight_quant_config, + fp_input, + kDummyValueRef, // input scale + kDummyValueRef, // input zp + weight_data, + kDummyValueRef, // weight sums + weight_scales_data, + kDummyValueRef, // weight zeros + kDummyValueRef, // group size + bias_data, + output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.linear_q8ta_q8csw.default, linear_q8ta_q8csw); + VK_REGISTER_OP(et_vk.linear_q8csw.default, linear_q8csw); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h new file mode 100644 index 00000000000..7b62c98390d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace vkcompute { + +utils::uvec3 quantized_linear_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args); + +ValueRef prepack_quantized_linear_weight( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef qmat2_data); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/QuantizationConfig.h b/backends/vulkan/runtime/graph/ops/impl/utils/QuantizationConfig.h new file mode 100644 index 00000000000..4bc8c7c3bfc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/utils/QuantizationConfig.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +enum class QuantizationGranularity { + PerChannel, + PerTensor, + PerGroup, + NoQuantization, +}; + +static constexpr QuantizationGranularity kPerChannel = + QuantizationGranularity::PerChannel; +static constexpr QuantizationGranularity kPerTensor = + QuantizationGranularity::PerTensor; +static constexpr QuantizationGranularity kPerGroup = + QuantizationGranularity::PerGroup; +static constexpr QuantizationGranularity kNoQuantization = + QuantizationGranularity::NoQuantization; + +struct QuantizationConfig { + int nbits; + QuantizationGranularity granularity; + std::vector granularity_sizes; + bool is_symmetric; + bool is_dynamic; + + QuantizationConfig() + : nbits(8), + granularity(kPerTensor), + granularity_sizes(), + is_symmetric(true), + is_dynamic(false) {} + + QuantizationConfig( + int nbits_, + QuantizationGranularity granularity_, + const std::vector& granularity_sizes_, + bool is_symmetric_ = true, + bool is_dynamic_ = false) + : nbits(nbits_), + granularity(granularity_), + granularity_sizes(granularity_sizes_), + is_symmetric(is_symmetric_), + is_dynamic(is_dynamic_) {} +}; + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index f44db22c17e..6944fe59385 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -91,4 +91,5 @@ if(TARGET vulkan_backend) # Define operator prototypes add_operator_prototype(add) + add_operator_prototype(q8csw_linear) endif() diff --git a/backends/vulkan/test/custom_ops/q8csw_linear.cpp b/backends/vulkan/test/custom_ops/q8csw_linear.cpp new file mode 100644 index 00000000000..23973426fcc --- /dev/null +++ b/backends/vulkan/test/custom_ops/q8csw_linear.cpp @@ -0,0 +1,479 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 300; + +// Linear configuration struct +struct LinearConfig { + int64_t M; // Batch size / number of rows in input + int64_t K; // Input features / columns in input, rows in weight + int64_t N; // Output features / columns in weight + bool has_bias = true; + std::string test_case_name = "placeholder"; + std::string op_name = "linear_q8ta_q8csw"; +}; + +// Utility function to create a test case from a LinearConfig +TestCase create_test_case_from_config( + const LinearConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + config.test_case_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "et_vk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Derive sizes from M, K, N + std::vector input_size = {config.M, config.K}; + std::vector weight_size = {config.N, config.K}; + + // Input tensor (float/half) - [M, K] + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.008f; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = -2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [K, N] + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {config.N}, // Per output feature + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {config.N}, // Per output features + vkapi::kInt, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + int64_t in_features = config.K; + int64_t out_features = config.N; + compute_weight_sums(weight_sums, quantized_weight, out_features, in_features); + + // Bias (optional, float/half) - [N] + ValueSpec bias( + {config.N}, // Per output feature + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + bias.set_constant(true); + if (!config.has_bias) { + bias.set_none(true); + } + + // Output tensor (float/half) - [M, N] + ValueSpec output( + {config.M, config.N}, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + + // Add all specs to test case + if (config.op_name.find("q8ta") != std::string::npos) { + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(bias); + test_case.add_output_spec(output); + } else { + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(bias); + test_case.add_output_spec(output); + } + + return test_case; +} + +// Generate easy test cases for quantized linear operation (for debugging) +std::vector generate_quantized_linear_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + int M = 4; + int K = 4; + int N = 4; + + LinearConfig config = { + M, // Batch size + K, // Input features + N, // Output features + true, // has_bias + "simple", // test_case_name + }; + + // Test with both storage types and data types for completeness + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized linear operation +std::vector generate_quantized_linear_test_cases() { + std::vector test_cases; + + std::vector configs = { + {4, 64, 32}, + {4, 128, 64}, + {4, 256, 128}, + {32, 64, 32}, + {32, 128, 64}, + {32, 256, 128}, + // No bias tests + {32, 128, 64, false}, + {32, 256, 128, false}, + {256, 2048, 2048}, + {512, 2048, 2048}, + {1024, 2048, 2048}, + }; + + // Test with different storage types and data types + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + for (auto config : configs) { + std::string prefix = + (config.M < kRefDimSizeLimit && config.K < kRefDimSizeLimit && + config.N < kRefDimSizeLimit) + ? "correctness_" + : "performance_"; + std::string generated_test_case_name = prefix + std::to_string(config.M) + + "_" + std::to_string(config.K) + "_" + std::to_string(config.N); + if (!config.has_bias) { + generated_test_case_name += "_no_bias"; + } + + config.test_case_name = generated_test_case_name; + + for (const auto& storage_type : storage_types) { + if (vkcompute::api::context() + ->adapter_ptr() + ->supports_int8_dot_product()) { + // Test both activation+weight quantized and weight only quantized + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + + LinearConfig wo_quant_config = config; + wo_quant_config.op_name = "linear_q8csw"; + test_cases.push_back(create_test_case_from_config( + wo_quant_config, storage_type, vkapi::kFloat)); + } + } + + return test_cases; +} + +// Reference implementation for weight only quantized linear +void linear_q8csw_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [out_features, in_features] + auto output_sizes = + output_spec.get_tensor_sizes(); // [batch_size, out_features] + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + // Skip for large tensors since computation time will be extremely slow + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate number of output elements + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform quantized linear transformation (matrix multiplication) + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + float sum = 0.0f; + + // Matrix multiplication: output[b][out_f] = sum(input[b][in_f] * + // weight[out_f][in_f]) + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + // Get input value and dequantize + int64_t input_idx = b * in_features + in_f; + float input_val = input_data[input_idx]; + + // Get weight value and dequantize + int64_t weight_idx = out_f * in_features + in_f; + float dequant_weight = (static_cast(weight_data[weight_idx])) * + weight_scales_data[out_f]; + + sum += input_val * dequant_weight; + } + + // Add bias and store result + if (!bias_spec.is_none()) { + sum += bias_data[out_f]; + } + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = sum; + } + } +} + +void linear_q8ta_q8csw_reference_impl(TestCase& test_case) { + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [out_features, in_features] + auto output_sizes = + output_spec.get_tensor_sizes(); // [batch_size, out_features] + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + // Skip for large tensors since computation time will be extremely slow + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate number of output elements + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform quantized linear transformation (matrix multiplication) with + // integer accumulation + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + int32_t int_sum = 0; + int32_t weight_sum = 0; // Track weight sum on the fly + + // Matrix multiplication with integer accumulation: + // int_sum = sum(quantized_input[b][in_f] * quantized_weight[out_f][in_f]) + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + // Get input value and quantize to int8 + int64_t input_idx = b * in_features + in_f; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + input_zero_point; + quant_input_f = std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Get quantized weight (already int8) + int64_t weight_idx = out_f * in_features + in_f; + int8_t quantized_weight = weight_data[weight_idx]; + + // Integer multiplication and accumulation + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } + + // Convert accumulated integer result to float and apply scales + // Final result = (int_sum - zero_point_correction) * input_scale * + // weight_scale + bias zero_point_correction = input_zero_point * + // sum_of_weights_for_this_output_channel + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_f]; + + // Add bias and store result + if (!bias_spec.is_none()) { + float_result += bias_data[out_f]; + } + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = float_result; + } + } +} + +void reference_impl(TestCase& test_case) { + if (test_case.operator_name().find("q8ta") != std::string::npos) { + linear_q8ta_q8csw_reference_impl(test_case); + } else { + linear_q8csw_reference_impl(test_case); + } +} + +int64_t quantized_linear_flop_calculator(const TestCase& test_case) { + int input_idx = 0; + int weight_idx = 1; + if (test_case.operator_name().find("q8ta") != std::string::npos) { + input_idx = 0; + weight_idx = 3; + } + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[input_idx].get_tensor_sizes(); + const auto& weight_sizes = test_case.inputs()[weight_idx].get_tensor_sizes(); + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + // Calculate FLOPs for quantized linear operation + // Each output element requires: + // - in_features multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = batch_size * out_features; + int64_t ops_per_output = in_features; + + // Add quantization overhead (approximate) + // - Dequantize input: 1 op per input element used + // - Dequantize weight: 1 op per weight element used + // - Add bias: 1 op per output element + int64_t quantization_ops = ops_per_output + 1; // Simplified estimate + + int64_t flop = output_elements * (ops_per_output + quantization_ops); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Quantized Linear Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + auto results = execute_test_cases( + generate_quantized_linear_test_cases, + quantized_linear_flop_calculator, + "QuantizedLinear", + 0, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 2ddf49834e1..68bdc9e6fbd 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -12,7 +12,6 @@ def define_custom_op_test_binary(custom_op_name, extra_deps = [], src_file = Non ":operator_implementations", ":custom_ops_shaderlib", "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/runtime/core/exec_aten:lib", runtime.external_dep_location("libtorch"), ] + extra_deps @@ -68,8 +67,6 @@ def define_common_targets(is_fbcode = False): platforms = get_platforms(), deps = [ "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/runtime/core/exec_aten:lib", - runtime.external_dep_location("libtorch"), ], visibility = [ "//executorch/backends/vulkan/test/custom_ops/...", @@ -86,7 +83,6 @@ def define_common_targets(is_fbcode = False): platforms = get_platforms(), deps = [ "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/runtime/core/exec_aten:lib", ":custom_ops_shaderlib", ], visibility = [ @@ -97,3 +93,4 @@ def define_common_targets(is_fbcode = False): ) define_custom_op_test_binary("add") + define_custom_op_test_binary("q8csw_linear") diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index 235a6bd293e..ee2f6858025 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -694,7 +694,8 @@ void BenchmarkResult::print_summary( int case_number, const std::string& size_info, float total_gflops) const { - static constexpr int KERNEL_NAME_WIDTH = 140; + static constexpr int OPERATOR_NAME_WIDTH = 50; + static constexpr int KERNEL_NAME_WIDTH = 70; static constexpr int SIZE_INFO_WIDTH = 20; static constexpr int TIMING_WIDTH = 20; static constexpr int GFLOPS_WIDTH = 20; @@ -713,8 +714,10 @@ void BenchmarkResult::print_summary( break; } - std::cout << std::left << std::setw(KERNEL_NAME_WIDTH) << get_kernel_name() - << std::right << " " << std::setw(SIZE_INFO_WIDTH) << size_info + std::cout << std::left << std::setw(OPERATOR_NAME_WIDTH) + << get_operator_name() << " " << std::left + << std::setw(KERNEL_NAME_WIDTH) << get_kernel_name() << std::right + << " " << std::setw(SIZE_INFO_WIDTH) << size_info << std::setw(TIMING_WIDTH) << std::fixed << std::setprecision(3) << get_avg_time_us() << " μs " << std::setw(GFLOPS_WIDTH) << std::fixed << std::setprecision(3) << total_gflops << " GFLOP/s " @@ -999,7 +1002,9 @@ ComputeGraph setup_compute_graph(TestCase& test_case, std::string op_name) { for (size_t i = 0; i < test_case.num_inputs(); ++i) { const ValueSpec& input_spec = test_case.inputs()[i]; - if (input_spec.is_float()) { + if (input_spec.is_none()) { + input_values.push_back(graph.add_none()); + } else if (input_spec.is_float()) { ValueRef input_value = graph.add_scalar(static_cast(input_spec.get_float_value())); input_values.push_back(input_value); @@ -1246,9 +1251,11 @@ TestResult execute_test_cases( bool shader_not_supported = false; try { result = execute_test_case(test_case, warmup_runs, benchmark_runs); + result.set_operator_name(test_case.operator_name()); } catch (const vkcompute::vkapi::ShaderNotSupportedError& e) { result = BenchmarkResult( - test_case.name().empty() ? "unnamed_test_case" : test_case.name()); + test_case.name().empty() ? "unnamed_test_case" : test_case.name(), + test_case.operator_name()); shader_not_supported = true; } @@ -1606,20 +1613,21 @@ void compute_weight_sums( const ValueSpec& quantized_weight, int64_t out_features, int64_t elements_per_output_feature) { - auto& weight_sums_data = weight_sums.get_float_data(); + auto& weight_sums_data = weight_sums.get_int32_data(); auto& quantized_weight_data = quantized_weight.get_int8_data(); weight_sums_data.resize(out_features); // For each output feature, compute the sum of quantized weights for (int64_t out_f = 0; out_f < out_features; ++out_f) { - float sum = 0.0f; + int32_t sum = 0; for (int64_t elem = 0; elem < elements_per_output_feature; ++elem) { // Weight indexing depends on the layout: - // For linear: [in_features, out_features] -> elem * out_features + out_f - // For conv2d: [C_in * K_h * K_w, C_out] -> elem * out_features + out_f - int64_t weight_idx = elem * out_features + out_f; - sum += static_cast(quantized_weight_data[weight_idx]); + // For linear: [out_features, in_features] -> out_f * + // elements_per_output_feature + elem For conv2d: [C_out, C_in * K_h * + // K_w] -> out_f * elements_per_output_feature + elem + int64_t weight_idx = out_f * elements_per_output_feature + elem; + sum += static_cast(quantized_weight_data[weight_idx]); } weight_sums_data[out_f] = sum; } diff --git a/backends/vulkan/test/custom_ops/utils.h b/backends/vulkan/test/custom_ops/utils.h index 5ca05dc824f..6c4e2263fc1 100644 --- a/backends/vulkan/test/custom_ops/utils.h +++ b/backends/vulkan/test/custom_ops/utils.h @@ -66,6 +66,7 @@ struct ValueSpec { SpecType spec_type; DataGenType data_gen_type; bool is_constant_tensor; + bool is_none_flag; std::vector float_data; std::vector int32_data; @@ -90,7 +91,8 @@ struct ValueSpec { storage_type(storage_type), spec_type(SpecType::Tensor), data_gen_type(DataGenType::ZEROS), - is_constant_tensor(false) { + is_constant_tensor(false), + is_none_flag(false) { generate_tensor_data(); } @@ -107,7 +109,8 @@ struct ValueSpec { storage_type(storage_type), spec_type(SpecType::Tensor), data_gen_type(data_gen_type), - is_constant_tensor(false) { + is_constant_tensor(false), + is_none_flag(false) { generate_tensor_data(); } @@ -119,7 +122,8 @@ struct ValueSpec { storage_type(utils::kTexture3D), spec_type(SpecType::Int), data_gen_type(DataGenType::FIXED), - is_constant_tensor(false) { + is_constant_tensor(false), + is_none_flag(false) { int32_data.push_back(value); } @@ -131,7 +135,8 @@ struct ValueSpec { storage_type(utils::kTexture3D), spec_type(SpecType::Float), data_gen_type(DataGenType::FIXED), - is_constant_tensor(false) { + is_constant_tensor(false), + is_none_flag(false) { float_data.push_back(value); } @@ -143,7 +148,8 @@ struct ValueSpec { storage_type(utils::kTexture3D), spec_type(SpecType::Bool), data_gen_type(DataGenType::FIXED), - is_constant_tensor(false) { + is_constant_tensor(false), + is_none_flag(false) { int32_data.push_back(value ? 1 : 0); } @@ -156,6 +162,7 @@ struct ValueSpec { spec_type(SpecType::IntList), data_gen_type(DataGenType::FIXED), is_constant_tensor(false), + is_none_flag(false), int32_data(values) {} // Default constructor @@ -165,7 +172,8 @@ struct ValueSpec { storage_type(utils::kTexture3D), spec_type(SpecType::Tensor), data_gen_type(DataGenType::ZEROS), - is_constant_tensor(false) {} + is_constant_tensor(false), + is_none_flag(false) {} int64_t numel() const; size_t nbytes() const; @@ -279,6 +287,14 @@ struct ValueSpec { is_constant_tensor = is_constant; } + // Set/get none flag + bool is_none() const { + return is_none_flag; + } + void set_none(bool is_none) { + is_none_flag = is_none; + } + const void* get_data_ptr() const; // Correctness checking against reference data @@ -401,6 +417,13 @@ class BenchmarkResult { BenchmarkResult(const std::string& name) : kernel_name(name), correctness_status_(CorrectnessStatus::SKIPPED) {} + BenchmarkResult( + const std::string& kernel_name, + const std::string& operator_name) + : kernel_name(kernel_name), + operator_name(operator_name), + correctness_status_(CorrectnessStatus::SKIPPED) {} + // Add timing for a single iteration void add_iter_timing(float time_us); @@ -408,6 +431,9 @@ class BenchmarkResult { const std::string& get_kernel_name() const { return kernel_name; } + const std::string& get_operator_name() const { + return operator_name; + } float get_avg_time_us() const; size_t get_num_iterations() const { return iter_timings.size(); @@ -423,6 +449,9 @@ class BenchmarkResult { void set_kernel_name(const std::string& name) { kernel_name = name; } + void set_operator_name(const std::string& name) { + operator_name = name; + } void set_correctness_status(CorrectnessStatus status) { correctness_status_ = status; } @@ -445,6 +474,7 @@ class BenchmarkResult { private: std::string kernel_name; + std::string operator_name; std::vector iter_timings; // Individual iteration timings in microseconds CorrectnessStatus correctness_status_;