Skip to content

Commit 112d7a6

Browse files
authored
feat(compression): extend interpreter to handle compressed tensors (#3002)
Add methods and data structures to handle compressed tensors in the core allocator, context, and interpreter. Extend the interpreter unit test. Add test helpers to build models with compressed tensors and compression metadata. BUG=part of #2636
1 parent d59136a commit 112d7a6

20 files changed

+1668
-149
lines changed

tensorflow/lite/micro/BUILD

+9
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,14 @@ tflm_cc_library(
7272
"micro_context.h",
7373
],
7474
deps = [
75+
":compression",
7576
":micro_common",
7677
":micro_graph",
7778
":micro_log",
79+
":micro_profiler",
80+
"//tensorflow/lite:type_to_tflitetype",
7881
"//tensorflow/lite/c:common",
82+
"//tensorflow/lite/micro/kernels:decompress",
7983
],
8084
)
8185

@@ -145,6 +149,7 @@ tflm_cc_library(
145149
":memory_helpers",
146150
":micro_allocator",
147151
":micro_common",
152+
":micro_context",
148153
":micro_graph",
149154
":micro_log",
150155
":micro_profiler",
@@ -180,6 +185,7 @@ tflm_cc_library(
180185
"micro_allocator.h",
181186
],
182187
deps = [
188+
":compression",
183189
":flatbuffer_utils",
184190
":memory_helpers",
185191
":micro_arena_constants",
@@ -192,6 +198,7 @@ tflm_cc_library(
192198
"//tensorflow/lite/micro/arena_allocator:non_persistent_arena_buffer_allocator",
193199
"//tensorflow/lite/micro/arena_allocator:persistent_arena_buffer_allocator",
194200
"//tensorflow/lite/micro/arena_allocator:simple_memory_allocator",
201+
"//tensorflow/lite/micro/compression:metadata_saved",
195202
"//tensorflow/lite/micro/memory_planner:greedy_memory_planner",
196203
"//tensorflow/lite/micro/memory_planner:linear_memory_planner",
197204
"//tensorflow/lite/micro/memory_planner:micro_memory_planner",
@@ -245,7 +252,9 @@ tflm_cc_library(
245252
"test_helpers.h",
246253
],
247254
deps = [
255+
":compression",
248256
":memory_helpers",
257+
":micro_log",
249258
":micro_utils",
250259
":op_resolvers",
251260
"//tensorflow/lite:type_to_tflitetype",

tensorflow/lite/micro/fake_micro_context.cc

+74-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -23,10 +23,23 @@ limitations under the License.
2323

2424
namespace tflite {
2525

26-
FakeMicroContext::FakeMicroContext(TfLiteTensor* tensors,
27-
SingleArenaBufferAllocator* allocator,
28-
MicroGraph* micro_graph)
29-
: graph_(*micro_graph), tensors_(tensors), allocator_(allocator) {}
26+
FakeMicroContext::FakeMicroContext(
27+
TfLiteTensor* tensors, SingleArenaBufferAllocator* allocator,
28+
MicroGraph* micro_graph
29+
#ifdef USE_TFLM_COMPRESSION
30+
,
31+
const CompressedTensorList* compressed_tensors
32+
#endif // USE_TFLM_COMPRESSION
33+
)
34+
: graph_(*micro_graph),
35+
tensors_(tensors),
36+
allocator_(allocator)
37+
#ifdef USE_TFLM_COMPRESSION
38+
,
39+
compressed_tensors_(compressed_tensors)
40+
#endif // USE_TFLM_COMPRESSION
41+
{
42+
}
3043

3144
TfLiteTensor* FakeMicroContext::AllocateTempTfLiteTensor(int tensor_index) {
3245
allocated_temp_count_++;
@@ -112,4 +125,60 @@ void* FakeMicroContext::external_context() { return nullptr; }
112125

113126
MicroGraph& FakeMicroContext::graph() { return graph_; }
114127

128+
#ifdef USE_TFLM_COMPRESSION
129+
130+
// Available during Prepare & Eval. Returns false if tensor is not
131+
// compressed.
132+
bool FakeMicroContext::IsTensorCompressed(const TfLiteNode* node,
133+
int tensor_idx) {
134+
if (compressed_tensors_ != nullptr && tensor_idx < node->inputs->size) {
135+
int index = node->inputs->data[tensor_idx];
136+
if (index >= 0 && compressed_tensors_->tensors[index] != nullptr) {
137+
return true;
138+
}
139+
}
140+
141+
return false;
142+
}
143+
144+
// Only available during Prepare. The kernel is responsible for storing the
145+
// scratch buffer handle.
146+
int FakeMicroContext::AllocateDecompressionScratchBuffer(const TfLiteNode* node,
147+
int tensor_idx) {
148+
if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs->size) {
149+
return -1;
150+
}
151+
int index = node->inputs->data[tensor_idx];
152+
if (index < 0 || compressed_tensors_->tensors[index] == nullptr) {
153+
return -1;
154+
}
155+
TfLiteTensor* tensor = &tensors_[index];
156+
int scratch_index = -1;
157+
TfLiteStatus result =
158+
RequestScratchBufferInArena(tensor->bytes, &scratch_index);
159+
if (result != kTfLiteOk) {
160+
return -1;
161+
}
162+
163+
return scratch_index;
164+
}
165+
166+
// Available during Prepare & Eval. Returns nullptr if tensor is not
167+
// compressed.
168+
const CompressionTensorData* FakeMicroContext::GetTensorCompressionData(
169+
const TfLiteNode* node, int tensor_idx) {
170+
if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs->size) {
171+
return nullptr;
172+
}
173+
174+
int index = node->inputs->data[tensor_idx];
175+
if (index < 0) {
176+
return nullptr;
177+
}
178+
179+
return compressed_tensors_->tensors[index];
180+
}
181+
182+
#endif // USE_TFLM_COMPRESSION
183+
115184
} // namespace tflite

tensorflow/lite/micro/fake_micro_context.h

+34-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -30,7 +30,12 @@ class FakeMicroContext : public MicroContext {
3030
~FakeMicroContext() = default;
3131

3232
FakeMicroContext(TfLiteTensor* tensors, SingleArenaBufferAllocator* allocator,
33-
MicroGraph* micro_graph);
33+
MicroGraph* micro_graph
34+
#ifdef USE_TFLM_COMPRESSION
35+
,
36+
const CompressedTensorList* compressed_tensors = nullptr
37+
#endif // USE_TFLM_COMPRESSION
38+
);
3439

3540
void* AllocatePersistentBuffer(size_t bytes) override;
3641
TfLiteStatus RequestScratchBufferInArena(size_t bytes,
@@ -50,6 +55,24 @@ class FakeMicroContext : public MicroContext {
5055
void* external_context() override;
5156
MicroGraph& graph() override;
5257

58+
#ifdef USE_TFLM_COMPRESSION
59+
60+
// Available during Prepare & Eval. Returns false if tensor is not
61+
// compressed.
62+
bool IsTensorCompressed(const TfLiteNode* node, int tensor_idx) override;
63+
64+
// Only available during Prepare. The kernel is responsible for storing the
65+
// scratch buffer handle.
66+
int AllocateDecompressionScratchBuffer(const TfLiteNode* node,
67+
int tensor_idx) override;
68+
69+
// Available during Prepare & Eval. Returns nullptr if tensor is not
70+
// compressed.
71+
const CompressionTensorData* GetTensorCompressionData(
72+
const TfLiteNode* node, int tensor_idx) override;
73+
74+
#endif // USE_TFLM_COMPRESSION
75+
5376
private:
5477
static constexpr int kNumScratchBuffers_ = 12;
5578

@@ -62,6 +85,15 @@ class FakeMicroContext : public MicroContext {
6285

6386
SingleArenaBufferAllocator* allocator_;
6487

88+
#ifdef USE_TFLM_COMPRESSION
89+
90+
//
91+
// Compression
92+
//
93+
const CompressedTensorList* compressed_tensors_;
94+
95+
#endif // USE_TFLM_COMPRESSION
96+
6597
TF_LITE_REMOVE_VIRTUAL_DELETE
6698
};
6799

tensorflow/lite/micro/kernels/kernel_runner.cc

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@ limitations under the License.
1818
#include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
1919
#include "tensorflow/lite/micro/micro_arena_constants.h"
2020
#include "tensorflow/lite/micro/micro_log.h"
21-
#include "tensorflow/lite/micro/test_helpers.h"
2221

2322
namespace tflite {
2423
namespace micro {
@@ -38,12 +37,22 @@ KernelRunner::KernelRunner(const TFLMRegistration& registration,
3837
TfLiteTensor* tensors, int tensors_size,
3938
TfLiteIntArray* inputs, TfLiteIntArray* outputs,
4039
const void* builtin_data,
41-
TfLiteIntArray* intermediates)
40+
TfLiteIntArray* intermediates
41+
#ifdef USE_TFLM_COMPRESSION
42+
,
43+
const CompressedTensorList* compressed_tensors
44+
#endif // USE_TFLM_COMPRESSION
45+
)
4246
: registration_(registration),
4347
allocator_(SingleArenaBufferAllocator::Create(kKernelRunnerBuffer_,
4448
kKernelRunnerBufferSize_)),
4549
mock_micro_graph_(allocator_),
46-
fake_micro_context_(tensors, allocator_, &mock_micro_graph_) {
50+
fake_micro_context_(tensors, allocator_, &mock_micro_graph_
51+
#ifdef USE_TFLM_COMPRESSION
52+
,
53+
compressed_tensors
54+
#endif // USE_TFLM_COMPRESSION
55+
) {
4756
// Prepare TfLiteContext:
4857
context_.impl_ = static_cast<void*>(&fake_micro_context_);
4958
context_.ReportError = MicroContextReportOpError;

tensorflow/lite/micro/kernels/kernel_runner.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -36,7 +36,12 @@ class KernelRunner {
3636
KernelRunner(const TFLMRegistration& registration, TfLiteTensor* tensors,
3737
int tensors_size, TfLiteIntArray* inputs,
3838
TfLiteIntArray* outputs, const void* builtin_data,
39-
TfLiteIntArray* intermediates = nullptr);
39+
TfLiteIntArray* intermediates = nullptr
40+
#ifdef USE_TFLM_COMPRESSION
41+
,
42+
const CompressedTensorList* compressed_tensors = nullptr
43+
#endif // USE_TFLM_COMPRESSION
44+
);
4045

4146
// Calls init and prepare on the kernel (i.e. TFLMRegistration) struct.
4247
// Any exceptions will be DebugLog'd and returned as a status code.

tensorflow/lite/micro/kernels/kernel_util.h

+40-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -25,6 +25,13 @@ limitations under the License.
2525
#include "tensorflow/lite/kernels/internal/types.h"
2626
#include "tensorflow/lite/micro/micro_context.h"
2727

28+
#ifdef USE_TFLM_COMPRESSION
29+
30+
#include "tensorflow/lite/micro/micro_arena_constants.h"
31+
#include "tensorflow/lite/micro/micro_utils.h"
32+
33+
#endif // USE_TFLM_COMPRESSION
34+
2835
namespace tflite {
2936
namespace micro {
3037

@@ -91,6 +98,38 @@ const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) {
9198
: reinterpret_cast<const T*>(tensor->data.raw);
9299
}
93100

101+
#ifdef USE_TFLM_COMPRESSION
102+
103+
// Overloads existing GetTensorData. If not compressed, this will return
104+
// tensor->data.
105+
template <typename T>
106+
const T* GetTensorData(MicroContext* micro_context,
107+
const TfLiteEvalTensor* tensor,
108+
const CompressionTensorData* compression_data,
109+
int scratch_buffer_handle) {
110+
if (tensor == nullptr) {
111+
return nullptr;
112+
}
113+
if (compression_data == nullptr) {
114+
return reinterpret_cast<const T*>(tensor->data.data);
115+
}
116+
117+
void* scratch_buffer = nullptr;
118+
if (scratch_buffer_handle != -1) {
119+
scratch_buffer = micro_context->GetScratchBuffer(scratch_buffer_handle);
120+
} else {
121+
size_t bytes_to_allocate = EvalTensorBytes(tensor);
122+
scratch_buffer = micro_context->AllocateDecompressionMemory(
123+
bytes_to_allocate, MicroArenaBufferAlignment());
124+
}
125+
TFLITE_DCHECK(scratch_buffer != nullptr);
126+
void* uncompressed_data = micro_context->DecompressTensorToBuffer(
127+
*tensor, *compression_data, scratch_buffer);
128+
return reinterpret_cast<const T*>(uncompressed_data);
129+
}
130+
131+
#endif // USE_TFLM_COMPRESSION
132+
94133
// Returns the shape of a TfLiteEvalTensor struct.
95134
const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor);
96135

0 commit comments

Comments
 (0)