Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions tensorflow/lite/micro/kernels/decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ TfLiteStatus SetOutputTensorData(TfLiteContext* context, const TfLiteNode* node,
return kTfLiteOk;
}

DecodeState* GetDecodeStateFromCustomRegistration(const TfLiteContext* context,
uint8_t type) {
const MicroContext* mc = GetMicroContext(context);
auto registrations = mc->GetCustomDecodeRegistrations();
if (registrations == nullptr) {
return nullptr;
}

for (auto& reg : *registrations) {
if (reg.type == type && reg.func != nullptr) {
return reg.func(context, mc->GetAlternateProfiler());
}
}

return nullptr;
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const size_t num_inputs = NumInputs(node);
const size_t num_outputs = NumOutputs(node);
Expand Down Expand Up @@ -113,21 +130,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
dsp = DecodeState::CreateDecodeStateHuffman(
context, micro_context->GetAlternateProfiler());
break;
case DecodeState::kDcmTypeCustom:
MicroPrintf("Custom decode type not yet supported");
break;
default:
MicroPrintf("unsupported decode type %u",
DecodeState::Type(*ancillary));
uint32_t type = DecodeState::Type(*ancillary);
if (type >= DecodeState::kDcmTypeCustomFirst &&
type <= DecodeState::kDcmTypeCustomLast) {
dsp = GetDecodeStateFromCustomRegistration(context, type);
} else {
MicroPrintf("unsupported decode type %u", type);
}
break;
}

status = SetOutputTensorData(context, node, i / 2, output);
if (status != kTfLiteOk) {
break;
}

if (dsp != nullptr) {
status = SetOutputTensorData(context, node, i / 2, output);
if (status != kTfLiteOk) {
break;
}
status = dsp->Setup(*input, *ancillary, *output);
if (status != kTfLiteOk) {
break;
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/lite/micro/kernels/decode_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class DecodeState {
static constexpr uint8_t kDcmTypeLUT = 0;
static constexpr uint8_t kDcmTypeHuffman = 1;
static constexpr uint8_t kDcmTypePrune = 2;
static constexpr uint8_t kDcmTypeCustom = 127;
static constexpr uint8_t kDcmTypeCustomFirst = 128;
static constexpr uint8_t kDcmTypeCustomLast = 255;

static constexpr size_t kDcmSizeInBytes = 16;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ TF_LITE_MICRO_TEST(DecodeHuffmanTable16BitsInt16Fail) {
tflite::testing::TestDecode<encodes.size() + ancillaries.size(),
outputs.size()>(
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(),
nullptr, kTfLiteError);
nullptr, nullptr, kTfLiteError);
}

TF_LITE_MICRO_TEST(DecodeHuffmanTable32BitsInt8) {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/lite/micro/kernels/decode_state_prune_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ TF_LITE_MICRO_TEST(DecodePruneQuantizedInvalidZeroPointInt16) {
tflite::testing::TestDecode<kEncodes.size() + kAncillaries.size(),
kOutputs.size()>(
kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE(),
nullptr, kTfLiteError);
nullptr, nullptr, kTfLiteError);
}

TF_LITE_MICRO_TESTS_END
129 changes: 129 additions & 0 deletions tensorflow/lite/micro/kernels/decode_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,76 @@ constexpr int kEncodedShapeLUT[] = {1, sizeof(kEncodedLUT)};
constexpr int8_t kExpectLUT0[] = {1, 2, 3, 4, 4, 3, 2, 1};
constexpr int16_t kExpectLUT1[] = {5, 6, 7, 8, 8, 7, 6, 5};

//
// Custom DECODE test data
//
constexpr int kDecodeTypeCustom = 200;

constexpr int8_t kAncillaryDataCustom[] = {0x42};

constexpr uint8_t kDcmCustom[tflite::DecodeState::kDcmSizeInBytes] = {
kDecodeTypeCustom, // type: custom
1, // DCM version: 1
};

// Align the tensor data the same as a Buffer in the TfLite schema
alignas(16) const uint8_t kEncodedCustom[] = {0x42, 0x43, 0x40, 0x46,
0x4A, 0x52, 0x62, 0x02};

// Tensor shapes as TfLiteIntArray
constexpr int kOutputShapeCustom[] = {1, 8};
constexpr int kEncodedShapeCustom[] = {1, sizeof(kEncodedCustom)};

constexpr int8_t kExpectCustom[] = {0x00, 0x01, 0x02, 0x04,
0x08, 0x10, 0x20, 0x40};

class DecodeStateCustom : public tflite::DecodeState {
public:
DecodeStateCustom() = delete;

DecodeStateCustom(const TfLiteContext* context,
tflite::MicroProfilerInterface* profiler)
: DecodeState(context, profiler) {}

virtual TfLiteStatus Setup(const TfLiteTensor& input,
const TfLiteTensor& ancillary,
const TfLiteTensor& output) override {
return kTfLiteOk;
}

virtual TfLiteStatus Decode(const TfLiteEvalTensor& input,
const TfLiteEvalTensor& ancillary,
const TfLiteEvalTensor& output) override {
const uint8_t* inp = tflite::micro::GetTensorData<uint8_t>(&input);
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), inp != nullptr);
uint8_t* outp = tflite::micro::GetTensorData<uint8_t>(
const_cast<TfLiteEvalTensor*>(&output));
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), outp != nullptr);
const uint8_t* vp = tflite::micro::GetTensorData<uint8_t>(&ancillary);
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), vp != nullptr);
vp += kDcmSizeInBytes;

