diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 1ba52ca9e51c4..0f0a849bb83fc 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -2253,6 +2253,66 @@ IMPLEMENT_GRADIENT_BUILDER(GetGlobalMaxPoolGradient) { result.push_back(NodeDef("Expand", {GO(0), IA("X_shape")}, {IA("expanded_dY")})); result.push_back(NodeDef("Mul", {IA("mask_cast"), IA("expanded_dY")}, {GI(0)})); + return result; +} + +IMPLEMENT_GRADIENT_BUILDER(GetReduceMaxGradient) { + std::vector result; + auto attributes = SrcNodeAttributes(); + bool keepdims = true; + + // Check the "keepdims" attribute + if (attributes.find("keepdims") != attributes.end() && + attributes.at("keepdims").has_i()) { + keepdims = static_cast(attributes.at("keepdims").i()); + } + + ArgDef grad = GO(0); + ArgDef reduced_output = O(0); + + if (!keepdims) { + size_t numInputs = GetSrcNodeInputSize(); + ArgDef unsqueeze_axes_arg; + bool axes_provided = false; + + // Handle "axes" as attribute or input + if (attributes.find("axes") != attributes.end()) { + axes_provided = true; + std::vector axes_values = RetrieveValues(attributes.at("axes")); + if (SrcNodeOpsetVersion() >= 13) { + NodeDef axes_values_node = ConstantVectorNode(axes_values, Name("axes_values")); + result.push_back(axes_values_node); + unsqueeze_axes_arg = axes_values_node.output_args[0]; + } + } else if (numInputs == 2) { + axes_provided = true; + unsqueeze_axes_arg = I(1); + } + + if (axes_provided) { + grad = IA("Unsqueezed_Grad"); + reduced_output = IA("Unsqueezed_Output"); + if (SrcNodeOpsetVersion() < 13 && attributes.find("axes") != attributes.end()) { + std::vector axes_values = RetrieveValues(attributes.at("axes")); + result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)})); + result.push_back(NodeDef("Unsqueeze", {O(0)}, {reduced_output}, {MakeAttribute("axes", axes_values)})); + } else { + result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), unsqueeze_axes_arg}, {grad})); + result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {O(0), unsqueeze_axes_arg}, {reduced_output})); + } + } + } + + // Step 1: Recreate the boolean mask tensor indicating max positions + result.push_back(NodeDef("Shape", {I(0)}, {IA("Shaped_X")})); + result.push_back(NodeDef("Expand", {reduced_output, IA("Shaped_X")}, {IA("Expanded_Output")})); + result.push_back(NodeDef("Equal", {I(0), IA("Expanded_Output")}, {IA("Mask")})); + // Step 2: Convert the boolean mask to a float tensor (0.0 and 1.0) + result.push_back(NodeDef("Cast", {IA("Mask")}, {IA("Mask_Float")}, {MakeAttribute("to", static_cast(OElemType(0)))})); + // Step 3: Multiply the input gradient by the mask + result.push_back(NodeDef("Mul", {grad, IA("Mask_Float")}, {IA("Masked_Grad")})); + // Step 4: Ensure the output gradient has the same shape as the input + result.push_back(NodeDef("Expand", {IA("Masked_Grad"), IA("Shaped_X")}, {GI(0)})); return result; } diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 2611e742f342a..cf6a3f9f95c57 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -95,6 +95,7 @@ DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient) DECLARE_GRADIENT_BUILDER(GetResizeGradient) DECLARE_GRADIENT_BUILDER(GetAtanGradient) DECLARE_GRADIENT_BUILDER(GetGlobalMaxPoolGradient) +DECLARE_GRADIENT_BUILDER(GetReduceMaxGradient) DECLARE_GRADIENT_BUILDER(GetExternalGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index a04d909267142..845b8cd7ba2b4 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -127,6 +127,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient); REGISTER_GRADIENT_BUILDER("Atan", GetAtanGradient); REGISTER_GRADIENT_BUILDER("GlobalMaxPool", GetGlobalMaxPoolGradient); + REGISTER_GRADIENT_BUILDER("ReduceMax", GetReduceMaxGradient); REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient); }; diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index f4083d5b8f933..2d0181b69413c 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3379,6 +3379,30 @@ TEST(GradientCheckerTest, GlobalMaxPoolGrad) { } } +TEST(GradientCheckerTest, ReduceMaxGrad) { + // Attribute axes supports negative values from opset 11. + OpDef op_def_11{"ReduceMax", kOnnxDomain, 11}; + + RunReductionTests(op_def_11, false, true); + + OpDef op_def_12{"ReduceMax", kOnnxDomain, 12}; + + RunReductionTests(op_def_12, false, true); + + OpDef op_def_13{"ReduceMax", kOnnxDomain, 13}; + + RunReductionTests(op_def_13, false, true); + + // axes is input from opset 18. + OpDef op_def_18{"ReduceMax", kOnnxDomain, 18}; + + RunReductionTests(op_def_18, true, true); + + OpDef op_def_20{"ReduceMax", kOnnxDomain, 20}; + + RunReductionTests(op_def_20, true, true); +} + } // namespace test } // namespace onnxruntime