Skip to content
Open
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
d1bf200
support bf16*mxfp4 gemm
Sep 5, 2025
4e205c4
rebase bf16*fp4 example to develop branch
k50112113 Sep 8, 2025
52c5ed5
Clean up commented debug code in GEMM kernel
eliotwang Sep 8, 2025
e1d0365
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 8, 2025
43db1f7
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 9, 2025
ff89459
rename example folder
Sep 9, 2025
1409d62
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 9, 2025
637f2e8
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 10, 2025
30450e3
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 10, 2025
291e36b
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 11, 2025
d2c79f8
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 12, 2025
ba84541
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 16, 2025
e31f9df
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 22, 2025
f2c0d77
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 24, 2025
f65c005
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 25, 2025
78ae8aa
support bf16*mxfp4 gemm
Sep 5, 2025
e953070
rebase bf16*fp4 example to develop branch
k50112113 Sep 8, 2025
9d01db5
Clean up commented debug code in GEMM kernel
eliotwang Sep 8, 2025
28d4d24
rename example folder
Sep 9, 2025
8229d64
rebase to new develop
k50112113 Oct 10, 2025
23c89b3
rebase to new develop
k50112113 Oct 10, 2025
ec53824
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 10, 2025
adb4bc3
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 16, 2025
5304448
Merge branch 'develop' into bf16_fp4_gemm
illsilin Oct 16, 2025
75dbf17
fix clang format
illsilin Oct 16, 2025
3efca0f
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 17, 2025
8aec6b9
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 20, 2025
5e91e9e
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 21, 2025
cba5ab1
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 22, 2025
5697816
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 23, 2025
8f272b3
Merge branch 'develop' into bf16_fp4_gemm
illsilin Oct 23, 2025
8bad07a
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 27, 2025
03406c0
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 29, 2025
984ed9f
Merge remote-tracking branch 'upstream/develop' into bf16_fp4_gemm
eliotwang Nov 11, 2025
46ff36a
update code according to reviewer's comment
eliotwang Nov 11, 2025
01f5c75
Update README.md
eliotwang Nov 11, 2025
92d7082
update code according to reviewer's comment
eliotwang Nov 13, 2025
7e32cb9
Merge branch 'bf16_fp4_gemm' of https://github.com/eliotwang/heyi_com…
eliotwang Nov 13, 2025
48e6393
update code according to reviewer's comment
eliotwang Nov 13, 2025
88c6a8c
Update CMakeLists.txt
eliotwang Nov 13, 2025
87c4e07
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 13, 2025
6883654
Merge remote-tracking branch 'upstream/develop' into bf16_fp4_gemm
eliotwang Nov 14, 2025
3579741
Update README.md
eliotwang Nov 14, 2025
9579e6f
Update CMakeLists.txt
eliotwang Nov 14, 2025
8a4ac27
Delete files
eliotwang Nov 14, 2025
f6ffb76
Delete files
eliotwang Nov 14, 2025
5225b4d
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 17, 2025
eb15154
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 18, 2025
ba12e7d
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 18, 2025
fde6e39
Add unit tests
eliotwang Nov 19, 2025
f54857f
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 19, 2025
7ceedeb
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 19, 2025
d5ce464
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 20, 2025
329c601
Update test_gemm_quant_base.hpp
eliotwang Nov 20, 2025
8c75bc1
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 20, 2025
7851abb
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example/ck_tile/38_block_scale_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
gemm_aquant_quantgrouped.cpp
gemm_bquant_quantgrouped_prefill_bf8i4.cpp
gemm_bquant_quantgrouped_prefill_fp8i4.cpp
gemm_bquant_quantgrouped_prefill_bf16mxfp4.cpp
gemm_bquant_quantgrouped_prefill_bf8.cpp
gemm_bquant_quantgrouped_prefill_fp8.cpp
gemm_bquant_quantgrouped_preshuffleb_prefill.cpp
Expand Down
8 changes: 4 additions & 4 deletions example/ck_tile/38_block_scale_gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming
- **Preshuffled GEMM**: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM.
- **TransposeC**: Transpose the C Matrix Output layout to have the best coalesced scale reading
- **Preshuffled Quant**: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension.
- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix).
- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix), uint8(split into two fp4 in the pipeline (for B Matrix)).
- **Validation**: CPU/GPU validation and error tolerance options.

## build
Expand All @@ -33,9 +33,9 @@ mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
# Compile the quant kernels
make tile_example_gemm_quant_basic -j
make tile_example_gemm_quant -j
```
This will result in an executable `build/bin/tile_example_gemm_quant_basic`
This will result in an executable `build/bin/tile_example_gemm_quant`

## example
```
Expand All @@ -53,7 +53,7 @@ args:
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8)
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4 or bf16f4 (default for both AQuant and Bquant: fp8)
-warmup Number of iterations before benchmarking the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:1000)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.

#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;

