Skip to content

Commit 3ba36b5

Browse files
Reduce matmul-kernel creation time
1 parent 7addf83 commit 3ba36b5

File tree

3 files changed

+29
-24
lines changed

3 files changed

+29
-24
lines changed

Diff for: experimental/kernels/Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ build/train_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin
9595
build/ops.o: ops.cpp ops.hpp kernels.h llm.c
9696
mkdir -p build && $(CXX) $(CXXFLAGS) -c -o $@ $<
9797

98-
build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c
98+
build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c gpt2_webgpu.cpp ops.cpp
9999
mkdir -p build
100100
$(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu.cpp ops.cpp
101101

Diff for: experimental/kernels/ops.cpp

+25-22
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,29 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias,
157157
toCPU(ctx, dbias_t, dbias, c * sizeof(float));
158158
}
159159

160+
static constexpr size_t MATMUL_BT = 64;
161+
static constexpr size_t MATMUL_BC = 8;
162+
static constexpr size_t MATMUL_BOC = 64;
163+
static constexpr size_t MATMUL_TT = MATMUL_BT / MATMUL_BC;
164+
static constexpr size_t MATMUL_TOC = MATMUL_BOC / MATMUL_BC;
165+
static size_t MATMUL_num_threads = MATMUL_BT * MATMUL_BOC / (MATMUL_TT * MATMUL_TOC);
166+
static Shape MATMUL_wgSize = {MATMUL_num_threads, 1, 1};
167+
static std::string kShaderMatmul2DTiling_(kShaderMatmul2DTiling);
168+
static std::string kShaderMatmul2D(loopUnrolling(
169+
replaceAll(kShaderMatmul2DTiling_,
170+
{{"{{precision}}", toString(kf32)},
171+
{"{{BT}}", toString(MATMUL_BT)},
172+
{"{{BC}}", toString(MATMUL_BC)},
173+
{"{{BOC}}", toString(MATMUL_BOC)},
174+
{"{{TT}}", toString(MATMUL_TT)},
175+
{"{{TOC}}", toString(MATMUL_TOC)},
176+
{"{{NUM_TILEI}}", toString(MATMUL_BT * MATMUL_BC / MATMUL_num_threads)},
177+
{"{{NUM_TILEW}}", toString(MATMUL_BOC * MATMUL_BC / MATMUL_num_threads)}
178+
})
179+
)
180+
);
181+
182+
160183
void matmul_forward(Context& ctx, float* out,
161184
const float* inp, const float* weight, const float* bias,
162185
int B, int T, int C, int OC){
@@ -181,27 +204,8 @@ void matmul_forward(Context& ctx, float* out,
181204
assert ( (b*t) % 256 == 0 );
182205
int version = 1;
183206
if (version == 1){
184-
static constexpr size_t BT = 64;
185-
static constexpr size_t BC = 8;
186-
static constexpr size_t BOC = 64;
187-
static constexpr size_t TT = BT / BC;
188-
static constexpr size_t TOC = BOC / BC;
189-
size_t num_threads = BT * BOC / (TT * TOC);
190-
Shape wgSize = {num_threads, 1, 1}; // This is the same as BK * BK.
191-
Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)};
192-
193-
std::string codeString(kShaderMatmul2DTiling);
194-
replaceAll(codeString, {{"{{precision}}", toString(kf32)},
195-
{"{{BT}}", toString(BT)},
196-
{"{{BC}}", toString(BC)},
197-
{"{{BOC}}", toString(BOC)},
198-
{"{{TT}}", toString(TT)},
199-
{"{{TOC}}", toString(TOC)},
200-
{"{{NUM_TILEI}}", toString(BT * BC / num_threads)},
201-
{"{{NUM_TILEW}}", toString(BOC * BC / num_threads)}
202-
});
203-
std::string unrolledCode = loopUnrolling(codeString);
204-
Kernel op = createKernel(ctx, {unrolledCode, wgSize, kf32},
207+
Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)};
208+
Kernel op = createKernel(ctx, {kShaderMatmul2D, MATMUL_wgSize, kf32},
205209
Bindings{inp_i, weight_i, bias_i, out_o},
206210
nWorkgroups,
207211
/* params */
@@ -213,7 +217,6 @@ void matmul_forward(Context& ctx, float* out,
213217
});
214218
dispatchKernel(ctx, op, promise);
215219
wait(ctx, future);
216-
toCPU(ctx, out_o, out, b * t * oc * sizeof(float));
217220
} else {
218221
Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32},
219222
Bindings{inp_i, weight_i, bias_i, out_o},

Diff for: gpu.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -415,12 +415,14 @@ struct KernelCode {
415415
* @endcode
416416
* "f32"}});
417417
*/
418-
inline void
418+
inline const std::string
419419
replaceAll(std::string &str,
420420
const std::vector<std::pair<std::string, std::string>> &reps) {
421421
for (const auto &rep : reps) {
422422
replaceAll(str, rep.first, rep.second);
423423
}
424+
425+
return str;
424426
}
425427

426428
/**

0 commit comments

Comments
 (0)