Skip to content

[CK_TILE] Remove scratch usage from universal gemm #2001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 5, 2025

Conversation

jakpiase
Copy link
Contributor

Proposed changes

Remove scratch usage from universal gemm by moving the if kbatch related condition oustide of kernel and passing memory operation enum as a template parameter

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

@@ -115,6 +119,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
return ave_time;
};

const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refactor this to use previous method of "run"? We can create two GemmEpilogue:

  • GemmEpilogue
  • GemmEpilogueSplitK
    Then create GemmKernel and GemmKernelSplitK and launch appropriate

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still you would have if k_batch ==1 .... else logic inside. For me it's OK to do this in two stages. I'd rather make an alias template of using GemmEpilogue = ck_tile::CShuffleEpilogue< which you would parameterize with TransposeC - because it's known only in lambda, from UniversalGEmmProblem and with memory_operation. All other types are known earlier.

{
Run_with_k_batch(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same for each operator

static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we put this information as a class member, you no longer need to have it as a operator() template paramter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -115,6 +119,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
return ave_time;
};

const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still you would have if k_batch ==1 .... else logic inside. For me it's OK to do this in two stages. I'd rather make an alias template of using GemmEpilogue = ck_tile::CShuffleEpilogue< which you would parameterize with TransposeC - because it's known only in lambda, from UniversalGEmmProblem and with memory_operation. All other types are known earlier.

@jakpiase jakpiase force-pushed the jakpiase/gemm_scratch_fix branch from 952699e to 1bec2e5 Compare April 7, 2025 15:54
Copy link
Collaborator

@aosewski aosewski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there ;)

static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jakpiase jakpiase merged commit 0bcb804 into develop May 5, 2025
34 of 38 checks passed
@jakpiase jakpiase deleted the jakpiase/gemm_scratch_fix branch May 5, 2025 16:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants