Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion tensorflow/lite/micro/compression/model_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import numpy as np
from numpy.typing import NDArray
from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite
from typing import ByteString, Generic, TypeVar
from typing import ByteString, Generic, TypeVar, List

_IteratorTo = TypeVar("_IteratorTo")

Expand Down Expand Up @@ -100,10 +100,37 @@ def __init__(self, operator, index, subgraph):
def opcode(self) -> tflite.OperatorCodeT:
return self.subgraph.model.operatorCodes[self.operator.opcodeIndex]

@property
def builtin_opcode(self) -> int:
result: int = self.opcode.deprecatedBuiltinCode
if result == tflite.BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES:
result = self.opcode.builtinCode
return result

@property
def inputs(self):
return _IndirectIterator(self.operator.inputs, self.subgraph.tensors)

@property
def outputs(self):
return _IndirectIterator(self.operator.outputs, self.subgraph.tensors)

@property
def inputs_indices(self) -> List[int]:
return self.operator.inputs

@property
def outputs_indices(self) -> List[int]:
return self.operator.outputs

@property
def builtin_options_type(self) -> int:
return self.operator.builtinOptionsType

@property
def builtin_options(self):
return self.operator.builtinOptions


_NP_DTYPES = {
tflite.TensorType.FLOAT16: np.dtype("<f2"),
Expand Down Expand Up @@ -208,6 +235,10 @@ def operators(self) -> _Iterator[_Operator]:
def tensors(self) -> _Iterator[_Tensor]:
return _Iterator(self._subgraph_t.tensors, _Tensor, parent=self)

@property
def outputs_indices(self) -> List[int]:
return self._subgraph_t.outputs


class _Model:
"""A facade for manipulating tflite.Model.
Expand Down Expand Up @@ -268,6 +299,10 @@ def subgraphs(self) -> _Iterator[_Subgraph]:
def buffers(self) -> _Iterator[_Buffer]:
return _Iterator(self._model_t.buffers, _Buffer, parent=self)

@property
def root(self) -> tflite.ModelT:
return self._model_t


def read(buffer: ByteString):
"""Reads a tflite.Model and returns a model facade.
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/lite/micro/micro_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,14 @@ const tflite::micro::compression::Metadata* GetCompressionMetadata(
MicroPrintf("Compression: verification failure");
return nullptr;
} else {
tflite::micro::compression::MetadataT schema;
if (compression_metadata->schema_version() > schema.schema_version) {
MicroPrintf("Compression: schema version mismatch (using %d got %d)",
schema.schema_version,
compression_metadata->schema_version());
return nullptr;
}

return compression_metadata;
}
}
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/lite/micro/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,17 @@ flatbuffer_py_library(
name = "layer_by_layer_schema_py",
srcs = ["layer_by_layer_schema.fbs"],
)

py_binary(
name = "reorder_ops",
srcs = [
"reorder_ops.py",
],
deps = [
":model_transforms_utils",
"//tensorflow/lite/micro/compression:model_facade",
"//tensorflow/lite/python:schema_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)
Loading