|
| 1 | +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | + |
| 16 | +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECOMPRESS_H_ |
| 17 | +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECOMPRESS_H_ |
| 18 | + |
| 19 | +#include <cstdint> |
| 20 | + |
| 21 | +#include "tensorflow/lite/micro/compression.h" |
| 22 | +#include "tensorflow/lite/micro/micro_profiler.h" |
| 23 | + |
| 24 | +namespace tflite { |
| 25 | + |
| 26 | +#ifdef USE_TFLM_COMPRESSION |
| 27 | + |
| 28 | +struct DecompressionState { |
| 29 | + DecompressionState() = delete; |
| 30 | + |
| 31 | + DecompressionState(const uint8_t* compressed_indices, |
| 32 | + const size_t count_indices, |
| 33 | + const CompressionTensorData& comp_data, |
| 34 | + const size_t num_channels, |
| 35 | + MicroProfilerInterface* profiler = nullptr) |
| 36 | + : compressed_indices_(compressed_indices), |
| 37 | + count_indices_(count_indices), |
| 38 | + comp_data_(comp_data), |
| 39 | + num_channels_(num_channels), |
| 40 | + micro_profiler_(profiler) {} |
| 41 | + |
| 42 | + DecompressionState(const DecompressionState& other) |
| 43 | + : compressed_indices_(other.compressed_indices_), |
| 44 | + count_indices_(other.count_indices_), |
| 45 | + comp_data_(other.comp_data_), |
| 46 | + num_channels_(other.num_channels_), |
| 47 | + micro_profiler_(other.micro_profiler_) {} |
| 48 | + |
| 49 | + template <typename T> |
| 50 | + T* DecompressToBuffer(void* buffer); |
| 51 | + |
| 52 | + protected: |
| 53 | + // optimized C++ for INT8, use_alt_axis == false |
| 54 | + void DecompressToBufferWidth4_16(int8_t* buffer); |
| 55 | + void DecompressToBufferWidth3_32(int8_t* buffer); |
| 56 | + void DecompressToBufferWidth2_16(int8_t* buffer); |
| 57 | + |
| 58 | + // generic C++ for any bit width and value table type |
| 59 | + template <typename T> |
| 60 | + void DecompressToBufferWidthAny(T* buffer); |
| 61 | + |
| 62 | + // Optimized C++ table index fetch |
| 63 | + inline size_t GetNextTableIndexWidth7(const size_t current_offset); |
| 64 | + inline size_t GetNextTableIndexWidth6(const size_t current_offset); |
| 65 | + inline size_t GetNextTableIndexWidth5(const size_t current_offset); |
| 66 | + inline size_t GetNextTableIndexWidth4(const size_t current_offset); |
| 67 | + inline size_t GetNextTableIndexWidth3(const size_t current_offset); |
| 68 | + inline size_t GetNextTableIndexWidth2(const size_t current_offset); |
| 69 | + inline size_t GetNextTableIndexWidth1(const size_t current_offset); |
| 70 | + |
| 71 | + protected: |
| 72 | + const uint8_t* compressed_indices_; |
| 73 | + const size_t count_indices_; |
| 74 | + const CompressionTensorData& comp_data_; |
| 75 | + const size_t num_channels_; |
| 76 | + const size_t compressed_bit_width_ = |
| 77 | + comp_data_.data.lut_data->compressed_bit_width; |
| 78 | + const size_t elements_per_channel_ = |
| 79 | + comp_data_.data.lut_data->use_alternate_axis |
| 80 | + ? 1 |
| 81 | + : count_indices_ / num_channels_; |
| 82 | + MicroProfilerInterface* micro_profiler_; |
| 83 | +}; |
| 84 | + |
| 85 | +#endif // USE_TFLM_COMPRESSION |
| 86 | + |
| 87 | +} // namespace tflite |
| 88 | + |
| 89 | +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECOMPRESS_H_ |
0 commit comments