diff --git a/CMakeLists.txt b/CMakeLists.txt index 758fe426db2..5928809907b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -325,6 +325,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/polymorphic_value.cpp ${NVFUSER_SRCS_DIR}/predicate_compute.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/add_axioms.cpp + ${NVFUSER_SRCS_DIR}/preseg_passes/fmin_fmax_promotion.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/allocation_order_inference.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/consecutive_cast.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/exact_mapped_extent_substitution.cpp diff --git a/benchmarks/python/test_softmax_fwd.py b/benchmarks/python/test_softmax_fwd.py index f12581e9d55..329ee7a26fd 100644 --- a/benchmarks/python/test_softmax_fwd.py +++ b/benchmarks/python/test_softmax_fwd.py @@ -24,24 +24,17 @@ def softmax_fwd_fusion( T0 = fd.ops.cast(T0, dtype=DataType.Float) T2 = fd.ops.max(T0, dims=[reduction_axis], keepdim=False, dtype=DataType.Null) - if reduction_axis: - shape_v6 = [T0.size(0), 1] - else: - shape_v6 = [1, T0.size(1)] - bcast_dim = 1 - reduction_axis + bcast_dims = [False, False] + bcast_dims[reduction_axis] = True - T7 = fd.ops.broadcast_in_dim(T2, shape=shape_v6, broadcast_dims=[bcast_dim]) - - V11 = T0.shape() - T12 = fd.ops.broadcast_in_dim(T7, shape=V11, broadcast_dims=[0, 1]) - T13 = fd.ops.sub(T0, T12) + T7 = fd.ops.broadcast(T2, is_broadcast_dim=bcast_dims) + T13 = fd.ops.sub(T0, T7) T14 = fd.ops.exp(T13) T15 = fd.ops.sum(T14, dims=[reduction_axis], keepdim=False, dtype=DataType.Null) - T20 = fd.ops.broadcast_in_dim(T15, shape=shape_v6, broadcast_dims=[bcast_dim]) - T25 = fd.ops.broadcast_in_dim(T20, shape=V11, broadcast_dims=[0, 1]) + T20 = fd.ops.broadcast(T15, is_broadcast_dim=bcast_dims) - T26 = fd.ops.reciprocal(T25) + T26 = fd.ops.reciprocal(T20) T27 = fd.ops.mul(T14, T26) if dtype in PROMOTE_DTYPES: diff --git a/csrc/preseg_passes/fmin_fmax_promotion.cpp b/csrc/preseg_passes/fmin_fmax_promotion.cpp new file mode 100644 index 00000000000..c24a5f4b2c6 --- /dev/null +++ b/csrc/preseg_passes/fmin_fmax_promotion.cpp @@ -0,0 +1,351 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include +#include + +#include +#include +#include + +namespace nvfuser::preseg_passes { + +namespace { + +// To analyze a single min or max reduction op, the "target" op, we perform +// a downstream dataflow analysis to detect whether a NAN squelched by the +// promotion will reach any fusion outputs, or if it will be repaired by a +// downstream "safe" reductions. +// +// The entire analysis happens on TensorViews, not IterDomains. +enum class NanStatus { + // "None" status corresponds to a lack of relevant information. This status is + // the default when looking up an un-tracked node in the NanStatusMap. This + // can happen when checking e.g. an input to a binary op, which is a TV from + // some other part of the fusion not traversed during analysis. "None" is the + // lowest precedence state, everything else overwrites it. + None = 0, + + // "Unreduced" is the state attached to the input of the target reduction op. + // It becomes "GoodReduced" if it passes through a safe reduction. It's safe + // in the sense that it is allowed to reach fusion outputs. But it won't fix + // a bad status. + Unreduced, + + // "BadReduced" is attached to the output of the target reduction op. It + // corresponds to a TV that has its data modified by the fmin/fmax promotion. + // If it reaches an output node, we know we cannot do the fmin/fmax promotion + // because it might change fusion output data. + BadReduced, + + // "Mixed" is a combination of "Unreduced" and "BadReduced" states. It + // contains the NANs from an "Unreduced" state, so a safe reduction will + // transform it into a "GoodReduced" state. But it is also downstream of an + // unrepaired bad reduction, so if it reaches an output, that output may have + // lost NAN data. + Mixed, + + // "GoodReduced" is the status that is reached after "Unreduced" data passes + // through a safe reduction. It is the highest precedence state, and repairs + // all other states. + GoodReduced, +}; + +// An example of what each status looks like is below, with *max* reduction: +// +// [0.0 1.0 2.0 3.0 NAN 5.0] <- Unreduced +// [5.0, 5.0, 5.0, 5.0, 5.0, 5.0] <- BadReduced +// [NAN, NAN, NAN, NAN, NAN, NAN] <- GoodReduced +// [5.0 5.0 5.0 5.0 NAN 5.0] <- Mixed +// +// Note that "Mixed" appears the same as "Unreduced" - what is the difference? +// The difference is that "Mixed" signals that a node is downstream of the bad +// reduction. If the bad reduction had been a good reduction, then a TV with +// a "Mixed" state would have data like a "GoodReduced" node. + +// --- Dataflow Example: +// Let tv0, tv1 be fusion input TVs. Then here are the NanStatus assigned to +// each TV in an example fusion subgraph: +// tv0 NONE +// tv1 Unreduced +// tv2 = max(tv1, {0, 1}) BadReduced +// tv3 = broadcast(tv2, {true, true}) BadReduced +// tv4 = add(tv3, tv1) Mixed +// tv5 = sum(tv4, {0, 1}) GoodReduced +// tv6 = broadcast(tv5, {true, true}) GoodReduced +// tv7 = add(tv6, tv5) GoodReduced +// tv8 = add(tv7, tv0) GoodReduced +// fusion->addOutput(tv8) +// +// ---- Step-by-step: +// 1. All TV's have a None state by default. +// 2. The analysis is launched on the max reduction which produces tv2 +// 3. We assign tv2 a BadReduced status, and tv1 an Unreduced status +// 4. We traverse to the broadcast producing tv3. We propagate the BadReduced +// state to tv3. This is the first broadcast op we have seen, so we save it +// to subsequently enforce all broadcasts match its axes. +// 5. We traverse the tv4 add expr. This is a binary op between an Unreduced and +// a BadReduced state, so tv4 gets a Mixed state. +// 6. We traverse the tv5 sum expr. This is a safe reduction of a mixed state. +// Since Mixed state carries the original NAN data that entered tv2, this +// safe reduction creates a GoodReduced state. We also checked to make sure +// this reduction's axes match the target reduction's axes. +// 7. We traverse the tv6 broadcast expr. We already saw a broadcast before, so +// we ensure that this broadcast matches the axes of the prior one. It does, +// so we prop the GoodReduced state and continue. +// 8. The subsequent add() simply propagate GoodReduced state since it is the +// highest priority. +// 9. No output TV's contain a BadReduced or Mixed state, so the max() op can +// be promoted. + +using NanStatusMap = std::unordered_map; +using PromotedOpSet = std::unordered_set; + +bool isSafeReduction(Expr* expr, const PromotedOpSet& promotedOps) { + if (auto* rop = dynamic_cast(expr)) { + // Check that this expr hasn't already been promoted to an unsafe reduction. + return !promotedOps.contains(rop); + } + + return false; +} + +bool reductionMatches(ReductionOp* left, ReductionOp* right) { + auto* left_tv = dynamic_cast(left->output(0)); + auto* right_tv = dynamic_cast(right->output(0)); + + if (left_tv->nDims() != right_tv->nDims()) { + return false; + } + + for (int i = 0; i < left_tv->nDims(); ++i) { + if (left_tv->getLogicalDomain()[i]->isReduction() != + right_tv->getLogicalDomain()[i]->isReduction()) { + return false; + } + } + + return true; +} + +bool broadcastMatches(BroadcastOp* left, BroadcastOp* right) { + auto* left_tv = dynamic_cast(left->output(0)); + auto* right_tv = dynamic_cast(right->output(0)); + + if (left_tv->nDims() != right_tv->nDims()) { + return false; + } + + for (int i = 0; i < left_tv->nDims(); ++i) { + if (left_tv->getLogicalDomain()[i]->isBroadcast() != + right_tv->getLogicalDomain()[i]->isBroadcast()) { + return false; + } + } + + return true; +} + +bool canBeAnalyzed( + Expr* expr, + ReductionOp* compareRop, + std::optional& compareBop) { + // This is where we enforce the restricted-subgraph rules. Arbitrary unary + // and binary ops are allowed, and do not affect the analysis. Reduction and + // broadcasts have strict requirements to simplify the state tracking. + + if (expr->isA() || expr->isA()) { + return true; + } else if (auto* rop = dynamic_cast(expr)) { + // We require all reduction ops exactly match in reduction axes. + // This avoids the need for complicated IterDomain handling. + return reductionMatches(rop, compareRop); + } else if (auto* bop = dynamic_cast(expr)) { + // Similarly for reductions, we require all broadcasts to have the same + // axes. + if (!compareBop) { + compareBop = bop; + return true; + } else { + return broadcastMatches(bop, *compareBop); + } + } + + return false; +} + +// Traverses the restricted subgraph around the target rop and checks whether +// NANs which would be squelched by a promotion, will be subsequently repaired +// by safe reductions. +bool minMaxOpIsRepaired( + ReductionOp* targetRop, + const PromotedOpSet& promotedOps) { + Fusion* fusion = targetRop->fusion(); + + auto* in_tv = targetRop->input(0)->as(); + auto* out_tv = targetRop->output(0)->as(); + + NanStatusMap status_map; + + status_map.emplace(in_tv, NanStatus::Unreduced); + status_map.emplace(out_tv, NanStatus::BadReduced); + + std::optional broadcastMatcher; + + // Topological traversal downstream of the targetRop input. + // Note we start from the input, not the output, of the targetRop, because + // we need to track the Unreduced state, so it can make repairs. + auto traversal = + StmtSort::getExprsBetween({targetRop->input(0)}, fusion->outputs()); + + for (Expr* expr : traversal) { + if (expr == targetRop) { + // Skip the target rop. We already marked its status. + continue; + } + + // Get aggregate status from all inputs. + bool anyUnreduced = false; + bool anyBadReduced = false; + bool anyMixed = false; + bool anyGoodReduced = false; + + for (auto input : expr->inputs()) { + if (auto* in_tv = dynamic_cast(input)) { + NanStatus status = NanStatus::None; + + auto it = status_map.find(in_tv); + if (it != status_map.end()) { + status = it->second; + } + + if (status == NanStatus::Unreduced) { + anyUnreduced = true; + } + if (status == NanStatus::BadReduced) { + anyBadReduced = true; + } + if (status == NanStatus::Mixed) { + anyMixed = true; + } + if (status == NanStatus::GoodReduced) { + anyGoodReduced = true; + } + } + } + + if (!canBeAnalyzed(expr, targetRop, broadcastMatcher)) { + // Analysis is blocked for this node, treat it like a fusion output. + if (anyBadReduced || anyMixed) { + return false; + } else { + continue; + } + } + + NanStatus status = NanStatus::None; + // Determine this node's status based on its inputs. + // Status is mostly propped based on priority. For example, GoodReduced + // beats all other states. There is also one combination rule with + // BadReduced and Unreduced combining to become Mixed. + if (anyGoodReduced) { + status = NanStatus::GoodReduced; + } else if (anyMixed) { + status = NanStatus::Mixed; + } else if (anyUnreduced && anyBadReduced) { + status = NanStatus::Mixed; + } else if (anyUnreduced) { + status = NanStatus::Unreduced; + } else if (anyBadReduced) { + status = NanStatus::BadReduced; + } + + if (isSafeReduction(expr, promotedOps)) { + if (status == NanStatus::Unreduced || status == NanStatus::Mixed) { + // Unreduced and Mixed states both indicate the targetRop's input has + // propagated here pointwise, preserving its NAN values in unchanged + // positions. Therefore, this reduction will create the tensor with + // reduced NAN values matching the original targetRop if it propagated + // its NANs. + status = NanStatus::GoodReduced; + } + } + + auto* out_tv = dynamic_cast(expr->output(0)); + + status_map.emplace(out_tv, status); + } + + // Check whether any bad status reached output nodes + auto output_tvs = ir_utils::filterByType(fusion->outputs()); + for (TensorView* out_tv : output_tvs) { + NanStatus status = NanStatus::None; + + auto it = status_map.find(out_tv); + if (it != status_map.end()) { + status = it->second; + } + + if (status == NanStatus::BadReduced || status == NanStatus::Mixed) { + return false; + } + } + + return true; +} + +} // namespace + +void FMinFMaxPromotionPass::runPass(Fusion* fusion) { + FusionGuard fusion_guard(fusion); + + PromotedOpSet promotedOps; + + // This outer loop runs over all expressions, filtering for min/max + // reductions, which become the target for the rest of the analysis. + for (Expr* targetExpr : fusion->exprs()) { + auto* targetRop = dynamic_cast(targetExpr); + + if (!targetRop) { + continue; + } + + auto reduction_type = targetRop->getReductionOpType(); + + if (reduction_type == BinaryOpType::Min || + reduction_type == BinaryOpType::Max) { + if (minMaxOpIsRepaired(targetRop, promotedOps)) { + promotedOps.insert(targetRop); + } + } + } + + for (auto* rop : promotedOps) { + // Promote the reduction ops by doing expression replacement + auto red_op_type = rop->getReductionOpType(); + auto init = rop->init(); + auto out = rop->out(); + auto in = rop->in(); + + if (red_op_type == BinaryOpType::Max) { + red_op_type = BinaryOpType::FMax; + } + if (red_op_type == BinaryOpType::Min) { + red_op_type = BinaryOpType::FMin; + } + + fusion->removeExpr(rop); + IrBuilder::create(red_op_type, init, out, in, true); + } + + return; +} + +} // namespace nvfuser::preseg_passes diff --git a/csrc/preseg_passes/fmin_fmax_promotion.h b/csrc/preseg_passes/fmin_fmax_promotion.h new file mode 100644 index 00000000000..e9b026ebc24 --- /dev/null +++ b/csrc/preseg_passes/fmin_fmax_promotion.h @@ -0,0 +1,70 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +namespace nvfuser::preseg_passes { + +// Cuda has fmax() and fmin() functions that are faster than our max() and min() +// helper functions. However fmax and fmin don't propagate NAN values, so we +// must generally use max() and min() in kernels (Pytorch behavior is to +// propagate NANs). +// +// Normalization is a common fusion pattern whereby a max or min reduction is +// followed by a sum reduction, and the result is combined with a binary op. +// For example: +// +// tv1 = max(tv0, {0, 1}) +// tv2 = broadcast(tv1, {true, true}) +// tv3 = add(tv2, tv0) +// tv4 = sum(tv3, {0, 1}) +// tv5 = broadcast(tv4, {true, true}) +// tv6 = add(tv5, tv3) +// +// For this example, it would actually be OK if the max was performed using +// fmax, because the NAN values will flow into the sum(), and will be combined +// during the final add(). Any loss of NANs that occur at fmax will be repaired, +// assuming tv1 through tv5 aren't consumed elsewhere in the fusion. +// +// The purpose of this pass is to identify situations like this, and "promote" +// max and min ops into fmax and fmin where possible. +// +// The scope of this analysis can be quite large, and could e.g. apply to +// pointwise min/max as well as reductions. However this pass currently only +// targets normalization-style cases, so the promotion algorithm is simplified +// with the following restrictions: +// +// 1. Only promotes min() and max() reduction ops, not pointwise min and max. +// 2. Analyzes a restricted subgraph around the "target" min/max reduction. +// Specifically, we only analyze UnaryOp, BinaryOp, ReductionOp and +// BroadcastOp. +// 3. Limited support for reduction and broadcast axes. All ReductionOps in the +// subgraph must match the target reduction axes. Likewise with BroadcastOps, +// if we encounter a single broadcast, it becomes the structure that all +// broadcasts in the subgraph must conform to. This simplifies the analysis +// and avoids the need for a complicated IterDomain propagation and +// interaction tracker. +// 4. Restricted subgraph-input analysis. We start from the input of the target +// ReductionOp, and we do not look any further upstream. This means we +// conservatively reject the following example: +// tv1 = abs(tv0) +// tv2 = max(tv1, {0}) +// tv3 = sum(tv0, {0}) +// tv4 = add(tv2, tv3) +class FMinFMaxPromotionPass : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); + static constexpr std::string_view name() { + return "FMinFMaxPromotionPass"; + } +}; + +} // namespace nvfuser::preseg_passes diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 8d60505cf04..519f5978e87 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -51,6 +52,7 @@ namespace nvfuser::preseg_passes { // removes consecutive cast operations OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); // MovePadPass needs to happen: // 1. before MarkAliasPrepare; and diff --git a/tests/cpp/test_math_opt.cpp b/tests/cpp/test_math_opt.cpp index 46dffc63edd..15389a5e922 100644 --- a/tests/cpp/test_math_opt.cpp +++ b/tests/cpp/test_math_opt.cpp @@ -109,4 +109,194 @@ INSTANTIATE_TEST_SUITE_P( return sanitizeTestName(ss.str()); }); +class FMinFMaxPromotionTest : public NVFuserTest { + protected: + void SetUp() override { + NVFuserTest::SetUp(); + + fusion_ = std::make_unique(); + fg_ = std::make_unique(fusion_.get()); + + in_tv0_ = makeSymbolicTensor(2); + fusion_->addInput(in_tv0_); + in_tv1_ = makeSymbolicTensor(2); + fusion_->addInput(in_tv1_); + in_tv2_ = makeSymbolicTensor(2); + fusion_->addInput(in_tv2_); + } + + void TearDown() override { + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({32, 32}, options); + t0[0][1] = std::numeric_limits::quiet_NaN(); + + auto t1 = at::randn({32, 32}, options); + auto t2 = at::randn({32, 32}, options); + + FusionExecutorCache executor_cache(std::move(fusion_)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); + + testValidate( + executor_cache.fusion(), outputs, {t0, t1, t2}, __LINE__, __FILE__); + + auto kernel_runtime = executor_cache.getMostRecentKernelRuntime(); + + bool anyFMax = false; + for (auto& segment : kernel_runtime->fusionSegments()->groups()) { + const auto* ke = kernel_runtime->executors() + .at(segment->groupId()) + ->as(); + std::string kernel_code = ke->compiledKernel()->kernelString(); + if (kernel_code.find("fmax(") != std::string::npos) { + anyFMax = true; + } + } + + EXPECT_EQ(anyFMax, should_promote_fmax_); + } + + std::unique_ptr fusion_; + std::unique_ptr fg_; + TensorView* in_tv0_; + TensorView* in_tv1_; + TensorView* in_tv2_; + bool should_promote_fmax_ = false; +}; + +// The most basic case of promotion. The sum covers the max reduction. +TEST_F(FMinFMaxPromotionTest, BasicMaxSum) { + TensorView* tv1 = max(in_tv0_, {0, 1}); + TensorView* tv2 = sum(in_tv0_, {0, 1}); + // At tv3, the damage done by an fmax promotion is repaired by the BinaryOp + // with tv2. + TensorView* tv3 = add(tv1, tv2); + fusion_->addOutput(tv3); + should_promote_fmax_ = true; +} + +// Like BasicMaxSum but reducing over different axes, so sum doesn't cover max. +TEST_F(FMinFMaxPromotionTest, MaxSumDifferentAxes) { + TensorView* tv1 = max(in_tv0_, {0}); + TensorView* tv2 = sum(in_tv0_, {1}); + TensorView* tv3 = add(tv1, tv2); + fusion_->addOutput(tv3); + should_promote_fmax_ = false; +} + +// Like BasicMaxSum, but the tensors are different, so sum doesn't cover max. +TEST_F(FMinFMaxPromotionTest, MaxSumDifferentTensorViews) { + TensorView* tv1 = max(in_tv0_, {0}); + TensorView* tv2 = sum(in_tv1_, {0}); + TensorView* tv3 = add(tv1, tv2); + fusion_->addOutput(tv3); + should_promote_fmax_ = false; +} + +// Like BasicMaxSum but with unary ops inserted. +// Unary ops should not affect the promotion at all. +TEST_F(FMinFMaxPromotionTest, MaxSumSameAxesUnary) { + TensorView* tv1 = max(in_tv0_, {0, 1}); + TensorView* tv2 = sum(in_tv0_, {0, 1}); + TensorView* tv3 = abs(tv1); + TensorView* tv4 = abs(tv2); + TensorView* tv5 = add(tv3, tv4); + fusion_->addOutput(tv5); + should_promote_fmax_ = true; +} + +// Like BasicMaxSum but with binary ops connected to unrelated inputs. +// Like unary ops, binary ops with unrelated inputs do not affect the promotion. +TEST_F(FMinFMaxPromotionTest, MaxSumSameAxesBinary) { + TensorView* tv1 = max(in_tv0_, {0, 1}); + TensorView* tv2 = sum(in_tv0_, {0, 1}); + TensorView* tv3 = broadcast(tv1, {true, true}); + TensorView* tv4 = broadcast(tv2, {true, true}); + TensorView* tv5 = add(tv3, in_tv1_); + TensorView* tv6 = add(tv4, in_tv2_); + TensorView* tv7 = add(tv5, tv6); + fusion_->addOutput(tv7); + should_promote_fmax_ = true; +} + +// The axes are repaired separately by multiple safe reductions +// Although this is safe to promote, the current algorithm cannot verify it. +TEST_F(FMinFMaxPromotionTest, MultiStageRepair) { + TensorView* tv1 = max(in_tv0_, {0, 1}); + TensorView* tv2 = sum(in_tv0_, {1}); + TensorView* tv3 = sum(tv2, {0}); + TensorView* tv4 = add(tv1, tv3); + fusion_->addOutput(tv4); + should_promote_fmax_ = false; +} + +// Here the reductions broadcast up to 2D along different axes. +// They are basically transposed with each other, and repair doesn't happen. +TEST_F(FMinFMaxPromotionTest, WrongBroadcast) { + TensorView* tv1 = max(in_tv0_, {1}); + TensorView* tv2 = sum(in_tv0_, {1}); + TensorView* tv3 = broadcast(tv1, {true, false}); + TensorView* tv4 = broadcast(tv2, {false, true}); + TensorView* tv5 = add(tv3, tv4); + fusion_->addOutput(tv5); + should_promote_fmax_ = false; +} + +// Normalization pattern requiring a mixed state +TEST_F(FMinFMaxPromotionTest, Normalization) { + TensorView* tv1 = max(in_tv0_, {1}); + TensorView* tv2 = broadcast(tv1, {false, true}); + + // tv2 is in a mixed state. It's not a safe output, but it could be repaired + // by a safe reduction. + TensorView* tv3 = add(tv2, in_tv0_); + + TensorView* tv4 = sum(tv3, {1}); + TensorView* tv5 = broadcast(tv4, {false, true}); + TensorView* tv6 = add(tv5, tv4); + fusion_->addOutput(tv6); + should_promote_fmax_ = true; +} + +// Normalization with unary and binary ops thrown in. +// These should not affect promotion. +TEST_F(FMinFMaxPromotionTest, NormalizationUnaryBinary) { + TensorView* tv1 = max(in_tv0_, {0}); + + // Unary op + TensorView* tv2 = abs(tv1); + TensorView* tv3 = broadcast(tv2, {true, false}); + TensorView* tv4 = add(tv3, in_tv0_); + TensorView* tv5 = sum(tv4, {0}); + TensorView* tv6 = broadcast(tv5, {true, false}); + + // Unrelated binary op + TensorView* tv7 = add(tv6, in_tv1_); + fusion_->addOutput(tv7); + should_promote_fmax_ = true; +} + +// Normalization style pattern, but with different axes, breaking promotion. +TEST_F(FMinFMaxPromotionTest, NormalizationDifferentAxes) { + TensorView* tv1 = max(in_tv0_, {0}); + TensorView* tv2 = broadcast(tv1, {true, false}); + TensorView* tv3 = add(tv2, in_tv0_); + TensorView* tv4 = sum(tv3, {1}); + TensorView* tv5 = broadcast(tv4, {true, false}); + TensorView* tv6 = add(tv5, tv3); + fusion_->addOutput(tv6); + should_promote_fmax_ = false; +} + +// Two unsafe reductions on the same input. Exactly one should be promoted. +// This tests that the promotion considers its previous promotion decisions. +TEST_F(FMinFMaxPromotionTest, SiblingReduction) { + TensorView* tv1 = max(in_tv0_, {0, 1}); + TensorView* tv2 = min(in_tv0_, {0, 1}); + // tv1 must be the first argument to add, simply because we check for fmax and + // not fmin. + TensorView* tv3 = add(tv1, tv2); + fusion_->addOutput(tv3); + should_promote_fmax_ = true; +} + } // namespace nvfuser