Skip to content

Conversation

@cyanguwa
Copy link
Collaborator

@cyanguwa cyanguwa commented Dec 4, 2025

Description

This PR continues the work in #2195 and extends the support for max_logit (used in MuonClip) to THD format and both non-CP and CP cases (cp_comm_type = {'p2p', 'a2a', 'all_gather', 'a2a_p2p'}).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Updated cudnn-frontend to enable THD support for max_logit
  • Changed the shape of Stats (and subsequently Max, Sum_Exp) from max_tokens_q, h, 1 to num_tokens_q, h, 1

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@cyanguwa cyanguwa added the 2.11.0 label Dec 5, 2025
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 5, 2025

Greptile Overview

Greptile Summary

This PR extends THD (Total-sequence-length, Head, Dimension) format support for the max_logit feature used by the MuonClip optimizer. The changes build upon previous work (#2195) to enable THD format across all context parallelism types. The core modifications involve simplifying the tensor architecture from a dual Max/Sum_Exp approach to a unified Stats tensor approach, updating tensor shapes from max_tokens_q to num_tokens_q for proper ragged tensor handling, and removing backend restrictions that previously disabled FusedAttention and UnfusedDotProductAttention for THD format with max_logit. The changes span Python interface layers, backend selection logic, and CUDA kernel implementations, creating a more efficient and unified approach to statistics generation while maintaining backward compatibility.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/cpp_extensions/fused_attn.py 5/5 Updates Python interface to handle new CUDA kernel tensor ordering, extracting max_logit from third output tensor instead of second
transformer_engine/pytorch/attention/dot_product_attention/utils.py 5/5 Removes backend restrictions that disabled FusedAttention and UnfusedDotProductAttention for THD format with max_logit
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu 5/5 Implements CUDA kernel changes for THD format support, updates tensor shapes from max_tokens_q to num_tokens_q, simplifies stats generation

Confidence score: 5/5

  • This PR is safe to merge with minimal risk as it extends existing functionality without breaking changes
  • Score reflects well-structured changes that properly handle tensor format differences and maintain API compatibility
  • No files require special attention as all changes are focused, well-commented, and follow established patterns in the codebase

Sequence Diagram

sequenceDiagram
    participant User
    participant DotProductAttention as "DotProductAttention"
    participant FusedAttention as "FusedAttention Backend"
    participant cuDNN as "cuDNN Frontend"
    participant GPU as "GPU Kernels"

    User->>DotProductAttention: "forward() with THD format, return_max_logit=True"
    DotProductAttention->>DotProductAttention: "Check qkv_format == 'thd' and return_max_logit"
    DotProductAttention->>FusedAttention: "Call fused_attn_fwd() with THD tensors"
    FusedAttention->>FusedAttention: "Detect cuDNN runtime >= 90600"
    FusedAttention->>FusedAttention: "Set Stats/Max shape to [num_tokens_q, h, 1]"
    FusedAttention->>FusedAttention: "Apply ragged offset for THD format"
    FusedAttention->>cuDNN: "Create execution graph with updated tensor layouts"
    cuDNN->>GPU: "Execute attention kernels with THD support"
    GPU-->>cuDNN: "Return attention output and max logits"
    cuDNN-->>FusedAttention: "Return computed tensors"
    FusedAttention->>FusedAttention: "Extract max_logit from Stats tensor [tq, h, 1] -> [h]"
    FusedAttention-->>DotProductAttention: "Return output and max_logit"
    DotProductAttention-->>User: "Return attention output with max_logit"
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 5, 2025

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant