-
Notifications
You must be signed in to change notification settings - Fork 516
[Draft][PyTorch][MOE] Support NVFP4 Grouped Linear #2215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
/te-ci pytorch L1 |
Signed-off-by: Zhongbo Zhu <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Zhongbo Zhu <[email protected]>
3c9e5ea
to
4ac9df6
Compare
Signed-off-by: Zhongbo Zhu <[email protected]>
/te-ci pytorch L1 |
…ck the vec_load_size to 1 to unblock Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
// Current unit test won't capture this issue, but in E2E | ||
// using vec_load_size = 1 other than 1 will lead to mis-aligned | ||
// address error in MOE training | ||
vec_load_size = all_nvfp4 ? 1 : std::min(vec_load_size, vec_load_size_i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yaox12 do you have any idea why? Note that NVFP4 is TN only, ie. there will be transpose happening, unlike MXFP8. Plus that the error only happens for WGRAD. Leaving me to believe maybe it's because padding to 32 for NVFP4 in m_splits
is not enough if we want vec_load_size
more than 1.
So then I increased the padding from 32 to 64, 128. I found that only 128 works. However, I haven't really figured out why the vec_load_size
calculation logic is wrong, so I am overriding it to 1 as hack.
Signed-off-by: Zhongbo Zhu <[email protected]>
// Check for size (not just pointer) for 0-dim or no token cases. | ||
bool has_data() const noexcept { return data.dptr != nullptr || data.shape.size() != 0; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mathematically, a 0-D tensor is a scalar with 1 entry.
// Check for size (not just pointer) for 0-dim or no token cases. | |
bool has_data() const noexcept { return data.dptr != nullptr || data.shape.size() != 0; } | |
bool has_data() const noexcept { return data.dptr != nullptr; } |
TensorWrapper fake_te_output( | ||
nullptr, te_input.shape(), | ||
amax_ptr, te_input.shape(), | ||
DType::kFloat8E4M3, // It doesn't matter because we only compute amax. | ||
amax.data_ptr<float>()); | ||
amax_ptr, nullptr, amax_ptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This horrifying hack is needed because the tensor checking functions assume that the output tensor requires data:
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!"); |
The right answer is to modify the API for nvte_compute_amax
so that the output tensor is an FP32 tensor with one entry. We might use that amax value later to compute an FP8 tensor, an NVFP4 tensor, whatever, but that is completely irrelevant.
Signed-off-by: Zhongbo Zhu <[email protected]>
/te-ci pytorch L1 |
Description
NVFP4 Group Linear Support.
Fixes # (issue)
Type of change
Unit test
Checklist: