diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 379d47716c9..084cfe17a4d 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -929,7 +929,9 @@ jobs: CMAKE_ARGS="-DEXECUTORCH_BUILD_VULKAN=ON" \ .ci/scripts/setup-linux.sh --build-tool "cmake" + # Custom operator tests PYTHON_EXECUTABLE=python bash backends/vulkan/test/custom_ops/build_and_run.sh add + ./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear nxp-build-test: name: nxp-build-test diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 78fb79e65e8..4e9e2d36e1e 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -308,6 +308,10 @@ class ComputeGraph final { return idx == kDummyValueRef ? true : values_.at(idx).isNone(); } + inline bool val_is_not_none(const ValueRef idx) { + return !val_is_none(idx); + } + inline TypeTag get_val_type(const ValueRef idx) { return values_.at(idx).type(); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh new file mode 100644 index 00000000000..732b7006c2c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef COMMON_GLSLH +#define COMMON_GLSLH + +#define mul_2(x) ((x) << 1) +#define mul_4(x) ((x) << 2) +#define mul_8(x) ((x) << 3) + +#define div_2(x) ((x) >> 1) +#define div_4(x) ((x) >> 2) +#define div_8(x) ((x) >> 3) + +#define div_up_2(x) (((x) + 1) >> 1) +#define div_up_4(x) (((x) + 3) >> 2) +#define div_up_8(x) (((x) + 7) >> 3) + +#define align_up_2(x) ((x + 1) & -2) +#define align_up_4(x) ((x + 3) & -4) +#define align_up_8(x) ((x + 7) & -8) + +#define mod_2(x) ((x) & 1) +#define mod_4(x) ((x) & 3) +#define mod_8(x) ((x) & 7) + +struct TensorIndex4D { + ivec4 data; +}; + +#ifdef DEBUG_MODE + +#extension GL_EXT_debug_printf : require + +void printTensorIndex4D(const TensorIndex4D index) { + debugPrintfEXT( + "tensor_idx: %d, %d, %d, %d\\n", + index.data.x, + index.data.y, + index.data.z, + index.data.w); +} + +#endif // DEBUG_MODE + +#endif // COMMON_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh new file mode 100644 index 00000000000..90ede450ae7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines common functions and structs to be used across matrix multiplication + * operators. + */ + +#ifndef LINEAR_COMMON_GLSLH +#define LINEAR_COMMON_GLSLH + +#include "common.glslh" + +int sign_extend_8bit(const int val) { + if ((val & 0x80) != 0) { + return val | (~0xFF); + } + return val; +} + +int extract_8bit_from_packed_int_le(const int packed, const int i) { + // account for little endian + int byte = sign_extend_8bit(packed >> (8 * i) & 0xFF); + return byte; +} + +#endif // LINEAR_COMMON_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_bias_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_bias_load.glslh new file mode 100644 index 00000000000..f3d32be8b3d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_bias_load.glslh @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_FP_BIAS_LOAD_GLSLH +#define LINEAR_FP_BIAS_LOAD_GLSLH + +#include "linear_fp_per_out_channel_params.glslh" + +VEC4_T load_bias_x4(const int n4) { + return t_bias[n4]; +} + +void load_bias_tile(out FPPerOutChannelParams bias, const int n4_start) { +#if TILE_N4 == 1 + bias.data[0] = load_bias_x4(n4_start); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + bias.data[n4] = load_bias_x4(n4_start + n4); + } + +#endif +} + +#endif // LINEAR_FP_BIAS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh new file mode 100644 index 00000000000..68eee57a132 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_FP_INPUT_TILE_GLSLH +#define LINEAR_FP_INPUT_TILE_GLSLH + +/* + * Defines the FPInputTile struct, which is used to represent a tile of the + * input matrix of a matrix multiplication operation. + * + * Settings: + * - TILE_M: number of rows in the tile + * - TILE_K4: number of (groups of 4) columns in the tile + */ + +#extension GL_EXT_control_flow_attributes : require + +struct FPInputTile { + VEC4_T data[TILE_M][TILE_K4]; +}; + +#ifdef DEBUG_MODE + +void printFPInputTile(const FPInputTile in_tile) { + debugPrintfEXT("input_tile: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + in_tile.data[m][k4].x, + in_tile.data[m][k4].y, + in_tile.data[m][k4].z, + in_tile.data[m][k4].w); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_INPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh new file mode 100644 index 00000000000..6697003935f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to load a FPInputTile from input buffer/texture. + * + * Requires: + * - t_input to be declared in the shader layout (input buffer/texture) + * + * Settings: + * - INPUT_BUFFER to indicate input resource is a buffer, otherwise texture is + * assumed. + */ + +#ifndef LINEAR_FP_INPUT_TILE_LOAD_GLSLH +#define LINEAR_FP_INPUT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_input_tile.glslh" + +#ifdef INPUT_BUFFER + +VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) { + return t_input[(m * ntexels_k) + k4]; +} + +#else + +VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) { + return texelFetch(t_input, ivec3(k4, m, 0), 0); +} + +#endif // INPUT_BUFFER + +// To be used if (M - m_start >= TILE_M) || (K4 - k4_start >= TILE_K4) +void load_input_tile_no_checks( + out FPInputTile in_tile, + const int k4_start, + const int m_start, + const int K4, + const int M) { +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + in_tile.data[m][0] = load_input_x4(k4_start, m_start + m, K4); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4); + } + } +#endif +} + +// To be used if near tensor boundaries +void load_input_tile_with_checks( + out FPInputTile in_tile, + const int k4_start, + const int m_start, + const int K4, + const int M) { +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + if (m_start + m < M) { + in_tile.data[m][0] = load_input_x4(k4_start, m_start + m, K4); + } else { + in_tile.data[m][0] = VEC4_T(0.0); + } + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + if (m_start + m < M && k4_start + k4 < K4) { + in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4); + } else { + in_tile.data[m][k4] = VEC4_T(0.0); + } + } + } +#endif +} + +#endif // LINEAR_FP_INPUT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh new file mode 100644 index 00000000000..049f1d34caf --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines the FPOutTile struct, which is used to represent a tile of the output + * matrix of a matrix multiplication operation. + * + * Settings: + * - TILE_M: number of rows in the output tile + * - TILE_N4: number of (groups of 4) columns in the output tile + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct FPOutTile { + VEC4_T data[TILE_M][TILE_N4]; +}; + +void initialize(out FPOutTile out_tile) { +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + out_tile.data[m][0] = VEC4_T(0); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + out_tile.data[m][n4] = VEC4_T(0); + } + } +#endif +} + +#ifdef DEBUG_MODE + +void printFPOutTile(const FPOutTile tile) { + debugPrintfEXT("output_tile: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %f, %f, %f, %f,", + tile.data[m][n4].x, + tile.data[m][n4].y, + tile.data[m][n4].z, + tile.data[m][n4].w); + } + debugPrintfEXT("\\n"); + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_OUTPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh new file mode 100644 index 00000000000..7229da32cd3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh @@ -0,0 +1,126 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using fp input and weight tiles. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_common.glslh" +#include "linear_fp_input_tile.glslh" +#include "linear_fp_output_tile.glslh" +#include "linear_fp_per_out_channel_params.glslh" +#include "linear_fp_weight_tile.glslh" + +/* + * Accumulates floating point input tile and floating point weight tile into + * floating point output tile. + */ +void fp_accumulate_with_fp_weight( + inout FPOutTile accum, + FPInputTile in_tile, + FPWeightTile w_tile) { +#if TILE_N4 == 1 && TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + accum.data[m][0] = + fma(VEC4_T(in_tile.data[m][0][0]), + w_tile.data[mul_4(0)][0], + accum.data[m][0]); + + accum.data[m][0] = + fma(VEC4_T(in_tile.data[m][0][1]), + w_tile.data[mul_4(0) + 1][0], + accum.data[m][0]); + + accum.data[m][0] = + fma(VEC4_T(in_tile.data[m][0][2]), + w_tile.data[mul_4(0) + 2][0], + accum.data[m][0]); + + accum.data[m][0] = + fma(VEC4_T(in_tile.data[m][0][3]), + w_tile.data[mul_4(0) + 3][0], + accum.data[m][0]); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int n = mul_4(n4); + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][0]), + w_tile.data[mul_4(k4)][n4], + accum.data[m][n4]); + + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][1]), + w_tile.data[mul_4(k4) + 1][n4], + accum.data[m][n4]); + + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][2]), + w_tile.data[mul_4(k4) + 2][n4], + accum.data[m][n4]); + + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][3]), + w_tile.data[mul_4(k4) + 3][n4], + accum.data[m][n4]); + } + } + } + +#endif +} + +/* + * Applies per output channel weight scales to the output tile. + */ +void apply_scales(inout FPOutTile tile, const FPPerOutChannelParams scales) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + tile.data[m][0] = tile.data[m][0] * scales.data[0]; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m][n4] = tile.data[m][n4] * scales.data[n4]; + } + } +#endif +} + +/* + * Applies per output channel weight scales and per output channel biases to the + * output tile. + */ +void apply_scales_and_biases( + inout FPOutTile tile, + const FPPerOutChannelParams scales, + const FPPerOutChannelParams bias) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + tile.data[m][0] = tile.data[m][0] * scales.data[0] + bias.data[0]; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m][n4] = tile.data[m][n4] * scales.data[n4] + bias.data[n4]; + } + } +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh new file mode 100644 index 00000000000..b2ab64a1573 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using fp input and weight tiles. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_FP_INT8_COMPUTE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_FP_INT8_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_common.glslh" +#include "linear_fp_input_tile.glslh" +#include "linear_fp_output_tile.glslh" +#include "linear_int8_weight_tile.glslh" + +// Unpacks a int containing 4 packed 8-bit integers into a vec4 containing each +// of the 4 unpacked 8-bit integers. +VEC4_T unpack_packed_4xint8(int int8x4) { + return VEC4_T( + extract_8bit_from_packed_int_le(int8x4, 0), + extract_8bit_from_packed_int_le(int8x4, 1), + extract_8bit_from_packed_int_le(int8x4, 2), + extract_8bit_from_packed_int_le(int8x4, 3)); +} + +void fp_accumulate_with_int8_weight( + inout FPOutTile accum, + FPInputTile in_tile, + Int8WeightTile w_tile) { + // Accum tile is indexed as accum[m][n4][n4i] + // -> gives fp accumulator for output tile element at (x = n, y = m) + // Input tile is indexed as in_tile.data[m][k4] + // -> gives vec4 containing the fp inputs at index + // (k, m), (k + 1, m), (k + 2, m), (k + 3, m) + // Weight tile is indexed as w_tile.data[k4][n4][n4i] + // -> gives packed integer containing the 4x 8-bit quantized values at index + // (n, k), (n, k + 1), (n, k + 2), (n, k + 3) + VEC4_T weight_texel; +#if TILE_K4 == 1 && TILE_N4 == 1 + [[unroll]] for (int k = 0; k < 4; ++k) { + // Unpack one column of weights + weight_texel = VEC4_T( + extract_8bit_from_packed_int_le(w_tile.data[0][0][0], k), + extract_8bit_from_packed_int_le(w_tile.data[0][0][1], k), + extract_8bit_from_packed_int_le(w_tile.data[0][0][2], k), + extract_8bit_from_packed_int_le(w_tile.data[0][0][3], k)); + + for (int m = 0; m < TILE_M; ++m) { + accum.data[m][0] = + fma(VEC4_T(in_tile.data[m][0][k]), weight_texel, accum.data[m][0]); + } + } + +#else + // TODO(ssjia): implement the general case + not implemented + +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_FP_INT8_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh new file mode 100644 index 00000000000..b04074eba75 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh @@ -0,0 +1,179 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using int8 input and weight tiles. + * + * Settings: + * - TILE_M: The number of rows in the output tile. + * - TILE_N4: The number of (groups of 4) columns in the output tile. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#include "linear_common.glslh" +#include "linear_fp_output_tile.glslh" +#include "linear_fp_per_out_channel_params.glslh" +#include "linear_int8_input_tile.glslh" +#include "linear_int8_weight_tile.glslh" +#include "linear_int_per_out_channel_params.glslh" + +// Stores integer accumulators for an output tile. +struct Int32Accum { + ivec4 data[TILE_M][TILE_N4]; +}; + +// Initialize values to 0 +void initialize(out Int32Accum out_accum) { +#if TILE_N4 == 1 + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + out_accum.data[y][0] = ivec4(0); + } + +#else + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + [[unroll]] for (int x4 = 0; x4 < TILE_K4; ++x4) { + out_accum.data[y][x4] = ivec4(0); + } + } +#endif +} + +// Accumulate int8 input and weight tiles into integer accumulator tile +void int_accumulate_with_int8_weight( + inout Int32Accum accum, + Int8InputTile in_tile, + Int8WeightTile w_tile) { + // Accum tile is indexed as accum[m][n4][n4i] + // -> gives integer accumulator for output tile element at (x = n, y = m) + // Input tile is indexed as in_tile.data[m4][k4][m4i] + // -> gives packed integer containing the 4x 8-bit quantized values at index + // (k, m), (k + 1, m), (k + 2, m), (k + 3, m) + // Weight tile is indexed as w_tile.data[k4][n4][n4i] + // -> gives packed integer containing the 4x 8-bit quantized values at index + // (n, k), (n, k + 1), (n, k + 2), (n, k + 3) +#if TILE_M4 == 1 && TILE_K4 == 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + // n = 0 + accum.data[m][0][0] = dotPacked4x8AccSatEXT( + in_tile.data[0][0][m], w_tile.data[0][0][0], accum.data[m][0][0]); + // n = 1 + accum.data[m][0][1] = dotPacked4x8AccSatEXT( + in_tile.data[0][0][m], w_tile.data[0][0][1], accum.data[m][0][1]); + // n = 2 + accum.data[m][0][2] = dotPacked4x8AccSatEXT( + in_tile.data[0][0][m], w_tile.data[0][0][2], accum.data[m][0][2]); + // n = 3 + accum.data[m][0][3] = dotPacked4x8AccSatEXT( + in_tile.data[0][0][m], w_tile.data[0][0][3], accum.data[m][0][3]); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + const int m4 = div_4(m); + const int m4i = mod_4(m); + [[unroll]] for (int n = 0; n < TILE_N; ++n) { + const int n4 = div_4(n); + const int n4i = mod_4(n); + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + accum.data[m][n4][n4i] = dotPacked4x8AccSatEXT( + in_tile.data[m4][k4][m4i], + w_tile.data[k4][n4][n4i], + accum.data[m][n4][n4i]); + } + } + } + +#endif +} + +/* + * Computes final weight matrix output tile using: + * - int8 accumulator tile + * - per output channel weight sums + * - per output channel scales + */ +void accumulate_out_tile_with_int_accum( + inout FPOutTile out_tile, + const Int32Accum accum, + const float input_q_scale, + const int input_q_zp, + const IntPerOutChannelParams weight_sums, + const FPPerOutChannelParams weight_scales) { + ivec4 input_zp_vec = ivec4(-input_q_zp); +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + // Unfortunately fma doesn't work with ivec4. Prefer to preserve integer + // format for as long as possible to avoid precision loss. + ivec4 accum_adjusted = + input_zp_vec * weight_sums.data[0] + accum.data[m][0]; + out_tile.data[m][0] = + fma(VEC4_T(accum_adjusted), + input_q_scale * weight_scales.data[0], + out_tile.data[m][0]); + } + +#else + // TODO(ssjia): Implement the general case + not implemented + +#endif +} + +void accumulate_out_tile_with_int_accum( + inout FPOutTile out_tile, + const Int32Accum accum, + const float input_q_scale, + const int input_q_zp, + const IntPerOutChannelParams weight_sums, + const FPPerOutChannelParams weight_scales, + const FPPerOutChannelParams bias) { + ivec4 input_zp_vec = ivec4(-input_q_zp); +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + // Apply scale and zero points to the int accumulator + ivec4 accum_adjusted = + input_zp_vec * weight_sums.data[0] + accum.data[m][0]; + out_tile.data[m][0] = + fma(VEC4_T(accum_adjusted), + input_q_scale * weight_scales.data[0], + out_tile.data[m][0]); + out_tile.data[m][0] += bias.data[0]; + } + +#else + // TODO(ssjia): Implement the general case + not implemented + +#endif +} + +#ifdef DEBUG_MODE + +void printInt32Accum(const Int32Accum tile) { + debugPrintfEXT("int accum: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %d, %d, %d, %d,", + tile.data[m][n4].x, + tile.data[m][n4].y, + tile.data[m][n4].z, + tile.data[m][n4].w); + } + debugPrintfEXT("\\n"); + } +} + +#endif + +#endif // LINEAR_FP_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh new file mode 100644 index 00000000000..a4019204cc3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions store a FpOutTile to output buffer/texture. + * + * Requires: + * - t_output to be declared in the shader layout + * + * Settings: + * - OUTPUT_BUFFER to indicate t_output is a vec4 buffer, otherwise texture + * storage is assumed. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_STORE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_STORE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_output_tile.glslh" + +#ifdef OUTPUT_BUFFER + +void write_output_x4( + const VEC4_T out_texel, + const int n4, + const int m, + const int N4) { + t_output[m * N4 + n4] = out_texel; +} + +#else + +void write_output_x4( + const VEC4_T out_texel, + const int n4, + const int m, + const int N4) { + imageStore(t_output, ivec3(n4, m, 0), out_texel); +} + +#endif // OUTPUT_BUFFER + +void write_output_tile( + const FPOutTile out_tile, + const int n4_start, + const int m_start, + const int N4) { +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); + } + } +#endif +} + +// To be used if M - m >= TILE_M && N4 - n4 >= TILE_N4 +void write_output_tile_no_checks( + const FPOutTile out_tile, + const int n4_start, + const int m_start, + const int N4, + const int M) { +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); + } + } +#endif +} + +// To be used if close to tensor boundaries +void write_output_tile_with_checks( + const FPOutTile out_tile, + const int n4_start, + const int m_start, + const int N4, + const int M) { +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + if (m_start + m < M) { + write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); + } + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + if (m_start + m < M && n4_start + n4 < N4) { + write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); + } + } + } +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_per_out_channel_params.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_per_out_channel_params.glslh new file mode 100644 index 00000000000..72b22988414 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_per_out_channel_params.glslh @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines common functions and structs to be used across matrix multiplication + * operators. + */ + +#ifndef LINEAR_FP_PER_OUT_CHANNEL_PARAMS_GLSLH +#define LINEAR_FP_PER_OUT_CHANNEL_PARAMS_GLSLH + +#include "common.glslh" + +#extension GL_EXT_control_flow_attributes : require + +// Represents floating point parameter tensors where each element is associated +// with an output channel, such as weight scales, biases, etc. +struct FPPerOutChannelParams { + VEC4_T data[TILE_N4]; +}; + +#ifdef DEBUG_MODE + +void printFPPerOutChannelParams(const FPPerOutChannelParams params) { + debugPrintfEXT("per_out_channel_params: \\n"); + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + params.data[n4].x, + params.data[n4].y, + params.data[n4].z, + params.data[n4].w); + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_PER_OUT_CHANNEL_PARAMS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh new file mode 100644 index 00000000000..0cba49e87c7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_FP_WEIGHT_SCALES_LOAD_GLSLH +#define LINEAR_FP_WEIGHT_SCALES_LOAD_GLSLH + +#include "linear_fp_per_out_channel_params.glslh" + +VEC4_T load_weight_scale_x4(const int n4) { + return t_weight_scales[n4]; +} + +void load_weight_scales_tile( + out FPPerOutChannelParams scales, + const int n4_start) { +#if TILE_N4 == 1 + scales.data[0] = load_weight_scale_x4(n4_start); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + scales.data[n4] = load_weight_scale_x4(n4_start + n4); + } + +#endif +} + +#endif // LINEAR_FP_WEIGHT_SCALES_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh new file mode 100644 index 00000000000..f44bbbc1565 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh @@ -0,0 +1,103 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines the FPWeightTile struct, which is used to represent a fp tile of a + * weight matrix in matrix multiplication. + * + * Settings: + * - TILE_K: number of rows in the output tile + * - TILE_N4: number of (groups of 4) columns in the output tile + */ + +#ifndef LINEAR_FP_WEIGHT_TILE_GLSLH +#define LINEAR_FP_WEIGHT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "common.glslh" + +struct FPWeightTile { + VEC4_T data[TILE_K][TILE_N4]; +}; + +#ifdef LINEAR_INT8_WEIGHT_TILE_GLSLH + +int sign_extend(const int val) { + if ((val & 0x80) != 0) { + return val | (~0xFF); + } + return val; +} + +T extract_8bit_value(const Int8WeightTile w_tile, const int k, const int n) { +#if TILE_K4 == 1 && TILE_N4 == 1 + const int k4i = k; + const int n4i = n; + ivec4 block = w_tile.data[0][0]; + +#else + const int k4 = div_4(k); + const int k4i = mod_4(k); + + const int n4 = div_4(n); + const int n4i = mod_4(n); + + ivec4 block = w_tile.data[k4][n4]; +#endif + + int col = block[n4i]; + int val = (col >> (k4i * 8)) & 0xFF; + + return T(sign_extend(val)); +} + +void unpack(out FPWeightTile fp_w_tile, const Int8WeightTile w_tile) { +#if TILE_K > 1 && TILE_N4 == 1 + [[unroll]] for (int k = 0; k < TILE_K; ++k) { + fp_w_tile.data[k][0][0] = extract_8bit_value(w_tile, k, 0); + fp_w_tile.data[k][0][1] = extract_8bit_value(w_tile, k, 1); + fp_w_tile.data[k][0][2] = extract_8bit_value(w_tile, k, 2); + fp_w_tile.data[k][0][3] = extract_8bit_value(w_tile, k, 3); + } + +#else + [[unroll]] for (int k = 0; k < TILE_M; ++k) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int n = mul_4(n4); + fp_w_tile.data[k][n4][0] = extract_8bit_value(w_tile, k, n); + fp_w_tile.data[k][n4][1] = extract_8bit_value(w_tile, k, n + 1); + fp_w_tile.data[k][n4][2] = extract_8bit_value(w_tile, k, n + 2); + fp_w_tile.data[k][n4][3] = extract_8bit_value(w_tile, k, n + 3); + } + } +#endif +} + +#endif // LINEAR_INT8_WEIGHT_TILE_GLSLH + +#ifdef DEBUG_MODE + +void printFPWeightTile(const FPWeightTile tile) { + debugPrintfEXT("weight_tile: \\n"); + [[unroll]] for (int k = 0; k < TILE_K; ++k) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %f, %f, %f, %f, ", + tile.data[k][n4].x, + tile.data[k][n4].y, + tile.data[k][n4].z, + tile.data[k][n4].w); + } + debugPrintfEXT("\\n"); + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_WEIGHT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh new file mode 100644 index 00000000000..9535de21f7b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh @@ -0,0 +1,84 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * This file defines utilties to perform int8 quantization and block packing of + * matrix multiplation inputs. It also defines utilities to store packed block + * data to an output buffer or texture. + * + * Requires: + * - t_packed_int8_input to be defined in shader layout (output buffer/texture) + * + * Settings: + * - OUTPUT_BUFFER to indicate if output resource is a buffer. Otherwise texture + * is assumed. + */ + +#ifndef LINEAR_INT8_INPUT_BLOCK_GLSLH +#define LINEAR_INT8_INPUT_BLOCK_GLSLH + +#define TILE_M 4 +#define TILE_K4 1 + +#include "linear_fp_input_tile.glslh" + +struct Int8InputBlock { + ivec4 data; +}; + +ivec4 quantize( + const VEC4_T val, + const float q_inv_scale, + const int q_zero_point) { + vec4 quantized = round(vec4(val) * q_inv_scale) + q_zero_point; + // hard-code 8 bit quantization range + return clamp(ivec4(quantized), -128, 127); +} + +int pack_into_int32(const ivec4 quant_vals) { + int packed = ((quant_vals[0] & 0xFF) << 0) | ((quant_vals[1] & 0xFF) << 8) | + ((quant_vals[2] & 0xFF) << 16) | ((quant_vals[3] & 0xFF) << 24); + + return packed; +} + +void quantize_and_pack( + out Int8InputBlock packed, + const FPInputTile in_block, + const float q_inv_scale, + const int q_zero_point) { + for (int row = 0; row < 4; ++row) { + ivec4 quantized_inputs = + quantize(in_block.data[row][0], q_inv_scale, q_zero_point); + packed.data[row] = pack_into_int32(quantized_inputs); + } +} + +#ifdef OUTPUT_BUFFER + +void write_block( + const Int8InputBlock block, + const int block_x, + const int block_y, + const int nblocks_x) { + t_packed_int8_input[block_y * nblocks_x + block_x] = block.data; +} + +#else // OUTPUT_TEXTURE + +void write_block( + const Int8InputBlock block, + const int block_x, + const int block_y, + const int nblocks_x) { + imageStore(t_packed_int8_input, ivec3(block_x, block_y, 0), block.data); +} + +#endif // OUTPUT_BUFFER + +#endif // LINEAR_INT8_INPUT_BLOCK_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh new file mode 100644 index 00000000000..89a7e1b3f89 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines the Int8InputTile struct, which is used to represent a tile of the + * quantized int8 input matrix of a quantized matrix multiplication operation. + * + * Settings: + * - TILE_M4: number of (groups of 4) rows in the tile + * - TILE_K4: number of (groups of 4) columns in the tile + */ + +#ifndef LINEAR_INT8_INPUT_TILE_GLSLH +#define LINEAR_INT8_INPUT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct Int8InputTile { + ivec4 data[TILE_M4][TILE_K4]; +}; + +#ifdef DEBUG_MODE + +#include "linear_common.glslh" + +void printInt8InputTile(const Int8InputTile tile) { + debugPrintfEXT( + "Int8InputTile [TILE_M4=%d][TILE_K4=%d]:\\n", TILE_M4, TILE_K4); + + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, k4); + + // Each ivec4 contains 4 packed integers, each integer contains 4 8-bit + // values + [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { + int packed_int = tile.data[m4][k4][vec_idx]; + debugPrintfEXT("packed_int[%d]=%d -> [", vec_idx, packed_int); + + // Extract 4 8-bit values from this packed integer + [[unroll]] for (int byte_idx = 0; byte_idx < 4; ++byte_idx) { + int val = extract_8bit_from_packed_int_le(packed_int, byte_idx); + if (byte_idx < 3) { + debugPrintfEXT("%d, ", val); + } else { + debugPrintfEXT("%d] ", val); + } + } + } + debugPrintfEXT("\\n"); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_INPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh new file mode 100644 index 00000000000..c79badab6c6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to load a Int8InputTile from input buffer/texture. + * + * Requires: + * - t_packed_int8_input to be declared in the shader layout + * + * Settings: + * - PACKED_INT8_INPUT_BUFFER to indicate resource is a buffer, otherwise + * texture storage is assumed. + */ + +#ifndef LINEAR_INT8_INPUT_TILE_LOAD_GLSLH +#define LINEAR_INT8_INPUT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_int8_input_tile.glslh" + +#ifdef PACKED_INT8_INPUT_BUFFER + +ivec4 load_int8_input_block( + const int block_x, + const int block_y, + const int nblocks_x) { + return t_packed_int8_input[(block_y * nblocks_x) + block_x]; +} + +#else + +ivec4 load_int8_input_block( + const int block_x, + const int block_y, + const int nblocks_x) { + return texelFetch(t_packed_int8_input, ivec3(block_x, block_y, 0), 0); +} + +#endif // PACKED_INT8_INPUT_BUFFER + +void load_int8_input_tile( + out Int8InputTile in_tile, + const int block_x, + const int block_y, + const int nblocks_x) { +#if TILE_M4 == 1 && TILE_K4 == 1 + in_tile.data[0][0] = load_int8_input_block(block_x, block_y, nblocks_x); + +#elif TILE_M4 == 1 && TILE_K4 > 1 + [[unroll]] for (int x = 0; x < TILE_K4; ++x) { + in_tile.data[0][x] = load_int8_input_block(block_x + x, block_y, nblocks_x); + } + +#elif TILE_M4 > 1 && TILE_K4 == 1 + [[unroll]] for (int y = 0; y < TILE_M4; ++y) { + in_tile.data[y][0] = load_int8_input_block(block_x, block_y + y, nblocks_x); + } + +#else + [[unroll]] for (int y = 0; y < TILE_M4; ++y) { + [[unroll]] for (int x = 0; x < TILE_K4; ++x) { + in_tile.data[y][x] = + load_int8_input_block(block_x + x, block_y + y, nblocks_x); + } + } +#endif +} + +#endif // LINEAR_INT8_INPUT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh new file mode 100644 index 00000000000..6e98caea49e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_WEIGHT_BLOCK_GLSLH +#define LINEAR_INT8_WEIGHT_BLOCK_GLSLH + +/* + * This file defines utilties to perform weight prepacking of quantized int8 + * matrix multiplation weights. It also defines utilities to load source + * weight data from inputbuffer, and write out a packed weight block to output + * texture/buffer. + * + * Requires: + * - t_packed_int8_weight to be defined in shader layout (output texture/buffer) + * - t_int8_weight to be defined in shader layout (input buffer) + * + * Settings: + * - USING_BUFFER to indicate if output resource is a buffer. Otherwise texture + * is assumed. + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_common.glslh" + +// Represents data for a 4x4 block of the weight matrix read from the input +// buffer. +struct Int8WeightBlock { + ivec4 data; +}; + +void load_block_data_no_checks( + out Int8WeightBlock block, + const int k4, + const int n_start, + const int ntexels_K, + const int N) { + [[unroll]] for (int n = 0; n < 4; ++n) { + block.data[n] = t_int8_weight[(n_start + n) * ntexels_K + k4]; + } +} + +void load_block_data_with_checks( + out Int8WeightBlock block, + const int k4, + const int n_start, + const int ntexels_K, + const int N) { + [[unroll]] for (int n = 0; n < 4; ++n) { + if (n_start + n < N) { + block.data[n] = t_int8_weight[(n_start + n) * ntexels_K + k4]; + } else { + block.data[n] = 0; + } + } +} + +#ifdef USING_BUFFER + +void write_weight_block( + const Int8WeightBlock block, + const int n4, + const int k4, + const int ntexels_N) { + t_packed_int8_weight[k4 * ntexels_N + n4] = block.data; +} + +#else // USING_TEXTURE + +void write_weight_block( + const Int8WeightBlock block, + const int n4, + const int k4, + const int ntexels_N) { + imageStore(t_packed_int8_weight, ivec2(n4, k4), block.data); +} + +#endif // USING_BUFFER + +#ifdef DEBUG_MODE + +void printInt8WeightBlock(const Int8WeightBlockPacked block) { + debugPrintfEXT("int8_weight_block_packed: \\n"); + debugPrintfEXT( + "%i %i %i %i \\n", + block.data[0], + block.data[1], + block.data[2], + block.data[3]); +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_WEIGHT_BLOCK_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh new file mode 100644 index 00000000000..f312db543db --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_WEIGHT_TILE_GLSLH +#define LINEAR_INT8_WEIGHT_TILE_GLSLH + +/* + * Defines the Int8WeightTile struct, which is used to represent a tile of the + * quantized int8 weight matrix of a quantized matrix multiplication operation. + * + * Settings: + * - TILE_K4: number of (groups of 4) rows in the weight tile + * - TILE_N4: number of (groups of 4) columns in the weight tile + */ + +#extension GL_EXT_control_flow_attributes : require + +struct Int8WeightTile { + ivec4 data[TILE_K4][TILE_N4]; +}; + +#ifdef DEBUG_MODE + +void printInt8WeightTile(const Int8WeightTile tile) { + debugPrintfEXT( + "Int8WeightTile [TILE_K4=%d][TILE_N4=%d]:\\n", TILE_K4, TILE_N4); + + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, k4); + + // Each ivec4 contains 4 packed integers, each integer contains 4 8-bit + // values + [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { + int packed_int = tile.data[m4][k4][vec_idx]; + debugPrintfEXT("packed_int[%d]=%d -> [", vec_idx, packed_int); + + // Extract 4 8-bit values from this packed integer + [[unroll]] for (int byte_idx = 0; byte_idx < 4; ++byte_idx) { + int val = extract_8bit_from_packed_int_le(packed_int, byte_idx); + if (byte_idx < 3) { + debugPrintfEXT("%d, ", val); + } else { + debugPrintfEXT("%d] ", val); + } + } + } + debugPrintfEXT("\\n"); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_WEIGHT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh new file mode 100644 index 00000000000..fe16d3469b3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH +#define LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH + +/* + * Defines functions to load a Int8WeightTile from input buffer/texture. + * + * Requires: + * - t_packed_int8_weight to be declared in the shader layout (input + * buffer/texture) + * + * Settings: + * - WEIGHT_BUFFER to indicate t_packed_int8_weight is a buffer, otherwise + * texture storage is assumed. + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_int8_weight_tile.glslh" + +#ifdef WEIGHT_BUFFER + +ivec4 load_int8_weight_block( + const int block_x, + const int block_y, + const int nblocks_x) { + return t_packed_int8_weight[(block_y * nblocks_x) + block_x]; +} + +#else // WEIGHT_TEXTURE + +ivec4 load_int8_weight_block( + const int block_x, + const int block_y, + const int nblocks_x) { + return texelFetch(t_packed_int8_weight, ivec2(block_x, block_y), 0); +} + +#endif // WEIGHT_BUFFER + +void load_int8_weight_tile( + out Int8WeightTile weight_tile, + const int block_x, + const int block_y, + const int nblocks_x) { +#if TILE_K4 == 1 && TILE_N4 == 1 + weight_tile.data[0][0] = load_int8_weight_block(block_x, block_y, nblocks_x); + +#elif TILE_K4 == 1 && TILE_N4 > 1 + [[unroll]] for (int x = 0; x < TILE_N4; ++x) { + weight_tile.data[0][x] = + load_int8_weight_block(block_x + x, block_y, nblocks_x); + } + +#elif TILE_K4 > 1 && TILE_N4 == 1 + [[unroll]] for (int y = 0; y < TILE_M4; ++y) { + weight_tile.data[y][0] = + load_int8_weight_block(block_x, block_y + y, nblocks_x); + } + +#else + [[unroll]] for (int y = 0; y < TILE_K4; ++y) { + [[unroll]] for (int x = 0; x < TILE_N4; ++x) { + weight_tile.data[y][x] = + load_int8_weight_block(block_x + x, block_y + y, nblocks_x); + } + } +#endif +} + +#endif // LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_out_channel_params.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_out_channel_params.glslh new file mode 100644 index 00000000000..ca29fd52780 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_out_channel_params.glslh @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines common functions and structs to be used across matrix multiplication + * operators. + */ + +#ifndef LINEAR_INT_PER_OUT_CHANNEL_PARAMS_GLSLH +#define LINEAR_INT_PER_OUT_CHANNEL_PARAMS_GLSLH + +#include "common.glslh" + +#extension GL_EXT_control_flow_attributes : require + +// Represents floating point parameter tensors where each element is associated +// with an output channel, such as weight scales, biases, etc. +struct IntPerOutChannelParams { + ivec4 data[TILE_N4]; +}; + +#ifdef DEBUG_MODE + +void printIntPerOutChannelParams(const IntPerOutChannelParams params) { + debugPrintfEXT("per_out_channel_params: \\n"); + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %d, %d, %d, %d, ", + params.data[n4].x, + params.data[n4].y, + params.data[n4].z, + params.data[n4].w); + } + debugPrintfEXT("\\n"); +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT_PER_OUT_CHANNEL_PARAMS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int_weight_sums_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int_weight_sums_load.glslh new file mode 100644 index 00000000000..1a17f99ea4e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int_weight_sums_load.glslh @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_FP_WEIGHT_SUMS_LOAD_GLSLH +#define LINEAR_FP_WEIGHT_SUMS_LOAD_GLSLH + +#include "linear_int_per_out_channel_params.glslh" + +ivec4 load_weight_sum_x4(const int n4) { + return ivec4(t_weight_sums[n4]); +} + +void load_weight_sums_tile( + out IntPerOutChannelParams sums, + const int n4_start) { +#if TILE_N4 == 1 + sums.data[0] = load_weight_sum_x4(n4_start); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + sums.data[n4] = load_weight_sum_x4(n4_start + n4); + } + +#endif +} + +#endif // LINEAR_FP_WEIGHT_SUMS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl new file mode 100644 index 00000000000..b6d932f0015 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl @@ -0,0 +1,113 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "uint", "apply_bias", "0")} + +#include "linear_fp_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_weight_tile.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "linear_fp_output_tile_fp_int8_compute.glslh" +#include "linear_fp_output_tile_store.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" + +void main() { + // Each thread writes out a 4 wide x 4 high tile of output values + const int out_tile_x = int(gl_GlobalInvocationID.x); + const int out_tile_y = int(gl_GlobalInvocationID.y); + + const int n = out_tile_x * TILE_N; + const int m = out_tile_y * TILE_M; + + const int n4 = div_4(n); + const int m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const int M = input_sizes.y; + const int K4 = div_up_4(input_sizes.x); + const int N4 = div_up_4(output_sizes.x); + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + Int8WeightTile int8_weight_tile; + + const bool dont_check_bounds = (M - m) >= TILE_M; + if (dont_check_bounds) { + for (int k4 = 0; k4 < K4; k4 += TILE_K4) { + load_input_tile_no_checks(in_tile, k4, m, K4, M); + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + fp_accumulate_with_int8_weight(out_tile, in_tile, int8_weight_tile); + } + } else { + for (int k4 = 0; k4 < K4; k4 += TILE_K4) { + load_input_tile_with_checks(in_tile, k4, m, K4, M); + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + fp_accumulate_with_int8_weight(out_tile, in_tile, int8_weight_tile); + } + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, n4); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + apply_scales_and_biases(out_tile, weight_scales_tile, bias_tile); + } + else { + apply_scales(out_tile, weight_scales_tile); + } + + if (dont_check_bounds) { + write_output_tile_no_checks(out_tile, n4, m, N4, M); + } else { + write_output_tile_with_checks(out_tile, n4, m, N4, M); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml new file mode 100644 index 00000000000..242c4471b3d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_q8csw_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 1 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: linear_q8csw_tiled_texture3d_texture2d + - NAME: linear_q8csw_tiled_texture3d_buffer + WEIGHT_STORAGE: buffer + - NAME: linear_q8csw_tiled_buffer_texture2d + IO_STORAGE: buffer + WEIGHT_STORAGE: texture2d + - NAME: linear_q8csw_tiled_buffer_buffer + IO_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl new file mode 100644 index 00000000000..9f7e00e3317 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl @@ -0,0 +1,132 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T int + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if PACKED_INT8_INPUT_STORAGE == "buffer": + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INT8_INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_spec_const(C, "int", "apply_bias", "0")} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "linear_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_fp_output_tile_store.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_bias_load.glslh" + +void main() { + // Each thread writes out a 4 wide x 4 high tile of output values + const int out_tile_x = int(gl_GlobalInvocationID.x); + const int out_tile_y = int(gl_GlobalInvocationID.y); + + const int n = out_tile_x * TILE_N; + const int m = out_tile_y * TILE_M; + + const int n4 = div_4(n); + const int m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const int M = output_sizes.y; + const int K4 = div_up_4(input_sizes.x); + const int N4 = div_up_4(output_sizes.x); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_in_tile; + Int8WeightTile int8_weight_tile; + + // No checks are needed since packed input and weight are structured in units + // of 4x4 blocks. + for (int k4 = 0; k4 < K4; k4 += TILE_K4) { + load_int8_input_tile(int8_in_tile, k4, m4, K4); + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + + int_accumulate_with_int8_weight(out_accum, int8_in_tile, int8_weight_tile); + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, n4); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, n4); + + FPOutTile out_tile; + initialize(out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile); + } + + if (M - m >= TILE_M) { + write_output_tile_no_checks(out_tile, n4, m, N4, M); + } else { + write_output_tile_with_checks(out_tile, n4, m, N4, M); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml new file mode 100644 index 00000000000..aa1de3077fc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_q8ta_q8csw_tiled: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + PACKED_INT8_INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 1 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: linear_q8ta_q8csw_tiled_texture3d_buffer_texture2d + - NAME: linear_q8ta_q8csw_tiled_buffer_buffer_texture2d + OUTPUT_STORAGE: buffer + PACKED_INT8_INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl new file mode 100644 index 00000000000..f2c74b67283 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type(STORAGE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_packed_int8_weight", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_int8_weight", "int", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec2 orig_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" +#include "linear_int8_weight_block.glslh" + +void main() { + // The size of the source weight tensor is [W=K, H=N]. Each shader invocation + // processes a 4x4 block. The thread position corresponds to the block index. + int n4 = int(gl_GlobalInvocationID.x); + int k4 = int(gl_GlobalInvocationID.y); + + const int K = orig_sizes.x; + const int N = orig_sizes.y; + + // Determine the total number of blocks and check bounds + const int N4 = div_up_4(N); + const int K4 = div_up_4(K); + if (n4 >= N4 || k4 >= K4) { + return; + } + + // Each block is represented as an ivec4. Each int corresponds to a row i.e. + // N dim of the weight tensor and contains data for 4 columns i.e. K dim. + Int8WeightBlock block; + const int n = mul_4(n4); + if (N - n >= 4) { + load_block_data_no_checks(block, k4, n, K4, N); + } else { + load_block_data_with_checks(block , k4, n, K4, N); + } + + // The weight blocks are stored in a tranposed manner, such that weight blocks + // are indexed like packed_weight[k4][n4]. This is to optimize memory + // coalescing when computing tiled GEMM. + write_weight_block(block, n4, k4, N4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml new file mode 100644 index 00000000000..13e6d43b2c5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_q8_linear_weight: + parameter_names_with_default_values: + STORAGE: buffer + shader_variants: + - NAME: pack_q8_linear_weight_buffer + STORAGE: buffer + - NAME: pack_q8_linear_weight_texture2d + STORAGE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl new file mode 100644 index 00000000000..6ba9343f10d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl @@ -0,0 +1,79 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +$if GRANULARITY == "per_channel": + ${layout_declare_tensor(B, "r", "t_scale", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float inv_scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "linear_int8_input_block.glslh" +#include "linear_fp_input_tile_load.glslh" + +void main() { + // Each input block contains 4x4 int8 quantized values, which are packed into + // a ivec4. k4 and m4 represent the "block index" of the current block being + // processed. + int k4 = int(gl_GlobalInvocationID.x); + int m4 = int(gl_GlobalInvocationID.y); + + const int K = input_sizes.x; + const int M = input_sizes.y; + + // K4 and M4 represent the number of blocks in each dimension. + const int K4 = div_up_4(K); + const int M4 = div_up_4(M); + + if (k4 >= K4 || m4 >= M4) { + return; + } + + // row of the input tensor to start loading from. Note the input tensor is + // interpreted as a t + const int m = mul_4(m4); + + const bool dont_check_bounds = (M - m) >= 4; + + FPInputTile in_tile; + if (dont_check_bounds) { + load_input_tile_no_checks(in_tile, k4, m, K4, M); + } else { + load_input_tile_with_checks(in_tile, k4, m, K4, M); + } + + Int8InputBlock packed_block; + quantize_and_pack(packed_block, in_tile, inv_scale, zp); + + write_block(packed_block, k4, m4, K4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml new file mode 100644 index 00000000000..37721db1ba8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +quantize_and_pack_linear_input: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + STORAGE: texture3d + GRANULARITY: per_tensor + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: quantize_and_pack_linear_input_per_tensor_texture3d_texture3d + - NAME: quantize_and_pack_linear_input_per_tensor_buffer_texture3d + OUTPUT_STORAGE: buffer + - NAME: quantize_and_pack_linear_input_per_tensor_buffer_buffer + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp new file mode 100644 index 00000000000..d6aeb5e3dce --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -0,0 +1,554 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace vkcompute { + +// +// Shader dispatch utilities +// + +utils::uvec3 quantized_linear_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + + std::vector out_sizes = graph->sizes_of(out); + // height + const uint32_t M = utils::val_at(-2, out_sizes); + // width + const uint32_t N = utils::val_at(-1, out_sizes); + + // 1 output tile is 4x4 elements + const uint32_t M4 = utils::div_up(M, 4u); + const uint32_t N4 = utils::div_up(N, 4u); + + return {N4, M4, 1}; +} + +utils::uvec3 quantized_linear_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); +} + +std::tuple get_quantized_input_num_blocks( + ComputeGraph& graph, + const ValueRef input) { + std::vector input_sizes = graph.sizes_of(input); + const int64_t ndim = graph.dim_of(input); + + const int64_t M = input_sizes.at(ndim - 2); + const int64_t K = input_sizes.at(ndim - 1); + + const int64_t num_blocks_M = utils::div_up(M, int64_t(4)); + const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); + + return std::make_tuple(num_blocks_M, num_blocks_K); +} + +utils::uvec3 quant_pack_input_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef input = args.at(1).refs.at(0); + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(*graph, input); + + return { + utils::safe_downcast(num_blocks_K), + utils::safe_downcast(num_blocks_M), + 1u}; +} + +// +// Prepacking nodes +// + +ValueRef prepack_quantized_linear_weight( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef qmat2_data) { + VK_CHECK_COND(weight_quant_config.nbits == 8); + + std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); + const int64_t ndim = graph.dim_of(qmat2_data); + + // Input size is [N, K]. K will be guaranteed to be a multiple of 4. + const int64_t K = qmat2_orig_sizes.at(ndim - 1); + const int64_t N = qmat2_orig_sizes.at(ndim - 2); + + // Sanity check that assumption is correct + VK_CHECK_COND(K % 4 == 0); + + // The packing format packs the weight tensor into units of 4 wide x 4 high + // blocks. To figure out the size of the output tensor, determine the number + // of blocks along each dimension. + const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); + const int64_t num_blocks_N = utils::div_up(N, int64_t(4)); + + // The blocks are arranged in a transposed manner, such that the transposed + // weight block is indexed like packed_weights[k4][n4] - this is to allow for + // optimal memory coalescing when computing GEMM. + const int64_t output_height = num_blocks_K; + // The base dtype of the packed tensor is int32 (each int32 contains 4x 8bit + // values) and each block is represented as a ivec4. Therefore the width dim + // of the packed tensor is multiplied by 4. + const int64_t output_width = num_blocks_N * 4; + + // Store the original sizes of the tensor to pass to the shader + utils::ivec2 orig_sizes{ + utils::safe_downcast(K), utils::safe_downcast(N)}; + + std::vector qmat2_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef qmat2 = graph.add_tensor( + qmat2_sizes, vkcompute::vkapi::kInt, storage_type, utils::kWidthPacked); + + // Global workgroup size: each thread writes out two adjacent blocks + utils::uvec3 global_wg_size{ + utils::safe_downcast(num_blocks_N), + utils::safe_downcast(num_blocks_K), + 1u}; + + std::string kernel_name = "pack_q8_linear_weight"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + qmat2_data, + qmat2, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(qmat2), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec2))})); + + return qmat2; +} + +// +// Dispatch nodes +// + +/* + * Shader dispatch for linear with quantized weight but fp activations. + */ +DynamicDispatchNode make_linear_qw_node( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef fp_input, + const ValueRef weight_data, + const ValueRef packed_weight, + const ValueRef packed_weight_scales, + const ValueRef packed_weight_zeros, + const ValueRef group_size, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef output) { + // Only certain quantization types supported at the moment + VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); + VK_CHECK_COND(weight_quant_config.is_symmetric); + VK_CHECK_COND(weight_quant_config.nbits == 8); + + std::string kernel_name = "linear_q8csw_tiled"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output), graph.sizes_ubo(fp_input)}; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + return DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + quantized_linear_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {{fp_input, packed_weight, packed_weight_scales, packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {apply_bias}, + // Resize args + {}, + // Resizing Logic + nullptr); +} + +DynamicDispatchNode make_quantize_and_pack_linear_input_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const ValueRef fp_input, + const ValueRef packed_input_scale, + const ValueRef packed_input_zp, + const ValueRef input_scale_data, + const ValueRef input_zp_data, + const ValueRef packed_int_input, + const ValueRef group_size) { + // Only certain quantization types supported at the moment + VK_CHECK_COND(input_quant_config.granularity == kPerTensor); + + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, fp_input); + + float inv_scale = 1.0f / graph.extract_scalar(input_scale_data); + int32_t zp = graph.extract_scalar(input_zp_data); + + std::string shader_name = "quantize_and_pack_linear_input_per_tensor"; + add_storage_type_suffix(shader_name, graph.storage_type_of(packed_int_input)); + add_storage_type_suffix(shader_name, graph.storage_type_of(fp_input)); + add_dtype_suffix(shader_name, graph.dtype_of(fp_input)); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + return DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(shader_name), + quant_pack_input_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{packed_int_input, vkapi::kWrite}, {fp_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}); +} + +DynamicDispatchNode make_linear_qa_qw_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const QuantizationConfig& weight_quant_config, + const ValueRef fp_input, + const ValueRef packed_int_input, + const ValueRef packed_input_scale, + const ValueRef packed_input_zp, + const ValueRef input_scale_data, + const ValueRef input_zp_data, + const ValueRef weight_data, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef group_size, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef output) { + VK_CHECK_COND(input_quant_config.granularity == kPerTensor); + VK_CHECK_COND(input_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); + VK_CHECK_COND(weight_quant_config.is_symmetric); + VK_CHECK_COND(weight_quant_config.nbits == 8); + + float scale = graph.extract_scalar(input_scale_data); + int32_t zp = graph.extract_scalar(input_zp_data); + + // Get shader for quantized linear + std::string kernel_name = "linear_q8ta_q8csw_tiled"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_int_input)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output), graph.sizes_ubo(packed_int_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&scale, sizeof(scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + // Add the compute node + return DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + quantized_linear_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {{packed_int_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {fp_input}, + // Resizing Logic + nullptr); +} + +// +// High level operator impl +// + +void quantized_linear_impl( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const QuantizationConfig& weight_quant_config, + const ValueRef fp_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef weight_data, + const ValueRef weight_sums_data, + const ValueRef weight_scales_data, + const ValueRef weight_zeros_data, + const ValueRef group_size, + const ValueRef bias_data, + const ValueRef output) { + std::vector input_sizes = graph.sizes_of(fp_input); + std::vector weight_sizes = graph.sizes_of(weight_data); + + const int64_t K = utils::val_at(-1, input_sizes); + // K (input channels) must be a multiple of 4 to ensure that reading a group + // of 4 input channels from the input tensor will be aligned on a texel + // boundary. + VK_CHECK_COND(K % 4 == 0); + + // Prepack weight data + + const ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + const ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + // Weight affine quant not supported at the moment + const ValueRef packed_weight_zeros = kDummyValueRef; + + // Prepack bias data + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided. This helps simplify dispatch logic and makes it so that + // fewer shdaer variants need to be generated. + TmpTensor dummy_bias( + &graph, {}, graph.dtype_of(output), utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + // Use weight only quantized linear if at least one is true: + // 1. Device does not support int8 dot product + // 2. Input is not quantized + if (!graph.can_use_int8_dot_product() || + input_quant_config.granularity == kNoQuantization) { + DynamicDispatchNode linear_qw_node(make_linear_qw_node( + graph, + weight_quant_config, + fp_input, + weight_data, + packed_weight, + packed_weight_scales, + packed_weight_zeros, + group_size, + bias_data, + packed_bias, + output)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode(linear_qw_node)); + return; + } else { + // Otherwise, use input and weight quantized linear computed with integer + // accumulation + + // Input scale/zero point only used for activation & weight quantized linear + ValueRef packed_input_scale = input_scale; + ValueRef packed_input_zp = input_zp; + if (graph.val_is_tref(input_scale)) { + VK_CHECK_COND(graph.val_is_tref(packed_input_zp)); + packed_input_scale = prepack_standard( + graph, input_scale, utils::kBuffer, utils::kWidthPacked); + packed_input_zp = prepack_standard( + graph, input_zp, utils::kBuffer, utils::kWidthPacked); + } + + // Pre-computed per quant group weight sums are needed for int accumulation, + // but not for weight only + const ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + // Allocate temporary tensor to store quantized and packed input + + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, fp_input); + + const int64_t int_input_height = num_blocks_M; + const int64_t int_input_width = num_blocks_K * 4; + + TmpTensor packed_int_input( + &graph, + {int_input_height, int_input_width}, + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + DynamicDispatchNode quantize_and_pack_linear_node( + make_quantize_and_pack_linear_input_node( + graph, + input_quant_config, + fp_input, + packed_input_scale, + packed_input_zp, + input_scale, + input_zp, + packed_int_input, + group_size)); + + graph.execute_nodes().emplace_back( + new DynamicDispatchNode(quantize_and_pack_linear_node)); + + DynamicDispatchNode linear_qa_qw_node(make_linear_qa_qw_node( + graph, + input_quant_config, + weight_quant_config, + fp_input, + packed_int_input, + packed_input_scale, + packed_input_zp, + input_scale, + input_zp, + weight_data, + packed_weight, + packed_weight_sums, + packed_weight_scales, + group_size, + bias_data, + packed_bias, + output)); + + graph.execute_nodes().emplace_back( + new DynamicDispatchNode(linear_qa_qw_node)); + } +} + +void linear_q8ta_q8csw(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef output = args.at(idx++); + + const int64_t K = graph.size_at(-1, fp_input); + + QuantizationConfig input_quant_config(8, kPerTensor, {}, false); + QuantizationConfig weight_quant_config(8, kPerChannel, {K}); + + quantized_linear_impl( + graph, + input_quant_config, + weight_quant_config, + fp_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + kDummyValueRef, // weight_zeros_data + kDummyValueRef, // group_size + bias_data, + output); +} + +void linear_q8csw(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef output = args.at(idx++); + + const int64_t K = graph.size_at(-1, fp_input); + + QuantizationConfig input_quant_config(32, kNoQuantization, {}); + QuantizationConfig weight_quant_config(8, kPerChannel, {K}); + + quantized_linear_impl( + graph, + input_quant_config, + weight_quant_config, + fp_input, + kDummyValueRef, // input scale + kDummyValueRef, // input zp + weight_data, + kDummyValueRef, // weight sums + weight_scales_data, + kDummyValueRef, // weight zeros + kDummyValueRef, // group size + bias_data, + output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.linear_q8ta_q8csw.default, linear_q8ta_q8csw); + VK_REGISTER_OP(et_vk.linear_q8csw.default, linear_q8csw); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h new file mode 100644 index 00000000000..7b62c98390d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace vkcompute { + +utils::uvec3 quantized_linear_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args); + +ValueRef prepack_quantized_linear_weight( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef qmat2_data); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/QuantizationConfig.h b/backends/vulkan/runtime/graph/ops/impl/utils/QuantizationConfig.h new file mode 100644 index 00000000000..4bc8c7c3bfc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/utils/QuantizationConfig.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +enum class QuantizationGranularity { + PerChannel, + PerTensor, + PerGroup, + NoQuantization, +}; + +static constexpr QuantizationGranularity kPerChannel = + QuantizationGranularity::PerChannel; +static constexpr QuantizationGranularity kPerTensor = + QuantizationGranularity::PerTensor; +static constexpr QuantizationGranularity kPerGroup = + QuantizationGranularity::PerGroup; +static constexpr QuantizationGranularity kNoQuantization = + QuantizationGranularity::NoQuantization; + +struct QuantizationConfig { + int nbits; + QuantizationGranularity granularity; + std::vector granularity_sizes; + bool is_symmetric; + bool is_dynamic; + + QuantizationConfig() + : nbits(8), + granularity(kPerTensor), + granularity_sizes(), + is_symmetric(true), + is_dynamic(false) {} + + QuantizationConfig( + int nbits_, + QuantizationGranularity granularity_, + const std::vector& granularity_sizes_, + bool is_symmetric_ = true, + bool is_dynamic_ = false) + : nbits(nbits_), + granularity(granularity_), + granularity_sizes(granularity_sizes_), + is_symmetric(is_symmetric_), + is_dynamic(is_dynamic_) {} +}; + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index f44db22c17e..6944fe59385 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -91,4 +91,5 @@ if(TARGET vulkan_backend) # Define operator prototypes add_operator_prototype(add) + add_operator_prototype(q8csw_linear) endif() diff --git a/backends/vulkan/test/custom_ops/q8csw_linear.cpp b/backends/vulkan/test/custom_ops/q8csw_linear.cpp new file mode 100644 index 00000000000..23973426fcc --- /dev/null +++ b/backends/vulkan/test/custom_ops/q8csw_linear.cpp @@ -0,0 +1,479 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 300; + +// Linear configuration struct +struct LinearConfig { + int64_t M; // Batch size / number of rows in input + int64_t K; // Input features / columns in input, rows in weight + int64_t N; // Output features / columns in weight + bool has_bias = true; + std::string test_case_name = "placeholder"; + std::string op_name = "linear_q8ta_q8csw"; +}; + +// Utility function to create a test case from a LinearConfig +TestCase create_test_case_from_config( + const LinearConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + config.test_case_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "et_vk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Derive sizes from M, K, N + std::vector input_size = {config.M, config.K}; + std::vector weight_size = {config.N, config.K}; + + // Input tensor (float/half) - [M, K] + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.008f; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = -2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [K, N] + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {config.N}, // Per output feature + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {config.N}, // Per output features + vkapi::kInt, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + int64_t in_features = config.K; + int64_t out_features = config.N; + compute_weight_sums(weight_sums, quantized_weight, out_features, in_features); + + // Bias (optional, float/half) - [N] + ValueSpec bias( + {config.N}, // Per output feature + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + bias.set_constant(true); + if (!config.has_bias) { + bias.set_none(true); + } + + // Output tensor (float/half) - [M, N] + ValueSpec output( + {config.M, config.N}, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + + // Add all specs to test case + if (config.op_name.find("q8ta") != std::string::npos) { + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(bias); + test_case.add_output_spec(output); + } else { + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(bias); + test_case.add_output_spec(output); + } + + return test_case; +} + +// Generate easy test cases for quantized linear operation (for debugging) +std::vector generate_quantized_linear_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + int M = 4; + int K = 4; + int N = 4; + + LinearConfig config = { + M, // Batch size + K, // Input features + N, // Output features + true, // has_bias + "simple", // test_case_name + }; + + // Test with both storage types and data types for completeness + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized linear operation +std::vector generate_quantized_linear_test_cases() { + std::vector test_cases; + + std::vector configs = { + {4, 64, 32}, + {4, 128, 64}, + {4, 256, 128}, + {32, 64, 32}, + {32, 128, 64}, + {32, 256, 128}, + // No bias tests + {32, 128, 64, false}, + {32, 256, 128, false}, + {256, 2048, 2048}, + {512, 2048, 2048}, + {1024, 2048, 2048}, + }; + + // Test with different storage types and data types + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + for (auto config : configs) { + std::string prefix = + (config.M < kRefDimSizeLimit && config.K < kRefDimSizeLimit && + config.N < kRefDimSizeLimit) + ? "correctness_" + : "performance_"; + std::string generated_test_case_name = prefix + std::to_string(config.M) + + "_" + std::to_string(config.K) + "_" + std::to_string(config.N); + if (!config.has_bias) { + generated_test_case_name += "_no_bias"; + } + + config.test_case_name = generated_test_case_name; + + for (const auto& storage_type : storage_types) { + if (vkcompute::api::context() + ->adapter_ptr() + ->supports_int8_dot_product()) { + // Test both activation+weight quantized and weight only quantized + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + + LinearConfig wo_quant_config = config; + wo_quant_config.op_name = "linear_q8csw"; + test_cases.push_back(create_test_case_from_config( + wo_quant_config, storage_type, vkapi::kFloat)); + } + } + + return test_cases; +} + +// Reference implementation for weight only quantized linear +void linear_q8csw_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [out_features, in_features] + auto output_sizes = + output_spec.get_tensor_sizes(); // [batch_size, out_features] + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + // Skip for large tensors since computation time will be extremely slow + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate number of output elements + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform quantized linear transformation (matrix multiplication) + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + float sum = 0.0f; + + // Matrix multiplication: output[b][out_f] = sum(input[b][in_f] * + // weight[out_f][in_f]) + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + // Get input value and dequantize + int64_t input_idx = b * in_features + in_f; + float input_val = input_data[input_idx]; + + // Get weight value and dequantize + int64_t weight_idx = out_f * in_features + in_f; + float dequant_weight = (static_cast(weight_data[weight_idx])) * + weight_scales_data[out_f]; + + sum += input_val * dequant_weight; + } + + // Add bias and store result + if (!bias_spec.is_none()) { + sum += bias_data[out_f]; + } + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = sum; + } + } +} + +void linear_q8ta_q8csw_reference_impl(TestCase& test_case) { + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [out_features, in_features] + auto output_sizes = + output_spec.get_tensor_sizes(); // [batch_size, out_features] + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + // Skip for large tensors since computation time will be extremely slow + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate number of output elements + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform quantized linear transformation (matrix multiplication) with + // integer accumulation + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + int32_t int_sum = 0; + int32_t weight_sum = 0; // Track weight sum on the fly + + // Matrix multiplication with integer accumulation: + // int_sum = sum(quantized_input[b][in_f] * quantized_weight[out_f][in_f]) + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + // Get input value and quantize to int8 + int64_t input_idx = b * in_features + in_f; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + input_zero_point; + quant_input_f = std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Get quantized weight (already int8) + int64_t weight_idx = out_f * in_features + in_f; + int8_t quantized_weight = weight_data[weight_idx]; + + // Integer multiplication and accumulation + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } + + // Convert accumulated integer result to float and apply scales + // Final result = (int_sum - zero_point_correction) * input_scale * + // weight_scale + bias zero_point_correction = input_zero_point * + // sum_of_weights_for_this_output_channel + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_f]; + + // Add bias and store result + if (!bias_spec.is_none()) { + float_result += bias_data[out_f]; + } + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = float_result; + } + } +} + +void reference_impl(TestCase& test_case) { + if (test_case.operator_name().find("q8ta") != std::string::npos) { + linear_q8ta_q8csw_reference_impl(test_case); + } else { + linear_q8csw_reference_impl(test_case); + } +} + +int64_t quantized_linear_flop_calculator(const TestCase& test_case) { + int input_idx = 0; + int weight_idx = 1; + if (test_case.operator_name().find("q8ta") != std::string::npos) { + input_idx = 0; + weight_idx = 3; + } + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[input_idx].get_tensor_sizes(); + const auto& weight_sizes = test_case.inputs()[weight_idx].get_tensor_sizes(); + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + // Calculate FLOPs for quantized linear operation + // Each output element requires: + // - in_features multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = batch_size * out_features; + int64_t ops_per_output = in_features; + + // Add quantization overhead (approximate) + // - Dequantize input: 1 op per input element used + // - Dequantize weight: 1 op per weight element used + // - Add bias: 1 op per output element + int64_t quantization_ops = ops_per_output + 1; // Simplified estimate + + int64_t flop = output_elements * (ops_per_output + quantization_ops); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Quantized Linear Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + auto results = execute_test_cases( + generate_quantized_linear_test_cases, + quantized_linear_flop_calculator, + "QuantizedLinear", + 0, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 2ddf49834e1..68bdc9e6fbd 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -12,7 +12,6 @@ def define_custom_op_test_binary(custom_op_name, extra_deps = [], src_file = Non ":operator_implementations", ":custom_ops_shaderlib", "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/runtime/core/exec_aten:lib", runtime.external_dep_location("libtorch"), ] + extra_deps @@ -68,8 +67,6 @@ def define_common_targets(is_fbcode = False): platforms = get_platforms(), deps = [ "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/runtime/core/exec_aten:lib", - runtime.external_dep_location("libtorch"), ], visibility = [ "//executorch/backends/vulkan/test/custom_ops/...", @@ -86,7 +83,6 @@ def define_common_targets(is_fbcode = False): platforms = get_platforms(), deps = [ "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/runtime/core/exec_aten:lib", ":custom_ops_shaderlib", ], visibility = [ @@ -97,3 +93,4 @@ def define_common_targets(is_fbcode = False): ) define_custom_op_test_binary("add") + define_custom_op_test_binary("q8csw_linear") diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index 235a6bd293e..ee2f6858025 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -694,7 +694,8 @@ void BenchmarkResult::print_summary( int case_number, const std::string& size_info, float total_gflops) const { - static constexpr int KERNEL_NAME_WIDTH = 140; + static constexpr int OPERATOR_NAME_WIDTH = 50; + static constexpr int KERNEL_NAME_WIDTH = 70; static constexpr int SIZE_INFO_WIDTH = 20; static constexpr int TIMING_WIDTH = 20; static constexpr int GFLOPS_WIDTH = 20; @@ -713,8 +714,10 @@ void BenchmarkResult::print_summary( break; } - std::cout << std::left << std::setw(KERNEL_NAME_WIDTH) << get_kernel_name() - << std::right << " " << std::setw(SIZE_INFO_WIDTH) << size_info + std::cout << std::left << std::setw(OPERATOR_NAME_WIDTH) + << get_operator_name() << " " << std::left + << std::setw(KERNEL_NAME_WIDTH) << get_kernel_name() << std::right + << " " << std::setw(SIZE_INFO_WIDTH) << size_info << std::setw(TIMING_WIDTH) << std::fixed << std::setprecision(3) << get_avg_time_us() << " μs " << std::setw(GFLOPS_WIDTH) << std::fixed << std::setprecision(3) << total_gflops << " GFLOP/s " @@ -999,7 +1002,9 @@ ComputeGraph setup_compute_graph(TestCase& test_case, std::string op_name) { for (size_t i = 0; i < test_case.num_inputs(); ++i) { const ValueSpec& input_spec = test_case.inputs()[i]; - if (input_spec.is_float()) { + if (input_spec.is_none()) { + input_values.push_back(graph.add_none()); + } else if (input_spec.is_float()) { ValueRef input_value = graph.add_scalar(static_cast(input_spec.get_float_value())); input_values.push_back(input_value); @@ -1246,9 +1251,11 @@ TestResult execute_test_cases( bool shader_not_supported = false; try { result = execute_test_case(test_case, warmup_runs, benchmark_runs); + result.set_operator_name(test_case.operator_name()); } catch (const vkcompute::vkapi::ShaderNotSupportedError& e) { result = BenchmarkResult( - test_case.name().empty() ? "unnamed_test_case" : test_case.name()); + test_case.name().empty() ? "unnamed_test_case" : test_case.name(), + test_case.operator_name()); shader_not_supported = true; } @@ -1606,20 +1613,21 @@ void compute_weight_sums( const ValueSpec& quantized_weight, int64_t out_features, int64_t elements_per_output_feature) { - auto& weight_sums_data = weight_sums.get_float_data(); + auto& weight_sums_data = weight_sums.get_int32_data(); auto& quantized_weight_data = quantized_weight.get_int8_data(); weight_sums_data.resize(out_features); // For each output feature, compute the sum of quantized weights for (int64_t out_f = 0; out_f < out_features; ++out_f) { - float sum = 0.0f; + int32_t sum = 0; for (int64_t elem = 0; elem < elements_per_output_feature; ++elem) { // Weight indexing depends on the layout: - // For linear: [in_features, out_features] -> elem * out_features + out_f - // For conv2d: [C_in * K_h * K_w, C_out] -> elem * out_features + out_f - int64_t weight_idx = elem * out_features + out_f; - sum += static_cast(quantized_weight_data[weight_idx]); + // For linear: [out_features, in_features] -> out_f * + // elements_per_output_feature + elem For conv2d: [C_out, C_in * K_h * + // K_w] -> out_f * elements_per_output_feature + elem + int64_t weight_idx = out_f * elements_per_output_feature + elem; + sum += static_cast(quantized_weight_data[weight_idx]); } weight_sums_data[out_f] = sum; } diff --git a/backends/vulkan/test/custom_ops/utils.h b/backends/vulkan/test/custom_ops/utils.h index 5ca05dc824f..6c4e2263fc1 100644 --- a/backends/vulkan/test/custom_ops/utils.h +++ b/backends/vulkan/test/custom_ops/utils.h @@ -66,6 +66,7 @@ struct ValueSpec { SpecType spec_type; DataGenType data_gen_type; bool is_constant_tensor; + bool is_none_flag; std::vector float_data; std::vector int32_data; @@ -90,7 +91,8 @@ struct ValueSpec { storage_type(storage_type), spec_type(SpecType::Tensor), data_gen_type(DataGenType::ZEROS), - is_constant_tensor(false) { + is_constant_tensor(false), + is_none_flag(false) { generate_tensor_data(); } @@ -107,7 +109,8 @@ struct ValueSpec { storage_type(storage_type), spec_type(SpecType::Tensor), data_gen_type(data_gen_type), - is_constant_tensor(false) { + is_constant_tensor(false), + is_none_flag(false) { generate_tensor_data(); } @@ -119,7 +122,8 @@ struct ValueSpec { storage_type(utils::kTexture3D), spec_type(SpecType::Int), data_gen_type(DataGenType::FIXED), - is_constant_tensor(false) { + is_constant_tensor(false), + is_none_flag(false) { int32_data.push_back(value); } @@ -131,7 +135,8 @@ struct ValueSpec { storage_type(utils::kTexture3D), spec_type(SpecType::Float), data_gen_type(DataGenType::FIXED), - is_constant_tensor(false) { + is_constant_tensor(false), + is_none_flag(false) { float_data.push_back(value); } @@ -143,7 +148,8 @@ struct ValueSpec { storage_type(utils::kTexture3D), spec_type(SpecType::Bool), data_gen_type(DataGenType::FIXED), - is_constant_tensor(false) { + is_constant_tensor(false), + is_none_flag(false) { int32_data.push_back(value ? 1 : 0); } @@ -156,6 +162,7 @@ struct ValueSpec { spec_type(SpecType::IntList), data_gen_type(DataGenType::FIXED), is_constant_tensor(false), + is_none_flag(false), int32_data(values) {} // Default constructor @@ -165,7 +172,8 @@ struct ValueSpec { storage_type(utils::kTexture3D), spec_type(SpecType::Tensor), data_gen_type(DataGenType::ZEROS), - is_constant_tensor(false) {} + is_constant_tensor(false), + is_none_flag(false) {} int64_t numel() const; size_t nbytes() const; @@ -279,6 +287,14 @@ struct ValueSpec { is_constant_tensor = is_constant; } + // Set/get none flag + bool is_none() const { + return is_none_flag; + } + void set_none(bool is_none) { + is_none_flag = is_none; + } + const void* get_data_ptr() const; // Correctness checking against reference data @@ -401,6 +417,13 @@ class BenchmarkResult { BenchmarkResult(const std::string& name) : kernel_name(name), correctness_status_(CorrectnessStatus::SKIPPED) {} + BenchmarkResult( + const std::string& kernel_name, + const std::string& operator_name) + : kernel_name(kernel_name), + operator_name(operator_name), + correctness_status_(CorrectnessStatus::SKIPPED) {} + // Add timing for a single iteration void add_iter_timing(float time_us); @@ -408,6 +431,9 @@ class BenchmarkResult { const std::string& get_kernel_name() const { return kernel_name; } + const std::string& get_operator_name() const { + return operator_name; + } float get_avg_time_us() const; size_t get_num_iterations() const { return iter_timings.size(); @@ -423,6 +449,9 @@ class BenchmarkResult { void set_kernel_name(const std::string& name) { kernel_name = name; } + void set_operator_name(const std::string& name) { + operator_name = name; + } void set_correctness_status(CorrectnessStatus status) { correctness_status_ = status; } @@ -445,6 +474,7 @@ class BenchmarkResult { private: std::string kernel_name; + std::string operator_name; std::vector iter_timings; // Individual iteration timings in microseconds CorrectnessStatus correctness_status_;