Skip to content

Conversation

@eliotwang
Copy link

Proposed changes

Added an example of bf16*fp4 gemm, where fp4 and fp4_scale are in uint8 data format. In the pipeline, matrix B(fp4) will be dequantized to bf16 before performing multiplication operations.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

spolifroni-amd
spolifroni-amd previously approved these changes Sep 8, 2025
Copy link
Contributor

@spolifroni-amd spolifroni-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Readme looks ok.

@eliotwang eliotwang changed the title Bf16 fp4 gemm Bf16*fp4 gemm Sep 9, 2025
Copy link
Contributor

@ThomasNing ThomasNing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add the gtest for the new developed kernel?

#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"

template <typename GemmConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also put the mx gemm into the example of blockscale gemm and share the util and example.inc code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, for reusing the code, my understanding is that we need to add our own mx_gemm.cpp entry interface within example/38_* instead of defining a new example/45_* like we are doing now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we do not need to create a mxfp4_gemm example operator. We should just have a .cpp file under example/38_*/mx_gemm.cpp. The datatype should also be a configuration to that example..

make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why in fp4 we will transposed in the data type comparing to other data types?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original gemm_quant kernel, we noticed that for the tensor_view definitions of B and Bq, B's shape is taken as (N, K), while Bq is (K, N). We don't quite understand why Bq needs to be transposed here, as in our implementation both B and Bq are defined as (N, K).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. After reviewing the solution, we could reunify into one version without the transpose. I will create a PR to unify that soon, so we do not need that unnecessary transpose branch.

@illsilin
Copy link
Collaborator

Hi @eliotwang, please resolve conflicts and sync branch to latest develop in order to proceed! Thanks!

@ThomasNing
Copy link
Contributor

@eliotwang LGTM overall. Please add the unit test.

@eliotwang eliotwang closed this Nov 18, 2025
@eliotwang eliotwang reopened this Nov 18, 2025
@eliotwang eliotwang closed this Nov 19, 2025
@eliotwang eliotwang reopened this Nov 19, 2025
@eliotwang
Copy link
Author

@eliotwang LGTM overall. Please add the unit test.

We have added unit tests for bf16_mxfp4_gemm in the test/ck_tile/gemm_block_scale/ directory. Please help review it.

{
using ComputeType =
std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>;
// using ComputeType =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Leftovers?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has been updated. Please help review it.

@ThomasNing
Copy link
Contributor

@eliotwang LGTM, we could do the last iteration of the merging after the PR #3245 merged to the develop. Thanks!

cc. @CongMa13

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants