diff --git a/experimental/kernels/Makefile b/experimental/kernels/Makefile index c233ef5..5817e23 100644 --- a/experimental/kernels/Makefile +++ b/experimental/kernels/Makefile @@ -79,7 +79,7 @@ endef build/test_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin mkdir -p build $(call preprocess_file) - $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o + $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o -g build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bin mkdir -p build @@ -90,12 +90,12 @@ build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bi build/train_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin mkdir -p build $(call preprocess_file) - $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o + $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o -g build/ops.o: ops.cpp ops.hpp kernels.h llm.c mkdir -p build && $(CXX) $(CXXFLAGS) -c -o $@ $< -build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c +build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c gpt2_webgpu.cpp ops.cpp mkdir -p build $(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu.cpp ops.cpp diff --git a/experimental/kernels/kernels.h b/experimental/kernels/kernels.h index 212c075..1bce081 100644 --- a/experimental/kernels/kernels.h +++ b/experimental/kernels/kernels.h @@ -309,6 +309,104 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) { } } } + +)"; + + +static const char *kShaderMatmul2DTiling = R"( +@group(0) @binding(0) var<storage, read_write> inp : array<{{precision}}>; +@group(0) @binding(1) var<storage, read_write> weight : array<{{precision}}>; +@group(0) @binding(2) var<storage, read_write> bias : array<{{precision}}>; +@group(0) @binding(3) var<storage, read_write> out : array<{{precision}}>; +@group(0) @binding(4) var<uniform> params : Params; +struct Params { + B: u32, + T: u32, + C: u32, + OC: u32, +}; +var<workgroup> tileInp: array<{{precision}}, {{BT}} * {{BC}}>; +var<workgroup> tileWeight: array<{{precision}}, {{BOC}} * {{BC}}>; + +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(local_invocation_id) localID : vec3<u32>, + @builtin(workgroup_id) groupid : vec3<u32>) { + let B : u32 = params.B; + let T : u32 = params.T; + let C : u32 = params.C; + let OC : u32 = params.OC; + + var localT: array<{{precision}}, {{TT}}>; + var localOC: array<{{precision}}, {{TOC}}>; + + let outB: u32 = groupid.x; + let outT: u32 = groupid.y; + let outOC: u32 = groupid.z; + let numThread: u32 = ({{BT}} * {{BOC}}) / ({{TT}} * {{TOC}}); + + // position of the first c element computed by the thread + let threadRow: u32 = (localID.x / ({{BOC}} / {{TOC}})) * {{TT}}; + let threadCol: u32 = (localID.x % ({{BOC}} / {{TOC}})) * {{TOC}}; + + // inpPtr and weightPtr are the starting positions of the tiles in a and b, + // incremented in the bkidx loop. + // outPtr is the starting position of the tile in c which is fixed. + + var inpPtr = (outB * T + outT * {{BT}}) * C; // BTC + var weightPtr = outOC * {{BOC}} * C; //OCC + var threadResults: array<{{precision}}, {{TT}} * {{TOC}}>; + let outPtr = (outB * T + outT * {{BT}}) * OC + outOC * {{BOC}}; //BTOC + let biasPtr = outOC * {{BOC}}; + + for (var bkidx: u32 = 0; bkidx < C; bkidx += {{BC}}) { + // Load BC x BOC by numThread(BT * BOC / (TT * TOC)) + // The number of iteration == BC * BOC / (BT * BOC / (TT * TOC)) + for (var idx: u32 = 0; idx < {{NUM_TILEW}}; idx++) { + tileWeight[localID.x + idx * numThread] = weight[weightPtr + ((localID.x + idx * numThread) / {{BC}}) * C + ((localID.x + idx * numThread) % {{BC}})]; + } + weightPtr += {{BC}}; + + // Load tile + // Load BT x BC by numThread(BT * BOC / (TT * TOC)) + // The number of iteration == BT * BC / (BT * BOC / (TT * TOC)) + for (var idx: u32 = 0; idx < {{NUM_TILEI}}; idx++) { + tileInp[localID.x + idx * numThread] = inp[inpPtr + ((localID.x + idx * numThread) / {{BC}}) * C + (localID.x + idx * numThread) % {{BC}}]; + } + inpPtr += {{BC}}; + + workgroupBarrier(); + // Compute tile + for (var dotIdx: u32 = 0; dotIdx < {{BC}}; dotIdx = dotIdx + 1) { + for (var idx: u32 = 0; idx < {{TT}}; idx++) { + localT[idx] = tileInp[(threadRow + idx) * {{BC}} + dotIdx]; + } + for (var idx: u32 = 0; idx < {{TOC}}; idx++) { + localOC[idx] = tileWeight[(threadCol + idx) * {{BC}} + dotIdx]; + } + for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) { + for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) { + threadResults[resIdxT * {{TOC}} + resIdxOC] += localT[resIdxT] * localOC[resIdxOC]; + } + } + } + workgroupBarrier(); + } + + if (arrayLength(&bias) == 1) { + for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) { + for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) { + out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC]; + } + } + } else { + for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) { + for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) { + out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC] + bias[biasPtr + threadCol + resIdxOC]; + } + } + } +} )"; static const char *kShaderMatmulBackward = R"( diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index 67fc679..0e9c076 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -6,6 +6,7 @@ #include "kernels.h" #include "ops.hpp" +#include "experimental/wgsl.h" // loopUnrolling using namespace gpu; @@ -22,27 +23,39 @@ void encoder_forward(Context& ctx, float* out, uint32_t C; }; setLogLevel(kError); - printf("Creating tensors\n"); - printf("Creating input tensor\%pn", inp); - Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp); - printf("Created input tensor\n"); - Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32, wte); - printf("Created wte tensor\n"); - Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32, wpe); - printf("Created wpe tensor\n"); - Tensor output = createTensor(ctx, Shape{b * t * c}, kf32); - printf("Created tensors\n"); + // Generate the key of the cache by arguments. + std::string key = "encoder_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, Shape{b * t}, ki32); + Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32); + Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32); + Tensor output = createTensor(ctx, Shape{b * t * c}, kf32); + op = createKernel(ctx, {kShaderEncoder, 256, kf32}, + Bindings{input, wte_t, wpe_t, output}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& wte_t = ctx.pool.data[op->buffers[1]]; + Tensor& wpe_t = ctx.pool.data[op->buffers[2]]; + Tensor& output = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, input); + toGPU(ctx, wte, wte_t); + toGPU(ctx, wpe, wpe_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderEncoder, 256, kf32}, - Bindings{input, wte_t, wpe_t, output}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - EncoderParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, out, b * t * c * sizeof(float)); @@ -61,21 +74,40 @@ void encoder_backward(Context& ctx, float* dwte, float* dwpe, uint32_t C; }; setLogLevel(kError); - Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32, dwte); - Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32, dwpe); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp); + // Generate the key of the cache by arguments. + std::string key = "encoder_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32); + Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor input = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, + Bindings{dwte_t, dwpe_t, dout_t, input}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dwte_t = ctx.pool.data[op->buffers[0]]; + Tensor& dwpe_t = ctx.pool.data[op->buffers[1]]; + Tensor& dout_t = ctx.pool.data[op->buffers[2]]; + Tensor& input = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, dwte, dwte_t); + toGPU(ctx, dwpe, dwpe_t); + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, input); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, - Bindings{dwte_t, dwpe_t, dout_t, input}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - EncoderParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dwte_t, dwte, v * c * sizeof(float)); @@ -94,23 +126,43 @@ void layernorm_forward(Context& ctx, float* out, float* mean, float* rstd, uint32_t C; }; setLogLevel(kError); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight); - Tensor bias_t = createTensor(ctx, Shape{c}, kf32, bias); - Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); - Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); - Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + // Generate the key of the cache by arguments. + std::string key = "layernorm_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{c}, kf32); + Tensor bias_t = createTensor(ctx, Shape{c}, kf32); + Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + op = createKernel(ctx, {kShaderLayerNorm, 256, kf32}, + Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_t = ctx.pool.data[op->buffers[0]]; + Tensor& weight_t = ctx.pool.data[op->buffers[1]]; + Tensor& bias_t = ctx.pool.data[op->buffers[2]]; + Tensor& out_t = ctx.pool.data[op->buffers[3]]; + Tensor& mean_t = ctx.pool.data[op->buffers[4]]; + Tensor& rstd_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + toGPU(ctx, bias, bias_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNorm, 256, kf32}, - Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - LayerNormParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_t, out, b * t * c * sizeof(float)); @@ -130,25 +182,52 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias, uint32_t C; }; setLogLevel(kError); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp); - Tensor dweight_t = createTensor(ctx, Shape{c}, kf32, dweight); - Tensor dbias_t = createTensor(ctx, Shape{c}, kf32, dbias); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight); - Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32, mean); - Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32, rstd); + // Generate the key of the cache by arguments. + std::string key = "layernorm_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor dweight_t = createTensor(ctx, Shape{c}, kf32); + Tensor dbias_t = createTensor(ctx, Shape{c}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{c}, kf32); + Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, + Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dweight_t = ctx.pool.data[op->buffers[1]]; + Tensor& dbias_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + Tensor& mean_t = ctx.pool.data[op->buffers[6]]; + Tensor& rstd_t = ctx.pool.data[op->buffers[7]]; + + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dweight, dweight_t); + toGPU(ctx, dbias, dbias_t); + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + toGPU(ctx, mean, mean_t); + toGPU(ctx, rstd, rstd_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, - Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - LayerNormParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float)); @@ -156,9 +235,34 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias, toCPU(ctx, dbias_t, dbias, c * sizeof(float)); } +struct DurationTime { + std::chrono::high_resolution_clock::time_point start; + std::chrono::high_resolution_clock::time_point end; + std::chrono::microseconds duration; + std::string src; + bool verbose; + + inline DurationTime(const std::string& src, bool verbose = true) { + this->src = src; + this->verbose = verbose; + start = std::chrono::high_resolution_clock::now(); + } + + inline ~DurationTime() { + end = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start); + if (this->verbose) { + printf("Duration(%s): %.1f microseconds\n", src.c_str(), static_cast<double>(duration.count())); + } + } +}; + + void matmul_forward(Context& ctx, float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC){ + bool verbose = false; + DurationTime duration("matmul_forward_gpu", verbose); struct MatmulParams { uint32_t B; uint32_t T; @@ -171,25 +275,76 @@ void matmul_forward(Context& ctx, float* out, unsigned long oc = static_cast<unsigned long>(OC); setLogLevel(kError); - Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32, weight); - Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32, bias); - Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + // Generate the key of the cache by arguments. + std::string key = "matmul_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + constexpr size_t BT = 64; + constexpr size_t BC = 8; + constexpr size_t BOC = 64; + constexpr size_t TT = BT / BC; + constexpr size_t TOC = BOC / BC; + size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + + std::string kShaderMatmul2DTiling_(kShaderMatmul2DTiling); + std::string kShaderMatmul2D(loopUnrolling( + replaceAll(kShaderMatmul2DTiling_, + {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + }) + ) + ); + + Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); + Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); + Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + + op = createKernel(ctx, {kShaderMatmul2D, wgSize, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c), + static_cast<uint32_t>(oc) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_i = ctx.pool.data[op->buffers[0]]; + Tensor& weight_i = ctx.pool.data[op->buffers[1]]; + Tensor& bias_i = ctx.pool.data[op->buffers[2]]; + Tensor& out_o = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, inp_i); + toGPU(ctx, weight, weight_i); + if (bias != NULL) { + toGPU(ctx, bias, bias_i); + } + std::promise<void> promise; std::future<void> future = promise.get_future(); - assert ( (b*t) % 256 == 0 ); - Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c), - static_cast<uint32_t>(oc) - }); - dispatchKernel(ctx, op, promise); - wait(ctx, future); + { + DurationTime duration("matmul_forward_gpu without creating tensors", verbose); + { + DurationTime duration("matmul_forward_gpu without creating kernel", verbose); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + } + } toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } @@ -207,24 +362,47 @@ void matmul_backward(Context& ctx, float* dinp, float* dweight, float* dbias, unsigned long c = static_cast<unsigned long>(C); unsigned long oc = static_cast<unsigned long>(OC); setLogLevel(kError); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp); - Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32, dweight); - Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32, dbias); - Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32, weight); + // Generate the key of the cache by arguments. + std::string key = "matmul_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32); + Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32); + op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, + Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c), + static_cast<uint32_t>(oc) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dweight_t = ctx.pool.data[op->buffers[1]]; + Tensor& dbias_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dweight, dweight_t); + toGPU(ctx, dbias, dbias_t); + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, - Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c), - static_cast<uint32_t>(oc) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float)); @@ -246,22 +424,40 @@ void attention_forward(Context& ctx, float* out, float* preatt, float* att, unsigned long c = static_cast<unsigned long>(C); unsigned long nh = static_cast<unsigned long>(NH); setLogLevel(kError); - Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp); - Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, preatt); - Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att); - Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + // Generate the key of the cache by arguments. + std::string key = "attention_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + op = createKernel(ctx, {kShaderAttention, 256, kf32}, + Bindings{inp_t, preatt_t, att_t, out_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c), + static_cast<uint32_t>(nh) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_t = ctx.pool.data[op->buffers[0]]; + Tensor& preatt_t = ctx.pool.data[op->buffers[1]]; + Tensor& att_t = ctx.pool.data[op->buffers[2]]; + Tensor& out_t = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, inp_t); + toGPU(ctx, preatt, preatt_t); + toGPU(ctx, att, att_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderAttention, 256, kf32}, - Bindings{inp_t, preatt_t, att_t, out_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - AttentionParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c), - static_cast<uint32_t>(nh) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, preatt_t, preatt, b * nh * t * t * sizeof(float)); @@ -283,24 +479,47 @@ void attention_backward(Context& ctx, float* dinp, float* dpreatt, float* datt, unsigned long c = static_cast<unsigned long>(C); unsigned long nh = static_cast<unsigned long>(NH); setLogLevel(kError); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, dinp); - Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, dpreatt); - Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, datt); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp); - Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att); + // Generate the key of the cache by arguments. + std::string key = "attention_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, + Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c), + static_cast<uint32_t>(nh) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dpreatt_t = ctx.pool.data[op->buffers[1]]; + Tensor& datt_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& att_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dpreatt, dpreatt_t); + toGPU(ctx, datt, datt_t); + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, att, att_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, - Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - AttentionParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c), - static_cast<uint32_t>(nh) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * 3 * sizeof(float)); @@ -311,13 +530,28 @@ void attention_backward(Context& ctx, float* dinp, float* dpreatt, float* datt, void gelu_forward(Context& ctx, float* out, float* inp, int n) { unsigned long N = static_cast<unsigned long>(n); setLogLevel(kError); - Tensor input = createTensor(ctx, Shape{N}, kf32, inp); - Tensor output = createTensor(ctx, Shape{N}, kf32); + // Generate the key of the cache by arguments. + std::string key = "gelu_forward_" + std::to_string(n); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, Shape{N}, kf32); + Tensor output = createTensor(ctx, Shape{N}, kf32); + op = createKernel(ctx, {kShaderGelu, 256, kf32}, + Bindings{input, output}, + /* nWorkgroups */ {cdiv(N, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& output = ctx.pool.data[op->buffers[1]]; + + toGPU(ctx, inp, input); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderGelu, 256, kf32}, - Bindings{input, output}, - /* nWorkgroups */ {cdiv(N, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, out, N * sizeof(float)); @@ -326,14 +560,32 @@ void gelu_forward(Context& ctx, float* out, float* inp, int n) { void gelu_backward(Context& ctx, float* dinp, float* inp, float* dout, int N){ unsigned long n = static_cast<unsigned long>(N); setLogLevel(kError); - Tensor inp_i = createTensor(ctx, Shape{n}, kf32, inp); - Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout); - Tensor dinp_o = createTensor(ctx, Shape{n}, kf32, dinp); + // Generate the key of the cache by arguments. + std::string key = "gelu_backward_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_i = createTensor(ctx, Shape{n}, kf32); + Tensor dout_i = createTensor(ctx, Shape{n}, kf32); + Tensor dinp_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderGeluBackward, 256, kf32}, + Bindings{inp_i, dout_i, dinp_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_i = ctx.pool.data[op->buffers[0]]; + Tensor& dout_i = ctx.pool.data[op->buffers[1]]; + Tensor& dinp_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, inp, inp_i); + toGPU(ctx, dout, dout_i); + toGPU(ctx, dinp, dinp_o); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderGeluBackward, 256, kf32}, - Bindings{inp_i, dout_i, dinp_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_o, dinp, n * sizeof(float)); @@ -342,14 +594,31 @@ void gelu_backward(Context& ctx, float* dinp, float* inp, float* dout, int N){ void residual_forward(Context& ctx, float* out, float* inp1, float* inp2, int N){ unsigned long n = static_cast<unsigned long>(N); setLogLevel(kError); - Tensor inp1_i = createTensor(ctx, Shape{n}, kf32, inp1); - Tensor inp2_i = createTensor(ctx, Shape{n}, kf32, inp2); - Tensor out_o = createTensor(ctx, Shape{n}, kf32); + // Generate the key of the cache by arguments. + std::string key = "residual_forward_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp1_i = createTensor(ctx, Shape{n}, kf32); + Tensor inp2_i = createTensor(ctx, Shape{n}, kf32); + Tensor out_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderResidual, 256, kf32}, + Bindings{inp1_i, inp2_i, out_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp1_i = ctx.pool.data[op->buffers[0]]; + Tensor& inp2_i = ctx.pool.data[op->buffers[1]]; + Tensor& out_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, inp1, inp1_i); + toGPU(ctx, inp2, inp2_i); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderResidual, 256, kf32}, - Bindings{inp1_i, inp2_i, out_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_o, out, n * sizeof(float)); @@ -358,14 +627,32 @@ void residual_forward(Context& ctx, float* out, float* inp1, float* inp2, int N) void residual_backward(Context& ctx, float* dinp1, float* dinp2, float* dout, int N){ unsigned long n = static_cast<unsigned long>(N); setLogLevel(kError); - Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout); - Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32, dinp1); - Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32, dinp2); + // Generate the key of the cache by arguments. + std::string key = "residual_backward_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dout_i = createTensor(ctx, Shape{n}, kf32); + Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32); + Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderResidualBackward, 256, kf32}, + Bindings{dout_i, dinp1_o, dinp2_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dout_i = ctx.pool.data[op->buffers[0]]; + Tensor& dinp1_o = ctx.pool.data[op->buffers[1]]; + Tensor& dinp2_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, dout, dout_i); + toGPU(ctx, dinp1, dinp1_o); + toGPU(ctx, dinp2, dinp2_o); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderResidualBackward, 256, kf32}, - Bindings{dout_i, dinp1_o, dinp2_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp1_o, dinp1, n * sizeof(float)); @@ -382,14 +669,28 @@ void softmax_forward(Context& ctx, float* probs, float* logits, int B, int T, in uint32_t t = static_cast<uint32_t>(T); uint32_t c = static_cast<uint32_t>(V); uint32_t cp = static_cast<uint32_t>(Vp); - Tensor input = createTensor(ctx, {b * t, cp}, kf32, logits); - Tensor output = createTensor(ctx, {b * t, cp}, kf32); + // Generate the key of the cache by arguments. + std::string key = "softmax_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, {b * t, cp}, kf32); + Tensor output = createTensor(ctx, {b * t, cp}, kf32); + assert( (B*T) % 256 == 0); + op = createKernel( + ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, + Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& output = ctx.pool.data[op->buffers[1]]; + + toGPU(ctx, logits, input); + std::promise<void> promise; std::future<void> future = promise.get_future(); - assert( (B*T) % 256 == 0); - Kernel op = createKernel( - ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, - Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, probs, sizeof(float)*b*t*cp); @@ -407,20 +708,37 @@ void crossentropy_forward(Context& ctx, float* losses, unsigned long t = static_cast<unsigned long>(T); unsigned long vp = static_cast<unsigned long>(Vp); setLogLevel(kError); - Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32, losses); - Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs); - Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets); + // Generate the key of the cache by arguments. + std::string key = "crossentropy_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, + Bindings{losses_t, probs_t, targets_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropyParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(vp) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& losses_t = ctx.pool.data[op->buffers[0]]; + Tensor& probs_t = ctx.pool.data[op->buffers[1]]; + Tensor& targets_t = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, losses, losses_t); + toGPU(ctx, probs, probs_t); + toGPU(ctx, targets, targets_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, - Bindings{losses_t, probs_t, targets_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - CrossEntropyParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(vp) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, losses_t, losses, b * t * sizeof(float)); @@ -440,22 +758,41 @@ void crossentropy_softmax_backward(Context& ctx, float* dlogits, unsigned long v = static_cast<unsigned long>(V); unsigned long vp = static_cast<unsigned long>(Vp); setLogLevel(kError); - Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32, dlogits); - Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32, dlosses); - Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs); - Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets); + // Generate the key of the cache by arguments. + std::string key = "crossentropy_softmax_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, + Bindings{dlogits_t, dlosses_t, probs_t, targets_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropySoftmaxBackwardParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(v), + static_cast<uint32_t>(vp) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dlogits_t = ctx.pool.data[op->buffers[0]]; + Tensor& dlosses_t = ctx.pool.data[op->buffers[1]]; + Tensor& probs_t = ctx.pool.data[op->buffers[2]]; + Tensor& targets_t = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, dlogits, dlogits_t); + toGPU(ctx, dlosses, dlosses_t); + toGPU(ctx, probs, probs_t); + toGPU(ctx, targets, targets_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, - Bindings{dlogits_t, dlosses_t, probs_t, targets_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - CrossEntropySoftmaxBackwardParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(v), - static_cast<uint32_t>(vp) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dlogits_t, dlogits, b * t * vp * sizeof(float)); diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index 37cdcaf..d037eac 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -2,9 +2,11 @@ #include <array> #include <cstdio> #include <future> +#include <map> #include "kernels.h" #include "unittest_llmc/unittest_kernels.h" +#include "experimental/wgsl.h" // loopUnrolling using namespace gpu; // createContext, createTensor, createKernel, // createShader, dispatchKernel, wait, toCPU @@ -51,6 +53,33 @@ using namespace gpu; // createContext, createTensor, createKernel, } \ } +struct DurationTime { + std::chrono::high_resolution_clock::time_point start; + std::chrono::high_resolution_clock::time_point end; + std::chrono::microseconds duration; + std::string src; + bool verbose; + + inline DurationTime(const std::string& src, bool verbose = true) { + this->src = src; + this->verbose = verbose; + start = std::chrono::high_resolution_clock::now(); + } + + inline ~DurationTime() { + end = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start); + if (this->verbose) { + printf("Duration(%s): %.1f microseconds\n", src.c_str(), static_cast<double>(duration.count())); + } + } +}; + +static WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; +static Context ctx = createContext({},{},{ + .requiredLimits = &requiredLimits + }); + void ENCODER_FORWARD_GPU(float* out, int* inp, float* wte, float* wpe, int B, int T, int C){ @@ -64,25 +93,40 @@ void ENCODER_FORWARD_GPU(float* out, uint32_t C; }; setLogLevel(kError); - WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; - Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); - Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp); - Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32, wte); - Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32, wpe); - Tensor output = createTensor(ctx, Shape{b * t * c}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "ENCODER_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, Shape{b * t}, ki32); + Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32); + Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32); + Tensor output = createTensor(ctx, Shape{b * t * c}, kf32); + op = createKernel(ctx, {kShaderEncoder, 256, kf32}, + Bindings{input, wte_t, wpe_t, output}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& wte_t = ctx.pool.data[op->buffers[1]]; + Tensor& wpe_t = ctx.pool.data[op->buffers[2]]; + Tensor& output = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, input); + toGPU(ctx, wte, wte_t); + toGPU(ctx, wpe, wpe_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderEncoder, 256, kf32}, - Bindings{input, wte_t, wpe_t, output}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - EncoderParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, out, b * t * c * sizeof(float)); @@ -101,25 +145,41 @@ void ENCODER_BACKWARD_GPU(float* dwte, float* dwpe, uint32_t C; }; setLogLevel(kError); - WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; - Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); - Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32, dwte); - Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32, dwpe); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp); + + // Generate the key of the cache by arguments. + std::string key = "ENCODER_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32); + Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor input = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, + Bindings{dwte_t, dwpe_t, dout_t, input}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dwte_t = ctx.pool.data[op->buffers[0]]; + Tensor& dwpe_t = ctx.pool.data[op->buffers[1]]; + Tensor& dout_t = ctx.pool.data[op->buffers[2]]; + Tensor& input = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, dwte, dwte_t); + toGPU(ctx, dwpe, dwpe_t); + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, input); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, - Bindings{dwte_t, dwpe_t, dout_t, input}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - EncoderParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dwte_t, dwte, v * c * sizeof(float)); @@ -138,24 +198,44 @@ void LAYERNORM_FORWARD_GPU(float* out, float* mean, float* rstd, uint32_t C; }; setLogLevel(kError); - Context ctx = createContext(); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight); - Tensor bias_t = createTensor(ctx, Shape{c}, kf32, bias); - Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); - Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); - Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "LAYERNORM_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{c}, kf32); + Tensor bias_t = createTensor(ctx, Shape{c}, kf32); + Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + op = createKernel(ctx, {kShaderLayerNorm, 256, kf32}, + Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_t = ctx.pool.data[op->buffers[0]]; + Tensor& weight_t = ctx.pool.data[op->buffers[1]]; + Tensor& bias_t = ctx.pool.data[op->buffers[2]]; + Tensor& out_t = ctx.pool.data[op->buffers[3]]; + Tensor& mean_t = ctx.pool.data[op->buffers[4]]; + Tensor& rstd_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + toGPU(ctx, bias, bias_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNorm, 256, kf32}, - Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - LayerNormParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_t, out, b * t * c * sizeof(float)); @@ -175,26 +255,53 @@ void LAYERNORM_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, uint32_t C; }; setLogLevel(kError); - Context ctx = createContext(); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp); - Tensor dweight_t = createTensor(ctx, Shape{c}, kf32, dweight); - Tensor dbias_t = createTensor(ctx, Shape{c}, kf32, dbias); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight); - Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32, mean); - Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32, rstd); + + // Generate the key of the cache by arguments. + std::string key = "LAYERNORM_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor dweight_t = createTensor(ctx, Shape{c}, kf32); + Tensor dbias_t = createTensor(ctx, Shape{c}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{c}, kf32); + Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, + Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dweight_t = ctx.pool.data[op->buffers[1]]; + Tensor& dbias_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + Tensor& mean_t = ctx.pool.data[op->buffers[6]]; + Tensor& rstd_t = ctx.pool.data[op->buffers[7]]; + + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dweight, dweight_t); + toGPU(ctx, dbias, dbias_t); + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + toGPU(ctx, mean, mean_t); + toGPU(ctx, rstd, rstd_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, - Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - LayerNormParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float)); @@ -202,9 +309,29 @@ void LAYERNORM_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, toCPU(ctx, dbias_t, dbias, c * sizeof(float)); } +void matmul_forward_dummy(float* out, + const float* inp, const float* weight, const float* bias, + int B, int T, int C, int OC); + + void MATMUL_FORWARD_GPU(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC){ + int version = 2; + bool verbose = false; + bool debug = false; + float *out_exp; + DurationTime duration("matmul_forward_gpu with preparing a kernel", verbose); + if (verbose) { + printf("matmul forward: B=%d, T=%d, C=%d, OC=%d, bias=%d\n", B, T, C, OC, bias != NULL); + } + if (debug) { + out_exp = new float[B*T*OC]; + { + DurationTime duration("matmul_forward_cpu", verbose); + matmul_forward_dummy(out_exp, inp, weight, bias, B, T, C, OC); + } + } struct MatmulParams { uint32_t B; uint32_t T; @@ -216,31 +343,132 @@ void MATMUL_FORWARD_GPU(float* out, unsigned long c = static_cast<unsigned long>(C); unsigned long oc = static_cast<unsigned long>(OC); setLogLevel(kError); - WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; - Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); - - Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32, weight); - Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32, bias); - Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); - std::promise<void> promise; - std::future<void> future = promise.get_future(); - assert ( (b*t) % 256 == 0 ); - Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c), - static_cast<uint32_t>(oc) - }); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + + if (version == 2 || version == 1) { + // Generate the key of the cache by arguments. + std::string key = "MATMUL_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); + Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); + Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + + if (version == 2) { + constexpr size_t BT = 64; + constexpr size_t BC = 16; + constexpr size_t BOC = 64; + constexpr size_t TT = BT / BC; + constexpr size_t TOC = BOC / BC; + constexpr size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; + + std::string codeString(kShaderMatmul2DTiling); + std::string unrolledCode = loopUnrolling(replaceAll(codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + })); + + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c), + static_cast<uint32_t>(oc) + }, + nullptr, + key.c_str() + ); + } else { + op = createKernel(ctx, {kShaderMatmul, 256, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c), + static_cast<uint32_t>(oc) + }, + nullptr, + key.c_str() + ); + } + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_i = ctx.pool.data[op->buffers[0]]; + Tensor& weight_i = ctx.pool.data[op->buffers[1]]; + Tensor& bias_i = ctx.pool.data[op->buffers[2]]; + Tensor& out_o = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, inp_i); + toGPU(ctx, weight, weight_i); + if (bias != NULL) { + toGPU(ctx, bias, bias_i); + } + + std::promise<void> promise; + std::future<void> future = promise.get_future(); + + { + DurationTime duration("matmul_forward_gpu", verbose); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + } + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + } else { + DurationTime duration("matmul_forward_cpu", verbose); + matmul_forward_dummy(out, inp, weight, bias, B, T, C, OC); + } + + if (debug) { // compare out with out_exp. + for (int i = 0; i < B*T*OC; i++) { + if (fabs(out[i] - out_exp[i]) > 1e-2) { + printf("matmul forward: out[%d] = %f, out_exp[%d] = %f\n", i, out[i], i, out_exp[i]); + //Dump the first 4 x 4 elements by table, at first output out, then output out_exp + printf("inp:\n"); + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 4; k++) { + printf("%f ", inp[j * C + k]); + } + printf("\n"); + } + printf("weight:\n"); + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 4; k++) { + printf("%f ", weight[j * OC + k]); + } + printf("\n"); + } + printf("out:\n"); + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 4; k++) { + printf("%f ", out[j * OC + k]); + } + printf("\n"); + } + printf("out_exp:\n"); + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 4; k++) { + printf("%f ", out_exp[j * OC + k]); + } + printf("\n"); + } + exit(1); + } + } + delete[] out_exp; + } } void MATMUL_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, @@ -257,28 +485,48 @@ void MATMUL_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, unsigned long c = static_cast<unsigned long>(C); unsigned long oc = static_cast<unsigned long>(OC); setLogLevel(kError); - WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; - Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp); - Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32, dweight); - Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32, dbias); - Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32, weight); + + // Generate the key of the cache by arguments. + std::string key = "MATMUL_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32); + Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32); + op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, + Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c), + static_cast<uint32_t>(oc) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dweight_t = ctx.pool.data[op->buffers[1]]; + Tensor& dbias_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dweight, dweight_t); + toGPU(ctx, dbias, dbias_t); + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, - Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c), - static_cast<uint32_t>(oc) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float)); @@ -300,23 +548,41 @@ void ATTENTION_FORWARD_GPU(float* out, float* preatt, float* att, unsigned long c = static_cast<unsigned long>(C); unsigned long nh = static_cast<unsigned long>(NH); setLogLevel(kError); - Context ctx = createContext(); - Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp); - Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, preatt); - Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att); - Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "ATTENTION_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + op = createKernel(ctx, {kShaderAttention, 256, kf32}, + Bindings{inp_t, preatt_t, att_t, out_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c), + static_cast<uint32_t>(nh) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_t = ctx.pool.data[op->buffers[0]]; + Tensor& preatt_t = ctx.pool.data[op->buffers[1]]; + Tensor& att_t = ctx.pool.data[op->buffers[2]]; + Tensor& out_t = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, inp_t); + toGPU(ctx, preatt, preatt_t); + toGPU(ctx, att, att_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderAttention, 256, kf32}, - Bindings{inp_t, preatt_t, att_t, out_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - AttentionParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c), - static_cast<uint32_t>(nh) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, preatt_t, preatt, b * nh * t * t * sizeof(float)); @@ -338,25 +604,48 @@ void ATTENTION_BACKWARD_GPU(float* dinp, float* dpreatt, float* datt, unsigned long c = static_cast<unsigned long>(C); unsigned long nh = static_cast<unsigned long>(NH); setLogLevel(kError); - Context ctx = createContext(); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, dinp); - Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, dpreatt); - Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, datt); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp); - Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att); + + // Generate the key of the cache by arguments. + std::string key = "ATTENTION_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, + Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(c), + static_cast<uint32_t>(nh) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dpreatt_t = ctx.pool.data[op->buffers[1]]; + Tensor& datt_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& att_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dpreatt, dpreatt_t); + toGPU(ctx, datt, datt_t); + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, att, att_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, - Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - AttentionParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(c), - static_cast<uint32_t>(nh) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * 3 * sizeof(float)); @@ -367,14 +656,29 @@ void ATTENTION_BACKWARD_GPU(float* dinp, float* dpreatt, float* datt, void GELU_FORWARD_GPU(float* out, float* inp, int n) { unsigned long N = static_cast<unsigned long>(n); setLogLevel(kError); - Context ctx = createContext(); - Tensor input = createTensor(ctx, Shape{N}, kf32, inp); - Tensor output = createTensor(ctx, Shape{N}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "GELU_FORWARD_GPU_" + std::to_string(n); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, Shape{N}, kf32); + Tensor output = createTensor(ctx, Shape{N}, kf32); + op = createKernel(ctx, {kShaderGelu, 256, kf32}, + Bindings{input, output}, + /* nWorkgroups */ {cdiv(N, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& output = ctx.pool.data[op->buffers[1]]; + + toGPU(ctx, inp, input); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderGelu, 256, kf32}, - Bindings{input, output}, - /* nWorkgroups */ {cdiv(N, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, out, N * sizeof(float)); @@ -383,15 +687,33 @@ void GELU_FORWARD_GPU(float* out, float* inp, int n) { void GELU_BACKWARD_GPU(float* dinp, float* inp, float* dout, int N){ unsigned long n = static_cast<unsigned long>(N); setLogLevel(kError); - Context ctx = createContext(); - Tensor inp_i = createTensor(ctx, Shape{n}, kf32, inp); - Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout); - Tensor dinp_o = createTensor(ctx, Shape{n}, kf32, dinp); + + // Generate the key of the cache by arguments. + std::string key = "GELU_BACKWARD_GPU_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_i = createTensor(ctx, Shape{n}, kf32); + Tensor dout_i = createTensor(ctx, Shape{n}, kf32); + Tensor dinp_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderGeluBackward, 256, kf32}, + Bindings{inp_i, dout_i, dinp_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_i = ctx.pool.data[op->buffers[0]]; + Tensor& dout_i = ctx.pool.data[op->buffers[1]]; + Tensor& dinp_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, inp, inp_i); + toGPU(ctx, dout, dout_i); + toGPU(ctx, dinp, dinp_o); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderGeluBackward, 256, kf32}, - Bindings{inp_i, dout_i, dinp_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_o, dinp, n * sizeof(float)); @@ -400,15 +722,32 @@ void GELU_BACKWARD_GPU(float* dinp, float* inp, float* dout, int N){ void RESIDUAL_FORWARD_GPU(float* out, float* inp1, float* inp2, int N){ unsigned long n = static_cast<unsigned long>(N); setLogLevel(kError); - Context ctx = createContext(); - Tensor inp1_i = createTensor(ctx, Shape{n}, kf32, inp1); - Tensor inp2_i = createTensor(ctx, Shape{n}, kf32, inp2); - Tensor out_o = createTensor(ctx, Shape{n}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "RESIDUAL_FORWARD_GPU_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp1_i = createTensor(ctx, Shape{n}, kf32); + Tensor inp2_i = createTensor(ctx, Shape{n}, kf32); + Tensor out_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderResidual, 256, kf32}, + Bindings{inp1_i, inp2_i, out_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp1_i = ctx.pool.data[op->buffers[0]]; + Tensor& inp2_i = ctx.pool.data[op->buffers[1]]; + Tensor& out_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, inp1, inp1_i); + toGPU(ctx, inp2, inp2_i); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderResidual, 256, kf32}, - Bindings{inp1_i, inp2_i, out_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_o, out, n * sizeof(float)); @@ -417,15 +756,33 @@ void RESIDUAL_FORWARD_GPU(float* out, float* inp1, float* inp2, int N){ void RESIDUAL_BACKWARD_GPU(float* dinp1, float* dinp2, float* dout, int N){ unsigned long n = static_cast<unsigned long>(N); setLogLevel(kError); - Context ctx = createContext(); - Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout); - Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32, dinp1); - Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32, dinp2); + + // Generate the key of the cache by arguments. + std::string key = "RESIDUAL_BACKWARD_GPU_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dout_i = createTensor(ctx, Shape{n}, kf32); + Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32); + Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderResidualBackward, 256, kf32}, + Bindings{dout_i, dinp1_o, dinp2_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dout_i = ctx.pool.data[op->buffers[0]]; + Tensor& dinp1_o = ctx.pool.data[op->buffers[1]]; + Tensor& dinp2_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, dout, dout_i); + toGPU(ctx, dinp1, dinp1_o); + toGPU(ctx, dinp2, dinp2_o); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderResidualBackward, 256, kf32}, - Bindings{dout_i, dinp1_o, dinp2_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp1_o, dinp1, n * sizeof(float)); @@ -442,15 +799,29 @@ void SOFTMAX_FORWARD_GPU(float* probs, float* logits, int B, int T, int V, int V uint32_t t = static_cast<uint32_t>(T); uint32_t c = static_cast<uint32_t>(V); uint32_t cp = static_cast<uint32_t>(Vp); - Context ctx = createContext(); - Tensor input = createTensor(ctx, {b * t, cp}, kf32, logits); - Tensor output = createTensor(ctx, {b * t, cp}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "SOFTMAX_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, {b * t, cp}, kf32); + Tensor output = createTensor(ctx, {b * t, cp}, kf32); + assert( (B*T) % 256 == 0); + op = createKernel( + ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, + Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& output = ctx.pool.data[op->buffers[1]]; + + toGPU(ctx, logits, input); + std::promise<void> promise; std::future<void> future = promise.get_future(); - assert( (B*T) % 256 == 0); - Kernel op = createKernel( - ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, - Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, probs, sizeof(float)*b*t*cp); @@ -468,21 +839,38 @@ void CROSSENTROPY_FORWARD_GPU(float* losses, unsigned long t = static_cast<unsigned long>(T); unsigned long vp = static_cast<unsigned long>(Vp); setLogLevel(kError); - Context ctx = createContext(); - Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32, losses); - Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs); - Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets); + + // Generate the key of the cache by arguments. + std::string key = "CROSSENTROPY_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, + Bindings{losses_t, probs_t, targets_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropyParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(vp) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& losses_t = ctx.pool.data[op->buffers[0]]; + Tensor& probs_t = ctx.pool.data[op->buffers[1]]; + Tensor& targets_t = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, losses, losses_t); + toGPU(ctx, probs, probs_t); + toGPU(ctx, targets, targets_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, - Bindings{losses_t, probs_t, targets_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - CrossEntropyParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(vp) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, losses_t, losses, b * t * sizeof(float)); @@ -502,23 +890,42 @@ void CROSSENTROPY_SOFTMAX_BACKWARD_GPU(float* dlogits, unsigned long v = static_cast<unsigned long>(V); unsigned long vp = static_cast<unsigned long>(Vp); setLogLevel(kError); - Context ctx = createContext(); - Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32, dlogits); - Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32, dlosses); - Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs); - Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets); + + // Generate the key of the cache by arguments. + std::string key = "CROSSENTROPY_SOFTMAX_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, + Bindings{dlogits_t, dlosses_t, probs_t, targets_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropySoftmaxBackwardParams{ + static_cast<uint32_t>(b), + static_cast<uint32_t>(t), + static_cast<uint32_t>(v), + static_cast<uint32_t>(vp) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dlogits_t = ctx.pool.data[op->buffers[0]]; + Tensor& dlosses_t = ctx.pool.data[op->buffers[1]]; + Tensor& probs_t = ctx.pool.data[op->buffers[2]]; + Tensor& targets_t = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, dlogits, dlogits_t); + toGPU(ctx, dlosses, dlosses_t); + toGPU(ctx, probs, probs_t); + toGPU(ctx, targets, targets_t); + std::promise<void> promise; std::future<void> future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, - Bindings{dlogits_t, dlosses_t, probs_t, targets_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - CrossEntropySoftmaxBackwardParams{ - static_cast<uint32_t>(b), - static_cast<uint32_t>(t), - static_cast<uint32_t>(v), - static_cast<uint32_t>(vp) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dlogits_t, dlogits, b * t * vp * sizeof(float)); diff --git a/gpu.hpp b/gpu.hpp index 83fc94b..941656e 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -415,12 +415,14 @@ struct KernelCode { * @endcode * "f32"}}); */ -inline void +inline const std::string replaceAll(std::string &str, const std::vector<std::pair<std::string, std::string>> &reps) { for (const auto &rep : reps) { replaceAll(str, rep.first, rep.second); } + + return str; } /** @@ -452,7 +454,7 @@ struct CopyData { * The struct members can be divided into "consumed upon dispatch" * (commandBuffer) and reusable ahead-of-time setup (all other members). */ -struct Kernel { +struct RawKernel { std::unique_ptr<WGPUBuffer[]> buffers; // non-owning std::unique_ptr<size_t[]> bufferSizes; size_t numBindings; @@ -460,8 +462,11 @@ struct Kernel { WGPUBindGroup bindGroup; // persists between submission WGPUComputePipeline computePipeline; // persists between submission WGPUCommandBuffer commandBuffer; // destroyed upon submission + bool used; }; +typedef std::shared_ptr<RawKernel> Kernel; + /** * @brief A struct to package the result of a WGSL code compilation. @@ -481,7 +486,7 @@ struct CompilationInfo { * @return True if lhs < rhs, false otherwise */ inline bool operator<(const Kernel &lhs, const Kernel &rhs) { - return lhs.commandBuffer < rhs.commandBuffer; + return lhs->commandBuffer < rhs->commandBuffer; } /** @@ -492,7 +497,7 @@ inline bool operator<(const Kernel &lhs, const Kernel &rhs) { struct KernelPool { inline KernelPool(Context *ctx) : ctx(ctx), data() {} Context *ctx; - std::set<Kernel *> data; + std::unordered_map<std::string, Kernel> data; inline ~KernelPool() { // Note : Some kernel resources such as commandBuffer are harvested by // queue submission, explicitly destroying readback and callback buffers @@ -997,6 +1002,7 @@ inline void wait(Context &ctx, std::future<void> &future) { inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize, CopyData &op) { wgpuQueueSubmit(ctx.queue, 1, &op.commandBuffer); + wgpuCommandBufferRelease(op.commandBuffer); CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise, &op.future}; wgpuQueueOnSubmittedWorkDone( @@ -1052,14 +1058,17 @@ inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize) { } { WGPUCommandEncoder commandEncoder; - WGPUComputePassEncoder computePassEncoder; commandEncoder = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr); wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, tensor.data.buffer, 0, op.readbackBuffer, 0, bufferSize); op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + wgpuCommandEncoderRelease(commandEncoder); check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__); } toCPU(ctx, tensor, data, bufferSize, op); + if (op.readbackBuffer) { + wgpuBufferRelease(op.readbackBuffer); + } } /** @@ -1078,6 +1087,61 @@ void toCPU(Context &ctx, Tensor &tensor, std::array<float, N> &data) { toCPU(ctx, tensor, data.data(), sizeof(data)); } +inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, + size_t size) { + uint64_t bufferSize = size; + CopyData op; + op.future = op.promise.get_future(); + { + WGPUBufferDescriptor readbackBufferDescriptor = { + .usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead, + .size = bufferSize, + }; + op.readbackBuffer = + wgpuDeviceCreateBuffer(ctx.device, &readbackBufferDescriptor); + } + { + WGPUCommandEncoder commandEncoder; + commandEncoder = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr); + wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, buffer, 0, + op.readbackBuffer, 0, bufferSize); + op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + wgpuCommandEncoderRelease(commandEncoder); + check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__); + } + wgpuQueueSubmit(ctx.queue, 1, &op.commandBuffer); + wgpuCommandBufferRelease(op.commandBuffer); + CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise, + &op.future}; + wgpuQueueOnSubmittedWorkDone( + ctx.queue, + [](WGPUQueueWorkDoneStatus status, void *callbackData) { + check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", + __FILE__, __LINE__); + const auto *data = static_cast<CallbackData *>(callbackData); + wgpuBufferMapAsync( + data->buffer, WGPUMapMode_Read, 0, data->bufferSize, + [](WGPUBufferMapAsyncStatus status, void *captureData) { + const auto *data = static_cast<CallbackData *>(captureData); + check(status == WGPUBufferMapAsyncStatus_Success, + "Map readbackBuffer", __FILE__, __LINE__); + const void *mappedData = wgpuBufferGetConstMappedRange( + data->buffer, /*offset=*/0, data->bufferSize); + check(mappedData, "Get mapped range", __FILE__, __LINE__); + memcpy(data->output, mappedData, data->bufferSize); + wgpuBufferUnmap(data->buffer); + data->promise->set_value(); + }, + callbackData); + }, + &callbackData); + wait(ctx, op.future); + if (op.readbackBuffer) { + wgpuBufferRelease(op.readbackBuffer); + } +} + + /** * @brief Copies data from CPU memory to a GPU buffer. The toGPU overloads are * effectively a convenience wrapper around the WebGPU API call @@ -1119,13 +1183,18 @@ inline void toGPU(Context &ctx, const half *data, Tensor &tensor) { tensor.data.size); } +inline void toGPU(Context &ctx, const int *data, Tensor &tensor) { + wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data, + tensor.data.size); +} + template <typename Params> inline void toGPU(Context &ctx, Params ¶ms, Kernel &op) { // TODO(avh): Maintain params metadata in Kernel and check for consistency. // If a kernel does not have parameters this will quietly overwrite // the last buffer in the bind group with the parameters buffer. - if (op.numBindings > 0) { - wgpuQueueWriteBuffer(ctx.queue, op.buffers[op.numBindings - 1], 0, + if (op->numBindings > 0) { + wgpuQueueWriteBuffer(ctx.queue, op->buffers[op->numBindings - 1], 0, static_cast<void *>(¶ms), sizeof(params)); } } @@ -1148,14 +1217,17 @@ inline void resetCommandBuffer(WGPUDevice &device, Kernel &op) { wgpuDeviceCreateCommandEncoder(device, nullptr); WGPUComputePassEncoder computePassEncoder = wgpuCommandEncoderBeginComputePass(commandEncoder, nullptr); - wgpuComputePassEncoderSetPipeline(computePassEncoder, op.computePipeline); - wgpuComputePassEncoderSetBindGroup(computePassEncoder, 0, op.bindGroup, 0, + wgpuComputePassEncoderSetPipeline(computePassEncoder, op->computePipeline); + wgpuComputePassEncoderSetBindGroup(computePassEncoder, 0, op->bindGroup, 0, nullptr); wgpuComputePassEncoderDispatchWorkgroups( - computePassEncoder, op.totalWorkgroups[0], op.totalWorkgroups[1], - op.totalWorkgroups[2]); + computePassEncoder, op->totalWorkgroups[0], op->totalWorkgroups[1], + op->totalWorkgroups[2]); wgpuComputePassEncoderEnd(computePassEncoder); - op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + wgpuComputePassEncoderRelease(computePassEncoder); + op->commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + wgpuCommandEncoderRelease(commandEncoder); + op->used = false; } } @@ -1217,11 +1289,19 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, const size_t *viewOffsets, const Shape &totalWorkgroups, const void *params = nullptr, size_t paramsSize = 0, - CompilationInfo* compilationInfo = nullptr) { + CompilationInfo* compilationInfo = nullptr, + const char* cacheKey = nullptr) { + // Create a cache key by the pointer values of the data bindings and the kernel code + if (cacheKey != nullptr && ctx.kernelPool.data.find(cacheKey) != ctx.kernelPool.data.end()) { + LOG(kDefLog, kInfo, "Kernel cache hit"); + return ctx.kernelPool.data[cacheKey]; + } + assert(totalWorkgroups.rank == 3); WGPUDevice device = ctx.device; WGPUQueue queue = ctx.queue; - Kernel op; + Kernel op(new RawKernel()); + // paramIndex is the index into bgLayoutEntries for the parameters buffer If // there are no parameters for the kernel, paramsSize == 0 and paramIndex is // effectively undefined (== -1) @@ -1234,9 +1314,9 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, // op.buffers, op.bufferSizes and // bgLayoutEntries } - op.buffers = std::make_unique<WGPUBuffer[]>(numBindings); - op.bufferSizes = std::make_unique<size_t[]>(numBindings); - op.numBindings = numBindings; + op->buffers = std::make_unique<WGPUBuffer[]>(numBindings); + op->bufferSizes = std::make_unique<size_t[]>(numBindings); + op->numBindings = numBindings; std::vector<WGPUBindGroupLayoutEntry> bgLayoutEntries(numBindings); // Create layout entries for input buffers for (size_t i = 0; i < numTensors; ++i) { @@ -1270,8 +1350,8 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, WGPUBindGroupLayout bgLayout = wgpuDeviceCreateBindGroupLayout(device, &bgLayoutDesc); for (size_t i = 0; i < numTensors; ++i) { - op.buffers[i] = dataBindings[i].data.buffer; - op.bufferSizes[i] = dataBindings[i].data.size; + op->buffers[i] = dataBindings[i].data.buffer; + op->bufferSizes[i] = dataBindings[i].data.size; } // Create a buffer for the Params struct if (paramsSize > 0) { @@ -1280,9 +1360,9 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, .size = paramsSize, .mappedAtCreation = false, }; - op.buffers[paramIndex] = wgpuDeviceCreateBuffer(device, ¶msBufferDesc); - op.bufferSizes[paramIndex] = paramsSize; - wgpuQueueWriteBuffer(queue, op.buffers[paramIndex], 0, params, paramsSize); + op->buffers[paramIndex] = wgpuDeviceCreateBuffer(device, ¶msBufferDesc); + op->bufferSizes[paramIndex] = paramsSize; + wgpuQueueWriteBuffer(queue, op->buffers[paramIndex], 0, params, paramsSize); LOG(kDefLog, kTrace, "Params buffer written"); } else { LOG(kDefLog, kTrace, "No params buffer needed"); @@ -1291,9 +1371,9 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, for (size_t i = 0; i < numTensors; ++i) { bindGroupEntries[i] = WGPUBindGroupEntry{ .binding = static_cast<uint32_t>(i), - .buffer = op.buffers[i], + .buffer = op->buffers[i], .offset = viewOffsets[i], - .size = op.bufferSizes[i], + .size = op->bufferSizes[i], }; } if (paramsSize > 0) { @@ -1301,7 +1381,7 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, LOG(kDefLog, kInfo, "paramIndex: %d", paramIndex); bindGroupEntries[paramIndex] = WGPUBindGroupEntry{ .binding = static_cast<uint32_t>(paramIndex), - .buffer = op.buffers[paramIndex], + .buffer = op->buffers[paramIndex], .offset = 0, .size = paramsSize, }; @@ -1312,7 +1392,7 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, .entryCount = static_cast<uint32_t>(numBindings), .entries = bindGroupEntries.data(), }; - op.bindGroup = wgpuDeviceCreateBindGroup(device, &bindGroupDesc); + op->bindGroup = wgpuDeviceCreateBindGroup(device, &bindGroupDesc); WGPUPipelineLayoutDescriptor pipelineLayoutDesc = { .bindGroupLayoutCount = 1, @@ -1334,12 +1414,13 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, computePipelineDesc.compute.entryPoint = code.entryPoint.c_str(); computePipelineDesc.label = code.label.c_str(); - op.computePipeline = + op->computePipeline = wgpuDeviceCreateComputePipeline(device, &computePipelineDesc); - op.totalWorkgroups = {totalWorkgroups[0], totalWorkgroups[1], totalWorkgroups[2]}; + op->totalWorkgroups = {totalWorkgroups[0], totalWorkgroups[1], totalWorkgroups[2]}; resetCommandBuffer(device, op); - ctx.kernelPool.data.insert(&op); - + if (cacheKey != nullptr) + ctx.kernelPool.data[cacheKey]=op; + WGPUCompilationInfoCallback cb = [](WGPUCompilationInfoRequestStatus status, WGPUCompilationInfo const *compilationInfo, void *userData) { @@ -1394,17 +1475,20 @@ Kernel createKernel(Context &ctx, const KernelCode &code, const Bindings<numInputs> &dataBindings, const Shape &totalWorkgroups, const ParamsType ¶ms = ParamsType{}, - CompilationInfo* compilationInfo = nullptr + CompilationInfo* compilationInfo = nullptr, + const char* cacheKey = nullptr ) { if constexpr (!IsNoParam<ParamsType>) { return createKernel(ctx, code, dataBindings.data.data(), numInputs, dataBindings.viewOffsets.data(), totalWorkgroups, reinterpret_cast<const void *>(¶ms), - sizeof(ParamsType), compilationInfo); + sizeof(ParamsType), compilationInfo, + cacheKey); } else { return createKernel(ctx, code, dataBindings.data.data(), numInputs, dataBindings.viewOffsets.data(), totalWorkgroups, nullptr, - 0, compilationInfo); + 0, compilationInfo, + cacheKey); } } @@ -1429,7 +1513,12 @@ Kernel createKernel(Context &ctx, const KernelCode &code, inline void dispatchKernel(Context &ctx, Kernel &kernel, std::promise<void> &promise) { // Submit the command buffer - wgpuQueueSubmit(ctx.queue, 1, &kernel.commandBuffer); + if (kernel->used) { + resetCommandBuffer(ctx.device, kernel); + } + wgpuQueueSubmit(ctx.queue, 1, &kernel->commandBuffer); + wgpuCommandBufferRelease(kernel->commandBuffer); + kernel->used = true; wgpuQueueOnSubmittedWorkDone( ctx.queue, [](WGPUQueueWorkDoneStatus status, void *data) {