// simple XOR de-obfuscation
std::transform(inp, inp + input.dims->data[0], outp,
[vp](uint8_t i) { return i ^ *vp; });

return kTfLiteOk;
}

static DecodeState* CreateDecodeStateCustom(
const TfLiteContext* context, tflite::MicroProfilerInterface* profiler) {
alignas(4) static uint8_t buffer[sizeof(DecodeStateCustom)];
DecodeState* instance = new (buffer) DecodeStateCustom(context, profiler);
return instance;
}

protected:
virtual ~DecodeStateCustom() = default;

private:
TF_LITE_REMOVE_VIRTUAL_DELETE
};

} // namespace

TF_LITE_MICRO_TESTS_BEGIN
Expand Down Expand Up @@ -246,4 +316,63 @@ TF_LITE_MICRO_TEST(DecodeWithAltDecompressionMemory) {
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), &amr);
}

TF_LITE_MICRO_TEST(DecodeWithCustomRegistration) {
// Align the tensor data the same as a Buffer in the TfLite schema
alignas(16) int8_t output_data[std::size(kExpectCustom)] = {};
alignas(16) const AncillaryData<int8_t, std::size(kAncillaryDataCustom)>
kAncillaryData = {{kDcmCustom}, {kAncillaryDataCustom}};

constexpr int kAncillaryShapeCustom[] = {1, sizeof(kAncillaryData)};

const TfLiteIntArray* const encoded_dims =
tflite::testing::IntArrayFromInts(kEncodedShapeCustom);
static const TensorInDatum tid_encode = {
kEncodedCustom,
*encoded_dims,
};
static constexpr std::initializer_list<const TensorInDatum*> encodes = {
&tid_encode,
};

const TfLiteIntArray* const ancillary_dims =
tflite::testing::IntArrayFromInts(kAncillaryShapeCustom);
static const TensorInDatum tid_ancillary = {
&kAncillaryData,
*ancillary_dims,
};
static constexpr std::initializer_list<const TensorInDatum*> ancillaries = {
&tid_ancillary};

const TfLiteIntArray* const output_dims =
tflite::testing::IntArrayFromInts(kOutputShapeCustom);
constexpr int kOutputZeroPointsData[] = {0};
const TfLiteIntArray* const kOutputZeroPoints =
tflite::testing::IntArrayFromInts(kOutputZeroPointsData);
const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size};
static const TensorOutDatum tod = {
output_data, *output_dims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints,
0, {},
};
static constexpr std::initializer_list<const TensorOutDatum*> outputs = {
&tod};

const std::initializer_list<const void*> expected = {kExpectCustom};

const std::initializer_list<tflite::MicroContext::CustomDecodeRegistration>
cdr = {
{
kDecodeTypeCustom,
0, // reserved
0, // reserved
0, // reserved
DecodeStateCustom::CreateDecodeStateCustom,
},
};

tflite::testing::TestDecode<encodes.size() + ancillaries.size(),
outputs.size()>(
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(),
nullptr, &cdr);
}

