Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into snnn/update_docker_im…
Browse files Browse the repository at this point in the history
…ages
  • Loading branch information
snnn committed Jan 7, 2025
2 parents 83b2dbe + 4b0cee3 commit 9e190cd
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 31 deletions.
29 changes: 27 additions & 2 deletions onnxruntime/core/optimizer/pad_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace onnxruntime {

bool VerifyNotCastChild(const Node& child_node) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {7, 10, 11, 19}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) {
return false;
}
Expand All @@ -31,11 +31,32 @@ bool VerifyNotCastChild(const Node& child_node) {
return false;
}

if (child_node.OpType() == "AveragePool") {
// in case there's already padding and count_include_pad is 0, fusion can't be performed
auto has_pad = false;
if (child_node.GetAttributes().find("pads") != child_node.GetAttributes().end()) {
auto const& pads_values = child_node.GetAttributes().at("pads").ints();
if (!pads_values.empty()) {
has_pad = std::any_of(pads_values.begin(), pads_values.end(), [](int64_t value) { return value != 0; });
}
}
if (has_pad && child_node.GetAttributes().find("count_include_pad") != child_node.GetAttributes().end()) {
if (child_node.GetAttributes().at("count_include_pad").i() == 0) {
return false;
}
}
}

return true;
}

void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_values, const uint32_t pads_size) {
if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) {
auto reset_pads = true;
if (child_node.GetAttributes().find("pads") != child_node.GetAttributes().end()) {
/* pads can be empty, overwrite pads attribute in this case */
reset_pads = child_node.GetAttributes().at("pads").ints().empty();
}
if (reset_pads) {
std::vector<int64_t> pads(pads_size - 4, 0);
child_node.AddAttribute("pads", pads);
}
Expand All @@ -49,6 +70,10 @@ void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_v
uint32_t mirrored_pad_index = pads_index + (pads_size / 2);
child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]);
}

if (child_node.OpType() == "AveragePool") {
child_node.AddAttribute("count_include_pad", static_cast<int64_t>(1));
}
}
/*
* Before:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,7 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
.AddIndices(reshaped_output_shape)
.AddIndices(reshaped_lhs_shape)
.AddIndices(reshaped_rhs_shape)
.CacheHint("V" + absl::StrJoin({reshaped_lhs_shape.NumDimensions(),
reshaped_rhs_shape.NumDimensions(),
reshaped_output_shape.NumDimensions()},
";"));
.CacheHint("V");
} else {
// Mode Broadcast
// cache hint: "B"
Expand Down
23 changes: 18 additions & 5 deletions onnxruntime/core/providers/webgpu/program_cache_key.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace webgpu {

namespace {
// append the info of an input or output to the cachekey
void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency,
void AppendTensorInfo(std::ostream& ss, const TensorShape& tensor_shape, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency,
bool& first) {
if (first) {
first = false;
Expand All @@ -35,9 +35,9 @@ void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDat
}

if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) {
ss D("Dims=") << tensor.Shape().ToString();
ss D("Dims=") << tensor_shape.ToString();
} else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) {
ss D("Rank=") << tensor.Shape().NumDimensions();
ss D("Rank=") << tensor_shape.NumDimensions();
}
}
} // namespace
Expand Down Expand Up @@ -97,13 +97,26 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp
ss << ":" D("Inputs=");
first = true;
for (const auto& input : program.Inputs()) {
AppendTensorInfo(ss, *input.tensor, input.var_type, input.dependency, first);
AppendTensorInfo(ss, input.use_override_shape ? input.override_shape : input.tensor->Shape(), input.var_type, input.dependency, first);
}

ss << ":" D("Outputs=");
first = true;
for (const auto& output : program.Outputs()) {
AppendTensorInfo(ss, *output.tensor, output.var_type, output.dependency, first);
AppendTensorInfo(ss, output.use_override_shape ? output.override_shape : output.tensor->Shape(), output.var_type, output.dependency, first);
}

if (!program.Indices().empty()) {
ss << ":" D("Indices=");
first = true;
for (const auto& indices_shape : program.Indices()) {
if (first) {
first = false;
} else {
ss << '|';
}
ss D("Rank=") << indices_shape.NumDimensions();
}
}

return SS_GET(ss);
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webgpu/tensor/where.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ Status Where::ComputeInternal(ComputeContext& context) const {
program
.CacheHint(is_broadcast)
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Rank, {(cond_shape.Size() + 3) / 4}, 4},
{x_tensor, ProgramTensorMetadataDependency::Rank, {(x_shape.Size() + 3) / 4}, 4},
{y_tensor, ProgramTensorMetadataDependency::Rank, {(y_shape.Size() + 3) / 4}, 4}})
.AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Type, {(cond_shape.Size() + 3) / 4}, 4},
{x_tensor, ProgramTensorMetadataDependency::Type, {(x_shape.Size() + 3) / 4}, 4},
{y_tensor, ProgramTensorMetadataDependency::Type, {(y_shape.Size() + 3) / 4}, 4}})
.AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
.AddUniformVariables({
{static_cast<uint32_t>(vec_size)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
options.set("epsilon", epsilon);

emscripten::val output = emscripten::val::undefined();
// SkipSimplifiedLayerNormalization's output: input_skip_bias_sum.
emscripten::val input_skip_bias_sum = emscripten::val::undefined();
if (op_type == "BatchNormalization") {
ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs.");
emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name());
Expand Down Expand Up @@ -107,14 +105,31 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
| | | | | |
Y:2 axis B:epsilon A:X A:scale B:bias
If it is SkipSimplifiedLayerNormalization and its output input_skip_bias_sum exists,
If it is SkipSimplifiedLayerNormalization, X should be input_skip_bias_sum:
input_skip_bias_sum = X + skip + bias (if it exists)
*/

