@@ -15,6 +15,7 @@ limitations under the License.
1515
1616#include " tensorflow/lite/core/api/flatbuffer_conversions.h"
1717
18+ #include < algorithm>
1819#include < cstddef>
1920#include < cstdint>
2021#include < memory>
@@ -881,6 +882,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
881882 case BuiltinOperator_STABLEHLO_GATHER: {
882883 return ParseStablehloGather (op, error_reporter, allocator, builtin_data);
883884 }
885+ case BuiltinOperator_STABLEHLO_REDUCE_WINDOW: {
886+ return ParseStablehloReduceWindow (op, error_reporter, allocator,
887+ builtin_data);
888+ }
884889 case BuiltinOperator_REDUCE_WINDOW: {
885890 auto params = safe_allocator.Allocate <TfLiteReduceWindowParams>();
886891 TF_LITE_ENSURE (error_reporter, params != nullptr );
@@ -949,7 +954,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
949954 case BuiltinOperator_STABLEHLO_CONVERT:
950955 case BuiltinOperator_STABLEHLO_PAD:
951956 case BuiltinOperator_STABLEHLO_DOT_GENERAL:
952- case BuiltinOperator_STABLEHLO_REDUCE_WINDOW:
953957 case BuiltinOperator_STABLEHLO_SORT:
954958 case BuiltinOperator_STABLEHLO_WHILE:
955959 case BuiltinOperator_STABLEHLO_TRANSPOSE:
@@ -2096,6 +2100,98 @@ TfLiteStatus ParseResizeNearestNeighbor(const Operator* op,
20962100 return kTfLiteOk ;
20972101}
20982102
2103+ TfLiteStatus ParseStablehloReduceWindow (const Operator* op,
2104+ ErrorReporter* error_reporter,
2105+ BuiltinDataAllocator* allocator,
2106+ void ** builtin_data) {
2107+ CheckParsePointerParams (op, error_reporter, allocator, builtin_data);
2108+
2109+ SafeBuiltinDataAllocator safe_allocator (allocator);
2110+ auto params = safe_allocator.Allocate <TfLiteStablehloReduceWindowParams>();
2111+
2112+ const StablehloReduceWindowOptions* schema_params =
2113+ op->builtin_options_2_as_StablehloReduceWindowOptions ();
2114+ if (schema_params) {
2115+ if (!schema_params->window_dimensions () ||
2116+ schema_params->window_dimensions ()->size () == 0 ) {
2117+ TF_LITE_REPORT_ERROR (error_reporter,
2118+ " 'window_dimensions' attribute is not optional for "
2119+ " 'stablehlo.reduce_window' and cannot be empty." );
2120+ return kTfLiteError ;
2121+ }
2122+
2123+ const size_t rank = schema_params->window_dimensions ()->size ();
2124+
2125+ auto LoadAttr = [&error_reporter](
2126+ auto & params_array, auto * const flatbuffer_vector,
2127+ const char * attr_name, const size_t expected_size,
2128+ const int64_t fill_value) -> TfLiteStatus {
2129+ if (flatbuffer_vector && flatbuffer_vector->size ()) {
2130+ if (expected_size != 0 && flatbuffer_vector->size () != expected_size) {
2131+ TF_LITE_REPORT_ERROR (
2132+ error_reporter,
2133+ " '%s' attribute of 'stablehlo.reduce_window' does not have the "
2134+ " expected size (%llu != %llu)." ,
2135+ attr_name, flatbuffer_vector->size (), expected_size);
2136+ return kTfLiteError ;
2137+ }
2138+ TfLiteStatus status = FlatBufferIntVectorToArray (
2139+ sizeof (params_array), flatbuffer_vector, params_array,
2140+ error_reporter, " stablehlo.reduce_window" );
2141+ if (status != kTfLiteOk ) {
2142+ TF_LITE_REPORT_ERROR (error_reporter, " Check the '%s' attribute." ,
2143+ attr_name);
2144+ return status;
2145+ }
2146+ } else {
2147+ std::fill_n (params_array,
2148+ TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT,
2149+ fill_value);
2150+ }
2151+ return kTfLiteOk ;
2152+ };
2153+
2154+ if (TfLiteStatus status = LoadAttr (
2155+ params->window_dimensions , schema_params->window_dimensions (),
2156+ " window_dimensions" , /* expected_size=*/ rank, /* fill_value=*/ 1 );
2157+ status != kTfLiteOk ) {
2158+ return status;
2159+ }
2160+ if (TfLiteStatus status = LoadAttr (
2161+ params->window_strides , schema_params->window_strides (),
2162+ " window_strides" , /* expected_size=*/ rank, /* fill_value=*/ 1 );
2163+ status != kTfLiteOk ) {
2164+ return status;
2165+ }
2166+ if (TfLiteStatus status = LoadAttr (
2167+ params->base_dilations , schema_params->base_dilations (),
2168+ " base_dilations" , /* expected_size=*/ rank, /* fill_value=*/ 1 );
2169+ status != kTfLiteOk ) {
2170+ return status;
2171+ }
2172+ if (TfLiteStatus status = LoadAttr (
2173+ params->window_dilations , schema_params->window_dilations (),
2174+ " window_dilations" , /* expected_size=*/ rank, /* fill_value=*/ 1 );
2175+ status != kTfLiteOk ) {
2176+ return status;
2177+ }
2178+ if (TfLiteStatus status =
2179+ LoadAttr (params->padding , schema_params->padding (), " padding" ,
2180+ /* expected_size=*/ 2 * rank, /* fill_value=*/ 0 );
2181+ status != kTfLiteOk ) {
2182+ return status;
2183+ }
2184+
2185+ params->body_subgraph_index = schema_params->body_subgraph_index ();
2186+ *builtin_data = params.release ();
2187+ return kTfLiteOk ;
2188+ }
2189+ TF_LITE_REPORT_ERROR (
2190+ error_reporter,
2191+ " Could not get 'stablehlo.reduce_window' operation parameters." );
2192+ return kTfLiteError ;
2193+ }
2194+
20992195TfLiteStatus ParseStablehloScatter (const Operator* op,
21002196 ErrorReporter* error_reporter,
21012197 BuiltinDataAllocator* allocator,
0 commit comments