#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::pk_fp4_raw_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_bf16f4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf16_t,
ck_tile::pk_fp4_raw_t,
ck_tile::bf16_t,
ck_tile::pk_fp4_raw_t>{});
#ifndef CK_GFX950_SUPPORT
lut[hash_multiple_strings({"bf16f4", "bquant", "non-preshuffleb", "1x1x32"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
#endif
lut[hash_multiple_strings({"bf16f4", "bquant", "non-preshuffleb", "1x1x32"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf16f4", "bquant", "non-preshuffleb", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf16f4", "bquant", "non-preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
4 changes: 4 additions & 0 deletions example/ck_tile/38_block_scale_gemm/gemm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ void bquant_quantgrouped_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_bf16f4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshuffleb_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void quant_rowcol_instance_factory(
Expand All @@ -110,9 +112,11 @@ int main(int argc, char* argv[])
bquant_quantgrouped_bf8_instance_factory(lut);
bquant_quantgrouped_fp8i4_instance_factory(lut);
bquant_quantgrouped_bf8i4_instance_factory(lut);
bquant_quantgrouped_bf16f4_instance_factory(lut);
bquant_quantgrouped_preshuffleb_instance_factory(lut);
quant_rowcol_instance_factory(lut);
quant_tensor_instance_factory(lut);


auto key = gen_lut_key(arg_parser);

Expand Down
11 changes: 9 additions & 2 deletions example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
using ComputeType =
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>, ADataType,
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
Expand Down Expand Up @@ -288,6 +289,12 @@ struct DataTypeTraits<ck_tile::bf8_t>
static constexpr const char* name = "bf8";
};

template <>
struct DataTypeTraits<ck_tile::pk_fp4_raw_t>
{
static constexpr const char* name = "pk_fp4_raw_t";
};

template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
Expand Down
57 changes: 44 additions & 13 deletions example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
std::conditional_t<std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
ck_tile::BaseMxFp4GemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>;

const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
Expand Down Expand Up @@ -128,11 +130,15 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
std::conditional_t<std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>>;

using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
typename TypeConfig::ADataType, typename TypeConfig::BDataType>,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
Expand Down Expand Up @@ -188,7 +194,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::HostTensor<typename TypeConfig::ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<typename TypeConfig::BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t> ? args.K / 2 : args.K,
args.N, args.stride_B, is_row_major(BLayout{})));

auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
Expand Down Expand Up @@ -404,9 +411,10 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::index_t init_method = arg_parser.get_int("init");
bool flush_cache = arg_parser.get_bool("flush_cache");
int rotating_count = arg_parser.get_int("rotating_count");

stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_B = ck_tile::get_default_stride((std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>) ? (K / 2) : K,
N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));

// Conditional stride calculation based on QuantMode
Expand Down Expand Up @@ -434,7 +442,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::host_tensor_descriptor(
(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>) ? (K / 2) : K,
N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));

Expand Down Expand Up @@ -478,13 +488,23 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}

ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
Expand Down Expand Up @@ -684,7 +704,16 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
ck_tile::reference_mxfp4gemm_quant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
else
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BDataType,
AccDataType,
Expand Down Expand Up @@ -750,16 +779,18 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;

if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant) &&
QuantMode == ck_tile::QuantType::RowColQuant ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>) &&
GemmConfig::PreshuffleB)
{
throw std::runtime_error(
"Preshuffling weight matrix is not supported for AQuant or RowColQuant");
"Preshuffling weight matrix is not supported for AQuant, RowColQuant or bf16_fp4_gemm");
}

if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf16_t>)
{
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
Expand Down
6 changes: 4 additions & 2 deletions include/ck_tile/core/arch/amd_buffer_addressing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1550,9 +1550,11 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, e8m0_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
(std::is_same<T, pk_fp4_raw_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");

using rtn_type = thread_buffer<T, N>;
Expand Down
4 changes: 2 additions & 2 deletions include/ck_tile/host/check_err.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1
{

static_assert(
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_fp4_t, pk_fp4_raw_t, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");

double compute_error = 0;
Expand Down Expand Up @@ -114,7 +114,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
{

static_assert(
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_fp4_t, pk_fp4_raw_t, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");

auto expo = std::log2(std::abs(max_possible_num));
Expand Down
57 changes: 57 additions & 0 deletions include/ck_tile/host/reference/reference_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,63 @@ CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor<ADataType>& a_m_k
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
}

template <typename ADataType,
typename QDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename QuantGroupSize,
bool aquant,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<QDataType>& q,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);

auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
AccDataType pasual = 0;
for(std::size_t k = 0; k < (K / 2); k++)
{
using ComputeType = float;
auto b_scale = type_convert<int32_t>(q((2 * k) / QuantGroupSize::kK, n)) - 127;
ComputeType v_a_0, v_a_1;
ComputeType v_b_0, v_b_1;

v_a_0 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k))));
v_a_1 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k + 1))));

if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
{
auto b_pack = type_convert<pk_fp4_t>(b_element_op(b_k_n(k, n)));
auto b_scale_fp4 = type_convert<float>(std::pow(2.0f, b_scale));

auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));

v_b_0 = type_convert<ComputeType>(b_f4_lo) * b_scale_fp4;
v_b_1 = type_convert<ComputeType>(b_f4_hi) * b_scale_fp4;
}

pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1;
v_acc += pasual;
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};

make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
std::cout << std::endl;
}

template <typename ADataType,
typename BDataType,
typename AccDataType,
Expand Down
1 change: 1 addition & 0 deletions include/ck_tile/ops/common/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ template <> struct typeToStr<bf8_t> { static constexpr const char * name = "bf8"
template <> struct typeToStr<int8_t> { static constexpr const char * name = "int8"; };
template <> struct typeToStr<pk_int4_t> { static constexpr const char * name = "pk_int4"; };
template <> struct typeToStr<pk_fp4_t> { static constexpr const char * name = "pk_fp4"; };
template <> struct typeToStr<pk_fp4_raw_t> { static constexpr const char * name = "pk_fp4_raw"; };

template <memory_operation_enum MemOp> struct memOpToStr;
template <> struct memOpToStr<memory_operation_enum::set> { static constexpr const char * name = "set"; };
Expand Down
Loading