int32_t input_type;
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input type");
emscripten::val common_options = emscripten::val::object();

// If it is SkipSimplifiedLayerNormalization, add the skip and bias (if it exists) to the input.
if (op_type == "SkipSimplifiedLayerNormalization") {
emscripten::val skip = model_builder.GetOperand(input_defs[1]->Name());
common_options.set("label", node.Name() + "_add_skip");
input = model_builder.GetBuilder().call<emscripten::val>("add", input, skip, common_options);
if (!bias.isUndefined()) {
common_options.set("label", node.Name() + "_add_skip_bias");
input = model_builder.GetBuilder().call<emscripten::val>("add", input, bias, common_options);
}

// Add SkipSimplifiedLayerNormalization's output input_skip_bias_sum if it exists.
// Now input equals to input_skip_bias_sum.
if (TensorExists(output_defs, 3)) {
model_builder.AddOperand(output_defs[3]->Name(), input);
}
}

// Pow
emscripten::val pow_constant = model_builder.CreateOrGetConstant<float>(input_type, 2);
common_options.set("label", node.Name() + "_pow");
Expand Down Expand Up @@ -146,24 +161,11 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
common_options.set("label", node.Name() + "_mul");
output = model_builder.GetBuilder().call<emscripten::val>("mul", scale, div, common_options);

// Add (if bias exits)
// Add (if bias exists)
if (!bias.isUndefined()) {
common_options.set("label", node.Name() + "_add_bias");
output = model_builder.GetBuilder().call<emscripten::val>("add", output, bias, common_options);
}

// SkipSimplifiedLayerNormalization's output input_skip_bias_sum is the sum of input, skip, and bias.
if (op_type == "SkipSimplifiedLayerNormalization" && TensorExists(output_defs, 3)) {
emscripten::val skip = model_builder.GetOperand(input_defs[1]->Name());
common_options.set("label", node.Name() + "_add_skip");
input_skip_bias_sum = model_builder.GetBuilder().call<emscripten::val>("add", input, skip, common_options);
if (!bias.isUndefined()) {
common_options.set("label", node.Name() + "_add_skip_bias");
input_skip_bias_sum = model_builder.GetBuilder().call<emscripten::val>(
"add", input_skip_bias_sum, bias, common_options);
}
model_builder.AddOperand(output_defs[3]->Name(), std::move(input_skip_bias_sum));
}
}
} else if (op_type == "InstanceNormalization") {
// WebNN spec only supports 4D input for instanceNormalization.
Expand Down
122 changes: 122 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,128 @@ TEST_F(GraphTransformationTests, FusePadWithMaxPoolOpsetLessThan11) {
}
}

TEST_F(GraphTransformationTests, FusePadWithAvgPool) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-avgpool.onnx";

std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

std::vector<int64_t> expected_pads;
GraphViewer graphViewer(graph);
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
auto& node = *graph.GetNode(node_index);
if (node.OpType() == "Pad") {
auto const& pads_proto = node.GetAttributes().at("pads").ints();
gsl::span<const int64_t> pads_values = gsl::make_span(pads_proto.data(), pads_proto.size());
expected_pads.resize(pads_values.size() - 4);
for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) {
expected_pads[index] = pads_values[pads_index];
expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)];
}
}
}

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<PadFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Pad"], 0);
ASSERT_EQ(op_to_count["AveragePool"], 1);

for (auto& node : graph.Nodes()) {
if (node.OpType() == "AveragePool") {
auto const& child_pads = node.GetAttributes().at("pads").ints();
auto const& count_include_pad = node.GetAttributes().at("count_include_pad");
ASSERT_NE(count_include_pad.i(), 0) << "fusion should ensure count_include_pad!=0";
ASSERT_EQ(child_pads.size(), static_cast<int32_t>(expected_pads.size()))
<< "fusion should produce the same size of pads integer as the AvgPool node";
for (uint32_t index = 0; index < expected_pads.size(); index++) {
ASSERT_EQ(expected_pads[index], child_pads.Get(index))
<< "fusion does not produce correct padding value";
}
}
}
}

TEST_F(GraphTransformationTests, FusePadWithAvgPoolWithPad) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-avgpool_with_pad.onnx";

std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

std::vector<int64_t> expected_pads;
GraphViewer graphViewer(graph);
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
auto& node = *graph.GetNode(node_index);
if (node.OpType() == "Pad") {
auto const& pads_proto = node.GetAttributes().at("pads").ints();
gsl::span<const int64_t> pads_values = gsl::make_span(pads_proto.data(), pads_proto.size());
expected_pads.resize(pads_values.size() - 4);

for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) {
expected_pads[index] = pads_values[pads_index];
expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)];
}
} else if (node.OpType() == "AveragePool") {
auto const& child_pads = node.GetAttributes().at("pads").ints();
for (uint32_t index = 0; index < expected_pads.size(); index++) {
expected_pads[index] += child_pads.Get(index);
}
}
}

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<PadFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Pad"], 0);
ASSERT_EQ(op_to_count["AveragePool"], 1);

for (auto& node : graph.Nodes()) {
if (node.OpType() == "AveragePool") {
auto const& child_pads = node.GetAttributes().at("pads").ints();
auto const& count_include_pad = node.GetAttributes().at("count_include_pad");
ASSERT_NE(count_include_pad.i(), 0) << "fusion should ensure count_include_pad!=0";
ASSERT_EQ(child_pads.size(), static_cast<int32_t>(expected_pads.size()))
<< "fusion should produce the same size of pads integer as the AvgPool node";
for (uint32_t index = 0; index < expected_pads.size(); index++) {
ASSERT_EQ(expected_pads[index], child_pads.Get(index))
<< "fusion does not produce correct padding value";
}
}
}
}

// should not fuse
TEST_F(GraphTransformationTests, FusePadWithAvgPoolWithPadNoInclude) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-avgpool_with_pad-nofuse.onnx";

std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<PadFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Pad"], 1);
ASSERT_EQ(op_to_count["AveragePool"], 1);
}

TEST_F(GraphTransformationTests, FuseMatmulBNWithInBetweenNodes) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx";

Expand Down
68 changes: 68 additions & 0 deletions onnxruntime/test/testdata/transform/fusion/fuse-pad-avgpool-gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from pathlib import Path

import numpy as np
import onnx

HERE = Path(__file__).parent.resolve(strict=True)
TEST = False

if TEST:
import onnxruntime


def generate_fuse_pad_avgpool():
parameters = {
"fuse-pad-avgpool": (
{},
[[1.333333, 2.333333, 1.777778], [3.0, 5.0, 3.666667], [2.666667, 4.333333, 3.111111]],
),
"fuse-pad-avgpool_with_pad": (
{"pads": [1, 1, 0, 0], "count_include_pad": 1},
[
[0.111111, 0.333333, 0.666667, 0.555556],
[0.555556, 1.333333, 2.333333, 1.777778],
[1.333333, 3.0, 5.0, 3.666667],
[1.222222, 2.666667, 4.333333, 3.111111],
],
),
"fuse-pad-avgpool_with_pad-nofuse": (
{"pads": [1, 1, 0, 0]},
[
[0.25, 0.5, 1.0, 0.833333],
[0.833333, 1.333333, 2.333333, 1.777778],
[2.0, 3.0, 5.0, 3.666667],
[1.833333, 2.666667, 4.333333, 3.111111],
],
),
}
for name in parameters:
model_path = HERE / f"{name}.onnx"
input_ = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, (1, 1, 3, 3))
pad = onnx.helper.make_node("Pad", ["input"], ["tp"], mode="constant", pads=[0, 0, 1, 1, 0, 0, 1, 1])
pool = onnx.helper.make_node("AveragePool", ["tp"], ["output"], kernel_shape=[3, 3], **parameters[name][0])
nodes = [pad, pool]
output_shape = (1, 1, 3, 3) if name == "fuse-pad-avgpool" else (1, 1, 4, 4)
output_ = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, output_shape)
graph = onnx.helper.make_graph(nodes, name, [input_], [output_])
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 7)])
onnx.checker.check_model(model)
onnx.save_model(model, model_path)
if TEST:
input_array = np.array([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=np.float32)
expected = np.array(parameters[name][1], dtype=np.float32)
session_options = onnxruntime.SessionOptions()
session_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
session = onnxruntime.InferenceSession(model_path, session_options)
out = session.run(["output"], {"input": input_array})
actual = out[0].squeeze()
np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=0.0)
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
session = onnxruntime.InferenceSession(model_path, session_options)
out = session.run(["output"], {"input": input_array})
actual = out[0].squeeze()
np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=0.0)


if __name__ == "__main__":
generate_fuse_pad_avgpool()
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 9e190cd

Please sign in to comment.