Skip to content

Commit 84b517e

Browse files
committed
Sync from upstream TF.
1 parent c9e2319 commit 84b517e

10 files changed

+628
-701
lines changed

tensorflow/compiler/mlir/lite/schema/schema.fbs

+19
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,19 @@ table CustomQuantization {
7070
// Represents a specific quantization technique's parameters.
7171
union QuantizationDetails {
7272
CustomQuantization,
73+
BlockwiseQuantization,
74+
}
75+
76+
// Parameters for blockwise quantization.
77+
table BlockwiseQuantization {
78+
// index to the scale tensor, the tensor can be found in tensors array in
79+
// subgraph.
80+
scales: int;
81+
// index to the zero point tensor. If zero_points is -1, the zero point is
82+
// assumed to be 0, following the convention of optional tensors in tflite.
83+
zero_points: int;
84+
// The block size of the tensor.
85+
block_size: int;
7386
}
7487

7588
// Parameters for converting a quantized tensor back to float.
@@ -474,6 +487,7 @@ enum BuiltinOperator : int32 {
474487
STABLEHLO_COMPOSITE = 206, // WARNING: No runtime support
475488
STABLEHLO_SHIFT_LEFT = 207,
476489
STABLEHLO_CBRT = 208, // WARNING: No runtime support
490+
STABLEHLO_CASE = 209,
477491
}
478492
// LINT.ThenChange(nnapi_linter/linter.proto)
479493

@@ -633,6 +647,7 @@ union BuiltinOptions2{
633647
ReduceWindowOptions (deprecated),
634648
StableHLOCompositeOptions,
635649
StablehloShiftLeftOptions,
650+
StablehloCaseOptions,
636651
}
637652

638653
table StablehloGatherOptions{
@@ -777,6 +792,10 @@ table StablehloScatterOptions {
777792
update_computation_subgraph_index: int;
778793
}
779794

795+
table StablehloCaseOptions{
796+
branch_subgraph_indices : [int];
797+
}
798+
780799
enum RngAlgorithm : byte {
781800
// An algorithm auto-selected by the system according to device type.
782801
DEFAULT = 0,

tensorflow/lite/builtin_ops.h

+1
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ typedef enum {
236236
kTfLiteBuiltinStablehloComposite = 206,
237237
kTfLiteBuiltinStablehloShiftLeft = 207,
238238
kTfLiteBuiltinStablehloCbrt = 208,
239+
kTfLiteBuiltinStablehloCase = 209,
239240
} TfLiteBuiltinOperator;
240241

241242
#ifdef __cplusplus

tensorflow/lite/core/api/flatbuffer_conversions.cc

+47
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ limitations under the License.
2727
#include "tensorflow/lite/kernels/internal/compatibility.h"
2828
#include "tensorflow/lite/schema/schema_generated.h"
2929

30+
// TODO(sosagarcia): Rework all function implementations to wrap around the
31+
// compiler flatbuffer_conversions.
32+
// LINT.IfChange
3033
namespace tflite {
3134

3235
namespace {
@@ -928,6 +931,9 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
928931
return ParseStablehloShiftLeft(op, error_reporter, allocator,
929932
builtin_data);
930933
}
934+
case BuiltinOperator_STABLEHLO_CASE: {
935+
return ParseStablehloCase(op, error_reporter, allocator, builtin_data);
936+
}
931937
// TODO: skip param parsing for now since ops below don't have kernels
932938
case BuiltinOperator_STABLEHLO_SLICE:
933939
case BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM:
@@ -2421,6 +2427,46 @@ TfLiteStatus ParseStablehloShiftLeft(const Operator* op,
24212427
return kTfLiteOk;
24222428
}
24232429

2430+
TfLiteStatus ParseStablehloCase(const Operator* op,
2431+
ErrorReporter* error_reporter,
2432+
BuiltinDataAllocator* allocator,
2433+
void** builtin_data) {
2434+
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
2435+
2436+
SafeBuiltinDataAllocator safe_allocator(allocator);
2437+
auto params = safe_allocator.Allocate<TfLiteStablehloCaseParams>();
2438+
2439+
const StablehloCaseOptions* schema_params =
2440+
op->builtin_options_2_as_StablehloCaseOptions();
2441+
if (schema_params) {
2442+
auto LoadAttr =
2443+
[&error_reporter](
2444+
int32_t* params_array, const size_t params_array_size_bytes,
2445+
const flatbuffers::Vector<int32_t>* const flatbuffer_vector,
2446+
const char* const attr_name) -> TfLiteStatus {
2447+
TfLiteStatus status = FlatBufferIntVectorToArray(
2448+
params_array_size_bytes, flatbuffer_vector, params_array,
2449+
error_reporter, "stablehlo.case");
2450+
if (status != kTfLiteOk) {
2451+
TF_LITE_REPORT_ERROR(error_reporter, "Check the '%s' attribute.",
2452+
attr_name);
2453+
}
2454+
return status;
2455+
};
2456+
2457+
TF_LITE_ENSURE_STATUS(LoadAttr(params->branch_subgraph_indices,
2458+
sizeof(params->branch_subgraph_indices),
2459+
schema_params->branch_subgraph_indices(),
2460+
"branch subgraph indices"));
2461+
params->num_branches = schema_params->branch_subgraph_indices()->size();
2462+
*builtin_data = params.release();
2463+
return kTfLiteOk;
2464+
}
2465+
TF_LITE_REPORT_ERROR(error_reporter,
2466+
"Could not get 'stablehlo.case' operation parameters.");
2467+
return kTfLiteError;
2468+
}
2469+
24242470
// We have this parse function instead of directly returning kTfLiteOk from the
24252471
// switch-case in ParseOpData because this function is used as part of the
24262472
// selective registration for the OpResolver implementation in micro.
@@ -2943,3 +2989,4 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
29432989
}
29442990

29452991
} // namespace tflite
2992+
// LINT.ThenChange(//tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc)

tensorflow/lite/core/api/flatbuffer_conversions.h

+5
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,11 @@ TfLiteStatus ParseStablehloShiftLeft(const Operator* op,
456456
BuiltinDataAllocator* allocator,
457457
void** builtin_data);
458458

459+
TfLiteStatus ParseStablehloCase(const Operator* op,
460+
ErrorReporter* error_reporter,
461+
BuiltinDataAllocator* allocator,
462+
void** builtin_data);
463+
459464
} // namespace tflite
460465

461466
#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_

0 commit comments

Comments
 (0)