-
Notifications
You must be signed in to change notification settings - Fork 188
[CK_TILE] Remove scratch usage from universal gemm #2001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -115,6 +119,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& | |||
return ave_time; | |||
}; | |||
|
|||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we refactor this to use previous method of "run"? We can create two GemmEpilogue:
- GemmEpilogue
- GemmEpilogueSplitK
Then create GemmKernel and GemmKernelSplitK and launch appropriate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still you would have if k_batch ==1 .... else
logic inside. For me it's OK to do this in two stages. I'd rather make an alias template of using GemmEpilogue = ck_tile::CShuffleEpilogue<
which you would parameterize with TransposeC
- because it's known only in lambda, from UniversalGEmmProblem
and with memory_operation
. All other types are known earlier.
{ | ||
Run_with_k_batch(has_hot_loop_, | ||
tail_number_, | ||
ck_tile::integral_constant<ck_tile::memory_operation_enum, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same for each operator
static constexpr index_t kMPerIteration = kMPerXdl * kMWave; | ||
static constexpr index_t kNPerIteration = kNPerXdl * kNWave; | ||
using CLayout = remove_cvref_t<typename Problem::CLayout>; | ||
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we put this information as a class member, you no longer need to have it as a operator()
template paramter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still see this as a tparam here:
https://github.com/ROCm/composable_kernel/pull/2001/files#diff-a2466bfef61d871813a8a210b406e7db886cc294935562fe0076017c0e7f75aaR127
@@ -115,6 +119,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& | |||
return ave_time; | |||
}; | |||
|
|||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still you would have if k_batch ==1 .... else
logic inside. For me it's OK to do this in two stages. I'd rather make an alias template of using GemmEpilogue = ck_tile::CShuffleEpilogue<
which you would parameterize with TransposeC
- because it's known only in lambda, from UniversalGEmmProblem
and with memory_operation
. All other types are known earlier.
952699e
to
1bec2e5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there ;)
static constexpr index_t kMPerIteration = kMPerXdl * kMWave; | ||
static constexpr index_t kNPerIteration = kNPerXdl * kNWave; | ||
using CLayout = remove_cvref_t<typename Problem::CLayout>; | ||
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still see this as a tparam here:
https://github.com/ROCm/composable_kernel/pull/2001/files#diff-a2466bfef61d871813a8a210b406e7db886cc294935562fe0076017c0e7f75aaR127
Proposed changes
Remove scratch usage from universal gemm by moving the if kbatch related condition oustide of kernel and passing memory operation enum as a template parameter
Checklist
Please put an
x
into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-format
on all changed files