-
Notifications
You must be signed in to change notification settings - Fork 570
[Common] Persistent NVFP4 cast + transpose kernel #2412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Common] Persistent NVFP4 cast + transpose kernel #2412
Conversation
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryIntroduces a specialized persistent CUDA kernel for NVFP4 quantization + transpose operations optimized for Blackwell architecture (sm_100+). The implementation leverages Blackwell-specific features including TMA (Tensor Memory Accelerator), asynchronous barriers, and cluster launch control to achieve performance improvements for BF16 inputs with 128×128 tile sizes. Key Changes:
The implementation has proper multi-layer protection: runtime device capability checks in the dispatch layer, compile-time architecture guards in the kernel, and constexpr guards in PTX wrappers. Confidence Score: 5/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant DispatchFwd as dispatch/quantize.cuh<br/>(quantize_fwd_helper)
participant DispatchBwd as dispatch/quantize.cuh<br/>(quantize_bwd_helper)
participant Dispatcher as nvfp4/quantize_transpose_nvfp4.cuh<br/>(quantize_transpose)
participant PersistentKernel as specialized/quantize_transpose_nvfp4_persistent_1D.cuh
participant OldKernel as nvfp4/quantize_transpose_nvfp4.cuh<br/>(existing kernel)
User->>DispatchFwd: Forward pass quantization
DispatchFwd->>DispatchFwd: Check is_supported_by_CC_100()
alt sm_100+ AND dtype==BF16 AND dims % 32 == 0
DispatchFwd->>Dispatcher: quantize_transpose<use_2d_quantization>()
Dispatcher->>Dispatcher: Check conditions:<br/>!use_2d_quantization AND<br/>dtype==BF16 AND<br/>return_transpose
alt Conditions met for persistent kernel
Dispatcher->>PersistentKernel: quantize_transpose_persistent_1D()
Note over PersistentKernel: __CUDA_ARCH__ >= 1000 guard<br/>Uses TMA, mbarriers,<br/>cluster launch control
PersistentKernel-->>Dispatcher: Optimized BF16→NVFP4 + transpose
else Use general kernel
Dispatcher->>OldKernel: Launch general quantize kernel
OldKernel-->>Dispatcher: Result
end
Dispatcher-->>DispatchFwd: Result
else Use fallback
DispatchFwd->>OldKernel: quantize_transpose_vector_blockwise_fp4()
OldKernel-->>DispatchFwd: Result
end
DispatchFwd-->>User: Quantized output
User->>DispatchBwd: Backward pass quantization
Note over DispatchBwd: Same flow as forward pass
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
Outdated
Show resolved
Hide resolved
Signed-off-by: Oleg Goncharov <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
Outdated
Show resolved
Hide resolved
Signed-off-by: Oleg Goncharov <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Oleg Goncharov <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Oleg Goncharov <[email protected]>
445c870 to
a7a0652
Compare
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/cast/dispatch/quantize.cuh, line 250 (link)logic: missing
is_supported_by_CC_100()check. forward pass (line 104) has this check, but backward pass doesn't. without it, the persistent kernel can be dispatched on non-Blackwell hardware and hitNVTE_DEVICE_ERRORat runtime.
5 files reviewed, 1 comment
Signed-off-by: Oleg Goncharov <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
|
/te-ci |
| bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && | ||
| (cols % 32 == 0) && output_tensor->has_data(); | ||
| (cols % 32 == 0) && output_tensor->has_data() && | ||
| is_supported_by_CC_100(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need that additional check? NVFP4 is only supported on SM100+ anyway.
| : "=r"(reinterpret_cast<uint32_t &>(out)) | ||
| : "f"(in)); | ||
| #endif | ||
| constexpr bool has_redux_f32 = ARCH_HAS_REDUX_F32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm changing this in #2482.
| } | ||
|
|
||
| // Loads 8x BF16 values from shared memory state space | ||
| __device__ __forceinline__ void ld_shared_b128(uint64_t &elts03, uint64_t &elts47, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a consistent API between those functions? Some return value, some output things in the arguments.
Description
This PR introduces a specialized persistent CUDA kernel optimized for NVFP4 quantization + transpose of BF16 inputs on Blackwell architecture (sm100f family). The implementation achieves performance improvements by leveraging architecture-specific features.
RN - round-to-nearest mode
SR - stochastic rounding
Type of change
Changes
Checklist: