@@ -15,6 +15,7 @@ limitations under the License.
15
15
16
16
#include " tensorflow/lite/core/api/flatbuffer_conversions.h"
17
17
18
+ #include < algorithm>
18
19
#include < cstddef>
19
20
#include < cstdint>
20
21
#include < memory>
@@ -881,6 +882,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
881
882
case BuiltinOperator_STABLEHLO_GATHER: {
882
883
return ParseStablehloGather (op, error_reporter, allocator, builtin_data);
883
884
}
885
+ case BuiltinOperator_STABLEHLO_REDUCE_WINDOW: {
886
+ return ParseStablehloReduceWindow (op, error_reporter, allocator,
887
+ builtin_data);
888
+ }
884
889
case BuiltinOperator_REDUCE_WINDOW: {
885
890
auto params = safe_allocator.Allocate <TfLiteReduceWindowParams>();
886
891
TF_LITE_ENSURE (error_reporter, params != nullptr );
@@ -949,7 +954,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
949
954
case BuiltinOperator_STABLEHLO_CONVERT:
950
955
case BuiltinOperator_STABLEHLO_PAD:
951
956
case BuiltinOperator_STABLEHLO_DOT_GENERAL:
952
- case BuiltinOperator_STABLEHLO_REDUCE_WINDOW:
953
957
case BuiltinOperator_STABLEHLO_SORT:
954
958
case BuiltinOperator_STABLEHLO_WHILE:
955
959
case BuiltinOperator_STABLEHLO_TRANSPOSE:
@@ -2096,6 +2100,88 @@ TfLiteStatus ParseResizeNearestNeighbor(const Operator* op,
2096
2100
return kTfLiteOk ;
2097
2101
}
2098
2102
2103
+ TfLiteStatus ParseStablehloReduceWindow (const Operator* op,
2104
+ ErrorReporter* error_reporter,
2105
+ BuiltinDataAllocator* allocator,
2106
+ void ** builtin_data) {
2107
+ CheckParsePointerParams (op, error_reporter, allocator, builtin_data);
2108
+
2109
+ SafeBuiltinDataAllocator safe_allocator (allocator);
2110
+ auto params = safe_allocator.Allocate <TfLiteStablehloReduceWindowParams>();
2111
+
2112
+ const StablehloReduceWindowOptions* schema_params =
2113
+ op->builtin_options_2_as_StablehloReduceWindowOptions ();
2114
+ if (schema_params) {
2115
+ if (!schema_params->window_dimensions () ||
2116
+ schema_params->window_dimensions ()->size () == 0 ) {
2117
+ TF_LITE_REPORT_ERROR (error_reporter,
2118
+ " 'window_dimensions' attribute is not optional for "
2119
+ " 'stablehlo.reduce_window' and cannot be empty." );
2120
+ return kTfLiteError ;
2121
+ }
2122
+
2123
+ const size_t rank = schema_params->window_dimensions ()->size ();
2124
+
2125
+ auto LoadAttr = [&error_reporter](
2126
+ int64_t * params_array, size_t params_array_size_bytes,
2127
+ const flatbuffers::Vector<int64_t >* flatbuffer_vector,
2128
+ const char * attr_name, const size_t expected_size,
2129
+ const int64_t fill_value) -> TfLiteStatus {
2130
+ if (flatbuffer_vector && flatbuffer_vector->size ()) {
2131
+ if (expected_size != 0 && flatbuffer_vector->size () != expected_size) {
2132
+ TF_LITE_REPORT_ERROR (
2133
+ error_reporter,
2134
+ " '%s' attribute of 'stablehlo.reduce_window' does not have the "
2135
+ " expected size (%llu != %llu)." ,
2136
+ attr_name, flatbuffer_vector->size (), expected_size);
2137
+ return kTfLiteError ;
2138
+ }
2139
+ TfLiteStatus status = FlatBufferIntVectorToArray (
2140
+ params_array_size_bytes, flatbuffer_vector, params_array,
2141
+ error_reporter, " stablehlo.reduce_window" );
2142
+ if (status != kTfLiteOk ) {
2143
+ TF_LITE_REPORT_ERROR (error_reporter, " Check the '%s' attribute." ,
2144
+ attr_name);
2145
+ return status;
2146
+ }
2147
+ } else {
2148
+ std::fill_n (params_array, params_array_size_bytes / sizeof (int64_t ),
2149
+ fill_value);
2150
+ }
2151
+ return kTfLiteOk ;
2152
+ };
2153
+
2154
+ TF_LITE_ENSURE_STATUS (
2155
+ LoadAttr (params->window_dimensions , sizeof (params->window_dimensions ),
2156
+ schema_params->window_dimensions (), " window_dimensions" ,
2157
+ /* expected_size=*/ rank, /* fill_value=*/ 1 ));
2158
+ TF_LITE_ENSURE_STATUS (
2159
+ LoadAttr (params->window_strides , sizeof (params->window_strides ),
2160
+ schema_params->window_strides (), " window_strides" ,
2161
+ /* expected_size=*/ rank, /* fill_value=*/ 1 ));
2162
+ TF_LITE_ENSURE_STATUS (
2163
+ LoadAttr (params->base_dilations , sizeof (params->base_dilations ),
2164
+ schema_params->base_dilations (), " base_dilations" ,
2165
+ /* expected_size=*/ rank, /* fill_value=*/ 1 ));
2166
+ TF_LITE_ENSURE_STATUS (
2167
+ LoadAttr (params->window_dilations , sizeof (params->window_dilations ),
2168
+ schema_params->window_dilations (), " window_dilations" ,
2169
+ /* expected_size=*/ rank, /* fill_value=*/ 1 ));
2170
+ TF_LITE_ENSURE_STATUS (LoadAttr (params->padding , sizeof (params->padding ),
2171
+ schema_params->padding (), " padding" ,
2172
+ /* expected_size=*/ 2 * rank,
2173
+ /* fill_value=*/ 0 ));
2174
+
2175
+ params->body_subgraph_index = schema_params->body_subgraph_index ();
2176
+ *builtin_data = params.release ();
2177
+ return kTfLiteOk ;
2178
+ }
2179
+ TF_LITE_REPORT_ERROR (
2180
+ error_reporter,
2181
+ " Could not get 'stablehlo.reduce_window' operation parameters." );
2182
+ return kTfLiteError ;
2183
+ }
2184
+
2099
2185
TfLiteStatus ParseStablehloScatter (const Operator* op,
2100
2186
ErrorReporter* error_reporter,
2101
2187
BuiltinDataAllocator* allocator,
0 commit comments