Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add of ReduceMax Gradient #23501

Merged
merged 2 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,66 @@ IMPLEMENT_GRADIENT_BUILDER(GetGlobalMaxPoolGradient) {

result.push_back(NodeDef("Expand", {GO(0), IA("X_shape")}, {IA("expanded_dY")}));
result.push_back(NodeDef("Mul", {IA("mask_cast"), IA("expanded_dY")}, {GI(0)}));
return result;
}

IMPLEMENT_GRADIENT_BUILDER(GetReduceMaxGradient) {
std::vector<NodeDef> result;
auto attributes = SrcNodeAttributes();
bool keepdims = true;

// Check the "keepdims" attribute
if (attributes.find("keepdims") != attributes.end() &&
attributes.at("keepdims").has_i()) {
keepdims = static_cast<bool>(attributes.at("keepdims").i());
}

ArgDef grad = GO(0);
ArgDef reduced_output = O(0);

if (!keepdims) {
size_t numInputs = GetSrcNodeInputSize();
ArgDef unsqueeze_axes_arg;
bool axes_provided = false;

// Handle "axes" as attribute or input
if (attributes.find("axes") != attributes.end()) {
axes_provided = true;
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
if (SrcNodeOpsetVersion() >= 13) {
NodeDef axes_values_node = ConstantVectorNode(axes_values, Name("axes_values"));
result.push_back(axes_values_node);
unsqueeze_axes_arg = axes_values_node.output_args[0];
}
} else if (numInputs == 2) {
axes_provided = true;
unsqueeze_axes_arg = I(1);
}

if (axes_provided) {
grad = IA("Unsqueezed_Grad");
reduced_output = IA("Unsqueezed_Output");
if (SrcNodeOpsetVersion() < 13 && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
result.push_back(NodeDef("Unsqueeze", {O(0)}, {reduced_output}, {MakeAttribute("axes", axes_values)}));
} else {
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), unsqueeze_axes_arg}, {grad}));
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {O(0), unsqueeze_axes_arg}, {reduced_output}));
}
}
}

// Step 1: Recreate the boolean mask tensor indicating max positions
result.push_back(NodeDef("Shape", {I(0)}, {IA("Shaped_X")}));
result.push_back(NodeDef("Expand", {reduced_output, IA("Shaped_X")}, {IA("Expanded_Output")}));
result.push_back(NodeDef("Equal", {I(0), IA("Expanded_Output")}, {IA("Mask")}));
// Step 2: Convert the boolean mask to a float tensor (0.0 and 1.0)
result.push_back(NodeDef("Cast", {IA("Mask")}, {IA("Mask_Float")}, {MakeAttribute("to", static_cast<int64_t>(OElemType(0)))}));
// Step 3: Multiply the input gradient by the mask
result.push_back(NodeDef("Mul", {grad, IA("Mask_Float")}, {IA("Masked_Grad")}));
// Step 4: Ensure the output gradient has the same shape as the input
result.push_back(NodeDef("Expand", {IA("Masked_Grad"), IA("Shaped_X")}, {GI(0)}));

return result;
}
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient)
DECLARE_GRADIENT_BUILDER(GetResizeGradient)
DECLARE_GRADIENT_BUILDER(GetAtanGradient)
DECLARE_GRADIENT_BUILDER(GetGlobalMaxPoolGradient)
DECLARE_GRADIENT_BUILDER(GetReduceMaxGradient)

DECLARE_GRADIENT_BUILDER(GetExternalGradient)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient);
REGISTER_GRADIENT_BUILDER("Atan", GetAtanGradient);
REGISTER_GRADIENT_BUILDER("GlobalMaxPool", GetGlobalMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("ReduceMax", GetReduceMaxGradient);

REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
};
Expand Down
24 changes: 24 additions & 0 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3379,6 +3379,30 @@ TEST(GradientCheckerTest, GlobalMaxPoolGrad) {
}
}

TEST(GradientCheckerTest, ReduceMaxGrad) {
// Attribute axes supports negative values from opset 11.
OpDef op_def_11{"ReduceMax", kOnnxDomain, 11};

RunReductionTests(op_def_11, false, true);

OpDef op_def_12{"ReduceMax", kOnnxDomain, 12};

RunReductionTests(op_def_12, false, true);

OpDef op_def_13{"ReduceMax", kOnnxDomain, 13};

RunReductionTests(op_def_13, false, true);

// axes is input from opset 18.
OpDef op_def_18{"ReduceMax", kOnnxDomain, 18};

RunReductionTests(op_def_18, true, true);

OpDef op_def_20{"ReduceMax", kOnnxDomain, 20};

RunReductionTests(op_def_20, true, true);
}

} // namespace test
} // namespace onnxruntime

Expand Down
Loading