Skip to content

Add fused mHC cuTile kernels (stacks on #4483)#4527

Closed
Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Connor-XY:yxu1/mhc-fused-kernels-stacked-dsv4
Closed

Add fused mHC cuTile kernels (stacks on #4483)#4527
Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Connor-XY:yxu1/mhc-fused-kernels-stacked-dsv4

Conversation

@Connor-XY
Copy link
Copy Markdown

Summary

Adds the fused cuTile kernel path for mHC primitives. Carved out of PR #4469 so the cuTile kernel and its tests can be reviewed separately from the transformer core mHC reference implementation.

This PR's content (the only files this PR is asking to add or modify):

  • megatron/core/fusions/fused_mhc_kernels.py (new): cuTile kernels for Sinkhorn, H_aggregate, H_post_bda, proj_rms. Selected at config time via use_fused_mhc=True.
  • tests/unit_tests/fusions/test_fused_mhc_kernels.py (new): forward / backward parity vs the native reference impl across the supported shapes.

Stacking note

This branch is built on top of #4483 (yxu1/mhc-transformer-core-code-dsv4) so the test file's from megatron.core.transformer.hyper_connection import native_h_aggregate, … imports resolve. The "Files changed" tab consequently shows #4483's content as ancestry — please review only the two files listed above and rely on #4483 for the rest.

This PR cannot merge before #4483. Once #4483 lands, rebase / merge will leave only the two new files as this PR's diff.

Origin

Carved out of PR #4469 (yxu1/mhc-hybridmodel-dsv4) at commit e3d0102ad. Includes the fused-kernel-specific strict-review fixes collected over passes 3–11 of /claude strict-review:

  • eps plumbed through proj_rms backward kernels (closes a fwd/bwd eps mismatch)
  • explicit TILE_SIZE=1 assertion on _ct_hpb_*_kernel paths
  • fp32 reduction for the g_bias accumulator
  • math.gcd(sb, 4) tiling for the aggregate kernels

Validation

python3 -m compileall on the touched files. Functional / GPU pytest is in a separate stacked PR.

🤖 Generated by Claude Opus 4.7 (1M context).

Connor-XY and others added 10 commits April 27, 2026 09:15
Adds fused cuTile kernel implementations for the mHC primitives:
sinkhorn, h_aggregate, h_post_bda, and proj_rms. Selected at config
time via `use_fused_mhc=True`; the reference path in
`hyper_connection.py` remains the default.

Implementation notes:
- TILE_SIZE=1 in the H_post_bda fwd/bwd kernels is hard-required by
  the reshape pattern (asserted in the launcher); a higher tile size
  would silently mix data across batch elements and is left as future
  work.
- proj_rms backward plumbs `eps` through the kernel (closes a
  forward/backward eps mismatch found during strict review).
- H_post_bda bias gradient is reduced in fp32 to match the rest of
  the mHC fp32-accumulation practice; bf16 reduction across `s*b`
  rows accumulates noticeable rounding error for long sequences.

Unit tests cover forward/backward parity vs the native reference
implementation across the supported shapes.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 29, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant