diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 0206aa88a8..be8e3d2bf8 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -184,6 +184,8 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; + static constexpr bool kPadN = true; + static constexpr int kBlockPerCu = 2; }; template + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t Aload_inst = + (kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize; + constexpr index_t Bload_inst = + (kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize; + constexpr index_t BQload_inst = ck_tile::integer_divide_ceil( + ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType), + QuantGroupSize * QuantGroupSize), + VectorLoadSize); + + constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst + 4; + constexpr index_t ds_read_inst = kMPerBlock / 8; + constexpr index_t ds_write_inst = Aload_inst; + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + constexpr index_t buffer_load_rep = mfma_inst / buffer_load_inst; + static_for<0, nloop, 1>{}([&](auto j_inst) { + ignore = j_inst; + static_for<0, mfma_inst, 1>{}([&](auto i_inst) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA + + if constexpr(i_inst % ds_rep == 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_READ, 1, 0); // DS read + } + if constexpr(i_inst % ds_rep == 1) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write + } + + if constexpr(i_inst % buffer_load_rep == 0) + { + if constexpr(ds_write_inst > 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read + } + } + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + static constexpr bool PreshuffleB = Problem::PreshuffleB; static constexpr auto TailNum = Problem::TailNum; @@ -302,24 +351,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV // MAIN LOOP index_t iCounter = (num_loop - 1) / 2; + while(iCounter > 0) { - // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - bq_block_tile_2 = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); - + __builtin_amdgcn_sched_barrier(0); // Prefill A(2i+1) a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_pong, a_block_tile_tmp); @@ -335,6 +370,21 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV b_warp_tensor_ping, bq_block_tile, a_warp_windows_ping); + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + bq_block_tile_2 = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; @@ -342,7 +392,6 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV a_warp_tensor(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); }); - Base::HotLoopScheduler(); // Next K @@ -384,9 +433,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV a_warp_tensor(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); }); - Base::HotLoopScheduler(); - iCounter--; + HotLoopScheduler<2>(); } // tail @@ -424,15 +472,13 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV load_tile(a_warp_windows_pong(number{})(number{})); }); - Base::Last2ndHotLoopScheduler(); - // GEMM loopK block_weight_preshuffle(c_block_tile, a_warp_tensor, b_warp_tensor_pong, bq_block_tile_2, a_warp_windows_pong); - Base::LastHotLoopScheduler(); + HotLoopScheduler<2>(); } else if constexpr(TailNum == TailNumber::Odd) {