Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tensorflow/lite/micro/kernels/comparisons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,19 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(input2), output_shape,
output_data);
break;
case kTfLiteInt16:
requires_broadcast
? reference_ops::Broadcast4DSlowGreaterWithScaling(
data->params, input1_shape,
tflite::micro::GetTensorData<int16_t>(input1), input2_shape,
tflite::micro::GetTensorData<int16_t>(input2), output_shape,
output_data)
: reference_ops::GreaterWithScaling(
data->params, input1_shape,
tflite::micro::GetTensorData<int16_t>(input1), input2_shape,
tflite::micro::GetTensorData<int16_t>(input2), output_shape,
output_data);
break;
default:
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(input1->type), input1->type);
Expand Down
52 changes: 52 additions & 0 deletions tensorflow/lite/micro/kernels/comparisons_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,29 @@ void TestComparisonQuantizedInt8(const TFLMRegistration& registration,
TestComparison(registration, tensors, expected_output_data, output_data);
}

void TestComparisonQuantizedInt16(const TFLMRegistration& registration,
int* input1_dims_data, float* input1_data,
int16_t* input1_quantized, float input1_scale,
int input1_zero_point, int* input2_dims_data,
float* input2_data, int16_t* input2_quantized,
float input2_scale, int input2_zero_point,
bool* expected_output_data,
int* output_dims_data, bool* output_data) {
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);

TfLiteTensor tensors[tensors_size] = {
CreateQuantizedTensor(input1_data, input1_quantized, input1_dims,
input1_scale, input1_zero_point),
CreateQuantizedTensor(input2_data, input2_quantized, input2_dims,
input2_scale, input2_zero_point),
CreateTensor(output_data, output_dims),
};

TestComparison(registration, tensors, expected_output_data, output_data);
}

} // namespace
} // namespace testing
} // namespace tflite
Expand Down Expand Up @@ -656,6 +679,35 @@ TF_LITE_MICRO_TEST(GreaterQuantizedInt8WithBroadcast) {
}
}

TF_LITE_MICRO_TEST(GreaterQuantizedInt16WithBroadcast) {
const int num_shapes = 4;
const int max_shape_size = 5;
int test_shapes[num_shapes][max_shape_size] = {
{1, 6}, {2, 2, 3}, {3, 2, 1, 3}, {4, 1, 3, 1, 2}};

for (int i = 0; i < num_shapes; ++i) {
int* input1_dim = test_shapes[i];
int input2_dim[] = {1, 1};
float input1_data[] = {20, -2, -71, 8, 11, 20};
float input2_data[] = {8};

bool expected_data[] = {true, false, false, false, true, true};
int* expected_dim = input1_dim;

const float input1_scale = 0.5;
const int input1_zero_point = -9;
int16_t input1_quantized[6];
int16_t input2_quantized[6];

bool output_data[6];
tflite::testing::TestComparisonQuantizedInt16(
tflite::Register_GREATER(), input1_dim, input1_data, input1_quantized,
input1_scale, input1_zero_point, input2_dim, input2_data,
input2_quantized, input1_scale, input1_zero_point, expected_data,
expected_dim, output_data);
}
}

