Skip to content

Commit d59136a

Browse files
authored
feat(compression): add decompression library (#2996)
Add a decompression library, defining structures for compressed tensors and decompression logic to be used by kernels. Add a unit test to validate decompression logic. BUG=part of #2636
1 parent 4a8bb6b commit d59136a

File tree

9 files changed

+1230
-1
lines changed

9 files changed

+1230
-1
lines changed

tensorflow/lite/micro/BUILD

+10
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ tflm_cc_library(
2727
],
2828
)
2929

30+
tflm_cc_library(
31+
name = "compression",
32+
hdrs = [
33+
"compression.h",
34+
],
35+
deps = [
36+
"//tensorflow/lite/c:common",
37+
],
38+
)
39+
3040
tflm_cc_library(
3141
# TODO(b/187093492): Rename to micro_interpreter.
3242
name = "micro_framework",

tensorflow/lite/micro/compression.h

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_LITE_MICRO_MICRO_COMPRESSION_H_
17+
#define TENSORFLOW_LITE_MICRO_MICRO_COMPRESSION_H_
18+
19+
#ifdef USE_TFLM_COMPRESSION
20+
21+
#include "tensorflow/lite/c/common.h"
22+
23+
namespace tflite {
24+
25+
//
26+
// Compressed tensors
27+
//
28+
29+
static constexpr const char* kCompressionMetadataString =
30+
"COMPRESSION_METADATA";
31+
32+
enum class CompressionScheme : uint8_t {
33+
kBinQuant,
34+
};
35+
36+
struct LookupTableData {
37+
static constexpr size_t kMaxBitWidth = 7;
38+
static constexpr size_t kMaxValueTableChannelStride = 128;
39+
40+
const void* value_table; // Pointer into FlatBuffer Values.
41+
uint8_t value_table_channel_stride; // elements per channel
42+
uint8_t compressed_bit_width : 3; // 1 to 7 bits
43+
bool is_per_channel_quantized : 1; // tensor is per-channel quantized
44+
bool use_alternate_axis : 1; // shape default channel:
45+
// 0 = first, 1 = last
46+
uint8_t reserved : 3;
47+
};
48+
49+
union CompressionData {
50+
LookupTableData* lut_data;
51+
};
52+
53+
struct CompressionTensorData {
54+
CompressionScheme scheme;
55+
CompressionData data;
56+
};
57+
58+
struct CompressedTensorList {
59+
// Sparsely populated array with the same number of elements as there are
60+
// tensors in the Subgraph. An alternative would include a tensor index in
61+
// the struct for each and walk the list on look up. This could be slow.
62+
const CompressionTensorData** tensors;
63+
};
64+
65+
} // namespace tflite
66+
67+
#endif // USE_TFLM_COMPRESSION
68+
#endif // TENSORFLOW_LITE_MICRO_MICRO_COMPRESSION_H_

tensorflow/lite/micro/kernels/BUILD

+42
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,29 @@ tflm_cc_library(
7979
],
8080
)
8181

