@@ -49,8 +49,19 @@ def consumer_advance(self):
4949 def consumer_release_barrier (self ) -> RegisterTensor :
5050 return self .full_barriers [self .consumer_stage ]
5151
52+
5253class BlockInfo (tilus .Class ):
53- def __init__ (self , m_size : int32 , n_size : int , k_size : int , block_m : int , block_n : int , block_k : int , offset_m : int32 , offset_n : int32 ):
54+ def __init__ (
55+ self ,
56+ m_size : int32 ,
57+ n_size : int ,
58+ k_size : int ,
59+ block_m : int ,
60+ block_n : int ,
61+ block_k : int ,
62+ offset_m : int32 ,
63+ offset_n : int32 ,
64+ ):
5465 self .m_size : int32 = m_size
5566 self .n_size : int = n_size
5667 self .k_size : int = k_size
@@ -60,6 +71,7 @@ def __init__(self, m_size: int32, n_size: int, k_size: int, block_m: int, block_
6071 self .offset_m : int32 = offset_m
6172 self .offset_n : int32 = offset_n
6273
74+
6375class LoadPipeline (Pipeline ):
6476 def __init__ (
6577 self ,
@@ -70,12 +82,18 @@ def __init__(
7082 num_stages = num_stages , producer_arrive_count = 2 , consumer_arrive_count = 1
7183 )
7284 self .info : BlockInfo = info
73- self .s_a = self .shared_tensor (dtype = float16 , shape = [num_stages , info .block_m , info .block_k ])
74- self .s_b = self .shared_tensor (dtype = float16 , shape = [num_stages , info .block_n , info .block_k ])
85+ self .s_a = self .shared_tensor (
86+ dtype = float16 , shape = [num_stages , info .block_m , info .block_k ]
87+ )
88+ self .s_b = self .shared_tensor (
89+ dtype = float16 , shape = [num_stages , info .block_n , info .block_k ]
90+ )
7591
7692
7793class LoadWorker (tilus .Class ):
78- def __init__ (self , pipe : LoadPipeline , g_a : GlobalTensor , g_b : GlobalTensor , info : BlockInfo ):
94+ def __init__ (
95+ self , pipe : LoadPipeline , g_a : GlobalTensor , g_b : GlobalTensor , info : BlockInfo
96+ ):
7997 self .pipe : LoadPipeline = pipe
8098 self .g_a : GlobalTensor = g_a
8199 self .g_b : GlobalTensor = g_b
@@ -113,15 +131,18 @@ class MmaWorker(tilus.Class):
113131 def __init__ (self , pipe : LoadPipeline , info : BlockInfo ):
114132 self .pipe : LoadPipeline = pipe
115133 self .info : BlockInfo = info
116- self .t_acc = self .tcgen05 .alloc (dtype = float32 , shape = [info .block_m , info .block_n ], init = 0.0 )
117-
134+ self .t_acc = self .tcgen05 .alloc (
135+ dtype = float32 , shape = [info .block_m , info .block_n ], init = 0.0
136+ )
118137
119138 def async_run (self ):
120139 pipe = self .pipe
121140 s_a , s_b = pipe .s_a , pipe .s_b
122141 num_stages : int = pipe .num_stages
123142 with self .thread_group (thread_begin = 32 , num_threads = 32 ):
124- for offset_k in self .range (0 , self .info .k_size , self .info .block_k , unroll = num_stages ):
143+ for offset_k in self .range (
144+ 0 , self .info .k_size , self .info .block_k , unroll = num_stages
145+ ):
125146 pipe .consumer_acquire ()
126147 with self .single_thread ():
127148 self .tcgen05 .mma (
0 commit comments