Skip to content

Commit 01e4a06

Browse files
author
morelos
committed
[ET-VK][Ops] enabling double support for quantization and dequantization ops
Pull Request resolved: #11553 With the added double support in the layout template, this diff is enabling it as input/output for dequantization. Since there are limitations with how 64bit can be supported, the expectation is that IO be downgraded to 32bit ghstack-source-id: 290041469 @exported-using-ghexport Differential Revision: [D76289197](https://our.internmc.facebook.com/intern/diff/D76289197/)
1 parent d46f3e9 commit 01e4a06

File tree

8 files changed

+96
-2
lines changed

8 files changed

+96
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_buffer:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_buffer
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ $if MODE == "per_tensor":
6767
[[unroll]] for (int i = 0; i < 4; ++i) {
6868
IN_T qvalue = IN_T(intex[i]);
6969
OUT_T value = dequantize_val(qvalue, scale, zero_point);
70-
outtex[i] = value;
70+
$if OUT_DTYPE == "double":
71+
outtex[i] = float(value);
72+
$else:
73+
outtex[i] = value;
7174
}
7275
write_texel(t_out, pos, outtex);
7376

@@ -110,7 +113,10 @@ $if MODE == "per_token":
110113
[[unroll]] for (int i = 0; i < 4; ++i) {
111114
IN_T qvalue = IN_T(intex[i]);
112115
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
113-
outtex[i] = value;
116+
$if OUT_DTYPE == "double":
117+
outtex[i] = float(value);
118+
$else:
119+
outtex[i] = value;
114120
}
115121

116122
write_texel(t_out, pos, outtex);

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_texture:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_texture3d
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_buffer:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_texture:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ void quantize_per_tensor_impl(
162162

163163
// Verify input is a floating point type
164164
VK_CHECK_COND(
165+
graph.dtype_of(input) == vkapi::kDouble ||
165166
graph.dtype_of(input) == vkapi::kFloat ||
166167
graph.dtype_of(input) == vkapi::kHalf);
167168

@@ -185,6 +186,7 @@ void quantize_per_token_impl(
185186

186187
// Verify input is a floating point type
187188
VK_CHECK_COND(
189+
graph.dtype_of(input) == vkapi::kDouble ||
188190
graph.dtype_of(input) == vkapi::kFloat ||
189191
graph.dtype_of(input) == vkapi::kHalf);
190192

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,12 @@ void test_vulkan_dequantize_per_tensor(
364364
vkcompute::utils::kBuffer,
365365
vkcompute::utils::kBuffer);
366366

367+
// Telling the system to expect a float instead of a double
368+
// since the shader can only return 32bit anyways
369+
if (out_dtype == at::kDouble) {
370+
out_dtype = at::kFloat;
371+
}
372+
367373
// Test with texture storage
368374
test_vulkan_dequantize_per_tensor_impl(
369375
input_sizes,
@@ -398,6 +404,12 @@ void test_vulkan_dequantize_per_token(
398404
vkcompute::utils::kBuffer,
399405
vkcompute::utils::kBuffer);
400406

407+
// Telling the system to expect a float instead of a double
408+
// since the shader can only return 32bit anyways
409+
if (out_dtype == at::kDouble) {
410+
out_dtype = at::kFloat;
411+
}
412+
401413
// Test with texture storage
402414
test_vulkan_dequantize_per_token_impl(
403415
input_sizes,
@@ -767,6 +779,19 @@ TEST(
767779
at::kHalf); // output dtype
768780
}
769781

782+
TEST(
783+
VulkanDequantizePerTensorTest,
784+
test_vulkan_dequantize_per_tensor_int32_to_double) {
785+
test_vulkan_dequantize_per_tensor(
786+
{2, 4, 3}, // input sizes
787+
0.0001, // scale
788+
100, // zero_point
789+
-2147483648, // quant_min
790+
2147483647, // quant_max
791+
at::kInt, // input dtype
792+
at::kDouble); // output dtype
793+
}
794+
770795
void test_reference_dequantize_per_token(
771796
const std::vector<int>& input_sizes,
772797
const std::vector<float>& scales,
@@ -1232,3 +1257,19 @@ TEST(
12321257
at::kInt, // input dtype
12331258
at::kHalf); // output dtype
12341259
}
1260+
1261+
TEST(
1262+
VulkanDequantizePerTokenTest,
1263+
test_vulkan_dequantize_per_token_int32_to_double) {
1264+
std::vector<float> scales = {0.0001, 0.0002, 0.0003, 0.0};
1265+
std::vector<int> zero_points = {100, -100, 50, -50};
1266+
1267+
test_vulkan_dequantize_per_token(
1268+
{2, 2, 8}, // input sizes (2*2=4 tokens)
1269+
scales,
1270+
zero_points,
1271+
-2147483648, // quant_min
1272+
2147483647, // quant_max
1273+
at::kInt, // input dtype
1274+
at::kDouble); // output dtype
1275+
}

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,12 @@ void test_vulkan_quantize_per_tensor(
314314
vkcompute::utils::kBuffer,
315315
vkcompute::utils::kBuffer);
316316

317+
// If the in_dtype is a double, convert to float for texture implementation
318+
// since they don't support 64bit as inputs
319+
if (in_dtype == at::kDouble) {
320+
in_dtype = at::kFloat;
321+
}
322+
317323
// Test with texture storage
318324
test_vulkan_quantize_per_tensor_impl(
319325
input_sizes,
@@ -348,6 +354,12 @@ void test_vulkan_quantize_per_token(
348354
vkcompute::utils::kBuffer,
349355
vkcompute::utils::kBuffer);
350356

357+
// If the in_dtype is a double, convert to float for texture implementation
358+
// since they don't support 64bit as inputs
359+
if (in_dtype == at::kDouble) {
360+
in_dtype = at::kFloat;
361+
}
362+
351363
// Test with texture storage
352364
test_vulkan_quantize_per_token_impl(
353365
input_sizes,
@@ -639,6 +651,19 @@ TEST(
639651
at::kChar); // output dtype
640652
}
641653

654+
TEST(
655+
VulkanQuantizePerTensorTest,
656+
test_vulkan_quantize_per_tensor_double_to_int8) {
657+
test_vulkan_quantize_per_tensor(
658+
{2, 3}, // input sizes
659+
0.01, // scale
660+
1, // zero_point
661+
-128, // quant_min
662+
127, // quant_max
663+
at::kDouble, // input dtype
664+
at::kChar); // output dtype
665+
}
666+
642667
void test_reference_quantize_per_token(
643668
const std::vector<int>& input_sizes,
644669
const std::vector<float>& pre_scales,
@@ -1033,3 +1058,19 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) {
10331058
at::kHalf, // input dtype
10341059
at::kChar); // output dtype
10351060
}
1061+
1062+
TEST(
1063+
VulkanQuantizePerTensorTest,
1064+
test_vulkan_quantize_per_token_double_to_int8) {
1065+
std::vector<float> scales = {0.1, 0.2};
1066+
std::vector<int> zero_points = {0, 5};
1067+
1068+
test_vulkan_quantize_per_token(
1069+
{2, 2}, // input sizes (2*2=4 tokens)
1070+
scales,
1071+
zero_points,
1072+
-128, // quant_min
1073+
127, // quant_max
1074+
at::kDouble, // input dtype
1075+
at::kChar); // output dtype
1076+
}

0 commit comments

Comments
 (0)