Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example/ck_tile/01_fmha/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")

message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}")
# to save build time, exclude the target from "all" target of "01_fmha" directory and its ancestors
add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS})
Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/02_layernorm2d/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ add_custom_command(
set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd")

message(DEBUG "adding example ${EXAMPLE_LAYERNORM2D_FWD}")
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
add_executable(${EXAMPLE_LAYERNORM2D_FWD} layernorm2d_fwd.cpp)
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})

Expand Down
40 changes: 21 additions & 19 deletions example/ck_tile/03_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp)
add_executable(tile_example_gemm_reduce EXCLUDE_FROM_ALL gemm_splitk_two_stage_reduce.cpp)
add_executable(tile_example_gemm_splitk_two_stage EXCLUDE_FROM_ALL gemm_splitk_two_stage.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBD provide the list of valid architectures for each target

add_executable(tile_example_gemm_basic gemm_basic.cpp)
add_executable(tile_example_gemm_universal universal_gemm.cpp)
add_executable(tile_example_gemm_weight_preshuffle gemm_weight_preshuffle.cpp)
add_executable(tile_example_gemm_reduce gemm_splitk_two_stage_reduce.cpp)
add_executable(tile_example_gemm_splitk_two_stage gemm_splitk_two_stage.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm -enable-noalias-to-md-conversion=0")
target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_splitk_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm -enable-noalias-to-md-conversion=0")
target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_splitk_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
2 changes: 1 addition & 1 deletion example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ int run_gemm_example_with_layouts_two_stage(ck_tile::ArgParser& arg_parser,

if constexpr(preshuffle)
{
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
ck_tile::HostTensor<BDataType> b_shuffle_host = ck_tile::shuffle_b<GemmConfig>(b_k_n);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't call without explicit namespace prior to C++20

// shuffled buffer B for device implementation
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
}
Expand Down
4 changes: 2 additions & 2 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,12 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
if constexpr(GemmConfig::TiledMMAPermuteN)
{
std::cout << "Run with PermuteN" << std::endl;
return shuffle_b_permuteN<GemmConfig>(b_k_n);
return ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
}
else
{
std::cout << "Run without PermuteN" << std::endl;
return shuffle_b<GemmConfig>(b_k_n);
return ck_tile::shuffle_b<GemmConfig>(b_k_n);
}
}();
// shuffled buffer B for device implementation
Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/04_img2col/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp)
add_executable(tile_example_img2col image_to_column.cpp)
4 changes: 2 additions & 2 deletions example/ck_tile/05_reduce/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set(EXAMPLE_REDUCE "tile_example_reduce")
# to be included in "make all/install/check"
message(DEBUG "adding example ${EXAMPLE_REDUCE}")

add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL reduce.cpp)
add_executable(${EXAMPLE_REDUCE} reduce.cpp)
target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set(EXAMPLE_REDUCE_COMPILE_OPTIONS)

Expand All @@ -16,4 +16,4 @@ target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTION
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
2 changes: 1 addition & 1 deletion example/ck_tile/06_permute/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(tile_example_permute EXCLUDE_FROM_ALL permute.cpp)
add_executable(tile_example_permute permute.cpp)

if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL)
# set(PERMUTE_USE_ALTERNATIVE_IMPL false)
Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/09_topk_softmax/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
add_executable(tile_example_topk_softmax EXCLUDE_FROM_ALL topk_softmax.cpp topk_softmax_api.cpp)
add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp)
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)

set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS)
Expand Down
4 changes: 2 additions & 2 deletions example/ck_tile/10_rmsnorm2d/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ add_custom_command(
set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd")

message(DEBUG "adding ${TILE_RMSNORM2D_FWD}")
add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp)
add_executable(${TILE_RMSNORM2D_FWD} rmsnorm2d_fwd.cpp)
target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})

Expand All @@ -38,7 +38,7 @@ list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno
target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS})

set(EXAMPLE_RMSNORM2D_FWD "tile_example_rmsnorm2d_fwd")
add_executable(${EXAMPLE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL example_rmsnorm2d_fwd.cpp)
add_executable(${EXAMPLE_RMSNORM2D_FWD} example_rmsnorm2d_fwd.cpp)
target_compile_options(${EXAMPLE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS})

# TODO: we have to turn off this global prop, otherwise the progress bar generated
Expand Down
4 changes: 2 additions & 2 deletions example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_add_rmsnorm2d_rdquant_fwd")
# to be included in "make all/install/check"
message(DEBUG "adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL add_rmsnorm2d_rdquant_fwd.cpp)
add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} add_rmsnorm2d_rdquant_fwd.cpp)
target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS})

Expand All @@ -15,7 +15,7 @@ list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-t
target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS})

set(EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD "tile_example_add_rmsnorm2d_rdquant_fwd")
add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL example_add_rmsnorm2d_rdquant_fwd.cpp)
add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} example_add_rmsnorm2d_rdquant_fwd.cpp)
target_compile_options(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS})

# TODO: we have to turn off this global prop, otherwise the progress bar generated
Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/12_smoothquant/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC)
message(DEBUG "adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
add_executable(${TARGET_NAME} ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})

foreach(source IN LISTS ARGN)
Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/13_moe_sorting/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp)
add_executable(tile_example_moe_sorting moe_sorting.cpp moe_sorting_api.cpp)
target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)

set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
Expand Down
3 changes: 1 addition & 2 deletions example/ck_tile/14_moe_smoothquant/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
message(DEBUG "adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
add_executable(${TARGET_NAME} ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})

foreach(source IN LISTS ARGN)
Expand All @@ -22,4 +22,3 @@ endfunction(add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)

add_moe_smoothquant_example(tile_example_moe_smoothquant moe_smoothquant.cpp ${INSTANCE_SRCS})

