Skip to content

Conversation

@tbqh
Copy link
Collaborator

@tbqh tbqh commented Oct 7, 2025

Rewrite and rebase of #5121. Adds a new presegmentation pass "fmin_fmax_promotion" which switches min/max reductions with fmin/fmax reductions where possible. Original motivation on #319.

The new pass does dataflow analysis by attaching an enum to TensorViews's. It flows these downward and checks whether any corrupted bad states end up in the output. Can only handle 4 operator types:

  1. UnaryOp
  2. ReduceOp
  3. BroadcastOp
  4. BinaryOp

Test with:

./bin/test_nvfuser --gtest_filter=FMinFMaxPromotionTest*

@github-actions
Copy link

github-actions bot commented Oct 7, 2025

Review updated until commit 188466c

Description

  • Promote min/max reductions to fmin/fmax when safe

  • Analyze dataflow to preserve NaN propagation semantics

  • Support normalization patterns with mixed state tracking

  • Enforce matching reduction and broadcast axes


Changes walkthrough 📝

Relevant files
Enhancement
fmin_fmax_promotion.cpp
Implement fmin/fmax promotion pass with dataflow analysis

csrc/preseg_passes/fmin_fmax_promotion.cpp

  • Implement dataflow analysis using NanStatus enum to track NaN
    propagation
  • Propagate states (Unreduced, BadReduced, Mixed, GoodReduced)
    downstream
  • Check if unsafe reductions are repaired by subsequent safe reductions
  • Replace eligible min/max reductions with fmin/fmax via expression
    replacement
  • +351/-0 
    fmin_fmax_promotion.h
    Declare FMinFMaxPromotionPass with safety constraints       

    csrc/preseg_passes/fmin_fmax_promotion.h

  • Declare FMinFMaxPromotionPass class inheriting from OptimizationPass
  • Document pass purpose and limitations in comments
  • Define interface for safe min/max to fmin/fmax conversion
  • +70/-0   
    Configuration changes
    pre_segmenter.cpp
    Register FMinFMaxPromotionPass in pre-segmenter                   

    csrc/preseg_passes/pre_segmenter.cpp

  • Include new fmin_fmax_promotion.h header
  • Insert FMinFMaxPromotionPass into pre-segmentation pass pipeline
  • +2/-0     
    CMakeLists.txt
    Include fmin_fmax_promotion.cpp in build                                 

    CMakeLists.txt

    • Add fmin_fmax_promotion.cpp to NVFUSER_SRCS
    +1/-0     
    Tests
    test_math_opt.cpp
    Add comprehensive tests for fmin/fmax promotion                   

    tests/cpp/test_math_opt.cpp

  • Add FMinFMaxPromotionTest test fixture with SetUp/TearDown
  • Implement 12 test cases covering basic, normalization, and edge cases
  • Validate promotion decisions via kernel code inspection
  • +190/-0 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The function isSafeReduction only checks if the reduction op is not in the promotedOps set, but does not verify whether the reduction operation is actually safe in terms of NAN propagation. This could lead to incorrect promotion decisions if a reduction is considered safe without proper validation of its impact on NAN values.

    bool isSafeReduction(Expr* expr, const PromotedOpSet& promotedOps) {
      if (auto* rop = dynamic_cast<ReductionOp*>(expr)) {
        // Check that this expr hasn't already been promoted to an unsafe reduction.
        return !promotedOps.contains(rop);
      }
    
      return false;
    }
    Performance Concern

    The dataflow analysis in minMaxOpIsRepaired may not handle cases where multiple reductions repair different axes of the same tensor, as seen in the MultiStageRepair test case. This limitation could prevent valid optimizations and should be addressed to improve the pass's effectiveness.

    bool minMaxOpIsRepaired(
        ReductionOp* targetRop,
        const PromotedOpSet& promotedOps) {
      Fusion* fusion = targetRop->fusion();
    
      auto* in_tv = targetRop->input(0)->as<TensorView>();
      auto* out_tv = targetRop->output(0)->as<TensorView>();
    
      NanStatusMap status_map;
    
      status_map.emplace(in_tv, NanStatus::Unreduced);
      status_map.emplace(out_tv, NanStatus::BadReduced);
    
      std::optional<BroadcastOp*> broadcastMatcher;
    
      // Topological traversal downstream of the targetRop input.
      // Note we start from the input, not the output, of the targetRop, because
      // we need to track the Unreduced state, so it can make repairs.
      auto traversal =
          StmtSort::getExprsBetween({targetRop->input(0)}, fusion->outputs());
    
      for (Expr* expr : traversal) {
        if (expr == targetRop) {
          // Skip the target rop. We already marked its status.
          continue;
        }
    
        // Get aggregate status from all inputs.
        bool anyUnreduced = false;
        bool anyBadReduced = false;
        bool anyMixed = false;
        bool anyGoodReduced = false;
    
        for (auto input : expr->inputs()) {
          if (auto* in_tv = dynamic_cast<TensorView*>(input)) {
            NanStatus status = NanStatus::None;
    
            auto it = status_map.find(in_tv);
            if (it != status_map.end()) {
              status = it->second;
            }
    
            if (status == NanStatus::Unreduced) {
              anyUnreduced = true;
            }
            if (status == NanStatus::BadReduced) {
              anyBadReduced = true;
            }
            if (status == NanStatus::Mixed) {
              anyMixed = true;
            }
            if (status == NanStatus::GoodReduced) {
              anyGoodReduced = true;
            }
          }
        }
    
        if (!canBeAnalyzed(expr, targetRop, broadcastMatcher)) {
          // Analysis is blocked for this node, treat it like a fusion output.
          if (anyBadReduced || anyMixed) {
            return false;
          } else {
            continue;
          }
        }
    
        NanStatus status = NanStatus::None;
        // Determine this node's status based on its inputs.
        // Status is mostly propped based on priority. For example, GoodReduced
        // beats all other states. There is also one combination rule with
        // BadReduced and Unreduced combining to become Mixed.
        if (anyGoodReduced) {
          status = NanStatus::GoodReduced;
        } else if (anyMixed) {
          status = NanStatus::Mixed;
        } else if (anyUnreduced && anyBadReduced) {
          status = NanStatus::Mixed;
        } else if (anyUnreduced) {
          status = NanStatus::Unreduced;
        } else if (anyBadReduced) {
          status = NanStatus::BadReduced;
        }
    
        if (isSafeReduction(expr, promotedOps)) {
          if (status == NanStatus::Unreduced || status == NanStatus::Mixed) {
            // Unreduced and Mixed states both indicate the targetRop's input has
            // propagated here pointwise, preserving its NAN values in unchanged
            // positions. Therefore, this reduction will create the tensor with
            // reduced NAN values matching the original targetRop if it propagated
            // its NANs.
            status = NanStatus::GoodReduced;
          }
        }
    
        auto* out_tv = dynamic_cast<TensorView*>(expr->output(0));
    
        status_map.emplace(out_tv, status);
      }
    
      // Check whether any bad status reached output nodes
      auto output_tvs = ir_utils::filterByType<TensorView>(fusion->outputs());
      for (TensorView* out_tv : output_tvs) {
        NanStatus status = NanStatus::None;
    
        auto it = status_map.find(out_tv);
        if (it != status_map.end()) {
          status = it->second;
        }
    
        if (status == NanStatus::BadReduced || status == NanStatus::Mixed) {
          return false;
        }
      }
    
      return true;
    }
    Possible Issue

    The status map uses status_map.emplace(out_tv, status) which does not overwrite existing entries. If a tensor view is revisited during traversal, its status may not be updated correctly, leading to incorrect analysis results.

    status_map.emplace(out_tv, status);


    // Full-size statuses
    DEFAULT,
    BAD_BROADCAST,
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Still trying to understand the analysis, but wondering why we need a separate status for reduction and broadcast. Just having GOOD and BAD not enough?

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    It is still unclear to me why there is both DEFAULT and GOOD. I also don't understand why we need separate state for broadcasted BAD.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I've reduced the number of states down to what I think is the minimal set of states. I've also renamed all these states to hopefully be more explicit (and not use overloaded words like "default").

    There's now a pretty large comment block at the top of this file, explaining every single state with a full example. I recommend going through the example.

    Regarding this specific point:

    Why can't we just have GOOD and BAD states?

    I guess the problem is that these states are communicating several pieces of information:

    1. Whether the node is downstream of the unsafe target reduction
    2. Whether the node is downstream from the target-rop input, and downstream in a way that the NAN data is still intact (i.e. there is a pointwise dataflow path that didn't squelch NANs or move them around).
    3. Whether the data has passed through a safe reduction

    These are 3 different types of information, perhaps there could be 8 states (2^3), but since there is some redundancy we only need a subset of those states.


    namespace nvfuser::preseg_passes {

    // IterDomainStatus are attached to IterDomains and propagated with a
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'm actually not sure iter domains are the right granularity of the analysis. If one iter domain has a bad status, its tensor should be considered bad as well. Also, reductions remove iter domains, so "bad" iter domains would just disappear from the fusion. It seems to me tensors are the right level of this analysis. What do you think?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is the central question and took a long time for me to figure out. It comes down to how powerful you want the analysis to be. If you want to allow arbitrary reductions and broadcasts, you need IterDomain tracking. For example, recognizing that it would be safe to promote this max() requires IterDomains:

    TensorView* tv1 = max(in_tv0_, {0, 1});
    TensorView* tv2 = sum(in_tv0_, {1});
    TensorView* tv3 = sum(tv2, {0});
    TensorView* tv4 = add(tv1, tv3);
    fusion_->addOutput(tv4);
    

    I was able to get an algorithm working that would recognize this. It's on this commit, but I ended up getting rid of that algorithm in favor of a simpler one. It actually required tracking on both the IterDomains and the TensorViews.

    The problem I found was that while it could handle reduction axis tracking, it could not handle broadcast axis tracking. So it would wrongfully promote this fusion:

    TensorView* tv1 = max(in_tv0_, {1});
    TensorView* tv2 = sum(in_tv0_, {1});
    TensorView* tv3 = broadcast(tv1, {true, false});
    TensorView* tv4 = broadcast(tv2, {false, true});
    // The reduction axes are basically transposed now, they do not repair eachother
    TensorView* tv5 = add(tv3, tv4);
    fusion_->addOutput(tv5);
    

    To support both of these (flexible reduction and broadcast axes) at the same time is very challenging. You would need a state that is tracked on both TensorViews and IterDomains, with some semantics like:

    • When reducing an IterDomain, move its state onto the TensorView
    • When introducing a broadcast IterDomain, inherit the current state from the TensorView

    And then it's an open question how you resolve all these states coming into a given expression, (the TensorView and IterDomain states of all the input TV's).

    Ultimately I realized that for normalization patterns, we can enforce that all reductions and are over identical axes, and all broadcasts are introducing the same axes. As long as we enforce that, we can keep the interactions at the TensorView level, and it should serve the purpose we set out for optimizing normalization patterns.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The problematic fusion you show there is handled correctly by the "vectors of ValGroups" tracking approach we derived last week though isn't it?

    Copy link
    Collaborator

    @jacobhinkle jacobhinkle Nov 1, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    TV  BAD PARTIALLY REPAIRED 
    tv0 {} {}
    tv1 {1} {}
    tv2 {} {}
    tv3 {1} {}
    tv4 {} {}
    tv5 {} {1} // Contracting Bad and missing moves the ValGroup fork Bad to Partially Repaired
    

    Summing the i1 dimension would fully repair it but as is we would detect that it's not safe to do this conversion.

    Copy link
    Collaborator Author

    @tbqh tbqh Nov 2, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Are you thinking of an iterative algorithm that propagates the ValGroup through the TV's, one expr at a time, similar to how the current status enum is propagated? That makes more sense to me. The mistake I made with IdModel before was trying to use it to do mappings between distant expressions, when I guess I need to control the contraction rules myself.

    I think we still need the 4 different states for the ValGroups though. In your sketch, how do you know whether a missing state contains the original data, and not something else which would not produce a partial repair. See this counter-example where the final add() receives "bad" and "missing" states, this should not produce a partially-repaired state.

    TensorView* tv1 = max(in_tv0_, {1});
    TensorView* tv2 = sum(in_tv0_, {1});
    TensorView* tv3 = broadcast(tv1, {true, false});
    TensorView* tv5 = add(tv3, in_tv1_);
    

    In this case, tv5 should still have a bad state. This is where we need the Unreduced state to communicate whether a TV is downstream of in_tv0_.

    jacobhinkle

    This comment was marked as resolved.

    @tbqh

    This comment was marked as outdated.

    tbqh added 3 commits October 26, 2025 21:54
    - Function names start with lowercase letters
    - Use snake_case instead of camelCase
    - Add anonymous namespace to file-scoped things
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    This PR adds a new presegmentation optimization pass called FMinFMaxPromotionPass that converts standard min/max reductions to faster fmin/fmax reductions when safe to do so. The motivation is performance - CUDA's fmin/fmax functions are faster than the project's min/max helpers that properly handle NaN propagation, but they don't preserve NaN values as PyTorch requires. The pass performs dataflow analysis to identify patterns where promoting to fmin/fmax won't affect final outputs because NaN values either don't reach outputs or are properly "repaired" by downstream operations. This optimization targets normalization patterns common in ML workloads and can provide up to 20% speedups in operations like cross-entropy loss. The implementation is conservatively scoped to handle only 4 operator types (UnaryOp, ReduceOp, BroadcastOp, BinaryOp) to keep complexity manageable while covering the most common fusion patterns.

    Important Files Changed

    Filename Score Overview
    csrc/preseg_passes/fmin_fmax_promotion.cpp 4/5 Implements the core optimization pass with dataflow analysis for safe min/max to fmin/fmax promotion
    csrc/preseg_passes/fmin_fmax_promotion.h 5/5 Header file defining the optimization pass class with comprehensive documentation
    csrc/preseg_passes/pre_segmenter.cpp 5/5 Integrates the new pass into the pre-segmentation pipeline after axiom passes
    tests/cpp/test_math_opt.cpp 4/5 Adds comprehensive test coverage for the new optimization pass
    CMakeLists.txt 5/5 Standard addition of new source file to build system

    Confidence score: 4/5

    • This PR implements a well-scoped performance optimization with thorough testing and comprehensive documentation
    • Score reflects conservative implementation approach with clear limitations and good test coverage, but complexity of dataflow analysis requires careful review
    • Pay close attention to the dataflow analysis logic in fmin_fmax_promotion.cpp, particularly the status propagation and safety checks

    Sequence Diagram

    sequenceDiagram
        participant User
        participant PreSegmenter
        participant FMinFMaxPromotionPass
        participant NanStatusAnalyzer
        participant Fusion
    
        User->>PreSegmenter: "runPass(fusion)"
        PreSegmenter->>FMinFMaxPromotionPass: "runPass(fusion)"
        
        FMinFMaxPromotionPass->>Fusion: "fusion->exprs()"
        Fusion-->>FMinFMaxPromotionPass: "expression list"
        
        loop "For each expression"
            FMinFMaxPromotionPass->>FMinFMaxPromotionPass: "filter for min/max ReductionOp"
            alt "Is min/max reduction"
                FMinFMaxPromotionPass->>NanStatusAnalyzer: "minMaxOpIsRepaired(targetRop, promotedOps)"
                
                NanStatusAnalyzer->>NanStatusAnalyzer: "Initialize status map with Unreduced/BadReduced"
                NanStatusAnalyzer->>Fusion: "getExprsBetween(input, outputs)"
                Fusion-->>NanStatusAnalyzer: "traversal order"
                
                loop "For each expression in traversal"
                    NanStatusAnalyzer->>NanStatusAnalyzer: "canBeAnalyzed(expr, targetRop, broadcastMatcher)"
                    alt "Can analyze"
                        NanStatusAnalyzer->>NanStatusAnalyzer: "compute aggregate NanStatus from inputs"
                        alt "Is safe reduction"
                            NanStatusAnalyzer->>NanStatusAnalyzer: "promote Unreduced/Mixed to GoodReduced"
                        end
                        NanStatusAnalyzer->>NanStatusAnalyzer: "update status_map with computed status"
                    else "Cannot analyze"
                        alt "Has BadReduced or Mixed status"
                            NanStatusAnalyzer-->>FMinFMaxPromotionPass: "false (cannot promote)"
                        end
                    end
                end
                
                NanStatusAnalyzer->>NanStatusAnalyzer: "check fusion outputs for BadReduced/Mixed"
                alt "No bad status in outputs"
                    NanStatusAnalyzer-->>FMinFMaxPromotionPass: "true (can promote)"
                    FMinFMaxPromotionPass->>FMinFMaxPromotionPass: "add to promotedOps set"
                else "Bad status found"
                    NanStatusAnalyzer-->>FMinFMaxPromotionPass: "false (cannot promote)"
                end
            end
        end
        
        loop "For each promoted operation"
            FMinFMaxPromotionPass->>Fusion: "removeExpr(rop)"
            alt "Max -> FMax"
                FMinFMaxPromotionPass->>Fusion: "create ReductionOp with BinaryOpType::FMax"
            else "Min -> FMin"
                FMinFMaxPromotionPass->>Fusion: "create ReductionOp with BinaryOpType::FMin"
            end
        end
        
        FMinFMaxPromotionPass-->>PreSegmenter: "promotion complete"
        PreSegmenter-->>User: "pass complete"
    
    Loading

    5 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    @tbqh
    Copy link
    Collaborator Author

    tbqh commented Nov 1, 2025

    I went through a couple more iterations on the core algorithm and made a semi-large change. The algorithm now requires that all reductions have identical reduce dimensions, and all broadcasts have the same broadcast dimensions (if there are any broadcasts).

    This is a very strong requirement that allows us to stop dealing with IterDomains. The previous algorithms that used IterDomains were incomplete in various ways, and the things I was trying to support with IterDomains (differing reduce and broadcast dims) actually required significantly more work. You can see one commit where I got it working for differing reduce dimensions, but the code is pretty complicated, and to make it work with broadcast dimensions would have been even more so. We don't need any of this to support normalization / softmax patterns.

    While working on those algorithms, I tried multiple ways to do IterDomain mapping: PairwiseLogicalDomainMap, ComputeAtLogicalDomainMap and all types of IdModel. None of these seemed to serve our needs for this algorithm however. Take this fusion:

    in_tv0_;                             // [i0, i1]
    TensorView* tv1 = max(in_tv0_, {0}); // [r2, i3]
    TensorView* tv2 = abs(in_tv0_);      // [i4, i5]
    TensorView* tv3 = sum(tv2, {1});     // [i6, r7]
    TensorView* tv4 = add(tv1, tv3);     // [i8]
    

    To support analysis of more than 1 reduction shape, the difficulty of the analysis comes when we are trying to assign state for tv3. In this case, the max ReductionOp will be our target rop. We can easily figure out that i0 is the IterDomain of interest. Now we need to figure out whether or not the data from from i0 makes it into r7 intact (it does not). I thought IdModel would be able to answer this question, but there were many cases where it was too permissive. I believe in this case it would map everything into one ValGroup of IterDomains. I guess IdModel is used to do something different which I don't yet understand.

    The only thing I found that worked was manually propagating IterDomains expr-by-expr. For this I found PairwiseLogicalDomainMap was able to map producer-to-consumer through a single expression, basically helping me map ID's around reduction and broadcast axes changing the dimension indices.

    All this is to say, I think handling all kinds of reduction and broadcast shapes is out of the scope of what we wanted. And I think these new restrictions on reductions/broadcasts make the algorithm tractable and understandable. This PR is ready for a new round of reviews.

    NanStatusMap status_map;

    status_map.emplace(in_tv, NanStatus::Unreduced);
    status_map.emplace(out_tv, NanStatus::BadReduced);
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    @jacobhinkle regarding the question

    Why not have an initialization step where all IDs of fusion inputs are marked GOOD instead of NONE?

    That is kind of what we are doing here when we insert the Unreduced and BadReduced states.

    Why don't we do this for all fusion inputs? Because we only analyze a subgraph of a the fusion starting at targetRop->input(0).

    The reason we only work on this subgraph is because we only support status mapping for a small number of operators. If we seed data on fusion inputs, it's unlikely any meaningful status will survive up to the point where we start using it.

    Copy link
    Collaborator

    @jacobhinkle jacobhinkle left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Please also see my replies to an earlier comment. I appreciate the argument that we should reduce scope to the known perf opportunities. I mostly am concerned with safety: to guarantee we don't return non-nan incorrectly we do need to do a careful analysis. ValGraphs are your friend here for guaranteeing dimensions are equivalent.

    Comment on lines +65 to +66
    // [5.0, 5.0, 5.0, 5.0, 5.0, 5.0] <- BadReduced
    // [NAN, NAN, NAN, NAN, NAN, NAN] <- GoodReduced
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Wouldnt these just be 5.0 and NAN if they are reduced?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yes they may very well be:

    [5.0] <- BadReduced
    [NAN] <- GoodReduced
    

    These states could end up on both shapes - the full sizes domains, and the reduced domains. The "Reduced" in the name actually refers to the fact that these are downstream of a bad reduction, or a good reduction. But they still apply to shapes after they have (maybe) been broadcasted to the single allowed broadcast shape.

    The confusing naming goes back to that other thing I said, that these NanStatus states contain a couple different types of information. Perhaps it would be better to take "Reduced" out of the names. Then we would have:

    Unreduced
    Bad
    Mixed
    Good
    

    Let me know what you think.

    return false;
    }

    bool reductionMatches(ReductionOp* left, ReductionOp* right) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Why not use the Broadcast graph here to check that the ValGroups with reduction IDs in each op match? The current function you have here is very frail because you could for example broadcast an outer dimension which throws off the indexing. That's exactly the problem the ValGraph solves for you by tracking equivalent dimensions through the whole fusion.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Can you give a bit more detail about how to do that with the broadcast graph? I am not sure how to use it for this purpose.

    Regarding the safety of the current code: an outer broadcast that changes indexing is one of the cases that we want to reject outright, and we will with the checks. This is true because of the combination of these checks:

    • All reductions must match each other. Since there is a targetRop, all reductions must match it.
    • All broadcasts must match each other (if there are any)

    In order for a reduction op to pass this test, it must be mapping from the "original" logical domain going into the target rop, into the "reduced" target domain coming out of that op. Suppose that a reduction is downstream of a broadcast op. The only way to pass this check is if the input domain matches the "original" logical domain. So the broadcast op could only have had 1 shape - broadcasting the reduced shape to the original shape.

    In all other cases where a reduction is downstream of a broadcast, the indexing is not guaranteed to be correct, and this function will reject it which is what we want because it's a conservative check.

    Copy link
    Collaborator

    @jacobhinkle jacobhinkle Nov 2, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I meant something like this

    bool reductionMatches(ReductionOp* a, ReductionOp* b) {
      IdModel id_model(a->fusion());
      const ValGraph& graph = id_model.idGraph(IdMappingMode:: BROADCAST);
      auto get_groups = [&graph] (ReductionOp* rop) {
        std::unordered_set<ValGroup> s;
        for (IterDomain* id : rop->out()->getLogicalDomain()) {
          if (id->is reduction()) {
            s.insert(graph.toGroup(id));
          }
        }
        return s;
      };
      return get_groups(a) == get_groups(b);
    }

    But of course you'd want to reuse graph.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    A SqueezeOp or PermuteOp could interfere with your current implementation but IdModel would be fine.

    Copy link
    Collaborator Author

    @tbqh tbqh Nov 2, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I prototyped this and it doesn't seem to do what we need. Take this case which should be rejected:

    in_tv0_; // input
    TensorView* tv1 = max(in_tv0_, {0});
    TensorView* tv2 = sum(in_tv0_, {1});
    TensorView* tv3 = add(tv1, tv2);
    fusion_->addOutput(tv3);
    

    IdModel will match these reductions. Placing all IterDomains into a single ValGroup. Here is the fusion print out:

    Inputs:
      T0_g_float[iS0{i0}, iS1{i2}]
    Outputs:
      T3_g_float[iS6{i2}]
    
    %kernel_math {
    T1_l_float[rS2{i0}, iS3{i2}]
       = reduction( T0_g_float[iS0{i0}, iS1{i2}], op = max, initial value = double(-inf), allreduce = false )
    T2_l_float[iS4{i0}, rS5{i2}]
       = reduction( T0_g_float[iS0{i0}, iS1{i2}], op = add, initial value = float(0), allreduce = false )
    T3_g_float[iS6{i2}]
       = T1_l_float[rS2{i0}, iS3{i2}]
       + T2_l_float[iS4{i0}, rS5{i2}];
    } // %kernel_math 
    

    And here is the Broadcast graph:

    IdGraph { 
    Disjoint Ids:
      (idgs){
        idg{0 1 2 3 4 5 6}
    }
    
    Disjoint Expression groups:
      (exprgs){
      }
     } IdGraph
    

    I tried the other mapping modes and they all had similar issues.

    A SqueezeOp or PermuteOp could interfere with your current implementation

    Those ops will be rejected by canBeAnalyzed(). We only allow 4 op types in an allow-list, so the default for all other ops is to reject them.


    While I understand there may be value in using IdModel in this way to be more idiomatic, and have proper domain mapping, I think we will still need these strict requirements on reduction and broadcast shape. And the current matching functions may be simpler than using IdModel to enforce these requirements.

    We are not trying to support things like transposes, though in principle we could. It would require an IterDomain<->TensorView tracking and interaction logic which is too complicated for the normalization case we are trying to support.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yeah that is kind of a problem isn't it. It hints that there is possibly something a little finer than even the Exact graph which would create a separate group when two groups are aligned together in an op where those two groups are present as distinct IDs in some single TV in the fusion. This problem is similar in that respect to x*x.T which will also collapse to a single ValGroup in the exact graph (and all other graphs).

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is this because the IdModel groups domains based on iteration/load/compute, rather than tracking actual data? E.g. the loop walking over x in x*x.T could be walking over x.T at the same time. The IterDomains match each other computationally, but not in terms of dataflow.

    I'm not sure if this is more fine-grained than Exact graph, it seems like an orthogonal concept. This new type of graph is what we would need to transpose, permute and other ops with this analysis.

    Copy link
    Collaborator

    @jacobhinkle jacobhinkle Nov 3, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The Exact graph maps all corresponding IDs from producer to consumer in a UnaryOp, and in a BinaryOp it maps Iteration domains in the producer to Iteration domains in the consumer. Corresponding IDs are also mapped in a permute (which is a LoadStoreOp whose output has a logical domain that's a permutation of its root domain). These two rules are enough (along with transitivity of the mapping) to cause all IDs to be mapped to one another in this case:

    tv0 // [ iS0{i0}, iS1{i1} ]
    tv1 = transpose(tv0)  // [ iS2{i1}, iS3{i0} ]
    tv2 = mul(tv0, tv1)  // [ iS4{i0}, iS5{i0} ]
    

    The permutation mapping means we have {0, 3}, {1, 2}, then the binary op mapping for exact map means we also map iS4 to both iS0 and iS2 and similar for iS5, so we wind up with {0, 1, 2, 3, 4, 5}.

    The permutation is not a problem for the fmin analysis because it doesn't merge any distinct groups really, just forwards through a trivial op. But the BinaryOp behavior is a problem. So I'm proposing a new graph that does not merge aligned ValGroups at all. In that graph, we'd have the following groups: {0, 3}, {1, 2}, {4, 5} and we'd have an "Aligned" ExprGroup with those two inputs and that one output group. We could lift the fmin analysis to the ValGraph at that point potentially:

    flowchart TD
      g03["{0, 3}"]
      g12["{1, 2}"]
      g4["{4}"]
      g5["{5}"]
      e0{"{Align({0,3}, {1,2})}"}
      e1{"{Align({1,2}, {0,3})}"}
      g03 --> e0
      g12 --> e0
      e0 --> g4
      g03 --> e1
      g12 --> e1
      e1 --> g5
    
    Loading

    For now, I think we should try and reduce this to be as simple and safe as possible. Using axis positions feels like a step in the wrong direction from the simplicity and safety perspective but as your example points out, if our finest graph cannot distinguish those two reduction axes then it's hard to recommend another way right now. You could support LoadStoreOp without any root domain safely I think.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Small point on:

    You could support LoadStoreOp without any root domain safely

    This should already be supported in as much as we support anything else that can't be analyzed. As an input, it will have a None state which will get ignored, or as a node output we will treat it as a graph output and ensure it has no bad inputs.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I mean that we could support it between unsafe and safe operations.

    return true;
    }

    bool broadcastMatches(BroadcastOp* left, BroadcastOp* right) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Same comment wrt reductions and broadcasts

    greptile-apps[bot]

    This comment was marked as resolved.

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    This review covers only the changes made since the last review, not the entire PR. This update refactors the FMinFMaxPromotionTest class to follow standard Google Test patterns by moving validation logic from a manually-called validateFusion() method into an automatic TearDown() override. The key changes include: converting manual validation calls to automatic teardown execution, replacing the method parameter shouldPromoteFMax with a member variable should_promote_fmax_, and updating all test cases to set the member variable instead of calling validation directly. This restructuring ensures validation runs after every test case automatically, making the test suite more robust and following established testing best practices. The change maintains identical validation logic while integrating it into the standard test lifecycle.

    Important Files Changed

    Filename Score Overview
    tests/cpp/test_math_opt.cpp 2/5 Refactored FMinFMaxPromotionTest validation from manual method calls to automatic TearDown() with concerning use of std::move()

    Confidence score: 2/5

    • This PR introduces a critical bug where std::move(fusion_) invalidates the class member, making subsequent test operations unsafe
    • Score reflects the dangerous use of std::move() on a class member that could lead to undefined behavior if validation runs multiple times or if other methods access fusion_ after TearDown()
    • Pay close attention to the std::move usage on line 136 which needs to be fixed before merging

    1 file reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    This review covers only the changes made since the last review, not the entire PR. The most recent change is a single line modification in the test file tests/cpp/test_math_opt.cpp that replaces NVF_CHECK with EXPECT_EQ for better test integration practices. This aligns the assertion with standard Google Test conventions, providing clearer error reporting when the test validation fails. The change occurs in the TearDown() method where the test framework validates that the expected fmin/fmax promotion behavior matches the actual kernel code generation by checking for "fmax(" strings in the generated code.

    Important Files Changed

    Filename Score Overview
    tests/cpp/test_math_opt.cpp 5/5 Single line change replacing NVF_CHECK with EXPECT_EQ for better Google Test integration

    Confidence score: 5/5

    • This PR is safe to merge with minimal risk as it only changes test assertion methodology
    • Score reflects a simple, well-understood change that improves test framework consistency without affecting functionality
    • No files require special attention as this is a trivial test improvement

    1 file reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants