Skip to content

Commit afc22c6

Browse files
authored
[cortex-m] Add scalar c++ op for quantize_per_tensor
Differential Revision: D73141767 Pull Request resolved: #10266
1 parent 10f6563 commit afc22c6

File tree

7 files changed

+336
-5
lines changed

7 files changed

+336
-5
lines changed

backends/cortex_m/ops/TARGETS

+5-4
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
8-
load("@fbcode_macros//build_defs:export_files.bzl", "export_file")
9-
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib")
8+
load("targets.bzl", "define_common_targets")
109

1110
oncall("executorch")
1211

@@ -17,5 +16,7 @@ python_library(
1716
],
1817
deps = [
1918
"fbcode//caffe2:torch",
20-
]
21-
)
19+
],
20+
)
21+
22+
define_common_targets()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
#include <algorithm>
11+
#include <cinttypes>
12+
#include <cmath>
13+
14+
// Check for Helium/MVE support
15+
#if defined(__ARM_FEATURE_MVE) && (__ARM_FEATURE_MVE & 1)
16+
#include <arm_mve.h>
17+
#define HAS_HELIUM_SIMD 1
18+
#endif
19+
20+
namespace cortex_m {
21+
namespace native {
22+
23+
using Tensor = executorch::aten::Tensor;
24+
using ScalarType = executorch::aten::ScalarType;
25+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
26+
27+
namespace {
28+
29+
/**
30+
* Asserts that the parameters are valid for float to int8 quantization.
31+
*/
32+
void check_quantize_args(
33+
const Tensor& input,
34+
int64_t quant_min,
35+
int64_t quant_max,
36+
ScalarType dtype,
37+
Tensor& out) {
38+
// Ensure input is float type
39+
ET_CHECK_MSG(
40+
input.scalar_type() == ScalarType::Float,
41+
"input.scalar_type() %" PRId8 " is not float type",
42+
static_cast<int8_t>(input.scalar_type()));
43+
44+
// Check output dtype is int8 (Char)
45+
ET_CHECK_MSG(
46+
out.scalar_type() == ScalarType::Char,
47+
"out.scalar_type() %" PRId8 " is not int8 (Char)",
48+
static_cast<int8_t>(out.scalar_type()));
49+
50+
// Check dtype is int8 (Char)
51+
ET_CHECK_MSG(
52+
dtype == ScalarType::Char,
53+
"dtype %" PRId8 " is not int8 (Char)",
54+
static_cast<int8_t>(dtype));
55+
56+
// Validate quant_min and quant_max for int8
57+
int32_t quant_min_lower_bound = std::numeric_limits<int8_t>::min();
58+
int32_t quant_max_upper_bound = std::numeric_limits<int8_t>::max();
59+
60+
ET_CHECK_MSG(
61+
quant_min >= quant_min_lower_bound,
62+
"quant_min out of bound for int8, expected quant_min_lower_bound: %" PRId32
63+
" actual quant_min: %" PRId64,
64+
quant_min_lower_bound,
65+
quant_min);
66+
67+
ET_CHECK_MSG(
68+
quant_max <= quant_max_upper_bound,
69+
"quant_max out of bound for int8, expected quant_max_upper_bound: %" PRId32
70+
" actual quant_max: %" PRId64,
71+
quant_max_upper_bound,
72+
quant_max);
73+
}
74+
75+
/**
76+
* Scalar implementation of quantization for a single value.
77+
*/
78+
template <typename T, typename K>
79+
T quantize_val(
80+
float inv_scale,
81+
int32_t zero_point,
82+
K value,
83+
int64_t quant_min,
84+
int64_t quant_max) {
85+
int32_t qvalue =
86+
zero_point + static_cast<int32_t>(std::nearbyint(inv_scale * value));
87+
qvalue = std::max<int32_t>(qvalue, static_cast<int32_t>(quant_min));
88+
qvalue = std::min<int32_t>(qvalue, static_cast<int32_t>(quant_max));
89+
return static_cast<T>(qvalue);
90+
}
91+
92+
} // namespace
93+
94+
Tensor& quantize_per_tensor_out(
95+
KernelRuntimeContext& context,
96+
const Tensor& input,
97+
double scale,
98+
int64_t zero_point,
99+
int64_t quant_min,
100+
int64_t quant_max,
101+
ScalarType dtype,
102+
Tensor& out) {
103+
// Ignore context for now
104+
(void)context;
105+
106+
// Resize output tensor to match input dimensions
107+
torch::executor::Error err = resize_tensor(out, input.sizes());
108+
ET_CHECK_MSG(
109+
err == torch::executor::Error::Ok,
110+
"Failed to resize out Tensor in quantize_per_tensor_out");
111+
112+
// Validate input parameters
113+
check_quantize_args(input, quant_min, quant_max, dtype, out);
114+
115+
// Pre-compute inverse scale for better performance
116+
float inv_scale = 1.0f / static_cast<float>(scale);
117+
int32_t zp = static_cast<int32_t>(zero_point);
118+
int32_t qmin = static_cast<int32_t>(quant_min);
119+
int32_t qmax = static_cast<int32_t>(quant_max);
120+
121+
// Get pointers to input and output data
122+
const float* input_data = input.const_data_ptr<float>();
123+
int8_t* out_data = out.mutable_data_ptr<int8_t>();
124+
const size_t numel = input.numel();
125+
126+
#if defined(HAS_HELIUM_SIMD)
127+
// Helium MVE implementation for float32 to int8 quantization
128+
#Error "Implement MVE version!"
129+
#else
130+
// Scalar implementation for float32 to int8 quantization
131+
for (size_t i = 0; i < numel; i++) {
132+
out_data[i] =
133+
quantize_val<int8_t, float>(inv_scale, zp, input_data[i], qmin, qmax);
134+
}
135+
#endif
136+
137+
return out;
138+
}
139+
140+
Tensor& quantize_per_tensor_out(
141+
const Tensor& input,
142+
double scale,
143+
int64_t zero_point,
144+
int64_t quant_min,
145+
int64_t quant_max,
146+
ScalarType dtype,
147+
Tensor& out) {
148+
KernelRuntimeContext context;
149+
return quantize_per_tensor_out(
150+
context, input, scale, zero_point, quant_min, quant_max, dtype, out);
151+
}
152+
153+
} // namespace native
154+
} // namespace cortex_m

