Skip to content

Commit 099774d

Browse files
authored
feat(compression): implement tensor decompression in op transpose conv (#3018)
Implement tensor decompression in op transpose conv. Extend tests to validate operation on compressed tensors. BUG=part of #2636
1 parent f9fecab commit 099774d

File tree

4 files changed

+705
-78
lines changed

4 files changed

+705
-78
lines changed

tensorflow/lite/micro/kernels/transpose_conv.cc

+92-35
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -27,30 +27,26 @@ limitations under the License.
2727
#include "tensorflow/lite/kernels/kernel_util.h"
2828
#include "tensorflow/lite/kernels/padding.h"
2929
#include "tensorflow/lite/micro/kernels/kernel_util.h"
30+
#include "tensorflow/lite/micro/kernels/transpose_conv.h"
3031
#include "tensorflow/lite/micro/micro_log.h"
3132

3233
namespace tflite {
3334
namespace {
3435

35-
// For the TfLite transpose_conv implementation, input tensor 0 corresponds to
36-
// the OutputShapeTensor. However, since TFLM does not support dynamic tensors,
37-
// the TFLM implementation ignores input tensor 0 and the only inputs we care
38-
// about are kFilterTensor, kInputTensor and kBiasTensor.
39-
constexpr int kFilterTensor = 1;
40-
constexpr int kInputTensor = 2;
41-
constexpr int kBiasTensor = 3;
42-
constexpr int kOutputTensor = 0;
43-
44-
// Conv is quantized along dimension 0:
45-
// https://www.tensorflow.org/lite/performance/quantization_spec
46-
constexpr int kConvQuantizedDimension = 0;
47-
4836
struct OpData {
4937
ConvParams params;
5038

5139
// A scratch buffer is required for quantized implementations.
5240
int scratch_buffer_index;
5341

42+
#ifdef USE_TFLM_COMPRESSION
43+
44+
// scratch buffers for compressed tensors
45+
int filter_scratch_index;
46+
int bias_scratch_index;
47+
48+
#endif // USE_TFLM_COMPRESSION
49+
5450
// Index to the converted 64-bit bias buffer from 16-bit bias. This is
5551
// required to handle 16x8 transpose convolutions where a 16-bit bias is
5652
// provided, whereas the kernel expects 64-bit biases.
@@ -102,17 +98,17 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
10298
MicroContext* micro_context = GetMicroContext(context);
10399

104100
TfLiteTensor* input =
105-
micro_context->AllocateTempInputTensor(node, kInputTensor);
101+
micro_context->AllocateTempInputTensor(node, kTransposeConvInputTensor);
106102
TF_LITE_ENSURE(context, input != nullptr);
107-
TfLiteTensor* filter =
108-
micro_context->AllocateTempInputTensor(node, kFilterTensor);
103+
TfLiteTensor* filter = micro_context->AllocateTempInputTensor(
104+
node, kTransposeConvFilterTensor);
109105
TF_LITE_ENSURE(context, filter != nullptr);
110106
TfLiteTensor* bias =
111-
micro_context->AllocateTempInputTensor(node, kBiasTensor);
112-
TfLiteTensor* output =
113-
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
107+
micro_context->AllocateTempInputTensor(node, kTransposeConvBiasTensor);
108+
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
109+
node, kTransposeConvOutputTensor);
114110
TF_LITE_ENSURE(context, output != nullptr);
115-
int output_channels = filter->dims->data[kConvQuantizedDimension];
111+
int output_channels = filter->dims->data[kTransposeConvQuantizedDimension];
116112

117113
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
118114
context, input, filter, bias, output, kTfLiteActNone,
@@ -164,13 +160,13 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) {
164160
MicroContext* micro_context = GetMicroContext(context);
165161

166162
TfLiteTensor* output =
167-
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
163+
micro_context->AllocateTempOutputTensor(node, kTransposeConvOutputTensor);
168164
TF_LITE_ENSURE(context, output != nullptr);
169165
TfLiteTensor* input =
170-
micro_context->AllocateTempInputTensor(node, kInputTensor);
166+
micro_context->AllocateTempInputTensor(node, kTransposeConvInputTensor);
171167
TF_LITE_ENSURE(context, input != nullptr);
172168
TfLiteTensor* filter =
173-
micro_context->AllocateTempInputTensor(node, kFilterTensor);
169+
micro_context->AllocateTempInputTensor(node, kTransposeConvFilterTensor);
174170
TF_LITE_ENSURE(context, filter != nullptr);
175171

176172
TF_LITE_ENSURE_MSG(
@@ -186,7 +182,7 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) {
186182
const int filter_height = SizeOfDimension(filter, 1);
187183

188184
// Dynamically allocate per-channel quantization parameters.
189-
const int num_channels = filter->dims->data[kConvQuantizedDimension];
185+
const int num_channels = filter->dims->data[kTransposeConvQuantizedDimension];
190186
data->per_channel_output_multiplier =
191187
static_cast<int32_t*>(context->AllocatePersistentBuffer(
192188
context, num_channels * sizeof(int32_t)));
@@ -223,10 +219,10 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) {
223219
TF_LITE_ENSURE(context, affine_quantization->scale);
224220
TF_LITE_ENSURE(context, affine_quantization->zero_point);
225221

226-
TF_LITE_ENSURE(context,
227-
affine_quantization->scale->size == 1 ||
228-
affine_quantization->scale->size ==
229-
filter->dims->data[kConvQuantizedDimension]);
222+
TF_LITE_ENSURE(
223+
context, affine_quantization->scale->size == 1 ||
224+
affine_quantization->scale->size ==
225+
filter->dims->data[kTransposeConvQuantizedDimension]);
230226
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
231227
affine_quantization->zero_point->size);
232228
}
@@ -244,6 +240,18 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) {
244240
data->params.stride_width = params->stride_width;
245241
data->params.stride_height = params->stride_height;
246242

243+
#ifdef USE_TFLM_COMPRESSION
244+
245+
// Compression scratch buffers.
246+
// These will only be allocated if the tensor is compressed.
247+
data->filter_scratch_index =
248+
micro_context->AllocateDecompressionScratchBuffer(
249+
node, kTransposeConvFilterTensor);
250+
data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer(
251+
node, kTransposeConvBiasTensor);
252+
253+
#endif // USE_TFLM_COMPRESSION
254+
247255
micro_context->DeallocateTempTfLiteTensor(output);
248256
micro_context->DeallocateTempTfLiteTensor(input);
249257
micro_context->DeallocateTempTfLiteTensor(filter);
@@ -252,15 +260,26 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) {
252260

253261
TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
254262
const TfLiteEvalTensor* input =
255-
tflite::micro::GetEvalInput(context, node, kInputTensor);
263+
tflite::micro::GetEvalInput(context, node, kTransposeConvInputTensor);
256264
const TfLiteEvalTensor* filter =
257-
tflite::micro::GetEvalInput(context, node, kFilterTensor);
265+
tflite::micro::GetEvalInput(context, node, kTransposeConvFilterTensor);
258266
const TfLiteEvalTensor* bias =
259267
(NumInputs(node) == 4)
260-
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
268+
? tflite::micro::GetEvalInput(context, node, kTransposeConvBiasTensor)
261269
: nullptr;
262270
TfLiteEvalTensor* output =
263-
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
271+
tflite::micro::GetEvalOutput(context, node, kTransposeConvOutputTensor);
272+
273+
#ifdef USE_TFLM_COMPRESSION
274+
275+
MicroContext* micro_context = GetMicroContext(context);
276+
277+
const CompressionTensorData* filter_comp_td =
278+
micro_context->GetTensorCompressionData(node, kTransposeConvFilterTensor);
279+
const CompressionTensorData* bias_comp_td =
280+
micro_context->GetTensorCompressionData(node, kTransposeConvBiasTensor);
281+
282+
#endif // USE_TFLM_COMPRESSION
264283

265284
TFLITE_DCHECK(node->user_data != nullptr);
266285
const OpData& data = *(static_cast<const OpData*>(node->user_data));
@@ -280,9 +299,17 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
280299
op_params, tflite::micro::GetTensorShape(input),
281300
tflite::micro::GetTensorData<float>(input),
282301
tflite::micro::GetTensorShape(filter),
302+
#ifdef USE_TFLM_COMPRESSION
303+
tflite::micro::GetTensorData<float>(
304+
micro_context, filter, filter_comp_td, data.filter_scratch_index),
305+
tflite::micro::GetTensorShape(bias),
306+
tflite::micro::GetOptionalTensorData<float>(
307+
micro_context, bias, bias_comp_td, data.bias_scratch_index),
308+
#else // USE_TFLM_COMPRESSION
283309
tflite::micro::GetTensorData<float>(filter),
284310
tflite::micro::GetTensorShape(bias),
285311
tflite::micro::GetOptionalTensorData<float>(bias),
312+
#endif // USE_TFLM_COMPRESSION
286313
tflite::micro::GetTensorShape(output),
287314
tflite::micro::GetTensorData<float>(output),
288315
tflite::micro::GetTensorShape(nullptr), nullptr);
@@ -296,9 +323,17 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
296323
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
297324
tflite::micro::GetTensorData<int8_t>(input),
298325
tflite::micro::GetTensorShape(filter),
326+
#ifdef USE_TFLM_COMPRESSION
327+
tflite::micro::GetTensorData<int8_t>(
328+
micro_context, filter, filter_comp_td, data.filter_scratch_index),
329+
tflite::micro::GetTensorShape(bias),
330+
tflite::micro::GetOptionalTensorData<int32_t>(
331+
micro_context, bias, bias_comp_td, data.bias_scratch_index),
332+
#else // USE_TFLM_COMPRESSION
299333
tflite::micro::GetTensorData<int8_t>(filter),
300334
tflite::micro::GetTensorShape(bias),
301335
tflite::micro::GetOptionalTensorData<int32_t>(bias),
336+
#endif // USE_TFLM_COMPRESSION
302337
tflite::micro::GetTensorShape(output),
303338
tflite::micro::GetTensorData<int8_t>(output),
304339
tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
@@ -311,16 +346,29 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
311346
auto* bias_converted_buffer =
312347
static_cast<int64_t*>(context->GetScratchBuffer(
313348
context, data.bias_converted_buffer_index));
349+
const int16_t* const bias_int16_data =
350+
#ifdef USE_TFLM_COMPRESSION
351+
tflite::micro::GetTensorData<int16_t>(
352+
micro_context, bias, bias_comp_td, data.bias_scratch_index);
353+
#else // USE_TFLM_COMPRESSION
354+
static_cast<int16_t*>(bias->data.data);
355+
#endif // USE_TFLM_COMPRESSION
314356
for (int i = 0; i < tflite::micro::GetTensorShape(bias).FlatSize();
315357
i++) {
316-
bias_converted_buffer[i] = bias->data.i16[i];
358+
bias_converted_buffer[i] = bias_int16_data[i];
317359
}
318360
reference_integer_ops::TransposeConv(
319361
data.params, data.per_channel_output_multiplier,
320362
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
321363
tflite::micro::GetTensorData<int16_t>(input),
322364
tflite::micro::GetTensorShape(filter),
365+
#ifdef USE_TFLM_COMPRESSION
366+
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
367+
filter_comp_td,
368+
data.filter_scratch_index),
369+
#else // USE_TFLM_COMPRESSION
323370
tflite::micro::GetTensorData<int8_t>(filter),
371+
#endif // USE_TFLM_COMPRESSION
324372
tflite::micro::GetTensorShape(bias), bias_converted_buffer,
325373
tflite::micro::GetTensorShape(output),
326374
tflite::micro::GetTensorData<int16_t>(output),
@@ -331,9 +379,18 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
331379
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
332380
tflite::micro::GetTensorData<int16_t>(input),
333381
tflite::micro::GetTensorShape(filter),
382+
#ifdef USE_TFLM_COMPRESSION
383+
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
384+
filter_comp_td,
385+
data.filter_scratch_index),
386+
tflite::micro::GetTensorShape(bias),
387+
tflite::micro::GetOptionalTensorData<int64_t>(
388+
micro_context, bias, bias_comp_td, data.bias_scratch_index),
389+
#else // USE_TFLM_COMPRESSION
334390
tflite::micro::GetTensorData<int8_t>(filter),
335391
tflite::micro::GetTensorShape(bias),
336-
tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
392+
tflite::micro::GetOptionalTensorData<int64_t>(bias),
393+
#endif // USE_TFLM_COMPRESSION
337394
tflite::micro::GetTensorShape(output),
338395
tflite::micro::GetTensorData<int16_t>(output),
339396
tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);

tensorflow/lite/micro/kernels/transpose_conv.h

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -23,6 +23,19 @@ limitations under the License.
2323

2424
namespace tflite {
2525

26+
// For the TfLite transpose_conv implementation, input tensor 0 corresponds to
27+
// the OutputShapeTensor. However, since TFLM does not support dynamic tensors,
28+
// the TFLM implementation ignores input tensor 0 and the only inputs we care
29+
// about are kFilterTensor, kInputTensor and kBiasTensor.
30+
constexpr int kTransposeConvFilterTensor = 1;
31+
constexpr int kTransposeConvInputTensor = 2;
32+
constexpr int kTransposeConvBiasTensor = 3;
33+
constexpr int kTransposeConvOutputTensor = 0;
34+
35+
// Conv is quantized along dimension 0:
36+
// https://www.tensorflow.org/lite/performance/quantization_spec
37+
constexpr int kTransposeConvQuantizedDimension = 0;
38+
2639
// This is the most generic TFLMRegistration. The actual supported types
2740
// may still be target dependent. The only requirement is that every
2841
// implementation (reference or optimized) must define this function.

0 commit comments

Comments
 (0)