Add fused mHC cuTile kernels (stacks on #4483)#4527
Closed
Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Closed
Add fused mHC cuTile kernels (stacks on #4483)#4527Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Conversation
Adds fused cuTile kernel implementations for the mHC primitives: sinkhorn, h_aggregate, h_post_bda, and proj_rms. Selected at config time via `use_fused_mhc=True`; the reference path in `hyper_connection.py` remains the default. Implementation notes: - TILE_SIZE=1 in the H_post_bda fwd/bwd kernels is hard-required by the reshape pattern (asserted in the launcher); a higher tile size would silently mix data across batch elements and is left as future work. - proj_rms backward plumbs `eps` through the kernel (closes a forward/backward eps mismatch found during strict review). - H_post_bda bias gradient is reduced in fp32 to match the rest of the mHC fp32-accumulation practice; bf16 reduction across `s*b` rows accumulates noticeable rounding error for long sequences. Unit tests cover forward/backward parity vs the native reference implementation across the supported shapes. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
This was referenced Apr 29, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds the fused cuTile kernel path for mHC primitives. Carved out of PR #4469 so the cuTile kernel and its tests can be reviewed separately from the transformer core mHC reference implementation.
This PR's content (the only files this PR is asking to add or modify):
megatron/core/fusions/fused_mhc_kernels.py(new): cuTile kernels for Sinkhorn, H_aggregate, H_post_bda, proj_rms. Selected at config time viause_fused_mhc=True.tests/unit_tests/fusions/test_fused_mhc_kernels.py(new): forward / backward parity vs the native reference impl across the supported shapes.Stacking note
This branch is built on top of #4483 (
yxu1/mhc-transformer-core-code-dsv4) so the test file'sfrom megatron.core.transformer.hyper_connection import native_h_aggregate, …imports resolve. The "Files changed" tab consequently shows #4483's content as ancestry — please review only the two files listed above and rely on #4483 for the rest.This PR cannot merge before #4483. Once #4483 lands, rebase / merge will leave only the two new files as this PR's diff.
Origin
Carved out of PR #4469 (
yxu1/mhc-hybridmodel-dsv4) at commite3d0102ad. Includes the fused-kernel-specific strict-review fixes collected over passes 3–11 of/claude strict-review:epsplumbed throughproj_rmsbackward kernels (closes a fwd/bwd eps mismatch)TILE_SIZE=1assertion on_ct_hpb_*_kernelpathsg_biasaccumulatormath.gcd(sb, 4)tiling for the aggregate kernelsValidation
python3 -m compileallon the touched files. Functional / GPU pytest is in a separate stacked PR.🤖 Generated by Claude Opus 4.7 (1M context).