2 changes: 1 addition & 1 deletion example/ck_tile/15_fused_moe/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe")
# to be included in "make all/install/check"
message(DEBUG "adding ${TILE_EXAPMLE_FUSED_MOE}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp)
add_executable(${TILE_EXAPMLE_FUSED_MOE} main.cpp)
target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS})

Expand Down
16 changes: 0 additions & 16 deletions example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,22 +402,6 @@ float fused_moesorting_mp(fused_moesorting_trait t,
using ms_index_t = ck_tile::index_t;
using ms_weight_type = float;

auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
if(t.clear_workspace_inside_api)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no such member and the macro is not defined

{
if(is_local_token)
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
k(s_);
}
else
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
k(s_);
}
}
};

if(a.tokens < 2048)
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/16_batched_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
add_executable(tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp)
add_executable(tile_example_batched_gemm batched_gemm.cpp)
10 changes: 5 additions & 5 deletions example/ck_tile/17_grouped_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp)
add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp)
add_executable(tile_example_grouped_gemm_multi_d EXCLUDE_FROM_ALL grouped_gemm_multi_d.cpp)
add_executable(tile_example_grouped_gemm grouped_gemm.cpp)
add_executable(tile_example_quant_grouped_gemm quant_grouped_gemm.cpp)
add_executable(tile_example_grouped_gemm_preshuffle grouped_gemm_preshuffle.cpp)
add_executable(tile_example_grouped_gemm_multi_d grouped_gemm_multi_d.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
16 changes: 15 additions & 1 deletion example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@
#include "ck_tile/host.hpp"
#include "quant_grouped_gemm.hpp"

struct QuantGroupSize
{
static constexpr auto kM = 128;
static constexpr auto kN = 1;
static constexpr auto kK = 1;
};

struct BQuantGroupSize
{
static constexpr auto kM = 1;
static constexpr auto kN = 1;
static constexpr auto kK = 128;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my best guess about the tile shape

};

template <typename GemmConfig,
typename ALayout,
typename AQLayout,
Expand Down Expand Up @@ -75,7 +89,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
AccDataType,
GemmShape,
GemmUniversalTraits,
128>, // QuantGroupSize
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this argument must be a typename which has kM kN and kK as static integer members

BQuantGroupSize>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
};

const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
const int init_method = arg_parser.get_int("init");
bool validate = arg_parser.get_bool("validate");
const ck_tile::index_t QuantGroupSize = 128;
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
const int init_method = arg_parser.get_int("init");
bool validate = arg_parser.get_bool("validate");

if(kbatch > 1 && validate && warmup + repeat > 1)
{
Expand Down Expand Up @@ -259,9 +258,9 @@ int run_grouped_gemm_example_with_layouts(int argc,
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
if(K % QuantGroupSize != 0)
AQK = 0; // No A quantization
BQK = K / QuantGroupSize::kM; // Group quantization: BQK = K / GroupSize
if(K % QuantGroupSize::kM != 0)
{
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
// Perform preshuffle for B tensor
if constexpr(GemmConfig::Preshuffle)
{
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n_tensors[i]);
ck_tile::HostTensor<BDataType> b_shuffle_host =
ck_tile::shuffle_b<GemmConfig>(b_k_n_tensors[i]);
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(b_shuffle_host));
}
else
Expand Down
13 changes: 6 additions & 7 deletions example/ck_tile/18_flatmm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ foreach(gpu IN LISTS GPU_TARGETS)
endforeach()

if(has_supported_gpu)
add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
add_executable(tile_example_mixed_prec_flatmm EXCLUDE_FROM_ALL mixed_prec/mixed_prec_flatmm.cpp)
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp)
add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp)
add_executable(tile_example_flatmm_basic flatmm_basic.cpp)
add_executable(tile_example_mixed_prec_flatmm mixed_prec/mixed_prec_flatmm.cpp)
add_executable(tile_example_moe_flatmm moe_flatmm.cpp)
add_executable(tile_example_a16w4_moe_flatmm mixed_prec/a16w4_moe_flatmm.cpp)
add_executable(tile_example_grouped_flatmm grouped_flatmm.cpp)

include(mxgemm/mx_flatmm_instance.cmake)
mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES)
message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}")
add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp ${EXAMPLE_MX_FLATMM_FILES})
add_executable(tile_example_mx_flatmm mxgemm/mx_flatmm.cpp ${EXAMPLE_MX_FLATMM_FILES})
target_include_directories(tile_example_mx_flatmm PRIVATE mxgemm)

# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
Expand All @@ -37,4 +37,3 @@ if(has_supported_gpu)
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
target_compile_options(tile_example_mx_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) # TODO: 950 only
endif()

4 changes: 2 additions & 2 deletions example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ int run_contiguous_grouped_flatmm_example_with_layouts(
assert(N % N_Warp_Tile == 0 &&
"N must be divisible by N_Warp_Tile for contiguous grouped gemm");
ck_tile::HostTensor<BDataType> b_shuffle_host =
shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensor);
ck_tile::shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensor);

std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes()));
Expand Down Expand Up @@ -431,7 +431,7 @@ int run_masked_grouped_flatmm_example_with_layouts(
assert(N % N_Warp_Tile == 0 &&
"N must be divisible by N_Warp_Tile for contiguous grouped gemm");
ck_tile::HostTensor<BDataType> b_shuffle_host =
shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensor);
ck_tile::shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensor);

std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes()));
Expand Down
4 changes: 0 additions & 4 deletions example/ck_tile/18_flatmm/run_moe_flatmm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,6 @@ int run_moe_gemm_example_with_layouts(int argc,
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));

const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1 /*kbatch*/, max_accumulated_value);
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());

const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
Expand Down
Loading
Loading