-
Notifications
You must be signed in to change notification settings - Fork 252
Bf16*fp4 gemm #2801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Bf16*fp4 gemm #2801
Conversation
spolifroni-amd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Readme looks ok.
ThomasNing
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add the gtest for the new developed kernel?
| #include "ck_tile/host.hpp" | ||
| #include "gemm_utils.hpp" | ||
|
|
||
| template <typename GemmConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also put the mx gemm into the example of blockscale gemm and share the util and example.inc code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, for reusing the code, my understanding is that we need to add our own mx_gemm.cpp entry interface within example/38_* instead of defining a new example/45_* like we are doing now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we do not need to create a mxfp4_gemm example operator. We should just have a .cpp file under example/38_*/mx_gemm.cpp. The datatype should also be a configuration to that example..
| make_tuple(kargs.stride_B, 1), | ||
| number<GemmPipeline::GetVectorSizeB()>{}, | ||
| number<1>{}); | ||
| if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why in fp4 we will transposed in the data type comparing to other data types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the original gemm_quant kernel, we noticed that for the tensor_view definitions of B and Bq, B's shape is taken as (N, K), while Bq is (K, N). We don't quite understand why Bq needs to be transposed here, as in our implementation both B and Bq are defined as (N, K).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. After reviewing the solution, we could reunify into one version without the transpose. I will create a PR to unify that soon, so we do not need that unnecessary transpose branch.
|
Hi @eliotwang, please resolve conflicts and sync branch to latest develop in order to proceed! Thanks! |
|
@eliotwang LGTM overall. Please add the unit test. |
We have added unit tests for bf16_mxfp4_gemm in the test/ck_tile/gemm_block_scale/ directory. Please help review it. |
| { | ||
| using ComputeType = | ||
| std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>; | ||
| // using ComputeType = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Leftovers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has been updated. Please help review it.
|
@eliotwang LGTM, we could do the last iteration of the merging after the PR #3245 merged to the develop. Thanks! cc. @CongMa13 |
Proposed changes
Added an example of bf16*fp4 gemm, where fp4 and fp4_scale are in uint8 data format. In the pipeline, matrix B(fp4) will be dequantized to bf16 before performing multiplication operations.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered