-
Notifications
You must be signed in to change notification settings - Fork 69
Cast opt pass #355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cast opt pass #355
Changes from 92 commits
6772128
3bb0edb
28b488a
6675a7e
ec9827a
623ae4f
2c4d4c1
1405827
b2d9ed8
72da69c
b786160
816a5a0
875164f
aca1abd
de7dee7
87d55df
c5dd6ad
0344dfa
d258a37
faa269a
b42bba4
6949fef
47da287
f51481e
56bd373
41b09c6
d5938e1
3e404d1
81ec162
e1c4e42
206915d
f1c1e6b
44605e1
a4fa8d0
109a4d6
36c32d6
9e81880
2b849a3
f1eeafa
6ed2312
ce88a4c
5c47da4
b8a56ca
f0ea352
d7cb3e3
352025e
249bb9e
83f5889
6fece9a
3d1ad95
082c7fc
ba66797
8233762
a6066c3
b08054d
3cdd2dc
48ac0cf
a198517
527ec3d
c3bcd35
fcba589
0f7ddcc
5966632
738a814
9940dd5
e95f898
0cf81a1
9b175ed
7c8ea52
7846600
1216b38
10520e9
90fabd4
1501a76
96c2398
80c98bc
0fdfd61
1c3c6ce
29f6403
9b0fd8e
861353e
e87ce9a
0429b12
3a8724e
6153e2f
a85e1a6
29a6137
e665635
55241f6
57cf5ea
e13eac1
a5bef36
99c54b6
44a6aa3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,195 @@ | ||||||
| // clang-format off | ||||||
| /* | ||||||
| * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||||||
| * All rights reserved. | ||||||
| * SPDX-License-Identifier: BSD-3-Clause | ||||||
| */ | ||||||
| // clang-format on | ||||||
| #include <optimization/consecutive_cast.h> | ||||||
|
|
||||||
| #include <ir/utils.h> | ||||||
|
|
||||||
| namespace nvfuser::optimization { | ||||||
|
|
||||||
| namespace { | ||||||
|
|
||||||
| bool isCast(Expr* expr) { | ||||||
| if (auto op = dynamic_cast<UnaryOp*>(expr)) { | ||||||
| return op->getUnaryOpType() == UnaryOpType::Cast; | ||||||
| } | ||||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| // replaces input to the cast op that produes cast_output, return the new | ||||||
| // cast_output | ||||||
| Val* replaceInputInCast(Val* cast_output, Val* new_input) { | ||||||
| auto tmp_expr = cast_output->definition(); | ||||||
| // short-cut for cases when no substitution is needed; | ||||||
| if (cast_output == new_input || new_input == tmp_expr->input(0)) { | ||||||
| return cast_output; | ||||||
| } | ||||||
| auto new_expr = nvfuser::ir_utils::replaceValInExpr( | ||||||
| tmp_expr, tmp_expr->input(0), new_input); | ||||||
| return new_expr->output(0); | ||||||
| } | ||||||
|
|
||||||
| // castOptimizationPass folds away consecutive cast operations. e.g. a chain of | ||||||
| // cast like fp16 -> fp32 -> fp16 can be simplified away without impacting the | ||||||
| // output of the fusion. For a section of consecutive casts, chained as: | ||||||
| // first_tv -> cast_tv0 -> cast_tv1 -> cast_tv2 -> output_tv | ||||||
| // Supposed we have a TensorView `lo_anchor` in the chain, that the dtype of | ||||||
| // every other tv is either the same as or wider than | ||||||
| // `lo_anchor->getDataType()`, we can then assume that all the other cast ops in | ||||||
| // the chain is a no-op (except the last one, which defines the output dtype). | ||||||
| // so the above chain can be re-wired with only two casts as: | ||||||
| // first_tv -> lo_anchor -> output_tv | ||||||
| // | ||||||
| // A complexity that could happen is, we might not necessarily have a narrowest | ||||||
| // dtype in the chain. i.e. think about pairs like fp16/bfloat16, or fp32/int32, | ||||||
| // where one can't represent the other. In order to handle this scenario, we can | ||||||
| // just keep track of a `starting_anchor` that indicates the starting point of | ||||||
| // the section, that has a valid `lo_anchor`. When we encounter a new cast op | ||||||
| // that breaks the assumption, we'll optimize what we have seen in the existing | ||||||
| // section and start a new section with the next cast op. | ||||||
| // | ||||||
| // The algorithm: | ||||||
| // 1. iterating through all expr in the fusion: | ||||||
| // 1.1 we skip all exprs other than cast; | ||||||
| // 1.2 for each end cast-op 'expr', we trace back its producers iteratively | ||||||
| // and push the value(s) on top of `chain_cast_vals`, until: | ||||||
| // | ||||||
| // a. the producer is not a cast op; or | ||||||
| // | ||||||
| // b. the producer is used by other ops, or is a fusion output. | ||||||
| // | ||||||
| // 1.3 at this point, each `chain_cast_vals` has an ordered cast outputs with | ||||||
| // a straight line dependency: | ||||||
| // 1.3.1 we point starting_anchor at the beginning op, indicating the | ||||||
| // starting point of our folding optimization, meanwhile, we point | ||||||
| // lo_anchor at the first op, indicating the narrowest dtype we have seen | ||||||
| // in the segment; | ||||||
| // 1.3.2 we enter the loop to iterate through items | ||||||
| // inside `chain_cast_vals`, for item `val`: | ||||||
| // | ||||||
| // a. if `val_dtype` is the same as, or wider than `anchor_dtype` | ||||||
| // of `lo_anchor`, current cast is a no-op and can be ignored; | ||||||
| // | ||||||
| // b. if `anchor_dtype` is narrower than `val_dtype`, previous cast | ||||||
| // to `lo_anchor` is a no-op and can be folded away. We update | ||||||
| // `lo_anchor` to point to `val`; | ||||||
| // | ||||||
| // c. otherwise, `val` and `lo_anchor` are incompatible casts and | ||||||
| // both needs to be preserved. We'll rewire it as: | ||||||
| // `starting_anchor`->`lo_anchor`->`val`. Afterwards, we'll update | ||||||
| // `starting_anchor` and `lo_anchor` to both point at `val`. | ||||||
| // | ||||||
| // 1.4 At this point we look at `anchor_dtype` of `lo_anchor` and | ||||||
| // `output_dtype` of `expr->output(0)`: | ||||||
| // | ||||||
| // a. if `anchor_dtype` is the same as `output_dtype`, we skip the last | ||||||
| // cast op and replace all its uses with `lo_anchor`; | ||||||
|
||||||
| // | ||||||
| // b. if `anchor_dtype` is wider than `output_dtype`, all previous cast | ||||||
| // after `starting_anchor` is no-op, we re-wire `starting_anchor` | ||||||
| // directly to `expr`; | ||||||
| // | ||||||
| // c. otherwise, we can't bypass `lo_anchor` cast, we rewire this | ||||||
| // section as `starting_anchor`->`lo_anchor`->`expr->output(0)` | ||||||
| void castOptimizationPass(Fusion* fusion) { | ||||||
| auto exprs = fusion->exprs(); | ||||||
| std::unordered_set<Expr*> visited; | ||||||
| for (auto iter = exprs.rbegin(); iter != exprs.rend(); ++iter) { | ||||||
| auto expr = *iter; | ||||||
| // skip current expr if it's not a foldable cast or it has already been | ||||||
| // addressed | ||||||
| if (visited.count(expr) != 0 || !isCast(expr)) { | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
Because I think
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. visited check should be dirt cheap though right?! Also, unfortunately I was getting a segfault earlier with that.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, I think this is a desired behavior. i.e. for expr that has been short-wired, it's being removed from the graph and that memory is also reused for something else?! I'm just surprised that we already have that. I thought our mutation passes are only adding exprs...
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
But whatever, they are both super cheap. It's just cheap vs cheaper. |
||||||
| continue; | ||||||
| } | ||||||
| std::list<Val*> chain_cast_vals; | ||||||
| auto prev_expr = expr->input(0)->definition(); | ||||||
| while (isCast(prev_expr)) { | ||||||
| auto intermediate_cast = prev_expr->output(0); | ||||||
| // 1.2 Note, if the output of prev_expr | ||||||
| // is used by other operation(s); or | ||||||
| // is a direct output from fusion | ||||||
| // we skip the casting chaining | ||||||
| if (intermediate_cast->isFusionOutput() || | ||||||
| intermediate_cast->uses().size() > 1) { | ||||||
| break; | ||||||
| } | ||||||
|
|
||||||
| // adding prev_expr to visited node so we'll short-cut it. | ||||||
| visited.insert(prev_expr); | ||||||
| // in the loop, we just repetitively chaining consecutive casts. | ||||||
| chain_cast_vals.push_front(intermediate_cast); | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we mark
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's exactly what I have in mind. I just got a little bit uncomfortable with the traversal inside |
||||||
| prev_expr = prev_expr->input(0)->definition(); | ||||||
| } | ||||||
|
|
||||||
| // skip current expr if there's no chain_cast_vals | ||||||
| if (chain_cast_vals.empty()) { | ||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| // 1.3.1 Note, chain_cast_vals has a straight-line use without branches | ||||||
| auto lo_anchor = chain_cast_vals.front()->definition()->input(0); | ||||||
| auto anchor_dtype = lo_anchor->getDataType().value(); | ||||||
| auto starting_anchor = lo_anchor; | ||||||
| for (auto val : chain_cast_vals) { | ||||||
| auto val_dtype = val->getDataType().value(); | ||||||
|
|
||||||
| // 1.3.2.a short-cut when we are not losing precision | ||||||
| if (isInclusiveType(anchor_dtype, val_dtype)) { | ||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| // 1.3.2.c NOTE: To enter here, we have | ||||||
| // !isInclusiveType(anchor_dtype, val_dtype) && | ||||||
| // !isInclusiveType(val_dtype, anchor_dtype) | ||||||
| // | ||||||
| // Which means the dtype between lo_anchor and val isn't compatible and | ||||||
| // can't be fold away without losing information. So we update the | ||||||
| // starting_anchor to current val, which ensures that we preserve the | ||||||
| // incompatible casts. e.g. for cases where no one type is strictly wider | ||||||
| // than the other: i.e. bf16 & fp16, int32 & float32 e.t.c. | ||||||
| if (!isInclusiveType(val_dtype, anchor_dtype)) { | ||||||
| lo_anchor = replaceInputInCast(lo_anchor, starting_anchor); | ||||||
| val = replaceInputInCast(val, lo_anchor); | ||||||
| // We need to update the starting_anchor for the fold to be past this | ||||||
| // current cast. | ||||||
| starting_anchor = val; | ||||||
| } | ||||||
| // 1.3.2.b/c updating new lo_anchor to current val | ||||||
| lo_anchor = val; | ||||||
| anchor_dtype = lo_anchor->getDataType().value(); | ||||||
| } | ||||||
|
|
||||||
| auto output_dtype = expr->output(0)->getDataType().value(); | ||||||
| if (anchor_dtype == output_dtype) { | ||||||
|
||||||
| // 1.4.a final cast is the same dtype as with previous lo_anchor, | ||||||
| // replacing output with lo_anchor in the fusion | ||||||
| lo_anchor = replaceInputInCast(lo_anchor, starting_anchor); | ||||||
| ir_utils::replaceValue(fusion, {{expr->output(0), lo_anchor}}); | ||||||
| if (expr->output(0)->isFusionOutput()) { | ||||||
| fusion->replaceOutput(expr->output(0), lo_anchor); | ||||||
| } | ||||||
| } else if (isInclusiveType(output_dtype, anchor_dtype)) { | ||||||
|
||||||
| // 1.4.b: if lo_anchor is wider than output_dtype, casting to lo_anchor | ||||||
| // isn't doing anything, we'll just fold away to the starting_anchor | ||||||
| // instead | ||||||
| replaceInputInCast(expr->output(0), starting_anchor); | ||||||
| } else { | ||||||
| // 1.4.c: This is the case where we cannot fold away the cast of | ||||||
| // lo_anchor; we'll just re-wire input to expr with lo_anchor | ||||||
| lo_anchor = replaceInputInCast(lo_anchor, starting_anchor); | ||||||
| replaceInputInCast(expr->output(0), lo_anchor); | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| } // namespace | ||||||
|
|
||||||
| void ConsecutiveCastPass::runPass(Fusion* fusion) { | ||||||
| castOptimizationPass(fusion); | ||||||
| } | ||||||
|
|
||||||
| } // namespace nvfuser::optimization | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| // clang-format off | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
| * All rights reserved. | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| */ | ||
| // clang-format on | ||
| #include <optimization/optimization_pass.h> | ||
|
|
||
| namespace nvfuser::optimization { | ||
|
|
||
| //! ConsecutiveCastPass removes redundant consecutive cast operations that | ||
| //! doesn't have any impact on output from fusion. | ||
| class TORCH_CUDA_CU_API ConsecutiveCastPass | ||
| : public OptimizationPass<ConsecutiveCastPass> { | ||
| friend class OptimizationPass<ConsecutiveCastPass>; | ||
|
|
||
| protected: | ||
| static void runPass(Fusion* fusion); | ||
| }; | ||
|
|
||
| } // namespace nvfuser::optimization |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| // clang-format off | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
| * All rights reserved. | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| */ | ||
| // clang-format on | ||
| #pragma once | ||
|
|
||
| #include <ir/interface_nodes.h> | ||
|
|
||
| #include <atomic> | ||
|
|
||
| namespace nvfuser::optimization { | ||
|
|
||
| using FusionPass = std::function<void(Fusion*)>; | ||
|
|
||
| //! [experimental API] | ||
| //! Base class to unify optimization pass APIs. | ||
| //! OptimizationPass can be turned on/off programmatically with the `setEnabled` | ||
| //! API. There's helper template OptimizationPassGuard to temporarily switch the | ||
| //! enablement within the context. Note the we are using a curiously recurring | ||
| //! template pattern here to ensure that static objects are unique for each | ||
| //! DerivedClass. In order to apply OptimizationPass with the switch enabled, | ||
| //! you need to run the function with | ||
| //! `OptimizationPass<DerivedClass>::runPass(...)` | ||
| //! | ||
| //! Specific optimization pass needs to be created like: | ||
| //! | ||
| //! class TORCH_CUDA_CU_API Pass0 : public OptimizationPass<Pass0> { | ||
| //! friend class OptimizationPass<Pass0>; | ||
| //! | ||
| //! protected: | ||
| //! static void runPass(Fusion* fusion); | ||
| //! }; | ||
| template <typename DerivedClass> | ||
| class TORCH_CUDA_CU_API OptimizationPass { | ||
| public: | ||
| static void setEnabled(bool enabled) { | ||
| flag_.store(enabled); | ||
| } | ||
|
|
||
| static bool getEnabled() { | ||
| return flag_.load(); | ||
| } | ||
|
|
||
| static void runPass(Fusion* fusion) { | ||
| if (!flag_.load()) { | ||
| return; | ||
| } | ||
| DerivedClass::runPass(fusion); | ||
| } | ||
|
|
||
| virtual ~OptimizationPass() = default; | ||
|
|
||
| protected: | ||
| static inline std::atomic<bool> flag_{true}; | ||
| }; | ||
|
|
||
| //! [experimental API] | ||
| //! OptimizationPassGuard is used to temporarily switch enable/disable on a | ||
| //! certain pass. Original status will be restored at destruction. | ||
| template <typename OptPass> | ||
| class TORCH_CUDA_CU_API OptimizationPassGuard { | ||
| public: | ||
| OptimizationPassGuard(bool enabled) : prev_status_(OptPass::getEnabled()) { | ||
| if (prev_status_ != enabled) { | ||
| OptPass::setEnabled(enabled); | ||
| } | ||
| } | ||
| ~OptimizationPassGuard() { | ||
| OptPass::setEnabled(prev_status_); | ||
| } | ||
|
|
||
| protected: | ||
| bool prev_status_ = false; | ||
| }; | ||
|
|
||
| } // namespace nvfuser::optimization |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| // clang-format off | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
| * All rights reserved. | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| */ | ||
| // clang-format on | ||
| #include <optimization/pre_segmenter.h> | ||
|
|
||
| #include <optimization/consecutive_cast.h> | ||
|
|
||
| namespace nvfuser::optimization { | ||
|
|
||
| void PreSegmenter::runPass(Fusion* fusion) { | ||
| // removes consecutive cast operations | ||
| OptimizationPass<ConsecutiveCastPass>::runPass(fusion); | ||
| } | ||
|
|
||
| } // namespace nvfuser::optimization |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| // clang-format off | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
| * All rights reserved. | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| */ | ||
| // clang-format on | ||
| #pragma once | ||
|
|
||
| #include <optimization/optimization_pass.h> | ||
|
|
||
| namespace nvfuser::optimization { | ||
|
|
||
| //! PreSegmenter is an optimization group that runs right before fusion executor | ||
| //! segments a fusion into multiple kernels. | ||
| class TORCH_CUDA_CU_API PreSegmenter : public OptimizationPass<PreSegmenter> { | ||
| friend class OptimizationPass<PreSegmenter>; | ||
|
|
||
| protected: | ||
| static void runPass(Fusion* fusion); | ||
| }; | ||
|
|
||
| } // namespace nvfuser::optimization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a test for this case?