-
Notifications
You must be signed in to change notification settings - Fork 570
[PyTorch] Add THD support for max_logit/MuonClip #2480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Charlene Yang <[email protected]>
Greptile OverviewGreptile SummaryThis PR extends THD (Total-sequence-length, Head, Dimension) format support for the Important Files Changed
Confidence score: 5/5
Sequence DiagramsequenceDiagram
participant User
participant DotProductAttention as "DotProductAttention"
participant FusedAttention as "FusedAttention Backend"
participant cuDNN as "cuDNN Frontend"
participant GPU as "GPU Kernels"
User->>DotProductAttention: "forward() with THD format, return_max_logit=True"
DotProductAttention->>DotProductAttention: "Check qkv_format == 'thd' and return_max_logit"
DotProductAttention->>FusedAttention: "Call fused_attn_fwd() with THD tensors"
FusedAttention->>FusedAttention: "Detect cuDNN runtime >= 90600"
FusedAttention->>FusedAttention: "Set Stats/Max shape to [num_tokens_q, h, 1]"
FusedAttention->>FusedAttention: "Apply ragged offset for THD format"
FusedAttention->>cuDNN: "Create execution graph with updated tensor layouts"
cuDNN->>GPU: "Execute attention kernels with THD support"
GPU-->>cuDNN: "Return attention output and max logits"
cuDNN-->>FusedAttention: "Return computed tensors"
FusedAttention->>FusedAttention: "Extract max_logit from Stats tensor [tq, h, 1] -> [h]"
FusedAttention-->>DotProductAttention: "Return output and max_logit"
DotProductAttention-->>User: "Return attention output with max_logit"
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
|
/te-ci pytorch L1 |
Signed-off-by: Charlene Yang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Description
This PR continues the work in #2195 and extends the support for
max_logit(used in MuonClip) to THD format and both non-CP and CP cases (cp_comm_type = {'p2p', 'a2a', 'all_gather', 'a2a_p2p'}).Type of change
Changes
Please list the changes introduced in this PR:
max_logitStats(and subsequentlyMax,Sum_Exp) frommax_tokens_q, h, 1tonum_tokens_q, h, 1Checklist: