1
- /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1
+ /* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
2
3
3
Licensed under the Apache License, Version 2.0 (the "License");
4
4
you may not use this file except in compliance with the License.
@@ -27,30 +27,26 @@ limitations under the License.
27
27
#include " tensorflow/lite/kernels/kernel_util.h"
28
28
#include " tensorflow/lite/kernels/padding.h"
29
29
#include " tensorflow/lite/micro/kernels/kernel_util.h"
30
+ #include " tensorflow/lite/micro/kernels/transpose_conv.h"
30
31
#include " tensorflow/lite/micro/micro_log.h"
31
32
32
33
namespace tflite {
33
34
namespace {
34
35
35
- // For the TfLite transpose_conv implementation, input tensor 0 corresponds to
36
- // the OutputShapeTensor. However, since TFLM does not support dynamic tensors,
37
- // the TFLM implementation ignores input tensor 0 and the only inputs we care
38
- // about are kFilterTensor, kInputTensor and kBiasTensor.
39
- constexpr int kFilterTensor = 1 ;
40
- constexpr int kInputTensor = 2 ;
41
- constexpr int kBiasTensor = 3 ;
42
- constexpr int kOutputTensor = 0 ;
43
-
44
- // Conv is quantized along dimension 0:
45
- // https://www.tensorflow.org/lite/performance/quantization_spec
46
- constexpr int kConvQuantizedDimension = 0 ;
47
-
48
36
struct OpData {
49
37
ConvParams params;
50
38
51
39
// A scratch buffer is required for quantized implementations.
52
40
int scratch_buffer_index;
53
41
42
+ #ifdef USE_TFLM_COMPRESSION
43
+
44
+ // scratch buffers for compressed tensors
45
+ int filter_scratch_index;
46
+ int bias_scratch_index;
47
+
48
+ #endif // USE_TFLM_COMPRESSION
49
+
54
50
// Index to the converted 64-bit bias buffer from 16-bit bias. This is
55
51
// required to handle 16x8 transpose convolutions where a 16-bit bias is
56
52
// provided, whereas the kernel expects 64-bit biases.
@@ -102,17 +98,17 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
102
98
MicroContext* micro_context = GetMicroContext (context);
103
99
104
100
TfLiteTensor* input =
105
- micro_context->AllocateTempInputTensor (node, kInputTensor );
101
+ micro_context->AllocateTempInputTensor (node, kTransposeConvInputTensor );
106
102
TF_LITE_ENSURE (context, input != nullptr );
107
- TfLiteTensor* filter =
108
- micro_context-> AllocateTempInputTensor ( node, kFilterTensor );
103
+ TfLiteTensor* filter = micro_context-> AllocateTempInputTensor (
104
+ node, kTransposeConvFilterTensor );
109
105
TF_LITE_ENSURE (context, filter != nullptr );
110
106
TfLiteTensor* bias =
111
- micro_context->AllocateTempInputTensor (node, kBiasTensor );
112
- TfLiteTensor* output =
113
- micro_context-> AllocateTempOutputTensor ( node, kOutputTensor );
107
+ micro_context->AllocateTempInputTensor (node, kTransposeConvBiasTensor );
108
+ TfLiteTensor* output = micro_context-> AllocateTempOutputTensor (
109
+ node, kTransposeConvOutputTensor );
114
110
TF_LITE_ENSURE (context, output != nullptr );
115
- int output_channels = filter->dims ->data [kConvQuantizedDimension ];
111
+ int output_channels = filter->dims ->data [kTransposeConvQuantizedDimension ];
116
112
117
113
TF_LITE_ENSURE_STATUS (tflite::PopulateConvolutionQuantizationParams (
118
114
context, input, filter, bias, output, kTfLiteActNone ,
@@ -164,13 +160,13 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) {
164
160
MicroContext* micro_context = GetMicroContext (context);
165
161
166
162
TfLiteTensor* output =
167
- micro_context->AllocateTempOutputTensor (node, kOutputTensor );
163
+ micro_context->AllocateTempOutputTensor (node, kTransposeConvOutputTensor );
168
164
TF_LITE_ENSURE (context, output != nullptr );
169
165
TfLiteTensor* input =
170
- micro_context->AllocateTempInputTensor (node, kInputTensor );
166
+ micro_context->AllocateTempInputTensor (node, kTransposeConvInputTensor );
171
167
TF_LITE_ENSURE (context, input != nullptr );
172
168
TfLiteTensor* filter =
173
- micro_context->AllocateTempInputTensor (node, kFilterTensor );
169
+ micro_context->AllocateTempInputTensor (node, kTransposeConvFilterTensor );
174
170
TF_LITE_ENSURE (context, filter != nullptr );
175
171
176
172
TF_LITE_ENSURE_MSG (
@@ -186,7 +182,7 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) {
186
182
const int filter_height = SizeOfDimension (filter, 1 );
187
183
188
184
// Dynamically allocate per-channel quantization parameters.
189
- const int num_channels = filter->dims ->data [kConvQuantizedDimension ];
185
+ const int num_channels = filter->dims ->data [kTransposeConvQuantizedDimension ];
190
186
data->per_channel_output_multiplier =
191
187
static_cast <int32_t *>(context->AllocatePersistentBuffer (
192
188
context, num_channels * sizeof (int32_t )));
@@ -223,10 +219,10 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) {
223
219
TF_LITE_ENSURE (context, affine_quantization->scale );
224
220
TF_LITE_ENSURE (context, affine_quantization->zero_point );
225
221
226
- TF_LITE_ENSURE (context,
227
- affine_quantization->scale ->size == 1 ||
228
- affine_quantization->scale ->size ==
229
- filter->dims ->data [kConvQuantizedDimension ]);
222
+ TF_LITE_ENSURE (
223
+ context, affine_quantization->scale ->size == 1 ||
224
+ affine_quantization->scale ->size ==
225
+ filter->dims ->data [kTransposeConvQuantizedDimension ]);
230
226
TF_LITE_ENSURE_EQ (context, affine_quantization->scale ->size ,
231
227
affine_quantization->zero_point ->size );
232
228
}
@@ -244,6 +240,18 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) {
244
240
data->params .stride_width = params->stride_width ;
245
241
data->params .stride_height = params->stride_height ;
246
242
243
+ #ifdef USE_TFLM_COMPRESSION
244
+
245
+ // Compression scratch buffers.
246
+ // These will only be allocated if the tensor is compressed.
247
+ data->filter_scratch_index =
248
+ micro_context->AllocateDecompressionScratchBuffer (
249
+ node, kTransposeConvFilterTensor );
250
+ data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer (
251
+ node, kTransposeConvBiasTensor );
252
+
253
+ #endif // USE_TFLM_COMPRESSION
254
+
247
255
micro_context->DeallocateTempTfLiteTensor (output);
248
256
micro_context->DeallocateTempTfLiteTensor (input);
249
257
micro_context->DeallocateTempTfLiteTensor (filter);
@@ -252,15 +260,26 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) {
252
260
253
261
TfLiteStatus TransposeConvEval (TfLiteContext* context, TfLiteNode* node) {
254
262
const TfLiteEvalTensor* input =
255
- tflite::micro::GetEvalInput (context, node, kInputTensor );
263
+ tflite::micro::GetEvalInput (context, node, kTransposeConvInputTensor );
256
264
const TfLiteEvalTensor* filter =
257
- tflite::micro::GetEvalInput (context, node, kFilterTensor );
265
+ tflite::micro::GetEvalInput (context, node, kTransposeConvFilterTensor );
258
266
const TfLiteEvalTensor* bias =
259
267
(NumInputs (node) == 4 )
260
- ? tflite::micro::GetEvalInput (context, node, kBiasTensor )
268
+ ? tflite::micro::GetEvalInput (context, node, kTransposeConvBiasTensor )
261
269
: nullptr ;
262
270
TfLiteEvalTensor* output =
263
- tflite::micro::GetEvalOutput (context, node, kOutputTensor );
271
+ tflite::micro::GetEvalOutput (context, node, kTransposeConvOutputTensor );
272
+
273
+ #ifdef USE_TFLM_COMPRESSION
274
+
275
+ MicroContext* micro_context = GetMicroContext (context);
276
+
277
+ const CompressionTensorData* filter_comp_td =
278
+ micro_context->GetTensorCompressionData (node, kTransposeConvFilterTensor );
279
+ const CompressionTensorData* bias_comp_td =
280
+ micro_context->GetTensorCompressionData (node, kTransposeConvBiasTensor );
281
+
282
+ #endif // USE_TFLM_COMPRESSION
264
283
265
284
TFLITE_DCHECK (node->user_data != nullptr );
266
285
const OpData& data = *(static_cast <const OpData*>(node->user_data ));
@@ -280,9 +299,17 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
280
299
op_params, tflite::micro::GetTensorShape (input),
281
300
tflite::micro::GetTensorData<float >(input),
282
301
tflite::micro::GetTensorShape (filter),
302
+ #ifdef USE_TFLM_COMPRESSION
303
+ tflite::micro::GetTensorData<float >(
304
+ micro_context, filter, filter_comp_td, data.filter_scratch_index ),
305
+ tflite::micro::GetTensorShape (bias),
306
+ tflite::micro::GetOptionalTensorData<float >(
307
+ micro_context, bias, bias_comp_td, data.bias_scratch_index ),
308
+ #else // USE_TFLM_COMPRESSION
283
309
tflite::micro::GetTensorData<float >(filter),
284
310
tflite::micro::GetTensorShape (bias),
285
311
tflite::micro::GetOptionalTensorData<float >(bias),
312
+ #endif // USE_TFLM_COMPRESSION
286
313
tflite::micro::GetTensorShape (output),
287
314
tflite::micro::GetTensorData<float >(output),
288
315
tflite::micro::GetTensorShape (nullptr ), nullptr );
@@ -296,9 +323,17 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
296
323
data.per_channel_output_shift , tflite::micro::GetTensorShape (input),
297
324
tflite::micro::GetTensorData<int8_t >(input),
298
325
tflite::micro::GetTensorShape (filter),
326
+ #ifdef USE_TFLM_COMPRESSION
327
+ tflite::micro::GetTensorData<int8_t >(
328
+ micro_context, filter, filter_comp_td, data.filter_scratch_index ),
329
+ tflite::micro::GetTensorShape (bias),
330
+ tflite::micro::GetOptionalTensorData<int32_t >(
331
+ micro_context, bias, bias_comp_td, data.bias_scratch_index ),
332
+ #else // USE_TFLM_COMPRESSION
299
333
tflite::micro::GetTensorData<int8_t >(filter),
300
334
tflite::micro::GetTensorShape (bias),
301
335
tflite::micro::GetOptionalTensorData<int32_t >(bias),
336
+ #endif // USE_TFLM_COMPRESSION
302
337
tflite::micro::GetTensorShape (output),
303
338
tflite::micro::GetTensorData<int8_t >(output),
304
339
tflite::micro::GetTensorShape (nullptr ), nullptr , scratch_buffer);
@@ -311,16 +346,29 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
311
346
auto * bias_converted_buffer =
312
347
static_cast <int64_t *>(context->GetScratchBuffer (
313
348
context, data.bias_converted_buffer_index ));
349
+ const int16_t * const bias_int16_data =
350
+ #ifdef USE_TFLM_COMPRESSION
351
+ tflite::micro::GetTensorData<int16_t >(
352
+ micro_context, bias, bias_comp_td, data.bias_scratch_index );
353
+ #else // USE_TFLM_COMPRESSION
354
+ static_cast <int16_t *>(bias->data .data );
355
+ #endif // USE_TFLM_COMPRESSION
314
356
for (int i = 0 ; i < tflite::micro::GetTensorShape (bias).FlatSize ();
315
357
i++) {
316
- bias_converted_buffer[i] = bias-> data . i16 [i];
358
+ bias_converted_buffer[i] = bias_int16_data [i];
317
359
}
318
360
reference_integer_ops::TransposeConv (
319
361
data.params , data.per_channel_output_multiplier ,
320
362
data.per_channel_output_shift , tflite::micro::GetTensorShape (input),
321
363
tflite::micro::GetTensorData<int16_t >(input),
322
364
tflite::micro::GetTensorShape (filter),
365
+ #ifdef USE_TFLM_COMPRESSION
366
+ tflite::micro::GetTensorData<int8_t >(micro_context, filter,
367
+ filter_comp_td,
368
+ data.filter_scratch_index ),
369
+ #else // USE_TFLM_COMPRESSION
323
370
tflite::micro::GetTensorData<int8_t >(filter),
371
+ #endif // USE_TFLM_COMPRESSION
324
372
tflite::micro::GetTensorShape (bias), bias_converted_buffer,
325
373
tflite::micro::GetTensorShape (output),
326
374
tflite::micro::GetTensorData<int16_t >(output),
@@ -331,9 +379,18 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
331
379
data.per_channel_output_shift , tflite::micro::GetTensorShape (input),
332
380
tflite::micro::GetTensorData<int16_t >(input),
333
381
tflite::micro::GetTensorShape (filter),
382
+ #ifdef USE_TFLM_COMPRESSION
383
+ tflite::micro::GetTensorData<int8_t >(micro_context, filter,
384
+ filter_comp_td,
385
+ data.filter_scratch_index ),
386
+ tflite::micro::GetTensorShape (bias),
387
+ tflite::micro::GetOptionalTensorData<int64_t >(
388
+ micro_context, bias, bias_comp_td, data.bias_scratch_index ),
389
+ #else // USE_TFLM_COMPRESSION
334
390
tflite::micro::GetTensorData<int8_t >(filter),
335
391
tflite::micro::GetTensorShape (bias),
336
- tflite::micro::GetOptionalTensorData<std::int64_t >(bias),
392
+ tflite::micro::GetOptionalTensorData<int64_t >(bias),
393
+ #endif // USE_TFLM_COMPRESSION
337
394
tflite::micro::GetTensorShape (output),
338
395
tflite::micro::GetTensorData<int16_t >(output),
339
396
tflite::micro::GetTensorShape (nullptr ), nullptr , scratch_buffer);
0 commit comments