Skip to content

[CK_TILE] Add universal gemm mem skip A/B LDS pipelines #2056

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added GEMM pipeline for microscaling (MX) data types
* Added support for FP16 2:4 structured sparsity to universal GEMM.
* Added support for Split K for grouped convolution backward data.
* Added support for skipping LDS to universal GEMM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be "to"? Because this sounds like the LDS is being skipped to go straight to universal GEMM, which doesn't sound quite right.

Is it maybe supposed to be "for" or "in" or "with"?

As in support's been added for skipping LDS when using universal GEMM?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be "for" or "in".


### Optimized

Expand Down
9 changes: 9 additions & 0 deletions example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ struct GemmConfig
static constexpr ck_tile::index_t K_Warp_Tile = 8;

static constexpr bool DoubleSmemBuffer = false;

static constexpr bool SkipALds = true;
static constexpr bool SkipBLds = true;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
Expand All @@ -68,6 +71,9 @@ struct GemmConfig
static constexpr ck_tile::index_t K_Warp_Tile = 32;

static constexpr bool DoubleSmemBuffer = false;

static constexpr bool SkipALds = false;
static constexpr bool SkipBLds = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
Expand All @@ -84,6 +90,9 @@ struct GemmConfig
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr bool DoubleSmemBuffer = true;

static constexpr bool SkipALds = false;
static constexpr bool SkipBLds = false;
#endif

static constexpr bool kPadM = false;
Expand Down
2 changes: 2 additions & 0 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ void permute_tensor_b(Tensor& tensor)
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
GemmConfig::SkipALds,
GemmConfig::SkipBLds,
ALayout,
BLayout,
CLayout,
Expand Down
2 changes: 2 additions & 0 deletions example/ck_tile/03_gemm/universal_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
GemmConfig::SkipALds,
GemmConfig::SkipBLds,
ALayout,
BLayout,
CLayout,
Expand Down
3 changes: 3 additions & 0 deletions include/ck_tile/ops/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_ar_br_cr.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_ar_bs_cr.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_br_cr.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
Expand Down
458 changes: 458 additions & 0 deletions include/ck_tile/ops/gemm/block/block_universal_gemm_ar_br_cr.hpp

Large diffs are not rendered by default.

526 changes: 526 additions & 0 deletions include/ck_tile/ops/gemm/block/block_universal_gemm_ar_bs_cr.hpp

Large diffs are not rendered by default.

539 changes: 539 additions & 0 deletions include/ck_tile/ops/gemm/block/block_universal_gemm_as_br_cr.hpp

Large diffs are not rendered by default.

58 changes: 40 additions & 18 deletions include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ struct BlockUniversalGemmAsBsCr
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));

ALdsTile a_warp_tile_;
ALdsTile b_warp_tile_;
BLdsTile b_warp_tile_;

// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
Expand Down Expand Up @@ -300,9 +300,8 @@ struct BlockUniversalGemmAsBsCr
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;

template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
template <typename ASmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetchA(const ASmemBlockWindow& a_block_window)
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
Expand All @@ -312,6 +311,11 @@ struct BlockUniversalGemmAsBsCr
{
load_tile(a_warp_tile_, a_block_window);
}
}

template <typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetchB(const BSmemBlockWindow& b_block_window)
{
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_tile_, b_block_window);
Expand Down Expand Up @@ -390,27 +394,18 @@ struct BlockUniversalGemmAsBsCr
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));

ALdsTile a_warp_tile_;
ALdsTile b_warp_tile_;
BLdsTile b_warp_tile_;

template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
template <index_t KIdx, typename ASmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetchA(const ASmemBlockWindow& a_block_window)
{
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(MakeBBlockDistributionEncode());

auto a_lds_gemm_window = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{}),
{0, KIdx * KPerInnerLoop},
a_lds_load_tile_distr);
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{}),
{0, KIdx * KPerInnerLoop},
b_lds_load_tile_distr);

if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
Expand All @@ -420,6 +415,19 @@ struct BlockUniversalGemmAsBsCr
{
load_tile(a_warp_tile_, a_lds_gemm_window);
}
}

template <index_t KIdx, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetchB(const BSmemBlockWindow& b_block_window)
{
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(MakeBBlockDistributionEncode());
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{}),
{0, KIdx * KPerInnerLoop},
b_lds_load_tile_distr);

if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_tile_, b_block_window);
Expand All @@ -442,7 +450,8 @@ struct BlockUniversalGemmAsBsCr

// hot loop:
static_for<0, KRepeat, 1>{}([&](auto kIter) {
LocalPrefetch<kIter.value>(a_block_window, b_block_window);
LocalPrefetchA<kIter.value>(a_block_window);
LocalPrefetchB<kIter.value>(b_block_window);
__builtin_amdgcn_sched_barrier(0);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
Expand Down Expand Up @@ -547,7 +556,20 @@ struct BlockUniversalGemmAsBsCr
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
block_gemm_impl_.LocalPrefetchA(a_block_window);
block_gemm_impl_.LocalPrefetchB(b_block_window);
}

template <typename ASmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetchA(const ASmemBlockWindow& a_block_window)
{
block_gemm_impl_.LocalPrefetchA(a_block_window);
}

template <typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetchB(const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_.LocalPrefetchB(b_block_window);
}

// C += A * B
Expand Down
61 changes: 43 additions & 18 deletions include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ struct GemmPipelineAgBgCrImplBase
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
const auto block_tile = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile);
}

template <typename DstBlockTile, typename SrcTileWindow>
Expand All @@ -56,8 +56,11 @@ struct GemmPipelineAgBgCrImplBase
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);

// TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
constexpr index_t a_lds_block_space_size_aligned =
Problem::SkipALds
? 0
: integer_least_multiple(
sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);

// B tile in LDS
BDataType* __restrict__ p_b_lds = static_cast<BDataType*>(
Expand All @@ -68,8 +71,8 @@ struct GemmPipelineAgBgCrImplBase
return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
}

template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
template <typename ADramBlockWindow, typename ALdsTensorView, typename ALdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindow& a_dram_block_window,
const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr&) const
{
Expand All @@ -79,11 +82,22 @@ struct GemmPipelineAgBgCrImplBase
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;

// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
decltype(auto) a_copy_dram_window = [&] {
if constexpr(Problem::SkipALds == false)
{
return make_tile_window(a_dram_block_window.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
a_dram_block_window.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
}
else
{
return make_tile_window(a_dram_block_window.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
a_dram_block_window.get_window_origin(),
ALdsLoadTileDistr{});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not necessarily be the optimal solution. I imagine that regardless of using LDS we should read global memory in with an optimal vectorized access pattern. Then if it's needed you would just adapt the data layout in registers.

}
}();

// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
Expand All @@ -100,8 +114,8 @@ struct GemmPipelineAgBgCrImplBase
std::move(a_lds_gemm_window));
}

template <typename BDramBlockWindowTmp, typename BLdsTensorView, typename BLdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
template <typename BDramBlockWindow, typename BLdsTensorView, typename BLdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindow& b_dram_block_window,
const BLdsTensorView& b_lds_block_view,
const BLdsLoadTileDistr&) const
{
Expand All @@ -110,11 +124,22 @@ struct GemmPipelineAgBgCrImplBase
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;

auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
decltype(auto) b_copy_dram_window = [&] {
if constexpr(Problem::SkipBLds == false)
{
return make_tile_window(b_dram_block_window.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
b_dram_block_window.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
}
else
{
return make_tile_window(b_dram_block_window.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
b_dram_block_window.get_window_origin(),
BLdsLoadTileDistr{});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above for GetAWindows

}
}();

// TODO: Do we really need those two tile windows???
// They're exactly same...
Expand Down
Loading