@@ -157,6 +157,29 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias,
157
157
toCPU (ctx, dbias_t , dbias, c * sizeof (float ));
158
158
}
159
159
160
+ static constexpr size_t MATMUL_BT = 64 ;
161
+ static constexpr size_t MATMUL_BC = 8 ;
162
+ static constexpr size_t MATMUL_BOC = 64 ;
163
+ static constexpr size_t MATMUL_TT = MATMUL_BT / MATMUL_BC;
164
+ static constexpr size_t MATMUL_TOC = MATMUL_BOC / MATMUL_BC;
165
+ static size_t MATMUL_num_threads = MATMUL_BT * MATMUL_BOC / (MATMUL_TT * MATMUL_TOC);
166
+ static Shape MATMUL_wgSize = {MATMUL_num_threads, 1 , 1 };
167
+ static std::string kShaderMatmul2DTiling_ (kShaderMatmul2DTiling );
168
+ static std::string kShaderMatmul2D (loopUnrolling(
169
+ replaceAll (kShaderMatmul2DTiling_ ,
170
+ {{" {{precision}}" , toString (kf32)},
171
+ {" {{BT}}" , toString (MATMUL_BT)},
172
+ {" {{BC}}" , toString (MATMUL_BC)},
173
+ {" {{BOC}}" , toString (MATMUL_BOC)},
174
+ {" {{TT}}" , toString (MATMUL_TT)},
175
+ {" {{TOC}}" , toString (MATMUL_TOC)},
176
+ {" {{NUM_TILEI}}" , toString (MATMUL_BT * MATMUL_BC / MATMUL_num_threads)},
177
+ {" {{NUM_TILEW}}" , toString (MATMUL_BOC * MATMUL_BC / MATMUL_num_threads)}
178
+ })
179
+ )
180
+ );
181
+
182
+
160
183
void matmul_forward (Context& ctx, float * out,
161
184
const float * inp, const float * weight, const float * bias,
162
185
int B, int T, int C, int OC){
@@ -181,27 +204,8 @@ void matmul_forward(Context& ctx, float* out,
181
204
assert ( (b*t) % 256 == 0 );
182
205
int version = 1 ;
183
206
if (version == 1 ){
184
- static constexpr size_t BT = 64 ;
185
- static constexpr size_t BC = 8 ;
186
- static constexpr size_t BOC = 64 ;
187
- static constexpr size_t TT = BT / BC;
188
- static constexpr size_t TOC = BOC / BC;
189
- size_t num_threads = BT * BOC / (TT * TOC);
190
- Shape wgSize = {num_threads, 1 , 1 }; // This is the same as BK * BK.
191
- Shape nWorkgroups = {b, cdiv (T, BT), cdiv (OC, BOC)};
192
-
193
- std::string codeString (kShaderMatmul2DTiling );
194
- replaceAll (codeString, {{" {{precision}}" , toString (kf32)},
195
- {" {{BT}}" , toString (BT)},
196
- {" {{BC}}" , toString (BC)},
197
- {" {{BOC}}" , toString (BOC)},
198
- {" {{TT}}" , toString (TT)},
199
- {" {{TOC}}" , toString (TOC)},
200
- {" {{NUM_TILEI}}" , toString (BT * BC / num_threads)},
201
- {" {{NUM_TILEW}}" , toString (BOC * BC / num_threads)}
202
- });
203
- std::string unrolledCode = loopUnrolling (codeString);
204
- Kernel op = createKernel (ctx, {unrolledCode, wgSize, kf32},
207
+ Shape nWorkgroups = {b, cdiv (T, MATMUL_BT), cdiv (OC, MATMUL_BOC)};
208
+ Kernel op = createKernel (ctx, {kShaderMatmul2D , MATMUL_wgSize, kf32},
205
209
Bindings{inp_i, weight_i, bias_i, out_o},
206
210
nWorkgroups,
207
211
/* params */
@@ -213,7 +217,6 @@ void matmul_forward(Context& ctx, float* out,
213
217
});
214
218
dispatchKernel (ctx, op, promise);
215
219
wait (ctx, future);
216
- toCPU (ctx, out_o, out, b * t * oc * sizeof (float ));
217
220
} else {
218
221
Kernel op = createKernel (ctx, {kShaderMatmul , 256 , kf32},
219
222
Bindings{inp_i, weight_i, bias_i, out_o},
0 commit comments