Skip to content

[cortex-m] Add scalar c++ op for dequantize_per_tensor #10383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions backends/cortex_m/ops/op_dequantize_per_tensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/kernel/kernel_includes.h>
#include <cinttypes>

// Check for Helium/MVE support
#if defined(__ARM_FEATURE_MVE) && (__ARM_FEATURE_MVE & 1)
#include <arm_mve.h>
#define HAS_HELIUM_SIMD 1
#endif

namespace cortex_m {
namespace native {

using Tensor = executorch::aten::Tensor;
using ScalarType = executorch::aten::ScalarType;
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;

namespace {

/**
* Asserts that the parameters are valid for float to int8 quantization.
*/
void check_dequantize_args(
const Tensor& input,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
Tensor& out) {
// Ensure input is char type
ET_CHECK_MSG(
input.scalar_type() == ScalarType::Char,
"input.scalar_type() %" PRId8 " is not char type",
static_cast<int8_t>(input.scalar_type()));

// Check output dtype is float
ET_CHECK_MSG(
out.scalar_type() == ScalarType::Float,
"out.scalar_type() %" PRId8 " is not float",
static_cast<int8_t>(out.scalar_type()));

// Check dtype is int8 (Char)
ET_CHECK_MSG(
dtype == ScalarType::Char,
"dtype %" PRId8 " is not int8 (Char)",
static_cast<int8_t>(dtype));

// Validate quant_min and quant_max for int8
int32_t quant_min_lower_bound = std::numeric_limits<int8_t>::min();
int32_t quant_max_upper_bound = std::numeric_limits<int8_t>::max();

ET_CHECK_MSG(
quant_min >= quant_min_lower_bound,
"quant_min out of bound for int8, expected quant_min_lower_bound: %" PRId32
" actual quant_min: %" PRId64,
quant_min_lower_bound,
quant_min);

ET_CHECK_MSG(
quant_max <= quant_max_upper_bound,
"quant_max out of bound for int8, expected quant_max_upper_bound: %" PRId32
" actual quant_max: %" PRId64,
quant_max_upper_bound,
quant_max);
}

/**
* Scalar implementation of quantization for a single value.
*/
template <typename K, typename T>
T dequantize_val(
float scale,
int32_t zero_point,
K value,
int64_t quant_min,
int64_t quant_max) {
(void)quant_min;
(void)quant_max;
return static_cast<T>((static_cast<int32_t>(value) - zero_point) * scale);
}

} // namespace

Tensor& dequantize_per_tensor_out(
KernelRuntimeContext& context,
const Tensor& input,
double scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
Tensor& out) {
// Ignore context for now
(void)context;

// Resize output tensor to match input dimensions
torch::executor::Error err = resize_tensor(out, input.sizes());
ET_CHECK_MSG(
err == torch::executor::Error::Ok,
"Failed to resize out Tensor in dequantize_per_tensor_out");

// Validate input parameters
check_dequantize_args(input, quant_min, quant_max, dtype, out);

// Pre-compute inverse scale for better performance
int32_t zp = static_cast<int32_t>(zero_point);
int32_t qmin = static_cast<int32_t>(quant_min);
int32_t qmax = static_cast<int32_t>(quant_max);

// Get pointers to input and output data
const int8_t* input_data = input.const_data_ptr<int8_t>();
float* out_data = out.mutable_data_ptr<float>();
const size_t numel = input.numel();

#if defined(HAS_HELIUM_SIMD)
// Helium MVE implementation for float32 to int8 quantization
#Error "Implement MVE version!"
#else
// Scalar implementation for float32 to int8 quantization
for (size_t i = 0; i < numel; i++) {
out_data[i] =
dequantize_val<int8_t, float>(scale, zp, input_data[i], qmin, qmax);
}
#endif

return out;
}

Tensor& dequantize_per_tensor_out(
const Tensor& input,
double scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
Tensor& out) {
KernelRuntimeContext context;
return dequantize_per_tensor_out(
context, input, scale, zero_point, quant_min, quant_max, dtype, out);
}

} // namespace native
} // namespace cortex_m
6 changes: 6 additions & 0 deletions backends/cortex_m/ops/operators.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,9 @@
kernels:
- arg_meta: null
kernel_name: cortex_m::quantize_per_tensor_out

- func: cortex_m::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
kernel_name: cortex_m::dequantize_per_tensor_out
1 change: 1 addition & 0 deletions backends/cortex_m/ops/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def define_operator_target(name: str):

OPERATORS = [
"quantize_per_tensor",
"dequantize_per_tensor",
]

def define_common_targets():
Expand Down
55 changes: 55 additions & 0 deletions backends/cortex_m/test/op_dequantize_per_tensor_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cortex_m/ops/NativeFunctions.h> // Declares the operator
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <gtest/gtest.h>

using executorch::aten::ScalarType;
using executorch::aten::Tensor;
using executorch::runtime::KernelRuntimeContext;
using torch::executor::testing::TensorFactory;

// Test op
using cortex_m::native::dequantize_per_tensor_out;

void test_dtype() {
TensorFactory<ScalarType::Char> tf;

Tensor input = tf.full({3, 5}, 4);
double scale = 0.5;

int64_t zero_point = 108;
int64_t quant_min = -128;
int64_t quant_max = 127;

TensorFactory<ScalarType::Float> tfo;
Tensor out = tfo.zeros({3, 5});
// (4 - 108) * 0.5 = -52
Tensor expected = tfo.full({3, 5}, -52.0);

KernelRuntimeContext ctx;
dequantize_per_tensor_out(
ctx,
input,
scale,
zero_point,
quant_min,
quant_max,
ScalarType::Char,
out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpDequantizeOutTest, AllDtypesSupported) {
test_dtype();
}
1 change: 1 addition & 0 deletions backends/cortex_m/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

OPERATORS = [
"quantize_per_tensor",
"dequantize_per_tensor",
]

def define_operator_test_target(op):
Expand Down
Loading