Skip to content

Commit c0272ba

Browse files
committed
format & lint
Signed-off-by: Yaoyao Ding <[email protected]>
1 parent b24fc51 commit c0272ba

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

examples/blackwell_matmul/matmul_v4.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,19 @@ def consumer_advance(self):
4949
def consumer_release_barrier(self) -> RegisterTensor:
5050
return self.full_barriers[self.consumer_stage]
5151

52+
5253
class BlockInfo(tilus.Class):
53-
def __init__(self, m_size: int32, n_size: int, k_size: int, block_m: int, block_n: int, block_k: int, offset_m: int32, offset_n: int32):
54+
def __init__(
55+
self,
56+
m_size: int32,
57+
n_size: int,
58+
k_size: int,
59+
block_m: int,
60+
block_n: int,
61+
block_k: int,
62+
offset_m: int32,
63+
offset_n: int32,
64+
):
5465
self.m_size: int32 = m_size
5566
self.n_size: int = n_size
5667
self.k_size: int = k_size
@@ -60,6 +71,7 @@ def __init__(self, m_size: int32, n_size: int, k_size: int, block_m: int, block_
6071
self.offset_m: int32 = offset_m
6172
self.offset_n: int32 = offset_n
6273

74+
6375
class LoadPipeline(Pipeline):
6476
def __init__(
6577
self,
@@ -70,12 +82,18 @@ def __init__(
7082
num_stages=num_stages, producer_arrive_count=2, consumer_arrive_count=1
7183
)
7284
self.info: BlockInfo = info
73-
self.s_a = self.shared_tensor(dtype=float16, shape=[num_stages, info.block_m, info.block_k])
74-
self.s_b = self.shared_tensor(dtype=float16, shape=[num_stages, info.block_n, info.block_k])
85+
self.s_a = self.shared_tensor(
86+
dtype=float16, shape=[num_stages, info.block_m, info.block_k]
87+
)
88+
self.s_b = self.shared_tensor(
89+
dtype=float16, shape=[num_stages, info.block_n, info.block_k]
90+
)
7591

7692

7793
class LoadWorker(tilus.Class):
78-
def __init__(self, pipe: LoadPipeline, g_a: GlobalTensor, g_b: GlobalTensor, info: BlockInfo):
94+
def __init__(
95+
self, pipe: LoadPipeline, g_a: GlobalTensor, g_b: GlobalTensor, info: BlockInfo
96+
):
7997
self.pipe: LoadPipeline = pipe
8098
self.g_a: GlobalTensor = g_a
8199
self.g_b: GlobalTensor = g_b
@@ -113,15 +131,18 @@ class MmaWorker(tilus.Class):
113131
def __init__(self, pipe: LoadPipeline, info: BlockInfo):
114132
self.pipe: LoadPipeline = pipe
115133
self.info: BlockInfo = info
116-
self.t_acc = self.tcgen05.alloc(dtype=float32, shape=[info.block_m, info.block_n], init=0.0)
117-
134+
self.t_acc = self.tcgen05.alloc(
135+
dtype=float32, shape=[info.block_m, info.block_n], init=0.0
136+
)
118137

119138
def async_run(self):
120139
pipe = self.pipe
121140
s_a, s_b = pipe.s_a, pipe.s_b
122141
num_stages: int = pipe.num_stages
123142
with self.thread_group(thread_begin=32, num_threads=32):
124-
for offset_k in self.range(0, self.info.k_size, self.info.block_k, unroll=num_stages):
143+
for offset_k in self.range(
144+
0, self.info.k_size, self.info.block_k, unroll=num_stages
145+
):
125146
pipe.consumer_acquire()
126147
with self.single_thread():
127148
self.tcgen05.mma(

0 commit comments

Comments
 (0)