From 4d3843210521a3167327ecfb83d5721cf24ef920 Mon Sep 17 00:00:00 2001 From: Butterfingrz <2395959141@qq.com> Date: Wed, 20 May 2026 17:32:22 +0800 Subject: [PATCH] add sm90 cutlass grouped_gemm swiglu epilogue fusion --- ...57_hopper_bias_swiglu_grouped_gemm_bf16.cu | 821 ++++++++++++++++++ .../57_hopper_grouped_gemm/CMakeLists.txt | 16 + .../sm90_gated_swiglu_store_tma.hpp | 428 +++++++++ ...m90_epilogue_array_tma_warpspecialized.hpp | 128 ++- .../sm90_visitor_tma_warpspecialized.hpp | 25 +- ..._array_tma_warpspecialized_cooperative.hpp | 25 + ...emm_array_tma_warpspecialized_pingpong.hpp | 25 + 7 files changed, 1463 insertions(+), 5 deletions(-) create mode 100644 examples/57_hopper_grouped_gemm/57_hopper_bias_swiglu_grouped_gemm_bf16.cu create mode 100644 examples/57_hopper_grouped_gemm/sm90_gated_swiglu_store_tma.hpp diff --git a/examples/57_hopper_grouped_gemm/57_hopper_bias_swiglu_grouped_gemm_bf16.cu b/examples/57_hopper_grouped_gemm/57_hopper_bias_swiglu_grouped_gemm_bf16.cu new file mode 100644 index 0000000000..b23c4f67e1 --- /dev/null +++ b/examples/57_hopper_grouped_gemm/57_hopper_bias_swiglu_grouped_gemm_bf16.cu @@ -0,0 +1,821 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper Grouped GEMM with fused Bias + Gated-SwiGLU epilogue. + + This example extends 57_hopper_grouped_gemm.cu to demonstrate a fully-fused + epilogue that: + 1) Loads a per-group BF16 row bias [2*I] into smem via cp.async + (Sm90RowBroadcast, EVT child leaf). + 2) Adds the bias to the FP32 GEMM accumulator (Sm90Compute). + 3) Stores the bias-added pre-activation [M_i, 2*I] to gmem + (the "aux D" output, useful for backward pass). + 4) Computes Gated-SwiGLU on (gate, up) pairs and TMA-stores the + half-width result [M_i, I] to a separate output buffer + (Sm90GatedSwiGLUStoreTma, EVT root with HasAuxTmaStore=true). + + The kernel signature: per group i, + gate_up_i = A_i [M_i, K] @ B_i [K, 2*I] + bias_i [2*I] (bf16, "aux D") + swiglu_i = SwiGLU(gate_up_i) (bf16, [M_i, I]) + + The SwiGLU activation is the OpenAI "openai-style" approximation used by + GPT-OSS / mxfp4 swiglu: with alpha = 1.702 and limit = 7.0, + g_clip = min(gate, limit) + u_clip = clamp(up, -limit, limit) + s = 0.5 * (1 + tanh(alpha * 0.5 * g_clip)) + out = g_clip * s * u_clip + g_clip * s + + To run this example: + + $ ./examples/57_hopper_grouped_gemm/57_hopper_bias_swiglu_grouped_gemm \ + --m=1024 --n=4096 --k=2048 --groups=8 + + Here --n is the GEMM-N dimension which equals 2*I (gate + up + concatenated); the SwiGLU output has half that width per row. + + Both the Cooperative and Pingpong PtrArray TMA warp-specialized schedules + are exercised. +*/ + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +// Custom EVT visitors (copies of PR_todo headers, kept inside this example so +// the binary is self-contained relative to the cutlass include tree). +#include "sm90_gated_swiglu_store_tma.hpp" + +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // + +using ElementA = cutlass::bfloat16_t; +using ElementB = cutlass::bfloat16_t; +using ElementOutput = cutlass::bfloat16_t; // aux_D and SwiGLU output +using ElementBias = cutlass::bfloat16_t; +using ElementAccum = float; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +static_assert(std::is_same_v, + "Sm90GatedSwiGLUStoreTma uses tanh.approx.bf16x2 and requires bf16 output"); + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Layouts / alignments +///////////////////////////////////////////////////////////////////////////////////////////////// + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutOutput = cutlass::layout::RowMajor; + +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // 8 bf16 +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // 8 bf16 +constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; // 8 bf16 +constexpr int AlignmentBias = AlignmentOutput; + +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Schedule configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct CooperativeConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_128, _256, _64>; + using ClusterShape = Shape<_2, _1, _1>; + static constexpr int EpiThreadCount = 256; +}; + +struct PingpongConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = Shape<_128, _128, _64>; + using ClusterShape = Shape<_2, _1, _1>; + static constexpr int EpiThreadCount = 128; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// GEMM instantiation per schedule +// +// EVT tree: SwiGLUStore( Plus( AccFetch, RowBroadcastVec ) ) +// leaf Sm90AccFetch -> FP32 accumulator +// leaf Sm90RowBroadcast<...> -> per-group bf16 bias [2I], cp.async loaded +// inner Sm90Compute -> elementwise (acc + bias) +// root Sm90GatedSwiGLUStoreTma -> aux_D store + TMA SwiGLU half-width store +// +// We build a 'reference' epilogue first to pull EpiTile / SmemLayoutAtomD / +// CopyOpR2S / StagesD out of the builder (the SwiGLU root visitor needs them +// to size its own smem and pipeline correctly), then assemble the real fusion. +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmGivenSchedule { + using TileShape = typename ScheduleConfig::TileShape; + using ClusterShape = typename ScheduleConfig::ClusterShape; + using KernelSchedule = typename ScheduleConfig::KernelSchedule; + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; + static constexpr int EpiThreadCount = ScheduleConfig::EpiThreadCount; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + + // ----- Phase 1: reference epilogue (stock Sm90RowBroadcast) ----- + using AccFetchRef = cutlass::epilogue::fusion::Sm90AccFetch; + using BiasBcastRef = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0, TileShape, ElementBias*, ElementAccum, + Stride<_0, _1, _0>, AlignmentBias>; + using PlusRef = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, ElementOutput, ElementAccum, RoundStyle>; + using FusionRef = cutlass::epilogue::fusion::Sm90EVT; + + using RefEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccum, ElementAccum, + void, void, 0, // no source C matrix + ElementOutput, LayoutOutput*, AlignmentOutput, + EpilogueSchedule, + FusionRef>::CollectiveOp; + + using EpiTile = typename RefEpilogue::EpilogueTile; + using SmemLayoutAtomD = typename RefEpilogue::SmemLayoutAtomD; + using CopyOpR2S = typename RefEpilogue::CopyOpR2S; + static constexpr int StagesD = RefEpilogue::DispatchPolicy::StagesD; + + // ----- Phase 2: real EVT tree ----- + using AccFetch = cutlass::epilogue::fusion::Sm90AccFetch; + using BiasBcast = cutlass::epilogue::fusion::Sm90RowBroadcast< + /*Stages=*/0, // visitor uses no smem stage + TileShape, + ElementBias*, // PtrArray (per-group bias) + ElementAccum, + Stride<_0, _1, _0>, + AlignmentBias>; + using PlusOp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, ElementOutput, ElementAccum, RoundStyle>; + using AccPlusBiasEVT = cutlass::epilogue::fusion::Sm90EVT; + + using SwiGLUStore = cutlass::epilogue::fusion::Sm90GatedSwiGLUStoreTma< + StagesD, + EpiTile, + ElementOutput, + ElementAccum, + RoundStyle, + SmemLayoutAtomD, + CopyOpR2S, + AlignmentOutput, + /*EnableNullptr=*/true, + EpiThreadCount>; + + using SwiGLUFusion = cutlass::epilogue::fusion::Sm90EVT; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccum, ElementAccum, + void, void, 0, // no source C matrix + ElementOutput, LayoutOutput*, AlignmentOutput, // aux D output + EpilogueSchedule, + SwiGLUFusion>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA*, AlignmentA, + ElementB, LayoutB*, AlignmentB, + ElementAccum, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, CollectiveMainloop, CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +using Gemm = GemmGivenSchedule::Gemm; +using GemmPingpong = GemmGivenSchedule::Gemm; + +// Reference device GEMM (FP32 accum -> bf16 output, no bias / no activation). +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, LayoutA, + ElementB, LayoutB, + ElementOutput, LayoutOutput, + ElementAccum, ElementAccum>; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +// Cooperative and Pingpong share the same A/B/D stride types because they're +// derived from the layouts (RowMajor / ColMajor / RowMajor), not the kernel +// schedule. Pin this with a static_assert so a future schedule divergence +// fails at compile time rather than silently mismapping device strides. +static_assert(std::is_same_v, + "Cooperative / Pingpong StrideA mismatch"); +static_assert(std::is_same_v, + "Cooperative / Pingpong StrideB mismatch"); +static_assert(std::is_same_v, + "Cooperative / Pingpong StrideD mismatch"); + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Reference kernels for bias + SwiGLU on device (used only by verify()) +///////////////////////////////////////////////////////////////////////////////////////////////// + +static __global__ void add_row_bias_row_major_kernel( + cutlass::bfloat16_t* D, cutlass::bfloat16_t const* bias, int M, int N) { + int64_t idx = blockIdx.x * static_cast(blockDim.x) + threadIdx.x; + int64_t total = static_cast(M) * N; + if (idx >= total) return; + int n = static_cast(idx % N); // row-major: contiguous in N + float v = static_cast(D[idx]) + static_cast(bias[n]); + D[idx] = cutlass::bfloat16_t(v); +} + +static __global__ void swiglu_reference_kernel( + cutlass::bfloat16_t const* D_pre, + cutlass::bfloat16_t* out, + int M, int Nhalf, + float alpha, float limit) { + int64_t idx = blockIdx.x * static_cast(blockDim.x) + threadIdx.x; + int64_t total = static_cast(M) * Nhalf; + if (idx >= total) return; + int m = static_cast(idx / Nhalf); + int n = static_cast(idx % Nhalf); + int64_t base = static_cast(m) * (2 * Nhalf); + float g = static_cast(D_pre[base + 2 * n + 0]); + float u = static_cast(D_pre[base + 2 * n + 1]); + g = fminf(g, limit); + u = fmaxf(fminf(u, limit), -limit); + float s = 0.5f * (1.f + tanhf(alpha * 0.5f * g)); + float gs = g * s; + float r = gs * u + gs; + out[idx] = cutlass::bfloat16_t(r); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Host-side allocations +///////////////////////////////////////////////////////////////////////////////////////////////// + +std::vector offset_A; // start element of each group in block_A (size M_i*K) +std::vector offset_B; // size K * N (N = 2I) +std::vector offset_aux; // size M_i * N (aux D & ref pre-activation) +std::vector offset_swi; // size M_i * I + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_D_host; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_bias; // groups * N +cutlass::DeviceAllocation block_aux_D; // sum(M_i) * N +cutlass::DeviceAllocation block_swiglu; // sum(M_i) * I +cutlass::DeviceAllocation block_ref_pre; // sum(M_i) * N (pre-activation reference) +cutlass::DeviceAllocation block_ref_swiglu; // sum(M_i) * I + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_aux_D; // visitor side: ptr_D +cutlass::DeviceAllocation ptr_bias; // bias is bf16** (non-const) +cutlass::DeviceAllocation ptr_swiglu; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_D; // also reused for aux_D (same layout) + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Command-line options +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct Options { + bool help = false; + int iterations = 10; + + // M can be randomized per group; N (= 2*I) and K are constant across groups + // because the SwiGLU visitor stores a single N_out / stride_swiglu in Params + // (Sm90GatedSwiGLUStoreTma::Arguments fields N_out / stride_swiglu apply to + // every group via aux_tensormaps_replace). + int m = 1024, n = 4096, k = 2048, groups = 8; + std::string benchmark_path; + std::vector problem_sizes_host; + + float swiglu_alpha = 1.702f; + float swiglu_limit = 7.0f; + + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + if (cmd.check_cmd_line_flag("help")) { help = true; return; } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + cmd.get_cmd_line_argument("swiglu_alpha", swiglu_alpha); + cmd.get_cmd_line_argument("swiglu_limit", swiglu_limit); + + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { problem_sizes_host.clear(); return; } + } else { + randomize_problems(cmd); + } + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + problem_sizes_host.reserve(groups); + + // N and K are shared across groups (see field comment above). + // M can vary; if --m is omitted, randomize per group. + std::srand(2024); // deterministic problem shapes across runs + for (int i = 0; i < groups; ++i) { + int m_i = cmd_line_m; + if (m_i < 0) { + m_i = alignment * ((std::rand() % 64) + 1); + } + problem_sizes_host.push_back({m_i, n, k}); + } + } + + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) return false; + + // Force every problem to share the same N and K (taken from the FIRST + // entry); only M varies. + int n_first = -1; + int k_first = -1; + while (file.good()) { + int idx = -1; + std::string extent_str; + file >> idx >> extent_str; + if (idx < 0 || extent_str.empty()) break; + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + for (int i = 0; i < int(tokens.size()); ++i) { + extent.at(i) = std::atoi(tokens.at(i).c_str()); + } + if (n_first < 0) { n_first = extent.n(); k_first = extent.k(); n = n_first; k = k_first; } + problem_sizes_host.push_back({extent.m(), n_first, k_first}); + } + groups = static_cast(problem_sizes_host.size()); + return true; + } + + bool valid() const { + if (groups <= 0) return false; + if (n % 16 != 0) return false; // bias / aux D / TMA all require multiples of 16 + if ((n / 2) % 8 != 0) return false; // SwiGLU half-width must be 8-aligned + if (k <= 0 || k % 8 != 0) return false; + for (auto const& p : problem_sizes_host) { + if (get<0>(p) <= 0 || get<1>(p) <= 0 || get<2>(p) <= 0) return false; + if (get<1>(p) != n || get<2>(p) != k) return false; // N, K must be shared + } + return true; + } + + std::ostream & print_usage(std::ostream &out) const { + out << "57_hopper_bias_swiglu_grouped_gemm\n\n" + << " Hopper BF16 Grouped GEMM with fused per-group bias + Gated-SwiGLU\n" + << " epilogue. The GEMM produces [M_i, 2*I] = bias-added pre-activations\n" + << " (stored to aux_D), and the SwiGLU output [M_i, I] is TMA-stored to a\n" + << " separate buffer.\n\n" + << "Options:\n\n" + << " --help Display this usage statement\n" + << " --m= Per-group M (random if omitted)\n" + << " --n= GEMM N = 2*I (shared across groups)\n" + << " --k= Shared K\n" + << " --groups= Number of groups (problems)\n" + << " --iterations= Profiling iterations (0 to skip)\n" + << " --swiglu_alpha= SwiGLU alpha (default 1.702)\n" + << " --swiglu_limit= SwiGLU clamp limit (default 7.0)\n" + << " --benchmark= Benchmark file (idx MxNxK per line)\n\n" + << "Examples:\n\n" + << "$ 57_hopper_bias_swiglu_grouped_gemm --m=1024 --n=4096 --k=2048 --groups=8\n"; + return out; + } + + // GEMM-only flops: counts 2*M*N*K per group; bias-add and SwiGLU are timed + // by GpuTimer but intentionally not counted here. + double gflops(double runtime_s) const { + uint64_t fmas = 0; + for (auto const& p : problem_sizes_host) { + fmas += uint64_t(get<0>(p)) * uint64_t(get<1>(p)) * uint64_t(get<2>(p)); + } + return double(2 * fmas) / 1.0e9 / runtime_s; + } +}; + +struct Result { + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; +}; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Init / allocate +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool initialize_block(cutlass::DeviceAllocation& block, uint64_t seed) { + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + if (bits_input == 1) { scope_max = Element(2); scope_min = Element(0); } + else if (bits_input <= 8) { scope_max = Element(2); scope_min = Element(-2); } + else { scope_max = Element(4); scope_min = Element(-4); } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + return true; +} + +void allocate(const Options &options) { + offset_A.clear(); offset_B.clear(); + offset_aux.clear(); offset_swi.clear(); + stride_A_host.clear(); stride_B_host.clear(); stride_D_host.clear(); + + int64_t total_A = 0, total_B = 0, total_aux = 0, total_swi = 0; + + for (int i = 0; i < options.groups; ++i) { + auto const& p = options.problem_sizes_host.at(i); + int M = get<0>(p); + int N = get<1>(p); // = 2*I + int K = get<2>(p); + int Ihalf = N / 2; + + offset_A.push_back(total_A); + offset_B.push_back(total_B); + offset_aux.push_back(total_aux); + offset_swi.push_back(total_swi); + + total_A += int64_t(M) * K; + total_B += int64_t(K) * N; + total_aux += int64_t(M) * N; + total_swi += int64_t(M) * Ihalf; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + } + + block_A.reset(total_A); + block_B.reset(total_B); + block_aux_D.reset(total_aux); + block_ref_pre.reset(total_aux); + block_swiglu.reset(total_swi); + block_ref_swiglu.reset(total_swi); + block_bias.reset(int64_t(options.groups) * options.n); +} + +void initialize(const Options &options) { + uint64_t seed = 2024; + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + std::vector host_ptr_A(options.groups); + std::vector host_ptr_B(options.groups); + std::vector host_ptr_aux_D(options.groups); + std::vector host_ptr_bias(options.groups); + std::vector host_ptr_swiglu(options.groups); + + for (int i = 0; i < options.groups; ++i) { + host_ptr_A.at(i) = block_A.get() + offset_A.at(i); + host_ptr_B.at(i) = block_B.get() + offset_B.at(i); + host_ptr_aux_D.at(i) = block_aux_D.get() + offset_aux.at(i); + host_ptr_swiglu.at(i) = block_swiglu.get() + offset_swi.at(i); + host_ptr_bias.at(i) = block_bias.get() + int64_t(i) * options.n; + } + + ptr_A.reset(options.groups); ptr_A.copy_from_host(host_ptr_A.data()); + ptr_B.reset(options.groups); ptr_B.copy_from_host(host_ptr_B.data()); + ptr_aux_D.reset(options.groups); ptr_aux_D.copy_from_host(host_ptr_aux_D.data()); + ptr_bias.reset(options.groups); ptr_bias.copy_from_host(host_ptr_bias.data()); + ptr_swiglu.reset(options.groups); ptr_swiglu.copy_from_host(host_ptr_swiglu.data()); + + stride_A.reset(options.groups); stride_A.copy_from_host(stride_A_host.data()); + stride_B.reset(options.groups); stride_B.copy_from_host(stride_B_host.data()); + stride_D.reset(options.groups); stride_D.copy_from_host(stride_D_host.data()); + + initialize_block(block_A, seed + 1); + initialize_block(block_B, seed + 2); + initialize_block(block_bias, seed + 3); + + // Sentinel-fill output buffers so a kernel that fails to write some + // positions is caught by verify() rather than silently passing on + // residual cudaMalloc memory. + CUDA_CHECK(cudaMemset(block_aux_D.get(), 0xFF, + block_aux_D.size() * sizeof(ElementOutput))); + CUDA_CHECK(cudaMemset(block_swiglu.get(), 0xFF, + block_swiglu.size() * sizeof(ElementOutput))); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Arguments construction +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +typename GemmT::Arguments args_from_options(const Options &options, + bool host_problem_shapes_available = true) { + int device_id = 0; + cutlass::KernelHardwareInfo kernel_hw_info = + cutlass::KernelHardwareInfo::make_kernel_hardware_info(device_id); + + // GEMM N is shared across groups (Options::valid enforces this); SwiGLU + // halves it to produce the output dimension I. + int32_t const N_out = options.n / 2; + int32_t const swiglu_stride = N_out; // row-major, contiguous over half-N + + typename GemmT::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), + host_problem_shapes_available ? options.problem_sizes_host.data() : nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {/*epilogue.thread (below)*/ {}, + /*ptr_C =*/ nullptr, // no source matrix + /*stride_C =*/ nullptr, + /*ptr_D =*/ ptr_aux_D.get(), // bias-added pre-activation [M_i, 2I] + /*stride_D =*/ stride_D.get()}, + kernel_hw_info + }; + + // EVT args, leaves-then-root: + // child(AccPlusBiasEVT) = { AccFetch{}, RowBroadcastVec{ptr_row, null_default, dRow}, Compute{} } + // root (SwiGLUStore) = { ptr_swiglu, N_out, stride_swiglu, alpha, limit } + arguments.epilogue.thread = { + { // child EVT (AccPlusBiasEVT) + {}, // Sm90AccFetch + { // Sm90RowBroadcast + reinterpret_cast(ptr_bias.get()), + ElementBias(0), + {} // default Stride<_0,_1,_0> + }, + {} // Sm90Compute + }, + { // root SwiGLUStore + ptr_swiglu.get(), + N_out, + swiglu_stride, + options.swiglu_alpha, + options.swiglu_limit + } + }; + + return arguments; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Reference computation and verification +///////////////////////////////////////////////////////////////////////////////////////////////// + +bool verify(const Options &options) { + bool passed = true; + + for (int i = 0; i < options.groups; ++i) { + auto const& p = options.problem_sizes_host.at(i); + int M = get<0>(p); + int N = get<1>(p); + int K = get<2>(p); + int Ihalf = N / 2; + + // 1) Pure GEMM (no bias / no activation) into block_ref_pre. + cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_ref_pre.get() + offset_aux.at(i), LayoutOutput::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_pre.get() + offset_aux.at(i), LayoutOutput::packed({M, N})); + + DeviceGemmReference gemm_ref; + gemm_ref( + {M, N, K}, + ElementAccum(1.0f), + ref_A, ref_B, + ElementAccum(0.0f), + ref_C, ref_D); + CUDA_CHECK(cudaDeviceSynchronize()); + + // 2) Add row bias into block_ref_pre (row-major: bias is broadcast across M). + { + int64_t total = int64_t(M) * N; + int threads = 256; + int blocks = static_cast((total + threads - 1) / threads); + add_row_bias_row_major_kernel<<>>( + block_ref_pre.get() + offset_aux.at(i), + block_bias.get() + int64_t(i) * options.n, + M, N); + CUDA_CHECK(cudaDeviceSynchronize()); + } + + // 3) Apply SwiGLU into block_ref_swiglu. + { + int64_t total = int64_t(M) * Ihalf; + int threads = 256; + int blocks = static_cast((total + threads - 1) / threads); + swiglu_reference_kernel<<>>( + block_ref_pre.get() + offset_aux.at(i), + block_ref_swiglu.get() + offset_swi.at(i), + M, Ihalf, + options.swiglu_alpha, options.swiglu_limit); + CUDA_CHECK(cudaDeviceSynchronize()); + } + + // 4) Compare. bf16 + tanh.approx easily diverges past strict-eq. + // Use a nonzero_floor near the GEMM dynamic range so that bf16 + // quantization noise around "should be zero" outputs (~0.15 absolute) + // is absorbed without masking real algorithmic bugs in large-magnitude + // regions. + bool aux_ok = cutlass::reference::device::BlockCompareRelativelyEqual( + block_aux_D.get() + offset_aux.at(i), + block_ref_pre.get() + offset_aux.at(i), + int64_t(M) * N, + ElementOutput(5e-2f), + ElementOutput(4.0f)); + + bool swi_ok = cutlass::reference::device::BlockCompareRelativelyEqual( + block_swiglu.get() + offset_swi.at(i), + block_ref_swiglu.get() + offset_swi.at(i), + int64_t(M) * Ihalf, + ElementOutput(6e-2f), + ElementOutput(4.0f)); + + if (!aux_ok || !swi_ok) { + std::cerr << " Group " << i << " failed (aux=" << aux_ok + << ", swiglu=" << swi_ok << ")\n"; + } + passed = passed && aux_ok && swi_ok; + } + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Driver +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +int run(Options &options, bool host_problem_shapes_available = true) { + allocate(options); + initialize(options); + + std::cout << " Problem sizes (M, 2I, K):\n"; + for (int i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i) << "\n"; + } + std::cout << " Groups : " << options.groups << "\n" + << " SwiGLU alpha : " << options.swiglu_alpha << "\n" + << " SwiGLU limit : " << options.swiglu_limit << "\n"; + + GemmT gemm; + auto arguments = args_from_options(options, host_problem_shapes_available); + + size_t workspace_size = GemmT::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + + Result result; + result.passed = verify(options); + + std::cout << " Disposition : " << (result.passed ? "Passed" : "Failed") << std::endl; + if (!result.passed) { std::exit(-1); } + + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " TFLOPS (GEMM) : " << result.gflops / 1000.0 << std::endl; + } + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires SM90 (Hopper).\n"; + return 0; + } + + Options options; + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + if (!options.valid()) { + std::cerr << "Invalid options: n must be a multiple of 16, n/2 a multiple of 8, k a multiple of 8.\n"; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + std::cout << "\n*** Cooperative schedule (bias + SwiGLU) ***\n"; + run(options); + + std::cout << "\n*** Cooperative schedule (host problem shapes unavailable) ***\n"; + run(options, /*host_problem_shapes_available=*/false); + + std::cout << "\n*** Pingpong schedule (bias + SwiGLU) ***\n"; + run(options); + + std::cout << "\n*** Pingpong schedule (host problem shapes unavailable) ***\n"; + run(options, /*host_problem_shapes_available=*/false); +#endif + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/57_hopper_grouped_gemm/CMakeLists.txt b/examples/57_hopper_grouped_gemm/CMakeLists.txt index 61d9d6a05a..138dd6d9d8 100644 --- a/examples/57_hopper_grouped_gemm/CMakeLists.txt +++ b/examples/57_hopper_grouped_gemm/CMakeLists.txt @@ -64,3 +64,19 @@ cutlass_example_add_executable( TEST_RANDOM_PERF TEST_RANDOM_PERF_LARGE_GROUP ) + +# ---------------------------------------------------------------------------- +# Bias + Gated-SwiGLU fused epilogue example (BF16) +# ---------------------------------------------------------------------------- +set(TEST_SWIGLU_SMALL --m=256 --n=512 --k=128 --groups=4 --iterations=0) +set(TEST_SWIGLU_RANDOM --groups=8 --iterations=0) +set(TEST_SWIGLU_FIXED --m=1024 --n=4096 --k=2048 --groups=8 --iterations=0) + +cutlass_example_add_executable( + 57_hopper_bias_swiglu_grouped_gemm + 57_hopper_bias_swiglu_grouped_gemm_bf16.cu + TEST_COMMAND_OPTIONS + TEST_SWIGLU_SMALL + TEST_SWIGLU_RANDOM + TEST_SWIGLU_FIXED + ) diff --git a/examples/57_hopper_grouped_gemm/sm90_gated_swiglu_store_tma.hpp b/examples/57_hopper_grouped_gemm/sm90_gated_swiglu_store_tma.hpp new file mode 100644 index 0000000000..be9fed446f --- /dev/null +++ b/examples/57_hopper_grouped_gemm/sm90_gated_swiglu_store_tma.hpp @@ -0,0 +1,428 @@ +// Copyright (c) Butterfingrz,13524387014@163.com + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace cutlass::epilogue::fusion { + +using namespace cute; + +CUTLASS_DEVICE __nv_bfloat162 tanh_approx_bf16x2(__nv_bfloat162 x) { + uint32_t in_bits = reinterpret_cast(x); + uint32_t out_bits; + asm("tanh.approx.bf16x2 %0, %1;" : "=r"(out_bits) : "r"(in_bits)); + return reinterpret_cast<__nv_bfloat162 const&>(out_bits); +} + +// Full-TMA-pipeline EVT visitor for gated SwiGLU with N→N/2 shape transformation. +// +// Replaces the scalar-store visitor (Sm90GatedSwiGLUStore) with a three-stage +// pipeline: visit() → registers, postreduce() → R2S, tma_store() → TMA S2G. +// +// The GEMM output is [M, 2I] but SwiGLU output is [M, I]. This visitor handles +// the 2:1 N-dimension reduction: visit() accumulates full-width fragments into +// a register tensor (Sm90AuxStore pattern), postreduce() computes SwiGLU from +// (gate, up) pairs and scatter-writes half-width results to [EPI_M, EPI_N/2] +// shared memory, and tma_store() issues TMA async stores from smem to gmem. +// +// For PtrArray grouped GEMM, per-group TMA descriptor updates are managed at +// the kernel dispatch layer (not in the visitor). The epilogue collective +// maintains aux TMA descriptors via HasAuxTmaStore trait, and the kernel passes +// a pre-updated tensormap pointer through ConsumerStoreArgs. + +template < + int Stages, + class EpilogueTile, + class Element, + class ElementCompute = float, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest, + class SmemLayoutAtom_ = void, // unused; kept for epilogue builder API compatibility + class CopyOpR2S_ = void, // unused; kept for epilogue builder API compatibility + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true, + int ThreadCount_ = 128 +> +struct Sm90GatedSwiGLUStoreTma { + using ElementAux = Element; + static constexpr bool HasAuxTmaStore = true; + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported"); + + // --- Subtile geometry and per-thread counts --- + + // Half-width epilogue subtile for SwiGLU output: [EPI_M, EPI_N/2] + static constexpr int EpiM_val = size<0>(EpilogueTile{}); + static constexpr int EpiN_val = size<1>(EpilogueTile{}); + static constexpr int EpiNH_val = EpiN_val / 2; + using EpiM = Int; + using EpiNHalf = Int; + using HalfEpiTile = Shape; + + static constexpr int ThreadCount = ThreadCount_; + static constexpr int ResultsPerThread = EpiM_val * EpiNH_val / ThreadCount; + static_assert(EpiM_val * EpiNH_val % ThreadCount == 0, + "Half-width epilogue tile must be evenly divisible by ThreadCount"); + + // --- Smem layouts for TMA store --- + // DEBUG: swizzle disabled to isolate S2G writeback correctness + using SmemLayoutAtomHalf = Layout, Stride>; + + using SmemLayoutTma = decltype(tile_to_shape( + SmemLayoutAtomHalf{}, + make_shape(EpiM{}, EpiNHalf{}), + Step<_1, _2>{})); + + using SmemLayout = decltype(tile_to_shape( + SmemLayoutTma{}, + make_shape(EpiM{}, EpiNHalf{}, Int{}), + Step<_1, _2, _3>{})); + + // --- TMA descriptor type --- + + // TMA copy type for [M, N_out] row-major output + using TMA_Swiglu = decltype(make_tma_copy( + SM90_TMA_STORE{}, + make_tensor( + static_cast(nullptr), + make_layout(make_shape(int32_t(0), int32_t(0)), + make_stride(int32_t(0), _1{}))), + SmemLayoutTma{})); + + // --- Shared storage --- + + struct SharedStorage { + alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) + array_aligned smem_swiglu; + }; + + // --- Host-side arguments and device params --- + + struct Arguments { + Element** ptr_swiglu_out_array = nullptr; + int32_t N_out = 0; + int32_t stride_swiglu = 0; + float swiglu_alpha = 1.702f; // SiLU approximation coefficient + float swiglu_limit = 7.0f; // bf16 clamp boundary (tanh saturates beyond this) + }; + + struct Params { + TMA_Swiglu tma_store_swiglu; + Element** ptr_swiglu_out_array; + int32_t N_out; + int32_t stride_swiglu; + float swiglu_alpha; + float swiglu_limit; + bool is_nullptr = false; + }; + + // --- Static host methods --- + + template + static constexpr Params + to_underlying_arguments(ProblemShape const&, Arguments const& args, void*) { + bool is_nullptr = (args.ptr_swiglu_out_array == nullptr); + TMA_Swiglu tma{}; + if (!is_nullptr) { + // 16B-aligned dummy address; real addresses set per-group on device + Element const* ptr_dummy = reinterpret_cast( + reinterpret_cast(args.ptr_swiglu_out_array) & ~uintptr_t(0xF)); + int32_t init_M = EpiM_val; // safe default, replaced on device + auto tensor_tmpl = make_tensor( + make_gmem_ptr(ptr_dummy), + make_layout(make_shape(init_M, args.N_out), + make_stride(static_cast(args.stride_swiglu), _1{}))); + tma = make_tma_copy(SM90_TMA_STORE{}, tensor_tmpl, SmemLayoutTma{}); + } + return {tma, args.ptr_swiglu_out_array, args.N_out, + args.stride_swiglu, args.swiglu_alpha, args.swiglu_limit, is_nullptr}; + } + + template + static bool + can_implement(ProblemShape const&, Arguments const&) { return true; } + + template + static size_t + get_workspace_size(ProblemShape const&, Arguments const&) { return 0; } + + template + static cutlass::Status + initialize_workspace(ProblemShape const&, Arguments const&, void*, cudaStream_t, + CudaHostAdapter* = nullptr) { + return cutlass::Status::kSuccess; + } + + // --- Constructors and member variables --- + + CUTLASS_HOST_DEVICE + Sm90GatedSwiGLUStoreTma() { } + + CUTLASS_HOST_DEVICE + Sm90GatedSwiGLUStoreTma(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms), + smem_swiglu(const_cast(shared_storage.smem_swiglu.data())) {} + + Params const* params_ptr; + Element* smem_swiglu; + + // --- Device predicate and TMA descriptor accessors --- + + CUTLASS_DEVICE bool + is_producer_load_needed() const { return false; } + + CUTLASS_DEVICE bool + is_C_load_needed() const { return false; } + + // Return the TMA descriptor template for aux_store_init() to copy into smem + CUTLASS_DEVICE auto + get_aux_tma_descriptor() const { + return params_ptr->tma_store_swiglu.get_tma_descriptor(); + } + + // Replace address + dims/strides in the smem descriptor for the given batch/group + template + CUTLASS_DEVICE void + aux_tensormaps_replace( + cute::TmaDescriptor& smem_desc, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if constexpr (EnableNullptr) { + if (params_ptr->is_nullptr) return; + } + Element* ptr_group = params_ptr->ptr_swiglu_out_array[next_batch]; + int32_t M = get<0>(problem_shape_mnkl); + int32_t N_out = params_ptr->N_out; + int32_t stride = params_ptr->stride_swiglu; + + cute::tma_descriptor_replace_addr_in_shared_mem(smem_desc, ptr_group); + + auto tensor_group = make_tensor( + make_gmem_ptr(ptr_group), + make_layout(make_shape(M, N_out), make_stride(stride, _1{}))); + constexpr int MaxRank = 5; + cute::array prob_shape = {1,1,1,1,1}; + cute::array prob_stride = {0,0,0,0,0}; + cute::detail::fill_tma_gmem_shape_stride( + params_ptr->tma_store_swiglu, tensor_group, prob_shape, prob_stride); + for (auto& s : prob_stride) { + s = (s * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem( + smem_desc, prob_shape, prob_stride); + } + + // --- Callback factories --- + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const&) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + cute::array const& result_m, + cute::array const& result_n_half, + RTensor&& tC_rFull, + SSwigluEpi&& sSwiglu_epi, + STensorS2G&& bSG_sSwiglu, + GTensorS2G&& bSG_gSwiglu, + Params const* params_ptr, + cute::TmaDescriptor const* kernel_swiglu_tensormap) + : result_m(result_m), + result_n_half(result_n_half), + tC_rFull(cute::forward(tC_rFull)), + sSwiglu_epi(cute::forward(sSwiglu_epi)), + bSG_sSwiglu(cute::forward(bSG_sSwiglu)), + bSG_gSwiglu(cute::forward(bSG_gSwiglu)), + params_ptr(params_ptr), + kernel_swiglu_tensormap(kernel_swiglu_tensormap) {} + + // --- Members --- + + cute::array result_m; // target M coordinate per result + cute::array result_n_half; // target N/2 coordinate per result + RTensor tC_rFull; // full-width register tensor (CPY,CPY_M,CPY_N) + SSwigluEpi sSwiglu_epi; // (EpiM, EpiNH, PIPE) unpartitioned smem tensor + STensorS2G bSG_sSwiglu; // (TMA,EPI_M,EPI_N,PIPE) TMA S2G smem partition + GTensorS2G bSG_gSwiglu; // (TMA,EPI_M,EPI_N,m,n) TMA S2G gmem partition + Params const* params_ptr; + cute::TmaDescriptor const* kernel_swiglu_tensormap; // pre-updated by kernel dispatch + + // --- Pipeline callbacks --- + + CUTLASS_DEVICE void + begin() {} + + CUTLASS_DEVICE bool + begin_sync_needed() const { + return false; + } + + // --- Compute and store --- + + // Accumulate each fragment into the full-width register tensor (Sm90AuxStore pattern). + // SwiGLU computation is deferred to postreduce() where the full subtile is available. + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, + int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + if constexpr (EnableNullptr) { + if (params_ptr->is_nullptr) return frg_input; + } + + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + auto tC_rFull_frg = recast>(coalesce(tC_rFull)); + tC_rFull_frg(epi_v) = convert_input(frg_input); + + return frg_input; + } + + // Compute SwiGLU from the fully-accumulated register tensor and scatter-write + // results to swizzled half-width smem. CLayout guarantees compact index 2k = gate + // (even N), 2k+1 = up (odd N), so we read bf16x2 pairs directly from tC_rFull. + CUTLASS_DEVICE void + postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + if constexpr (EnableNullptr) { + if (params_ptr->is_nullptr) return; + } + + if (!issue_smem_store) return; + + int pipe = store_iteration % Stages; + + auto const* full_ptr = reinterpret_cast<__nv_bfloat162 const*>(tC_rFull.data()); + + __nv_bfloat16 alpha_half_s = __float2bfloat16_rn( + params_ptr->swiglu_alpha * 0.5f); + __nv_bfloat16 limit_s = __float2bfloat16_rn(params_ptr->swiglu_limit); + __nv_bfloat162 alpha_half2 = __bfloat162bfloat162(alpha_half_s); + __nv_bfloat162 limit2 = __bfloat162bfloat162(limit_s); + __nv_bfloat162 neg_limit2 = __hneg2(limit2); + __nv_bfloat162 half2_val = __bfloat162bfloat162( + __float2bfloat16_rn(0.5f)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ResultsPerThread / 2; ++i) { + __nv_bfloat162 gu0 = full_ptr[2 * i]; + __nv_bfloat162 gu1 = full_ptr[2 * i + 1]; + __nv_bfloat162 g_pair = __halves2bfloat162( + __low2bfloat16(gu0), __low2bfloat16(gu1)); + __nv_bfloat162 u_pair = __halves2bfloat162( + __high2bfloat16(gu0), __high2bfloat16(gu1)); + + g_pair = __hmin2(g_pair, limit2); + u_pair = __hmax2(__hmin2(u_pair, limit2), neg_limit2); + + __nv_bfloat162 ag = __hmul2(alpha_half2, g_pair); + __nv_bfloat162 t = tanh_approx_bf16x2(ag); + __nv_bfloat162 s = __hfma2(t, half2_val, half2_val); + __nv_bfloat162 gs = __hmul2(g_pair, s); + __nv_bfloat162 result = __hfma2(gs, u_pair, gs); + + auto* result_bf16 = reinterpret_cast(&result); + sSwiglu_epi(result_m[2*i], result_n_half[2*i], pipe) = result_bf16[0]; + sSwiglu_epi(result_m[2*i+1], result_n_half[2*i+1], pipe) = result_bf16[1]; + } + } + + // S2G: TMA async store from smem to gmem using kernel-provided descriptor + CUTLASS_DEVICE void + tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + if constexpr (EnableNullptr) { + if (params_ptr->is_nullptr) return; + } + + if (issue_tma_store) { + int pipe_idx = store_iteration % Stages; + copy(params_ptr->tma_store_swiglu.with(kernel_swiglu_tensormap), + bSG_sSwiglu(_,_,_,pipe_idx), + bSG_gSwiglu(_,_,_,epi_m,epi_n)); + } + } + }; + + template < + bool ReferenceSrc, + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + // Smem view: [EpiM, EpiNH, Stages] — unpartitioned, for coordinate-based scatter writes + auto sSwiglu_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(smem_swiglu), SmemLayout{})); + + // --- Coordinate extraction from CLayout-aware tiled_copy --- + auto cEpi = make_identity_tensor(args.epi_tile); // (EpiM, EpiN) + auto thr_tc = args.tiled_copy.get_slice(args.thread_idx); + auto tC_cEpi = thr_tc.partition_S(cEpi); // (CPY, CPY_M, CPY_N) + + // Pre-compute target (M, N_half) for each SwiGLU result. + // Result k comes from gate/up pair at fragment indices (2k, 2k+1). + // Gate element (2k) gives the output coordinate; N_half = N / 2. + cute::array result_m; + cute::array result_n_half; + + constexpr int CPY = decltype(size<0>(tC_cEpi))::value; + constexpr int CPY_M = decltype(size<1>(tC_cEpi))::value; + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ResultsPerThread; ++k) { + int flat = 2 * k; + int c = flat % CPY; + int cm = (flat / CPY) % CPY_M; + int cn = flat / (CPY * CPY_M); + auto coord = tC_cEpi(c, cm, cn); + result_m[k] = get<0>(coord); + result_n_half[k] = get<1>(coord) / 2; + } + + // Full-width register tensor for Sm90AuxStore-style accumulation in visit() + auto tC_rFull = make_tensor(shape(tC_cEpi)); // (CPY,CPY_M,CPY_N) + + // --- TMA S2G partitions (unchanged) --- + int32_t N_out = params_ptr->N_out; + auto mSwiglu = params_ptr->tma_store_swiglu.get_tma_tensor( + make_shape(M, N_out)); + + constexpr int CtaM = size<0>(decltype(take<0,2>(args.tile_shape_mnk)){}); + constexpr int CtaNHalf = size<1>(decltype(take<0,2>(args.tile_shape_mnk)){}) / 2; + auto half_cta_tile = make_shape(Int{}, Int{}); + auto gSwiglu = local_tile(mSwiglu, half_cta_tile, make_coord(m, n)); + + auto gSwiglu_epi = flat_divide(gSwiglu, HalfEpiTile{}); + + auto thrblk_s2g = params_ptr->tma_store_swiglu.get_slice(_0{}); + auto bSG_sSwiglu = thrblk_s2g.partition_S(sSwiglu_epi); + auto bSG_gSwiglu = thrblk_s2g.partition_D(gSwiglu_epi); + + return ConsumerStoreCallbacks< + decltype(tC_rFull), decltype(sSwiglu_epi), decltype(bSG_sSwiglu), decltype(bSG_gSwiglu)>( + result_m, + result_n_half, + cute::move(tC_rFull), + cute::move(sSwiglu_epi), + cute::move(bSG_sSwiglu), + cute::move(bSG_gSwiglu), + params_ptr, + args.aux_store_tensormap); + } +}; + +} // namespace cutlass::epilogue::fusion \ No newline at end of file diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index 5601988cbd..120093908d 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -56,6 +56,41 @@ namespace cutlass { namespace epilogue { namespace collective { +namespace detail { + +// SFINAE trait: detect HasAuxTmaStore on a type or any Op in an Sm90VisitorImpl pack. +// Primary: false +template +struct has_aux_tma_store : cute::false_type {}; + +// Match types that directly define HasAuxTmaStore = true +template +struct has_aux_tma_store> + : cute::bool_constant {}; + +// Variadic OR over a pack +template +struct any_has_aux_tma_store : cute::bool_constant<(has_aux_tma_store::value || ...)> {}; + +// Match Sm90VisitorImpl — check all Ops in the EVT tree +template +struct has_aux_tma_store< + cutlass::epilogue::fusion::Sm90VisitorImpl, + cute::void_t::Params>> + : any_has_aux_tma_store {}; + +// Match Sm90TreeVisitor — delegates to Sm90VisitorImpl +template +struct has_aux_tma_store< + cutlass::epilogue::fusion::Sm90TreeVisitor, + cute::void_t::Params>> + : any_has_aux_tma_store {}; + +template +static constexpr bool has_aux_tma_store_v = has_aux_tma_store::value; + +} // namespace detail + ///////////////////////////////////////////////////////////////////////////////////////////////// template < @@ -230,6 +265,10 @@ class CollectiveEpilogue< cutlass::PipelineTmaStore>; using StorePipelineState = cutlass::PipelineState; + // Detect if the fusion tree contains an auxiliary TMA store (e.g. SwiGLU visitor) + static constexpr bool HasAuxTmaStore = detail::has_aux_tma_store_v; + static constexpr uint32_t NumAuxTmaTensors = HasAuxTmaStore ? NumEpilogueWarpGroups : 0; + struct SharedStorage { struct TensorStorage { using CollectiveStorage = cute::conditional_t { cute::TmaDescriptor smem_tensormap_C; cute::array smem_tensormap_D; + cute::array smem_tensormap_aux; } tensormaps; using PipelineStorage = typename LoadPipeline::SharedStorage; @@ -374,7 +414,7 @@ class CollectiveEpilogue< template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { - constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1) + NumAuxTmaTensors; auto descriptors_shape = cute::make_shape(sm_count, Int{}); constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies @@ -636,6 +676,7 @@ class CollectiveEpilogue< int thread_idx, TensorStorage& shared_tensors, TensorMapD const& store_tensormap, + cute::TmaDescriptor const* aux_store_tensormap = nullptr, int subtile_idx=-1) { using namespace cute; @@ -776,7 +817,8 @@ class CollectiveEpilogue< tRS_cD, residue_tRS_cD, tRS_rC, - thread_idx + thread_idx, + aux_store_tensormap }; auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); @@ -1039,6 +1081,86 @@ class CollectiveEpilogue< return cute::make_tuple(null_tma_desc); } + // Initialize auxiliary TMA store descriptor (e.g. SwiGLU output). + // Mirrors store_init() but targets the aux tensormap workspace slots. + CUTLASS_DEVICE auto + aux_store_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx, + int32_t warp_group_idx) { + if constexpr (!HasAuxTmaStore) { + cute::TmaDescriptor* null_desc = nullptr; + return cute::make_tuple(null_desc); + } + else { + int warp_idx_in_warp_group = canonical_warp_idx_sync() % NumWarpsPerWarpGroup; + if (warp_idx_in_warp_group == 0) { + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1) + NumAuxTmaTensors; + constexpr uint32_t aux_base = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); + Layout desc_layout = make_layout(make_shape(sm_count, Int{})); + Tensor gmem_tensormap = make_tensor(params.tensormaps, desc_layout); + + auto const& aux_tma_desc = fusion_callbacks.get_aux_tma_descriptor(); + Tensor p_desc = make_tensor(aux_tma_desc, Int<1>{}, Int<1>{}); + Tensor s_desc = make_tensor( + make_smem_ptr(&shared_tensormaps.smem_tensormap_aux[warp_group_idx]), + Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { + copy(recast(p_desc), recast(s_desc)); + } + __syncwarp(); + return cute::make_tuple(&gmem_tensormap(sm_idx, aux_base + warp_group_idx)); + } + cute::TmaDescriptor* null_desc = nullptr; + return cute::make_tuple(null_desc); + } + } + + // Patch the aux smem descriptor for the next batch/group + template + CUTLASS_DEVICE void + aux_tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& params, + cute::TmaDescriptor const* tensormap, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch, + int32_t warp_group_idx) { + if constexpr (HasAuxTmaStore) { + if (cute::elect_one_sync()) { + fusion_callbacks.aux_tensormaps_replace( + shared_tensormaps.smem_tensormap_aux[warp_group_idx], + problem_shape_mnkl, next_batch); + } + } + } + + // Wait for in-flight aux TMA stores, then publish updated descriptor to gmem + CUTLASS_DEVICE void + aux_tensormaps_cp_fence_release( + TensorMapStorage& shared_tensormaps, + cute::TmaDescriptor const* tensormap, + int32_t warp_group_idx) { + if constexpr (HasAuxTmaStore) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + cute::tma_descriptor_cp_fence_release( + tensormap, shared_tensormaps.smem_tensormap_aux[warp_group_idx]); + } + } + + // Acquire fence so subsequent TMA ops see the updated aux descriptor + CUTLASS_DEVICE void + aux_tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { + if constexpr (HasAuxTmaStore) { + cute::tma_descriptor_fence_acquire(tensormap); + } + } + // // Methods to perform different parts of TMA/Tensormap modifications // @@ -1052,7 +1174,7 @@ class CollectiveEpilogue< int32_t sm_idx, int32_t warp_group_idx) { - constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1) + NumAuxTmaTensors; Layout desc_layout = make_layout(make_shape(sm_count, Int{})); Tensor gmem_tensormap = make_tensor(params.tensormaps, desc_layout); // (SMs, NumInputTensors) diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index 5d4e9deb50..6149f24d1b 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -339,6 +339,7 @@ struct ConsumerStoreArgs { ThrResidue residue_tCcD; ThrSrcTensor & tCrC; int thread_idx; + cute::TmaDescriptor const* aux_store_tensormap; CUTLASS_DEVICE ConsumerStoreArgs( @@ -353,7 +354,8 @@ struct ConsumerStoreArgs { ThrCoordTensor tCcD, ThrResidue residue_tCcD, ThrSrcTensor & tCrC, - int thread_idx) + int thread_idx, + cute::TmaDescriptor const* aux_store_tensormap = nullptr) : problem_shape_mnkl(problem_shape_mnkl), tile_shape_mnk(tile_shape_mnk), tile_coord_mnkl(tile_coord_mnkl), @@ -365,7 +367,8 @@ struct ConsumerStoreArgs { tCcD(tCcD), residue_tCcD(residue_tCcD), tCrC(tCrC), - thread_idx(thread_idx) {} + thread_idx(thread_idx), + aux_store_tensormap(aux_store_tensormap) {} }; template @@ -619,6 +622,24 @@ struct Sm90TreeVisitor : Sm90VisitorImpl { template get_consumer_store_callbacks(args); return ConsumerStoreCallbacks(cute::move(callbacks_impl)); } + + // Forwarding methods for auxiliary TMA store descriptor management. + // The root NodeOp (last in the ops tuple) provides the actual implementation. + static constexpr int RootIdx = sizeof...(ChildOps); + + CUTLASS_DEVICE auto const& + get_aux_tma_descriptor() const { + return get(this->ops).get_aux_tma_descriptor(); + } + + template + CUTLASS_DEVICE void + aux_tensormaps_replace( + cute::TmaDescriptor& smem_desc, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + get(this->ops).aux_tensormaps_replace(smem_desc, problem_shape_mnkl, next_batch); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 1b9fcbf72c..f4ab152da3 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -915,6 +915,7 @@ class GemmUniversal< bool do_store_tail = false; // Get a copy of tensormaps auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx)); + auto aux_store_tensormap = get<0>(collective_epilogue.aux_store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx)); bool did_batch_change = true; constexpr bool IsEpiLoad = false; @@ -928,12 +929,23 @@ class GemmUniversal< work_tile_info.L_idx, consumer_warp_group_idx ); + collective_epilogue.aux_tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + aux_store_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + consumer_warp_group_idx + ); // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, consumer_warp_group_idx); + collective_epilogue.aux_tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + aux_store_tensormap, + consumer_warp_group_idx); } do { @@ -998,6 +1010,7 @@ class GemmUniversal< if (did_batch_change) { collective_epilogue.template tensormaps_fence_acquire(epi_store_tensormap); + collective_epilogue.aux_tensormaps_fence_acquire(aux_store_tensormap); } if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { @@ -1017,6 +1030,7 @@ class GemmUniversal< mma_thread_idx, shared_storage.tensors.epilogue, epi_store_tensormap, + aux_store_tensormap, work_tile_info.reduction_subtile_idx() ); @@ -1051,12 +1065,23 @@ class GemmUniversal< work_tile_info.L_idx, consumer_warp_group_idx ); + collective_epilogue.aux_tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + aux_store_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + consumer_warp_group_idx + ); // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, consumer_warp_group_idx); + collective_epilogue.aux_tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + aux_store_tensormap, + consumer_warp_group_idx); } } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index c828e82953..a26f5efe35 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -953,6 +953,7 @@ class GemmUniversal< bool do_store_tail = false; // Get a copy of tensormaps auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx)); + auto aux_store_tensormap = get<0>(collective_epilogue.aux_store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx)); bool did_batch_change = true; constexpr bool IsEpiLoad = false; @@ -966,12 +967,23 @@ class GemmUniversal< work_tile_info.L_idx, consumer_warp_group_idx ); + collective_epilogue.aux_tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + aux_store_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + consumer_warp_group_idx + ); // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, consumer_warp_group_idx); + collective_epilogue.aux_tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + aux_store_tensormap, + consumer_warp_group_idx); } do { @@ -1042,6 +1054,7 @@ class GemmUniversal< if (did_batch_change) { collective_epilogue.template tensormaps_fence_acquire(epi_store_tensormap); + collective_epilogue.aux_tensormaps_fence_acquire(aux_store_tensormap); } if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { @@ -1061,6 +1074,7 @@ class GemmUniversal< mma_thread_idx, shared_storage.tensors.epilogue, epi_store_tensormap, + aux_store_tensormap, work_tile_info.reduction_subtile_idx() ); @@ -1111,12 +1125,23 @@ class GemmUniversal< work_tile_info.L_idx, consumer_warp_group_idx ); + collective_epilogue.aux_tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + aux_store_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + consumer_warp_group_idx + ); // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, consumer_warp_group_idx); + collective_epilogue.aux_tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + aux_store_tensormap, + consumer_warp_group_idx); } }