TF_LITE_MICRO_TESTS_END
9 changes: 8 additions & 1 deletion tensorflow/lite/micro/kernels/decode_test_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ TfLiteStatus ExecuteDecodeTest(
TfLiteTensor* tensors, const TFLMRegistration& registration,
const std::initializer_list<const void*>& expected,
const std::initializer_list<MicroContext::AlternateMemoryRegion>* amr =
nullptr,
const std::initializer_list<MicroContext::CustomDecodeRegistration>* cdr =
nullptr) {
int kInputArrayData[kNumInputs + 1] = {kNumInputs};
for (size_t i = 0; i < kNumInputs; i++) {
Expand All @@ -104,6 +106,9 @@ TfLiteStatus ExecuteDecodeTest(
if (amr != nullptr) {
runner.GetFakeMicroContext()->SetDecompressionMemory(*amr);
}
if (cdr != nullptr) {
runner.GetFakeMicroContext()->SetCustomDecodeRegistrations(*cdr);
}

if (runner.InitAndPrepare() != kTfLiteOk || runner.Invoke() != kTfLiteOk) {
return kTfLiteError;
Expand Down Expand Up @@ -149,6 +154,8 @@ void TestDecode(
const TFLMRegistration& registration,
const std::initializer_list<MicroContext::AlternateMemoryRegion>* amr =
nullptr,
const std::initializer_list<MicroContext::CustomDecodeRegistration>* cdr =
nullptr,
const TfLiteStatus expected_status = kTfLiteOk) {
TfLiteTensor tensors[kNumInputs + kNumOutputs] = {};

Expand Down Expand Up @@ -182,7 +189,7 @@ void TestDecode(
}

TfLiteStatus s = ExecuteDecodeTest<kNumInputs, kNumOutputs>(
tensors, registration, expected, amr);
tensors, registration, expected, amr, cdr);
TF_LITE_MICRO_EXPECT_EQ(s, expected_status);
}

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/lite/micro/kernels/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/micro/micro_common.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_graph.h"

#ifdef USE_TFLM_COMPRESSION

Expand Down
32 changes: 31 additions & 1 deletion tensorflow/lite/micro/micro_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ namespace tflite {
// TODO(b/149795762): kTfLiteAbort cannot be part of the tflite TfLiteStatus.
const TfLiteStatus kTfLiteAbort = static_cast<TfLiteStatus>(15);

class DecodeState; // can't use decode_state.h due to circular include

// MicroContext is eventually going to become the API between TFLM and the
// kernels, replacing all the functions in TfLiteContext. The end state is code
// kernels to have code like:
Expand Down Expand Up @@ -136,7 +138,7 @@ class MicroContext {
};

// Set the alternate decompression memory regions.
// Can only be called during the MicroInterpreter kInit state.
// Can only be called during the kInit state.
virtual TfLiteStatus SetDecompressionMemory(
const std::initializer_list<AlternateMemoryRegion>& regions);

Expand Down Expand Up @@ -169,12 +171,40 @@ class MicroContext {
return nullptr;
}

struct CustomDecodeRegistration {
uint8_t type; // custom decode type
uint8_t reserved1; // reserved
uint8_t reserved2; // reserved
uint8_t reserved3; // reserved
tflite::DecodeState* (*func)(const TfLiteContext*, MicroProfilerInterface*);
};

// Set the custom DECODE operator registrations.
// Can only be called during the kInit state.
virtual TfLiteStatus SetCustomDecodeRegistrations(
const std::initializer_list<CustomDecodeRegistration>& registrations) {
if (custom_decode_registrations_ != nullptr) {
return kTfLiteError;
}
custom_decode_registrations_ = &registrations;
return kTfLiteOk;
}

// Get the custom decompression registrations.
virtual const std::initializer_list<CustomDecodeRegistration>*
GetCustomDecodeRegistrations() const {
return custom_decode_registrations_;
}

private:
const std::initializer_list<AlternateMemoryRegion>* decompress_regions_ =
nullptr;
// array of size_t elements with length equal to decompress_regions_.size()
size_t* decompress_regions_allocations_ = nullptr;

const std::initializer_list<CustomDecodeRegistration>*
custom_decode_registrations_ = nullptr;

TF_LITE_REMOVE_VIRTUAL_DELETE
};

Expand Down
13 changes: 12 additions & 1 deletion tensorflow/lite/micro/micro_interpreter_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class MicroInterpreterContext : public MicroContext {
#endif // USE_TFLM_COMPRESSION

// Set the alternate decompression memory regions.
// Can only be called during the MicroInterpreter kInit state.
// Can only be called during the kInit state.
TfLiteStatus SetDecompressionMemory(
const std::initializer_list<AlternateMemoryRegion>& regions) override;

Expand All @@ -159,6 +159,17 @@ class MicroInterpreterContext : public MicroContext {
// decompression subsystem.
MicroProfilerInterface* GetAlternateProfiler() const override;

// Set the custom DECODE operator registrations.
// Can only be called during the kInit state.
virtual TfLiteStatus SetCustomDecodeRegistrations(
const std::initializer_list<CustomDecodeRegistration>& registrations)
override {
if (state_ != InterpreterState::kInit) {
return kTfLiteError;
}
return MicroContext::SetCustomDecodeRegistrations(registrations);
}

private:
MicroAllocator& allocator_;
MicroInterpreterGraph& graph_;
Expand Down
Loading