Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
d1bf200
support bf16*mxfp4 gemm
Sep 5, 2025
4e205c4
rebase bf16*fp4 example to develop branch
k50112113 Sep 8, 2025
52c5ed5
Clean up commented debug code in GEMM kernel
eliotwang Sep 8, 2025
e1d0365
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 8, 2025
43db1f7
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 9, 2025
ff89459
rename example folder
Sep 9, 2025
1409d62
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 9, 2025
637f2e8
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 10, 2025
30450e3
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 10, 2025
291e36b
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 11, 2025
d2c79f8
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 12, 2025
ba84541
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 16, 2025
e31f9df
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 22, 2025
f2c0d77
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 24, 2025
f65c005
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Sep 25, 2025
78ae8aa
support bf16*mxfp4 gemm
Sep 5, 2025
e953070
rebase bf16*fp4 example to develop branch
k50112113 Sep 8, 2025
9d01db5
Clean up commented debug code in GEMM kernel
eliotwang Sep 8, 2025
28d4d24
rename example folder
Sep 9, 2025
8229d64
rebase to new develop
k50112113 Oct 10, 2025
23c89b3
rebase to new develop
k50112113 Oct 10, 2025
ec53824
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 10, 2025
adb4bc3
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 16, 2025
5304448
Merge branch 'develop' into bf16_fp4_gemm
illsilin Oct 16, 2025
75dbf17
fix clang format
illsilin Oct 16, 2025
3efca0f
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 17, 2025
8aec6b9
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 20, 2025
5e91e9e
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 21, 2025
cba5ab1
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 22, 2025
5697816
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 23, 2025
8f272b3
Merge branch 'develop' into bf16_fp4_gemm
illsilin Oct 23, 2025
8bad07a
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 27, 2025
03406c0
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Oct 29, 2025
984ed9f
Merge remote-tracking branch 'upstream/develop' into bf16_fp4_gemm
eliotwang Nov 11, 2025
46ff36a
update code according to reviewer's comment
eliotwang Nov 11, 2025
01f5c75
Update README.md
eliotwang Nov 11, 2025
92d7082
update code according to reviewer's comment
eliotwang Nov 13, 2025
7e32cb9
Merge branch 'bf16_fp4_gemm' of https://github.com/eliotwang/heyi_com…
eliotwang Nov 13, 2025
48e6393
update code according to reviewer's comment
eliotwang Nov 13, 2025
88c6a8c
Update CMakeLists.txt
eliotwang Nov 13, 2025
87c4e07
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 13, 2025
6883654
Merge remote-tracking branch 'upstream/develop' into bf16_fp4_gemm
eliotwang Nov 14, 2025
3579741
Update README.md
eliotwang Nov 14, 2025
9579e6f
Update CMakeLists.txt
eliotwang Nov 14, 2025
8a4ac27
Delete files
eliotwang Nov 14, 2025
f6ffb76
Delete files
eliotwang Nov 14, 2025
5225b4d
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 17, 2025
eb15154
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 18, 2025
ba12e7d
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 18, 2025
fde6e39
Add unit tests
eliotwang Nov 19, 2025
f54857f
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 19, 2025
7ceedeb
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 19, 2025
d5ce464
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 20, 2025
329c601
Update test_gemm_quant_base.hpp
eliotwang Nov 20, 2025
8c75bc1
Merge branch 'develop' into bf16_fp4_gemm
eliotwang Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions example/ck_tile/40_fp4_uint8_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_executable(tile_example_fp4_uint8_gemm EXCLUDE_FROM_ALL fp4_uint8_gemm.cpp)

