Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions example/ck_tile/18_flatmm/run_moe_flatmm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,9 @@ int run_moe_gemm_example_with_layouts(int argc,

const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1 /*kbatch*/, max_accumulated_value);
[[maybe_unused]] const auto rtol_atol =
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1 /*kbatch*/, max_accumulated_value);
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());

const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
Expand Down
4 changes: 4 additions & 0 deletions include/ck_tile/core/arch/generic_memory_space_atomic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
template <>
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
{
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
__builtin_amdgcn_global_atomic_fadd_v2bf16(c_style_pointer_cast<bf16x2_t*>(p_dst), x);
#else
union U32BF162_ADDR
{
uint32_t* u32_a;
Expand All @@ -128,6 +131,7 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
#endif
}

template <>
Expand Down
8 changes: 6 additions & 2 deletions include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ struct MoeFlatmmKernel
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken,
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens,
IsGateUp ? kargs.N / 2 : kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
Expand Down Expand Up @@ -1250,6 +1250,8 @@ struct MoeFlatmmKernel
constexpr int MPerThread = TileEncodingPattern::Y2;
statically_indexed_array<statically_indexed_array<index_t, MPerThread>, NumMEpiTile>
c_scatter_offsets;
statically_indexed_array<statically_indexed_array<bool, MPerThread>, NumMEpiTile>
c_scatter_valids;
auto c_coord = dram_tile_distribution.calculate_index();
static_for<0, NumMEpiTile, 1>{}([&](auto mIter) {
static_for<0, MPerThread, 1>{}([&](auto m0) {
Expand All @@ -1262,6 +1264,7 @@ struct MoeFlatmmKernel
scatter_token_id =
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
});
});

Expand Down Expand Up @@ -1302,7 +1305,8 @@ struct MoeFlatmmKernel
c_block_window.get_window_lengths(),
c_block_window.get_window_origin(),
dram_tile_distribution,
c_scatter_offsets[mIter]);
c_scatter_offsets[mIter],
c_scatter_valids[mIter]);

if constexpr(!IsInputGemm ||
EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)
Expand Down