Skip to content

Commit 9a32964

Browse files
authored
feat(compression): allocate resource variables in persistent buffer (#3013)
Allocate resource variables in a persistent buffer when the input tensor is compressed. Extend tests to validate operation. BUG=part of #2636
1 parent b2f2718 commit 9a32964

File tree

4 files changed

+63
-12
lines changed

4 files changed

+63
-12
lines changed

tensorflow/lite/micro/kernels/assign_variable.cc

+49-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include "tensorflow/lite/micro/micro_graph.h"
2727
#include "tensorflow/lite/micro/micro_log.h"
2828
#include "tensorflow/lite/micro/micro_resource_variable.h"
29+
#include "tensorflow/lite/micro/micro_utils.h"
2930
#include "tensorflow/lite/schema/schema_generated.h"
3031

3132
namespace tflite {
@@ -35,6 +36,20 @@ namespace {
3536
constexpr int kInputVariableId = 0;
3637
constexpr int kInputValue = 1;
3738

39+
#ifdef USE_TFLM_COMPRESSION
40+
41+
struct OpData {
42+
// scratch buffer for compressed input tensor
43+
int scratch_index;
44+
};
45+
46+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
47+
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
48+
return context->AllocatePersistentBuffer(context, sizeof(OpData));
49+
}
50+
51+
#endif // USE_TFLM_COMPRESSION
52+
3853
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
3954
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
4055
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
@@ -70,6 +85,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
7085
context, input_value));
7186
}
7287

88+
#ifdef USE_TFLM_COMPRESSION
89+
90+
TFLITE_DCHECK(node->user_data != nullptr);
91+
OpData* data = static_cast<OpData*>(node->user_data);
92+
// Compression scratch buffers.
93+
// These will only be allocated if the tensor is compressed.
94+
data->scratch_index =
95+
micro_context->AllocateDecompressionScratchBuffer(node, kInputValue);
96+
97+
#endif // USE_TFLM_COMPRESSION
98+
7399
micro_context->DeallocateTempTfLiteTensor(input_value);
74100
return kTfLiteOk;
75101
}
@@ -93,15 +119,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
93119
"ResourceVariables and pass it to the interpreter.");
94120
return kTfLiteError;
95121
}
122+
123+
#ifdef USE_TFLM_COMPRESSION
124+
OpData* data = static_cast<OpData*>(node->user_data);
125+
const CompressionTensorData* comp_td =
126+
micro_context->GetTensorCompressionData(node, kInputValue);
127+
const void* buffer = tflite::micro::GetTensorData<void>(
128+
micro_context, input_value, comp_td, data->scratch_index);
129+
#else // USE_TFLM_COMPRESSION
130+
const void* buffer = tflite::micro::GetTensorData<void>(input_value);
131+
#endif // USE_TFLM_COMPRESSION
132+
96133
TF_LITE_ENSURE_OK(context,
97-
resources->Assign(input_id->data.i32[0], input_value));
134+
resources->Assign(input_id->data.i32[0],
135+
EvalTensorBytes(input_value), buffer));
98136
return kTfLiteOk;
99137
}
100138

101139
} // namespace.
102140

141+
#ifdef USE_TFLM_COMPRESSION
142+
143+
TFLMRegistration Register_ASSIGN_VARIABLE() {
144+
return tflite::micro::RegisterOp(Init, Prepare, Eval);
145+
146+
#else // USE_TFLM_COMPRESSION
147+
103148
TFLMRegistration Register_ASSIGN_VARIABLE() {
104149
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
150+
151+
#endif // USE_TFLM_COMPRESSION
105152
}
106153

107154
} // namespace tflite

tensorflow/lite/micro/micro_resource_variable.cc

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -113,8 +113,8 @@ TfLiteStatus MicroResourceVariables::Allocate(int id, TfLiteContext* context,
113113
return kTfLiteOk;
114114
}
115115

116-
TfLiteStatus MicroResourceVariables::Assign(int id,
117-
const TfLiteEvalTensor* tensor) {
116+
TfLiteStatus MicroResourceVariables::Assign(int id, size_t count_bytes,
117+
const void* input_buffer) {
118118
if (id < 0 || id >= num_resource_variables_) {
119119
MicroPrintf("Attempting to read non-existent resource variable %d", id);
120120
return kTfLiteError;
@@ -128,8 +128,9 @@ TfLiteStatus MicroResourceVariables::Assign(int id,
128128
"with a TfLiteTensor first.");
129129
return kTfLiteError;
130130
}
131-
TFLITE_DCHECK(EvalTensorBytes(tensor) == variable.bytes);
132-
memcpy(variable.resource_buffer, tensor->data.raw, variable.bytes);
131+
TFLITE_DCHECK(count_bytes == variable.bytes);
132+
TFLITE_DCHECK(input_buffer != nullptr);
133+
memcpy(variable.resource_buffer, input_buffer, variable.bytes);
133134
return kTfLiteOk;
134135
}
135136

tensorflow/lite/micro/micro_resource_variable.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -46,10 +46,10 @@ class MicroResourceVariables {
4646
TfLiteStatus Allocate(int id, TfLiteContext* context,
4747
const TfLiteTensor* tensor);
4848

49-
// Copies input tensor contents to the resource buffer.
49+
// Copies input_buffer contents to the resource buffer.
5050
// AllocateResourceVariable with a TFLite tensor must have been called first
5151
// in order to allocate the resource buffer.
52-
TfLiteStatus Assign(int id, const TfLiteEvalTensor* tensor);
52+
TfLiteStatus Assign(int id, size_t count_bytes, const void* input_buffer);
5353

5454
// Zeros out all resource buffers.
5555
TfLiteStatus ResetAll();

tensorflow/lite/micro/micro_resource_variable_test.cc

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "tensorflow/lite/micro/micro_resource_variable.h"
1717

1818
#include "tensorflow/lite/c/common.h"
19+
#include "tensorflow/lite/micro/micro_utils.h"
1920
#include "tensorflow/lite/micro/test_helpers.h"
2021
#include "tensorflow/lite/micro/testing/micro_test.h"
2122

@@ -120,7 +121,9 @@ TF_LITE_MICRO_TEST(VerifyAssignAndReadResourceBuffer) {
120121

121122
.type = kTfLiteFloat32,
122123
};
123-
resource_variables->Assign(id, &assign_tensor);
124+
resource_variables->Assign(
125+
id, tflite::EvalTensorBytes(&assign_tensor),
126+
tflite::micro::GetTensorData<void>(&assign_tensor));
124127

125128
int32_t buffer[32];
126129
TfLiteEvalTensor read_tensor = {

0 commit comments

Comments
 (0)