82+
tflm_cc_library(
83+
name = "decompress",
84+
srcs = [
85+
"decompress.cc",
86+
"decompress_common.cc",
87+
],
88+
hdrs = [
89+
"decompress.h",
90+
],
91+
visibility = [
92+
":kernel_friends",
93+
":tflite_micro",
94+
],
95+
deps = [
96+
"//tensorflow/lite:type_to_tflitetype",
97+
"//tensorflow/lite/kernels/internal:compatibility",
98+
"//tensorflow/lite/micro:compression",
99+
"//tensorflow/lite/micro:micro_common",
100+
"//tensorflow/lite/micro:micro_log",
101+
"//tensorflow/lite/micro:micro_profiler",
102+
],
103+
)
104+
82105
tflm_cc_library(
83106
name = "detection_postprocess_flexbuffers_generated_data",
84107
srcs = [
@@ -613,6 +636,25 @@ tflm_cc_test(
613636
],
614637
)
615638

639+
tflm_cc_test(
640+
name = "decompress_test",
641+
srcs = [
642+
"decompress_test.cc",
643+
],
644+
target_compatible_with = select({
645+
"//conditions:default": ["@platforms//:incompatible"],
646+
"//:with_compression_enabled": [],
647+
}),
648+
deps = [
649+
":decompress",
650+
"//tensorflow/lite/c:common",
651+
"//tensorflow/lite/micro:micro_arena_constants",
652+
"//tensorflow/lite/micro:micro_log",
653+
"//tensorflow/lite/micro:test_helpers",
654+
"//tensorflow/lite/micro/testing:micro_test",
655+
],
656+
)
657+
616658
tflm_cc_test(
617659
name = "depth_to_space_test",
618660
srcs = [

tensorflow/lite/micro/kernels/Makefile.inc

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -180,6 +180,10 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/unpack_test.cc \
180180
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/while_test.cc \
181181
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/zeros_like_test.cc
182182

183+
ifeq ($(ENABLE_COMPRESSION), yes)
184+
MICROLITE_KERNEL_SIMPLE_TEST_SRCS += $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress_test.cc
185+
endif
186+
183187
# Generate simple kernel test targets in a common way
184188
$(foreach TEST_TARGET,$(MICROLITE_KERNEL_SIMPLE_TEST_SRCS),\
185189
$(eval $(call microlite_test,kernel_$(notdir $(basename $(TEST_TARGET))),$(TEST_TARGET))))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifdef USE_TFLM_COMPRESSION
17+
18+
#include "tensorflow/lite/micro/kernels/decompress.h"
19+
20+
#include <cstddef>
21+
#include <type_traits>
22+
23+
#include "tensorflow/lite/kernels/internal/compatibility.h"
24+
#include "tensorflow/lite/micro/micro_common.h"
25+
26+
namespace tflite {
27+
28+
template <typename T>
29+
T* DecompressionState::DecompressToBuffer(void* buffer) {
30+
TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth);
31+
TFLITE_DCHECK(compressed_bit_width_ > 0);
32+
33+
if (std::is_same<T, int8_t>::value &&
34+
comp_data_.data.lut_data->compressed_bit_width == 4 &&
35+
!comp_data_.data.lut_data->use_alternate_axis) {
36+
DecompressToBufferWidth4_16(static_cast<int8_t*>(buffer));
37+
} else if (std::is_same<T, int8_t>::value &&
38+
comp_data_.data.lut_data->compressed_bit_width == 3 &&
39+
!comp_data_.data.lut_data->use_alternate_axis) {
40+
DecompressToBufferWidth3_32(static_cast<int8_t*>(buffer));
41+
} else if (std::is_same<T, int8_t>::value &&
42+
comp_data_.data.lut_data->compressed_bit_width == 2 &&
43+
!comp_data_.data.lut_data->use_alternate_axis) {
44+
DecompressToBufferWidth2_16(static_cast<int8_t*>(buffer));
45+
} else {
46+
DecompressToBufferWidthAny<T>(static_cast<T*>(buffer));
47+
}
48+
49+
return static_cast<T*>(buffer);
50+
}
51+
52+
template bool* DecompressionState::DecompressToBuffer<bool>(void*);
53+
template float* DecompressionState::DecompressToBuffer<float>(void*);
54+
template int8_t* DecompressionState::DecompressToBuffer<int8_t>(void*);
55+
template int16_t* DecompressionState::DecompressToBuffer<int16_t>(void*);
56+
template int32_t* DecompressionState::DecompressToBuffer<int32_t>(void*);
57+
template int64_t* DecompressionState::DecompressToBuffer<int64_t>(void*);
58+
59+
} // namespace tflite
60+
61+
#endif // USE_TFLM_COMPRESSION
+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECOMPRESS_H_
17+
#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECOMPRESS_H_
18+
19+
#include <cstdint>
20+
21+
#include "tensorflow/lite/micro/compression.h"
22+
#include "tensorflow/lite/micro/micro_profiler.h"
23+
24+
namespace tflite {
25+
26+
#ifdef USE_TFLM_COMPRESSION
27+
28+
struct DecompressionState {
29+
DecompressionState() = delete;
30+
31+
DecompressionState(const uint8_t* compressed_indices,
32+
const size_t count_indices,
33+
const CompressionTensorData& comp_data,
34+
const size_t num_channels,
35+
MicroProfilerInterface* profiler = nullptr)
36+
: compressed_indices_(compressed_indices),
37+
count_indices_(count_indices),
38+
comp_data_(comp_data),
39+
num_channels_(num_channels),
40+
micro_profiler_(profiler) {}
41+
42+
DecompressionState(const DecompressionState& other)
43+
: compressed_indices_(other.compressed_indices_),
44+
count_indices_(other.count_indices_),
45+
comp_data_(other.comp_data_),
46+
num_channels_(other.num_channels_),
47+
micro_profiler_(other.micro_profiler_) {}
48+
49+
template <typename T>
50+
T* DecompressToBuffer(void* buffer);
51+
52+
protected:
53+
// optimized C++ for INT8, use_alt_axis == false
54+
void DecompressToBufferWidth4_16(int8_t* buffer);
55+
void DecompressToBufferWidth3_32(int8_t* buffer);
56+
void DecompressToBufferWidth2_16(int8_t* buffer);
57+
58+
// generic C++ for any bit width and value table type
59+
template <typename T>
60+
void DecompressToBufferWidthAny(T* buffer);
61+
62+
// Optimized C++ table index fetch
63+
inline size_t GetNextTableIndexWidth7(const size_t current_offset);
64+
inline size_t GetNextTableIndexWidth6(const size_t current_offset);
65+
inline size_t GetNextTableIndexWidth5(const size_t current_offset);
66+
inline size_t GetNextTableIndexWidth4(const size_t current_offset);
67+
inline size_t GetNextTableIndexWidth3(const size_t current_offset);
68+
inline size_t GetNextTableIndexWidth2(const size_t current_offset);
69+
inline size_t GetNextTableIndexWidth1(const size_t current_offset);
70+
71+
protected:
72+
const uint8_t* compressed_indices_;
73+
const size_t count_indices_;
74+
const CompressionTensorData& comp_data_;
75+
const size_t num_channels_;
76+
const size_t compressed_bit_width_ =
77+
comp_data_.data.lut_data->compressed_bit_width;
78+
const size_t elements_per_channel_ =
79+
comp_data_.data.lut_data->use_alternate_axis
80+
? 1
81+
: count_indices_ / num_channels_;
82+
MicroProfilerInterface* micro_profiler_;
83+
};
84+
85+
#endif // USE_TFLM_COMPRESSION
86+
87+
} // namespace tflite
88+
89+
#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECOMPRESS_H_

0 commit comments

Comments
 (0)