Skip to content

[ET-VK][Ops] choose_qparams op shaders and impl #11557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: gh/ahmtox/22/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#ifndef CHOOSE_QPARAMS_GLSLH
#define CHOOSE_QPARAMS_GLSLH

// equivalent of the eps defined in the cpu implementation
#define SMALL_SCALE_THRESHOLD 6.1e-5

// Calculate scale and zero point from min and max values
void calculate_scale_and_zero_point(
float min_val,
float max_val,
int qmin,
int qmax,
out float scale_val,
out int zero_point_val) {
// ensure we have zero included in our range
min_val = min(min_val, 0.0);
max_val = max(max_val, 0.0);

scale_val = (max_val - min_val) / float(qmax - qmin);

// Handle zero or very small scale
if (scale_val == 0.0 || isinf(1.0 / scale_val)) {
scale_val = 0.1;
}

// Cut off small scale
if (scale_val < SMALL_SCALE_THRESHOLD) {
float org_scale = scale_val;
scale_val = SMALL_SCALE_THRESHOLD;

// Adjust min and max based on new scale
if (min_val == 0.0) {
max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin);
} else if (max_val == 0.0) {
min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin);
} else {
float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
min_val *= amplifier;
max_val *= amplifier;
}
}

// Calculate zero point
float zero_point_from_min = float(qmin) - min_val / scale_val;
float zero_point_from_max = float(qmax) - max_val / scale_val;
float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale_val);
float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale_val);
float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error
? zero_point_from_min
: zero_point_from_max;

// Nudge zero point to integer
if (initial_zero_point < float(qmin)) {
zero_point_val = qmin;
} else if (initial_zero_point > float(qmax)) {
zero_point_val = qmax;
} else {
zero_point_val = int(round(initial_zero_point));
}
}

#endif // CHOOSE_QPARAMS_GLSLH
203 changes: 203 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define IN_T ${buffer_scalar_type(IN_DTYPE)}

${define_active_storage_type("buffer")}
${define_required_extensions(IN_DTYPE)}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")}

$if MODE == "per_tensor":
layout(push_constant) uniform restrict Block {
int quant_min;
int quant_max;
};
$else:
layout(push_constant) uniform restrict Block {
int num_tokens;
int quant_min;
int quant_max;
};

${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
${layout_declare_ubo(B, "ivec4", "t_in_strides")}
${layout_declare_ubo(B, "ivec4", "t_scale_sizes")}
${layout_declare_ubo(B, "ivec4", "t_scale_strides")}
${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")}
${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")}

#include "indexing_utils.h"
#include "choose_qparams.glslh"

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#define NWORKERS 64

// Shared memory for reduction - must match local work group size
shared float shared_min[NWORKERS];
shared float shared_max[NWORKERS];

void main() {
$if MODE == "per_tensor":
uint global_id = gl_GlobalInvocationID.x;
uint local_id = gl_LocalInvocationID.x;
uint group_id = gl_WorkGroupID.x;
uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x;

uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w);

// Each thread processes multiple elements with stride
float thread_min = 1.0/0.0; // +infinity
float thread_max = -1.0/0.0; // -infinity
bool found_valid = false;

for (uint i = global_id; i < total_elements; i += total_threads) {
float val = t_in[i];
if (!isnan(val) && !isinf(val)) {
if (!found_valid) {
thread_min = val;
thread_max = val;
found_valid = true;
} else {
thread_min = min(thread_min, val);
thread_max = max(thread_max, val);
}
}
}

// Intra-group reduction using shared memory
shared_min[local_id] = thread_min;
shared_max[local_id] = thread_max;
barrier();

// Tree reduction within work group
for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) {
if (local_id < stride) {
float other_min = shared_min[local_id + stride];
float other_max = shared_max[local_id + stride];

if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) {
shared_min[local_id] = other_min;
}
if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) {
shared_max[local_id] = other_max;
}
}
barrier();
}

// Final result calculation (single workgroup only)
if (local_id == 0) {
float global_min = shared_min[0];
float global_max = shared_max[0];

float scale_val;
int zero_point_val;
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val);

t_scale[0] = scale_val;
t_zero_point[0] = zero_point_val;
}

$if MODE == "per_token":
uint global_id = gl_GlobalInvocationID.x;
uint local_id = gl_LocalInvocationID.x;
uint group_id = gl_WorkGroupID.x;
uint total_workgroups = gl_NumWorkGroups.x;

uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w);
uint token_size = total_elements / uint(num_tokens);

// Calculate how many tokens each workgroup should process
// This handles the case where we have more tokens than workgroups
uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups;

// Calculate which tokens this workgroup is responsible for
uint start_token = group_id * tokens_per_workgroup;
uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens));

// Early exit if this workgroup has no tokens to process
if (start_token >= uint(num_tokens)) {
return;
}

// Process each token assigned to this workgroup
for (uint token_id = start_token; token_id < end_token; token_id++) {
// Calculate the start and end indices for this token
uint token_start = token_id * token_size;
uint token_end = token_start + token_size;

// Each thread processes multiple elements within the token with stride
float thread_min = 1.0/0.0; // +infinity
float thread_max = -1.0/0.0; // -infinity
bool found_valid = false;

// Process elements within this token only
for (uint i = token_start + local_id; i < token_end; i += gl_WorkGroupSize.x) {
float val = t_in[i];
if (!isnan(val) && !isinf(val)) {
if (!found_valid) {
thread_min = val;
thread_max = val;
found_valid = true;
} else {
thread_min = min(thread_min, val);
thread_max = max(thread_max, val);
}
}
}

// Intra-group reduction using shared memory
shared_min[local_id] = thread_min;
shared_max[local_id] = thread_max;
barrier();

// Tree reduction within work group
for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) {
if (local_id < stride) {
float other_min = shared_min[local_id + stride];
float other_max = shared_max[local_id + stride];

if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) {
shared_min[local_id] = other_min;
}
if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) {
shared_max[local_id] = other_max;
}
}
barrier();
}

// Final calculation for this token
if (local_id == 0) {
float token_min = shared_min[0];
float token_max = shared_max[0];

float scale_val;
int zero_point_val;
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val);

t_scale[token_id] = scale_val;
t_zero_point[token_id] = zero_point_val;
}

// Synchronize before processing next token
barrier();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
choose_qparams_buffer:
parameter_names_with_default_values:
IN_DTYPE: float
MODE: per_tensor
generate_variant_forall:
IN_DTYPE:
- VALUE: float
shader_variants:
- NAME: choose_qparams_tensor_buffer
MODE: per_tensor
- NAME: choose_qparams_per_token_asymmetric_buffer
MODE: per_token
Loading
Loading