1- /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+ /* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33Licensed under the Apache License, Version 2.0 (the "License");
44you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@ limitations under the License.
2626#include " tensorflow/lite/micro/micro_graph.h"
2727#include " tensorflow/lite/micro/micro_log.h"
2828#include " tensorflow/lite/micro/micro_resource_variable.h"
29+ #include " tensorflow/lite/micro/micro_utils.h"
2930#include " tensorflow/lite/schema/schema_generated.h"
3031
3132namespace tflite {
@@ -35,6 +36,20 @@ namespace {
3536constexpr int kInputVariableId = 0 ;
3637constexpr int kInputValue = 1 ;
3738
39+ #ifdef USE_TFLM_COMPRESSION
40+
41+ struct OpData {
42+ // scratch buffer for compressed input tensor
43+ int scratch_index;
44+ };
45+
46+ void * Init (TfLiteContext* context, const char * buffer, size_t length) {
47+ TFLITE_DCHECK (context->AllocatePersistentBuffer != nullptr );
48+ return context->AllocatePersistentBuffer (context, sizeof (OpData));
49+ }
50+
51+ #endif // USE_TFLM_COMPRESSION
52+
3853TfLiteStatus Prepare (TfLiteContext* context, TfLiteNode* node) {
3954 TF_LITE_ENSURE_EQ (context, NumInputs (node), 2 );
4055 TF_LITE_ENSURE_EQ (context, NumOutputs (node), 0 );
@@ -70,6 +85,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
7085 context, input_value));
7186 }
7287
88+ #ifdef USE_TFLM_COMPRESSION
89+
90+ TFLITE_DCHECK (node->user_data != nullptr );
91+ OpData* data = static_cast <OpData*>(node->user_data );
92+ // Compression scratch buffers.
93+ // These will only be allocated if the tensor is compressed.
94+ data->scratch_index =
95+ micro_context->AllocateDecompressionScratchBuffer (node, kInputValue );
96+
97+ #endif // USE_TFLM_COMPRESSION
98+
7399 micro_context->DeallocateTempTfLiteTensor (input_value);
74100 return kTfLiteOk ;
75101}
@@ -93,15 +119,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
93119 " ResourceVariables and pass it to the interpreter." );
94120 return kTfLiteError ;
95121 }
122+
123+ #ifdef USE_TFLM_COMPRESSION
124+ OpData* data = static_cast <OpData*>(node->user_data );
125+ const CompressionTensorData* comp_td =
126+ micro_context->GetTensorCompressionData (node, kInputValue );
127+ const void * buffer = tflite::micro::GetTensorData<void >(
128+ micro_context, input_value, comp_td, data->scratch_index );
129+ #else // USE_TFLM_COMPRESSION
130+ const void * buffer = tflite::micro::GetTensorData<void >(input_value);
131+ #endif // USE_TFLM_COMPRESSION
132+
96133 TF_LITE_ENSURE_OK (context,
97- resources->Assign (input_id->data .i32 [0 ], input_value));
134+ resources->Assign (input_id->data .i32 [0 ],
135+ EvalTensorBytes (input_value), buffer));
98136 return kTfLiteOk ;
99137}
100138
101139} // namespace.
102140
141+ #ifdef USE_TFLM_COMPRESSION
142+
143+ TFLMRegistration Register_ASSIGN_VARIABLE () {
144+ return tflite::micro::RegisterOp (Init, Prepare, Eval);
145+
146+ #else // USE_TFLM_COMPRESSION
147+
103148TFLMRegistration Register_ASSIGN_VARIABLE () {
104149 return tflite::micro::RegisterOp (nullptr , Prepare, Eval);
150+
151+ #endif // USE_TFLM_COMPRESSION
105152}
106153
107154} // namespace tflite
0 commit comments