Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement kernel cache #68

Merged
merged 13 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions experimental/kernels/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ endef
build/test_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin
mkdir -p build
$(call preprocess_file)
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o -g

build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bin
mkdir -p build
Expand All @@ -90,12 +90,12 @@ build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bi
build/train_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin
mkdir -p build
$(call preprocess_file)
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o -g

build/ops.o: ops.cpp ops.hpp kernels.h llm.c
mkdir -p build && $(CXX) $(CXXFLAGS) -c -o $@ $<

build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c
build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c gpt2_webgpu.cpp ops.cpp
mkdir -p build
$(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu.cpp ops.cpp

Expand Down
98 changes: 98 additions & 0 deletions experimental/kernels/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,104 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
}
}
}

)";


static const char *kShaderMatmul2DTiling = R"(
@group(0) @binding(0) var<storage, read_write> inp : array<{{precision}}>;
@group(0) @binding(1) var<storage, read_write> weight : array<{{precision}}>;
@group(0) @binding(2) var<storage, read_write> bias : array<{{precision}}>;
@group(0) @binding(3) var<storage, read_write> out : array<{{precision}}>;
@group(0) @binding(4) var<uniform> params : Params;
struct Params {
B: u32,
T: u32,
C: u32,
OC: u32,
};
var<workgroup> tileInp: array<{{precision}}, {{BT}} * {{BC}}>;
var<workgroup> tileWeight: array<{{precision}}, {{BOC}} * {{BC}}>;

@compute @workgroup_size({{workgroupSize}})
fn main(
@builtin(local_invocation_id) localID : vec3<u32>,
@builtin(workgroup_id) groupid : vec3<u32>) {
let B : u32 = params.B;
let T : u32 = params.T;
let C : u32 = params.C;
let OC : u32 = params.OC;

var localT: array<{{precision}}, {{TT}}>;
var localOC: array<{{precision}}, {{TOC}}>;

let outB: u32 = groupid.x;
let outT: u32 = groupid.y;
let outOC: u32 = groupid.z;
let numThread: u32 = ({{BT}} * {{BOC}}) / ({{TT}} * {{TOC}});

// position of the first c element computed by the thread
let threadRow: u32 = (localID.x / ({{BOC}} / {{TOC}})) * {{TT}};
let threadCol: u32 = (localID.x % ({{BOC}} / {{TOC}})) * {{TOC}};

// inpPtr and weightPtr are the starting positions of the tiles in a and b,
// incremented in the bkidx loop.
// outPtr is the starting position of the tile in c which is fixed.

var inpPtr = (outB * T + outT * {{BT}}) * C; // BTC
var weightPtr = outOC * {{BOC}} * C; //OCC
var threadResults: array<{{precision}}, {{TT}} * {{TOC}}>;
let outPtr = (outB * T + outT * {{BT}}) * OC + outOC * {{BOC}}; //BTOC
let biasPtr = outOC * {{BOC}};

for (var bkidx: u32 = 0; bkidx < C; bkidx += {{BC}}) {
// Load BC x BOC by numThread(BT * BOC / (TT * TOC))
// The number of iteration == BC * BOC / (BT * BOC / (TT * TOC))
for (var idx: u32 = 0; idx < {{NUM_TILEW}}; idx++) {
tileWeight[localID.x + idx * numThread] = weight[weightPtr + ((localID.x + idx * numThread) / {{BC}}) * C + ((localID.x + idx * numThread) % {{BC}})];
}
weightPtr += {{BC}};

// Load tile
// Load BT x BC by numThread(BT * BOC / (TT * TOC))
// The number of iteration == BT * BC / (BT * BOC / (TT * TOC))
for (var idx: u32 = 0; idx < {{NUM_TILEI}}; idx++) {
tileInp[localID.x + idx * numThread] = inp[inpPtr + ((localID.x + idx * numThread) / {{BC}}) * C + (localID.x + idx * numThread) % {{BC}}];
}
inpPtr += {{BC}};

workgroupBarrier();
// Compute tile
for (var dotIdx: u32 = 0; dotIdx < {{BC}}; dotIdx = dotIdx + 1) {
for (var idx: u32 = 0; idx < {{TT}}; idx++) {
localT[idx] = tileInp[(threadRow + idx) * {{BC}} + dotIdx];
}
for (var idx: u32 = 0; idx < {{TOC}}; idx++) {
localOC[idx] = tileWeight[(threadCol + idx) * {{BC}} + dotIdx];
}
for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
threadResults[resIdxT * {{TOC}} + resIdxOC] += localT[resIdxT] * localOC[resIdxOC];
}
}
}
workgroupBarrier();
}

if (arrayLength(&bias) == 1) {
for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC];
}
}
} else {
for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC] + bias[biasPtr + threadCol + resIdxOC];
}
}
}
}
)";

static const char *kShaderMatmulBackward = R"(
Expand Down
Loading
Loading