1
- /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1
+ /* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
2
3
3
Licensed under the Apache License, Version 2.0 (the "License");
4
4
you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@ limitations under the License.
26
26
#include " tensorflow/lite/micro/micro_graph.h"
27
27
#include " tensorflow/lite/micro/micro_log.h"
28
28
#include " tensorflow/lite/micro/micro_resource_variable.h"
29
+ #include " tensorflow/lite/micro/micro_utils.h"
29
30
#include " tensorflow/lite/schema/schema_generated.h"
30
31
31
32
namespace tflite {
@@ -35,6 +36,20 @@ namespace {
35
36
constexpr int kInputVariableId = 0 ;
36
37
constexpr int kInputValue = 1 ;
37
38
39
+ #ifdef USE_TFLM_COMPRESSION
40
+
41
+ struct OpData {
42
+ // scratch buffer for compressed input tensor
43
+ int scratch_index;
44
+ };
45
+
46
+ void * Init (TfLiteContext* context, const char * buffer, size_t length) {
47
+ TFLITE_DCHECK (context->AllocatePersistentBuffer != nullptr );
48
+ return context->AllocatePersistentBuffer (context, sizeof (OpData));
49
+ }
50
+
51
+ #endif // USE_TFLM_COMPRESSION
52
+
38
53
TfLiteStatus Prepare (TfLiteContext* context, TfLiteNode* node) {
39
54
TF_LITE_ENSURE_EQ (context, NumInputs (node), 2 );
40
55
TF_LITE_ENSURE_EQ (context, NumOutputs (node), 0 );
@@ -70,6 +85,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
70
85
context, input_value));
71
86
}
72
87
88
+ #ifdef USE_TFLM_COMPRESSION
89
+
90
+ TFLITE_DCHECK (node->user_data != nullptr );
91
+ OpData* data = static_cast <OpData*>(node->user_data );
92
+ // Compression scratch buffers.
93
+ // These will only be allocated if the tensor is compressed.
94
+ data->scratch_index =
95
+ micro_context->AllocateDecompressionScratchBuffer (node, kInputValue );
96
+
97
+ #endif // USE_TFLM_COMPRESSION
98
+
73
99
micro_context->DeallocateTempTfLiteTensor (input_value);
74
100
return kTfLiteOk ;
75
101
}
@@ -93,15 +119,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
93
119
" ResourceVariables and pass it to the interpreter." );
94
120
return kTfLiteError ;
95
121
}
122
+
123
+ #ifdef USE_TFLM_COMPRESSION
124
+ OpData* data = static_cast <OpData*>(node->user_data );
125
+ const CompressionTensorData* comp_td =
126
+ micro_context->GetTensorCompressionData (node, kInputValue );
127
+ const void * buffer = tflite::micro::GetTensorData<void >(
128
+ micro_context, input_value, comp_td, data->scratch_index );
129
+ #else // USE_TFLM_COMPRESSION
130
+ const void * buffer = tflite::micro::GetTensorData<void >(input_value);
131
+ #endif // USE_TFLM_COMPRESSION
132
+
96
133
TF_LITE_ENSURE_OK (context,
97
- resources->Assign (input_id->data .i32 [0 ], input_value));
134
+ resources->Assign (input_id->data .i32 [0 ],
135
+ EvalTensorBytes (input_value), buffer));
98
136
return kTfLiteOk ;
99
137
}
100
138
101
139
} // namespace.
102
140
141
+ #ifdef USE_TFLM_COMPRESSION
142
+
143
+ TFLMRegistration Register_ASSIGN_VARIABLE () {
144
+ return tflite::micro::RegisterOp (Init, Prepare, Eval);
145
+
146
+ #else // USE_TFLM_COMPRESSION
147
+
103
148
TFLMRegistration Register_ASSIGN_VARIABLE () {
104
149
return tflite::micro::RegisterOp (nullptr , Prepare, Eval);
150
+
151
+ #endif // USE_TFLM_COMPRESSION
105
152
}
106
153
107
154
} // namespace tflite
0 commit comments