Skip to content

Commit a4d6c5c

Browse files
authored
Automated sync from github.com/tensorflow/tensorflow (#2307)
BUG=automated sync from upstream NO_CHECK_TFLITE_FILES=automated sync from upstream
1 parent 38c657a commit a4d6c5c

File tree

4 files changed

+71
-2
lines changed

4 files changed

+71
-2
lines changed

tensorflow/lite/core/api/flatbuffer_conversions.cc

+56-1
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,9 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
918918
*builtin_data = params.release();
919919
return kTfLiteOk;
920920
}
921+
case BuiltinOperator_STABLEHLO_PAD: {
922+
return ParseStablehloPad(op, error_reporter, allocator, builtin_data);
923+
}
921924
// TODO: skip param parsing for now since ops below don't have kernels
922925
case BuiltinOperator_STABLEHLO_SLICE:
923926
case BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM:
@@ -952,7 +955,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
952955
case BuiltinOperator_STABLEHLO_IOTA:
953956
case BuiltinOperator_STABLEHLO_COMPARE:
954957
case BuiltinOperator_STABLEHLO_CONVERT:
955-
case BuiltinOperator_STABLEHLO_PAD:
956958
case BuiltinOperator_STABLEHLO_DOT_GENERAL:
957959
case BuiltinOperator_STABLEHLO_SORT:
958960
case BuiltinOperator_STABLEHLO_WHILE:
@@ -2316,6 +2318,59 @@ TfLiteStatus ParseStablehloGather(const Operator* op,
23162318
return kTfLiteOk;
23172319
}
23182320

2321+
TfLiteStatus ParseStablehloPad(const Operator* op,
2322+
ErrorReporter* error_reporter,
2323+
BuiltinDataAllocator* allocator,
2324+
void** builtin_data) {
2325+
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
2326+
2327+
SafeBuiltinDataAllocator safe_allocator(allocator);
2328+
auto params = safe_allocator.Allocate<TfLiteStablehloPadParams>();
2329+
const StablehloPadOptions* schema_params =
2330+
op->builtin_options_2_as_StablehloPadOptions();
2331+
2332+
if (schema_params) {
2333+
auto LoadAttr =
2334+
[&error_reporter](
2335+
int64_t* params_array, const size_t params_array_size_bytes,
2336+
const flatbuffers::Vector<int64_t>* const flatbuffer_vector,
2337+
const char* const attr_name) -> TfLiteStatus {
2338+
TfLiteStatus status = FlatBufferIntVectorToArray(
2339+
params_array_size_bytes, flatbuffer_vector, params_array,
2340+
error_reporter, "stablehlo.pad");
2341+
if (status != kTfLiteOk) {
2342+
TF_LITE_REPORT_ERROR(error_reporter, "Check the '%s' attribute.",
2343+
attr_name);
2344+
}
2345+
return status;
2346+
};
2347+
2348+
TF_LITE_ENSURE_STATUS(
2349+
LoadAttr(params->edge_padding_low, sizeof(params->edge_padding_low),
2350+
schema_params->edge_padding_low(), "edge_padding_low"));
2351+
TF_LITE_ENSURE_STATUS(
2352+
LoadAttr(params->edge_padding_high, sizeof(params->edge_padding_high),
2353+
schema_params->edge_padding_high(), "edge_padding_high"));
2354+
TF_LITE_ENSURE_STATUS(
2355+
LoadAttr(params->interior_padding, sizeof(params->interior_padding),
2356+
schema_params->interior_padding(), "interior_padding"));
2357+
if (schema_params->edge_padding_low()->size() !=
2358+
schema_params->edge_padding_high()->size() ||
2359+
schema_params->edge_padding_low()->size() !=
2360+
schema_params->interior_padding()->size()) {
2361+
TF_LITE_REPORT_ERROR(error_reporter,
2362+
"'stablehlo.pad' operation parameter array sizes "
2363+
"are not consistent.");
2364+
return kTfLiteError;
2365+
}
2366+
*builtin_data = params.release();
2367+
return kTfLiteOk;
2368+
}
2369+
TF_LITE_REPORT_ERROR(error_reporter,
2370+
"Could not get 'stablehlo.pad' operation parameters.");
2371+
return kTfLiteError;
2372+
}
2373+
23192374
// We have this parse function instead of directly returning kTfLiteOk from the
23202375
// switch-case in ParseOpData because this function is used as part of the
23212376
// selective registration for the OpResolver implementation in micro.

tensorflow/lite/core/api/flatbuffer_conversions.h

+5
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,11 @@ TfLiteStatus ParseStablehloReduceWindow(const Operator* op,
440440
BuiltinDataAllocator* allocator,
441441
void** builtin_data);
442442

443+
TfLiteStatus ParseStablehloPad(const Operator* op,
444+
ErrorReporter* error_reporter,
445+
BuiltinDataAllocator* allocator,
446+
void** builtin_data);
447+
443448
} // namespace tflite
444449

445450
#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_

tensorflow/lite/core/c/builtin_op_data.h

+9
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ extern "C" {
3535
#define TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT 8
3636
#define TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT 8
3737
#define TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT 8
38+
#define TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT 8
3839

3940
// TODO(aselle): Consider using "if this then that" for testing.
4041

@@ -636,6 +637,14 @@ typedef struct {
636637
enum TfLiteReduceWindowFunction reduce_function;
637638
} TfLiteReduceWindowParams;
638639

640+
typedef struct {
641+
// See the stablehlo spec for the explanation of the attributes:
642+
// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad
643+
int64_t edge_padding_low[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT];
644+
int64_t edge_padding_high[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT];
645+
int64_t interior_padding[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT];
646+
} TfLiteStablehloPadParams;
647+
639648
#ifdef __cplusplus
640649
} // extern "C"
641650
#endif // __cplusplus

tensorflow/lite/schema/schema.fbs

+1-1
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ enum BuiltinOperator : int32 {
459459
STABLEHLO_CONVERT = 192, // WARNING: No runtime support
460460
STABLEHLO_DYNAMIC_SLICE = 193, // WARNING: No runtime support
461461
STABLEHLO_DYNAMIC_UPDATE_SLICE = 194, // WARNING: No runtime support
462-
STABLEHLO_PAD = 195, // WARNING: No runtime support
462+
STABLEHLO_PAD = 195,
463463
STABLEHLO_IOTA = 196, // WARNING: No runtime support
464464
STABLEHLO_DOT_GENERAL = 197, // WARNING: No runtime support
465465
STABLEHLO_REDUCE_WINDOW = 198,

0 commit comments

Comments
 (0)