Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
4e9b960
Add fmin_fmax_promotion preseg pass
tbqh Oct 7, 2025
730f011
Rewrite fmin_fmax_promotion preseg pass
tbqh Oct 19, 2025
fb7409d
Fix dependent fmin/fmax promotion
tbqh Oct 27, 2025
a42414a
Address style nits
tbqh Oct 27, 2025
7b09901
Use IdModel for all IterDomain mapping needs
tbqh Oct 27, 2025
16c4e10
Rename states and cleanup logic
tbqh Oct 27, 2025
bac3247
Rename and add code comments
tbqh Oct 27, 2025
df9a676
Only build IdModel graph once per-pass
tbqh Oct 27, 2025
68224b4
Clean up tests and name them using a test class
tbqh Oct 30, 2025
1f11a81
Move min/max promotion out of analysis loop
tbqh Oct 30, 2025
9dda08c
Get rid of markUnsafe() and do expression replacement instead
tbqh Oct 30, 2025
55271b7
Add more test cases
tbqh Oct 31, 2025
0ebc938
Rewrite analysis algorithm to be a mixture of TensorView and IterDoma…
tbqh Oct 31, 2025
7761858
Rewrite the algorithm again. New restriction that reductions and broa…
tbqh Oct 31, 2025
0008034
Explicit insert and read from the status_map
tbqh Oct 31, 2025
a7ddcee
Update test comments
tbqh Oct 31, 2025
bf65b1e
Make all broadcasts explicit and put them on their own line
tbqh Oct 31, 2025
dbd2adf
Move comment
tbqh Oct 31, 2025
a2f4b67
Address greptile comments
tbqh Nov 3, 2025
c3470c7
Move test code into TearDown() function
tbqh Nov 3, 2025
188466c
Switch NVF_CHECK() to EXPECT_EQ()
tbqh Nov 3, 2025
86d6b62
Add microarchitecture_is_pre helper (#5338)
wujingyue Oct 7, 2025
6b238c2
push_back(std::nullopt) and remove unnecessary lint disablers (#5329)
wujingyue Oct 7, 2025
dafb15b
Remove a file accidentally checked (#5339)
wujingyue Oct 7, 2025
f0100c4
Allow allocation and loop to have different positions of device IDs (…
Priya2698 Oct 7, 2025
687871a
Fix allocation logic: unconnected alloc/logical (#5185)
jjsjann123 Oct 7, 2025
3300a33
Migrate `Tutorial.VectorizeStorePointwiseTMA` and `Tutorial.Pointwise…
rdspring1 Oct 8, 2025
0f94209
Revert "Disable UCC tests for CUDA 13 (#4680)" (#5334)
samnordmann Oct 8, 2025
148f640
add non_all_reduce version of cluster reduction (#5319)
liqiangxl Oct 8, 2025
e499693
count for static shared memory usage (#5272)
liqiangxl Oct 8, 2025
7786e4c
P2p cuda lowering (#5259)
samnordmann Oct 8, 2025
ab7a81d
[HostIR] Print index definitions in `HostIrContainer::print` (#5327)
samnordmann Oct 8, 2025
7022670
Fix loading jsons in codediff tool (#5280)
jacobhinkle Oct 8, 2025
b3e76aa
Migrate `Tutorial.TMABankConflictFreeTranspose` to direct bindings (#…
rdspring1 Oct 8, 2025
5a76779
WAR fix of a bug in compute map based loop ID retrieval (#5330)
liqiangxl Oct 8, 2025
58eff1a
Reapply "Disable UCC tests for CUDA 13 (#4680)" (#5334) (#5349)
samnordmann Oct 8, 2025
376334c
Remove legacy tests (#5342)
wujingyue Oct 8, 2025
ea6feb5
Refactor `haveDifferentShardings` for use with `Stream` parallel type…
Priya2698 Oct 9, 2025
54fcd62
Fixing tests for compute capability 12 devices. (#5284)
mdavis36 Oct 9, 2025
292c5e8
Support layout op in scheduler (#5174)
jjsjann123 Oct 9, 2025
1bc5bc3
Minor fix (#5313)
naoyam Oct 10, 2025
946f5bd
Greedy scheduler rejects resize to broadcast (#5318)
naoyam Oct 10, 2025
eb43a69
ScanOp supports bfloat16 and half inputs (#5312)
naoyam Oct 10, 2025
2da8b93
Use git submodule for Cutlass (#5351)
rdspring1 Oct 10, 2025
343730e
Fix itertype promotion for `GatherScatter` (#5365)
Priya2698 Oct 10, 2025
4944de8
add quack to pybm (#5336)
liqiangxl Oct 13, 2025
f845059
Pull llama4 config into benchmark_inference.py (#5360)
tbqh Oct 13, 2025
b89a83a
Fix inlining positions with constrained ops (#5317)
naoyam Oct 13, 2025
338ff98
Enabling Host IR at Runtime (#5344)
mdavis36 Oct 14, 2025
6faae14
Propagate `stream` in loop irrespective of device mesh (#5363)
Priya2698 Oct 14, 2025
c9d1f9f
Add CUPTI profiling to direct bindings (#5324)
rdspring1 Oct 14, 2025
972729c
renaming docstring (#5378)
jjsjann123 Oct 14, 2025
92972a0
Cutlass EVT Translation: Part 1 (#5359)
jacobhinkle Oct 14, 2025
7301aac
fix cached vectorization factor in scheduler hyper parameters (#5275)
liqiangxl Oct 14, 2025
e60fc12
Patch the exception check (#5386)
jjsjann123 Oct 14, 2025
bd56f00
Set env variables for `torch.compile` (#5264)
Priya2698 Oct 14, 2025
104bfb1
More validity checks in cutlass_kernels::grouped_mm (#5394)
wujingyue Oct 15, 2025
ca4f265
Fix predicate elimination with initialized tensors (#5355)
naoyam Oct 15, 2025
f26493d
Fix bug in EVT Sm90Compute type arguments (#5400)
jacobhinkle Oct 15, 2025
5f3ce30
GroupedMmaOp::evaluate invokes the cutlass kernel when possible (#5397)
wujingyue Oct 16, 2025
82483c8
Fix indexing traversal for vectorization validation (#5381)
naoyam Oct 16, 2025
d87899c
Propagate `Stream` parallel type in allocation (#5353)
Priya2698 Oct 16, 2025
c985e5e
Visit allIDs in IterVisitor (#5384)
jacobhinkle Oct 16, 2025
29cbf37
refactor number of groups in layout op (#5198)
jjsjann123 Oct 16, 2025
e2692b1
[Greedy scheduler] Cap fusion to limit shared memory usage (#5328)
naoyam Oct 16, 2025
1da95e5
Ensure the M dimension is larger than the expert offsets (#5352)
rdspring1 Oct 17, 2025
d66123a
Use std::unique_ptr for DeviceMesh and FusionExecutorCache in direct …
rdspring1 Oct 17, 2025
04544f7
Update benchmark_utils.py, pytorch_utils.py, and profile.cpp for benc…
rdspring1 Oct 17, 2025
3010ca2
Allocate CUPTI buffer dynamically (#5402)
Priya2698 Oct 17, 2025
3f69500
Fix a check (#5407)
wujingyue Oct 17, 2025
50269a0
fp4 packed dtype support on direct python API (#5380)
jjsjann123 Oct 17, 2025
f613322
missed one field during copy constructor (#5404)
jjsjann123 Oct 17, 2025
484ffc3
Improve configuration summary (#5408)
wujingyue Oct 17, 2025
8957c36
Remove accidentally added test (#5411)
naoyam Oct 17, 2025
27de5e5
Prevent passing -DNVMMH_INCLUDE_DIR to cmake if env var missing (#5409)
jacobhinkle Oct 18, 2025
414e894
Fix ShareIpcMemHandles test (#5412)
nsarka Oct 18, 2025
2cafb54
Lower stream-parallelized matmul (#5302)
wujingyue Oct 19, 2025
2dd0461
Skip test_deepseek_v3.py (#5416)
wujingyue Oct 20, 2025
20c4843
Add NVFUSER_DUMP=inlining to show verbose info about inlining positio…
naoyam Oct 20, 2025
91ff396
exposing layout op at direct python binding (#5345)
jjsjann123 Oct 21, 2025
7e9be2b
Add "Put" p2p Cuda protocol (#5372)
samnordmann Oct 21, 2025
2dcfcd7
Add __device__ so dumped source code can be compiled with nvcc (#5415)
rdspring1 Oct 21, 2025
4877075
Reapply #5344 (#5425)
wujingyue Oct 22, 2025
76781db
Support flexible fusion input order and workspace size in Cutlass cod…
jacobhinkle Oct 23, 2025
ff30b72
relaxing device split check in vectorization analysis (#5389)
jjsjann123 Oct 23, 2025
4bfa9d2
Skip TransformerEngine test on SM 10.x because it stalls the CI. (#5430)
rdspring1 Oct 25, 2025
1f6fdfb
import pytest to fix CI (#5432)
rdspring1 Oct 27, 2025
fde13f3
Fix and Reenable Ring Allgather Cuda Ipc Test (#5429)
nsarka Oct 27, 2025
e71516c
Drop support for CUDA <12 and clean driver_api (#5433)
wujingyue Oct 28, 2025
9a57500
Remove an outdated table (#5436)
wujingyue Oct 28, 2025
5a96a0c
Fix GroupedMmaOp so it's capturable (#5438)
wujingyue Oct 28, 2025
ac233a1
Hacky patch to support correct vectorization factor of nvfp4 for poin…
jjsjann123 Oct 28, 2025
bdabc9c
Refactor kernel launch to use only `cuLaunchKernelEx` (#5435)
rdspring1 Oct 28, 2025
9fcf1e5
Add support for automatic scheduler in direct bindings (#5374)
rdspring1 Oct 29, 2025
234af49
Refactor validation in test_repro.py (#5410)
rdspring1 Oct 29, 2025
2ccb362
Improve error messages (#5449)
wujingyue Oct 29, 2025
a1b58c9
Pass EVT arguments by configurable names (#5443)
jacobhinkle Oct 30, 2025
e6a6748
Add meta path for scan op (#5450)
zasdfgbnm Oct 30, 2025
f464f70
Improve error messages (#5452)
wujingyue Oct 31, 2025
c6914b5
Fix regiser alias cross for loops (#5350)
liqiangxl Oct 31, 2025
340a1c6
Move unshard helpers into pipeline test (#5458)
wujingyue Oct 31, 2025
080de3e
Remove dead code (#5459)
wujingyue Oct 31, 2025
cce2924
Factor multi-device execution helpers into execution_utils (#5461)
wujingyue Oct 31, 2025
85bf378
Adjust test_scaled_mm rtol AND use torch.testing.assert_close (#5462)
rdspring1 Oct 31, 2025
8dff405
Refactor multidevice allocation utilities (#5464)
wujingyue Nov 1, 2025
6e8875c
Inlining fix (#5465)
naoyam Nov 1, 2025
1b9f037
Create a new node for Block Quantization to NVFP4 and plumb it to a d…
protonu Nov 2, 2025
0b985cc
Debugging instructions for multi-GPU (#5469)
wujingyue Nov 3, 2025
16b1e27
Limit loop promotion analysis to those IDs that are in the logical-lo…
naoyam Nov 3, 2025
a6db940
Validate offsets only in debug mode (#5470)
wujingyue Nov 3, 2025
a96435e
Support block-scaled outputs in Cutlass EVT (#5441)
jacobhinkle Nov 3, 2025
8f708b1
Convert SDPA block comments to line comments (#5474)
wujingyue Nov 4, 2025
6705930
Minor fixes on pointwise scheduler support for fp4 (#5455)
jjsjann123 Nov 4, 2025
88a9d61
Upgrade clang-build-23 CI job to CUDA 13.0 and add ccache (#5471)
xwang233 Nov 4, 2025
12b9318
Convert python benchmarks to direct bindings (#5224)
rdspring1 Nov 4, 2025
cabc081
dump nvrtc compile params (#5466)
liqiangxl Nov 4, 2025
acc9ce5
Use std::list for HostIrContainer top-level exprs (#5476)
wujingyue Nov 4, 2025
890b375
[Refactoring] Use resize to represent TopKOp more faithfully (#5473)
naoyam Nov 4, 2025
5cfe2b8
Add meta device support for grouped mma (#5472)
zasdfgbnm Nov 5, 2025
dd2f940
Use std::list for IR scopes (#5475)
wujingyue Nov 5, 2025
753b4ed
Revert "Add meta device support for grouped mma" (#5482)
xwang233 Nov 5, 2025
1fffe1f
[Greedy scheduler] Make the topk support more flexible (#5478)
naoyam Nov 5, 2025
d735b05
Update Cutlass grouped gemm alignments constraints. (#5484)
rdspring1 Nov 6, 2025
3dfab04
Revert "Revert "Add meta device support for grouped mma"" (#5483)
zasdfgbnm Nov 6, 2025
45255f9
Add validation requirements for the block quantization op. (#5468)
protonu Nov 6, 2025
dc128b2
Add EmbeddingFwdOp::evaluate and tests for meta device (#5477)
zasdfgbnm Nov 6, 2025
8d8dbd4
Centralize shard allocation logic for stream-parallel tensors (#5463)
wujingyue Nov 6, 2025
8dae364
Extend the block quantization runtime function to also handle 2 and 8…
protonu Nov 8, 2025
c18df5d
Batching support of topk (#5371)
naoyam Nov 8, 2025
4dd3e7f
Modify heuristic parameters before scheduling fusion in direct bindin…
rdspring1 Nov 9, 2025
8a96590
Create `test_tutorial_scheduling_layer_norm_with_profiling` (#5481)
rdspring1 Nov 10, 2025
8b80a2c
Add meta device support for 1D @ 1D MatmulOp::evaluate (#5491)
zasdfgbnm Nov 10, 2025
4869ca6
Add metadata tensors for philox outputs in SDPA (#5492)
zasdfgbnm Nov 10, 2025
ef68868
Move common propagation logic to multidevice utilities (#5496)
Priya2698 Nov 11, 2025
003bcc8
Do not parallelize new iterdomains on device (#5497)
Priya2698 Nov 12, 2025
5218134
Change python softmax benchmark to a canonical form, to hit fmax pro…
tbqh Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/polymorphic_value.cpp
${NVFUSER_SRCS_DIR}/predicate_compute.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/add_axioms.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/fmin_fmax_promotion.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/allocation_order_inference.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/consecutive_cast.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/exact_mapped_extent_substitution.cpp
Expand Down
9 changes: 9 additions & 0 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,15 @@ class ReductionOp : public Expr {
return attribute<BinaryOpType>(1);
}

void markUnsafe() {
if (attribute<BinaryOpType>(1) == BinaryOpType::Max) {
attribute<BinaryOpType>(1) = BinaryOpType::FMax;
}
if (attribute<BinaryOpType>(1) == BinaryOpType::Min) {
attribute<BinaryOpType>(1) = BinaryOpType::FMin;
}
}

bool isAllreduce() const {
return attribute<bool>(2);
}
Expand Down
298 changes: 298 additions & 0 deletions csrc/preseg_passes/fmin_fmax_promotion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <preseg_passes/fmin_fmax_promotion.h>

#include <unordered_map>
#include <vector>

#include <ir/utils.h>
#include <logical_domain_map.h>

namespace nvfuser::preseg_passes {

// IterDomainStatus are attached to IterDomains and propagated with a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually not sure iter domains are the right granularity of the analysis. If one iter domain has a bad status, its tensor should be considered bad as well. Also, reductions remove iter domains, so "bad" iter domains would just disappear from the fusion. It seems to me tensors are the right level of this analysis. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the central question and took a long time for me to figure out. It comes down to how powerful you want the analysis to be. If you want to allow arbitrary reductions and broadcasts, you need IterDomain tracking. For example, recognizing that it would be safe to promote this max() requires IterDomains:

TensorView* tv1 = max(in_tv0_, {0, 1});
TensorView* tv2 = sum(in_tv0_, {1});
TensorView* tv3 = sum(tv2, {0});
TensorView* tv4 = add(tv1, tv3);
fusion_->addOutput(tv4);

I was able to get an algorithm working that would recognize this. It's on this commit, but I ended up getting rid of that algorithm in favor of a simpler one. It actually required tracking on both the IterDomains and the TensorViews.

The problem I found was that while it could handle reduction axis tracking, it could not handle broadcast axis tracking. So it would wrongfully promote this fusion:

TensorView* tv1 = max(in_tv0_, {1});
TensorView* tv2 = sum(in_tv0_, {1});
TensorView* tv3 = broadcast(tv1, {true, false});
TensorView* tv4 = broadcast(tv2, {false, true});
// The reduction axes are basically transposed now, they do not repair eachother
TensorView* tv5 = add(tv3, tv4);
fusion_->addOutput(tv5);

To support both of these (flexible reduction and broadcast axes) at the same time is very challenging. You would need a state that is tracked on both TensorViews and IterDomains, with some semantics like:

  • When reducing an IterDomain, move its state onto the TensorView
  • When introducing a broadcast IterDomain, inherit the current state from the TensorView

And then it's an open question how you resolve all these states coming into a given expression, (the TensorView and IterDomain states of all the input TV's).

Ultimately I realized that for normalization patterns, we can enforce that all reductions and are over identical axes, and all broadcasts are introducing the same axes. As long as we enforce that, we can keep the interactions at the TensorView level, and it should serve the purpose we set out for optimizing normalization patterns.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problematic fusion you show there is handled correctly by the "vectors of ValGroups" tracking approach we derived last week though isn't it?

Copy link
Collaborator

@jacobhinkle jacobhinkle Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TV  BAD PARTIALLY REPAIRED 
tv0 {} {}
tv1 {1} {}
tv2 {} {}
tv3 {1} {}
tv4 {} {}
tv5 {} {1} // Contracting Bad and missing moves the ValGroup fork Bad to Partially Repaired

Summing the i1 dimension would fully repair it but as is we would detect that it's not safe to do this conversion.

Copy link
Collaborator Author

@tbqh tbqh Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you thinking of an iterative algorithm that propagates the ValGroup through the TV's, one expr at a time, similar to how the current status enum is propagated? That makes more sense to me. The mistake I made with IdModel before was trying to use it to do mappings between distant expressions, when I guess I need to control the contraction rules myself.

I think we still need the 4 different states for the ValGroups though. In your sketch, how do you know whether a missing state contains the original data, and not something else which would not produce a partial repair. See this counter-example where the final add() receives "bad" and "missing" states, this should not produce a partially-repaired state.

TensorView* tv1 = max(in_tv0_, {1});
TensorView* tv2 = sum(in_tv0_, {1});
TensorView* tv3 = broadcast(tv1, {true, false});
TensorView* tv5 = add(tv3, in_tv1_);

In this case, tv5 should still have a bad state. This is where we need the Unreduced state to communicate whether a TV is downstream of in_tv0_.

// downward-flow algorithm. The goal is to detect whether lost NANs propagated
// to the outputs of a fusion. If so, then it is not valid to do the fast
// min/max promotion.
//
// "BAD" statuses indicate ID's that potentially lost their NANs.
// "GOOD" statuses will repair BAD statuses.
// "REDUCE" statuses are reduced dimensions, and should never be mapped to ID's
// with non-reduced status.
// All other statuses are "full-sized" (besides NONE which says nothing).
//
// A reduction ID will convert a full-size status into a REDUCE status. Likewise
// a broadcast ID will convert a reduced status to a full size status.
//
// Statuses can interact via binary ops in a kind of algebra. The two reduced
// statuses have a simple interaction: GOOD_REDUCE beats BAD-REDUCE.
//
// The 4 full-sized statuses have a slightly more complicated interaction:
// 1. Any status matched with itself, produces itself.
// 2. GOOD_BROADCAST matched with anything produces GOOD_BROADCAST.
// 3. All other cases produce BAD_BROADCAST_DEFAULT
//
// An example of what each status means is below (with *max* reduction):
// [0.0 1.0 2.0 3.0 NAN 5.0] <- DEFAULT
// [5.0] <- BAD_REDUCE
// [NAN] <- GOOD_REDUCE
// [5.0 5.0 5.0 5.0 5.0 5.0] <- BAD_BROADCAST
// [5.0 5.0 5.0 5.0 NAN 5.0] <- BAD_BROADCAST_DEFAULT
// [NAN NAN NAN NAN NAN NAN] <- GOOD_BROADCAST
enum class IterDomainStatus {
// NONE is the status when hitting untracked ID's.
NONE,

// Reduced statuses
BAD_REDUCE,
GOOD_REDUCE,

// Full-size statuses
DEFAULT,
BAD_BROADCAST,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still trying to understand the analysis, but wondering why we need a separate status for reduction and broadcast. Just having GOOD and BAD not enough?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is still unclear to me why there is both DEFAULT and GOOD. I also don't understand why we need separate state for broadcasted BAD.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reduced the number of states down to what I think is the minimal set of states. I've also renamed all these states to hopefully be more explicit (and not use overloaded words like "default").

There's now a pretty large comment block at the top of this file, explaining every single state with a full example. I recommend going through the example.

Regarding this specific point:

Why can't we just have GOOD and BAD states?

I guess the problem is that these states are communicating several pieces of information:

  1. Whether the node is downstream of the unsafe target reduction
  2. Whether the node is downstream from the target-rop input, and downstream in a way that the NAN data is still intact (i.e. there is a pointwise dataflow path that didn't squelch NANs or move them around).
  3. Whether the data has passed through a safe reduction

These are 3 different types of information, perhaps there could be 8 states (2^3), but since there is some redundancy we only need a subset of those states.

BAD_BROADCAST_DEFAULT,
GOOD_BROADCAST,
};

using IterStatusMap = std::unordered_map<IterDomain*, IterDomainStatus>;

IterDomainStatus BopStatus(IterDomainStatus lhs, IterDomainStatus rhs) {
if (lhs == IterDomainStatus::NONE) {
return rhs;
}

if (rhs == IterDomainStatus::NONE) {
return lhs;
}

if (lhs == IterDomainStatus::GOOD_REDUCE ||
rhs == IterDomainStatus::GOOD_REDUCE) {
return IterDomainStatus::GOOD_REDUCE;
} else if (
lhs == IterDomainStatus::BAD_REDUCE &&
rhs == IterDomainStatus::BAD_REDUCE) {
return IterDomainStatus::BAD_REDUCE;
}

if (lhs == rhs) {
return lhs;
}

if (lhs == IterDomainStatus::GOOD_BROADCAST ||
rhs == IterDomainStatus::GOOD_BROADCAST) {
return IterDomainStatus::GOOD_BROADCAST;
}

// The only remaining cases are combinations of DEFAULT,
// BAD_BROADCAST, and BAD_BROADCAST_DEFAULT.
return IterDomainStatus::BAD_BROADCAST_DEFAULT;
}

bool StatusIsBad(IterDomainStatus status) {
return status == IterDomainStatus::BAD_REDUCE ||
status == IterDomainStatus::BAD_BROADCAST ||
status == IterDomainStatus::BAD_BROADCAST_DEFAULT;
}

bool AnyBadInputs(Expr* expr, IterStatusMap& iterMap) {
for (auto input : expr->inputs()) {
if (auto* in_tv = dynamic_cast<TensorView*>(input)) {
for (IterDomain* id : in_tv->getLogicalDomain()) {
IterDomainStatus status = iterMap[id];
if (StatusIsBad(status)) {
return true;
}
}
}
}

return false;
}

// Once we identify a target reduction, we perform a downward pass starting from
// the target's direct input. The pass propagates IterDomainStatus information.
// At the end, we check all output TV's for bad statuses. If at any point we
// encounter a node we don't know how to propagate information through, we treat
// it like to a graph output and fail if it has any incoming bad statuses.
bool AnalyzeMinMaxOp(ReductionOp* targetRop) {
Fusion* fusion = targetRop->fusion();

FusionGuard fg(fusion);
ComputeAtLogicalDomainMap logical_map;
logical_map.build(true);

IterStatusMap iterMap;

auto* in_tv = targetRop->input(0)->as<TensorView>();
for (IterDomain* in_id : in_tv->getLogicalDomain()) {
iterMap[in_id] = IterDomainStatus::DEFAULT;
}

auto* out_tv = targetRop->output(0)->as<TensorView>();
for (IterDomain* out_id : out_tv->getLogicalDomain()) {
if (out_id->isReduction()) {
iterMap[out_id] = IterDomainStatus::BAD_REDUCE;
} else {
iterMap[out_id] = IterDomainStatus::DEFAULT;
}
}

auto traversal =
StmtSort::getExprsBetween({targetRop->input(0)}, fusion->outputs());
for (Expr* expr : traversal) {
std::string opName = expr->getOpString();

if (expr == targetRop) {
// Skip the target rop. We already marked its status.
continue;
}

bool anyBadInputs = AnyBadInputs(expr, iterMap);

auto* out_tv = dynamic_cast<TensorView*>(expr->output(0));

if (!out_tv) {
if (anyBadInputs) {
return false;
} else {
continue;
}
}

if (expr->isA<UnaryOp>() || expr->isA<ReductionOp>() ||
expr->isA<BroadcastOp>()) {
auto in_tv = expr->input(0)->as<TensorView>();
auto p2c = logical_map.mapBestEffort(
in_tv->domain(),
in_tv->getLogicalDomain(),
out_tv->domain(),
out_tv->getLogicalDomain());

for (IterDomain* in_id : in_tv->getLogicalDomain()) {
IterDomainStatus status = iterMap[in_id];
auto out_id = p2c[in_id];

if (out_id) {
if (out_id->isReduction()) {
if (status == IterDomainStatus::BAD_BROADCAST) {
status = IterDomainStatus::BAD_REDUCE;
} else if (status != IterDomainStatus::NONE) {
status = IterDomainStatus::GOOD_REDUCE;
}
}

if (out_id->isBroadcast()) {
if (status == IterDomainStatus::BAD_REDUCE) {
status = IterDomainStatus::BAD_BROADCAST;
} else if (status != IterDomainStatus::NONE) {
status = IterDomainStatus::GOOD_BROADCAST;
}
}

iterMap[out_id] = status;
} else {
if (StatusIsBad(status)) {
return false;
}
}
}

} else if (expr->isA<BinaryOp>()) {
auto* left_tv = dynamic_cast<TensorView*>(expr->input(0));
auto* right_tv = dynamic_cast<TensorView*>(expr->input(1));

// One side (not both) might not be a TensorView.
// To handle this, just propagate the status of the other side.
if (!left_tv) {
left_tv = right_tv;
} else if (!right_tv) {
right_tv = left_tv;
}

auto left2right = logical_map.mapBestEffort(
left_tv->domain(),
left_tv->getLogicalDomain(),
right_tv->domain(),
right_tv->getLogicalDomain());

auto left2out = logical_map.mapBestEffort(
left_tv->domain(),
left_tv->getLogicalDomain(),
out_tv->domain(),
out_tv->getLogicalDomain());

for (IterDomain* left_id : left_tv->getLogicalDomain()) {
// Note: this assumes that the left <-> right mapping exists
// Does this need to handle left-right mapping failures?

IterDomainStatus leftStatus = iterMap[left_id];
IterDomainStatus rightStatus = iterMap[left2right[left_id]];

IterDomainStatus status = BopStatus(leftStatus, rightStatus);

auto out_id = left2out[left_id];

if (out_id) {
iterMap[out_id] = status;
} else {
if (StatusIsBad(status)) {
return false;
}
}
}

} else {
// unknown op type, ensure it has no bad status since information will not
// flow through it.
if (anyBadInputs) {
return false;
}
}
}

// Check whether any bad status reached output nodes
auto output_tvs = ir_utils::filterByType<TensorView>(fusion->outputs());
for (TensorView* tv : output_tvs) {
for (IterDomain* id : tv->getLogicalDomain()) {
IterDomainStatus status = iterMap[id];
if (StatusIsBad(status)) {
return false;
}
}
}

return true;
}

void FMinFMaxPromotionPass::runPass(Fusion* fusion) {
FusionGuard fusion_guard(fusion);

// The outer loop runs over all expressions, filtering out most of them.
// It stops only on min/max reductions, which become the target for the rest
// of the analysis.
for (Expr* targetExpr : fusion->exprs()) {
auto* targetRop = dynamic_cast<ReductionOp*>(targetExpr);

if (!targetRop) {
continue;
}

auto reduction_type = targetRop->getReductionOpType();

if (reduction_type == BinaryOpType::Min ||
reduction_type == BinaryOpType::Max) {
if (AnalyzeMinMaxOp(targetRop)) {
targetRop->markUnsafe();
}
}
}

return;
}

} // namespace nvfuser::preseg_passes
41 changes: 41 additions & 0 deletions csrc/preseg_passes/fmin_fmax_promotion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#include <preseg_passes/optimization_pass.h>

namespace nvfuser::preseg_passes {

// Converts max & min reductions into faster versions,
// which don't propagate NANs.
//
// Pytorch propagates NANs for min and max reductions. However, fmax and fmin
// do not propagate NANs in Cuda. So for a kernel to match pytorch behavior, it
// must contain additional branches which are expensive. Other ops such as sum()
// propagate NANs by default, with no loss in performance. Then take for
// example:
//
// tv1 = max(tv0, {0});
// tv2 = sum(tv0, {0});
// tv3 = add(tv1, tv2);
//
// Here, if max() fails to propagate NANs, it will be "repaired" by the
// downstream sum() reduction. This can also work if there are matching
// broadcasts following the reductions.
//
class FMinFMaxPromotionPass : public OptimizationPass<FMinFMaxPromotionPass> {
friend class OptimizationPass<FMinFMaxPromotionPass>;

protected:
static void runPass(Fusion* fusion);
static constexpr std::string_view name() {
return "FMinFMaxPromotionPass";
}
};

} // namespace nvfuser::preseg_passes
2 changes: 2 additions & 0 deletions csrc/preseg_passes/pre_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <preseg_passes/decompose_reshardings.h>
#include <preseg_passes/exact_mapped_extent_substitution.h>
#include <preseg_passes/finalize_multidevice_domains.h>
#include <preseg_passes/fmin_fmax_promotion.h>
#include <preseg_passes/mark_aliases_prepare.h>
#include <preseg_passes/move_gather.h>
#include <preseg_passes/move_pad.h>
Expand Down Expand Up @@ -51,6 +52,7 @@ namespace nvfuser::preseg_passes {
// removes consecutive cast operations
OptimizationPass<ConsecutiveCastPass>::runPass(fusion);
OptimizationPass<AddAxiomsPass>::runPass(fusion);
OptimizationPass<FMinFMaxPromotionPass>::runPass(fusion);
OptimizationPass<MoveSplitCatPass>::runPass(fusion);
// MovePadPass needs to happen:
// 1. before MarkAliasPrepare; and
Expand Down
Loading