Skip to content
Merged
Show file tree
Hide file tree
Changes from 92 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
6772128
initial code shelve
jjsjann123 May 16, 2023
3bb0edb
wip add build files
jjsjann123 May 17, 2023
28b488a
fixing build
jjsjann123 May 17, 2023
6675a7e
adding optimization; adding test
jjsjann123 May 17, 2023
ec9827a
fixing tests
jjsjann123 May 17, 2023
623ae4f
fixing tests
jjsjann123 May 17, 2023
2c4d4c1
remove debug print
jjsjann123 May 17, 2023
1405827
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 May 17, 2023
b2d9ed8
clangformat
jjsjann123 May 17, 2023
72da69c
short-cut to skip trivial casting
jjsjann123 May 17, 2023
b786160
fixing logic in cast operation. updating cpp tests
jjsjann123 May 17, 2023
816a5a0
fixing logic in safety check for cast optimization; fixing test
jjsjann123 May 18, 2023
875164f
a few knobs to switch optimization pass
jjsjann123 May 18, 2023
aca1abd
fixing logic
jjsjann123 May 18, 2023
de7dee7
fixing logic in disabling flags
jjsjann123 May 18, 2023
87d55df
patching tests
jjsjann123 May 18, 2023
c5dd6ad
fixing tests
jjsjann123 May 18, 2023
0344dfa
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 May 18, 2023
d258a37
clangformat
jjsjann123 May 18, 2023
faa269a
clangtidy
jjsjann123 May 18, 2023
b42bba4
clangformat
jjsjann123 May 18, 2023
6949fef
CLANGTIDY
jjsjann123 May 18, 2023
47da287
typo
jjsjann123 May 18, 2023
f51481e
initial value for optimization pass
jjsjann123 May 18, 2023
56bd373
addressing review comments
jjsjann123 May 18, 2023
41b09c6
patching tests
jjsjann123 May 18, 2023
d5938e1
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 May 18, 2023
3e404d1
typo
jjsjann123 May 18, 2023
81ec162
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 May 19, 2023
e1c4e42
renaming header
jjsjann123 May 19, 2023
206915d
header change
jjsjann123 May 19, 2023
f1c1e6b
refactor cast opt pass
jjsjann123 May 21, 2023
44605e1
refactor registration
jjsjann123 May 22, 2023
a4fa8d0
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 May 22, 2023
109a4d6
fixing typo and stuff
jjsjann123 May 22, 2023
36c32d6
fixing build issue
jjsjann123 May 22, 2023
9e81880
fixing static member in template
jjsjann123 May 22, 2023
2b849a3
clangformat
jjsjann123 May 22, 2023
f1eeafa
linter; fixing flag check for groups
jjsjann123 May 22, 2023
6ed2312
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 May 22, 2023
ce88a4c
adding comment
jjsjann123 May 22, 2023
5c47da4
updating skip logic
jjsjann123 May 22, 2023
b8a56ca
comment
jjsjann123 May 22, 2023
f0ea352
clangformat
jjsjann123 May 22, 2023
d7cb3e3
Apply suggestions from code review
jjsjann123 May 23, 2023
352025e
addressing review comments
jjsjann123 May 23, 2023
249bb9e
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 May 25, 2023
83f5889
WIP
jjsjann123 May 26, 2023
6fece9a
filling in implementation
jjsjann123 May 26, 2023
3d1ad95
fixing typo; updating logic
jjsjann123 May 26, 2023
082c7fc
patched tests
jjsjann123 May 27, 2023
ba66797
fixing typo
jjsjann123 May 27, 2023
8233762
fixing logic, prints in test
jjsjann123 May 27, 2023
a6066c3
more test case added
jjsjann123 May 27, 2023
b08054d
fixing tests
jjsjann123 May 27, 2023
3cdd2dc
changing container type
jjsjann123 May 27, 2023
48ac0cf
moving wide type check to type.h/cpp
jjsjann123 May 27, 2023
a198517
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 May 27, 2023
527ec3d
lintrunner
jjsjann123 May 27, 2023
c3bcd35
removing print
jjsjann123 May 27, 2023
fcba589
clang-format clang-tidy
jjsjann123 May 27, 2023
0f7ddcc
clangformat
jjsjann123 May 27, 2023
5966632
review comment
jjsjann123 May 30, 2023
738a814
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 May 30, 2023
9940dd5
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 May 30, 2023
e95f898
addressing review comments
jjsjann123 May 30, 2023
0cf81a1
clangformat
jjsjann123 May 30, 2023
9b175ed
refactoring optimization passes
jjsjann123 May 31, 2023
7c8ea52
fixing typoe
jjsjann123 May 31, 2023
7846600
comment added
jjsjann123 May 31, 2023
1216b38
added more documentation
jjsjann123 May 31, 2023
10520e9
added test case with mixed dtype categories
jjsjann123 May 31, 2023
90fabd4
merging OptimizationGroup to OptimizationPass
jjsjann123 May 31, 2023
1501a76
renaming methods
jjsjann123 May 31, 2023
96c2398
fixing re-wiring bug
jjsjann123 May 31, 2023
80c98bc
clangformat
jjsjann123 May 31, 2023
0fdfd61
file renaming and class renaming
jjsjann123 May 31, 2023
1c3c6ce
moving tests to a separate file
jjsjann123 May 31, 2023
29f6403
fixing test; fixing context switch
jjsjann123 May 31, 2023
9b0fd8e
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 May 31, 2023
861353e
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 May 31, 2023
e87ce9a
code cleaning; review comments addressed
jjsjann123 May 31, 2023
0429b12
added missing test case
jjsjann123 Jun 1, 2023
3a8724e
reverting unwanted changes
jjsjann123 Jun 1, 2023
6153e2f
review comments
jjsjann123 Jun 1, 2023
a85e1a6
err test is wrong
jjsjann123 Jun 1, 2023
29a6137
quick refactor on redundant processing of short-wired casts
jjsjann123 Jun 1, 2023
e665635
type was wrong
jjsjann123 Jun 1, 2023
55241f6
fixing short-cut check to avoid segfault
jjsjann123 Jun 1, 2023
57cf5ea
clangformat
jjsjann123 Jun 1, 2023
e13eac1
Merge branch 'main' into cast_opt_pass
jjsjann123 Jun 1, 2023
a5bef36
clangtidy
jjsjann123 Jun 1, 2023
99c54b6
merging logic since we are only looking at straight-line logic now
jjsjann123 Jun 1, 2023
44a6aa3
patching the case where the cast needs to be removed
jjsjann123 Jun 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ else() # PROJECT_IS_TOP_LEVEL
return()
endif()

