diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 932acb72fd..6d16927dfe 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -12,6 +12,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_aquant_quantgrouped.cpp gemm_bquant_quantgrouped_prefill_bf8i4.cpp gemm_bquant_quantgrouped_prefill_fp8i4.cpp + gemm_bquant_quantgrouped_prefill_bf16mxfp4.cpp gemm_bquant_quantgrouped_prefill_bf8.cpp gemm_bquant_quantgrouped_prefill_fp8.cpp gemm_bquant_quantgrouped_preshuffleb_prefill.cpp diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 64ecebd15a..5c39bc3dff 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -23,7 +23,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming - **Preshuffled GEMM**: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM. - **TransposeC**: Transpose the C Matrix Output layout to have the best coalesced scale reading - **Preshuffled Quant**: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension. -- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix). +- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix), uint8(split into two fp4 in the pipeline (for B Matrix)). - **Validation**: CPU/GPU validation and error tolerance options. ## build @@ -33,9 +33,9 @@ mkdir build && cd build # you can replace with the appropriate architecture (for example gfx942) or leave it blank ../script/cmake-ck-dev.sh ../ # Compile the quant kernels -make tile_example_gemm_quant_basic -j +make tile_example_gemm_quant -j ``` -This will result in an executable `build/bin/tile_example_gemm_quant_basic` +This will result in an executable `build/bin/tile_example_gemm_quant` ## example ``` @@ -53,7 +53,7 @@ args: -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) - -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4 or bf16f4 (default for both AQuant and Bquant: fp8) -warmup Number of iterations before benchmarking the kernel (default:50) -repeat Number of iterations to benchmark the kernel (default:1000) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf16mxfp4.cpp new file mode 100644 index 0000000000..5e7342d78a --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf16mxfp4.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf16f4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"bf16f4", "bquant", "non-preshuffleb", "1x1x32"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"bf16f4", "bquant", "non-preshuffleb", "1x1x32"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf16f4", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf16f4", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index a35f867f5d..bb86fc8a7f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -84,6 +84,8 @@ void bquant_quantgrouped_fp8i4_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_bf8i4_instance_factory( std::unordered_map>& lut); +void bquant_quantgrouped_bf16f4_instance_factory( + std::unordered_map>& lut); void bquant_quantgrouped_preshuffleb_instance_factory( std::unordered_map>& lut); void quant_rowcol_instance_factory( @@ -110,9 +112,11 @@ int main(int argc, char* argv[]) bquant_quantgrouped_bf8_instance_factory(lut); bquant_quantgrouped_fp8i4_instance_factory(lut); bquant_quantgrouped_bf8i4_instance_factory(lut); + bquant_quantgrouped_bf16f4_instance_factory(lut); bquant_quantgrouped_preshuffleb_instance_factory(lut); quant_rowcol_instance_factory(lut); quant_tensor_instance_factory(lut); + auto key = gen_lut_key(arg_parser); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index cf120e1dd0..29d5bc747e 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -69,8 +69,9 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) { - using ComputeType = - std::conditional_t; + using ComputeType = + std::conditional_t, ADataType, + std::conditional_t>; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); @@ -288,6 +289,12 @@ struct DataTypeTraits static constexpr const char* name = "bf8"; }; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_fp4_raw_t"; +}; + template <> struct DataTypeTraits { diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 5089a6ea9a..efb11f1f92 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -66,7 +66,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str using BaseGemmPipeline = std::conditional_t< GemmConfig::PreshuffleB == true, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>; + std::conditional_t, + ck_tile::BaseMxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -128,11 +130,15 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::AQuantGemmPipelineAgBgCrCompV3, std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; + std::conditional_t, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, + typename TypeConfig::ADataType, typename TypeConfig::BDataType>, ck_tile::tuple<>, typename TypeConfig::AccDataType, typename TypeConfig::CDataType, @@ -188,7 +194,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( args.M, args.K, args.stride_A, is_row_major(ALayout{}))); ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + std::is_same_v ? args.K / 2 : args.K, + args.N, args.stride_B, is_row_major(BLayout{}))); auto size_a_buffer = a_m.get_element_space_size_in_bytes(); auto size_b_buffer = b_n.get_element_space_size_in_bytes(); @@ -404,9 +411,10 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::index_t init_method = arg_parser.get_int("init"); bool flush_cache = arg_parser.get_bool("flush_cache"); int rotating_count = arg_parser.get_int("rotating_count"); - + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_B = ck_tile::get_default_stride((std::is_same_v) ? (K / 2) : K, + N, stride_B, is_row_major(b_layout)); stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); // Conditional stride calculation based on QuantMode @@ -434,7 +442,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::host_tensor_descriptor( + (std::is_same_v) ? (K / 2) : K, + N, stride_B, is_row_major(b_layout))); ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); @@ -478,13 +488,23 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( + *bq_tensor_ptr); } else { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); + } + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); } else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) @@ -684,7 +704,16 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { - ck_tile::reference_gemm_quant) + ck_tile::reference_mxfp4gemm_quant(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); + else + ck_tile::reference_gemm_quant) && GemmConfig::PreshuffleB) { throw std::runtime_error( - "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); + "Preshuffling weight matrix is not supported for AQuant, RowColQuant or bf16_fp4_gemm"); } if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v) + std::is_same_v || + std::is_same_v) { std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 90137331f6..4c5b2b2bfc 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1550,9 +1550,11 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))|| (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 1ef6b040eb..2faec8c946 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -53,7 +53,7 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1 { static_assert( - is_any_of::value, + is_any_of::value, "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; @@ -114,7 +114,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, { static_assert( - is_any_of::value, + is_any_of::value, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 4d0f92f3e0..31ad141da4 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -246,6 +246,63 @@ CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor& a_m_k make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); } +template +CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor& a_m_k, + const HostTensor& q, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto f_mn = [&](auto m, auto n) { + AccDataType v_acc = 0; + AccDataType pasual = 0; + for(std::size_t k = 0; k < (K / 2); k++) + { + using ComputeType = float; + auto b_scale = type_convert(q((2 * k) / QuantGroupSize::kK, n)) - 127; + ComputeType v_a_0, v_a_1; + ComputeType v_b_0, v_b_1; + + v_a_0 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k)))); + v_a_1 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k + 1)))); + + if constexpr(std::is_same_v) + { + auto b_pack = type_convert(b_element_op(b_k_n(k, n))); + auto b_scale_fp4 = type_convert(std::pow(2.0f, b_scale)); + + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + + v_b_0 = type_convert(b_f4_lo) * b_scale_fp4; + v_b_1 = type_convert(b_f4_hi) * b_scale_fp4; + } + + pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1; + v_acc += pasual; + } + c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + }; + + make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); + std::cout << std::endl; +} + template struct typeToStr { static constexpr const char * name = "bf8" template <> struct typeToStr { static constexpr const char * name = "int8"; }; template <> struct typeToStr { static constexpr const char * name = "pk_int4"; }; template <> struct typeToStr { static constexpr const char * name = "pk_fp4"; }; +template <> struct typeToStr { static constexpr const char * name = "pk_fp4_raw"; }; template struct memOpToStr; template <> struct memOpToStr { static constexpr const char * name = "set"; }; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 8a84f7e9bf..5dddb814ad 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -92,11 +92,17 @@ struct CShuffleEpilogue using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; - using ATypeToUse = - std::conditional_t, BDataType, ADataType>; + using ATypeToUse = std::conditional_t || + std::is_same_v, + BDataType, + ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = std::conditional_t || + std::is_same_v || + std::is_same_v, + ADataType, + BDataType>; + using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 75a424e31e..b8cafbd6cc 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -97,7 +97,8 @@ struct BlockUniversalGemmAsBsCr using ATypeToUse = std::conditional_t, BDataType, ADataType>; using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + std::conditional_t || std::is_same_v, + ADataType, BDataType>; using WarpGemm = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index a05e07bbc4..33bee028d5 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -17,10 +17,12 @@ struct GemmPipelineAgBgCrImplBase using BsLayout = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ADataType = remove_cvref_t{}, AsDataType>>; - using ALayout = remove_cvref_t{}, AsLayout>>; - using BDataType = remove_cvref_t{}, BsDataType>>; - using BLayout = remove_cvref_t{}, BsLayout>>; + using ADataType = remove_cvref_t{}, AsDataType>>; + using ALayout = remove_cvref_t{}, AsLayout>>; + using BInDataType = remove_cvref_t{}, BsDataType>>; + using BDataType = + std::conditional_t, ADataType, BInDataType>; + using BLayout = remove_cvref_t{}, BsLayout>>; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; @@ -254,12 +256,17 @@ struct GemmPipelineAgBgCrImplBase }(); auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); + using BLdsDataType = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) return make_static_tile_distribution( - typename InputTileDistributionTraits< - typename BLdsLoadTileDistr::DstrEncode, - typename Problem::BDataType>::TransposedDstrEncode{}); + typename InputTileDistributionTraits::TransposedDstrEncode{}); + else return BLdsLoadTileDistr{}; }(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index fd95958995..ab7fec7ff5 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -281,8 +281,11 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { + using BDataType = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -549,14 +552,18 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() { - using BsLayout = remove_cvref_t; - using BsDataType = remove_cvref_t; + using BsLayout = remove_cvref_t; + using BsDataType = remove_cvref_t; + using BLayout = remove_cvref_t{}, BsLayout>>; + using BInDataType = remove_cvref_t{}, BsDataType>>; + + using BDataType = std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - using BLayout = remove_cvref_t{}, BsLayout>>; - using BDataType = remove_cvref_t{}, BsDataType>>; - if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlock = std::is_same_v + ? Problem::BlockGemmShape::kK / 2 + : Problem::BlockGemmShape::kK; constexpr index_t VecLoadSize = - Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + std::is_same_v + ? 4 + : (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB()); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - - using BLayout = remove_cvref_t< - std::tuple_element_t{}, remove_cvref_t>>; + using BLayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) { @@ -796,10 +807,12 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { - constexpr index_t smem_size_b = - integer_least_multiple(sizeof(typename Problem::BDataType) * - Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, - 16); + using BDataType = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + constexpr index_t smem_size_b = integer_least_multiple( + sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16); return smem_size_b; } @@ -837,8 +850,9 @@ struct UniversalGemmPipelineAgBgCrPolicy using BDataType = remove_cvref_t; using ATypeToUse = std::conditional_t, BDataType, ADataType>; - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = + std::conditional_t || std::is_same_v, + ADataType, BDataType>; using WarpGemm = WarpGemmDispatcher( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + if constexpr(std::is_same_v) + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + else + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); } } } @@ -688,12 +696,20 @@ struct QuantGemmKernel { static_assert(std::is_same_v); using QuantGroupSize = remove_cvref_t; - return make_naive_tensor_view( - bq_ptr, - make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)), - make_tuple(1, kargs.stride_BQ), - number{}, - number<1>{}); + if constexpr(std::is_same_v) + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), + make_tuple(kargs.stride_BQ, 1), + number{}, + number<1>{}); + else + return make_naive_tensor_view( + bq_ptr, + make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)), + make_tuple(1, kargs.stride_BQ), + number{}, + number<1>{}); } else { @@ -757,7 +773,13 @@ struct QuantGemmKernel const auto& b_tensor_view = views.at(I2); if constexpr(std::is_same_v) { - return pad_tensor_view(b_tensor_view, + if constexpr(std::is_same_v) + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + else + return pad_tensor_view(b_tensor_view, make_tuple(number{}, number{}), sequence{}); @@ -885,7 +907,13 @@ struct QuantGemmKernel { if constexpr(std::is_same_v) { - return make_tile_window(b_pad_view, + if constexpr(std::is_same_v) + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + else + return make_tile_window(b_pad_view, make_tuple(number{}, number{}), {i_n, 0}); @@ -912,11 +940,18 @@ struct QuantGemmKernel { static_assert(std::is_same_v); using QuantGroupSize = remove_cvref_t; - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {0, i_n / QuantGroupSize::kN}); + if constexpr(std::is_same_v) + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); + else + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {0, i_n / QuantGroupSize::kN}); } else { @@ -986,7 +1021,7 @@ struct QuantGemmKernel { const auto& bq_block_window = gemm_tile_windows.at(I3); return GemmPipeline{}.template operator()( - a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0); + a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0); } else if constexpr(kQuantType == QuantType::RowColQuant || kQuantType == QuantType::TensorQuant) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp new file mode 100644 index 0000000000..58019d703e --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +template +struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase +{ + using Base = GemmPipelineAgBgCrImplBase; + using ADataType = typename Base::ADataType; + using ALayout = typename Base::ALayout; + using BDataType = typename Base::BDataType; + using BLayout = typename Base::BLayout; + using BlockGemmShape = typename Base::BlockGemmShape; + using QuantGroupSize = remove_cvref_t; + + using BQLayout = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN; + static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK; + + static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize"); + static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize"); + + static_assert(NPerBlock % QuantGroupSize::kN == 0, + "NPerBlock must be a multiple of QuantGroupSize::kN"); + static_assert(KPerBlock % QuantGroupSize::kK == 0, + "KPerBlock must be a multiple of QuantGroupSize::kK"); + + // Create DRAM tile window for BQ + template + CK_TILE_DEVICE constexpr auto + GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const + { + static_assert(std::is_same_v); + + using YPerTile = number; + using XPerTile = number; + + auto bq_copy_dram_window = + make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile(), XPerTile()), + bq_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBQDramTileDistribution()); + return bq_copy_dram_window; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..6ce2ff10fa --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "gemm_group_quant_utils.hpp" + +namespace ck_tile { + +struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy +{ + using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base::I0; + using Base::I1; + using Base::I2; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() + { + using BQLayout = remove_cvref_t; + using BQDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; + + static_assert(std::is_same_v); + return GetABQGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBRegTileDistribution() + { + using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + // Tile: KPerBlock X NPerBlock + if constexpr(std::is_same_v) + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + // Tile: NPerBlock X KPerBlock + else + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() + { + // using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t KScale = KPerBlock / Problem::QuantGroupSize::kK; // k_scale num //2 + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t num_warps = BlockSize / get_warp_size(); + constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size); + constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize; + constexpr index_t K0 = KPerBlock / b_vec; + constexpr index_t K1 = K0 / KScale; + constexpr index_t K3 = K0 / K1; + constexpr index_t K2 = 1; + + constexpr index_t N0 = num_warps / NumWaveGroups; + constexpr index_t N1 = warp_size / K0; + constexpr index_t N2 = NPerBlock / (N0 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 0>>, + tuple, sequence<1, 0, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0, + "KPerWarpGemm must be a multiple of QuantGroupSize!"); + + using WarpGemm = WarpGemmDispatcher; + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v); + + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy< + typename Problem::ADataType, + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>, + typename Problem::CDataType, + BlockWarps, + WarpGemm>; + + return BlockUniversalGemmAsBsCr{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp new file mode 100644 index 0000000000..52b9243fac --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp @@ -0,0 +1,728 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 +template +struct BaseMxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +{ + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + if(has_hot_loop) + { + if(tail_number == ck_tile::TailNumber::Full) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Odd) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Even) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Unsupported tail number for this operation !!!"); + } + } + else + { + if(tail_number == ck_tile::TailNumber::Full) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Odd) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Even) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Unsupported tail number for this operation !!!"); + } + } + } +}; + +template +struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseMxFp4GemmPipelineAgBgCrCompV3 +{ + using Base = BaseGemmPipelineAgBgCrCompV3; + using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BDqDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using QuantGroupSize = remove_cvref_t; + + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); + + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t BQPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using ALayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN; + static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetVectorSizeBQ() + { + return Policy::template GetVectorSizeBQ(); + } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + using Base::PrefetchStages; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + return concat('_', "mxfp4gemm_pipeline_AgBgCrCompV3", + concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', kPadM, kPadN, kPadK), + concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName()); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + constexpr index_t BQ_Buffer_Load_Inst_Num = + NPerBlock * KPerBlockBQ / (BlockSize * GetVectorSizeBQ()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", " + << "BQ vector size: " << GetVectorSizeBQ() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << ", " << "BQ buffer load inst: " << BQ_Buffer_Load_Inst_Num << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "QuantGroupSize: " << QuantGroupSize::GetName() << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / + // sizeof(BDataType) + // ? sizeof(ComputeDataType) / + // sizeof(ADataType) : sizeof(ComputeDataType) + // / sizeof(BDataType); + constexpr auto num_mfma_stage1 = + num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/BQ Dram block window should have the same data type as appropriate " + "([A|B|BQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_bq_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}], + "Bq block window has incorrect lengths for defined BqLayout!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert( + is_b_row_major + ? (KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + // ------------------------------------------------------------------------------------ + // Definitions of all needed tiles + // int b_block_stride = 0; + // A/B tiles in LDS + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load, (kN, kK/2) + // B LDS tile window for store, (kN, kK) + // B LDS tile for block GEMM + auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // B scale DRAM tile window for load + // auto b_scale_copy_dram_window = + // make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + // bq_dram_block_window_tmp.get_window_lengths(), + // bq_dram_block_window_tmp.get_window_origin(), + // Policy::template GetBQDramLoadWindow()); + auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp); + + auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){}; + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + // using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_fp4_block_tile; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2); + + constexpr index_t b_scale_dram_tile_window_step = KPerBlock / QuantGroupSize::kK; + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + + // prefetch + // global read 0 + // auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){}; + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // BDataType + auto b_block_tile = make_static_distributed_tensor( + Policy::template MakeBRegTileDistribution()); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + + constexpr auto idx1_js = tile_distributed_index<0>{}; + constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans(); + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; + auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = + type_convert(type_convert(b_f4_lo) * b_scale); + b_block_tile(i_j_idx_hi) = + type_convert(type_convert(b_f4_hi) * b_scale); + }); + }); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + block_sync_lds(); + + // LDS write 0 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + + auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; + auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = + type_convert(type_convert(b_f4_lo) * b_scale); + b_block_tile(i_j_idx_hi) = + type_convert(type_convert(b_f4_hi) * b_scale); + }); + }); + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + + auto b_scale_uint = + type_convert(bq_block_tile(i_j_idx_scale)) - 127; + auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = + tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = + type_convert(type_convert(b_f4_lo) * b_scale); + b_block_tile(i_j_idx_hi) = + type_convert(type_convert(b_f4_hi) * b_scale); + }); + }); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + // b_block_stride +=1; + } while(i < (num_loop - 1)); + } + // tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile); + // tail + if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) + { + // Leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + else + { + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + } + __builtin_amdgcn_sched_barrier(0); + return c_block_tile; + } + }; + + /** + * @brief This function runs the pipeline using compile-time known hot loop and tail number. + * @param num_loop The number of loop iterations. This is determined at runtime due to e.g. + * SplitK. + * @note This is used by the kernel variants that are able to determine + * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. + */ + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDqDataType& b) { return b; }, + bq_dram_block_window_tmp, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt old mode 100644 new mode 100755 diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp old mode 100644 new mode 100755 index 08232f81be..e91a90ab79 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -127,8 +127,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test const ck_tile::index_t kbatch, const float max_accumulated_value) { - using ComputeType = - std::conditional_t; + using ComputeType = + std::conditional_t, ADataType_, + std::conditional_t>; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp old mode 100644 new mode 100755 index 0be276de8d..5f9606ae4b --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -42,7 +42,7 @@ struct GemmConfigBase // Default GEMM tile sizes for tests static constexpr ck_tile::index_t M_Tile = 16; static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; @@ -382,7 +382,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase ? (K / 2) : K; const ck_tile::index_t stride_C = N; // BQuant uses block/grouped quantization for B matrix @@ -394,14 +394,24 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + ck_tile::host_tensor_descriptor(std::is_same_v ? K / 2 : K, + N, stride_B, this->is_row_major(BLayout{}))); ck_tile::HostTensor bq_bqk_bqn( ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BLayout{}))); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); - ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); - ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f}(bq_bqk_bqn); + } + else + { + ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); + } + // Allocate device memory ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType)); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType)); @@ -474,13 +484,22 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + if constexpr(std::is_same_v) + ck_tile::reference_mxfp4gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + else + ck_tile::reference_gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); // Get device result ck_tile::HostTensor c_m_n_dev_result( @@ -528,7 +547,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase, + std::conditional_t, + ck_tile::BaseMxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BaseGemmPipelineAgBgCrCompV3>, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>; const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; @@ -555,12 +576,15 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase, + std::conditional_t, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>, ck_tile::WPQuantBPipelineAgBgCrV2>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, + ADataType, BDataType>, ck_tile::tuple<>, AccDataType, CDataType, diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp old mode 100644 new mode 100755 index 3ace9188cc..2eb2778a78 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp @@ -16,12 +16,15 @@ using FP8 = ck_tile::fp8_t; using BF8 = ck_tile::bf8_t; using Half = ck_tile::half_t; using PkInt4 = ck_tile::pk_int4_t; +using BF16 = ck_tile::bf16_t; +using UInt8 = ck_tile::pk_fp4_raw_t; using AQuantGrouped = std::integral_constant; using BQuantGrouped = std::integral_constant; using RowColQuant = std::integral_constant; using TensorQuant = std::integral_constant; using GroupSize = ck_tile::QuantGroupShape>; using GroupSize64 = ck_tile::QuantGroupShape>; +using GroupSize32 = ck_tile::QuantGroupShape>; // 2d block sizes for BQuant using GroupSize2D8N = ck_tile::QuantGroupShape>; @@ -39,6 +42,7 @@ using AQuantTypes = ::testing::Types< std::tuple, std::tuple, + // PreshuffleQuant = false && TransposeC = true std::tuple, std::tuple, @@ -66,11 +70,15 @@ using BQuantTypes = ::testing::Types< std::tuple, std::tuple, std::tuple, + std::tuple, std::tuple, std::tuple, std::tuple, std::tuple, + std::tuple, + + std::tuple, // 2d cases with grouping also on the n axis std::tuple, diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc b/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc old mode 100644 new mode 100755