From 519fae019bae52f8f3151e32841132855277aa91 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 8 Jan 2025 06:24:26 +0800 Subject: [PATCH] [WebNN] Fix bug in SkipSimplifiedLayerNormalization (#23236) The input should be added by skip and bias (if it exits) firstly. --- .../builders/impl/normalization_op_builder.cc | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index d1c0f598b79f4..77f4bdce52a84 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -76,8 +76,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder options.set("epsilon", epsilon); emscripten::val output = emscripten::val::undefined(); - // SkipSimplifiedLayerNormalization's output: input_skip_bias_sum. - emscripten::val input_skip_bias_sum = emscripten::val::undefined(); if (op_type == "BatchNormalization") { ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs."); emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name()); @@ -107,7 +105,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder | | | | | | Y:2 axis B:epsilon A:X A:scale B:bias - If it is SkipSimplifiedLayerNormalization and its output input_skip_bias_sum exists, + If it is SkipSimplifiedLayerNormalization, X should be input_skip_bias_sum: input_skip_bias_sum = X + skip + bias (if it exists) */ @@ -115,6 +113,23 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input type"); emscripten::val common_options = emscripten::val::object(); + // If it is SkipSimplifiedLayerNormalization, add the skip and bias (if it exists) to the input. + if (op_type == "SkipSimplifiedLayerNormalization") { + emscripten::val skip = model_builder.GetOperand(input_defs[1]->Name()); + common_options.set("label", node.Name() + "_add_skip"); + input = model_builder.GetBuilder().call("add", input, skip, common_options); + if (!bias.isUndefined()) { + common_options.set("label", node.Name() + "_add_skip_bias"); + input = model_builder.GetBuilder().call("add", input, bias, common_options); + } + + // Add SkipSimplifiedLayerNormalization's output input_skip_bias_sum if it exists. + // Now input equals to input_skip_bias_sum. + if (TensorExists(output_defs, 3)) { + model_builder.AddOperand(output_defs[3]->Name(), input); + } + } + // Pow emscripten::val pow_constant = model_builder.CreateOrGetConstant(input_type, 2); common_options.set("label", node.Name() + "_pow"); @@ -146,24 +161,11 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder common_options.set("label", node.Name() + "_mul"); output = model_builder.GetBuilder().call("mul", scale, div, common_options); - // Add (if bias exits) + // Add (if bias exists) if (!bias.isUndefined()) { common_options.set("label", node.Name() + "_add_bias"); output = model_builder.GetBuilder().call("add", output, bias, common_options); } - - // SkipSimplifiedLayerNormalization's output input_skip_bias_sum is the sum of input, skip, and bias. - if (op_type == "SkipSimplifiedLayerNormalization" && TensorExists(output_defs, 3)) { - emscripten::val skip = model_builder.GetOperand(input_defs[1]->Name()); - common_options.set("label", node.Name() + "_add_skip"); - input_skip_bias_sum = model_builder.GetBuilder().call("add", input, skip, common_options); - if (!bias.isUndefined()) { - common_options.set("label", node.Name() + "_add_skip_bias"); - input_skip_bias_sum = model_builder.GetBuilder().call( - "add", input_skip_bias_sum, bias, common_options); - } - model_builder.AddOperand(output_defs[3]->Name(), std::move(input_skip_bias_sum)); - } } } else if (op_type == "InstanceNormalization") { // WebNN spec only supports 4D input for instanceNormalization.