Skip to content

Commit 34816ac

Browse files
authored
[ET-VK][Ops] choose_qparams op shaders and impl
Differential Revision: D76436933 Pull Request resolved: #11557
1 parent b324f8b commit 34816ac

File tree

7 files changed

+1213
-0
lines changed

7 files changed

+1213
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef CHOOSE_QPARAMS_GLSLH
10+
#define CHOOSE_QPARAMS_GLSLH
11+
12+
// equivalent of the eps defined in the cpu implementation
13+
#define SMALL_SCALE_THRESHOLD 6.1e-5
14+
15+
// Calculate scale and zero point from min and max values
16+
void calculate_scale_and_zero_point(
17+
float min_val,
18+
float max_val,
19+
int qmin,
20+
int qmax,
21+
out float scale_val,
22+
out int zero_point_val) {
23+
// ensure we have zero included in our range
24+
min_val = min(min_val, 0.0);
25+
max_val = max(max_val, 0.0);
26+
27+
scale_val = (max_val - min_val) / float(qmax - qmin);
28+
29+
// Handle zero or very small scale
30+
if (scale_val == 0.0 || isinf(1.0 / scale_val)) {
31+
scale_val = 0.1;
32+
}
33+
34+
// Cut off small scale
35+
if (scale_val < SMALL_SCALE_THRESHOLD) {
36+
float org_scale = scale_val;
37+
scale_val = SMALL_SCALE_THRESHOLD;
38+
39+
// Adjust min and max based on new scale
40+
if (min_val == 0.0) {
41+
max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin);
42+
} else if (max_val == 0.0) {
43+
min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin);
44+
} else {
45+
float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
46+
min_val *= amplifier;
47+
max_val *= amplifier;
48+
}
49+
}
50+
51+
// Calculate zero point
52+
float zero_point_from_min = float(qmin) - min_val / scale_val;
53+
float zero_point_from_max = float(qmax) - max_val / scale_val;
54+
float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale_val);
55+
float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale_val);
56+
float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error
57+
? zero_point_from_min
58+
: zero_point_from_max;
59+
60+
// Nudge zero point to integer
61+
if (initial_zero_point < float(qmin)) {
62+
zero_point_val = qmin;
63+
} else if (initial_zero_point > float(qmax)) {
64+
zero_point_val = qmax;
65+
} else {
66+
zero_point_val = int(round(initial_zero_point));
67+
}
68+
}
69+
70+
#endif // CHOOSE_QPARAMS_GLSLH
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
14+
15+
#define ${MODE}
16+
17+
${define_active_storage_type("buffer")}
18+
${define_required_extensions(IN_DTYPE)}
19+
20+
#extension GL_EXT_control_flow_attributes : require
21+
22+
layout(std430) buffer;
23+
24+
${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")}
25+
${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")}
26+
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
27+
28+
$if MODE == "per_tensor":
29+
layout(push_constant) uniform restrict Block {
30+
int quant_min;
31+
int quant_max;
32+
};
33+
$else:
34+
layout(push_constant) uniform restrict Block {
35+
int num_tokens;
36+
int quant_min;
37+
int quant_max;
38+
};
39+
40+
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
41+
${layout_declare_ubo(B, "ivec4", "t_in_strides")}
42+
${layout_declare_ubo(B, "ivec4", "t_scale_sizes")}
43+
${layout_declare_ubo(B, "ivec4", "t_scale_strides")}
44+
${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")}
45+
${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")}
46+
47+
#include "indexing_utils.h"
48+
#include "choose_qparams.glslh"
49+
50+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
51+
52+
#define NWORKERS 64
53+
54+
// Shared memory for reduction - must match local work group size
55+
shared float shared_min[NWORKERS];
56+
shared float shared_max[NWORKERS];
57+
58+
/*
59+
* QUANTIZATION PARAMETER COMPUTATION SHADER (BUFFER STORAGE)
60+
*
61+
* This shader computes quantization parameters (scale and zero_point) for converting
62+
* floating-point tensors to n-bit integer representations while preserving the
63+
* original data range as much as possible.
64+
*
65+
* ALGORITHM:
66+
* 1. Find global min/max values across tensor elements using parallel reduction
67+
* 2. Use tree reduction with shared memory for efficient min/max computation
68+
* 3. Calculate scale = (max - min) / (quant_max - quant_min)
69+
* 4. Calculate zero_point to map floating-point zero to integer value
70+
*
71+
* WORKGROUP CONFIGURATION:
72+
* - Per-Tensor Mode:
73+
* - Global WG Size: {1, 1, 1} (single workgroup processes entire tensor)
74+
* - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory)
75+
* - Per-Token Mode:
76+
* - Global WG Size: {num_tokens, 1, 1} (one workgroup per token)
77+
* - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory)
78+
*
79+
* SUPPORTED CONFIGURATIONS:
80+
* - Buffer Storage: Uses simple linear indexing through buffer elements
81+
* - No axis mapping or packing considerations - processes elements sequentially
82+
* - Works with any tensor layout since it accesses buffer data linearly
83+
*
84+
* TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING:
85+
* For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]:
86+
*
87+
* Initial shared_min/shared_max arrays populated by each thread:
88+
* shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 |
89+
* shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 |
90+
* Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
91+
*
92+
* Stride 1 (compare pairs, keep min/max):
93+
* shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5))
94+
* shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5))
95+
* Active: | 0 | | 2 | | 4 | | 6 | |
96+
*
97+
* Stride 2 (compare pairs, keep min/max):
98+
* shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3))
99+
* shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5))
100+
* Active: | 0 | | | | 4 | | | |
101+
*
102+
* Stride 4 (final comparison):
103+
* shared_min: | 0 | | | | | | | | (min(0,0) = 0)
104+
* shared_max: | 10 | | | | | | | | (max(10,5) = 10)
105+
* Active: | 0 | | | | | | | |
106+
*
107+
* Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0])
108+
*
109+
* PER-TENSOR QUANTIZATION:
110+
* - Single workgroup processes entire tensor with strided access
111+
* - Each thread processes elements [thread_id, thread_id + 64, thread_id + 128, ...]
112+
* - Tree reduction combines all thread results into global min/max
113+
* - Output: Single scale and zero_point values
114+
*
115+
* PER-TOKEN QUANTIZATION:
116+
* - Multiple workgroups, each processing one token
117+
* - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements)
118+
* - Each workgroup finds min/max within its assigned token
119+
* - Output: Array of scale and zero_point values (one per token)
120+
*/
121+
122+
#ifdef per_tensor
123+
124+
void choose_qparams_per_tensor() {
125+
uint global_id = gl_GlobalInvocationID.x;
126+
uint local_id = gl_LocalInvocationID.x;
127+
uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
128+
129+
uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w);
130+
131+
// Each thread processes multiple elements with stride
132+
float thread_min = 1.0/0.0; // +infinity
133+
float thread_max = -1.0/0.0; // -infinity
134+
bool found_valid = false;
135+
136+
for (uint i = global_id; i < total_elements; i += total_threads) {
137+
float val = t_in[i];
138+
if (!isnan(val) && !isinf(val)) {
139+
if (!found_valid) {
140+
thread_min = val;
141+
thread_max = val;
142+
found_valid = true;
143+
} else {
144+
thread_min = min(thread_min, val);
145+
thread_max = max(thread_max, val);
146+
}
147+
}
148+
}
149+
150+
// Intra-group reduction using shared memory
151+
shared_min[local_id] = thread_min;
152+
shared_max[local_id] = thread_max;
153+
barrier();
154+
155+
// Tree reduction within work group
156+
for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) {
157+
if (local_id < stride) {
158+
float other_min = shared_min[local_id + stride];
159+
float other_max = shared_max[local_id + stride];
160+
161+
if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) {
162+
shared_min[local_id] = other_min;
163+
}
164+
if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) {
165+
shared_max[local_id] = other_max;
166+
}
167+
}
168+
barrier();
169+
}
170+
171+
// Final result calculation (single workgroup only)
172+
if (local_id == 0) {
173+
float global_min = shared_min[0];
174+
float global_max = shared_max[0];
175+
176+
float scale_val;
177+
int zero_point_val;
178+
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val);
179+
180+
t_scale[0] = scale_val;
181+
t_zero_point[0] = zero_point_val;
182+
}
183+
}
184+
185+
#else
186+
187+
void choose_qparams_per_token() {
188+
uint global_id = gl_GlobalInvocationID.x;
189+
uint local_id = gl_LocalInvocationID.x;
190+
uint group_id = gl_WorkGroupID.x;
191+
uint total_workgroups = gl_NumWorkGroups.x;
192+
193+
uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w);
194+
uint token_size = total_elements / uint(num_tokens);
195+
196+
// Calculate how many tokens each workgroup should process
197+
// This handles the case where we have more tokens than workgroups
198+
uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups;
199+
200+
// Calculate which tokens this workgroup is responsible for
201+
uint start_token = group_id * tokens_per_workgroup;
202+
uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens));
203+
204+
// Early exit if this workgroup has no tokens to process
205+
if (start_token >= uint(num_tokens)) {
206+
return;
207+
}
208+
209+
// Process each token assigned to this workgroup
210+
for (uint token_id = start_token; token_id < end_token; token_id++) {
211+
// Calculate the start and end indices for this token
212+
uint token_start = token_id * token_size;
213+
uint token_end = token_start + token_size;
214+
215+
// Each thread processes multiple elements within the token with stride
216+
float thread_min = 1.0/0.0; // +infinity
217+
float thread_max = -1.0/0.0; // -infinity
218+
bool found_valid = false;
219+
220+
// Process elements within this token only
221+
for (uint i = token_start + local_id; i < token_end; i += gl_WorkGroupSize.x) {
222+
float val = t_in[i];
223+
if (!isnan(val) && !isinf(val)) {
224+
if (!found_valid) {
225+
thread_min = val;
226+
thread_max = val;
227+
found_valid = true;
228+
} else {
229+
thread_min = min(thread_min, val);
230+
thread_max = max(thread_max, val);
231+
}
232+
}
233+
}
234+
235+
// Intra-group reduction using shared memory
236+
shared_min[local_id] = thread_min;
237+
shared_max[local_id] = thread_max;
238+
barrier();
239+
240+
// Tree reduction within work group
241+
for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) {
242+
if (local_id < stride) {
243+
float other_min = shared_min[local_id + stride];
244+
float other_max = shared_max[local_id + stride];
245+
246+
if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) {
247+
shared_min[local_id] = other_min;
248+
}
249+
if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) {
250+
shared_max[local_id] = other_max;
251+
}
252+
}
253+
barrier();
254+
}
255+
256+
// Final calculation for this token
257+
if (local_id == 0) {
258+
float token_min = shared_min[0];
259+
float token_max = shared_max[0];
260+
261+
float scale_val;
262+
int zero_point_val;
263+
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val);
264+
265+
t_scale[token_id] = scale_val;
266+
t_zero_point[token_id] = zero_point_val;
267+
}
268+
269+
// Synchronize before processing next token
270+
barrier();
271+
}
272+
}
273+
274+
#endif
275+
276+
void main() {
277+
choose_qparams_${MODE}();
278+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
choose_qparams_buffer:
2+
parameter_names_with_default_values:
3+
IN_DTYPE: float
4+
MODE: per_tensor
5+
generate_variant_forall:
6+
IN_DTYPE:
7+
- VALUE: float
8+
shader_variants:
9+
- NAME: choose_qparams_tensor_buffer
10+
MODE: per_tensor
11+
- NAME: choose_qparams_per_token_asymmetric_buffer
12+
MODE: per_token

0 commit comments

Comments
 (0)