diff --git a/tensorflow/lite/micro/kernels/comparisons.cc b/tensorflow/lite/micro/kernels/comparisons.cc index 69b3c61c32d..ca308a59e9d 100644 --- a/tensorflow/lite/micro/kernels/comparisons.cc +++ b/tensorflow/lite/micro/kernels/comparisons.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -286,6 +286,19 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorData(input2), output_shape, output_data); break; + case kTfLiteInt16: + requires_broadcast + ? reference_ops::Broadcast4DSlowGreaterWithScaling( + data->params, input1_shape, + tflite::micro::GetTensorData(input1), input2_shape, + tflite::micro::GetTensorData(input2), output_shape, + output_data) + : reference_ops::GreaterWithScaling( + data->params, input1_shape, + tflite::micro::GetTensorData(input1), input2_shape, + tflite::micro::GetTensorData(input2), output_shape, + output_data); + break; default: MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input1->type), input1->type); diff --git a/tensorflow/lite/micro/kernels/comparisons_test.cc b/tensorflow/lite/micro/kernels/comparisons_test.cc index eec57d62ae4..f1342fe5cdc 100644 --- a/tensorflow/lite/micro/kernels/comparisons_test.cc +++ b/tensorflow/lite/micro/kernels/comparisons_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -126,6 +126,29 @@ void TestComparisonQuantizedInt8(const TFLMRegistration& registration, TestComparison(registration, tensors, expected_output_data, output_data); } +void TestComparisonQuantizedInt16(const TFLMRegistration& registration, + int* input1_dims_data, float* input1_data, + int16_t* input1_quantized, float input1_scale, + int input1_zero_point, int* input2_dims_data, + float* input2_data, int16_t* input2_quantized, + float input2_scale, int input2_zero_point, + bool* expected_output_data, + int* output_dims_data, bool* output_data) { + TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data); + TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input1_data, input1_quantized, input1_dims, + input1_scale, input1_zero_point), + CreateQuantizedTensor(input2_data, input2_quantized, input2_dims, + input2_scale, input2_zero_point), + CreateTensor(output_data, output_dims), + }; + + TestComparison(registration, tensors, expected_output_data, output_data); +} + } // namespace } // namespace testing } // namespace tflite @@ -656,6 +679,35 @@ TF_LITE_MICRO_TEST(GreaterQuantizedInt8WithBroadcast) { } } +TF_LITE_MICRO_TEST(GreaterQuantizedInt16WithBroadcast) { + const int num_shapes = 4; + const int max_shape_size = 5; + int test_shapes[num_shapes][max_shape_size] = { + {1, 6}, {2, 2, 3}, {3, 2, 1, 3}, {4, 1, 3, 1, 2}}; + + for (int i = 0; i < num_shapes; ++i) { + int* input1_dim = test_shapes[i]; + int input2_dim[] = {1, 1}; + float input1_data[] = {20, -2, -71, 8, 11, 20}; + float input2_data[] = {8}; + + bool expected_data[] = {true, false, false, false, true, true}; + int* expected_dim = input1_dim; + + const float input1_scale = 0.5; + const int input1_zero_point = -9; + int16_t input1_quantized[6]; + int16_t input2_quantized[6]; + + bool output_data[6]; + tflite::testing::TestComparisonQuantizedInt16( + tflite::Register_GREATER(), input1_dim, input1_data, input1_quantized, + input1_scale, input1_zero_point, input2_dim, input2_data, + input2_quantized, input1_scale, input1_zero_point, expected_data, + expected_dim, output_data); + } +} + TF_LITE_MICRO_TEST(GreaterEqualQuantizedInt8WithBroadcast) { const int num_shapes = 4; const int max_shape_size = 5; diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index 6902728043f..6bf7665a81f 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -238,25 +238,97 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt16: { switch (filter->type) { case kTfLiteInt8: { - tflite::reference_integer_ops::FullyConnected( - FullyConnectedParamsQuantized(data), - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), + if (bias == nullptr || bias->type == kTfLiteInt32) { + data.is_per_channel + ? tflite::reference_integer_ops::FullyConnectedPerChannel( + FullyConnectedParamsQuantized(data), + data.per_channel_output_multiplier, + reinterpret_cast( + data.per_channel_output_shift), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), #ifdef USE_TFLM_COMPRESSION - tflite::micro::GetTensorData(micro_context, filter, - weights_comp_td, - data.weights_scratch_index), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData( - micro_context, bias, bias_comp_td, data.bias_scratch_index), + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), #else // USE_TFLM_COMPRESSION - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), #endif // USE_TFLM_COMPRESSION - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)) + : tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } else if (bias->type == kTfLiteInt64) { + data.is_per_channel + ? tflite::reference_integer_ops::FullyConnectedPerChannel( + FullyConnectedParamsQuantized(data), + data.per_channel_output_multiplier, + reinterpret_cast( + data.per_channel_output_shift), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)) + : tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } break; } default: { diff --git a/tensorflow/lite/micro/kernels/fully_connected_common.cc b/tensorflow/lite/micro/kernels/fully_connected_common.cc index 53709d366bf..9170d14d0a3 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_common.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -95,9 +95,14 @@ TfLiteStatus CalculateOpDataFullyConnected( filter->quantization.params); const int per_channel_quantization_size = affine_quantization->scale->size; - // Currently only Int8 is supported for per channel quantization. - TF_LITE_ENSURE(context, - input->type == kTfLiteInt8 && filter->type != kTfLiteInt4); + // Currently only Int8/Int16 are supported for per channel quantization. + TF_LITE_ENSURE( + context, + (input->type == kTfLiteInt8 && filter->type != kTfLiteInt4) || + (input->type == kTfLiteInt16 && filter->type != kTfLiteInt4)); + + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + per_channel_quantization_size); TF_LITE_ENSURE_EQ( context, per_channel_quantization_size, diff --git a/tensorflow/lite/micro/kernels/fully_connected_test.cc b/tensorflow/lite/micro/kernels/fully_connected_test.cc index 1197b105534..2ceed9ae983 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,6 +45,14 @@ const float simple_weights_data[] = { int simple_bias_dims[] = {1, 3}; const float simple_bias_data[] = {1, 2, 3}; +#if (defined(USE_TFLM_COMPRESSION) || (!defined(XTENSA) && !defined(HEXAGON))) + +constexpr size_t simple_bias_size = + std::extent::value; + +#endif // (defined(USE_TFLM_COMPRESSION) || (!defined(XTENSA) && + // !defined(HEXAGON))) + #ifdef USE_TFLM_COMPRESSION // compressed filter data for kBinQuant scheme @@ -60,8 +68,6 @@ constexpr int kBinQuantWeightBitWidth = 4; // Align the tensor data the same as a Buffer in the schema alignas(16) constexpr uint8_t kBinQuantBiasData[] = {0x18}; constexpr int kBinQuantBiasBitWidth = 2; -constexpr size_t simple_bias_size = - std::extent::value; #endif // USE_TFLM_COMPRESSION @@ -504,6 +510,58 @@ TfLiteStatus TestFullyConnectedQuantizedCompressed( #endif // USE_TFLM_COMPRESSION +template +TfLiteStatus TestFullyConnectedQuantizedPerChannel( + int* input_dims_data, const float* input_data, dataT* input_quantized, + const float input_scale, const int input_zero_point, int* weights_dims_data, + const float* weights_data, weightT* weights_quantized, + float* weights_scales, int* weights_zero_points, int* bias_dims_data, + const float* bias_data, biasT* bias_quantized, const float* golden, + dataT* golden_quantized, int* output_dims_data, const float output_scale, + const int output_zero_point, TfLiteFusedActivation activation, + dataT* output_data, TfLiteType weights_packed_type = kTfLiteNoType) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + bool null_bias = bias_data == nullptr ? true : false; + + constexpr int array_size = 4; // Avoid variable length array warning. + const int inputs_size = null_bias ? 2 : 3; + constexpr int outputs_size = 1; + const int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[array_size]; + TfLiteAffineQuantization weights_quant, bias_quant; + float bias_scales[5]; + int bias_zero_points[5]; + + tensors[0] = CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point); + tensors[1] = CreateSymmetricPerChannelQuantizedTensorWithoutScaleEstimation( + weights_data, weights_quantized, weights_dims, weights_scales, + weights_zero_points, &weights_quant, 0 /* quantized dimension */, false, + weights_packed_type); + + if (null_bias) { + tensors[2] = CreateQuantizedTensor(output_data, output_dims, output_scale, + output_zero_point); + } else { + tensors[2] = CreatePerChannelQuantizedBiasTensor( + bias_data, bias_quantized, bias_dims, input_scale, &weights_scales[1], + bias_scales, bias_zero_points, &bias_quant, + 0 /* quantized dimension */); + tensors[3] = CreateQuantizedTensor(output_data, output_dims, output_scale, + output_zero_point); + } + + Quantize(golden, golden_quantized, output_dims_count, output_scale, + output_zero_point); + return ValidateFullyConnectedGoldens( + tensors, tensors_size, null_bias, activation, 1.0f /* tolerance */, + output_dims_count, golden_quantized, output_data); +} + } // namespace } // namespace testing } // namespace tflite @@ -652,6 +710,40 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8Compressed) { #endif // USE_TFLM_COMPRESSION #if !defined(HEXAGON) + +#if !defined(XTENSA) + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt8) { + const float input_scale = 0.5f; + const int input_zero_point = -1; + const float output_scale = 1.0f; + const int output_zero_point = -1; + int weights_zero_points[tflite::testing::simple_bias_size + 1] = { + tflite::testing::simple_bias_size, 0, 0, 0}; + float weights_scales[tflite::testing::simple_bias_size + 1] = { + tflite::testing::simple_bias_size, 0.2f, 0.25f, 0.5f}; + + int8_t input_quantized[tflite::testing::simple_input_size]; + int8_t weights_quantized[tflite::testing::simple_weights_size]; + int32_t bias_quantized[tflite::testing::simple_output_size]; + int8_t golden_quantized[tflite::testing::simple_output_size]; + int8_t output_data[tflite::testing::simple_output_size]; + + TF_LITE_MICRO_EXPECT_EQ( + tflite::testing::TestFullyConnectedQuantizedPerChannel( + tflite::testing::simple_input_dims, + tflite::testing::simple_input_data, input_quantized, input_scale, + input_zero_point, tflite::testing::simple_weights_dims, + tflite::testing::simple_weights_data, weights_quantized, + weights_scales, weights_zero_points, + tflite::testing::simple_bias_dims, tflite::testing::simple_bias_data, + bias_quantized, tflite::testing::simple_golden, golden_quantized, + tflite::testing::simple_output_dims, output_scale, output_zero_point, + kTfLiteActNone, output_data), + kTfLiteOk); +} +#endif // #if !defined(XTENSA) + TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16) { const float input_scale = 128.0 / 65536; const int input_zero_point = 0; @@ -732,6 +824,40 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16Compressed) { #endif // USE_TFLM_COMPRESSION +#if !defined(XTENSA) && !defined(CMSIS_NN) + +TF_LITE_MICRO_TEST(SimpleTestPerChannelQuantizedInt16) { + const float input_scale = 128.0 / 65536; + const int input_zero_point = 0; + const float output_scale = 128.0 / 65536; + const int output_zero_point = 0; + int weights_zero_points[tflite::testing::simple_bias_size + 1] = { + tflite::testing::simple_bias_size, 0, 0, 0}; + float weights_scales[tflite::testing::simple_bias_size + 1] = { + tflite::testing::simple_bias_size, 0.2f, 0.25f, 0.5f}; + + int16_t input_quantized[tflite::testing::simple_input_size]; + int8_t weights_quantized[tflite::testing::simple_weights_size]; + int64_t bias_quantized[tflite::testing::simple_output_size]; + int16_t golden_quantized[tflite::testing::simple_output_size]; + int16_t output_data[tflite::testing::simple_output_size]; + + TF_LITE_MICRO_EXPECT_EQ( + tflite::testing::TestFullyConnectedQuantizedPerChannel( + tflite::testing::simple_input_dims, + tflite::testing::simple_input_data, input_quantized, input_scale, + input_zero_point, tflite::testing::simple_weights_dims, + tflite::testing::simple_weights_data, weights_quantized, + weights_scales, weights_zero_points, + tflite::testing::simple_bias_dims, tflite::testing::simple_bias_data, + bias_quantized, tflite::testing::simple_golden, golden_quantized, + tflite::testing::simple_output_dims, output_scale, output_zero_point, + kTfLiteActNone, output_data), + kTfLiteOk); +} + +#endif // !defined(XTENSA) && !defined(CMSIS_NN) + #endif // !defined(HEXAGON) TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedInt8) { @@ -916,6 +1042,6 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt4Weights) { output_zero_point, kTfLiteActNone, output_data, kTfLiteInt4), kTfLiteOk); } -#endif +#endif // !defined(HEXAGON) TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/transpose.cc b/tensorflow/lite/micro/kernels/transpose.cc index fd17e893937..70d53e2c449 100644 --- a/tensorflow/lite/micro/kernels/transpose.cc +++ b/tensorflow/lite/micro/kernels/transpose.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -103,10 +103,16 @@ TfLiteStatus TransposeEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; + case kTfLiteInt16: + reference_ops::Transpose(params, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; default: MicroPrintf( "Type %s is currently not supported by Transpose. " - "Only float32 and int8 is supported", + "Only float32, int8 and int16 are supported", TfLiteTypeGetName(input->type)); return kTfLiteError; } diff --git a/tensorflow/lite/micro/kernels/transpose_test.cc b/tensorflow/lite/micro/kernels/transpose_test.cc index 12bc431fbbd..14f433795e9 100644 --- a/tensorflow/lite/micro/kernels/transpose_test.cc +++ b/tensorflow/lite/micro/kernels/transpose_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -149,6 +149,20 @@ TF_LITE_MICRO_TEST(1D) { expected_output_data, output_data, ¶ms); } +TF_LITE_MICRO_TEST(1DInt16) { + int input_dims_data[] = {1, 3}; + int output_dims_data[] = {1, 3}; + + int16_t input_data[3]; + int16_t output_data[3]; + const int16_t expected_output_data[] = {0, 1, 2}; + + tflite::TransposeParams params = {1, {0}}; + + tflite::testing::TestTranspose(input_dims_data, input_data, output_dims_data, + expected_output_data, output_data, ¶ms); +} + TF_LITE_MICRO_TEST(2DPerm1) { int input_dims_data[] = {2, 3, 2}; int output_dims_data[] = {2, 3, 2}; @@ -163,6 +177,20 @@ TF_LITE_MICRO_TEST(2DPerm1) { expected_output_data, output_data, ¶ms); } +TF_LITE_MICRO_TEST(2DPerm1Int16) { + int input_dims_data[] = {2, 3, 2}; + int output_dims_data[] = {2, 3, 2}; + + int16_t input_data[6]; + int16_t output_data[6]; + const int16_t expected_output_data[] = {0, 2, 4, 1, 3, 5}; + + tflite::TransposeParams params = {2, {1, 0}}; + + tflite::testing::TestTranspose(input_dims_data, input_data, output_dims_data, + expected_output_data, output_data, ¶ms); +} + TF_LITE_MICRO_TEST(2D4x4KernelLeftOverRightSide) { int input_dims_data[] = {2, 4, 6}; int output_dims_data[] = {2, 4, 6}; @@ -179,6 +207,22 @@ TF_LITE_MICRO_TEST(2D4x4KernelLeftOverRightSide) { expected_output_data, output_data, ¶ms); } +TF_LITE_MICRO_TEST(2D4x4KernelLeftOverRightSideInt16) { + int input_dims_data[] = {2, 4, 6}; + int output_dims_data[] = {2, 4, 6}; + + int16_t input_data[24]; + int16_t output_data[24]; + const int16_t expected_output_data[] = {0, 6, 12, 18, 1, 7, 13, 19, + 2, 8, 14, 20, 3, 9, 15, 21, + 4, 10, 16, 22, 5, 11, 17, 23}; + + tflite::TransposeParams params = {2, {1, 0}}; + + tflite::testing::TestTranspose(input_dims_data, input_data, output_dims_data, + expected_output_data, output_data, ¶ms); +} + TF_LITE_MICRO_TEST(2D4x4KernelLeftOverBottomSide) { int input_dims_data[] = {2, 6, 4}; int output_dims_data[] = {2, 4, 6}; @@ -195,6 +239,22 @@ TF_LITE_MICRO_TEST(2D4x4KernelLeftOverBottomSide) { expected_output_data, output_data, ¶ms); } +TF_LITE_MICRO_TEST(2D4x4KernelLeftOverBottomSideInt16) { + int input_dims_data[] = {2, 6, 4}; + int output_dims_data[] = {2, 4, 6}; + + int16_t input_data[24]; + int16_t output_data[24]; + const int16_t expected_output_data[] = {0, 4, 8, 12, 16, 20, 1, 5, + 9, 13, 17, 21, 2, 6, 10, 14, + 18, 22, 3, 7, 11, 15, 19, 23}; + + tflite::TransposeParams params = {2, {1, 0}}; + + tflite::testing::TestTranspose(input_dims_data, input_data, output_dims_data, + expected_output_data, output_data, ¶ms); +} + TF_LITE_MICRO_TEST(3D) { int input_dims_data[] = {3, 2, 3, 4}; int output_dims_data[] = {3, 2, 3, 4}; @@ -211,6 +271,22 @@ TF_LITE_MICRO_TEST(3D) { expected_output_data, output_data, ¶ms); } +TF_LITE_MICRO_TEST(3DInt16) { + int input_dims_data[] = {3, 2, 3, 4}; + int output_dims_data[] = {3, 2, 3, 4}; + + int16_t input_data[24]; + int16_t output_data[24]; + const int16_t expected_output_data[] = {0, 4, 8, 12, 16, 20, 1, 5, + 9, 13, 17, 21, 2, 6, 10, 14, + 18, 22, 3, 7, 11, 15, 19, 23}; + + tflite::TransposeParams params = {3, {2, 0, 1}}; + + tflite::testing::TestTranspose(input_dims_data, input_data, output_dims_data, + expected_output_data, output_data, ¶ms); +} + TF_LITE_MICRO_TEST(1DNotShrinked) { int input_dims_data[] = {1, 1}; int output_dims_data[] = {1, 1}; diff --git a/tensorflow/lite/micro/kernels/unpack.cc b/tensorflow/lite/micro/kernels/unpack.cc index 9ce168384a4..8967decbec5 100644 --- a/tensorflow/lite/micro/kernels/unpack.cc +++ b/tensorflow/lite/micro/kernels/unpack.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -89,6 +89,9 @@ TfLiteStatus UnpackEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt8: { return UnpackImpl(context, node, input, data->num, data->axis); } + case kTfLiteInt16: { + return UnpackImpl(context, node, input, data->num, data->axis); + } default: { MicroPrintf("Type '%s' is not supported by unpack.", TfLiteTypeGetName(input->type)); diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc index 9faa991dc9e..ff786a7b0b1 100644 --- a/tensorflow/lite/micro/test_helpers.cc +++ b/tensorflow/lite/micro/test_helpers.cc @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -2055,6 +2055,31 @@ size_t GetModelTensorCount(const Model* model) { return 0; } +TfLiteTensor CreateSymmetricPerChannelQuantizedTensorWithoutScaleEstimation( + const float* input, int8_t* quantized, TfLiteIntArray* dims, float* scales, + int* zero_points, TfLiteAffineQuantization* affine_quant, + int quantized_dimension, bool is_variable, TfLiteType tensor_weight_type) { + int input_size = ElementCount(*dims); + int channel_count = dims->data[quantized_dimension]; + scales[0] = static_cast(channel_count); + zero_points[0] = channel_count; + + SymmetricPerChannelQuantize(input, quantized, input_size, + channel_count, &scales[1]); + + for (int i = 0; i < channel_count; i++) { + zero_points[i + 1] = 0; + } + + affine_quant->scale = FloatArrayFromFloats(scales); + affine_quant->zero_point = IntArrayFromInts(zero_points); + affine_quant->quantized_dimension = quantized_dimension; + TfLiteTensor result = + CreateTensor(quantized, dims, is_variable, tensor_weight_type); + result.quantization = {kTfLiteAffineQuantization, affine_quant}; + return result; +} + void PackInt4ValuesDenselyInPlace(uint8_t* src_buffer, int buffer_size) { for (int i = 0; i < buffer_size; ++i) { if (i % 2 == 0) { diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index 86eaf778f7b..f7bb3791415 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -376,6 +376,13 @@ TfLiteTensor CreateSymmetricPerChannelQuantizedTensor( int quantized_dimension, bool is_variable = false, TfLiteType tensor_weight_type = kTfLiteNoType); +// This function uses the scales provided to it and quantize based on the +// provided scale values +TfLiteTensor CreateSymmetricPerChannelQuantizedTensorWithoutScaleEstimation( + const float* input, int8_t* quantized, TfLiteIntArray* dims, float* scales, + int* zero_points, TfLiteAffineQuantization* affine_quant, + int quantized_dimension, bool is_variable, TfLiteType tensor_weight_type); + // Returns the number of tensors in the default subgraph for a tflite::Model. size_t GetModelTensorCount(const Model* model);