-
Notifications
You must be signed in to change notification settings - Fork 7
[MatMul] loop interleaving pass to interleave double buffered unrolled loops #1975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: index_codegen
Are you sure you want to change the base?
Changes from all commits
5a6aa34
68ce333
9918f3d
222e053
bfc41f4
e648eea
a98f510
8a262b0
531d3e2
2d66bd3
77fdf76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ unsigned int getDoubleBufferAxisPosition(const TensorView* tv) { | |
// which defines the loop where prefetching is applied. Therefore, | ||
// the CA position must be larger than 0. | ||
|
||
TORCH_INTERNAL_ASSERT(tv->getComputeAtPosition() > 0); | ||
TORCH_INTERNAL_ASSERT(tv->getComputeAtPosition() > 0, tv->toString()); | ||
|
||
// Unroll must not exist outside of double-buffer axis | ||
auto first_unroll_it = std::find_if( | ||
|
@@ -337,7 +337,10 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { | |
} | ||
} | ||
|
||
if (stage_depth > 2) { | ||
// Need to insert commits for multi-stage circular buffering | ||
// on the prologs, but do not need to wait for them until | ||
// the main loop. | ||
if (stage_depth > 2 && loop_type_ == DoubleBufferLoopStage::Prolog) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a generic bug fix or is it related to the interleaving transformation? |
||
cloned_top_level_loop_->body().push_back( | ||
IrBuilder::create<kir::CpAsyncCommit>()); | ||
} | ||
|
@@ -821,6 +824,10 @@ class DoubleBufferInserter : private kir::ExprMutator { | |
main_loop->iter_domain()); | ||
auto cp_async_wait = IrBuilder::create<kir::CpAsyncWait>(stage_depth - 2); | ||
|
||
// Make sure the commit is inserted right before the | ||
// cp.async.wait in circular buffering. | ||
bool need_insert_commit = stage_depth > 2; | ||
|
||
// Check if a sync has been inserted by WAR sync pass. | ||
auto block_sync_it = std::find_if( | ||
main_loop->body().exprs().rbegin(), | ||
|
@@ -832,10 +839,18 @@ class DoubleBufferInserter : private kir::ExprMutator { | |
// it can just be anywhere in the loop. Chose to | ||
// place at the end arbitrarily. | ||
main_loop->body().insert_after(end_of_loop_expr, cp_async_wait); | ||
if (need_insert_commit) { | ||
main_loop->body().insert_after( | ||
end_of_loop_expr, IrBuilder::create<kir::CpAsyncCommit>()); | ||
} | ||
} else { | ||
// If a sync has been inserted, wait needs to be placed | ||
// before the sync. | ||
main_loop->body().insert_before(*block_sync_it, cp_async_wait); | ||
if (need_insert_commit) { | ||
main_loop->body().insert_before( | ||
*block_sync_it, IrBuilder::create<kir::CpAsyncCommit>()); | ||
} | ||
Comment on lines
+850
to
+853
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not completely following what should be done here, but the above comment on |
||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments on the pair