set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
target_compile_options(tile_example_fp4_uint8_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
37 changes: 37 additions & 0 deletions example/ck_tile/40_fp4_uint8_gemm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# GEMM Matrix Multiplication

This folder contains example for GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile.

## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation
make tile_example_gemm_universal -j
```
This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal`

## example
```
args:
-b batch size (default:1)
-m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: C)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```
25 changes: 25 additions & 0 deletions example/ck_tile/40_fp4_uint8_gemm/benchmark_universal_fp4.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

EXE="$(find . -name tile_example_fp4_uint8_gemm -type f | head -n 1)"
VALID=1

N1=5888
K1=3072
N2=3072
K2=2944


#m_values=(1 16 32 64 128 256 512 1024 4096 16384)
m_values=(1 16 64 256 512 1024 4096 16384)

for m in "${m_values[@]}"; do
#echo "Running tests for m=$m"

# echo "Running test with m=$m, n=$N1, k=$K1"
$EXE -prec=pk_fp4_t -m=$m -n=$N1 -k=$K1 -v=1

# echo "Running test with m=$m, n=$N2, k=$K2"
$EXE -prec=pk_fp4_t -m=$m -n=$N2 -k=$K2 -v=1

# echo "Finished tests for m=$m"
done
185 changes: 185 additions & 0 deletions example/ck_tile/40_fp4_uint8_gemm/block_quant_gemm_kernel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <iostream>
#include <string>

#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "block_quant_universal_gemm_kernel.hpp"
#include "ck_tile/core/utility/type_traits.hpp"

namespace ck_tile {

/// @brief The GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments
/// object. It contain all necessary information required to build proper kernel argument
/// and launch kernel on GPU.
/// This structure defines the GEMM problem configuration by stating all required information
/// like M,N,K sizes and respective strides.
struct Block_quant_GemmHostArgs
{
CK_TILE_HOST Block_quant_GemmHostArgs() = default;
CK_TILE_HOST Block_quant_GemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* e_ptr_,
const void* a_scale_ptr_,
const void* b_scale_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_E_,
index_t stride_B_scale_,
index_t ScaleBlockSize_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
e_ptr(e_ptr_),
a_scale_ptr(a_scale_ptr_),
b_scale_ptr(b_scale_ptr_),
M(M_),
N(N_),
K(K_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_E(stride_E_),
stride_B_scale(stride_B_scale_),
ScaleBlockSize(ScaleBlockSize_),
k_batch(k_batch_)
{
}

const void* a_ptr;
const void* b_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
const void* a_scale_ptr;
const void* b_scale_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
union
{
index_t stride_E;
index_t stride_C;
};

index_t stride_B_scale;
index_t ScaleBlockSize;
index_t k_batch;
};

template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct Block_quant_GemmKernel
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using UniversalGemmKernel =
Block_quant_UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;

using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;

/// @brief Specify the layout configurations for A, B, E and D
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;

/// @brief Specify the data type configurations for A, B, E and D
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BInDataType = remove_cvref_t<typename GemmPipeline::BInDataType>;
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;

/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");

/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BInDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");

/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, ELayout>::value &&
!is_detected<is_tuple, EDataType>::value,
"C/ELayout and C/EDataType must be scalars.");

static constexpr index_t NumATensor = 1;
static constexpr index_t NumBTensor = 1;
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;

CK_TILE_HOST static auto GetName() -> const std::string
{
return UniversalGemmKernel::GetName();
}

CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
{
return UniversalGemmKernel::GridSize(M, N, KBatch);
}

CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
return UniversalGemmKernel::MaxOccupancyGridSize(s);
}

CK_TILE_HOST static constexpr auto BlockSize() -> dim3
{
return UniversalGemmKernel::BlockSize();
}

CK_TILE_HOST static constexpr auto MakeKernelArgs(const Block_quant_GemmHostArgs& hostArgs) ->
typename UniversalGemmKernel::KernelArgs
{
/// @brief Universal GEMM requires array objects and corresponding stride information for
/// matrices A, B.
return UniversalGemmKernel::MakeKernelArgs(
Block_quant_UniversalGemmHostArgs<NumATensor, NumBTensor /*NumDTensor = 0 */>(
{hostArgs.a_ptr},
{hostArgs.b_ptr},
{},
hostArgs.e_ptr,
hostArgs.a_scale_ptr,
hostArgs.b_scale_ptr,
hostArgs.k_batch,
hostArgs.M,
hostArgs.N,
hostArgs.K,
{hostArgs.stride_A},
{hostArgs.stride_B},
{/*hostArgs.stride_Ds*/},
hostArgs.stride_E,
hostArgs.stride_A,
hostArgs.stride_B_scale,
hostArgs.ScaleBlockSize));
}

CK_TILE_HOST static auto
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
{
return UniversalGemmKernel::IsSupportedArgument(kargs);
}

CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
{
UniversalGemmKernel{}.template operator()(kargs);
}
};
} // namespace ck_tile
Loading
Loading