Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase

static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr bool kPadN = true;
static constexpr int kBlockPerCu = 2;
};

template <typename ADataType_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV

using Base::m_preload;

static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize;
static constexpr index_t QScalesPerBlockRow =
Expand All @@ -93,6 +94,54 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
// clang-format on
}

template <index_t nloop>
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t Aload_inst =
(kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize;
constexpr index_t Bload_inst =
(kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize;
constexpr index_t BQload_inst = ck_tile::integer_divide_ceil(
ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType),
QuantGroupSize * QuantGroupSize),
VectorLoadSize);

constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst + 4;
constexpr index_t ds_read_inst = kMPerBlock / 8;
constexpr index_t ds_write_inst = Aload_inst;
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
constexpr index_t buffer_load_rep = mfma_inst / buffer_load_inst;
static_for<0, nloop, 1>{}([&](auto j_inst) {
ignore = j_inst;
static_for<0, mfma_inst, 1>{}([&](auto i_inst) {
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA

if constexpr(i_inst % ds_rep == 0)
{
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read
}
if constexpr(i_inst % ds_rep == 1)
{
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write
}

if constexpr(i_inst % buffer_load_rep == 0)
{
if constexpr(ds_write_inst > 0)
{
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
}
}
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU
});
});
__builtin_amdgcn_sched_barrier(0);
}

static constexpr bool PreshuffleB = Problem::PreshuffleB;
static constexpr auto TailNum = Problem::TailNum;

Expand Down Expand Up @@ -302,24 +351,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV

// MAIN LOOP
index_t iCounter = (num_loop - 1) / 2;

while(iCounter > 0)
{
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;

move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});

bq_block_tile_2 = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});

__builtin_amdgcn_sched_barrier(0);
// Prefill A(2i+1)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
Expand All @@ -335,14 +370,28 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
b_warp_tensor_ping,
bq_block_tile,
a_warp_windows_ping);
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;

move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});

bq_block_tile_2 = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});

static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
Base::HotLoopScheduler();

// Next K

Expand Down Expand Up @@ -384,9 +433,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
Base::HotLoopScheduler();

iCounter--;
HotLoopScheduler<2>();
}

// tail
Expand Down Expand Up @@ -424,15 +472,13 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});

Base::Last2ndHotLoopScheduler();

// GEMM loopK
block_weight_preshuffle(c_block_tile,
a_warp_tensor,
b_warp_tensor_pong,
bq_block_tile_2,
a_warp_windows_pong);
Base::LastHotLoopScheduler();
HotLoopScheduler<2>();
}
else if constexpr(TailNum == TailNumber::Odd)
{
Expand Down