Skip to content

Commit 841b423

Browse files
committed
feat: implement SM-Constrained GEMM API
As requested in #591, this PR implements the `plan` function of GEMM with `num_ctas` as an argument to specify the grid size.
1 parent a0e99a3 commit 841b423

12 files changed

+163
-16
lines changed

csrc/flashinfer_gemm_ops.cu

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::T
2121
void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr,
2222
at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld,
2323
at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major,
24-
int64_t cuda_stream);
24+
std::vector<int64_t> plan_info_vec, int64_t cuda_stream);
25+
26+
std::vector<int64_t> CutlassSegmentGEMMPlan(unsigned int num_ctas);
2527

2628
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
2729
m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM");
30+
m.def("cutlass_segment_gemm_plan", &CutlassSegmentGEMMPlan, "Cutlass Segment GEMM Plan");
2831
m.def("bmm_fp8", &bmm_fp8, "BMM FP8");
2932
}

csrc/flashinfer_gemm_sm90_ops.cu

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "pytorch_extension_utils.h"
17+
18+
void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
19+
at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr,
20+
at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride,
21+
at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major,
22+
std::vector<int64_t> plan_info_vec, int64_t cuda_stream);
23+
24+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
25+
m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90,
26+
"Cutlass Segment GEMM operator for SM90");
27+
}

csrc/flashinfer_ops.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::T
6363
void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr,
6464
at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld,
6565
at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major,
66-
int64_t cuda_stream);
66+
std::vector<int64_t> plan_info_vec, int64_t cuda_stream);
67+
std::vector<int64_t> CutlassSegmentGEMMPlan(unsigned int num_ctas);
6768

6869
//========== norm ==========
6970

@@ -223,6 +224,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
223224

224225
// gemm
225226
m.def("bmm_fp8", &bmm_fp8, "BMM FP8");
227+
m.def("cutlass_segment_gemm_plan", &CutlassSegmentGEMMPlan, "Cutlass Segment GEMM plan");
226228
m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator");
227229

228230
// norm

csrc/flashinfer_ops_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo
1919
at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr,
2020
at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride,
2121
at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major,
22-
int64_t cuda_stream);
22+
std::vector<int64_t> plan_info_vec, int64_t cuda_stream);
2323

2424
void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q, at::Tensor k,
2525
at::Tensor v,

csrc/group_gemm.cu

+13-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* limitations under the License.
1515
*/
1616
#include <flashinfer/gemm/group_gemm.cuh>
17+
#include <flashinfer/gemm/scheduler.cuh>
1718

1819
#include "pytorch_extension_utils.h"
1920

@@ -23,7 +24,9 @@ using namespace flashinfer::group_gemm;
2324
void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr,
2425
at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld,
2526
at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major,
26-
int64_t cuda_stream) {
27+
std::vector<int64_t> plan_info_vec, int64_t cuda_stream) {
28+
GemmPlanInfo plan_info;
29+
plan_info.FromVector(plan_info_vec);
2730
unsigned int batch_size = x_ptr.size(0);
2831

2932
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
@@ -32,9 +35,17 @@ void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at
3235
auto status = CutlassSegmentGEMMRun<cutlass_t>(
3336
workspace_buffer.data_ptr(), workspace_buffer.element_size() * workspace_buffer.size(0),
3437
all_problems.data_ptr(), batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(),
35-
x_ld.data_ptr(), w_ld.data_ptr(), y_ld.data_ptr(), weight_column_major, stream);
38+
x_ld.data_ptr(), w_ld.data_ptr(), y_ld.data_ptr(), weight_column_major, plan_info.num_ctas,
39+
stream);
3640
TORCH_CHECK(status == cudaSuccess,
3741
"Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status));
3842
return true;
3943
});
4044
}
45+
46+
std::vector<int64_t> CutlassSegmentGEMMPlan(unsigned int num_ctas) {
47+
GemmPlanInfo plan_info;
48+
cudaError_t status = GemmPlan(num_ctas, plan_info);
49+
TORCH_CHECK(status == cudaSuccess, "GemmPlan failed with error: ", cudaGetErrorString(status));
50+
return plan_info.ToVector();
51+
}

csrc/group_gemm_sm90.cu

+6-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* limitations under the License.
1515
*/
1616
#include <flashinfer/gemm/group_gemm_sm90.cuh>
17+
#include <flashinfer/gemm/scheduler.cuh>
1718

1819
#include "pytorch_extension_utils.h"
1920

@@ -24,7 +25,9 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo
2425
at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr,
2526
at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride,
2627
at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major,
27-
int64_t cuda_stream) {
28+
std::vector<int64_t> plan_info_vec, int64_t cuda_stream) {
29+
GemmPlanInfo plan_info;
30+
plan_info.FromVector(plan_info_vec);
2831
unsigned int batch_size = x_ptr.size(0);
2932
auto device = float_workspace_buffer.device();
3033

@@ -37,7 +40,8 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo
3740
int_workspace_buffer.data_ptr(),
3841
int_workspace_buffer.element_size() * int_workspace_buffer.size(0), all_problems.data_ptr(),
3942
batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(), x_stride.data_ptr(),
40-
weight_stride.data_ptr(), y_stride.data_ptr(), weight_column_major, stream);
43+
weight_stride.data_ptr(), y_stride.data_ptr(), weight_column_major, plan_info.num_ctas,
44+
stream);
4145
TORCH_CHECK(status == cudaSuccess,
4246
"Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status));
4347
return true;

flashinfer/gemm.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717
from types import SimpleNamespace
18-
from typing import Optional
18+
from typing import List, Optional
1919

2020
import torch
2121
import triton
@@ -104,6 +104,7 @@ def cutlass_segment_gemm(
104104
y: torch.Tensor,
105105
empty_x_data: torch.Tensor,
106106
weight_column_major: bool,
107+
plan_info_vec: List[int],
107108
) -> None:
108109
with x_data.device as device:
109110
module.cutlass_segment_gemm(
@@ -117,6 +118,7 @@ def cutlass_segment_gemm(
117118
y_ld,
118119
empty_x_data,
119120
weight_column_major,
121+
plan_info_vec,
120122
get_cuda_stream(device),
121123
)
122124

@@ -139,6 +141,7 @@ def _fake_cutlass_segment_gemm(
139141
# Register the module
140142
_gemm_module = SimpleNamespace(
141143
bmm_fp8=bmm_fp8,
144+
plan=module.cutlass_segment_gemm_plan,
142145
cutlass_segment_gemm=cutlass_segment_gemm,
143146
)
144147

@@ -181,6 +184,7 @@ def cutlass_segment_gemm_sm90(
181184
y: torch.Tensor,
182185
empty_x_data: torch.Tensor,
183186
weight_column_major: bool,
187+
plan_info_vec: List[int],
184188
) -> None:
185189
with x_data.device as device:
186190
module.cutlass_segment_gemm_sm90(
@@ -195,6 +199,7 @@ def cutlass_segment_gemm_sm90(
195199
y_stride,
196200
empty_x_data,
197201
weight_column_major,
202+
plan_info_vec,
198203
get_cuda_stream(device),
199204
)
200205

@@ -212,6 +217,7 @@ def _fake_cutlass_segment_gemm_sm90(
212217
y: torch.Tensor,
213218
empty_x_data: torch.Tensor,
214219
weight_column_major: bool,
220+
plan_info_vec: List[int],
215221
) -> None:
216222
pass
217223

@@ -444,6 +450,8 @@ class SegmentGEMMWrapper:
444450
>>> x = torch.randn(10, 128, device="cuda", dtype=torch.float16)
445451
>>> # create weight tensor with 4 weights, each with 128 input and 256 output channels, column major
446452
>>> weights = torch.randn(4, 256, 128, device="cuda", dtype=torch.float16)
453+
>>> # set the number of CTAs to 64
454+
>>> segment_gemm.plan(64)
447455
>>> # compute the segment GEMM
448456
>>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens)
449457
>>> y.shape
@@ -512,6 +520,29 @@ def reset_workspace_buffer(
512520
self._float_workspace_buffer = float_workspace_buffer
513521
self._int_workspace_buffer = int_workspace_buffer
514522

523+
def plan(self, num_ctas: int = 0) -> None:
524+
r"""Plan gemm for given num_ctas.
525+
526+
Parameters
527+
----------
528+
num_ctas: int
529+
The number of CTAs to run gemm kernel. If equal to 0 or greater than
530+
the number of CTAs on device, it will be set to the number of CTAs on device.
531+
532+
533+
Note
534+
----
535+
The :meth:`plan` method should be called before any :meth:`run`.
536+
537+
The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``.
538+
"""
539+
if num_ctas < 0:
540+
raise ValueError("Num_ctas should be greater than or equal to 0.")
541+
542+
self._plan_info = get_gemm_module().plan(
543+
num_ctas,
544+
)
545+
515546
def run(
516547
self,
517548
x: torch.Tensor,
@@ -629,6 +660,7 @@ def run(
629660
y, # for torch compile mutates_args
630661
empty_x_data, # for kernel type dispatch
631662
weight_column_major,
663+
self._plan_info,
632664
)
633665
case "sm80":
634666
(
@@ -660,6 +692,7 @@ def run(
660692
y,
661693
empty_x_data,
662694
weight_column_major,
695+
self._plan_info,
663696
)
664697
case _:
665698
raise ValueError(f"Unsupported gemm backend: {backend}")

flashinfer/jit/core.py

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def load_cuda_ops(
9393
"--threads",
9494
"4",
9595
"-use_fast_math",
96+
"-DFLASHINFER_ENABLE_F16",
9697
"-DFLASHINFER_ENABLE_BF16",
9798
"-DFLASHINFER_ENABLE_FP8_E4M3",
9899
"-DFLASHINFER_ENABLE_FP8_E5M2",

include/flashinfer/gemm/group_gemm.cuh

+5-5
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ template <typename DType>
3838
cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffer_size_in_bytes,
3939
void* all_problems, unsigned int batch_size, void* x, void* w,
4040
void* y, void* x_ld, void* w_ld, void* y_ld,
41-
bool weight_column_major, cudaStream_t stream) {
41+
bool weight_column_major, int num_ctas, cudaStream_t stream) {
4242
using cutlass::epilogue::thread::LinearCombination;
4343
using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle;
4444
DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, {
@@ -69,10 +69,10 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe
6969
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
7070
typename GemmGrouped::Arguments args(
7171
reinterpret_cast<cutlass::gemm::GemmCoord*>(all_problems), (int)batch_size,
72-
/*threadblock_count=*/4, epilogue_op, static_cast<DType**>(x), static_cast<DType**>(w),
73-
static_cast<DType**>(y), static_cast<DType**>(y), reinterpret_cast<int64_t*>(x_ld),
74-
reinterpret_cast<int64_t*>(w_ld), reinterpret_cast<int64_t*>(y_ld),
75-
reinterpret_cast<int64_t*>(y_ld));
72+
/*threadblock_count=*/num_ctas, epilogue_op, static_cast<DType**>(x),
73+
static_cast<DType**>(w), static_cast<DType**>(y), static_cast<DType**>(y),
74+
reinterpret_cast<int64_t*>(x_ld), reinterpret_cast<int64_t*>(w_ld),
75+
reinterpret_cast<int64_t*>(y_ld), reinterpret_cast<int64_t*>(y_ld));
7676

7777
GemmGrouped gemm;
7878
auto status = gemm.initialize(args, nullptr, stream);

include/flashinfer/gemm/group_gemm_sm90.cuh

+2-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si
5353
void* int_buffer, size_t int_buffer_size_in_bytes,
5454
void* all_problems, unsigned int batch_size, void* x, void* w,
5555
void* y, void* x_stride, void* w_stride, void* y_stride,
56-
bool weight_column_major, cudaStream_t stream) {
56+
bool weight_column_major, int num_ctas, cudaStream_t stream) {
5757
auto compute_capacity = GetCudaComputeCapability();
5858
if (compute_capacity.first < 9) {
5959
std::cerr << "CutlassSegmentGEMMSM90Run requires compute capability of at least 9.0"
@@ -121,8 +121,7 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si
121121

122122
cutlass::KernelHardwareInfo hw_info;
123123
cudaGetDevice(&hw_info.device_id);
124-
hw_info.sm_count =
125-
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
124+
hw_info.sm_count = num_ctas;
126125

127126
typename Gemm::EpilogueOutputOp::Params params;
128127
params =

include/flashinfer/gemm/scheduler.cuh

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_GEMM_SCHEDULER_CUH_
17+
#define FLASHINFER_GEMM_SCHEDULER_CUH_
18+
19+
#include <cuda_runtime_api.h>
20+
21+
#include <algorithm>
22+
#include <cstddef>
23+
#include <cstdint>
24+
#include <sstream>
25+
#include <vector>
26+
27+
#include "../utils.cuh"
28+
29+
namespace flashinfer {
30+
31+
struct GemmPlanInfo {
32+
int64_t num_ctas;
33+
34+
GemmPlanInfo() : num_ctas(0) {}
35+
36+
// convert GemmPlanInfo to std::vector<int64_t>
37+
std::vector<int64_t> ToVector() const { return {num_ctas}; }
38+
39+
// From std::vector<int64_t> to GemmPlanInfo
40+
void FromVector(const std::vector<int64_t>& vec) {
41+
if (vec.size() != 1) {
42+
std::ostringstream err_msg;
43+
err_msg << "GemmPlanInfo::FromVector: vec.size() should be 1, but got " << vec.size();
44+
FLASHINFER_ERROR(err_msg.str());
45+
}
46+
num_ctas = vec[0];
47+
}
48+
};
49+
50+
inline cudaError_t GemmPlan(uint32_t num_ctas, GemmPlanInfo& plan_info) {
51+
int dev_id = 0;
52+
int num_sms = 0;
53+
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
54+
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
55+
if (num_ctas > 0 && num_ctas < num_sms) {
56+
plan_info.num_ctas = num_ctas;
57+
} else {
58+
plan_info.num_ctas = num_sms;
59+
}
60+
return cudaSuccess;
61+
}
62+
63+
} // namespace flashinfer
64+
#endif // FLASHINFER_GEMM_SCHEDULER_CUH_

tests/test_group_gemm.py

+3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
@pytest.mark.parametrize("dtype", DTYPES)
3434
@pytest.mark.parametrize("device", CUDA_DEVICES)
3535
@pytest.mark.parametrize("backend", ["auto", "sm90", "sm80"])
36+
@pytest.mark.parametrize("num_ctas", [0, 4, 16, 64])
3637
def test_segment_gemm(
3738
batch_size,
3839
num_rows_per_batch,
@@ -43,6 +44,7 @@ def test_segment_gemm(
4344
dtype,
4445
device,
4546
backend,
47+
num_ctas,
4648
):
4749
if batch_size * num_rows_per_batch > 8192:
4850
pytest.skip("batch_size * num_rows_per_batch too large for test.")
@@ -64,6 +66,7 @@ def test_segment_gemm(
6466
weight = torch.randn(batch_size, d_out, d_in, dtype=dtype).to(device)
6567
else:
6668
weight = torch.randn(batch_size, d_in, d_out, dtype=dtype).to(device)
69+
segment_gemm.plan(num_ctas)
6770
y = segment_gemm.run(
6871
x,
6972
weight,

0 commit comments

Comments
 (0)