15
15
"""
16
16
17
17
from types import SimpleNamespace
18
- from typing import Optional
18
+ from typing import List , Optional
19
19
20
20
import torch
21
21
import triton
@@ -104,6 +104,7 @@ def cutlass_segment_gemm(
104
104
y : torch .Tensor ,
105
105
empty_x_data : torch .Tensor ,
106
106
weight_column_major : bool ,
107
+ plan_info_vec : List [int ],
107
108
) -> None :
108
109
with x_data .device as device :
109
110
module .cutlass_segment_gemm (
@@ -117,6 +118,7 @@ def cutlass_segment_gemm(
117
118
y_ld ,
118
119
empty_x_data ,
119
120
weight_column_major ,
121
+ plan_info_vec ,
120
122
get_cuda_stream (device ),
121
123
)
122
124
@@ -139,6 +141,7 @@ def _fake_cutlass_segment_gemm(
139
141
# Register the module
140
142
_gemm_module = SimpleNamespace (
141
143
bmm_fp8 = bmm_fp8 ,
144
+ plan = module .cutlass_segment_gemm_plan ,
142
145
cutlass_segment_gemm = cutlass_segment_gemm ,
143
146
)
144
147
@@ -181,6 +184,7 @@ def cutlass_segment_gemm_sm90(
181
184
y : torch .Tensor ,
182
185
empty_x_data : torch .Tensor ,
183
186
weight_column_major : bool ,
187
+ plan_info_vec : List [int ],
184
188
) -> None :
185
189
with x_data .device as device :
186
190
module .cutlass_segment_gemm_sm90 (
@@ -195,6 +199,7 @@ def cutlass_segment_gemm_sm90(
195
199
y_stride ,
196
200
empty_x_data ,
197
201
weight_column_major ,
202
+ plan_info_vec ,
198
203
get_cuda_stream (device ),
199
204
)
200
205
@@ -212,6 +217,7 @@ def _fake_cutlass_segment_gemm_sm90(
212
217
y : torch .Tensor ,
213
218
empty_x_data : torch .Tensor ,
214
219
weight_column_major : bool ,
220
+ plan_info_vec : List [int ],
215
221
) -> None :
216
222
pass
217
223
@@ -444,6 +450,8 @@ class SegmentGEMMWrapper:
444
450
>>> x = torch.randn(10, 128, device="cuda", dtype=torch.float16)
445
451
>>> # create weight tensor with 4 weights, each with 128 input and 256 output channels, column major
446
452
>>> weights = torch.randn(4, 256, 128, device="cuda", dtype=torch.float16)
453
+ >>> # set the number of CTAs to 64
454
+ >>> segment_gemm.plan(64)
447
455
>>> # compute the segment GEMM
448
456
>>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens)
449
457
>>> y.shape
@@ -512,6 +520,29 @@ def reset_workspace_buffer(
512
520
self ._float_workspace_buffer = float_workspace_buffer
513
521
self ._int_workspace_buffer = int_workspace_buffer
514
522
523
+ def plan (self , num_ctas : int = 0 ) -> None :
524
+ r"""Plan gemm for given num_ctas.
525
+
526
+ Parameters
527
+ ----------
528
+ num_ctas: int
529
+ The number of CTAs to run gemm kernel. If equal to 0 or greater than
530
+ the number of CTAs on device, it will be set to the number of CTAs on device.
531
+
532
+
533
+ Note
534
+ ----
535
+ The :meth:`plan` method should be called before any :meth:`run`.
536
+
537
+ The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``.
538
+ """
539
+ if num_ctas < 0 :
540
+ raise ValueError ("Num_ctas should be greater than or equal to 0." )
541
+
542
+ self ._plan_info = get_gemm_module ().plan (
543
+ num_ctas ,
544
+ )
545
+
515
546
def run (
516
547
self ,
517
548
x : torch .Tensor ,
@@ -629,6 +660,7 @@ def run(
629
660
y , # for torch compile mutates_args
630
661
empty_x_data , # for kernel type dispatch
631
662
weight_column_major ,
663
+ self ._plan_info ,
632
664
)
633
665
case "sm80" :
634
666
(
@@ -660,6 +692,7 @@ def run(
660
692
y ,
661
693
empty_x_data ,
662
694
weight_column_major ,
695
+ self ._plan_info ,
663
696
)
664
697
case _:
665
698
raise ValueError (f"Unsupported gemm backend: { backend } " )
0 commit comments