-
Notifications
You must be signed in to change notification settings - Fork 70
Add fmin_fmax_promotion presegmentation pass #5337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
4e9b960
730f011
fb7409d
a42414a
7b09901
16c4e10
bac3247
df9a676
68224b4
1f11a81
9dda08c
55271b7
0ebc938
7761858
0008034
a7ddcee
bf65b1e
dbd2adf
a2f4b67
c3470c7
188466c
86d6b62
6b238c2
dafb15b
f0100c4
687871a
3300a33
0f94209
148f640
e499693
7786e4c
ab7a81d
7022670
b3e76aa
5a76779
58eff1a
376334c
ea6feb5
54fcd62
292c5e8
1bc5bc3
946f5bd
eb43a69
2da8b93
343730e
4944de8
f845059
b89a83a
338ff98
6faae14
c9d1f9f
972729c
92972a0
7301aac
e60fc12
bd56f00
104bfb1
ca4f265
f26493d
5f3ce30
82483c8
d87899c
c985e5e
29cbf37
e2692b1
1da95e5
d66123a
04544f7
3010ca2
3f69500
50269a0
f613322
484ffc3
8957c36
27de5e5
414e894
2cafb54
2dd0461
20c4843
91ff396
7e9be2b
2dcfcd7
4877075
76781db
ff30b72
4bfa9d2
1f6fdfb
fde13f3
e71516c
9a57500
5a96a0c
ac233a1
bdabc9c
9fcf1e5
234af49
2ccb362
a1b58c9
e6a6748
f464f70
c6914b5
340a1c6
080de3e
cce2924
85bf378
8dff405
6e8875c
1b9f037
0b985cc
16b1e27
a6db940
a96435e
8f708b1
6705930
88a9d61
12b9318
cabc081
acc9ce5
890b375
5cfe2b8
dd2f940
753b4ed
1fffe1f
d735b05
3dfab04
45255f9
dc128b2
8d8dbd4
8dae364
c18df5d
4dd3e7f
8a96590
8b80a2c
4869ca6
ef68868
003bcc8
5218134
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,298 @@ | ||
| // clang-format off | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| * All rights reserved. | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| */ | ||
| // clang-format on | ||
| #include <preseg_passes/fmin_fmax_promotion.h> | ||
|
|
||
| #include <unordered_map> | ||
| #include <vector> | ||
|
|
||
| #include <ir/utils.h> | ||
| #include <logical_domain_map.h> | ||
|
|
||
| namespace nvfuser::preseg_passes { | ||
|
|
||
tbqh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // IterDomainStatus are attached to IterDomains and propagated with a | ||
|
||
| // downward-flow algorithm. The goal is to detect whether lost NANs propagated | ||
| // to the outputs of a fusion. If so, then it is not valid to do the fast | ||
| // min/max promotion. | ||
| // | ||
| // "BAD" statuses indicate ID's that potentially lost their NANs. | ||
| // "GOOD" statuses will repair BAD statuses. | ||
| // "REDUCE" statuses are reduced dimensions, and should never be mapped to ID's | ||
| // with non-reduced status. | ||
| // All other statuses are "full-sized" (besides NONE which says nothing). | ||
| // | ||
| // A reduction ID will convert a full-size status into a REDUCE status. Likewise | ||
| // a broadcast ID will convert a reduced status to a full size status. | ||
| // | ||
| // Statuses can interact via binary ops in a kind of algebra. The two reduced | ||
| // statuses have a simple interaction: GOOD_REDUCE beats BAD-REDUCE. | ||
| // | ||
| // The 4 full-sized statuses have a slightly more complicated interaction: | ||
| // 1. Any status matched with itself, produces itself. | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // 2. GOOD_BROADCAST matched with anything produces GOOD_BROADCAST. | ||
| // 3. All other cases produce BAD_BROADCAST_DEFAULT | ||
| // | ||
| // An example of what each status means is below (with *max* reduction): | ||
| // [0.0 1.0 2.0 3.0 NAN 5.0] <- DEFAULT | ||
| // [5.0] <- BAD_REDUCE | ||
| // [NAN] <- GOOD_REDUCE | ||
| // [5.0 5.0 5.0 5.0 5.0 5.0] <- BAD_BROADCAST | ||
| // [5.0 5.0 5.0 5.0 NAN 5.0] <- BAD_BROADCAST_DEFAULT | ||
| // [NAN NAN NAN NAN NAN NAN] <- GOOD_BROADCAST | ||
| enum class IterDomainStatus { | ||
| // NONE is the status when hitting untracked ID's. | ||
| NONE, | ||
|
|
||
| // Reduced statuses | ||
| BAD_REDUCE, | ||
| GOOD_REDUCE, | ||
|
|
||
| // Full-size statuses | ||
| DEFAULT, | ||
| BAD_BROADCAST, | ||
|
||
| BAD_BROADCAST_DEFAULT, | ||
| GOOD_BROADCAST, | ||
| }; | ||
|
|
||
| using IterStatusMap = std::unordered_map<IterDomain*, IterDomainStatus>; | ||
|
|
||
| IterDomainStatus BopStatus(IterDomainStatus lhs, IterDomainStatus rhs) { | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if (lhs == IterDomainStatus::NONE) { | ||
| return rhs; | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| if (rhs == IterDomainStatus::NONE) { | ||
| return lhs; | ||
| } | ||
|
|
||
| if (lhs == IterDomainStatus::GOOD_REDUCE || | ||
| rhs == IterDomainStatus::GOOD_REDUCE) { | ||
| return IterDomainStatus::GOOD_REDUCE; | ||
| } else if ( | ||
| lhs == IterDomainStatus::BAD_REDUCE && | ||
| rhs == IterDomainStatus::BAD_REDUCE) { | ||
| return IterDomainStatus::BAD_REDUCE; | ||
| } | ||
|
|
||
| if (lhs == rhs) { | ||
| return lhs; | ||
| } | ||
|
|
||
| if (lhs == IterDomainStatus::GOOD_BROADCAST || | ||
| rhs == IterDomainStatus::GOOD_BROADCAST) { | ||
| return IterDomainStatus::GOOD_BROADCAST; | ||
| } | ||
|
|
||
| // The only remaining cases are combinations of DEFAULT, | ||
| // BAD_BROADCAST, and BAD_BROADCAST_DEFAULT. | ||
| return IterDomainStatus::BAD_BROADCAST_DEFAULT; | ||
| } | ||
|
|
||
| bool StatusIsBad(IterDomainStatus status) { | ||
| return status == IterDomainStatus::BAD_REDUCE || | ||
| status == IterDomainStatus::BAD_BROADCAST || | ||
| status == IterDomainStatus::BAD_BROADCAST_DEFAULT; | ||
| } | ||
|
|
||
| bool AnyBadInputs(Expr* expr, IterStatusMap& iterMap) { | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for (auto input : expr->inputs()) { | ||
| if (auto* in_tv = dynamic_cast<TensorView*>(input)) { | ||
| for (IterDomain* id : in_tv->getLogicalDomain()) { | ||
| IterDomainStatus status = iterMap[id]; | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if (StatusIsBad(status)) { | ||
| return true; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return false; | ||
| } | ||
|
|
||
| // Once we identify a target reduction, we perform a downward pass starting from | ||
| // the target's direct input. The pass propagates IterDomainStatus information. | ||
| // At the end, we check all output TV's for bad statuses. If at any point we | ||
| // encounter a node we don't know how to propagate information through, we treat | ||
| // it like to a graph output and fail if it has any incoming bad statuses. | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| bool AnalyzeMinMaxOp(ReductionOp* targetRop) { | ||
| Fusion* fusion = targetRop->fusion(); | ||
|
|
||
| FusionGuard fg(fusion); | ||
| ComputeAtLogicalDomainMap logical_map; | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| logical_map.build(true); | ||
|
|
||
| IterStatusMap iterMap; | ||
|
|
||
| auto* in_tv = targetRop->input(0)->as<TensorView>(); | ||
| for (IterDomain* in_id : in_tv->getLogicalDomain()) { | ||
| iterMap[in_id] = IterDomainStatus::DEFAULT; | ||
| } | ||
|
|
||
| auto* out_tv = targetRop->output(0)->as<TensorView>(); | ||
| for (IterDomain* out_id : out_tv->getLogicalDomain()) { | ||
| if (out_id->isReduction()) { | ||
| iterMap[out_id] = IterDomainStatus::BAD_REDUCE; | ||
| } else { | ||
| iterMap[out_id] = IterDomainStatus::DEFAULT; | ||
| } | ||
| } | ||
|
|
||
| auto traversal = | ||
| StmtSort::getExprsBetween({targetRop->input(0)}, fusion->outputs()); | ||
tbqh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for (Expr* expr : traversal) { | ||
| std::string opName = expr->getOpString(); | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if (expr == targetRop) { | ||
| // Skip the target rop. We already marked its status. | ||
| continue; | ||
| } | ||
|
|
||
| bool anyBadInputs = AnyBadInputs(expr, iterMap); | ||
|
|
||
| auto* out_tv = dynamic_cast<TensorView*>(expr->output(0)); | ||
|
|
||
| if (!out_tv) { | ||
| if (anyBadInputs) { | ||
| return false; | ||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } else { | ||
| continue; | ||
| } | ||
| } | ||
|
|
||
| if (expr->isA<UnaryOp>() || expr->isA<ReductionOp>() || | ||
| expr->isA<BroadcastOp>()) { | ||
| auto in_tv = expr->input(0)->as<TensorView>(); | ||
| auto p2c = logical_map.mapBestEffort( | ||
| in_tv->domain(), | ||
| in_tv->getLogicalDomain(), | ||
| out_tv->domain(), | ||
| out_tv->getLogicalDomain()); | ||
|
|
||
| for (IterDomain* in_id : in_tv->getLogicalDomain()) { | ||
| IterDomainStatus status = iterMap[in_id]; | ||
| auto out_id = p2c[in_id]; | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if (out_id) { | ||
| if (out_id->isReduction()) { | ||
| if (status == IterDomainStatus::BAD_BROADCAST) { | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| status = IterDomainStatus::BAD_REDUCE; | ||
| } else if (status != IterDomainStatus::NONE) { | ||
| status = IterDomainStatus::GOOD_REDUCE; | ||
| } | ||
| } | ||
|
|
||
| if (out_id->isBroadcast()) { | ||
| if (status == IterDomainStatus::BAD_REDUCE) { | ||
| status = IterDomainStatus::BAD_BROADCAST; | ||
| } else if (status != IterDomainStatus::NONE) { | ||
| status = IterDomainStatus::GOOD_BROADCAST; | ||
| } | ||
| } | ||
|
|
||
| iterMap[out_id] = status; | ||
| } else { | ||
| if (StatusIsBad(status)) { | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return false; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } else if (expr->isA<BinaryOp>()) { | ||
| auto* left_tv = dynamic_cast<TensorView*>(expr->input(0)); | ||
| auto* right_tv = dynamic_cast<TensorView*>(expr->input(1)); | ||
|
|
||
| // One side (not both) might not be a TensorView. | ||
| // To handle this, just propagate the status of the other side. | ||
| if (!left_tv) { | ||
| left_tv = right_tv; | ||
| } else if (!right_tv) { | ||
| right_tv = left_tv; | ||
| } | ||
|
|
||
| auto left2right = logical_map.mapBestEffort( | ||
| left_tv->domain(), | ||
| left_tv->getLogicalDomain(), | ||
| right_tv->domain(), | ||
| right_tv->getLogicalDomain()); | ||
|
|
||
| auto left2out = logical_map.mapBestEffort( | ||
| left_tv->domain(), | ||
| left_tv->getLogicalDomain(), | ||
| out_tv->domain(), | ||
| out_tv->getLogicalDomain()); | ||
|
|
||
| for (IterDomain* left_id : left_tv->getLogicalDomain()) { | ||
| // Note: this assumes that the left <-> right mapping exists | ||
| // Does this need to handle left-right mapping failures? | ||
|
|
||
| IterDomainStatus leftStatus = iterMap[left_id]; | ||
| IterDomainStatus rightStatus = iterMap[left2right[left_id]]; | ||
|
|
||
| IterDomainStatus status = BopStatus(leftStatus, rightStatus); | ||
|
|
||
| auto out_id = left2out[left_id]; | ||
|
|
||
| if (out_id) { | ||
| iterMap[out_id] = status; | ||
| } else { | ||
| if (StatusIsBad(status)) { | ||
| return false; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } else { | ||
| // unknown op type, ensure it has no bad status since information will not | ||
| // flow through it. | ||
| if (anyBadInputs) { | ||
| return false; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Check whether any bad status reached output nodes | ||
| auto output_tvs = ir_utils::filterByType<TensorView>(fusion->outputs()); | ||
| for (TensorView* tv : output_tvs) { | ||
| for (IterDomain* id : tv->getLogicalDomain()) { | ||
| IterDomainStatus status = iterMap[id]; | ||
| if (StatusIsBad(status)) { | ||
| return false; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| void FMinFMaxPromotionPass::runPass(Fusion* fusion) { | ||
| FusionGuard fusion_guard(fusion); | ||
|
|
||
| // The outer loop runs over all expressions, filtering out most of them. | ||
| // It stops only on min/max reductions, which become the target for the rest | ||
| // of the analysis. | ||
| for (Expr* targetExpr : fusion->exprs()) { | ||
| auto* targetRop = dynamic_cast<ReductionOp*>(targetExpr); | ||
|
|
||
| if (!targetRop) { | ||
| continue; | ||
| } | ||
|
|
||
| auto reduction_type = targetRop->getReductionOpType(); | ||
|
|
||
| if (reduction_type == BinaryOpType::Min || | ||
| reduction_type == BinaryOpType::Max) { | ||
| if (AnalyzeMinMaxOp(targetRop)) { | ||
tbqh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| targetRop->markUnsafe(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return; | ||
| } | ||
|
|
||
| } // namespace nvfuser::preseg_passes | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| // clang-format off | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
| * All rights reserved. | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| */ | ||
| // clang-format on | ||
| #pragma once | ||
|
|
||
| #include <preseg_passes/optimization_pass.h> | ||
|
|
||
| namespace nvfuser::preseg_passes { | ||
|
|
||
| // Converts max & min reductions into faster versions, | ||
| // which don't propagate NANs. | ||
| // | ||
| // Pytorch propagates NANs for min and max reductions. However, fmax and fmin | ||
| // do not propagate NANs in Cuda. So for a kernel to match pytorch behavior, it | ||
| // must contain additional branches which are expensive. Other ops such as sum() | ||
| // propagate NANs by default, with no loss in performance. Then take for | ||
| // example: | ||
| // | ||
| // tv1 = max(tv0, {0}); | ||
| // tv2 = sum(tv0, {0}); | ||
| // tv3 = add(tv1, tv2); | ||
| // | ||
| // Here, if max() fails to propagate NANs, it will be "repaired" by the | ||
| // downstream sum() reduction. This can also work if there are matching | ||
| // broadcasts following the reductions. | ||
| // | ||
| class FMinFMaxPromotionPass : public OptimizationPass<FMinFMaxPromotionPass> { | ||
| friend class OptimizationPass<FMinFMaxPromotionPass>; | ||
|
|
||
| protected: | ||
| static void runPass(Fusion* fusion); | ||
| static constexpr std::string_view name() { | ||
| return "FMinFMaxPromotionPass"; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace nvfuser::preseg_passes |
Uh oh!
There was an error while loading. Please reload this page.