diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp new file mode 100644 index 00000000000..ec839cdf6bf --- /dev/null +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include +#include + +#include "test_utils.h" + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +std::tuple choose_qparams_tensor_out( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ET_UNUSED double eps, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out); + +std::tuple choose_qparams_per_token_asymmetric_out( + const Tensor& input, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out); + +// Wrapper function for choose_qparams_tensor_out without context +Tensor& choose_qparams_tensor_out_no_context( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ET_UNUSED double eps, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out) { + torch::executor::native::choose_qparams_tensor_out( + input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out); + return scale_out; +} + +// Wrapper function for choose_qparams_per_token_asymmetric_out without context +Tensor& choose_qparams_per_token_asymmetric_out_no_context( + const Tensor& input, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out) { + torch::executor::native::choose_qparams_per_token_asymmetric_out( + input, dtype, scale_out, zero_point_out); + return scale_out; +} + +// ATen wrapper for choose_qparams_tensor +std::tuple choose_qparams_tensor_aten( + const at::Tensor& input, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); + auto zero_point_out = at::empty({}, at::device(at::kCPU).dtype(at::kLong)); + double eps = 1e-7; + + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + // Use WRAP_TO_ATEN with the wrapper function + WRAP_TO_ATEN(choose_qparams_tensor_out_no_context, 5) + (input, quant_min, quant_max, eps, et_dtype, scale_out, zero_point_out); + + return {scale_out, zero_point_out}; +} + +// ATen wrapper for choose_qparams_per_token_asymmetric +std::tuple choose_qparams_per_token_asymmetric_aten( + const at::Tensor& input, + at::ScalarType dtype) { + // Calculate output sizes for scale and zero_point tensors + std::vector output_sizes; + for (int64_t i = 0; i < input.dim() - 1; i++) { + output_sizes.push_back(input.size(i)); + } + output_sizes.push_back(1); + + auto scale_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); + auto zero_point_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); + + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + // Use WRAP_TO_ATEN with the wrapper function + WRAP_TO_ATEN(choose_qparams_per_token_asymmetric_out_no_context, 2) + (input, et_dtype, scale_out, zero_point_out); + + return {scale_out, zero_point_out}; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index a22f2323896..0d014c7ef29 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -195,6 +195,15 @@ def define_common_targets(is_fbcode = False): "//executorch/extension/aten_util:aten_bridge", ] ) + define_test_targets( + "choose_qparams_test", + extra_deps = [ + ":test_utils", + "//executorch/kernels/quantized/cpu:op_choose_qparams", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) define_test_targets( "linear_weight_int4_test", extra_deps = [