Skip to content

Commit 70022c5

Browse files
committed
Sync from upstream TF.
1 parent e38f488 commit 70022c5

File tree

6 files changed

+134
-8
lines changed

6 files changed

+134
-8
lines changed

tensorflow/lite/core/api/flatbuffer_conversions.cc

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
1717

18+
#include <algorithm>
1819
#include <cstddef>
1920
#include <cstdint>
2021
#include <memory>
@@ -881,6 +882,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
881882
case BuiltinOperator_STABLEHLO_GATHER: {
882883
return ParseStablehloGather(op, error_reporter, allocator, builtin_data);
883884
}
885+
case BuiltinOperator_STABLEHLO_REDUCE_WINDOW: {
886+
return ParseStablehloReduceWindow(op, error_reporter, allocator,
887+
builtin_data);
888+
}
884889
case BuiltinOperator_REDUCE_WINDOW: {
885890
auto params = safe_allocator.Allocate<TfLiteReduceWindowParams>();
886891
TF_LITE_ENSURE(error_reporter, params != nullptr);
@@ -949,7 +954,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
949954
case BuiltinOperator_STABLEHLO_CONVERT:
950955
case BuiltinOperator_STABLEHLO_PAD:
951956
case BuiltinOperator_STABLEHLO_DOT_GENERAL:
952-
case BuiltinOperator_STABLEHLO_REDUCE_WINDOW:
953957
case BuiltinOperator_STABLEHLO_SORT:
954958
case BuiltinOperator_STABLEHLO_WHILE:
955959
case BuiltinOperator_STABLEHLO_TRANSPOSE:
@@ -2096,6 +2100,98 @@ TfLiteStatus ParseResizeNearestNeighbor(const Operator* op,
20962100
return kTfLiteOk;
20972101
}
20982102

2103+
TfLiteStatus ParseStablehloReduceWindow(const Operator* op,
2104+
ErrorReporter* error_reporter,
2105+
BuiltinDataAllocator* allocator,
2106+
void** builtin_data) {
2107+
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
2108+
2109+
SafeBuiltinDataAllocator safe_allocator(allocator);
2110+
auto params = safe_allocator.Allocate<TfLiteStablehloReduceWindowParams>();
2111+
2112+
const StablehloReduceWindowOptions* schema_params =
2113+
op->builtin_options_2_as_StablehloReduceWindowOptions();
2114+
if (schema_params) {
2115+
if (!schema_params->window_dimensions() ||
2116+
schema_params->window_dimensions()->size() == 0) {
2117+
TF_LITE_REPORT_ERROR(error_reporter,
2118+
"'window_dimensions' attribute is not optional for "
2119+
"'stablehlo.reduce_window' and cannot be empty.");
2120+
return kTfLiteError;
2121+
}
2122+
2123+
const size_t rank = schema_params->window_dimensions()->size();
2124+
2125+
auto LoadAttr = [&error_reporter](
2126+
auto& params_array, auto* const flatbuffer_vector,
2127+
const char* attr_name, const size_t expected_size,
2128+
const int64_t fill_value) -> TfLiteStatus {
2129+
if (flatbuffer_vector && flatbuffer_vector->size()) {
2130+
if (expected_size != 0 && flatbuffer_vector->size() != expected_size) {
2131+
TF_LITE_REPORT_ERROR(
2132+
error_reporter,
2133+
"'%s' attribute of 'stablehlo.reduce_window' does not have the "
2134+
"expected size (%llu != %llu).",
2135+
attr_name, flatbuffer_vector->size(), expected_size);
2136+
return kTfLiteError;
2137+
}
2138+
TfLiteStatus status = FlatBufferIntVectorToArray(
2139+
sizeof(params_array), flatbuffer_vector, params_array,
2140+
error_reporter, "stablehlo.reduce_window");
2141+
if (status != kTfLiteOk) {
2142+
TF_LITE_REPORT_ERROR(error_reporter, "Check the '%s' attribute.",
2143+
attr_name);
2144+
return status;
2145+
}
2146+
} else {
2147+
std::fill_n(params_array,
2148+
TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT,
2149+
fill_value);
2150+
}
2151+
return kTfLiteOk;
2152+
};
2153+
2154+
if (TfLiteStatus status = LoadAttr(
2155+
params->window_dimensions, schema_params->window_dimensions(),
2156+
"window_dimensions", /*expected_size=*/rank, /*fill_value=*/1);
2157+
status != kTfLiteOk) {
2158+
return status;
2159+
}
2160+
if (TfLiteStatus status = LoadAttr(
2161+
params->window_strides, schema_params->window_strides(),
2162+
"window_strides", /*expected_size=*/rank, /*fill_value=*/1);
2163+
status != kTfLiteOk) {
2164+
return status;
2165+
}
2166+
if (TfLiteStatus status = LoadAttr(
2167+
params->base_dilations, schema_params->base_dilations(),
2168+
"base_dilations", /*expected_size=*/rank, /*fill_value=*/1);
2169+
status != kTfLiteOk) {
2170+
return status;
2171+
}
2172+
if (TfLiteStatus status = LoadAttr(
2173+
params->window_dilations, schema_params->window_dilations(),
2174+
"window_dilations", /*expected_size=*/rank, /*fill_value=*/1);
2175+
status != kTfLiteOk) {
2176+
return status;
2177+
}
2178+
if (TfLiteStatus status =
2179+
LoadAttr(params->padding, schema_params->padding(), "padding",
2180+
/*expected_size=*/2 * rank, /*fill_value=*/0);
2181+
status != kTfLiteOk) {
2182+
return status;
2183+
}
2184+
2185+
params->body_subgraph_index = schema_params->body_subgraph_index();
2186+
*builtin_data = params.release();
2187+
return kTfLiteOk;
2188+
}
2189+
TF_LITE_REPORT_ERROR(
2190+
error_reporter,
2191+
"Could not get 'stablehlo.reduce_window' operation parameters.");
2192+
return kTfLiteError;
2193+
}
2194+
20992195
TfLiteStatus ParseStablehloScatter(const Operator* op,
21002196
ErrorReporter* error_reporter,
21012197
BuiltinDataAllocator* allocator,

tensorflow/lite/core/api/flatbuffer_conversions.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,11 @@ TfLiteStatus ParseStablehloGather(const Operator* op,
435435
BuiltinDataAllocator* allocator,
436436
void** builtin_data);
437437

438+
TfLiteStatus ParseStablehloReduceWindow(const Operator* op,
439+
ErrorReporter* error_reporter,
440+
BuiltinDataAllocator* allocator,
441+
void** builtin_data);
442+
438443
} // namespace tflite
439444

440445
#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_

tensorflow/lite/core/c/builtin_op_data.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ extern "C" {
3434
#define TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT 8
3535
#define TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT 8
3636
#define TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT 8
37+
#define TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT 8
3738

3839
// TODO(aselle): Consider using "if this then that" for testing.
3940

@@ -605,6 +606,22 @@ typedef struct {
605606
bool indices_are_sorted;
606607
} TfLiteStablehloGatherParams;
607608

609+
typedef struct {
610+
// See the stablehlo spec for the explanation of the attributes:
611+
// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window
612+
int64_t window_dimensions
613+
[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
614+
int64_t
615+
window_strides[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
616+
int64_t
617+
base_dilations[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
618+
int64_t window_dilations
619+
[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
620+
int64_t
621+
padding[2 * TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
622+
int body_subgraph_index;
623+
} TfLiteStablehloReduceWindowParams;
624+
608625
enum TfLiteReduceWindowFunction {
609626
TfLiteReduceWindowFunctionUnsupported,
610627
TfLiteReduceWindowFunctionAdd,

tensorflow/lite/core/c/c_api_types.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,13 @@ limitations under the License.
3434
extern "C" {
3535
#endif
3636

37-
/** \addtogroup c_api_types tensorflow/lite/c/c_api_types.h
37+
// clang-format off
38+
// NOLINTBEGIN(whitespace/line_length)
39+
/** \defgroup c_api_types tensorflow/lite/c/c_api_types.h
3840
* @{
3941
*/
42+
// NOLINTEND(whitespace/line_length)
43+
// clang-format on
4044

4145
// Define TFL_CAPI_EXPORT macro to export a function properly with a shared
4246
// library.

tensorflow/lite/core/c/common.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,13 @@ limitations under the License.
5454
extern "C" {
5555
#endif // __cplusplus
5656

57-
/** \addtogroup common tensorflow/lite/c/common.h
57+
// clang-format off
58+
// NOLINTBEGIN(whitespace/line_length)
59+
/** \defgroup common tensorflow/lite/c/common.h
5860
* @{
5961
*/
62+
// NOLINTEND(whitespace/line_length)
63+
// clang-format on
6064

6165
/// The list of external context types known to TF Lite. This list exists solely
6266
/// to avoid conflicts and to ensure ops can share the external contexts they

tensorflow/lite/schema/schema.fbs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ enum BuiltinOperator : int32 {
429429
STABLEHLO_LOGISTIC = 162, // WARNING: Do not have runtime support
430430
STABLEHLO_ADD = 163,
431431
STABLEHLO_DIVIDE = 164, // WARNING: No runtime support yet
432-
STABLEHLO_MULTIPLY = 165, // WARNING: No runtime support yet
432+
STABLEHLO_MULTIPLY = 165,
433433
STABLEHLO_MAXIMUM = 166, // WARNING: No runtime support yet
434434
STABLEHLO_RESHAPE = 167, // WARNING: No runtime support yet
435435
STABLEHLO_CLAMP = 168, // WARNING: No runtime support
@@ -462,14 +462,14 @@ enum BuiltinOperator : int32 {
462462
STABLEHLO_PAD = 195, // WARNING: No runtime support
463463
STABLEHLO_IOTA = 196, // WARNING: No runtime support
464464
STABLEHLO_DOT_GENERAL = 197, // WARNING: No runtime support
465-
STABLEHLO_REDUCE_WINDOW = 198, // WARNING: No runtime support
465+
STABLEHLO_REDUCE_WINDOW = 198,
466466
STABLEHLO_SORT = 199, // WARNING: No runtime support
467467
STABLEHLO_WHILE = 200, // WARNING: No runtime support
468468
STABLEHLO_GATHER = 201,
469469
STABLEHLO_TRANSPOSE = 202, // WARNING: No runtime support
470470
DILATE = 203,
471471
STABLEHLO_RNG_BIT_GENERATOR = 204,
472-
REDUCE_WINDOW = 205,
472+
REDUCE_WINDOW = 205 (deprecated),
473473
}
474474
// LINT.ThenChange(nnapi_linter/linter.proto)
475475

@@ -626,7 +626,7 @@ union BuiltinOptions2{
626626
StablehloTransposeOptions,
627627
DilateOptions,
628628
StablehloRngBitGeneratorOptions,
629-
ReduceWindowOptions,
629+
ReduceWindowOptions (deprecated),
630630
}
631631

632632
table StablehloGatherOptions{
@@ -1458,7 +1458,7 @@ enum ReduceWindowFunction : int {
14581458
ANY,
14591459
}
14601460

1461-
table ReduceWindowOptions{
1461+
table ReduceWindowOptions (deprecated) {
14621462
reduce_function: ReduceWindowFunction;
14631463
}
14641464

0 commit comments

Comments
 (0)