diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh new file mode 100644 index 00000000000..66620e9b174 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CHOOSE_QPARAMS_GLSLH +#define CHOOSE_QPARAMS_GLSLH + +// equivalent of the eps defined in the cpu implementation +#define SMALL_SCALE_THRESHOLD 6.1e-5 + +// Calculate scale and zero point from min and max values +void calculate_scale_and_zero_point( + float min_val, + float max_val, + int qmin, + int qmax, + out float scale_val, + out int zero_point_val) { + // ensure we have zero included in our range + min_val = min(min_val, 0.0); + max_val = max(max_val, 0.0); + + scale_val = (max_val - min_val) / float(qmax - qmin); + + // Handle zero or very small scale + if (scale_val == 0.0 || isinf(1.0 / scale_val)) { + scale_val = 0.1; + } + + // Cut off small scale + if (scale_val < SMALL_SCALE_THRESHOLD) { + float org_scale = scale_val; + scale_val = SMALL_SCALE_THRESHOLD; + + // Adjust min and max based on new scale + if (min_val == 0.0) { + max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin); + } else if (max_val == 0.0) { + min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min_val *= amplifier; + max_val *= amplifier; + } + } + + // Calculate zero point + float zero_point_from_min = float(qmin) - min_val / scale_val; + float zero_point_from_max = float(qmax) - max_val / scale_val; + float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale_val); + float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale_val); + float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Nudge zero point to integer + if (initial_zero_point < float(qmin)) { + zero_point_val = qmin; + } else if (initial_zero_point > float(qmax)) { + zero_point_val = qmax; + } else { + zero_point_val = int(round(initial_zero_point)); + } +} + +#endif // CHOOSE_QPARAMS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl new file mode 100644 index 00000000000..dcbfe493f34 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl @@ -0,0 +1,278 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} + +#define ${MODE} + +${define_active_storage_type("buffer")} +${define_required_extensions(IN_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")} +${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + int quant_min; + int quant_max; + }; +$else: + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec4", "t_in_sizes")} +${layout_declare_ubo(B, "ivec4", "t_in_strides")} +${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} +${layout_declare_ubo(B, "ivec4", "t_scale_strides")} +${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} +${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} + +#include "indexing_utils.h" +#include "choose_qparams.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#define NWORKERS 64 + +// Shared memory for reduction - must match local work group size +shared float shared_min[NWORKERS]; +shared float shared_max[NWORKERS]; + +/* + * QUANTIZATION PARAMETER COMPUTATION SHADER (BUFFER STORAGE) + * + * This shader computes quantization parameters (scale and zero_point) for converting + * floating-point tensors to n-bit integer representations while preserving the + * original data range as much as possible. + * + * ALGORITHM: + * 1. Find global min/max values across tensor elements using parallel reduction + * 2. Use tree reduction with shared memory for efficient min/max computation + * 3. Calculate scale = (max - min) / (quant_max - quant_min) + * 4. Calculate zero_point to map floating-point zero to integer value + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {1, 1, 1} (single workgroup processes entire tensor) + * - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory) + * - Per-Token Mode: + * - Global WG Size: {num_tokens, 1, 1} (one workgroup per token) + * - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory) + * + * SUPPORTED CONFIGURATIONS: + * - Buffer Storage: Uses simple linear indexing through buffer elements + * - No axis mapping or packing considerations - processes elements sequentially + * - Works with any tensor layout since it accesses buffer data linearly + * + * TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING: + * For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]: + * + * Initial shared_min/shared_max arrays populated by each thread: + * shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + * shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + * Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + * + * Stride 1 (compare pairs, keep min/max): + * shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) + * shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) + * Active: | 0 | | 2 | | 4 | | 6 | | + * + * Stride 2 (compare pairs, keep min/max): + * shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3)) + * shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) + * Active: | 0 | | | | 4 | | | | + * + * Stride 4 (final comparison): + * shared_min: | 0 | | | | | | | | (min(0,0) = 0) + * shared_max: | 10 | | | | | | | | (max(10,5) = 10) + * Active: | 0 | | | | | | | | + * + * Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) + * + * PER-TENSOR QUANTIZATION: + * - Single workgroup processes entire tensor with strided access + * - Each thread processes elements [thread_id, thread_id + 64, thread_id + 128, ...] + * - Tree reduction combines all thread results into global min/max + * - Output: Single scale and zero_point values + * + * PER-TOKEN QUANTIZATION: + * - Multiple workgroups, each processing one token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Each workgroup finds min/max within its assigned token + * - Output: Array of scale and zero_point values (one per token) + */ + +#ifdef per_tensor + +void choose_qparams_per_tensor() { + uint global_id = gl_GlobalInvocationID.x; + uint local_id = gl_LocalInvocationID.x; + uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; + + uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); + + // Each thread processes multiple elements with stride + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + for (uint i = global_id; i < total_elements; i += total_threads) { + float val = t_in[i]; + if (!isnan(val) && !isinf(val)) { + if (!found_valid) { + thread_min = val; + thread_max = val; + found_valid = true; + } else { + thread_min = min(thread_min, val); + thread_max = max(thread_max, val); + } + } + } + + // Intra-group reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Final result calculation (single workgroup only) + if (local_id == 0) { + float global_min = shared_min[0]; + float global_max = shared_max[0]; + + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + + t_scale[0] = scale_val; + t_zero_point[0] = zero_point_val; + } +} + +#else + +void choose_qparams_per_token() { + uint global_id = gl_GlobalInvocationID.x; + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_workgroups = gl_NumWorkGroups.x; + + uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); + uint token_size = total_elements / uint(num_tokens); + + // Calculate how many tokens each workgroup should process + // This handles the case where we have more tokens than workgroups + uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; + + // Calculate which tokens this workgroup is responsible for + uint start_token = group_id * tokens_per_workgroup; + uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); + + // Early exit if this workgroup has no tokens to process + if (start_token >= uint(num_tokens)) { + return; + } + + // Process each token assigned to this workgroup + for (uint token_id = start_token; token_id < end_token; token_id++) { + // Calculate the start and end indices for this token + uint token_start = token_id * token_size; + uint token_end = token_start + token_size; + + // Each thread processes multiple elements within the token with stride + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + // Process elements within this token only + for (uint i = token_start + local_id; i < token_end; i += gl_WorkGroupSize.x) { + float val = t_in[i]; + if (!isnan(val) && !isinf(val)) { + if (!found_valid) { + thread_min = val; + thread_max = val; + found_valid = true; + } else { + thread_min = min(thread_min, val); + thread_max = max(thread_max, val); + } + } + } + + // Intra-group reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Final calculation for this token + if (local_id == 0) { + float token_min = shared_min[0]; + float token_max = shared_max[0]; + + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + + t_scale[token_id] = scale_val; + t_zero_point[token_id] = zero_point_val; + } + + // Synchronize before processing next token + barrier(); + } +} + +#endif + +void main() { + choose_qparams_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml new file mode 100644 index 00000000000..c37039f68e9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml @@ -0,0 +1,12 @@ +choose_qparams_buffer: + parameter_names_with_default_values: + IN_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: float + shader_variants: + - NAME: choose_qparams_tensor_buffer + MODE: per_tensor + - NAME: choose_qparams_per_token_asymmetric_buffer + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl new file mode 100644 index 00000000000..282f1de170a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl @@ -0,0 +1,398 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} + +#define ${MODE} + +${define_active_storage_type("texture3d")} +${define_required_extensions(IN_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")} +${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + int quant_min; + int quant_max; + }; +$else: + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec3", "t_in_limits")} +${layout_declare_ubo(B, "ivec3", "t_scale_limits")} +${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} + +#include "indexing_utils.h" +#include "choose_qparams.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#define NWORKERS 64 + +// Shared memory for reduction - must match local work group size +shared float shared_min[NWORKERS]; +shared float shared_max[NWORKERS]; + +/* + * QUANTIZATION PARAMETER COMPUTATION SHADER (TEXTURE STORAGE) + * + * This shader computes quantization parameters (scale and zero_point) for converting + * floating-point tensors to n-bit integer representations while preserving the + * original data range as much as possible. + * + * ALGORITHM: + * 1. Find global min/max values across tensor elements using parallel reduction + * 2. Use tree reduction with shared memory for efficient min/max computation + * 3. Calculate scale = (max - min) / (quant_max - quant_min) + * 4. Calculate zero_point to map floating-point zero to integer value + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: Default (typically {num_elements, 1, 1}) + * - Local WG Size: Default (typically {64, 1, 1}) + * - Per-Token Mode: + * - Global WG Size: Default (typically based on tensor dimensions) + * - Local WG Size: Default (typically {64, 1, 1}, or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Texture Storage: Uses 3D texture indexing with linear texel iteration + * - Assumes width-packed layout (packed_dim = 0) in current implementation + * - Handles texel padding for non-multiple-of-4 tensor dimensions + * - Note: Axis mapping support depends on indexing utilities + * + * TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING: + * For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]: + * + * Initial shared_min/shared_max arrays populated by each thread: + * shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + * shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + * Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + * + * Stride 1 (compare pairs, keep min/max): + * shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) + * shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) + * Active: | 0 | | 2 | | 4 | | 6 | | + * + * Stride 2 (compare pairs, keep min/max): + * shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3)) + * shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) + * Active: | 0 | | | | 4 | | | | + * + * Stride 4 (final comparison): + * shared_min: | 0 | | | | | | | | (min(0,0) = 0) + * shared_max: | 10 | | | | | | | | (max(10,5) = 10) + * Active: | 0 | | | | | | | | + * + * Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) + * + * PER-TENSOR QUANTIZATION: + * - Single workgroup processes entire tensor + * - Each thread processes multiple texels with stride + * - Thread 0: texels [0, 64, 128, ...] -> elements [0-3, 256-259, 512-515, ...] + * - Thread 1: texels [1, 65, 129, ...] -> elements [4-7, 260-263, 516-519, ...] + * - Tree reduction combines all thread results into global min/max + * - Output: Single scale and zero_point values + * + * PER-TOKEN QUANTIZATION: + * - Multiple workgroups, each processing subset of tokens + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Each workgroup processes multiple tokens if num_tokens > num_workgroups + * - Within each token, threads process texels containing token elements + * - Output: Array of scale and zero_point values (one per token) + */ + +#ifdef per_tensor + +void choose_qparams_per_tensor() { + uint global_id = gl_GlobalInvocationID.x; + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; + + uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); + + // Each thread processes multiple texels with stride + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + // Process texels with stride across all threads + for (uint texel_idx = global_id; texel_idx < total_texels; texel_idx += total_threads) { + // Convert linear texel index to 3D coordinates + uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); + uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); + uint y = remainder / uint(t_in_limits.x); + uint x = remainder % uint(t_in_limits.x); + ivec3 texel_pos = ivec3(int(x), int(y), int(z)); + + FVEC4_T texel_data = load_texel(t_in, texel_pos); + + // For texture storage, we assume width-packed (packed_dim = 0) + // Calculate number of valid elements in this texel (handle padding) + int packed_dim = 0; // Width dimension is packed + ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format + ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); + + // Calculate total tensor elements to determine padding + int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; + int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + + tensor_coord.z * sizes.x * sizes.y; + int remaining_elements = total_elements - (linear_tensor_idx); + int valid_elements = min(4, remaining_elements); + + // Find min/max within this texel, considering only valid elements + if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { + if (!found_valid) { + thread_min = texel_data.x; + thread_max = texel_data.x; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.x); + thread_max = max(thread_max, texel_data.x); + } + } + + if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { + if (!found_valid) { + thread_min = texel_data.y; + thread_max = texel_data.y; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.y); + thread_max = max(thread_max, texel_data.y); + } + } + + if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { + if (!found_valid) { + thread_min = texel_data.z; + thread_max = texel_data.z; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.z); + thread_max = max(thread_max, texel_data.z); + } + } + + if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { + if (!found_valid) { + thread_min = texel_data.w; + thread_max = texel_data.w; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.w); + thread_max = max(thread_max, texel_data.w); + } + } + } + + // Intra-workgroup reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Final result calculation (single workgroup only for reliability) + if (local_id == 0 && group_id == 0) { + float global_min = shared_min[0]; + float global_max = shared_max[0]; + + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + + write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); + write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); + } +} + +#else + +void choose_qparams_per_token() { + // Each token is processed by multiple workgroups for parallel reduction + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_workgroups = gl_NumWorkGroups.x; + + uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); + + // Calculate texels per token (assuming last dimension contains the token data) + // For per-token quantization, we assume tokens are along the last dimension + uint texels_per_token = total_texels / uint(num_tokens); + + // Calculate how many tokens each workgroup should process + uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; + + // Calculate which tokens this workgroup is responsible for + uint start_token = group_id * tokens_per_workgroup; + uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); + + // Process each token assigned to this workgroup + for (uint token_id = start_token; token_id < end_token; token_id++) { + // Calculate the texel range for this token + uint token_start_texel = token_id * texels_per_token; + uint token_end_texel = token_start_texel + texels_per_token; + + // Each thread processes multiple texels within the token + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + // Process texels within this token only + for (uint texel_idx = token_start_texel + local_id; texel_idx < token_end_texel; texel_idx += gl_WorkGroupSize.x) { + // Convert linear texel index to 3D coordinates + uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); + uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); + uint y = remainder / uint(t_in_limits.x); + uint x = remainder % uint(t_in_limits.x); + ivec3 texel_pos = ivec3(int(x), int(y), int(z)); + + FVEC4_T texel_data = load_texel(t_in, texel_pos); + + // For texture storage, we assume width-packed (packed_dim = 0) + // Calculate number of valid elements in this texel (handle padding) + int packed_dim = 0; // Width dimension is packed + ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format + ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); + + // Calculate total tensor elements to determine padding + int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; + int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + + tensor_coord.z * sizes.x * sizes.y; + int remaining_elements = total_elements - (linear_tensor_idx); + int valid_elements = min(4, remaining_elements); + + // Find min/max within this texel, considering only valid elements + if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { + if (!found_valid) { + thread_min = texel_data.x; + thread_max = texel_data.x; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.x); + thread_max = max(thread_max, texel_data.x); + } + } + + if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { + if (!found_valid) { + thread_min = texel_data.y; + thread_max = texel_data.y; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.y); + thread_max = max(thread_max, texel_data.y); + } + } + + if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { + if (!found_valid) { + thread_min = texel_data.z; + thread_max = texel_data.z; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.z); + thread_max = max(thread_max, texel_data.z); + } + } + + if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { + if (!found_valid) { + thread_min = texel_data.w; + thread_max = texel_data.w; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.w); + thread_max = max(thread_max, texel_data.w); + } + } + } + + // Intra-workgroup reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + // Handle infinity values properly + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Final calculation for this token + if (local_id == 0) { + float token_min = shared_min[0]; + float token_max = shared_max[0]; + + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + + // Convert token_id to 3D coordinates for output texture + // Assuming output tensors have the same layout as input but with different dimensions + uint out_z = token_id / uint(t_scale_limits.x * t_scale_limits.y); + uint out_remainder = token_id % uint(t_scale_limits.x * t_scale_limits.y); + uint out_y = out_remainder / uint(t_scale_limits.x); + uint out_x = out_remainder % uint(t_scale_limits.x); + ivec3 out_pos = ivec3(int(out_x), int(out_y), int(out_z)); + + write_texel(t_scale, out_pos, vec4(scale_val, 0.0, 0.0, 0.0)); + write_texel(t_zero_point, out_pos, ivec4(zero_point_val, 0, 0, 0)); + } + + // Synchronize before processing next token + barrier(); + } +} + +#endif + +void main() { + choose_qparams_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml new file mode 100644 index 00000000000..f3961b87a0f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml @@ -0,0 +1,12 @@ +choose_qparams_texture: + parameter_names_with_default_values: + IN_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: float + shader_variants: + - NAME: choose_qparams_tensor_texture3d + MODE: per_tensor + - NAME: choose_qparams_per_token_asymmetric_texture3d + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh new file mode 100644 index 00000000000..7194bebda35 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh @@ -0,0 +1,16 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef DEQUANTIZE_GLSLH +#define DEQUANTIZE_GLSLH + +OUT_T dequantize_val(IN_T qvalue, float scale_val, int zero_point_val) { + return OUT_T(float(int(qvalue) - zero_point_val) * scale_val); +} + +#endif // DEQUANTIZE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl new file mode 100644 index 00000000000..2a1f62719a0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl @@ -0,0 +1,183 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +#define ${MODE} + +${define_active_storage_type("buffer")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "int", "out_numel")} +${layout_declare_ubo(B, "ivec4", "t_in_sizes")} +${layout_declare_ubo(B, "ivec4", "t_in_strides")} +${layout_declare_ubo(B, "ivec4", "t_out_sizes")} +${layout_declare_ubo(B, "ivec4", "t_out_strides")} + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} + +#include "dequantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); +const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); + +/* + * DEQUANTIZATION SHADER (BUFFER STORAGE) + * + * This shader converts n-bit integer tensor values back to floating-point representations + * using pre-computed quantization parameters (scale and zero_point). The dequantization + * reconstructs the original floating-point values from their discrete integer representations + * with minimal precision loss. + * + * ALGORITHM: + * 1. Load quantized integer value from buffer + * 2. Apply dequantization formula: value = (qvalue - zero_point) * scale + * 3. Store reconstructed floating-point value to output buffer + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Buffer Storage: Uses linear buffer indexing with stride-based tensor access + * - Per-Tensor: Supports any tensor layout through stride calculations and dimension ordering + * - Per-Token: Supports only width packed tensors (packed_dim = 0) and standard axis mapping + * - Scale/zero_point tensors: Must use buffer storage with width packing (packed_dim = 0) + * + * DEQUANTIZATION FORMULA VISUALIZATION: + * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: + * + * Integer Domain: Floating Point Domain: + * quant_min ──────────────► min_val + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * quant_max ──────────────► max_val + * + * Dequantization Process: + * Input: -103 (int8) + * Step 1: qvalue - zero_point = -103 - (-128) = 25 + * Step 2: result * scale = 25 * 0.1 = 2.5 + * Output: 2.5 (float) + * + * PER-TENSOR DEQUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All elements use same dequantization parameters + * - Parameters passed as push constants for efficiency + * - Formula: value = (qvalue - zero_point) * scale + * + * PER-TOKEN DEQUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates its token_id from tensor coordinates + * - Formula: value = (qvalue - zero_point[token_id]) * scale[token_id] + * + * Token ID calculation for element at tensor index (w, z, y, x): + * - 4D tensor: token_id = w * (sizes.z * sizes.y) + z * sizes.y + y + * - 3D tensor: token_id = z * sizes.y + y + * - 2D tensor: token_id = y + * - 1D tensor: token_id = 0 + */ + +#ifdef per_tensor + +void dequantize_per_tensor() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + OUT_T value = dequantize_val(qvalue, scale, zero_point); + + t_out[out_bufi] = value; +} + +#else + +void dequantize_per_token() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + + int token_idx = 0; + + if (t_out_sizes.w > 1) { + // 4D tensor + token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.z > 1) { + // 3D tensor + token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.y > 1) { + // 2D tensor + token_idx = out_tidx.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + OUT_T value = dequantize_val(qvalue, t_scale[token_idx], t_zero_point[token_idx]); + + t_out[out_bufi] = value; +} + +#endif + +void main() { + dequantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml new file mode 100644 index 00000000000..4e434935356 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -0,0 +1,18 @@ +dequantize_buffer: + parameter_names_with_default_values: + IN_DTYPE: int32 + OUT_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + OUT_DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: dequantize_per_tensor_buffer + MODE: per_tensor + - NAME: dequantize_per_token_buffer + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl new file mode 100644 index 00000000000..cfc61dd1816 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -0,0 +1,190 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define IVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} + +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} +#define FVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} + +#define ${MODE} + +${define_active_storage_type("texture3d")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec3", "t_in_limits")} +${layout_declare_ubo(B, "ivec3", "t_out_limits")} + +#include "indexing_utils.h" +#include "dequantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * DEQUANTIZATION SHADER (TEXTURE STORAGE) + * + * This shader converts n-bit integer tensor values back to floating-point representations + * using pre-computed quantization parameters (scale and zero_point). The dequantization + * reconstructs the original floating-point values from their discrete integer representations + * with minimal precision loss. + * + * ALGORITHM: + * 1. Load quantized integer texel (4 values) from 3D texture + * 2. Apply dequantization formula to each component: value = (qvalue - zero_point) * scale + * 3. Store reconstructed floating-point texel to output texture + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Texture Storage: Uses 3D texture indexing with texel-based processing + * - Assumes width-packed layout (packed_dim = 0) for input/output textures + * - Handles texel padding for non-multiple-of-4 tensor dimensions + * - For per-token mode: scale/zero_point tensors must use buffer storage + * - Input/output textures: Must use standard axis mapping for per-token mode + * + * DEQUANTIZATION FORMULA VISUALIZATION: + * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: + * + * Integer Domain: Floating Point Domain: + * quant_min ──────────────► min_val + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * quant_max ──────────────► max_val + * + * Texel Dequantization Process: + * Input Texel: [-103, -128, -123, -96] (int4) + * Per-component dequantization with scale=0.1, zero_point=-128: + * Component 0: (-103 - (-128)) * 0.1 = 25 * 0.1 = 2.5 + * Component 1: (-128 - (-128)) * 0.1 = 0 * 0.1 = 0.0 + * Component 2: (-123 - (-128)) * 0.1 = 5 * 0.1 = 0.5 + * Component 3: (-96 - (-128)) * 0.1 = 32 * 0.1 = 3.2 + * Output Texel: [2.5, 0.0, 0.5, 3.2] (float4) + * + * PER-TENSOR DEQUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All texel components use same dequantization parameters + * - Parameters passed as push constants for efficiency + * - Each thread processes one texel (4 elements) independently + * - Formula: value[i] = (qvalue[i] - zero_point) * scale + * + * PER-TOKEN DEQUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates token_id from its 3D texture position + * - Scale/zero_point buffers accessed directly (not as textures) + * - Formula: value[i] = (qvalue[i] - zero_point[token_id]) * scale[token_id] + * + * Token ID calculation for texel at position (x, y, z): + * - 3D tensor: token_id = z * texture_height + y + * - 2D tensor: token_id = y + * - 1D tensor: token_id = 0 + */ + +#ifdef per_tensor + +void dequantize_per_tensor() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + // Skip if out of bounds + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + IVEC4_T intex = load_texel(t_in, pos); + FVEC4_T outtex; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale, zero_point); + outtex[i] = value; + } + write_texel(t_out, pos, outtex); +} + +#else + +void dequantize_per_token() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + IVEC4_T intex = load_texel(t_in, pos); + + int token_idx = 0; + ivec3 dims = t_in_limits; + + if (dims.z > 1) { + // 3D tensor + token_idx = pos.z * dims.y + pos.y; + } else if (dims.y > 1) { + // 2D tensor + token_idx = pos.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + // Scale and zero_point are prepacked as buffers, so direct access + float scale_val = t_scale[token_idx]; + int zero_point_val = t_zero_point[token_idx]; + + FVEC4_T outtex; + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + outtex[i] = value; + } + + write_texel(t_out, pos, outtex); +} + +#endif + +void main() { + dequantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml new file mode 100644 index 00000000000..fc8c18468ed --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -0,0 +1,18 @@ +dequantize_texture: + parameter_names_with_default_values: + IN_DTYPE: int32 + OUT_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + OUT_DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: dequantize_per_tensor_texture3d + MODE: per_tensor + - NAME: dequantize_per_token_texture3d + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize.glslh b/backends/vulkan/runtime/graph/ops/glsl/quantize.glslh new file mode 100644 index 00000000000..cde72e41ac7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize.glslh @@ -0,0 +1,25 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef QUANTIZE_GLSLH +#define QUANTIZE_GLSLH + +OUT_T quantize_val(IN_T value, float scale_val, int zero_point_val) { + float inv_scale = 1.0 / scale_val; + + float rounded_float = round(inv_scale * float(value)); + + int qvalue = zero_point_val + int(rounded_float); + + qvalue = max(qvalue, quant_min); + qvalue = min(qvalue, quant_max); + + return OUT_T(qvalue); +} + +#endif // QUANTIZE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl new file mode 100644 index 00000000000..ea0c2f7dce7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl @@ -0,0 +1,179 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +#define ${MODE} + +${define_active_storage_type("buffer")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "int", "out_numel")} +${layout_declare_ubo(B, "ivec4", "t_in_sizes")} +${layout_declare_ubo(B, "ivec4", "t_in_strides")} +${layout_declare_ubo(B, "ivec4", "t_out_sizes")} +${layout_declare_ubo(B, "ivec4", "t_out_strides")} + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} + +#include "quantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); +const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); + +/* + * QUANTIZATION SHADER (BUFFER STORAGE) + * + * This shader converts floating-point tensor values to n-bit integer representations + * using pre-computed quantization parameters (scale and zero_point). The quantization + * maps floating-point values to a discrete integer range while preserving the + * original data distribution as much as possible. + * + * ALGORITHM: + * 1. Load floating-point input value from buffer + * 2. Apply quantization formula: qvalue = round(value / scale) + zero_point + * 3. Clamp result to [quant_min, quant_max] range + * 4. Store quantized integer value to output buffer + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Per-Tensor Config: Uses linear buffer indexing with stride-based tensor access + * - and supports any tensor layout through stride calculations and dimension ordering + * - Per-Token Config: Assumes width-packed layout (packed_dim = 0) + * - since that is how token index is calculated + * + * QUANTIZATION FORMULA VISUALIZATION: + * For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]: + * + * Floating Point Domain: Integer Domain: + * min_val ────────────────► quant_min + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * max_val ────────────────► quant_max + * + * Quantization Process: + * Input: 2.5 (float) + * Step 1: value / scale = 2.5 / 0.1 = 25.0 + * Step 2: round(25.0) + zero_point = 25 + (-128) = -103 + * Step 3: clamp(-103, -128, 127) = -103 + * Output: -103 (int8) + * + * PER-TENSOR QUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All elements use same quantization parameters + * - Parameters passed as push constants for efficiency + * - Formula: qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max) + * + * PER-TOKEN QUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates its token_id from tensor coordinates + * - Formula: qvalue = clamp(round(value / scale[token_id]) + zero_point[token_id], quant_min, quant_max) + */ + +#ifdef per_tensor + +void quantize_per_tensor() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T value = t_in[in_bufi]; + OUT_T qvalue = quantize_val(value, scale, zero_point); + + t_out[out_bufi] = qvalue; +} + +#else + +void quantize_per_token() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T value = t_in[in_bufi]; + + int token_idx = 0; + + if (t_out_sizes.w > 1) { + // 4D tensor + token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.z > 1) { + // 3D tensor + token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.y > 1) { + // 2D tensor + token_idx = out_tidx.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + OUT_T qvalue = quantize_val(value, t_scale[token_idx], t_zero_point[token_idx]); + + t_out[out_bufi] = qvalue; +} + +#endif + +void main() { + quantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml new file mode 100644 index 00000000000..90af2590936 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml @@ -0,0 +1,18 @@ +quantize_buffer: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: int32 + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: half + - VALUE: float + OUT_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + shader_variants: + - NAME: quantize_per_tensor_buffer + MODE: per_tensor + - NAME: quantize_per_token_buffer + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl new file mode 100644 index 00000000000..9ba7074f75b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl @@ -0,0 +1,184 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} + +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} +#define IVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} + +#define ${MODE} + +${define_active_storage_type("texture3d")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec3", "t_in_limits")} +${layout_declare_ubo(B, "ivec3", "t_out_limits")} + +#include "indexing_utils.h" +#include "quantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * QUANTIZATION SHADER (TEXTURE STORAGE) + * + * This shader converts floating-point tensor values to n-bit integer representations + * using pre-computed quantization parameters (scale and zero_point). The quantization + * maps floating-point values to a discrete integer range while preserving the + * original data distribution as much as possible. + * + * ALGORITHM: + * 1. Load floating-point texel (4 values) from 3D texture + * 2. Apply quantization formula to each component: qvalue = round(value / scale) + zero_point + * 3. Clamp each result to [quant_min, quant_max] range + * 4. Store quantized integer texel to output texture + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Texture Storage: Uses 3D texture indexing with texel-based processing + * - Assumes width-packed layout (packed_dim = 0) in current implementation + * - Handles texel padding for non-multiple-of-4 tensor dimensions + * - For per-token mode: scale/zero_point tensors must use buffer storage + * + * QUANTIZATION FORMULA VISUALIZATION: + * For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]: + * + * Floating Point Domain: Integer Domain: + * min_val ────────────────► quant_min + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * max_val ────────────────► quant_max + * + * Texel Quantization Process: + * Input Texel: [2.5, -1.0, 0.5, 3.2] (float4) + * Per-component quantization with scale=0.1, zero_point=-128: + * Component 0: round(2.5 / 0.1) + (-128) = 25 + (-128) = -103 + * Component 1: round(-1.0 / 0.1) + (-128) = -10 + (-128) = -138 → clamp to -128 + * Component 2: round(0.5 / 0.1) + (-128) = 5 + (-128) = -123 + * Component 3: round(3.2 / 0.1) + (-128) = 32 + (-128) = -96 + * Output Texel: [-103, -128, -123, -96] (int4) + * + * PER-TENSOR QUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All texel components use same quantization parameters + * - Parameters passed as push constants for efficiency + * - Each thread processes one texel (4 elements) independently + * - Formula: qvalue[i] = clamp(round(value[i] / scale) + zero_point, quant_min, quant_max) + * + * PER-TOKEN QUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates token_id from its 3D texture position + * - Scale/zero_point buffers accessed directly (not as textures) + * - Formula: qvalue[i] = clamp(round(value[i] / scale[token_id]) + zero_point[token_id], quant_min, quant_max) + */ + +#ifdef per_tensor + +void quantize_per_tensor() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + FVEC4_T intex = load_texel(t_in, pos); + IVEC4_T outtex; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale, zero_point); + outtex[i] = qvalue; + } + write_texel(t_out, pos, outtex); +} + +#else + +void quantize_per_token() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + FVEC4_T intex = load_texel(t_in, pos); + + int token_idx = 0; + ivec3 dims = t_in_limits; + + if (dims.z > 1) { + // 3D tensor + token_idx = pos.z * dims.y + pos.y; + } else if (dims.y > 1) { + // 2D tensor + token_idx = pos.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + // Scale and zero_point are prepacked as buffers, so direct access + float scale_val = t_scale[token_idx]; + int zero_point_val = t_zero_point[token_idx]; + + IVEC4_T outtex; + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); + outtex[i] = qvalue; + } + + write_texel(t_out, pos, outtex); +} + +#endif + +void main() { + quantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml new file mode 100644 index 00000000000..042eb0f8196 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml @@ -0,0 +1,18 @@ +quantize_texture: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: int32 + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: half + - VALUE: float + OUT_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + shader_variants: + - NAME: quantize_per_tensor_texture3d + MODE: per_tensor + - NAME: quantize_per_token_texture3d + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp new file mode 100644 index 00000000000..1dc2d34afbf --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -0,0 +1,347 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +namespace vkcompute { + +namespace { + +void resize_choose_qparams_tensor_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef scale_out = args.at(0).refs.at(0); + const ValueRef zero_point_out = args.at(0).refs.at(1); + + // Both scale and zero_point are scalar tensors for per-tensor quantization + // Since we use single workgroup approach, no extra buffer space needed + graph->virtual_resize(scale_out, {}); + graph->virtual_resize(zero_point_out, {}); +} + +void resize_choose_qparams_per_token_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef scale_out = args.at(0).refs.at(0); + const ValueRef zero_point_out = args.at(0).refs.at(1); + const ValueRef input = args.at(1).refs.at(0); + + // Calculate output sizes for scale and zero_point tensors + const auto input_sizes = graph->sizes_of(input); + std::vector output_sizes; + output_sizes.reserve(input_sizes.size() - 1); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + output_sizes.push_back(input_sizes[i]); + } + output_sizes.push_back(1); + + graph->virtual_resize(scale_out, output_sizes); + graph->virtual_resize(zero_point_out, output_sizes); +} + +// Custom workgroup size pickers for ChooseQParams operations +utils::uvec3 choose_qparams_pick_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + // For per-tensor quantization, we want a single workgroup that can handle + // all elements with proper reduction. The shader uses NWORKERS=64 threads. + const ValueRef input = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For buffer storage, use a single workgroup in X dimension + // The shader will handle strided access across all elements + return {1u, 1u, 1u}; + } else { + // For texture storage, use the default logic + return graph->create_global_wg_size(args.at(0).refs.at(0)); + } +} + +utils::uvec3 choose_qparams_pick_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For buffer storage, use 64 threads in X dimension to match NWORKERS + // This ensures the shared memory arrays are properly sized + return {64u, 1u, 1u}; + } else { + // For texture storage, use the default logic + return graph->create_local_wg_size(global_workgroup_size); + } +} + +utils::uvec3 choose_qparams_per_token_pick_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For per-token quantization, we need one workgroup per token + // Calculate number of tokens (product of all dimensions except the last + // one) + const auto input_sizes = graph->sizes_of(input); + int64_t num_tokens = 1; + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + return {static_cast(num_tokens), 1u, 1u}; + } else { + // For texture storage, use the default logic + return graph->create_global_wg_size(args.at(0).refs.at(0)); + } +} + +utils::uvec3 choose_qparams_per_token_pick_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For buffer storage, use 64 threads in X dimension to match NWORKERS + return {64u, 1u, 1u}; + } else { + // For texture storage, use the default logic + return graph->create_local_wg_size(global_workgroup_size); + } +} + +} // namespace + +void add_choose_qparams_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& scale_out, + const ValueRef& zero_point_out) { + std::string kernel_name("choose_qparams_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zero_point_out), + graph.strides_ubo(zero_point_out)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(scale_out), + graph.logical_limits_ubo(zero_point_out)}; + } + + std::vector push_constants; + push_constants = { + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + choose_qparams_pick_global_wg_size, + choose_qparams_pick_local_wg_size, + // Inputs and Outputs + {{scale_out, vkapi::kWrite}, + {zero_point_out, vkapi::kWrite}, + {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_choose_qparams_tensor_output)); +} + +void add_choose_qparams_per_token_asymmetric_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale_out, + const ValueRef& zero_point_out) { + std::string kernel_name("choose_qparams_per_token_asymmetric"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + int num_tokens_val = static_cast(num_tokens); + int quant_min_val = -128; // Fixed for asymmetric quantization + int quant_max_val = 127; // Fixed for asymmetric quantization + + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zero_point_out), + graph.strides_ubo(zero_point_out)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(scale_out), + graph.logical_limits_ubo(zero_point_out)}; + } + + std::vector push_constants; + push_constants = { + PushConstantDataInfo(&num_tokens_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + choose_qparams_per_token_pick_global_wg_size, + choose_qparams_per_token_pick_local_wg_size, + // Inputs and Outputs + {{scale_out, vkapi::kWrite}, + {zero_point_out, vkapi::kWrite}, + {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_choose_qparams_per_token_output)); +} + +void choose_qparams_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef scale_out = args[arg_idx++]; + const ValueRef zero_point_out = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale_out)); + VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf || + graph.dtype_of(input) == vkapi::kDouble); + + // Verify output types - accept CPU types but convert to GPU types + VK_CHECK_COND( + graph.dtype_of(scale_out) == vkapi::kFloat || + graph.dtype_of(scale_out) == vkapi::kDouble); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kLong); + + // Check that texture storage is width packed + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + } + + add_choose_qparams_tensor_node( + graph, input, quant_min, quant_max, scale_out, zero_point_out); +} + +void choose_qparams_per_token_asymmetric_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale_out = args[arg_idx++]; + const ValueRef zero_point_out = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale_out)); + VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf || + graph.dtype_of(input) == vkapi::kDouble); + + // Verify output types - accept CPU types but convert to GPU types + VK_CHECK_COND( + graph.dtype_of(scale_out) == vkapi::kFloat || + graph.dtype_of(scale_out) == vkapi::kDouble); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kLong); + + add_choose_qparams_per_token_asymmetric_node( + graph, input, scale_out, zero_point_out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(choose_qparams.tensor, choose_qparams_tensor_impl); + VK_REGISTER_OP( + choose_qparams_per_token_asymmetric.default, + choose_qparams_per_token_asymmetric_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp new file mode 100644 index 00000000000..77a51ce24f9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -0,0 +1,274 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include + +namespace vkcompute { + +namespace { + +void resize_dequantize_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + graph->virtual_resize(out, graph->sizes_of(in)); +} + +} // namespace + +void add_dequantize_per_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_per_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + float scale_val = static_cast(graph.get_double(scale)); + int zero_point_val = static_cast(graph.get_int(zero_point)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_output)); +} + +void add_dequantize_per_token_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_per_token"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + int num_tokens = static_cast(graph.sizes_of(scale)[0]); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_output)); +} + +void dequantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kHalf || + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + add_dequantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +void dequantize_per_token_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kHalf || + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + VK_CHECK_COND(scale_sizes.size() == 1); + VK_CHECK_COND(zero_point_sizes.size() == 1); + VK_CHECK_COND(scale_sizes[0] == num_tokens); + VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + + add_dequantize_per_token_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(dequantize_per_tensor.default, dequantize_per_tensor_impl); + VK_REGISTER_OP(dequantize_per_token.default, dequantize_per_token_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp new file mode 100644 index 00000000000..35712d59fb9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -0,0 +1,258 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include + +namespace vkcompute { + +namespace { + +void resize_quantize_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + graph->virtual_resize(out, graph->sizes_of(in)); +} + +} // namespace + +void add_quantize_per_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("quantize_per_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + float scale_val = static_cast(graph.get_double(scale)); + int zero_point_val = static_cast(graph.get_int(zero_point)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_quantize_output)); +} + +void add_quantize_per_token_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("quantize_per_token"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + int num_tokens = static_cast(graph.sizes_of(scale)[0]); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_quantize_output)); +} + +void quantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf); + + add_quantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +void quantize_per_token_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf); + + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + VK_CHECK_COND(scale_sizes.size() == 1); + VK_CHECK_COND(zero_point_sizes.size() == 1); + VK_CHECK_COND(scale_sizes[0] == num_tokens); + VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + + add_quantize_per_token_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(quantize_per_tensor.default, quantize_per_tensor_impl); + VK_REGISTER_OP(quantize_per_token.default, quantize_per_token_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp index 24c856e9d46..55e96151387 100644 --- a/backends/vulkan/test/op_tests/choose_qparams_test.cpp +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -516,6 +516,58 @@ TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) { at::kChar); } +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_uint8_4D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_tensor( + {5, 3, 2, 4}, // input sizes + 0, // quant_min + 255, // quant_max + at::kByte); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_2D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_tensor( + {5, 5}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_3D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_tensor( + {12, 8, 2}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_4D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_tensor( + {10, 10, 6, 4}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + void test_reference_choose_qparams_per_token_asymmetric( const std::vector& input_sizes, at::ScalarType dtype) { @@ -673,3 +725,47 @@ TEST( {2, 3, 4}, // input sizes (2*3=6 tokens) at::kChar); } + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_1D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_per_token_asymmetric({7}, at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_2D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_per_token_asymmetric({2, 2}, at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_3D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_per_token_asymmetric({3, 6, 4}, at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_4D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_per_token_asymmetric({128, 2, 16, 3}, at::kChar); +} diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 7b155c8f98b..1ec0602a4f2 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -20,6 +20,7 @@ #include "test_utils.h" #include +#include #include #include @@ -481,6 +482,8 @@ void test_reference_dequantize_per_tensor( std::cout << " zero_point: " << zero_point << std::endl; std::cout << " quant_min: " << quant_min << std::endl; std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -598,8 +601,15 @@ void test_vulkan_dequantize_per_tensor_impl( graph.copy_from_staging( staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - // Compare outputs - const bool output_correct = at::allclose(reference_out, vk_out); + // Compare outputs with appropriate tolerance for half precision + bool output_correct; + if (out_dtype == at::kHalf) { + // Use higher tolerance for half precision due to limited precision + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); + } else { + output_correct = at::allclose(reference_out, vk_out); + } if (!output_correct) { std::cout << "\n" << "Failed with parameters: " << std::endl; @@ -611,6 +621,8 @@ void test_vulkan_dequantize_per_tensor_impl( << (in_storage == vkcompute::utils::kBuffer ? "buffer" : "texture") << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -623,7 +635,6 @@ void test_vulkan_dequantize_per_tensor_impl( ASSERT_TRUE(output_correct); } -// Test cases for dequantize_per_tensor TEST( VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_uint8_to_float) { @@ -689,6 +700,99 @@ TEST( at::kHalf); // output dtype } +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_uint8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {2, 3, 4}, // input sizes + 0.1, // scale + 5, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {3, 4}, // input sizes + 0.05, // scale + 0, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int32_to_float) { + test_vulkan_dequantize_per_tensor( + {2, 4, 3, 12}, // input sizes + 0.0001, // scale + 100, // zero_point + -2147483648, // quant_min + 2147483647, // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int8_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {2, 3}, // input sizes + 0.05, // scale + 10, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int32_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + // Use much smaller scale to avoid overflow to infinity in half precision + // Half precision max value is ~65504, so with int32 values around 2e9, + // we need scales smaller than 65504/2e9 ≈ 3e-5 to avoid overflow + test_vulkan_dequantize_per_tensor( + {7}, // input sizes + 1e-5, // scale (much smaller to avoid overflow) + 5, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} + void test_reference_dequantize_per_token( const std::vector& input_sizes, const std::vector& scales, @@ -793,6 +897,8 @@ void test_reference_dequantize_per_token( std::cout << "" << std::endl; std::cout << " quant_min: " << quant_min << std::endl; std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -894,9 +1000,15 @@ void test_vulkan_dequantize_per_token_impl( IOValueRef r_input = graph.add_input_tensor( input.sizes().vec(), from_at_scalartype(dtype), in_storage); IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), vkapi::kFloat, in_storage); + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), vkapi::kInt, in_storage); + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); const ValueRef r_quant_min = graph.add_scalar(quant_min); const ValueRef r_quant_max = graph.add_scalar(quant_max); @@ -946,8 +1058,15 @@ void test_vulkan_dequantize_per_token_impl( graph.copy_from_staging( staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - // Compare outputs - const bool output_correct = at::allclose(reference_out, vk_out); + // Compare outputs with appropriate tolerance for half precision + bool output_correct; + if (out_dtype == at::kHalf) { + // Use higher tolerance for half precision due to limited precision + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); + } else { + output_correct = at::allclose(reference_out, vk_out); + } if (!output_correct) { std::cout << "\n" << "Failed with parameters: " << std::endl; @@ -967,6 +1086,8 @@ void test_vulkan_dequantize_per_token_impl( << (in_storage == vkcompute::utils::kBuffer ? "buffer" : "texture") << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -979,7 +1100,6 @@ void test_vulkan_dequantize_per_token_impl( ASSERT_TRUE(output_correct); } -// Test cases for dequantize_per_token TEST( VulkanDequantizePerTokenTest, test_reference_dequantize_per_token_uint8_to_float) { @@ -1059,3 +1179,112 @@ TEST( at::kInt, // input dtype at::kHalf); // output dtype } + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_uint8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector zero_points = {5, 10, 15, 20, 25, 30}; + + test_vulkan_dequantize_per_token( + {2, 3, 6}, // input sizes (2*3=6 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.05, 0.0}; + std::vector zero_points = {10, -5}; + + test_vulkan_dequantize_per_token( + {2, 2}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int32_to_float) { + std::vector scales = { + 0.0001, 0.0002, 0.0003, 0.0, 0.0011, 0.0102, 0.1003, 0.0}; + std::vector zero_points = {100, -100, 50, -50, 12, -6, 4, -24}; + + test_vulkan_dequantize_per_token( + {2, 2, 2, 12}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -2147483648, // quant_min + 2147483647, // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int8_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.05, 0.2}; + std::vector zero_points = {2, -5}; + + test_vulkan_dequantize_per_token( + {2, 2}, // input sizes (2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int32_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + // Use much smaller scales to avoid overflow to infinity in half precision + // Half precision max value is ~65504, so with int32 values around 2e9, + // we need scales smaller than 65504/2e9 ≈ 3e-5 to avoid overflow + std::vector scales = {1e-5, 2e-5, 1.5e-5}; + std::vector zero_points = {20, -15, 1}; + + test_vulkan_dequantize_per_token( + {3, 6}, // input sizes (3 tokens) + scales, + zero_points, + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 8b79dc1ce6b..7ea98b14fb2 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -21,6 +21,9 @@ #include #include +#include + +float eps = 1e-7; namespace torch { namespace executor { @@ -383,6 +386,8 @@ void test_reference_quantize_per_tensor( // Reshape back to original dimensions input = flat_input.reshape(input_sizes_int64); + scale = scale < eps ? eps : scale; + // Get reference output at::Tensor reference_out = quantize_per_tensor_reference_impl( input, scale, zero_point, quant_min, quant_max, dtype); @@ -435,6 +440,8 @@ void test_vulkan_quantize_per_tensor_impl( at::Tensor input = at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + scale = scale < eps ? eps : scale; + // Get reference output at::Tensor reference_out = torch::executor::native::quantize_per_tensor_aten( input, scale, zero_point, quant_min, quant_max, dtype); @@ -490,7 +497,7 @@ void test_vulkan_quantize_per_tensor_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - const bool output_correct = at::equal(reference_int, vk_int); + const bool output_correct = at::allclose(reference_int, vk_int); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -500,6 +507,10 @@ void test_vulkan_quantize_per_tensor_impl( std::cout << " zero_point: " << zero_point << std::endl; std::cout << " quant_min: " << quant_min << std::endl; std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -564,9 +575,89 @@ TEST( at::kInt); } +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_float_to_uint8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_float_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_float_to_int32) { + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + -2147483648, // quant_min + 2147483647, // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_float_to_int32_small_scale) { + test_vulkan_quantize_per_tensor( + {2, 8, 1, 3}, // input sizes + 0.0, // scale + 20, // zero_point + -2147483648, // quant_min + 2147483647, // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_half_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor( + {2, 3}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kHalf, // input dtype + at::kChar); // output dtype +} + void test_reference_quantize_per_token( const std::vector& input_sizes, - const std::vector& scales, + const std::vector& pre_scales, const std::vector& zero_points, int64_t quant_min, int64_t quant_max, @@ -595,9 +686,14 @@ void test_reference_quantize_per_token( } // Verify that the number of tokens matches the size of scales and zero_points - ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, pre_scales.size()); ASSERT_EQ(num_tokens, zero_points.size()); + std::vector scales = pre_scales; + for (auto& s : scales) { + s = s < eps ? eps : s; + } + // Create scale and zero_point tensors at::Tensor scale_tensor = at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); @@ -646,7 +742,7 @@ void test_reference_quantize_per_token( void test_vulkan_quantize_per_token_impl( const std::vector& input_sizes, - const std::vector& scales, + const std::vector& pre_scales, const std::vector& zero_points, int64_t quant_min, int64_t quant_max, @@ -662,9 +758,14 @@ void test_vulkan_quantize_per_token_impl( num_tokens *= input_sizes[i]; } - ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, pre_scales.size()); ASSERT_EQ(num_tokens, zero_points.size()); + std::vector scales = pre_scales; + for (auto& s : scales) { + s = s < eps ? eps : s; + } + // Create input tensor with random values std::vector input_sizes_int64( input_sizes.begin(), input_sizes.end()); @@ -688,9 +789,15 @@ void test_vulkan_quantize_per_token_impl( IOValueRef r_input = graph.add_input_tensor( input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), vkapi::kFloat, in_storage); + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), vkapi::kInt, in_storage); + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); const ValueRef r_quant_min = graph.add_scalar(quant_min); const ValueRef r_quant_max = graph.add_scalar(quant_max); @@ -744,7 +851,7 @@ void test_vulkan_quantize_per_token_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - const bool output_correct = at::equal(reference_int, vk_int); + const bool output_correct = at::allclose(reference_int, vk_int); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -841,3 +948,130 @@ TEST( at::kHalf, at::kByte); } + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_uint8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = { + -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; + std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; + + test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5*2=10 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = { + -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; + std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; + + test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_int32) { + std::vector scales = { + -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; + std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; + + test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5*2=10 tokens) + scales, + zero_points, + -2147483648, // quant_min + 2147483647, // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_int32_small_scales) { + std::vector scales = { + 0, + 2.9387358770557188e-39f, + 1.40129846e-45f, + 1.17549435e-38f, + 0.0000000000001}; + std::vector zero_points = {20, -10, 15, 200, 50}; + + test_vulkan_quantize_per_token( + {5, 2}, // input sizes (3 tokens) + scales, + zero_points, + -2147483648, // quant_min + 2147483647, // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_uint8_many_tokens) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(18, 0.1); + std::vector zero_points(18, 5); + + // Alternate scale values + for (size_t i = 0; i < scales.size(); i++) { + scales[i] = (i % 2 == 0) ? 0.3 : -0.5; + } + + test_vulkan_quantize_per_token( + {3, 3, 2, 3}, // input sizes (3*3*2=18 tokens) + scales, + zero_points, + 0, // quant_min + 125, // quant_max + at::kFloat, + at::kByte); +} + +TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_vulkan_quantize_per_token( + {2, 2}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kHalf, // input dtype + at::kChar); // output dtype +} diff --git a/backends/vulkan/test/op_tests/test_utils.cpp b/backends/vulkan/test/op_tests/test_utils.cpp index 196f079be2c..c5702abd079 100644 --- a/backends/vulkan/test/op_tests/test_utils.cpp +++ b/backends/vulkan/test/op_tests/test_utils.cpp @@ -94,7 +94,8 @@ vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { case c10::kInt: return vkapi::kInt; case c10::kLong: - return vkapi::kLong; + // No support for 64-bit integers + return vkapi::kInt; case c10::kChar: return vkapi::kChar; case c10::kByte: