Skip to content

[CK_TILE] Add universal gemm mem skip A/B LDS pipelines #2056

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from

Conversation

jakpiase
Copy link
Contributor

@jakpiase jakpiase commented Apr 5, 2025

Proposed changes

[CK_TILE] Add universal gemm mem skip A/B LDS pipelines for tall and skinny gemms.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Copy link
Collaborator

@aosewski aosewski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good start! But I think we should work more on reusing existing code. Additionally I think we should have skipping A/B LDS functionality controlled (turned on/off) from policy. It would be better if we would have single pipeline (mem in this case) which could be configured by policy to either use LDS for any of inputs, both or none.

Comment on lines 156 to 158
template <typename ADramBlockWindowTmp, typename ALdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto
GetADramWindowSkipLds(const ADramBlockWindowTmp& a_dram_block_window_tmp,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please get rid of this tmp suffix throughout this file ? :)

namespace ck_tile {

template <typename Derived>
struct UniversalGemmSkipBLdsBasePolicy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you actually reuse here this class:

You don't have to use all it's functionality.
Maybe you could even refactor it out to separate file?

auto b_lds_block = Base::GetBLdsTensorView(p_smem);

// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name here is misleading, since you skip A LDS.


// LDS write 0
// TODO add a colmajor support
static_assert(is_a_col_major == false, "AColMajor not supported yet!");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the problem with col-major? You can reuse available logic of reading DRAM->VGPR and then transpose tile if it's in col-major.

using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));

ALdsTile a_warp_tile_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case you actually even don't need this.

"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");

a_warp_tile_.get_thread_buffer() = a_block_tensor.get_thread_buffer();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just use a_block_tensor.

@aosewski
Copy link
Collaborator

And don't forget to update CHANGELOG.md

@jakpiase jakpiase requested a review from a team as a code owner May 5, 2025 18:10
@@ -16,6 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added GEMM pipeline for microscaling (MX) data types
* Added support for FP16 2:4 structured sparsity to universal GEMM.
* Added support for Split K for grouped convolution backward data.
* Added support for skipping LDS to universal GEMM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be "to"? Because this sounds like the LDS is being skipped to go straight to universal GEMM, which doesn't sound quite right.

Is it maybe supposed to be "for" or "in" or "with"?

As in support's been added for skipping LDS when using universal GEMM?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be "for" or "in".

Copy link
Collaborator

@aosewski aosewski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a step into a good direction, however I feel like there's still a lot we can improve. I wonder about just single block gemm version named like: BlockUniversalGemmAxBxCr. From pipeline problem you could get information about skipping A/B LDS, and leverage it in local prefetching of A/B. Other than that core parts of the code would remain the same.

@@ -16,6 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added GEMM pipeline for microscaling (MX) data types
* Added support for FP16 2:4 structured sparsity to universal GEMM.
* Added support for Split K for grouped convolution backward data.
* Added support for skipping LDS to universal GEMM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be "for" or "in".

Comment on lines 58 to 61
static constexpr bool TransposeC = TransposeC_;
static constexpr bool SkipALds = SkipALds_;
static constexpr bool SkipBLds = SkipBLds_;
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add doc to all those members ?

Comment on lines 29 to 32
static constexpr bool TransposeC = false;
static constexpr bool SkipALds = false;
static constexpr bool SkipBLds = false;
static constexpr bool UseStructuredSparsity = false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add doc to all those members ?

Comment on lines +103 to +104
constexpr bool SkipALds = false;
constexpr bool SkipBLds = false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should rather be parameterized in tests or you should create a separate test suite for that.

return make_tile_window(a_dram_block_window.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
a_dram_block_window.get_window_origin(),
ALdsLoadTileDistr{});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not necessarily be the optimal solution. I imagine that regardless of using LDS we should read global memory in with an optimal vectorized access pattern. Then if it's needed you would just adapt the data layout in registers.

}

// C = A * B
template <typename ARegBlockTensor, typename BSmemBlockWindow>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template <typename ARegBlockTensor, typename BSmemBlockWindow>
template <typename ARegBlockTensor, typename BRegBlockWindow>

// C = A * B
template <typename ARegBlockTensor, typename BSmemBlockWindow>
CK_TILE_DEVICE auto operator()(const ARegBlockTensor& a_block_tensor,
const BSmemBlockWindow& b_block_window)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const BSmemBlockWindow& b_block_window)
const BRegBlockWindow& b_block_window)

"traits should be the same as correspoinding block window data type!");

// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {

};

template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this one is entirely same as the Default one now, so you can just derive from it.


if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_tile_, b_block_window);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we have support for B preshuffle with int4 packed data type only when we load B to LDS ... i think we should be able to support this in all situations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants