Skip to content

Conversation

@Oleg-Goncharov
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov commented Nov 21, 2025

Description

This PR introduces a specialized persistent CUDA kernel optimized for NVFP4 quantization + transpose of BF16 inputs on Blackwell architecture (sm100f family). The implementation achieves performance improvements by leveraging architecture-specific features.

NVFP4 kernel performance

RN - round-to-nearest mode
SR - stochastic rounding

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added a specialized kernel
  • Added the logic to use it when the conditions are met

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 21, 2025

Greptile Overview

Greptile Summary

Introduces a specialized persistent CUDA kernel for NVFP4 quantization + transpose operations optimized for Blackwell architecture (sm_100+). The implementation leverages Blackwell-specific features including TMA (Tensor Memory Accelerator), asynchronous barriers, and cluster launch control to achieve performance improvements for BF16 inputs with 128×128 tile sizes.

Key Changes:

  • Added is_supported_by_CC_100() runtime checks in dispatch layer to prevent execution on incompatible GPUs
  • New persistent kernel uses compile-time guards (__CUDA_ARCH__ >= 1000) and architecture-specific PTX wrappers with ARCH_BLACKWELL_FAMILY checks
  • All Blackwell-specific PTX instructions properly guarded with if constexpr checks or #if preprocessor directives
  • Memory alignment helper added for TMA's 16-byte alignment requirements
  • The PTX change from __CUDA_ARCH_HAS_FEATURE__ macros to ARCH_HAS_REDUX_F32 improves architecture detection consistency

The implementation has proper multi-layer protection: runtime device capability checks in the dispatch layer, compile-time architecture guards in the kernel, and constexpr guards in PTX wrappers.

Confidence Score: 5/5

  • Safe to merge - well-architected performance optimization with comprehensive architecture guards
  • The PR demonstrates excellent engineering with multi-layer architecture protection (runtime checks in dispatch, compile-time guards in kernel, constexpr checks in PTX), proper use of Blackwell-specific features, and comprehensive error handling. All architecture-specific code paths have appropriate fallbacks.
  • No files require special attention - all changes follow established patterns with proper guards

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/cast/dispatch/quantize.cuh 5/5 Added is_supported_by_CC_100() check to ensure optimized kernel only runs on Blackwell+ GPUs (sm_100+)
transformer_engine/common/cast/core/common.cuh 5/5 Added helper function to align shared memory pointers for TMA requirements (16-byte alignment)
transformer_engine/common/util/ptx.cuh 5/5 Added Blackwell-specific PTX wrappers for mbarrier ops, cluster launch control, BF16→FP4 conversion, and shared memory ops with proper architecture guards
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh 5/5 Added dispatch logic to route BF16 1D quantization with transpose to specialized persistent kernel, already protected by upstream is_supported_by_CC_100() check
transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_persistent_1D.cuh 5/5 New persistent kernel for NVFP4 quantization+transpose optimized for Blackwell with TMA, async copies, mbarriers, and cluster launch control. Protected by __CUDA_ARCH__ >= 1000 guard

Sequence Diagram

sequenceDiagram
    participant User
    participant DispatchFwd as dispatch/quantize.cuh<br/>(quantize_fwd_helper)
    participant DispatchBwd as dispatch/quantize.cuh<br/>(quantize_bwd_helper)
    participant Dispatcher as nvfp4/quantize_transpose_nvfp4.cuh<br/>(quantize_transpose)
    participant PersistentKernel as specialized/quantize_transpose_nvfp4_persistent_1D.cuh
    participant OldKernel as nvfp4/quantize_transpose_nvfp4.cuh<br/>(existing kernel)
    
    User->>DispatchFwd: Forward pass quantization
    DispatchFwd->>DispatchFwd: Check is_supported_by_CC_100()
    alt sm_100+ AND dtype==BF16 AND dims % 32 == 0
        DispatchFwd->>Dispatcher: quantize_transpose<use_2d_quantization>()
        Dispatcher->>Dispatcher: Check conditions:<br/>!use_2d_quantization AND<br/>dtype==BF16 AND<br/>return_transpose
        alt Conditions met for persistent kernel
            Dispatcher->>PersistentKernel: quantize_transpose_persistent_1D()
            Note over PersistentKernel: __CUDA_ARCH__ >= 1000 guard<br/>Uses TMA, mbarriers,<br/>cluster launch control
            PersistentKernel-->>Dispatcher: Optimized BF16→NVFP4 + transpose
        else Use general kernel
            Dispatcher->>OldKernel: Launch general quantize kernel
            OldKernel-->>Dispatcher: Result
        end
        Dispatcher-->>DispatchFwd: Result
    else Use fallback
        DispatchFwd->>OldKernel: quantize_transpose_vector_blockwise_fp4()
        OldKernel-->>DispatchFwd: Result
    end
    DispatchFwd-->>User: Quantized output
    
    User->>DispatchBwd: Backward pass quantization
    Note over DispatchBwd: Same flow as forward pass
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Oleg Goncharov <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Oleg Goncharov <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Oleg-Goncharov Oleg-Goncharov changed the title [Common] Persistent NVFP4 kernel [Common] Persistent NVFP4 cast + transpose kernel Nov 22, 2025
Signed-off-by: Oleg Goncharov <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Oleg Goncharov <[email protected]>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_nvfp4_persistent_kernel branch from 445c870 to a7a0652 Compare November 22, 2025 01:34
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/common/cast/dispatch/quantize.cuh, line 250 (link)

    logic: missing is_supported_by_CC_100() check. forward pass (line 104) has this check, but backward pass doesn't. without it, the persistent kernel can be dispatched on non-Blackwell hardware and hit NVTE_DEVICE_ERROR at runtime.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Oleg Goncharov <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

@ptrendx ptrendx added the 2.11.0 label Nov 25, 2025
@Oleg-Goncharov Oleg-Goncharov added performance Performance issues enhancement New feature or request labels Dec 4, 2025
bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) &&
(cols % 32 == 0) && output_tensor->has_data();
(cols % 32 == 0) && output_tensor->has_data() &&
is_supported_by_CC_100();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need that additional check? NVFP4 is only supported on SM100+ anyway.

: "=r"(reinterpret_cast<uint32_t &>(out))
: "f"(in));
#endif
constexpr bool has_redux_f32 = ARCH_HAS_REDUX_F32;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm changing this in #2482.

}

// Loads 8x BF16 values from shared memory state space
__device__ __forceinline__ void ld_shared_b128(uint64_t &elts03, uint64_t &elts47,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a consistent API between those functions? Some return value, some output things in the arguments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.11.0 enhancement New feature or request performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants