diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index d898ed2f29..4303acec5a 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -302,10 +302,6 @@ int run_moe_gemm_example_with_layouts(int argc, static_cast(per_token_scale_dev_buf.GetDeviceBuffer()), static_cast(per_channel_scale_dev_buf.GetDeviceBuffer())); - const float max_accumulated_value = - *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, 1 /*kbatch*/, max_accumulated_value); c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); const float rtol = std::is_same_v && IsInputGemm ? 1e-3 : 1e-2; diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp index e56bcadcba..0ff97bb9a7 100644 --- a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -102,6 +102,9 @@ CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x); template <> CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) { +#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN + __builtin_amdgcn_global_atomic_fadd_v2bf16(c_style_pointer_cast(p_dst), x); +#else union U32BF162_ADDR { uint32_t* u32_a; @@ -128,6 +131,7 @@ CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) new_v = new_.u32; cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); } while(cur_v.u32 != old_v); +#endif } template <> diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 8a9aa3cdd3..fb98a71b0f 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -623,7 +623,7 @@ struct MoeFlatmmKernel { return make_naive_tensor_view( e_ptr, - make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken, + make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens, IsGateUp ? kargs.N / 2 : kargs.N), make_tuple(1, kargs.stride_C), number<1>{}, @@ -1250,6 +1250,8 @@ struct MoeFlatmmKernel constexpr int MPerThread = TileEncodingPattern::Y2; statically_indexed_array, NumMEpiTile> c_scatter_offsets; + statically_indexed_array, NumMEpiTile> + c_scatter_valids; auto c_coord = dram_tile_distribution.calculate_index(); static_for<0, NumMEpiTile, 1>{}([&](auto mIter) { static_for<0, MPerThread, 1>{}([&](auto m0) { @@ -1262,6 +1264,7 @@ struct MoeFlatmmKernel scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); }); }); @@ -1302,7 +1305,8 @@ struct MoeFlatmmKernel c_block_window.get_window_lengths(), c_block_window.get_window_origin(), dram_tile_distribution, - c_scatter_offsets[mIter]); + c_scatter_offsets[mIter], + c_scatter_valids[mIter]); if constexpr(!IsInputGemm || EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)