Skip to content

Commit a9a6301

Browse files
authored
[ET-VK][Ops] quantize ops skeleton test framework
Differential Revision: D75959066 Pull Request resolved: #11366
1 parent 7377b80 commit a9a6301

File tree

2 files changed

+167
-0
lines changed

2 files changed

+167
-0
lines changed
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <ATen/ATen.h>
12+
13+
#include <executorch/backends/vulkan/runtime/api/api.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16+
17+
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
18+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
19+
20+
#include "test_utils.h"
21+
22+
#include <cassert>
23+
#include <iostream>
24+
25+
namespace torch {
26+
namespace executor {
27+
namespace native {
28+
29+
// Forward declarations of the functions we're testing
30+
Tensor& quantize_per_tensor_out(
31+
const Tensor& input,
32+
double scale,
33+
int64_t zero_point,
34+
int64_t quant_min,
35+
int64_t quant_max,
36+
ScalarType dtype,
37+
Tensor& out);
38+
39+
Tensor& quantize_per_token_out(
40+
const Tensor& input,
41+
const Tensor& scale,
42+
const Tensor& zero_point,
43+
int64_t quant_min,
44+
int64_t quant_max,
45+
ScalarType dtype,
46+
Tensor& out);
47+
48+
// Wrapper function for quantize_per_tensor_out without context
49+
Tensor& quantize_per_tensor_out_no_context(
50+
const Tensor& input,
51+
double scale,
52+
int64_t zero_point,
53+
int64_t quant_min,
54+
int64_t quant_max,
55+
ScalarType dtype,
56+
Tensor& out) {
57+
return torch::executor::native::quantize_per_tensor_out(
58+
input, scale, zero_point, quant_min, quant_max, dtype, out);
59+
}
60+
61+
// Wrapper function for quantize_per_token_out without context
62+
Tensor& quantize_per_token_out_no_context(
63+
const Tensor& input,
64+
const Tensor& scale,
65+
const Tensor& zero_point,
66+
int64_t quant_min,
67+
int64_t quant_max,
68+
ScalarType dtype,
69+
Tensor& out) {
70+
return torch::executor::native::quantize_per_token_out(
71+
input, scale, zero_point, quant_min, quant_max, dtype, out);
72+
}
73+
74+
// ATen wrapper for quantize_per_tensor
75+
at::Tensor quantize_per_tensor_aten(
76+
const at::Tensor& input,
77+
double scale,
78+
int64_t zero_point,
79+
int64_t quant_min,
80+
int64_t quant_max,
81+
at::ScalarType dtype) {
82+
auto out = at::empty_like(input, dtype);
83+
ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype);
84+
85+
WRAP_TO_ATEN(quantize_per_tensor_out_no_context, 6)
86+
(input, scale, zero_point, quant_min, quant_max, et_dtype, out);
87+
return out;
88+
}
89+
90+
// ATen wrapper for quantize_per_token
91+
at::Tensor quantize_per_token_aten(
92+
const at::Tensor& input,
93+
const at::Tensor& scale,
94+
const at::Tensor& zero_point,
95+
int64_t quant_min,
96+
int64_t quant_max,
97+
at::ScalarType dtype) {
98+
auto out = at::empty_like(input, dtype);
99+
ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype);
100+
101+
WRAP_TO_ATEN(quantize_per_token_out_no_context, 6)
102+
(input, scale, zero_point, quant_min, quant_max, et_dtype, out);
103+
return out;
104+
}
105+
106+
} // namespace native
107+
} // namespace executor
108+
} // namespace torch
109+
110+
void check_quantize_args(
111+
int64_t quant_min,
112+
int64_t quant_max,
113+
c10::ScalarType out_dtype) {
114+
using namespace vkcompute;
115+
int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0;
116+
switch (out_dtype) {
117+
case c10::kByte:
118+
quant_min_lower_bound =
119+
static_cast<int32_t>(std::numeric_limits<uint8_t>::min());
120+
quant_max_upper_bound =
121+
static_cast<int32_t>(std::numeric_limits<uint8_t>::max());
122+
break;
123+
case c10::kChar:
124+
quant_min_lower_bound =
125+
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
126+
quant_max_upper_bound =
127+
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
128+
break;
129+
case c10::kBits16:
130+
case c10::kUInt16:
131+
quant_min_lower_bound = std::numeric_limits<uint16_t>::min();
132+
quant_max_upper_bound = std::numeric_limits<uint16_t>::max();
133+
break;
134+
case c10::kShort:
135+
quant_min_lower_bound = std::numeric_limits<int16_t>::min();
136+
quant_max_upper_bound = std::numeric_limits<int16_t>::max();
137+
break;
138+
case c10::kInt:
139+
quant_min_lower_bound = std::numeric_limits<int32_t>::min();
140+
quant_max_upper_bound = std::numeric_limits<int32_t>::max();
141+
break;
142+
default:
143+
VK_CHECK_COND(false, "Unsupported dtype: ", scalar_type_name(out_dtype));
144+
}
145+
VK_CHECK_COND(
146+
quant_min >= quant_min_lower_bound,
147+
"quant_min out of bound for dtype, expected quant_min_lower_bound: ",
148+
quant_min_lower_bound,
149+
" actual quant_min: ",
150+
quant_min);
151+
152+
VK_CHECK_COND(
153+
quant_max <= quant_max_upper_bound,
154+
"quant_max out of bound for dtype, expected quant_max_upper_bound: ",
155+
quant_max_upper_bound,
156+
" actual quant_max: ",
157+
quant_max);
158+
}

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,15 @@ def define_common_targets(is_fbcode = False):
177177
"//executorch/extension/tensor:tensor",
178178
]
179179
)
180+
define_test_targets(
181+
"quantize_test",
182+
extra_deps = [
183+
":test_utils",
184+
"//executorch/kernels/quantized/cpu:op_quantize",
185+
"//executorch/extension/tensor:tensor",
186+
"//executorch/extension/aten_util:aten_bridge",
187+
]
188+
)
180189
define_test_targets(
181190
"linear_weight_int4_test",
182191
extra_deps = [

0 commit comments

Comments
 (0)