-
Notifications
You must be signed in to change notification settings - Fork 156
unify mma gemm #727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
unify mma gemm #727
Conversation
WalkthroughAdds SM100 target detection and dispatch alongside SM120, introduces a new SM100 GEMM wrapper header, updates CUDA GEMM dispatch to include SM100, centralizes MMA implementations into gemm_mma.h with namespace changes, removes legacy tl_mma definitions from sm90, and adds uniform tl::gemm_* wrappers across SM80/89/100/120. Changes
Sequence Diagram(s)sequenceDiagram
participant Op as Gemm::InferLayout
participant T as Target utils
participant H as cuda/gemm.h (PP dispatch)
participant Arch as gemm_sm80/89/100/120.h
Note over Op: Compile-time + target-driven selection
Op->>T: TargetIsSM100(target)? / TargetIsSM120(target)?
T-->>Op: true/false
Op->>H: Include gemm.h for arch
H-->>Arch: Select header by __CUDA_ARCH__
Arch-->>Op: tl::gemm_* wrappers use cute::tl_mma::GemmTensorOp
sequenceDiagram
participant tl as tl::gemm_ss/rs/sr
participant mma as cute::tl_mma::GemmTensorOp
tl->>mma: instantiate GemmTensorOp<...>
tl->>mma: call body/body_rs/body_sr(pA,pB,accum)
mma-->>tl: completes device operation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Poem
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @oraluben, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request unifies the implementation of Matrix Multiply-Accumulate (MMA) operations by moving them into a dedicated cute::tl_mma namespace. A key outcome of this refactoring is the enablement of MMA instruction support for the SM100 CUDA architecture, ensuring that newer GPU generations can leverage optimized GEMM operations.
Highlights
- SM100 Architecture Support: Added TargetIsSM100 detection and integrated it into the GEMM layout inference, allowing SM100 devices to utilize specialized MMA instructions.
- Unified MMA Namespace: Refactored MMA-related templates and classes, such as GemmTensorOp, into a new cute::tl_mma namespace for better organization and clarity.
- Architecture-Specific GEMM Dispatch: Introduced gemm_sm100.h and updated existing architecture-specific GEMM headers (e.g., gemm_sm120.h, gemm_sm80.h) to define tl::gemm_ss/rs/sr functions, which now consistently call into the cute::tl_mma namespace.
- Streamlined Includes: Adjusted includes in gemm_sm90.h by removing redundant cute/arch/mma_sm80.hpp and cute/atom/mma_atom.hpp, aligning with the new unified structure.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request unifies the MMA GEMM operations into the cute::tl_mma
namespace and adds support for the SM100 architecture. The refactoring centralizes the common MMA logic, which is a good improvement for maintainability. However, this refactoring has also introduced significant code duplication across several architecture-specific header files (gemm_sm80.h
, gemm_sm89.h
, gemm_sm100.h
, gemm_sm120.h
), which I've commented on. Additionally, the new TargetIsSM100
function in src/target/utils.cc
has a potentially incorrect or misleading architecture range. Addressing these points will improve the maintainability and correctness of the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🔭 Outside diff range comments (1)
src/tl_templates/cuda/gemm_mma.h (1)
248-251
: Bug: B_type alias incorrectly falls back to A_type_rawThe else branch uses A_type_raw instead of B_type_raw, which will mis-type B when B_type_raw != A_type_raw (e.g., mixed-type GEMMs such as e4m3×e5m2). This can produce incorrect codegen and/or miscompiles.
Fix by using B_type_raw in the fallback:
- using B_type = - typename std::conditional<std::is_same<B_type_raw, float>::value, - tfloat32_t, A_type_raw>::type; + using B_type = + typename std::conditional<std::is_same<B_type_raw, float>::value, + tfloat32_t, B_type_raw>::type;Optionally, add a static_assert in paths that don’t support mixed A/B types to fail fast at compile-time.
🧹 Nitpick comments (5)
src/tl_templates/cuda/gemm_sm100.h (1)
7-38
: Avoid repetition across sm80/89/100/120 wrappersThese three nearly identical wrappers are duplicated across four headers. Consider centralizing them in a small helper macro or a shared inline header (e.g., gemm_wrappers_common.h) to reduce maintenance surface.
src/tl_templates/cuda/gemm_sm89.h (2)
22-42
: Reduce duplication of wrapper bodiesSame suggestion as other SM headers: generate gemm_ss/gemm_rs/gemm_sr via a shared template/macro to avoid repeating the aliasing and single-line body calls.
3-7
: Trim redundant includes in gemm_sm89.h
gemm_mma.h
already pulls in<cute/arch/mma_sm89.hpp>
and"cuda_fp8.h"
, so you can simplify the header:• File needing cleanup:
– src/tl_templates/cuda/gemm_sm89.h
Remove the explicit SM89 and FP8 includes.Suggested diff:
--- a/src/tl_templates/cuda/gemm_sm89.h +++ b/src/tl_templates/cuda/gemm_sm89.h @@ -#include <cute/arch/mma_sm89.hpp> -#include "cuda_fp8.h" -#include "gemm_mma.h" +#include "gemm_mma.h"Optional cleanup to keep headers lean.
src/tl_templates/cuda/gemm_sm80.h (1)
7-38
: Optional: factor out common wrapper patternSame three functions appear unchanged across sm80/89/100/120. A shared header or macro would simplify future edits.
src/tl_templates/cuda/gemm_sm120.h (1)
7-38
: Optional: de-duplicate wrapper boilerplateSame de-duplication opportunity here for gemm_ss/rs/sr generation as in other headers.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (10)
src/op/gemm.cc
(1 hunks)src/target/utils.cc
(1 hunks)src/target/utils.h
(1 hunks)src/tl_templates/cuda/gemm.h
(1 hunks)src/tl_templates/cuda/gemm_mma.h
(2 hunks)src/tl_templates/cuda/gemm_sm100.h
(1 hunks)src/tl_templates/cuda/gemm_sm120.h
(1 hunks)src/tl_templates/cuda/gemm_sm80.h
(1 hunks)src/tl_templates/cuda/gemm_sm89.h
(1 hunks)src/tl_templates/cuda/gemm_sm90.h
(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (8)
src/tl_templates/cuda/gemm_sm100.h (2)
src/tl_templates/cuda/gemm_sm120.h (3)
tl
(5-40)gemm_ss
(10-16)gemm_rs
(21-27)src/tl_templates/cuda/gemm_sm70.h (1)
tl
(159-188)
src/target/utils.h (1)
src/target/utils.cc (2)
TargetIsSM100
(56-61)TargetIsSM100
(56-56)
src/op/gemm.cc (1)
src/target/utils.cc (4)
TargetIsSM100
(56-61)TargetIsSM100
(56-56)TargetIsSM120
(63-68)TargetIsSM120
(63-63)
src/target/utils.cc (1)
src/op/gemm.cc (2)
GetArchInt
(291-302)GetArchInt
(291-291)
src/tl_templates/cuda/gemm_sm120.h (2)
src/target/utils.h (1)
tl
(13-31)src/tl_templates/cuda/gemm_sm100.h (4)
tl
(5-40)gemm_ss
(10-16)gemm_rs
(21-27)gemm_sr
(32-38)
src/tl_templates/cuda/gemm_sm80.h (2)
src/tl_templates/cuda/gemm_sm100.h (3)
tl
(5-40)gemm_ss
(10-16)gemm_rs
(21-27)src/tl_templates/cuda/gemm_sm120.h (3)
tl
(5-40)gemm_ss
(10-16)gemm_rs
(21-27)
src/tl_templates/cuda/gemm_sm89.h (2)
src/tl_templates/cuda/gemm_sm100.h (4)
tl
(5-40)gemm_ss
(10-16)gemm_rs
(21-27)gemm_sr
(32-38)src/tl_templates/cuda/gemm_sm120.h (4)
tl
(5-40)gemm_ss
(10-16)gemm_rs
(21-27)gemm_sr
(32-38)
src/tl_templates/cuda/gemm_mma.h (1)
src/tl_templates/cuda/gemm_sm90.h (1)
cute
(12-147)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: build-test-nvidia
- GitHub Check: build-test-amd
- GitHub Check: bot-task
🔇 Additional comments (10)
src/target/utils.h (1)
22-22
: Good addition: SM100 target classifier added consistentlyThe new TargetIsSM100 declaration aligns with existing helpers and the PR’s intent. No issues spotted here.
src/op/gemm.cc (1)
374-376
: SM100 grouping in the Ampere/Turing layout path is correct, please add a clarifying commentVerified that:
makeGemmFragmentC
andmakeGemmABLayout
(insrc/layout/layout.h
andsrc/layout/gemm_layouts.cc
) are fully generic and contain no SM100-specific branches.gemm_sm100.h
only defines MMA compute wrappers, not layout logic.- No Hopper-specific continuity adjustments (
makeGemmABLayoutHopper
) are used here, and none are required for SM100.To aid future maintainers, please add a brief comment above the SM100 case in
src/op/gemm.cc
(around line 374), for example:// SM100 uses the same C fragment and generic ABLayout as Ampere/Turing } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) || TargetIsSM100(T.target) || TargetIsSM120(T.target)) { …src/target/utils.cc (1)
56-61
: SM100 detection mirrors existing SMx helpersPattern and guard match the rest of the helpers. Range [100, 120) is consistent with SM120 using [120, 130). Looks good.
src/tl_templates/cuda/gemm.h (1)
4-5
: Arch dispatch extended to SM100Including gemm_sm100.h for CUDA_ARCH_LIST >= 1000 is in the right position/order and won’t shadow higher SM120. Looks correct.
src/tl_templates/cuda/gemm_mma.h (2)
424-424
: Closing namespace updatedMatches the new cute::tl_mma open. No issues.
13-13
: Verified: CUDA arch-specific wrappers consistently referencecute::tl_mma
I’ve confirmed that all CUDA SM80/89/100/120 wrappers now use
cute::tl_mma::GemmTensorOp<…>
, and SM90 falls back to it when WGMMA isn’t used. The specialized SP-SM90 and HIP variants remain in their own namespaces by design. No further changes needed here.src/tl_templates/cuda/gemm_sm100.h (1)
7-16
: SM100 tl wrappers correctly forward to unified cute::tl_mmaGemmTensorOp aliasing and body() dispatch look consistent with the new unified namespace. This mirrors sm80/89/120 and should enable mma on SM100 as intended.
src/tl_templates/cuda/gemm_sm89.h (1)
11-20
: Forwarding to cute::tl_mma is consistent and correctThe tl::gemm_ss wrapper is aligned with the unified GemmTensorOp and matches the sm80/100/120 patterns. Same applies to gemm_rs/sr below.
src/tl_templates/cuda/gemm_sm80.h (1)
7-16
: SM80 tl wrappers match the unified cute::tl_mma pathAlias and dispatch are correct and consistent with other SM targets.
src/tl_templates/cuda/gemm_sm120.h (1)
7-16
: SM120 tl wrappers look good and align with unified MMAForwarding to cute::tl_mma via GemmTensorOp is correct; mirrors sm100/sm80/sm89.
@@ -0,0 +1,40 @@ | |||
#pragma once | |||
|
|||
#include "gemm_mma.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Confirm CUTLASS_DEVICE availability via gemm_mma.h
If gemm_mma.h doesn’t include cutlass/cutlass.h (or otherwise define CUTLASS_DEVICE), these wrappers may fail to compile on some TU orders. Please verify.
Run:
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Locate gemm_mma.h and verify it (or its includes) define CUTLASS_DEVICE
fd -a gemm_mma.h
rg -n -C2 'CUTLASS_DEVICE' $(fd -a gemm_mma.h || true)
# As a fallback, ensure cutlass/cutlass.h is pulled somewhere above these wrappers
rg -n -C2 'include\s*<cutlass/cutlass\.h>' src || true
Length of output: 189
Let's drill into gemm_mma.h
’s include chain and locate the CUTLASS_DEVICE
macro definition:
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Locate gemm_mma.h
file=$(fd -a gemm_mma.h)
echo "Inspecting includes in $file:"
rg -n '^#\s*include' "$file"
echo -e "\nSearching for CUTLASS_DEVICE definition in the repo:"
rg -n '#\s*define\s*CUTLASS_DEVICE' -C2 .
echo -e "\nSearching for any inclusion of cutlass/cutlass.h (quotes or brackets):"
rg -n 'include\s*[<"]cutlass/cutlass.h[>"]' -C2 .
Length of output: 763
Ensure CUTLASS_DEVICE is defined in gemm_mma.h
The gemm_sm100.h
wrappers pull in gemm_mma.h
(via line 3) but we didn’t find any CUTLASS_DEVICE
definition or an #include <cutlass/cutlass.h>
in its include chain. Without this, builds can break depending on translation-unit include order.
Please update src/tl_templates/cuda/gemm_mma.h
to guarantee the macro is available, for example by adding at the top:
#include <cutlass/cutlass.h> // for CUTLASS_DEVICE
or by otherwise defining CUTLASS_DEVICE
before it’s used in the wrappers.
🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_mma.h (top of file), ensure CUTLASS_DEVICE is
always defined by adding an include of Cutlass's main header before any use of
CUTLASS_DEVICE (e.g., add #include <cutlass/cutlass.h> as the first include) or
alternatively add a guarded fallback definition (wrap a #ifndef CUTLASS_DEVICE /
#define CUTLASS_DEVICE __host__ __device__ / #endif) so the macro is available
regardless of translation-unit include order.
#include "gemm_mma.h" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Including gemm_mma.h here is the right move to centralize tl_mma
This aligns sm90 with the new source of tl_mma definitions and avoids duplicate declarations formerly in this header. Please also ensure there is no remaining tl_mma content here to prevent ODR issues.
Run:
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Ensure only gemm_mma.h defines cute::tl_mma::GemmTensorOp
rg -n -C3 '(namespace\s+cute\s*{)?\s*namespace\s+tl_mma\b' src | sed -n '1,200p'
rg -n -C2 '\bclass\s+GemmTensorOp\b' src | sed -n '1,200p'
# Sanity check that gemm.h dispatch includes sm100 and only one sm* header per arch
fd -a gemm.h
rg -n -C3 'gemm_sm(80|89|90|100|120)\.h' $(fd -a gemm.h || true)
Length of output: 250
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "1. Checking for any remaining tl_mma content in gemm_sm90.h"
rg -n 'namespace\s+tl_mma' src/tl_templates/cuda/gemm_sm90.h || true
rg -n 'GemmTensorOp' src/tl_templates/cuda/gemm_sm90.h || true
echo "2. Searching for class GemmTensorOp definitions across the codebase"
rg -n 'class\s+GemmTensorOp' -g '*.h' src || true
echo "3. Locating gemm.h files and listing included sm headers"
for f in $(fd -e h gemm.h); do
echo "In $f:"
rg -n '#include\s+"gemm_sm[0-9]' -C2 "$f" || true
done
Length of output: 2672
Centralize GemmTensorOp: remove duplicate definition in gemm_sm90.h
The include of gemm_mma.h
correctly centralizes the cute::tl_mma::GemmTensorOp
definition, but src/tl_templates/cuda/gemm_sm90.h
still declares its own class GemmTensorOp
(around line 23). Please remove that local definition and any remaining namespace tl_mma
blocks in this file to prevent ODR conflicts.
- File: src/tl_templates/cuda/gemm_sm90.h
• Remove theclass GemmTensorOp { … }
block starting at line 23
• Verify no othertl_mma
namespaces or related definitions remain
Confirmed that cuda/gemm.h
dispatch includes exactly one SM-90 header and that all uses now refer to the definition in gemm_mma.h
.
🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_sm90.h around lines 10–11 and starting at line 23,
the file includes gemm_mma.h but still defines a local class GemmTensorOp and
one or more tl_mma namespace blocks; remove the entire local class GemmTensorOp
block that begins at line 23 and delete any remaining namespace tl_mma blocks or
duplicate definitions in this file so all references use the centralized
cute::tl_mma::GemmTensorOp from gemm_mma.h; after removal, ensure includes and
forward declarations (if any) remain correct and run a build to confirm no ODR
or missing-symbol issues.
Suppressed by #793 |
Unify mma ops to
cute::tl_mma
ns. After this PR, sm100 can run with mma instructions.Summary by CodeRabbit
New Features
Refactor