1
- /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1
+ /* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
2
3
3
Licensed under the Apache License, Version 2.0 (the "License");
4
4
you may not use this file except in compliance with the License.
@@ -23,10 +23,23 @@ limitations under the License.
23
23
24
24
namespace tflite {
25
25
26
- FakeMicroContext::FakeMicroContext (TfLiteTensor* tensors,
27
- SingleArenaBufferAllocator* allocator,
28
- MicroGraph* micro_graph)
29
- : graph_(*micro_graph), tensors_(tensors), allocator_(allocator) {}
26
+ FakeMicroContext::FakeMicroContext (
27
+ TfLiteTensor* tensors, SingleArenaBufferAllocator* allocator,
28
+ MicroGraph* micro_graph
29
+ #ifdef USE_TFLM_COMPRESSION
30
+ ,
31
+ const CompressedTensorList* compressed_tensors
32
+ #endif // USE_TFLM_COMPRESSION
33
+ )
34
+ : graph_(*micro_graph),
35
+ tensors_ (tensors),
36
+ allocator_(allocator)
37
+ #ifdef USE_TFLM_COMPRESSION
38
+ ,
39
+ compressed_tensors_ (compressed_tensors)
40
+ #endif // USE_TFLM_COMPRESSION
41
+ {
42
+ }
30
43
31
44
TfLiteTensor* FakeMicroContext::AllocateTempTfLiteTensor (int tensor_index) {
32
45
allocated_temp_count_++;
@@ -112,4 +125,60 @@ void* FakeMicroContext::external_context() { return nullptr; }
112
125
113
126
MicroGraph& FakeMicroContext::graph () { return graph_; }
114
127
128
+ #ifdef USE_TFLM_COMPRESSION
129
+
130
+ // Available during Prepare & Eval. Returns false if tensor is not
131
+ // compressed.
132
+ bool FakeMicroContext::IsTensorCompressed (const TfLiteNode* node,
133
+ int tensor_idx) {
134
+ if (compressed_tensors_ != nullptr && tensor_idx < node->inputs ->size ) {
135
+ int index = node->inputs ->data [tensor_idx];
136
+ if (index >= 0 && compressed_tensors_->tensors [index ] != nullptr ) {
137
+ return true ;
138
+ }
139
+ }
140
+
141
+ return false ;
142
+ }
143
+
144
+ // Only available during Prepare. The kernel is responsible for storing the
145
+ // scratch buffer handle.
146
+ int FakeMicroContext::AllocateDecompressionScratchBuffer (const TfLiteNode* node,
147
+ int tensor_idx) {
148
+ if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs ->size ) {
149
+ return -1 ;
150
+ }
151
+ int index = node->inputs ->data [tensor_idx];
152
+ if (index < 0 || compressed_tensors_->tensors [index ] == nullptr ) {
153
+ return -1 ;
154
+ }
155
+ TfLiteTensor* tensor = &tensors_[index ];
156
+ int scratch_index = -1 ;
157
+ TfLiteStatus result =
158
+ RequestScratchBufferInArena (tensor->bytes , &scratch_index);
159
+ if (result != kTfLiteOk ) {
160
+ return -1 ;
161
+ }
162
+
163
+ return scratch_index;
164
+ }
165
+
166
+ // Available during Prepare & Eval. Returns nullptr if tensor is not
167
+ // compressed.
168
+ const CompressionTensorData* FakeMicroContext::GetTensorCompressionData (
169
+ const TfLiteNode* node, int tensor_idx) {
170
+ if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs ->size ) {
171
+ return nullptr ;
172
+ }
173
+
174
+ int index = node->inputs ->data [tensor_idx];
175
+ if (index < 0 ) {
176
+ return nullptr ;
177
+ }
178
+
179
+ return compressed_tensors_->tensors [index ];
180
+ }
181
+
182
+ #endif // USE_TFLM_COMPRESSION
183
+
115
184
} // namespace tflite
0 commit comments