1
- /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
1
+ /* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
2
3
3
Licensed under the Apache License, Version 2.0 (the "License");
4
4
you may not use this file except in compliance with the License.
@@ -52,16 +52,37 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
52
52
? tflite::micro::GetEvalInput (context, node, kDepthwiseConvBiasTensor )
53
53
: nullptr ;
54
54
55
+ #ifdef USE_TFLM_COMPRESSION
56
+
57
+ MicroContext* micro_context = GetMicroContext (context);
58
+
59
+ const CompressionTensorData* filter_comp_td =
60
+ micro_context->GetTensorCompressionData (node,
61
+ kDepthwiseConvWeightsTensor );
62
+ const CompressionTensorData* bias_comp_td =
63
+ micro_context->GetTensorCompressionData (node, kDepthwiseConvBiasTensor );
64
+
65
+ #endif // USE_TFLM_COMPRESSION
66
+
55
67
switch (input->type ) { // Already know in/out types are same.
56
68
case kTfLiteFloat32 : {
57
69
tflite::reference_ops::DepthwiseConv (
58
70
DepthwiseConvParamsFloat (params, data),
59
71
tflite::micro::GetTensorShape (input),
60
72
tflite::micro::GetTensorData<float >(input),
61
73
tflite::micro::GetTensorShape (filter),
74
+ #ifdef USE_TFLM_COMPRESSION
75
+ tflite::micro::GetTensorData<float >(micro_context, filter,
76
+ filter_comp_td,
77
+ data.weights_scratch_index ),
78
+ tflite::micro::GetTensorShape (bias),
79
+ tflite::micro::GetOptionalTensorData<float >(
80
+ micro_context, bias, bias_comp_td, data.bias_scratch_index ),
81
+ #else // USE_TFLM_COMPRESSION
62
82
tflite::micro::GetTensorData<float >(filter),
63
83
tflite::micro::GetTensorShape (bias),
64
84
tflite::micro::GetOptionalTensorData<float >(bias),
85
+ #endif // USE_TFLM_COMPRESSION
65
86
tflite::micro::GetTensorShape (output),
66
87
tflite::micro::GetTensorData<float >(output));
67
88
break ;
@@ -94,9 +115,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
94
115
tflite::micro::GetTensorShape (input),
95
116
tflite::micro::GetTensorData<int8_t >(input),
96
117
tflite::micro::GetTensorShape (filter),
118
+ #ifdef USE_TFLM_COMPRESSION
119
+ tflite::micro::GetTensorData<int8_t >(micro_context, filter,
120
+ filter_comp_td,
121
+ data.weights_scratch_index ),
122
+ tflite::micro::GetTensorShape (bias),
123
+ tflite::micro::GetOptionalTensorData<int32_t >(
124
+ micro_context, bias, bias_comp_td, data.bias_scratch_index ),
125
+ #else // USE_TFLM_COMPRESSION
97
126
tflite::micro::GetTensorData<int8_t >(filter),
98
127
tflite::micro::GetTensorShape (bias),
99
128
tflite::micro::GetOptionalTensorData<int32_t >(bias),
129
+ #endif // USE_TFLM_COMPRESSION
100
130
tflite::micro::GetTensorShape (output),
101
131
tflite::micro::GetTensorData<int8_t >(output));
102
132
break ;
@@ -118,9 +148,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
118
148
tflite::micro::GetTensorShape (input),
119
149
tflite::micro::GetTensorData<int16_t >(input),
120
150
tflite::micro::GetTensorShape (filter),
151
+ #ifdef USE_TFLM_COMPRESSION
152
+ tflite::micro::GetTensorData<int8_t >(micro_context, filter,
153
+ filter_comp_td,
154
+ data.weights_scratch_index ),
155
+ tflite::micro::GetTensorShape (bias),
156
+ tflite::micro::GetOptionalTensorData<int64_t >(
157
+ micro_context, bias, bias_comp_td, data.bias_scratch_index ),
158
+ #else // USE_TFLM_COMPRESSION
121
159
tflite::micro::GetTensorData<int8_t >(filter),
122
160
tflite::micro::GetTensorShape (bias),
123
161
tflite::micro::GetOptionalTensorData<int64_t >(bias),
162
+ #endif // USE_TFLM_COMPRESSION
124
163
tflite::micro::GetTensorShape (output),
125
164
tflite::micro::GetTensorData<int16_t >(output));
126
165
break ;
0 commit comments