@@ -27,6 +27,9 @@ limitations under the License.
27
27
#include " tensorflow/lite/kernels/internal/compatibility.h"
28
28
#include " tensorflow/lite/schema/schema_generated.h"
29
29
30
+ // TODO(sosagarcia): Rework all function implementations to wrap around the
31
+ // compiler flatbuffer_conversions.
32
+ // LINT.IfChange
30
33
namespace tflite {
31
34
32
35
namespace {
@@ -928,6 +931,9 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
928
931
return ParseStablehloShiftLeft (op, error_reporter, allocator,
929
932
builtin_data);
930
933
}
934
+ case BuiltinOperator_STABLEHLO_CASE: {
935
+ return ParseStablehloCase (op, error_reporter, allocator, builtin_data);
936
+ }
931
937
// TODO: skip param parsing for now since ops below don't have kernels
932
938
case BuiltinOperator_STABLEHLO_SLICE:
933
939
case BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM:
@@ -2421,6 +2427,46 @@ TfLiteStatus ParseStablehloShiftLeft(const Operator* op,
2421
2427
return kTfLiteOk ;
2422
2428
}
2423
2429
2430
+ TfLiteStatus ParseStablehloCase (const Operator* op,
2431
+ ErrorReporter* error_reporter,
2432
+ BuiltinDataAllocator* allocator,
2433
+ void ** builtin_data) {
2434
+ CheckParsePointerParams (op, error_reporter, allocator, builtin_data);
2435
+
2436
+ SafeBuiltinDataAllocator safe_allocator (allocator);
2437
+ auto params = safe_allocator.Allocate <TfLiteStablehloCaseParams>();
2438
+
2439
+ const StablehloCaseOptions* schema_params =
2440
+ op->builtin_options_2_as_StablehloCaseOptions ();
2441
+ if (schema_params) {
2442
+ auto LoadAttr =
2443
+ [&error_reporter](
2444
+ int32_t * params_array, const size_t params_array_size_bytes,
2445
+ const flatbuffers::Vector<int32_t >* const flatbuffer_vector,
2446
+ const char * const attr_name) -> TfLiteStatus {
2447
+ TfLiteStatus status = FlatBufferIntVectorToArray (
2448
+ params_array_size_bytes, flatbuffer_vector, params_array,
2449
+ error_reporter, " stablehlo.case" );
2450
+ if (status != kTfLiteOk ) {
2451
+ TF_LITE_REPORT_ERROR (error_reporter, " Check the '%s' attribute." ,
2452
+ attr_name);
2453
+ }
2454
+ return status;
2455
+ };
2456
+
2457
+ TF_LITE_ENSURE_STATUS (LoadAttr (params->branch_subgraph_indices ,
2458
+ sizeof (params->branch_subgraph_indices ),
2459
+ schema_params->branch_subgraph_indices (),
2460
+ " branch subgraph indices" ));
2461
+ params->num_branches = schema_params->branch_subgraph_indices ()->size ();
2462
+ *builtin_data = params.release ();
2463
+ return kTfLiteOk ;
2464
+ }
2465
+ TF_LITE_REPORT_ERROR (error_reporter,
2466
+ " Could not get 'stablehlo.case' operation parameters." );
2467
+ return kTfLiteError ;
2468
+ }
2469
+
2424
2470
// We have this parse function instead of directly returning kTfLiteOk from the
2425
2471
// switch-case in ParseOpData because this function is used as part of the
2426
2472
// selective registration for the OpResolver implementation in micro.
@@ -2943,3 +2989,4 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
2943
2989
}
2944
2990
2945
2991
} // namespace tflite
2992
+ // LINT.ThenChange(//tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc)
0 commit comments