if(NOT USE_ROCM)
set(TORCHLIB_FLAVOR torch_cuda)
else()
set(TORCHLIB_FLAVOR torch_hip)
endif()
set(TORCHLIB_FLAVOR torch_cuda)

# TODO: have TORCH_ROOT setup as a variable instead
# currently we are expecting nvfuser to be added from the pytorch root cmake file.
Expand Down Expand Up @@ -185,6 +181,8 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/utils.cpp
${NVFUSER_SRCS_DIR}/mma_type.cpp
${NVFUSER_SRCS_DIR}/scheduler/mma_utils.cpp
${NVFUSER_SRCS_DIR}/optimization/consecutive_cast.cpp
${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp
)

set(NVFUSER_CODEGEN ${PROJECT_NAME}_codegen)
Expand Down Expand Up @@ -376,6 +374,7 @@ if(BUILD_TEST)
${NVFUSER_ROOT}/test/test_gpu_multidevice.cpp
${NVFUSER_ROOT}/test/test_multicluster_fusion.cpp
${NVFUSER_ROOT}/test/test_combined_inner_outer_reduction.cpp
${NVFUSER_ROOT}/test/test_optimization_pass.cpp
)
list(APPEND JIT_TEST_CU_SRCS ${NVFUSER_ROOT}/test/test_gpu_rng.cu)

Expand Down
4 changes: 4 additions & 0 deletions csrc/kernel_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <executor_params.h>
#include <instrumentation.h>
#include <ir/utils.h>
#include <optimization/pre_segmenter.h>
#include <parser.h>
#include <scheduler/debug_utils.h>
#include <scheduler/registry.h>
Expand Down Expand Up @@ -665,6 +666,9 @@ FusionKernelRuntime::FusionKernelRuntime(
!fusion->hasDynamicTransform(),
"Fusion must be concretized before constructing FusionKernelRuntime");

optimization::OptimizationPass<optimization::PreSegmenter>::runPass(
fusion.get());

all_tvs_ = ir_utils::allTvs(fusion.get());

// Run segmentation on the copied fusion
Expand Down
195 changes: 195 additions & 0 deletions csrc/optimization/consecutive_cast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <optimization/consecutive_cast.h>

#include <ir/utils.h>

namespace nvfuser::optimization {

namespace {

bool isCast(Expr* expr) {
if (auto op = dynamic_cast<UnaryOp*>(expr)) {
return op->getUnaryOpType() == UnaryOpType::Cast;
}
return false;
}

// replaces input to the cast op that produes cast_output, return the new
// cast_output
Val* replaceInputInCast(Val* cast_output, Val* new_input) {
auto tmp_expr = cast_output->definition();
// short-cut for cases when no substitution is needed;
if (cast_output == new_input || new_input == tmp_expr->input(0)) {
return cast_output;
}
auto new_expr = nvfuser::ir_utils::replaceValInExpr(
tmp_expr, tmp_expr->input(0), new_input);
return new_expr->output(0);
}

// castOptimizationPass folds away consecutive cast operations. e.g. a chain of
// cast like fp16 -> fp32 -> fp16 can be simplified away without impacting the
// output of the fusion. For a section of consecutive casts, chained as:
// first_tv -> cast_tv0 -> cast_tv1 -> cast_tv2 -> output_tv
// Supposed we have a TensorView `lo_anchor` in the chain, that the dtype of
// every other tv is either the same as or wider than
// `lo_anchor->getDataType()`, we can then assume that all the other cast ops in
// the chain is a no-op (except the last one, which defines the output dtype).
// so the above chain can be re-wired with only two casts as:
// first_tv -> lo_anchor -> output_tv
//
// A complexity that could happen is, we might not necessarily have a narrowest
// dtype in the chain. i.e. think about pairs like fp16/bfloat16, or fp32/int32,
// where one can't represent the other. In order to handle this scenario, we can
// just keep track of a `starting_anchor` that indicates the starting point of
// the section, that has a valid `lo_anchor`. When we encounter a new cast op
// that breaks the assumption, we'll optimize what we have seen in the existing
// section and start a new section with the next cast op.
//
// The algorithm:
// 1. iterating through all expr in the fusion:
// 1.1 we skip all exprs other than cast;
// 1.2 for each end cast-op 'expr', we trace back its producers iteratively
// and push the value(s) on top of `chain_cast_vals`, until:
//
// a. the producer is not a cast op; or
//
// b. the producer is used by other ops, or is a fusion output.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test for this case?

//
// 1.3 at this point, each `chain_cast_vals` has an ordered cast outputs with
// a straight line dependency:
// 1.3.1 we point starting_anchor at the beginning op, indicating the
// starting point of our folding optimization, meanwhile, we point
// lo_anchor at the first op, indicating the narrowest dtype we have seen
// in the segment;
// 1.3.2 we enter the loop to iterate through items
// inside `chain_cast_vals`, for item `val`:
//
// a. if `val_dtype` is the same as, or wider than `anchor_dtype`
// of `lo_anchor`, current cast is a no-op and can be ignored;
//
// b. if `anchor_dtype` is narrower than `val_dtype`, previous cast
// to `lo_anchor` is a no-op and can be folded away. We update
// `lo_anchor` to point to `val`;
//
// c. otherwise, `val` and `lo_anchor` are incompatible casts and
// both needs to be preserved. We'll rewire it as:
// `starting_anchor`->`lo_anchor`->`val`. Afterwards, we'll update
// `starting_anchor` and `lo_anchor` to both point at `val`.
//
// 1.4 At this point we look at `anchor_dtype` of `lo_anchor` and
// `output_dtype` of `expr->output(0)`:
//
// a. if `anchor_dtype` is the same as `output_dtype`, we skip the last
// cast op and replace all its uses with `lo_anchor`;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think this can be a special case of b, and handled the same way?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. it's just merging that tiny logic in the two blocks.

//
// b. if `anchor_dtype` is wider than `output_dtype`, all previous cast
// after `starting_anchor` is no-op, we re-wire `starting_anchor`
// directly to `expr`;
//
// c. otherwise, we can't bypass `lo_anchor` cast, we rewire this
// section as `starting_anchor`->`lo_anchor`->`expr->output(0)`
void castOptimizationPass(Fusion* fusion) {
auto exprs = fusion->exprs();
std::unordered_set<Expr*> visited;
for (auto iter = exprs.rbegin(); iter != exprs.rend(); ++iter) {
auto expr = *iter;
// skip current expr if it's not a foldable cast or it has already been
// addressed
if (visited.count(expr) != 0 || !isCast(expr)) {
Copy link
Collaborator

@zasdfgbnm zasdfgbnm Jun 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (visited.count(expr) != 0 || !isCast(expr)) {
if (!isCast(expr) || visited.count(expr) != 0) {

Because I think isCast is faster to check and more likely to happen, so we should put it first to short-circuit the later relatively expensive check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

visited check should be dirt cheap though right?!

Also, unfortunately I was getting a segfault earlier with that.
It's very surprising to me that apparently some fusion mutation somehow makes expr in a limbo state. i.e. without the short-cut of visited.count(expr), The dynamic cast on expr gave me a segfault... which is pretty surprising.....

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I think this is a desired behavior. i.e. for expr that has been short-wired, it's being removed from the graph and that memory is also reused for something else?!

I'm just surprised that we already have that. I thought our mutation passes are only adding exprs...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought our mutation passes are only adding exprs...
Haha, I am thinking the same.

But whatever, they are both super cheap. It's just cheap vs cheaper.

continue;
}
std::list<Val*> chain_cast_vals;
auto prev_expr = expr->input(0)->definition();
while (isCast(prev_expr)) {
auto intermediate_cast = prev_expr->output(0);
// 1.2 Note, if the output of prev_expr
// is used by other operation(s); or
// is a direct output from fusion
// we skip the casting chaining
if (intermediate_cast->isFusionOutput() ||
intermediate_cast->uses().size() > 1) {
break;
}

// adding prev_expr to visited node so we'll short-cut it.
visited.insert(prev_expr);
// in the loop, we just repetitively chaining consecutive casts.
chain_cast_vals.push_front(intermediate_cast);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we mark intermediate_cast as visited and skip this expr in outer loop? I think your TODO is pretty easy to implement: just reverse the order of outer loop and use a set marking all visited extra.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly what I have in mind. I just got a little bit uncomfortable with the traversal inside fusion->exprs(). 😆
I'll refactor this.

prev_expr = prev_expr->input(0)->definition();
}

// skip current expr if there's no chain_cast_vals
if (chain_cast_vals.empty()) {
continue;
}

// 1.3.1 Note, chain_cast_vals has a straight-line use without branches
auto lo_anchor = chain_cast_vals.front()->definition()->input(0);
auto anchor_dtype = lo_anchor->getDataType().value();
auto starting_anchor = lo_anchor;
for (auto val : chain_cast_vals) {
auto val_dtype = val->getDataType().value();

// 1.3.2.a short-cut when we are not losing precision
if (isInclusiveType(anchor_dtype, val_dtype)) {
continue;
}

// 1.3.2.c NOTE: To enter here, we have
// !isInclusiveType(anchor_dtype, val_dtype) &&
// !isInclusiveType(val_dtype, anchor_dtype)
//
// Which means the dtype between lo_anchor and val isn't compatible and
// can't be fold away without losing information. So we update the
// starting_anchor to current val, which ensures that we preserve the
// incompatible casts. e.g. for cases where no one type is strictly wider
// than the other: i.e. bf16 & fp16, int32 & float32 e.t.c.
if (!isInclusiveType(val_dtype, anchor_dtype)) {
lo_anchor = replaceInputInCast(lo_anchor, starting_anchor);
val = replaceInputInCast(val, lo_anchor);
// We need to update the starting_anchor for the fold to be past this
// current cast.
starting_anchor = val;
}
// 1.3.2.b/c updating new lo_anchor to current val
lo_anchor = val;
anchor_dtype = lo_anchor->getDataType().value();
}

auto output_dtype = expr->output(0)->getDataType().value();
if (anchor_dtype == output_dtype) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test for this case? I don't think the logic inside this if clause is correct. For example, I didn't see the casts between start and anchor being removed.

// 1.4.a final cast is the same dtype as with previous lo_anchor,
// replacing output with lo_anchor in the fusion
lo_anchor = replaceInputInCast(lo_anchor, starting_anchor);
ir_utils::replaceValue(fusion, {{expr->output(0), lo_anchor}});
if (expr->output(0)->isFusionOutput()) {
fusion->replaceOutput(expr->output(0), lo_anchor);
}
} else if (isInclusiveType(output_dtype, anchor_dtype)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a test for this case?

// 1.4.b: if lo_anchor is wider than output_dtype, casting to lo_anchor
// isn't doing anything, we'll just fold away to the starting_anchor
// instead
replaceInputInCast(expr->output(0), starting_anchor);
} else {
// 1.4.c: This is the case where we cannot fold away the cast of
// lo_anchor; we'll just re-wire input to expr with lo_anchor
lo_anchor = replaceInputInCast(lo_anchor, starting_anchor);
replaceInputInCast(expr->output(0), lo_anchor);
}
}
}

} // namespace

void ConsecutiveCastPass::runPass(Fusion* fusion) {
castOptimizationPass(fusion);
}

} // namespace nvfuser::optimization
22 changes: 22 additions & 0 deletions csrc/optimization/consecutive_cast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <optimization/optimization_pass.h>

namespace nvfuser::optimization {

//! ConsecutiveCastPass removes redundant consecutive cast operations that
//! doesn't have any impact on output from fusion.
class TORCH_CUDA_CU_API ConsecutiveCastPass
: public OptimizationPass<ConsecutiveCastPass> {
friend class OptimizationPass<ConsecutiveCastPass>;

protected:
static void runPass(Fusion* fusion);
};

} // namespace nvfuser::optimization
79 changes: 79 additions & 0 deletions csrc/optimization/optimization_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#include <ir/interface_nodes.h>

#include <atomic>

namespace nvfuser::optimization {

using FusionPass = std::function<void(Fusion*)>;

//! [experimental API]
//! Base class to unify optimization pass APIs.
//! OptimizationPass can be turned on/off programmatically with the `setEnabled`
//! API. There's helper template OptimizationPassGuard to temporarily switch the
//! enablement within the context. Note the we are using a curiously recurring
//! template pattern here to ensure that static objects are unique for each
//! DerivedClass. In order to apply OptimizationPass with the switch enabled,
//! you need to run the function with
//! `OptimizationPass<DerivedClass>::runPass(...)`
//!
//! Specific optimization pass needs to be created like:
//!
//! class TORCH_CUDA_CU_API Pass0 : public OptimizationPass<Pass0> {
//! friend class OptimizationPass<Pass0>;
//!
//! protected:
//! static void runPass(Fusion* fusion);
//! };
template <typename DerivedClass>
class TORCH_CUDA_CU_API OptimizationPass {
public:
static void setEnabled(bool enabled) {
flag_.store(enabled);
}

static bool getEnabled() {
return flag_.load();
}

static void runPass(Fusion* fusion) {
if (!flag_.load()) {
return;
}
DerivedClass::runPass(fusion);
}

virtual ~OptimizationPass() = default;

protected:
static inline std::atomic<bool> flag_{true};
};

//! [experimental API]
//! OptimizationPassGuard is used to temporarily switch enable/disable on a
//! certain pass. Original status will be restored at destruction.
template <typename OptPass>
class TORCH_CUDA_CU_API OptimizationPassGuard {
public:
OptimizationPassGuard(bool enabled) : prev_status_(OptPass::getEnabled()) {
if (prev_status_ != enabled) {
OptPass::setEnabled(enabled);
}
}
~OptimizationPassGuard() {
OptPass::setEnabled(prev_status_);
}

protected:
bool prev_status_ = false;
};

} // namespace nvfuser::optimization
19 changes: 19 additions & 0 deletions csrc/optimization/pre_segmenter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <optimization/pre_segmenter.h>

#include <optimization/consecutive_cast.h>

namespace nvfuser::optimization {

void PreSegmenter::runPass(Fusion* fusion) {
// removes consecutive cast operations
OptimizationPass<ConsecutiveCastPass>::runPass(fusion);
}

} // namespace nvfuser::optimization
23 changes: 23 additions & 0 deletions csrc/optimization/pre_segmenter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#include <optimization/optimization_pass.h>

namespace nvfuser::optimization {

//! PreSegmenter is an optimization group that runs right before fusion executor
//! segments a fusion into multiple kernels.
class TORCH_CUDA_CU_API PreSegmenter : public OptimizationPass<PreSegmenter> {
friend class OptimizationPass<PreSegmenter>;

protected:
static void runPass(Fusion* fusion);
};

} // namespace nvfuser::optimization
Loading