Skip to content

Commit f4e1683

Browse files
authored
Merge pull request #68 from junjihashimoto/feature/cache
Implement kernel cache
2 parents 3c06137 + 30ed026 commit f4e1683

File tree

5 files changed

+1357
-426
lines changed

5 files changed

+1357
-426
lines changed

experimental/kernels/Makefile

+3-3
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ endef
7979
build/test_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin
8080
mkdir -p build
8181
$(call preprocess_file)
82-
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o
82+
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o -g
8383

8484
build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bin
8585
mkdir -p build
@@ -90,12 +90,12 @@ build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bi
9090
build/train_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin
9191
mkdir -p build
9292
$(call preprocess_file)
93-
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o
93+
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o -g
9494

9595
build/ops.o: ops.cpp ops.hpp kernels.h llm.c
9696
mkdir -p build && $(CXX) $(CXXFLAGS) -c -o $@ $<
9797

98-
build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c
98+
build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c gpt2_webgpu.cpp ops.cpp
9999
mkdir -p build
100100
$(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu.cpp ops.cpp
101101

experimental/kernels/kernels.h

+98
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,104 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
309309
}
310310
}
311311
}
312+
313+
)";
314+
315+
316+
static const char *kShaderMatmul2DTiling = R"(
317+
@group(0) @binding(0) var<storage, read_write> inp : array<{{precision}}>;
318+
@group(0) @binding(1) var<storage, read_write> weight : array<{{precision}}>;
319+
@group(0) @binding(2) var<storage, read_write> bias : array<{{precision}}>;
320+
@group(0) @binding(3) var<storage, read_write> out : array<{{precision}}>;
321+
@group(0) @binding(4) var<uniform> params : Params;
322+
struct Params {
323+
B: u32,
324+
T: u32,
325+
C: u32,
326+
OC: u32,
327+
};
328+
var<workgroup> tileInp: array<{{precision}}, {{BT}} * {{BC}}>;
329+
var<workgroup> tileWeight: array<{{precision}}, {{BOC}} * {{BC}}>;
330+
331+
@compute @workgroup_size({{workgroupSize}})
332+
fn main(
333+
@builtin(local_invocation_id) localID : vec3<u32>,
334+
@builtin(workgroup_id) groupid : vec3<u32>) {
335+
let B : u32 = params.B;
336+
let T : u32 = params.T;
337+
let C : u32 = params.C;
338+
let OC : u32 = params.OC;
339+
340+
var localT: array<{{precision}}, {{TT}}>;
341+
var localOC: array<{{precision}}, {{TOC}}>;
342+
343+
let outB: u32 = groupid.x;
344+
let outT: u32 = groupid.y;
345+
let outOC: u32 = groupid.z;
346+
let numThread: u32 = ({{BT}} * {{BOC}}) / ({{TT}} * {{TOC}});
347+
348+
// position of the first c element computed by the thread
349+
let threadRow: u32 = (localID.x / ({{BOC}} / {{TOC}})) * {{TT}};
350+
let threadCol: u32 = (localID.x % ({{BOC}} / {{TOC}})) * {{TOC}};
351+
352+
// inpPtr and weightPtr are the starting positions of the tiles in a and b,
353+
// incremented in the bkidx loop.
354+
// outPtr is the starting position of the tile in c which is fixed.
355+
356+
var inpPtr = (outB * T + outT * {{BT}}) * C; // BTC
357+
var weightPtr = outOC * {{BOC}} * C; //OCC
358+
var threadResults: array<{{precision}}, {{TT}} * {{TOC}}>;
359+
let outPtr = (outB * T + outT * {{BT}}) * OC + outOC * {{BOC}}; //BTOC
360+
let biasPtr = outOC * {{BOC}};
361+
362+
for (var bkidx: u32 = 0; bkidx < C; bkidx += {{BC}}) {
363+
// Load BC x BOC by numThread(BT * BOC / (TT * TOC))
364+
// The number of iteration == BC * BOC / (BT * BOC / (TT * TOC))
365+
for (var idx: u32 = 0; idx < {{NUM_TILEW}}; idx++) {
366+
tileWeight[localID.x + idx * numThread] = weight[weightPtr + ((localID.x + idx * numThread) / {{BC}}) * C + ((localID.x + idx * numThread) % {{BC}})];
367+
}
368+
weightPtr += {{BC}};
369+
370+
// Load tile
371+
// Load BT x BC by numThread(BT * BOC / (TT * TOC))
372+
// The number of iteration == BT * BC / (BT * BOC / (TT * TOC))
373+
for (var idx: u32 = 0; idx < {{NUM_TILEI}}; idx++) {
374+
tileInp[localID.x + idx * numThread] = inp[inpPtr + ((localID.x + idx * numThread) / {{BC}}) * C + (localID.x + idx * numThread) % {{BC}}];
375+
}
376+
inpPtr += {{BC}};
377+
378+
workgroupBarrier();
379+
// Compute tile
380+
for (var dotIdx: u32 = 0; dotIdx < {{BC}}; dotIdx = dotIdx + 1) {
381+
for (var idx: u32 = 0; idx < {{TT}}; idx++) {
382+
localT[idx] = tileInp[(threadRow + idx) * {{BC}} + dotIdx];
383+
}
384+
for (var idx: u32 = 0; idx < {{TOC}}; idx++) {
385+
localOC[idx] = tileWeight[(threadCol + idx) * {{BC}} + dotIdx];
386+
}
387+
for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
388+
for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
389+
threadResults[resIdxT * {{TOC}} + resIdxOC] += localT[resIdxT] * localOC[resIdxOC];
390+
}
391+
}
392+
}
393+
workgroupBarrier();
394+
}
395+
396+
if (arrayLength(&bias) == 1) {
397+
for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
398+
for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
399+
out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC];
400+
}
401+
}
402+
} else {
403+
for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
404+
for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
405+
out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC] + bias[biasPtr + threadCol + resIdxOC];
406+
}
407+
}
408+
}
409+
}
312410
)";
313411

314412
static const char *kShaderMatmulBackward = R"(

0 commit comments

Comments
 (0)