backends/cortex_m/ops/operators.yaml

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
- func: cortex_m::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
8+
variants: function
9+
kernels:
10+
- arg_meta: null
11+
kernel_name: cortex_m::quantize_per_tensor_out

backends/cortex_m/ops/targets.bzl

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX")
8+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
9+
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib")
10+
load("@fbcode_macros//build_defs:export_files.bzl", "export_file")
11+
12+
def define_operator_target(name: str):
13+
runtime.cxx_library(
14+
name = "op_{}".format(name),
15+
srcs = [
16+
"op_{}.cpp".format(name),
17+
],
18+
platforms = CXX,
19+
deps = [
20+
"//executorch/runtime/kernel:kernel_includes"
21+
],
22+
link_whole = True,
23+
)
24+
25+
OPERATORS = [
26+
"quantize_per_tensor",
27+
]
28+
29+
def define_common_targets():
30+
"""Defines targets that should be shared between fbcode and xplat.
31+
32+
The directory containing this targets.bzl file should also contain both
33+
TARGETS and BUCK files that call this function.
34+
"""
35+
for op in OPERATORS:
36+
define_operator_target(op)
37+
38+
all_op_targets = [":op_{}".format(op) for op in OPERATORS]
39+
40+
runtime.cxx_library(
41+
name = "cortex_m_operators",
42+
srcs = [],
43+
visibility = [
44+
"//executorch/...",
45+
"@EXECUTORCH_CLIENTS",
46+
],
47+
exported_deps = all_op_targets,
48+
)
49+
50+
export_file(name = "operators.yaml")
51+
52+
et_operator_library(
53+
name = "ops_lib",
54+
_is_external_target = True,
55+
ops_schema_yaml_target = ":operators.yaml",
56+
)
57+
58+
executorch_generated_lib(
59+
name = "cortex_m_generated_lib",
60+
deps = [
61+
":ops_lib",
62+
":cortex_m_operators",
63+
],
64+
functions_yaml_target = ":operators.yaml",
65+
platforms = CXX,
66+
visibility = ["PUBLIC"],
67+
define_static_targets = True,
68+
)

backends/cortex_m/test/TARGETS

+6-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
8+
load("targets.bzl", "define_common_targets")
9+
10+
oncall("executorch")
811

912
python_unittest(
1013
name = "test_replace_quant_nodes",
@@ -15,4 +18,6 @@ python_unittest(
1518
"//executorch/backends/cortex_m/passes:replace_quant_nodes_pass",
1619
"//executorch/backends/cortex_m/ops:ops",
1720
],
18-
)
21+
)
22+
23+
define_common_targets()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cortex_m/ops/NativeFunctions.h> // Declares the operator
10+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
11+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
#include <gtest/gtest.h>
15+
16+
using executorch::aten::ScalarType;
17+
using executorch::aten::Tensor;
18+
using executorch::runtime::KernelRuntimeContext;
19+
using torch::executor::testing::TensorFactory;
20+
21+
// Test op
22+
using cortex_m::native::quantize_per_tensor_out;
23+
24+
void test_dtype() {
25+
TensorFactory<ScalarType::Float> tf;
26+
27+
Tensor input = tf.full({3, 5}, 4);
28+
double scale = 0.5;
29+
30+
int64_t zero_point = 108;
31+
int64_t quant_min = 0;
32+
int64_t quant_max = 127;
33+
34+
TensorFactory<ScalarType::Char> tfo;
35+
Tensor out = tfo.zeros({3, 5});
36+
// 4 / 0.5 + 108 = 116
37+
Tensor expected = tfo.full({3, 5}, 116);
38+
39+
KernelRuntimeContext ctx;
40+
quantize_per_tensor_out(
41+
ctx,
42+
input,
43+
scale,
44+
zero_point,
45+
quant_min,
46+
quant_max,
47+
ScalarType::Char,
48+
out);
49+
50+
EXPECT_TENSOR_EQ(out, expected);
51+
}
52+
53+
TEST(OpQuantizeOutTest, AllDtypesSupported) {
54+
test_dtype();
55+
}

backends/cortex_m/test/targets.bzl

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
8+
9+
OPERATORS = [
10+
"quantize_per_tensor",
11+
]
12+
13+
def define_operator_test_target(op):
14+
runtime.cxx_test(
15+
name = "op_{}_test".format(op),
16+
srcs = [
17+
"op_{}_test.cpp".format(op),
18+
],
19+
deps = [
20+
"//executorch/runtime/kernel:kernel_includes",
21+
"//executorch/kernels/test:test_util",
22+
"//executorch/backends/cortex_m/ops:op_{}".format(op),
23+
"//executorch/backends/cortex_m/ops:cortex_m_generated_lib",
24+
"//executorch/backends/cortex_m/ops:cortex_m_generated_lib_headers",
25+
]
26+
)
27+
28+
def define_common_targets():
29+
"""Defines targets that should be shared between fbcode and xplat.
30+
31+
The directory containing this targets.bzl file should also contain both
32+
TARGETS and BUCK files that call this function.
33+
"""
34+
for op in OPERATORS:
35+
define_operator_test_target(op)
36+
37+

0 commit comments

Comments
 (0)