TF_LITE_MICRO_TEST(GreaterEqualQuantizedInt8WithBroadcast) {
const int num_shapes = 4;
const int max_shape_size = 5;
Expand Down
104 changes: 88 additions & 16 deletions tensorflow/lite/micro/kernels/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,25 +238,97 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt16: {
switch (filter->type) {
case kTfLiteInt8: {
tflite::reference_integer_ops::FullyConnected(
FullyConnectedParamsQuantized(data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
if (bias == nullptr || bias->type == kTfLiteInt32) {
data.is_per_channel
? tflite::reference_integer_ops::FullyConnectedPerChannel(
FullyConnectedParamsQuantized(data),
data.per_channel_output_multiplier,
reinterpret_cast<const int*>(
data.per_channel_output_shift),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
tflite::micro::GetTensorData<int8_t>(
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output))
: tflite::reference_integer_ops::FullyConnected(
FullyConnectedParamsQuantized(data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
} else if (bias->type == kTfLiteInt64) {
data.is_per_channel
? tflite::reference_integer_ops::FullyConnectedPerChannel(
FullyConnectedParamsQuantized(data),
data.per_channel_output_multiplier,
reinterpret_cast<const int*>(
data.per_channel_output_shift),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output))
: tflite::reference_integer_ops::FullyConnected(
FullyConnectedParamsQuantized(data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
}
break;
}
default: {
Expand Down
11 changes: 8 additions & 3 deletions tensorflow/lite/micro/kernels/fully_connected_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,14 @@ TfLiteStatus CalculateOpDataFullyConnected(
filter->quantization.params);
const int per_channel_quantization_size = affine_quantization->scale->size;

// Currently only Int8 is supported for per channel quantization.
TF_LITE_ENSURE(context,
input->type == kTfLiteInt8 && filter->type != kTfLiteInt4);
// Currently only Int8/Int16 are supported for per channel quantization.
TF_LITE_ENSURE(
context,
(input->type == kTfLiteInt8 && filter->type != kTfLiteInt4) ||
(input->type == kTfLiteInt16 && filter->type != kTfLiteInt4));

TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
per_channel_quantization_size);

TF_LITE_ENSURE_EQ(
context, per_channel_quantization_size,
Expand Down
129 changes: 127 additions & 2 deletions tensorflow/lite/micro/kernels/fully_connected_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ const float simple_weights_data[] = {
int simple_bias_dims[] = {1, 3};
const float simple_bias_data[] = {1, 2, 3};

#if (defined(USE_TFLM_COMPRESSION) || (!defined(XTENSA) && !defined(CMSIS_NN)))

constexpr size_t simple_bias_size =
std::extent<decltype(simple_bias_data)>::value;

#endif // (!defined(XTENSA) && !defined(CMSIS_NN)) ||
// (defined(USE_TFLM_COMPRESSION))

#ifdef USE_TFLM_COMPRESSION

// compressed filter data for kBinQuant scheme
Expand All @@ -60,8 +68,6 @@ constexpr int kBinQuantWeightBitWidth = 4;
// Align the tensor data the same as a Buffer in the schema
alignas(16) constexpr uint8_t kBinQuantBiasData[] = {0x18};
constexpr int kBinQuantBiasBitWidth = 2;
constexpr size_t simple_bias_size =
std::extent<decltype(simple_bias_data)>::value;

#endif // USE_TFLM_COMPRESSION

Expand Down Expand Up @@ -504,6 +510,58 @@ TfLiteStatus TestFullyConnectedQuantizedCompressed(

#endif // USE_TFLM_COMPRESSION

template <typename dataT, typename weightT, typename biasT>
TfLiteStatus TestFullyConnectedQuantizedPerChannel(
int* input_dims_data, const float* input_data, dataT* input_quantized,
const float input_scale, const int input_zero_point, int* weights_dims_data,
const float* weights_data, weightT* weights_quantized,
float* weights_scales, int* weights_zero_points, int* bias_dims_data,
const float* bias_data, biasT* bias_quantized, const float* golden,
dataT* golden_quantized, int* output_dims_data, const float output_scale,
const int output_zero_point, TfLiteFusedActivation activation,
dataT* output_data, TfLiteType weights_packed_type = kTfLiteNoType) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data);
TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
const int output_dims_count = ElementCount(*output_dims);
bool null_bias = bias_data == nullptr ? true : false;

constexpr int array_size = 4; // Avoid variable length array warning.
const int inputs_size = null_bias ? 2 : 3;
constexpr int outputs_size = 1;
const int tensors_size = inputs_size + outputs_size;
TfLiteTensor tensors[array_size];
TfLiteAffineQuantization weights_quant, bias_quant;
float bias_scales[5];
int bias_zero_points[5];

tensors[0] = CreateQuantizedTensor(input_data, input_quantized, input_dims,
input_scale, input_zero_point);
tensors[1] = CreateSymmetricPerChannelQuantizedTensorWithoutScaleEstimation(
weights_data, weights_quantized, weights_dims, weights_scales,
weights_zero_points, &weights_quant, 0 /* quantized dimension */, false,
weights_packed_type);

if (null_bias) {
tensors[2] = CreateQuantizedTensor(output_data, output_dims, output_scale,
output_zero_point);
} else {
tensors[2] = CreatePerChannelQuantizedBiasTensor(
bias_data, bias_quantized, bias_dims, input_scale, &weights_scales[1],
bias_scales, bias_zero_points, &bias_quant,
0 /* quantized dimension */);
tensors[3] = CreateQuantizedTensor(output_data, output_dims, output_scale,
output_zero_point);
}

Quantize(golden, golden_quantized, output_dims_count, output_scale,
output_zero_point);
return ValidateFullyConnectedGoldens(
tensors, tensors_size, null_bias, activation, 1.0f /* tolerance */,
output_dims_count, golden_quantized, output_data);
}

} // namespace
} // namespace testing
} // namespace tflite
Expand Down Expand Up @@ -651,6 +709,40 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8Compressed) {

#endif // USE_TFLM_COMPRESSION

#if !defined(XTENSA) && !defined(CMSIS_NN)

TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt8) {
const float input_scale = 0.5f;
const int input_zero_point = -1;
const float output_scale = 1.0f;
const int output_zero_point = -1;
int weights_zero_points[tflite::testing::simple_bias_size + 1] = {
tflite::testing::simple_bias_size, 0, 0, 0};
float weights_scales[tflite::testing::simple_bias_size + 1] = {
tflite::testing::simple_bias_size, 0.2f, 0.25f, 0.5f};

int8_t input_quantized[tflite::testing::simple_input_size];
int8_t weights_quantized[tflite::testing::simple_weights_size];
int32_t bias_quantized[tflite::testing::simple_output_size];
int8_t golden_quantized[tflite::testing::simple_output_size];
int8_t output_data[tflite::testing::simple_output_size];

TF_LITE_MICRO_EXPECT_EQ(
tflite::testing::TestFullyConnectedQuantizedPerChannel(
tflite::testing::simple_input_dims,
tflite::testing::simple_input_data, input_quantized, input_scale,
input_zero_point, tflite::testing::simple_weights_dims,
tflite::testing::simple_weights_data, weights_quantized,
weights_scales, weights_zero_points,
tflite::testing::simple_bias_dims, tflite::testing::simple_bias_data,
bias_quantized, tflite::testing::simple_golden, golden_quantized,
tflite::testing::simple_output_dims, output_scale, output_zero_point,
kTfLiteActNone, output_data),
kTfLiteOk);
}

#endif // #if !defined(XTENSA) && !defined(CMSIS_NN)

#if !defined(HEXAGON)
TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16) {
const float input_scale = 128.0 / 65536;
Expand Down Expand Up @@ -732,6 +824,39 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16Compressed) {

#endif // USE_TFLM_COMPRESSION

#if !defined(XTENSA) && !defined(CMSIS_NN)

TF_LITE_MICRO_TEST(SimpleTestPerChannelQuantizedInt16) {
const float input_scale = 128.0 / 65536;
const int input_zero_point = 0;
const float output_scale = 128.0 / 65536;
const int output_zero_point = 0;
int weights_zero_points[tflite::testing::simple_bias_size + 1] = {
tflite::testing::simple_bias_size, 0, 0, 0};
float weights_scales[tflite::testing::simple_bias_size + 1] = {
tflite::testing::simple_bias_size, 0.2f, 0.25f, 0.5f};

int16_t input_quantized[tflite::testing::simple_input_size];
int8_t weights_quantized[tflite::testing::simple_weights_size];
int64_t bias_quantized[tflite::testing::simple_output_size];
int16_t golden_quantized[tflite::testing::simple_output_size];
int16_t output_data[tflite::testing::simple_output_size];

TF_LITE_MICRO_EXPECT_EQ(
tflite::testing::TestFullyConnectedQuantizedPerChannel(
tflite::testing::simple_input_dims,
tflite::testing::simple_input_data, input_quantized, input_scale,
input_zero_point, tflite::testing::simple_weights_dims,
tflite::testing::simple_weights_data, weights_quantized,
weights_scales, weights_zero_points,
tflite::testing::simple_bias_dims, tflite::testing::simple_bias_data,
bias_quantized, tflite::testing::simple_golden, golden_quantized,
tflite::testing::simple_output_dims, output_scale, output_zero_point,
kTfLiteActNone, output_data),
kTfLiteOk);
}

#endif // #if !defined(XTENSA) && !defined(CMSIS_NN)
#endif // !defined(HEXAGON)

TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedInt8) {
Expand Down
8 changes: 7 additions & 1 deletion tensorflow/lite/micro/kernels/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,16 @@ TfLiteStatus TransposeEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
case kTfLiteInt16:
reference_ops::Transpose(params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
default:
MicroPrintf(
"Type %s is currently not supported by Transpose. "
"Only float32 and int8 is supported",
"Only float32, int8 and int16 are supported",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
Expand Down
Loading
Loading