Skip to content

Commit b1d8a08

Browse files
authored
feat(compression): implement tensor decompression in op depthwise conv (#3017)
Implement tensor decompression in op depthwise conv. Extend tests to validate operation on compressed tensors. BUG=part of #2636
1 parent 099774d commit b1d8a08

7 files changed

+581
-28
lines changed

tensorflow/lite/micro/kernels/depthwise_conv.cc

+40-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -52,16 +52,37 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
5252
? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor)
5353
: nullptr;
5454

55+
#ifdef USE_TFLM_COMPRESSION
56+
57+
MicroContext* micro_context = GetMicroContext(context);
58+
59+
const CompressionTensorData* filter_comp_td =
60+
micro_context->GetTensorCompressionData(node,
61+
kDepthwiseConvWeightsTensor);
62+
const CompressionTensorData* bias_comp_td =
63+
micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor);
64+
65+
#endif // USE_TFLM_COMPRESSION
66+
5567
switch (input->type) { // Already know in/out types are same.
5668
case kTfLiteFloat32: {
5769
tflite::reference_ops::DepthwiseConv(
5870
DepthwiseConvParamsFloat(params, data),
5971
tflite::micro::GetTensorShape(input),
6072
tflite::micro::GetTensorData<float>(input),
6173
tflite::micro::GetTensorShape(filter),
74+
#ifdef USE_TFLM_COMPRESSION
75+
tflite::micro::GetTensorData<float>(micro_context, filter,
76+
filter_comp_td,
77+
data.weights_scratch_index),
78+
tflite::micro::GetTensorShape(bias),
79+
tflite::micro::GetOptionalTensorData<float>(
80+
micro_context, bias, bias_comp_td, data.bias_scratch_index),
81+
#else // USE_TFLM_COMPRESSION
6282
tflite::micro::GetTensorData<float>(filter),
6383
tflite::micro::GetTensorShape(bias),
6484
tflite::micro::GetOptionalTensorData<float>(bias),
85+
#endif // USE_TFLM_COMPRESSION
6586
tflite::micro::GetTensorShape(output),
6687
tflite::micro::GetTensorData<float>(output));
6788
break;
@@ -94,9 +115,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
94115
tflite::micro::GetTensorShape(input),
95116
tflite::micro::GetTensorData<int8_t>(input),
96117
tflite::micro::GetTensorShape(filter),
118+
#ifdef USE_TFLM_COMPRESSION
119+
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
120+
filter_comp_td,
121+
data.weights_scratch_index),
122+
tflite::micro::GetTensorShape(bias),
123+
tflite::micro::GetOptionalTensorData<int32_t>(
124+
micro_context, bias, bias_comp_td, data.bias_scratch_index),
125+
#else // USE_TFLM_COMPRESSION
97126
tflite::micro::GetTensorData<int8_t>(filter),
98127
tflite::micro::GetTensorShape(bias),
99128
tflite::micro::GetOptionalTensorData<int32_t>(bias),
129+
#endif // USE_TFLM_COMPRESSION
100130
tflite::micro::GetTensorShape(output),
101131
tflite::micro::GetTensorData<int8_t>(output));
102132
break;
@@ -118,9 +148,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
118148
tflite::micro::GetTensorShape(input),
119149
tflite::micro::GetTensorData<int16_t>(input),
120150
tflite::micro::GetTensorShape(filter),
151+
#ifdef USE_TFLM_COMPRESSION
152+
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
153+
filter_comp_td,
154+
data.weights_scratch_index),
155+
tflite::micro::GetTensorShape(bias),
156+
tflite::micro::GetOptionalTensorData<int64_t>(
157+
micro_context, bias, bias_comp_td, data.bias_scratch_index),
158+
#else // USE_TFLM_COMPRESSION
121159
tflite::micro::GetTensorData<int8_t>(filter),
122160
tflite::micro::GetTensorShape(bias),
123161
tflite::micro::GetOptionalTensorData<int64_t>(bias),
162+
#endif // USE_TFLM_COMPRESSION
124163
tflite::micro::GetTensorShape(output),
125164
tflite::micro::GetTensorData<int16_t>(output));
126165
break;

tensorflow/lite/micro/kernels/depthwise_conv_common.cc

+21-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -127,7 +127,9 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
127127

128128
micro_context->DeallocateTempTfLiteTensor(input);
129129
micro_context->DeallocateTempTfLiteTensor(filter);
130-
micro_context->DeallocateTempTfLiteTensor(bias);
130+
if (has_bias) {
131+
micro_context->DeallocateTempTfLiteTensor(bias);
132+
}
131133
micro_context->DeallocateTempTfLiteTensor(output);
132134

133135
return kTfLiteOk;
@@ -209,6 +211,23 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) {
209211
context, node, params, input_width, input_height, filter_width,
210212
filter_height, output_width, output_height, input->type, data));
211213

214+
#ifdef USE_TFLM_COMPRESSION
215+
216+
// Compression scratch buffers.
217+
// These will only be allocated if the tensor is compressed.
218+
if (micro_context->IsTensorCompressed(node, kDepthwiseConvWeightsTensor) &&
219+
filter->type == kTfLiteInt4) {
220+
MicroPrintf("Compression not supported with INT4 tensors");
221+
return kTfLiteError;
222+
}
223+
data->weights_scratch_index =
224+
micro_context->AllocateDecompressionScratchBuffer(
225+
node, kDepthwiseConvWeightsTensor);
226+
data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer(
227+
node, kDepthwiseConvBiasTensor);
228+
229+
#endif // USE_TFLM_COMPRESSION
230+
212231
micro_context->DeallocateTempTfLiteTensor(output);
213232
micro_context->DeallocateTempTfLiteTensor(input);
214233
micro_context->DeallocateTempTfLiteTensor(filter);

0 commit comments

Comments
 (0)