Skip to content

Commit f8936cf

Browse files
committed
[cortex-m] Add scalar c++ op for dequantize_per_tensor
Only buck build for now, CMake is next. No MVE, scalar only. Strictly the dtypes we care about update arg_meta to reflect that. Differential Revision: [D73164576](https://our.internmc.facebook.com/intern/diff/D73164576/) ghstack-source-id: 278739851 Pull Request resolved: #10267
1 parent 8533de9 commit f8936cf

File tree

5 files changed

+203
-0
lines changed

5 files changed

+203
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
#include <cinttypes>
11+
12+
// Check for Helium/MVE support
13+
#if defined(__ARM_FEATURE_MVE) && (__ARM_FEATURE_MVE & 1)
14+
#include <arm_mve.h>
15+
#define HAS_HELIUM_SIMD 1
16+
#endif
17+
18+
namespace cortex_m {
19+
namespace native {
20+
21+
using Tensor = executorch::aten::Tensor;
22+
using ScalarType = executorch::aten::ScalarType;
23+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
24+
25+
namespace {
26+
27+
/**
28+
* Asserts that the parameters are valid for float to int8 quantization.
29+
*/
30+
void check_dequantize_args(
31+
const Tensor& input,
32+
int64_t quant_min,
33+
int64_t quant_max,
34+
ScalarType dtype,
35+
Tensor& out) {
36+
// Ensure input is char type
37+
ET_CHECK_MSG(
38+
input.scalar_type() == ScalarType::Char,
39+
"input.scalar_type() %" PRId8 " is not char type",
40+
static_cast<int8_t>(input.scalar_type()));
41+
42+
// Check output dtype is float
43+
ET_CHECK_MSG(
44+
out.scalar_type() == ScalarType::Float,
45+
"out.scalar_type() %" PRId8 " is not float",
46+
static_cast<int8_t>(out.scalar_type()));
47+
48+
// Check dtype is int8 (Char)
49+
ET_CHECK_MSG(
50+
dtype == ScalarType::Char,
51+
"dtype %" PRId8 " is not int8 (Char)",
52+
static_cast<int8_t>(dtype));
53+
54+
// Validate quant_min and quant_max for int8
55+
int32_t quant_min_lower_bound = std::numeric_limits<int8_t>::min();
56+
int32_t quant_max_upper_bound = std::numeric_limits<int8_t>::max();
57+
58+
ET_CHECK_MSG(
59+
quant_min >= quant_min_lower_bound,
60+
"quant_min out of bound for int8, expected quant_min_lower_bound: %" PRId32
61+
" actual quant_min: %" PRId64,
62+
quant_min_lower_bound,
63+
quant_min);
64+
65+
ET_CHECK_MSG(
66+
quant_max <= quant_max_upper_bound,
67+
"quant_max out of bound for int8, expected quant_max_upper_bound: %" PRId32
68+
" actual quant_max: %" PRId64,
69+
quant_max_upper_bound,
70+
quant_max);
71+
}
72+
73+
/**
74+
* Scalar implementation of quantization for a single value.
75+
*/
76+
template <typename K, typename T>
77+
T dequantize_val(
78+
float scale,
79+
int32_t zero_point,
80+
K value,
81+
int64_t quant_min,
82+
int64_t quant_max) {
83+
(void) quant_min;
84+
(void) quant_max;
85+
return static_cast<T>((static_cast<int32_t>(value) - zero_point) * scale);
86+
}
87+
88+
} // namespace
89+
90+
Tensor& dequantize_per_tensor_out(
91+
KernelRuntimeContext& context,
92+
const Tensor& input,
93+
double scale,
94+
int64_t zero_point,
95+
int64_t quant_min,
96+
int64_t quant_max,
97+
ScalarType dtype,
98+
Tensor& out) {
99+
// Ignore context for now
100+
(void)context;
101+
102+
// Resize output tensor to match input dimensions
103+
torch::executor::Error err = resize_tensor(out, input.sizes());
104+
ET_CHECK_MSG(
105+
err == torch::executor::Error::Ok,
106+
"Failed to resize out Tensor in dequantize_per_tensor_out");
107+
108+
// Validate input parameters
109+
check_dequantize_args(input, quant_min, quant_max, dtype, out);
110+
111+
// Pre-compute inverse scale for better performance
112+
int32_t zp = static_cast<int32_t>(zero_point);
113+
int32_t qmin = static_cast<int32_t>(quant_min);
114+
int32_t qmax = static_cast<int32_t>(quant_max);
115+
116+
// Get pointers to input and output data
117+
const int8_t* input_data = input.const_data_ptr<int8_t>();
118+
float* out_data = out.mutable_data_ptr<float>();
119+
const size_t numel = input.numel();
120+
121+
#if defined(HAS_HELIUM_SIMD)
122+
// Helium MVE implementation for float32 to int8 quantization
123+
#Error "Implement MVE version!"
124+
#else
125+
// Scalar implementation for float32 to int8 quantization
126+
for (size_t i = 0; i < numel; i++) {
127+
out_data[i] = dequantize_val<int8_t, float>(scale, zp, input_data[i], qmin, qmax);
128+
}
129+
#endif
130+
131+
return out;
132+
}
133+
134+
Tensor& dequantize_per_tensor_out(
135+
const Tensor& input,
136+
double scale,
137+
int64_t zero_point,
138+
int64_t quant_min,
139+
int64_t quant_max,
140+
ScalarType dtype,
141+
Tensor& out) {
142+
KernelRuntimeContext context;
143+
return dequantize_per_tensor_out(context, input, scale, zero_point, quant_min, quant_max, dtype, out);
144+
}
145+
146+
} // namespace native
147+
} // namespace cortex_m

backends/cortex_m/ops/operators.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,9 @@
99
kernels:
1010
- arg_meta: null
1111
kernel_name: cortex_m::quantize_per_tensor_out
12+
13+
- func: cortex_m::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
14+
variants: function
15+
kernels:
16+
- arg_meta: null
17+
kernel_name: cortex_m::dequantize_per_tensor_out

backends/cortex_m/ops/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def define_operator_target(name: str):
2424

2525
OPERATORS = [
2626
"quantize_per_tensor",
27+
"dequantize_per_tensor",
2728
]
2829

2930
def define_common_targets():
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cortex_m/ops/NativeFunctions.h> // Declares the operator
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14+
#include <gtest/gtest.h>
15+
16+
using executorch::aten::ScalarType;
17+
using executorch::aten::Tensor;
18+
using executorch::runtime::KernelRuntimeContext;
19+
using torch::executor::testing::TensorFactory;
20+
21+
// Test op
22+
using cortex_m::native::dequantize_per_tensor_out;
23+
24+
void test_dtype() {
25+
TensorFactory<ScalarType::Char> tf;
26+
27+
Tensor input = tf.full({3, 5}, 4);
28+
double scale = 0.5;
29+
30+
int64_t zero_point = 108;
31+
int64_t quant_min = -128;
32+
int64_t quant_max = 127;
33+
34+
TensorFactory<ScalarType::Float> tfo;
35+
Tensor out = tfo.zeros({3, 5});
36+
// (4 - 108) * 0.5 = -52
37+
Tensor expected = tfo.full({3, 5}, -52.0);
38+
39+
KernelRuntimeContext ctx;
40+
dequantize_per_tensor_out(
41+
ctx, input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
42+
43+
EXPECT_TENSOR_EQ(out, expected);
44+
}
45+
46+
TEST(OpDequantizeOutTest, AllDtypesSupported) {
47+
test_dtype();
48+
}

backends/cortex_m/test/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
88

99
OPERATORS = [
1010
"quantize_per_tensor",
11+
"dequantize_per_tensor",
1112
]
1213

1314
def define_operator_test_target(op):

0 commit comments

Comments
 (0)