From 7addf836cfd28dd66d461dcff8e4a47c8b91b5d1 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Sun, 13 Oct 2024 10:32:00 +0900 Subject: [PATCH 01/13] Add kShaderMatmul2DTiling in kernels.h --- experimental/kernels/kernels.h | 98 ++++++++++ experimental/kernels/ops.cpp | 62 +++++-- .../unittest_llmc/unittest_kernels.cpp | 171 ++++++++++++++++-- 3 files changed, 299 insertions(+), 32 deletions(-) diff --git a/experimental/kernels/kernels.h b/experimental/kernels/kernels.h index 212c075..1bce081 100644 --- a/experimental/kernels/kernels.h +++ b/experimental/kernels/kernels.h @@ -309,6 +309,104 @@ fn main(@builtin(global_invocation_id) global_id : vec3) { } } } + +)"; + + +static const char *kShaderMatmul2DTiling = R"( +@group(0) @binding(0) var inp : array<{{precision}}>; +@group(0) @binding(1) var weight : array<{{precision}}>; +@group(0) @binding(2) var bias : array<{{precision}}>; +@group(0) @binding(3) var out : array<{{precision}}>; +@group(0) @binding(4) var params : Params; +struct Params { + B: u32, + T: u32, + C: u32, + OC: u32, +}; +var tileInp: array<{{precision}}, {{BT}} * {{BC}}>; +var tileWeight: array<{{precision}}, {{BOC}} * {{BC}}>; + +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(local_invocation_id) localID : vec3, + @builtin(workgroup_id) groupid : vec3) { + let B : u32 = params.B; + let T : u32 = params.T; + let C : u32 = params.C; + let OC : u32 = params.OC; + + var localT: array<{{precision}}, {{TT}}>; + var localOC: array<{{precision}}, {{TOC}}>; + + let outB: u32 = groupid.x; + let outT: u32 = groupid.y; + let outOC: u32 = groupid.z; + let numThread: u32 = ({{BT}} * {{BOC}}) / ({{TT}} * {{TOC}}); + + // position of the first c element computed by the thread + let threadRow: u32 = (localID.x / ({{BOC}} / {{TOC}})) * {{TT}}; + let threadCol: u32 = (localID.x % ({{BOC}} / {{TOC}})) * {{TOC}}; + + // inpPtr and weightPtr are the starting positions of the tiles in a and b, + // incremented in the bkidx loop. + // outPtr is the starting position of the tile in c which is fixed. + + var inpPtr = (outB * T + outT * {{BT}}) * C; // BTC + var weightPtr = outOC * {{BOC}} * C; //OCC + var threadResults: array<{{precision}}, {{TT}} * {{TOC}}>; + let outPtr = (outB * T + outT * {{BT}}) * OC + outOC * {{BOC}}; //BTOC + let biasPtr = outOC * {{BOC}}; + + for (var bkidx: u32 = 0; bkidx < C; bkidx += {{BC}}) { + // Load BC x BOC by numThread(BT * BOC / (TT * TOC)) + // The number of iteration == BC * BOC / (BT * BOC / (TT * TOC)) + for (var idx: u32 = 0; idx < {{NUM_TILEW}}; idx++) { + tileWeight[localID.x + idx * numThread] = weight[weightPtr + ((localID.x + idx * numThread) / {{BC}}) * C + ((localID.x + idx * numThread) % {{BC}})]; + } + weightPtr += {{BC}}; + + // Load tile + // Load BT x BC by numThread(BT * BOC / (TT * TOC)) + // The number of iteration == BT * BC / (BT * BOC / (TT * TOC)) + for (var idx: u32 = 0; idx < {{NUM_TILEI}}; idx++) { + tileInp[localID.x + idx * numThread] = inp[inpPtr + ((localID.x + idx * numThread) / {{BC}}) * C + (localID.x + idx * numThread) % {{BC}}]; + } + inpPtr += {{BC}}; + + workgroupBarrier(); + // Compute tile + for (var dotIdx: u32 = 0; dotIdx < {{BC}}; dotIdx = dotIdx + 1) { + for (var idx: u32 = 0; idx < {{TT}}; idx++) { + localT[idx] = tileInp[(threadRow + idx) * {{BC}} + dotIdx]; + } + for (var idx: u32 = 0; idx < {{TOC}}; idx++) { + localOC[idx] = tileWeight[(threadCol + idx) * {{BC}} + dotIdx]; + } + for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) { + for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) { + threadResults[resIdxT * {{TOC}} + resIdxOC] += localT[resIdxT] * localOC[resIdxOC]; + } + } + } + workgroupBarrier(); + } + + if (arrayLength(&bias) == 1) { + for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) { + for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) { + out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC]; + } + } + } else { + for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) { + for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) { + out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC] + bias[biasPtr + threadCol + resIdxOC]; + } + } + } +} )"; static const char *kShaderMatmulBackward = R"( diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index 67fc679..48edf9b 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -6,6 +6,7 @@ #include "kernels.h" #include "ops.hpp" +#include "experimental/wgsl.h" // loopUnrolling using namespace gpu; @@ -178,18 +179,55 @@ void matmul_forward(Context& ctx, float* out, std::promise promise; std::future future = promise.get_future(); assert ( (b*t) % 256 == 0 ); - Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); - dispatchKernel(ctx, op, promise); - wait(ctx, future); + int version = 1; + if (version == 1){ + static constexpr size_t BT = 64; + static constexpr size_t BC = 8; + static constexpr size_t BOC = 64; + static constexpr size_t TT = BT / BC; + static constexpr size_t TOC = BOC / BC; + size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; // This is the same as BK * BK. + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + + std::string codeString(kShaderMatmul2DTiling); + replaceAll(codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + }); + std::string unrolledCode = loopUnrolling(codeString); + Kernel op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + } else { + Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + } toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index 37cdcaf..9278163 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -5,6 +5,7 @@ #include "kernels.h" #include "unittest_llmc/unittest_kernels.h" +#include "experimental/wgsl.h" // loopUnrolling using namespace gpu; // createContext, createTensor, createKernel, // createShader, dispatchKernel, wait, toCPU @@ -51,6 +52,28 @@ using namespace gpu; // createContext, createTensor, createKernel, } \ } +struct DurationTime { + std::chrono::high_resolution_clock::time_point start; + std::chrono::high_resolution_clock::time_point end; + std::chrono::microseconds duration; + std::string src; + bool verbose; + + inline DurationTime(const std::string& src, bool verbose = true) { + this->src = src; + this->verbose = verbose; + start = std::chrono::high_resolution_clock::now(); + } + + inline ~DurationTime() { + end = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration_cast(end - start); + if (this->verbose) { + printf("Duration(%s): %.1f microseconds\n", src.c_str(), static_cast(duration.count())); + } + } +}; + void ENCODER_FORWARD_GPU(float* out, int* inp, float* wte, float* wpe, int B, int T, int C){ @@ -202,9 +225,25 @@ void LAYERNORM_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, toCPU(ctx, dbias_t, dbias, c * sizeof(float)); } +void matmul_forward_dummy(float* out, + const float* inp, const float* weight, const float* bias, + int B, int T, int C, int OC); + void MATMUL_FORWARD_GPU(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC){ + int version = 2; + bool verbose = false; + bool debug = false; + float *out_exp; + DurationTime duration("matmul_forward_gpu with creating context", verbose); + if (verbose) { + printf("matmul forward: B=%d, T=%d, C=%d, OC=%d, bias=%d\n", B, T, C, OC, bias != NULL); + } + if (debug) { + out_exp = new float[B*T*OC]; + matmul_forward_dummy(out_exp, inp, weight, bias, B, T, C, OC); + } struct MatmulParams { uint32_t B; uint32_t T; @@ -221,26 +260,118 @@ void MATMUL_FORWARD_GPU(float* out, .requiredLimits = &requiredLimits }); - Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32, weight); - Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32, bias); - Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); - std::promise promise; - std::future future = promise.get_future(); - assert ( (b*t) % 256 == 0 ); - Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + { + DurationTime duration("matmul_forward_gpu: before creating tensors", verbose); + Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32, inp); + Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32, weight); + Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32, bias); + Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + std::promise promise; + std::future future = promise.get_future(); + + if (version == 2) { + DurationTime duration("matmul_forward_gpu: after creating tensors", verbose); + static constexpr size_t BT = 64; + static constexpr size_t BC = 16; + static constexpr size_t BOC = 64; + static constexpr size_t TT = BT / BC; + static constexpr size_t TOC = BOC / BC; + static constexpr size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + + std::string codeString(kShaderMatmul2DTiling); + replaceAll(codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + }); + std::string unrolledCode = loopUnrolling(codeString); + { + DurationTime duration("matmul_forward_gpu: before creating kernels", verbose); + + Kernel op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }); + { + DurationTime duration("matmul_forward_gpu without creating context", verbose); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + } + } + } else if (version == 1) { + Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }); + { + DurationTime duration("matmul_forward_gpu without creating context", verbose); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + } + } else { + DurationTime duration("matmul_forward_cpu", verbose); + matmul_forward_dummy(out, inp, weight, bias, B, T, C, OC); + } + } + + if (debug) { // compare out with out_exp. + for (int i = 0; i < B*T*OC; i++) { + if (fabs(out[i] - out_exp[i]) > 1e-2) { + printf("matmul forward: out[%d] = %f, out_exp[%d] = %f\n", i, out[i], i, out_exp[i]); + //Dump the first 4 x 4 elements by table, at first output out, then output out_exp + printf("inp:\n"); + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 4; k++) { + printf("%f ", inp[j * C + k]); + } + printf("\n"); + } + printf("weight:\n"); + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 4; k++) { + printf("%f ", weight[j * OC + k]); + } + printf("\n"); + } + printf("out:\n"); + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 4; k++) { + printf("%f ", out[j * OC + k]); + } + printf("\n"); + } + printf("out_exp:\n"); + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 4; k++) { + printf("%f ", out_exp[j * OC + k]); + } + printf("\n"); + } + exit(1); + } + } + delete[] out_exp; + } } void MATMUL_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, From da1f32d2ce43001b774b05158d560f402173b33e Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Sun, 13 Oct 2024 10:51:59 +0900 Subject: [PATCH 02/13] Reduce matmul-kernel creation time --- experimental/kernels/Makefile | 2 +- experimental/kernels/ops.cpp | 46 +++++++++++++++++++---------------- gpu.hpp | 4 ++- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/experimental/kernels/Makefile b/experimental/kernels/Makefile index c233ef5..3f6f763 100644 --- a/experimental/kernels/Makefile +++ b/experimental/kernels/Makefile @@ -95,7 +95,7 @@ build/train_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin build/ops.o: ops.cpp ops.hpp kernels.h llm.c mkdir -p build && $(CXX) $(CXXFLAGS) -c -o $@ $< -build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c +build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c gpt2_webgpu.cpp ops.cpp mkdir -p build $(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu.cpp ops.cpp diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index 48edf9b..b8ab6ac 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -157,6 +157,29 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias, toCPU(ctx, dbias_t, dbias, c * sizeof(float)); } +static constexpr size_t MATMUL_BT = 64; +static constexpr size_t MATMUL_BC = 8; +static constexpr size_t MATMUL_BOC = 64; +static constexpr size_t MATMUL_TT = MATMUL_BT / MATMUL_BC; +static constexpr size_t MATMUL_TOC = MATMUL_BOC / MATMUL_BC; +static size_t MATMUL_num_threads = MATMUL_BT * MATMUL_BOC / (MATMUL_TT * MATMUL_TOC); +static Shape MATMUL_wgSize = {MATMUL_num_threads, 1, 1}; +static std::string kShaderMatmul2DTiling_(kShaderMatmul2DTiling); +static std::string kShaderMatmul2D(loopUnrolling( + replaceAll(kShaderMatmul2DTiling_, + {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(MATMUL_BT)}, + {"{{BC}}", toString(MATMUL_BC)}, + {"{{BOC}}", toString(MATMUL_BOC)}, + {"{{TT}}", toString(MATMUL_TT)}, + {"{{TOC}}", toString(MATMUL_TOC)}, + {"{{NUM_TILEI}}", toString(MATMUL_BT * MATMUL_BC / MATMUL_num_threads)}, + {"{{NUM_TILEW}}", toString(MATMUL_BOC * MATMUL_BC / MATMUL_num_threads)} + }) + ) + ); + + void matmul_forward(Context& ctx, float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC){ @@ -181,27 +204,8 @@ void matmul_forward(Context& ctx, float* out, assert ( (b*t) % 256 == 0 ); int version = 1; if (version == 1){ - static constexpr size_t BT = 64; - static constexpr size_t BC = 8; - static constexpr size_t BOC = 64; - static constexpr size_t TT = BT / BC; - static constexpr size_t TOC = BOC / BC; - size_t num_threads = BT * BOC / (TT * TOC); - Shape wgSize = {num_threads, 1, 1}; // This is the same as BK * BK. - Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; - - std::string codeString(kShaderMatmul2DTiling); - replaceAll(codeString, {{"{{precision}}", toString(kf32)}, - {"{{BT}}", toString(BT)}, - {"{{BC}}", toString(BC)}, - {"{{BOC}}", toString(BOC)}, - {"{{TT}}", toString(TT)}, - {"{{TOC}}", toString(TOC)}, - {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, - {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} - }); - std::string unrolledCode = loopUnrolling(codeString); - Kernel op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; + Kernel op = createKernel(ctx, {kShaderMatmul2D, MATMUL_wgSize, kf32}, Bindings{inp_i, weight_i, bias_i, out_o}, nWorkgroups, /* params */ diff --git a/gpu.hpp b/gpu.hpp index 83fc94b..8b68463 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -415,12 +415,14 @@ struct KernelCode { * @endcode * "f32"}}); */ -inline void +inline const std::string replaceAll(std::string &str, const std::vector> &reps) { for (const auto &rep : reps) { replaceAll(str, rep.first, rep.second); } + + return str; } /** From dd2a25f530dabb784456e44cf5b1f5049f626bea Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Mon, 14 Oct 2024 04:24:37 +0900 Subject: [PATCH 03/13] Change Kernel to shared_ptr to support cached kernels --- experimental/kernels/Makefile | 4 +- experimental/kernels/ops.cpp | 34 ++++- .../unittest_llmc/unittest_kernels.cpp | 135 ++++++++++-------- gpu.hpp | 135 +++++++++++++----- 4 files changed, 215 insertions(+), 93 deletions(-) diff --git a/experimental/kernels/Makefile b/experimental/kernels/Makefile index 3f6f763..5817e23 100644 --- a/experimental/kernels/Makefile +++ b/experimental/kernels/Makefile @@ -79,7 +79,7 @@ endef build/test_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin mkdir -p build $(call preprocess_file) - $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o + $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o -g build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bin mkdir -p build @@ -90,7 +90,7 @@ build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bi build/train_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin mkdir -p build $(call preprocess_file) - $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o + $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o -g build/ops.o: ops.cpp ops.hpp kernels.h llm.c mkdir -p build && $(CXX) $(CXXFLAGS) -c -o $@ $< diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index b8ab6ac..b4b7555 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -179,10 +179,34 @@ static std::string kShaderMatmul2D(loopUnrolling( ) ); +struct DurationTime { + std::chrono::high_resolution_clock::time_point start; + std::chrono::high_resolution_clock::time_point end; + std::chrono::microseconds duration; + std::string src; + bool verbose; + + inline DurationTime(const std::string& src, bool verbose = true) { + this->src = src; + this->verbose = verbose; + start = std::chrono::high_resolution_clock::now(); + } + + inline ~DurationTime() { + end = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration_cast(end - start); + if (this->verbose) { + printf("Duration(%s): %.1f microseconds\n", src.c_str(), static_cast(duration.count())); + } + } +}; + void matmul_forward(Context& ctx, float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC){ + bool verbose = false; + DurationTime duration("matmul_forward_gpu", verbose); struct MatmulParams { uint32_t B; uint32_t T; @@ -204,6 +228,7 @@ void matmul_forward(Context& ctx, float* out, assert ( (b*t) % 256 == 0 ); int version = 1; if (version == 1){ + DurationTime duration("matmul_forward_gpu without creating tensors", verbose); Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; Kernel op = createKernel(ctx, {kShaderMatmul2D, MATMUL_wgSize, kf32}, Bindings{inp_i, weight_i, bias_i, out_o}, @@ -215,9 +240,12 @@ void matmul_forward(Context& ctx, float* out, static_cast(c), static_cast(oc) }); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + { + DurationTime duration("matmul_forward_gpu without creating kernel", verbose); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + } } else { Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, Bindings{inp_i, weight_i, bias_i, out_o}, diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index 9278163..4593bf5 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include "kernels.h" #include "unittest_llmc/unittest_kernels.h" @@ -229,6 +230,31 @@ void matmul_forward_dummy(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC); +static WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; +static Context ctx = createContext({},{},{ + .requiredLimits = &requiredLimits + }); + +static constexpr size_t BT = 64; +static constexpr size_t BC = 16; +static constexpr size_t BOC = 64; +static constexpr size_t TT = BT / BC; +static constexpr size_t TOC = BOC / BC; +static constexpr size_t num_threads = BT * BOC / (TT * TOC); +static Shape wgSize = {num_threads, 1, 1}; + +static std::string codeString(kShaderMatmul2DTiling); +static std::string unrolledCode = loopUnrolling(replaceAll(codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + })); + + void MATMUL_FORWARD_GPU(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC){ @@ -255,55 +281,52 @@ void MATMUL_FORWARD_GPU(float* out, unsigned long c = static_cast(C); unsigned long oc = static_cast(OC); setLogLevel(kError); - WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; - Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); { DurationTime duration("matmul_forward_gpu: before creating tensors", verbose); - Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32, weight); - Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32, bias); - Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + // Generate the key of the cache by arguments. + std::string key = std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); + Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); + Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }, + nullptr, + key.c_str() + ); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_i = ctx.pool.data[op->buffers[0]]; + Tensor& weight_i = ctx.pool.data[op->buffers[1]]; + Tensor& bias_i = ctx.pool.data[op->buffers[2]]; + Tensor& out_o = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, inp_i); + toGPU(ctx, weight, weight_i); + if (bias != NULL) { + toGPU(ctx, bias, bias_i); + } + std::promise promise; std::future future = promise.get_future(); if (version == 2) { DurationTime duration("matmul_forward_gpu: after creating tensors", verbose); - static constexpr size_t BT = 64; - static constexpr size_t BC = 16; - static constexpr size_t BOC = 64; - static constexpr size_t TT = BT / BC; - static constexpr size_t TOC = BOC / BC; - static constexpr size_t num_threads = BT * BOC / (TT * TOC); - Shape wgSize = {num_threads, 1, 1}; - Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; - - std::string codeString(kShaderMatmul2DTiling); - replaceAll(codeString, {{"{{precision}}", toString(kf32)}, - {"{{BT}}", toString(BT)}, - {"{{BC}}", toString(BC)}, - {"{{BOC}}", toString(BOC)}, - {"{{TT}}", toString(TT)}, - {"{{TOC}}", toString(TOC)}, - {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, - {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} - }); - std::string unrolledCode = loopUnrolling(codeString); { DurationTime duration("matmul_forward_gpu: before creating kernels", verbose); - - Kernel op = createKernel(ctx, {unrolledCode, wgSize, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - nWorkgroups, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); { DurationTime duration("matmul_forward_gpu without creating context", verbose); dispatchKernel(ctx, op, promise); @@ -311,23 +334,23 @@ void MATMUL_FORWARD_GPU(float* out, toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } } - } else if (version == 1) { - Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); - { - DurationTime duration("matmul_forward_gpu without creating context", verbose); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); - } + // } else if (version == 1) { + // Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, + // Bindings{inp_i, weight_i, bias_i, out_o}, + // /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + // /* params */ + // MatmulParams{ + // static_cast(b), + // static_cast(t), + // static_cast(c), + // static_cast(oc) + // }); + // { + // DurationTime duration("matmul_forward_gpu without creating context", verbose); + // dispatchKernel(ctx, op, promise); + // wait(ctx, future); + // toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + // } } else { DurationTime duration("matmul_forward_cpu", verbose); matmul_forward_dummy(out, inp, weight, bias, B, T, C, OC); diff --git a/gpu.hpp b/gpu.hpp index 8b68463..9c34d78 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -454,7 +454,7 @@ struct CopyData { * The struct members can be divided into "consumed upon dispatch" * (commandBuffer) and reusable ahead-of-time setup (all other members). */ -struct Kernel { +struct RawKernel { std::unique_ptr buffers; // non-owning std::unique_ptr bufferSizes; size_t numBindings; @@ -462,8 +462,11 @@ struct Kernel { WGPUBindGroup bindGroup; // persists between submission WGPUComputePipeline computePipeline; // persists between submission WGPUCommandBuffer commandBuffer; // destroyed upon submission + bool used; }; +typedef std::shared_ptr Kernel; + /** * @brief A struct to package the result of a WGSL code compilation. @@ -483,7 +486,7 @@ struct CompilationInfo { * @return True if lhs < rhs, false otherwise */ inline bool operator<(const Kernel &lhs, const Kernel &rhs) { - return lhs.commandBuffer < rhs.commandBuffer; + return lhs->commandBuffer < rhs->commandBuffer; } /** @@ -494,7 +497,7 @@ inline bool operator<(const Kernel &lhs, const Kernel &rhs) { struct KernelPool { inline KernelPool(Context *ctx) : ctx(ctx), data() {} Context *ctx; - std::set data; + std::unordered_map data; inline ~KernelPool() { // Note : Some kernel resources such as commandBuffer are harvested by // queue submission, explicitly destroying readback and callback buffers @@ -1080,6 +1083,57 @@ void toCPU(Context &ctx, Tensor &tensor, std::array &data) { toCPU(ctx, tensor, data.data(), sizeof(data)); } +inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, + size_t size) { + uint64_t bufferSize = size; + CopyData op; + op.future = op.promise.get_future(); + { + WGPUBufferDescriptor readbackBufferDescriptor = { + .usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead, + .size = bufferSize, + }; + op.readbackBuffer = + wgpuDeviceCreateBuffer(ctx.device, &readbackBufferDescriptor); + } + { + WGPUCommandEncoder commandEncoder; + WGPUComputePassEncoder computePassEncoder; + commandEncoder = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr); + wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, buffer, 0, + op.readbackBuffer, 0, bufferSize); + op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__); + } + wgpuQueueSubmit(ctx.queue, 1, &op.commandBuffer); + CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise, + &op.future}; + wgpuQueueOnSubmittedWorkDone( + ctx.queue, + [](WGPUQueueWorkDoneStatus status, void *callbackData) { + check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", + __FILE__, __LINE__); + const auto *data = static_cast(callbackData); + wgpuBufferMapAsync( + data->buffer, WGPUMapMode_Read, 0, data->bufferSize, + [](WGPUBufferMapAsyncStatus status, void *captureData) { + const auto *data = static_cast(captureData); + check(status == WGPUBufferMapAsyncStatus_Success, + "Map readbackBuffer", __FILE__, __LINE__); + const void *mappedData = wgpuBufferGetConstMappedRange( + data->buffer, /*offset=*/0, data->bufferSize); + check(mappedData, "Get mapped range", __FILE__, __LINE__); + memcpy(data->output, mappedData, data->bufferSize); + wgpuBufferUnmap(data->buffer); + data->promise->set_value(); + }, + callbackData); + }, + &callbackData); + wait(ctx, op.future); +} + + /** * @brief Copies data from CPU memory to a GPU buffer. The toGPU overloads are * effectively a convenience wrapper around the WebGPU API call @@ -1126,8 +1180,8 @@ inline void toGPU(Context &ctx, Params ¶ms, Kernel &op) { // TODO(avh): Maintain params metadata in Kernel and check for consistency. // If a kernel does not have parameters this will quietly overwrite // the last buffer in the bind group with the parameters buffer. - if (op.numBindings > 0) { - wgpuQueueWriteBuffer(ctx.queue, op.buffers[op.numBindings - 1], 0, + if (op->numBindings > 0) { + wgpuQueueWriteBuffer(ctx.queue, op->buffers[op->numBindings - 1], 0, static_cast(¶ms), sizeof(params)); } } @@ -1150,14 +1204,15 @@ inline void resetCommandBuffer(WGPUDevice &device, Kernel &op) { wgpuDeviceCreateCommandEncoder(device, nullptr); WGPUComputePassEncoder computePassEncoder = wgpuCommandEncoderBeginComputePass(commandEncoder, nullptr); - wgpuComputePassEncoderSetPipeline(computePassEncoder, op.computePipeline); - wgpuComputePassEncoderSetBindGroup(computePassEncoder, 0, op.bindGroup, 0, + wgpuComputePassEncoderSetPipeline(computePassEncoder, op->computePipeline); + wgpuComputePassEncoderSetBindGroup(computePassEncoder, 0, op->bindGroup, 0, nullptr); wgpuComputePassEncoderDispatchWorkgroups( - computePassEncoder, op.totalWorkgroups[0], op.totalWorkgroups[1], - op.totalWorkgroups[2]); + computePassEncoder, op->totalWorkgroups[0], op->totalWorkgroups[1], + op->totalWorkgroups[2]); wgpuComputePassEncoderEnd(computePassEncoder); - op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + op->commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + op->used = false; } } @@ -1219,11 +1274,19 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, const size_t *viewOffsets, const Shape &totalWorkgroups, const void *params = nullptr, size_t paramsSize = 0, - CompilationInfo* compilationInfo = nullptr) { + CompilationInfo* compilationInfo = nullptr, + const char* cacheKey = nullptr) { + // Create a cache key by the pointer values of the data bindings and the kernel code + if (cacheKey != nullptr && ctx.kernelPool.data.find(cacheKey) != ctx.kernelPool.data.end()) { + LOG(kDefLog, kInfo, "Kernel cache hit"); + return ctx.kernelPool.data[cacheKey]; + } + assert(totalWorkgroups.rank == 3); WGPUDevice device = ctx.device; WGPUQueue queue = ctx.queue; - Kernel op; + Kernel op(new RawKernel()); + // paramIndex is the index into bgLayoutEntries for the parameters buffer If // there are no parameters for the kernel, paramsSize == 0 and paramIndex is // effectively undefined (== -1) @@ -1236,9 +1299,9 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, // op.buffers, op.bufferSizes and // bgLayoutEntries } - op.buffers = std::make_unique(numBindings); - op.bufferSizes = std::make_unique(numBindings); - op.numBindings = numBindings; + op->buffers = std::make_unique(numBindings); + op->bufferSizes = std::make_unique(numBindings); + op->numBindings = numBindings; std::vector bgLayoutEntries(numBindings); // Create layout entries for input buffers for (size_t i = 0; i < numTensors; ++i) { @@ -1272,8 +1335,8 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, WGPUBindGroupLayout bgLayout = wgpuDeviceCreateBindGroupLayout(device, &bgLayoutDesc); for (size_t i = 0; i < numTensors; ++i) { - op.buffers[i] = dataBindings[i].data.buffer; - op.bufferSizes[i] = dataBindings[i].data.size; + op->buffers[i] = dataBindings[i].data.buffer; + op->bufferSizes[i] = dataBindings[i].data.size; } // Create a buffer for the Params struct if (paramsSize > 0) { @@ -1282,9 +1345,9 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, .size = paramsSize, .mappedAtCreation = false, }; - op.buffers[paramIndex] = wgpuDeviceCreateBuffer(device, ¶msBufferDesc); - op.bufferSizes[paramIndex] = paramsSize; - wgpuQueueWriteBuffer(queue, op.buffers[paramIndex], 0, params, paramsSize); + op->buffers[paramIndex] = wgpuDeviceCreateBuffer(device, ¶msBufferDesc); + op->bufferSizes[paramIndex] = paramsSize; + wgpuQueueWriteBuffer(queue, op->buffers[paramIndex], 0, params, paramsSize); LOG(kDefLog, kTrace, "Params buffer written"); } else { LOG(kDefLog, kTrace, "No params buffer needed"); @@ -1293,9 +1356,9 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, for (size_t i = 0; i < numTensors; ++i) { bindGroupEntries[i] = WGPUBindGroupEntry{ .binding = static_cast(i), - .buffer = op.buffers[i], + .buffer = op->buffers[i], .offset = viewOffsets[i], - .size = op.bufferSizes[i], + .size = op->bufferSizes[i], }; } if (paramsSize > 0) { @@ -1303,7 +1366,7 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, LOG(kDefLog, kInfo, "paramIndex: %d", paramIndex); bindGroupEntries[paramIndex] = WGPUBindGroupEntry{ .binding = static_cast(paramIndex), - .buffer = op.buffers[paramIndex], + .buffer = op->buffers[paramIndex], .offset = 0, .size = paramsSize, }; @@ -1314,7 +1377,7 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, .entryCount = static_cast(numBindings), .entries = bindGroupEntries.data(), }; - op.bindGroup = wgpuDeviceCreateBindGroup(device, &bindGroupDesc); + op->bindGroup = wgpuDeviceCreateBindGroup(device, &bindGroupDesc); WGPUPipelineLayoutDescriptor pipelineLayoutDesc = { .bindGroupLayoutCount = 1, @@ -1336,12 +1399,13 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, computePipelineDesc.compute.entryPoint = code.entryPoint.c_str(); computePipelineDesc.label = code.label.c_str(); - op.computePipeline = + op->computePipeline = wgpuDeviceCreateComputePipeline(device, &computePipelineDesc); - op.totalWorkgroups = {totalWorkgroups[0], totalWorkgroups[1], totalWorkgroups[2]}; + op->totalWorkgroups = {totalWorkgroups[0], totalWorkgroups[1], totalWorkgroups[2]}; resetCommandBuffer(device, op); - ctx.kernelPool.data.insert(&op); - + if (cacheKey != nullptr) + ctx.kernelPool.data[cacheKey]=op; + WGPUCompilationInfoCallback cb = [](WGPUCompilationInfoRequestStatus status, WGPUCompilationInfo const *compilationInfo, void *userData) { @@ -1396,17 +1460,20 @@ Kernel createKernel(Context &ctx, const KernelCode &code, const Bindings &dataBindings, const Shape &totalWorkgroups, const ParamsType ¶ms = ParamsType{}, - CompilationInfo* compilationInfo = nullptr + CompilationInfo* compilationInfo = nullptr, + const char* cacheKey = nullptr ) { if constexpr (!IsNoParam) { return createKernel(ctx, code, dataBindings.data.data(), numInputs, dataBindings.viewOffsets.data(), totalWorkgroups, reinterpret_cast(¶ms), - sizeof(ParamsType), compilationInfo); + sizeof(ParamsType), compilationInfo, + cacheKey); } else { return createKernel(ctx, code, dataBindings.data.data(), numInputs, dataBindings.viewOffsets.data(), totalWorkgroups, nullptr, - 0, compilationInfo); + 0, compilationInfo, + cacheKey); } } @@ -1431,7 +1498,11 @@ Kernel createKernel(Context &ctx, const KernelCode &code, inline void dispatchKernel(Context &ctx, Kernel &kernel, std::promise &promise) { // Submit the command buffer - wgpuQueueSubmit(ctx.queue, 1, &kernel.commandBuffer); + if (kernel->used) { + resetCommandBuffer(ctx.device, kernel); + } + wgpuQueueSubmit(ctx.queue, 1, &kernel->commandBuffer); + kernel->used = true; wgpuQueueOnSubmittedWorkDone( ctx.queue, [](WGPUQueueWorkDoneStatus status, void *data) { From 9fff8cd39a9e89a7b4633f97b788a67c232d0658 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 16 Oct 2024 17:43:16 +0900 Subject: [PATCH 04/13] Add caches for ops.cpp --- experimental/kernels/ops.cpp | 639 ++++++++++++++++++++++++----------- gpu.hpp | 5 + 2 files changed, 448 insertions(+), 196 deletions(-) diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index b4b7555..6390314 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -23,27 +23,39 @@ void encoder_forward(Context& ctx, float* out, uint32_t C; }; setLogLevel(kError); - printf("Creating tensors\n"); - printf("Creating input tensor\%pn", inp); - Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp); - printf("Created input tensor\n"); - Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32, wte); - printf("Created wte tensor\n"); - Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32, wpe); - printf("Created wpe tensor\n"); - Tensor output = createTensor(ctx, Shape{b * t * c}, kf32); - printf("Created tensors\n"); + // Generate the key of the cache by arguments. + std::string key = "encoder_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, Shape{b * t}, ki32); + Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32); + Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32); + Tensor output = createTensor(ctx, Shape{b * t * c}, kf32); + op = createKernel(ctx, {kShaderEncoder, 256, kf32}, + Bindings{input, wte_t, wpe_t, output}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& wte_t = ctx.pool.data[op->buffers[1]]; + Tensor& wpe_t = ctx.pool.data[op->buffers[2]]; + Tensor& output = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, input); + toGPU(ctx, wte, wte_t); + toGPU(ctx, wpe, wpe_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderEncoder, 256, kf32}, - Bindings{input, wte_t, wpe_t, output}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - EncoderParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, out, b * t * c * sizeof(float)); @@ -62,21 +74,38 @@ void encoder_backward(Context& ctx, float* dwte, float* dwpe, uint32_t C; }; setLogLevel(kError); - Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32, dwte); - Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32, dwpe); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp); + // Generate the key of the cache by arguments. + std::string key = "encoder_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32); + Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor input = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, + Bindings{dwte_t, dwpe_t, dout_t, input}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dwte_t = ctx.pool.data[op->buffers[0]]; + Tensor& dwpe_t = ctx.pool.data[op->buffers[1]]; + Tensor& dout_t = ctx.pool.data[op->buffers[2]]; + Tensor& input = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, input); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, - Bindings{dwte_t, dwpe_t, dout_t, input}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - EncoderParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dwte_t, dwte, v * c * sizeof(float)); @@ -95,23 +124,43 @@ void layernorm_forward(Context& ctx, float* out, float* mean, float* rstd, uint32_t C; }; setLogLevel(kError); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight); - Tensor bias_t = createTensor(ctx, Shape{c}, kf32, bias); - Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); - Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); - Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + // Generate the key of the cache by arguments. + std::string key = "layernorm_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{c}, kf32); + Tensor bias_t = createTensor(ctx, Shape{c}, kf32); + Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + op = createKernel(ctx, {kShaderLayerNorm, 256, kf32}, + Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_t = ctx.pool.data[op->buffers[0]]; + Tensor& weight_t = ctx.pool.data[op->buffers[1]]; + Tensor& bias_t = ctx.pool.data[op->buffers[2]]; + Tensor& out_t = ctx.pool.data[op->buffers[3]]; + Tensor& mean_t = ctx.pool.data[op->buffers[4]]; + Tensor& rstd_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + toGPU(ctx, bias, bias_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNorm, 256, kf32}, - Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - LayerNormParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_t, out, b * t * c * sizeof(float)); @@ -131,25 +180,49 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias, uint32_t C; }; setLogLevel(kError); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp); - Tensor dweight_t = createTensor(ctx, Shape{c}, kf32, dweight); - Tensor dbias_t = createTensor(ctx, Shape{c}, kf32, dbias); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight); - Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32, mean); - Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32, rstd); + // Generate the key of the cache by arguments. + std::string key = "layernorm_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor dweight_t = createTensor(ctx, Shape{c}, kf32); + Tensor dbias_t = createTensor(ctx, Shape{c}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{c}, kf32); + Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, + Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dweight_t = ctx.pool.data[op->buffers[1]]; + Tensor& dbias_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + Tensor& mean_t = ctx.pool.data[op->buffers[6]]; + Tensor& rstd_t = ctx.pool.data[op->buffers[7]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + toGPU(ctx, mean, mean_t); + toGPU(ctx, rstd, rstd_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, - Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - LayerNormParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float)); @@ -219,46 +292,51 @@ void matmul_forward(Context& ctx, float* out, unsigned long oc = static_cast(OC); setLogLevel(kError); - Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32, weight); - Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32, bias); - Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + // Generate the key of the cache by arguments. + std::string key = "matmul_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; + Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); + Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); + Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + op = createKernel(ctx, {kShaderMatmul2D, MATMUL_wgSize, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_i = ctx.pool.data[op->buffers[0]]; + Tensor& weight_i = ctx.pool.data[op->buffers[1]]; + Tensor& bias_i = ctx.pool.data[op->buffers[2]]; + Tensor& out_o = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, inp_i); + toGPU(ctx, weight, weight_i); + if (bias != NULL) { + toGPU(ctx, bias, bias_i); + } + std::promise promise; std::future future = promise.get_future(); - assert ( (b*t) % 256 == 0 ); - int version = 1; - if (version == 1){ + { DurationTime duration("matmul_forward_gpu without creating tensors", verbose); - Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; - Kernel op = createKernel(ctx, {kShaderMatmul2D, MATMUL_wgSize, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - nWorkgroups, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); { DurationTime duration("matmul_forward_gpu without creating kernel", verbose); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } - } else { - Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); - dispatchKernel(ctx, op, promise); - wait(ctx, future); } toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } @@ -277,24 +355,44 @@ void matmul_backward(Context& ctx, float* dinp, float* dweight, float* dbias, unsigned long c = static_cast(C); unsigned long oc = static_cast(OC); setLogLevel(kError); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp); - Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32, dweight); - Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32, dbias); - Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32, weight); + // Generate the key of the cache by arguments. + std::string key = "matmul_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32); + Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32); + op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, + Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dweight_t = ctx.pool.data[op->buffers[1]]; + Tensor& dbias_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, - Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float)); @@ -316,22 +414,38 @@ void attention_forward(Context& ctx, float* out, float* preatt, float* att, unsigned long c = static_cast(C); unsigned long nh = static_cast(NH); setLogLevel(kError); - Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp); - Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, preatt); - Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att); - Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + // Generate the key of the cache by arguments. + std::string key = "attention_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + op = createKernel(ctx, {kShaderAttention, 256, kf32}, + Bindings{inp_t, preatt_t, att_t, out_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(nh) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_t = ctx.pool.data[op->buffers[0]]; + Tensor& preatt_t = ctx.pool.data[op->buffers[1]]; + Tensor& att_t = ctx.pool.data[op->buffers[2]]; + Tensor& out_t = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, inp_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderAttention, 256, kf32}, - Bindings{inp_t, preatt_t, att_t, out_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - AttentionParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(nh) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, preatt_t, preatt, b * nh * t * t * sizeof(float)); @@ -353,24 +467,44 @@ void attention_backward(Context& ctx, float* dinp, float* dpreatt, float* datt, unsigned long c = static_cast(C); unsigned long nh = static_cast(NH); setLogLevel(kError); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, dinp); - Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, dpreatt); - Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, datt); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp); - Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att); + // Generate the key of the cache by arguments. + std::string key = "attention_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, + Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(nh) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dpreatt_t = ctx.pool.data[op->buffers[1]]; + Tensor& datt_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& att_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, att, att_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, - Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - AttentionParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(nh) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * 3 * sizeof(float)); @@ -381,13 +515,28 @@ void attention_backward(Context& ctx, float* dinp, float* dpreatt, float* datt, void gelu_forward(Context& ctx, float* out, float* inp, int n) { unsigned long N = static_cast(n); setLogLevel(kError); - Tensor input = createTensor(ctx, Shape{N}, kf32, inp); - Tensor output = createTensor(ctx, Shape{N}, kf32); + // Generate the key of the cache by arguments. + std::string key = "gelu_forward_" + std::to_string(n); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, Shape{N}, kf32); + Tensor output = createTensor(ctx, Shape{N}, kf32); + op = createKernel(ctx, {kShaderGelu, 256, kf32}, + Bindings{input, output}, + /* nWorkgroups */ {cdiv(N, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& output = ctx.pool.data[op->buffers[1]]; + + toGPU(ctx, inp, input); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderGelu, 256, kf32}, - Bindings{input, output}, - /* nWorkgroups */ {cdiv(N, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, out, N * sizeof(float)); @@ -396,14 +545,31 @@ void gelu_forward(Context& ctx, float* out, float* inp, int n) { void gelu_backward(Context& ctx, float* dinp, float* inp, float* dout, int N){ unsigned long n = static_cast(N); setLogLevel(kError); - Tensor inp_i = createTensor(ctx, Shape{n}, kf32, inp); - Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout); - Tensor dinp_o = createTensor(ctx, Shape{n}, kf32, dinp); + // Generate the key of the cache by arguments. + std::string key = "gelu_backward_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_i = createTensor(ctx, Shape{n}, kf32); + Tensor dout_i = createTensor(ctx, Shape{n}, kf32); + Tensor dinp_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderGeluBackward, 256, kf32}, + Bindings{inp_i, dout_i, dinp_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_i = ctx.pool.data[op->buffers[0]]; + Tensor& dout_i = ctx.pool.data[op->buffers[1]]; + Tensor& dinp_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, inp, inp_i); + toGPU(ctx, dout, dout_i); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderGeluBackward, 256, kf32}, - Bindings{inp_i, dout_i, dinp_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_o, dinp, n * sizeof(float)); @@ -412,14 +578,31 @@ void gelu_backward(Context& ctx, float* dinp, float* inp, float* dout, int N){ void residual_forward(Context& ctx, float* out, float* inp1, float* inp2, int N){ unsigned long n = static_cast(N); setLogLevel(kError); - Tensor inp1_i = createTensor(ctx, Shape{n}, kf32, inp1); - Tensor inp2_i = createTensor(ctx, Shape{n}, kf32, inp2); - Tensor out_o = createTensor(ctx, Shape{n}, kf32); + // Generate the key of the cache by arguments. + std::string key = "residual_forward_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp1_i = createTensor(ctx, Shape{n}, kf32); + Tensor inp2_i = createTensor(ctx, Shape{n}, kf32); + Tensor out_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderResidual, 256, kf32}, + Bindings{inp1_i, inp2_i, out_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp1_i = ctx.pool.data[op->buffers[0]]; + Tensor& inp2_i = ctx.pool.data[op->buffers[1]]; + Tensor& out_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, inp1, inp1_i); + toGPU(ctx, inp2, inp2_i); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderResidual, 256, kf32}, - Bindings{inp1_i, inp2_i, out_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_o, out, n * sizeof(float)); @@ -428,14 +611,30 @@ void residual_forward(Context& ctx, float* out, float* inp1, float* inp2, int N) void residual_backward(Context& ctx, float* dinp1, float* dinp2, float* dout, int N){ unsigned long n = static_cast(N); setLogLevel(kError); - Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout); - Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32, dinp1); - Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32, dinp2); + // Generate the key of the cache by arguments. + std::string key = "residual_backward_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dout_i = createTensor(ctx, Shape{n}, kf32); + Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32); + Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderResidualBackward, 256, kf32}, + Bindings{dout_i, dinp1_o, dinp2_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dout_i = ctx.pool.data[op->buffers[0]]; + Tensor& dinp1_o = ctx.pool.data[op->buffers[1]]; + Tensor& dinp2_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, dout, dout_i); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderResidualBackward, 256, kf32}, - Bindings{dout_i, dinp1_o, dinp2_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp1_o, dinp1, n * sizeof(float)); @@ -452,14 +651,28 @@ void softmax_forward(Context& ctx, float* probs, float* logits, int B, int T, in uint32_t t = static_cast(T); uint32_t c = static_cast(V); uint32_t cp = static_cast(Vp); - Tensor input = createTensor(ctx, {b * t, cp}, kf32, logits); - Tensor output = createTensor(ctx, {b * t, cp}, kf32); + // Generate the key of the cache by arguments. + std::string key = "softmax_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, {b * t, cp}, kf32); + Tensor output = createTensor(ctx, {b * t, cp}, kf32); + assert( (B*T) % 256 == 0); + op = createKernel( + ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, + Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& output = ctx.pool.data[op->buffers[1]]; + + toGPU(ctx, logits, input); + std::promise promise; std::future future = promise.get_future(); - assert( (B*T) % 256 == 0); - Kernel op = createKernel( - ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, - Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, probs, sizeof(float)*b*t*cp); @@ -477,20 +690,36 @@ void crossentropy_forward(Context& ctx, float* losses, unsigned long t = static_cast(T); unsigned long vp = static_cast(Vp); setLogLevel(kError); - Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32, losses); - Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs); - Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets); + // Generate the key of the cache by arguments. + std::string key = "crossentropy_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, + Bindings{losses_t, probs_t, targets_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropyParams{ + static_cast(b), + static_cast(t), + static_cast(vp) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& losses_t = ctx.pool.data[op->buffers[0]]; + Tensor& probs_t = ctx.pool.data[op->buffers[1]]; + Tensor& targets_t = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, probs, probs_t); + toGPU(ctx, targets, targets_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, - Bindings{losses_t, probs_t, targets_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - CrossEntropyParams{ - static_cast(b), - static_cast(t), - static_cast(vp) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, losses_t, losses, b * t * sizeof(float)); @@ -510,22 +739,40 @@ void crossentropy_softmax_backward(Context& ctx, float* dlogits, unsigned long v = static_cast(V); unsigned long vp = static_cast(Vp); setLogLevel(kError); - Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32, dlogits); - Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32, dlosses); - Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs); - Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets); + // Generate the key of the cache by arguments. + std::string key = "crossentropy_softmax_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, + Bindings{dlogits_t, dlosses_t, probs_t, targets_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropySoftmaxBackwardParams{ + static_cast(b), + static_cast(t), + static_cast(v), + static_cast(vp) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dlogits_t = ctx.pool.data[op->buffers[0]]; + Tensor& dlosses_t = ctx.pool.data[op->buffers[1]]; + Tensor& probs_t = ctx.pool.data[op->buffers[2]]; + Tensor& targets_t = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, dlosses, dlosses_t); + toGPU(ctx, probs, probs_t); + toGPU(ctx, targets, targets_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, - Bindings{dlogits_t, dlosses_t, probs_t, targets_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - CrossEntropySoftmaxBackwardParams{ - static_cast(b), - static_cast(t), - static_cast(v), - static_cast(vp) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dlogits_t, dlogits, b * t * vp * sizeof(float)); diff --git a/gpu.hpp b/gpu.hpp index 9c34d78..8f8bd9b 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -1175,6 +1175,11 @@ inline void toGPU(Context &ctx, const half *data, Tensor &tensor) { tensor.data.size); } +inline void toGPU(Context &ctx, const int *data, Tensor &tensor) { + wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data, + tensor.data.size); +} + template inline void toGPU(Context &ctx, Params ¶ms, Kernel &op) { // TODO(avh): Maintain params metadata in Kernel and check for consistency. From b9fb38bb6d4f8e5abc8783c11f3c7ea6e5d059a2 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 16 Oct 2024 18:10:59 +0900 Subject: [PATCH 05/13] Add caches for unittests.cpp --- .../unittest_llmc/unittest_kernels.cpp | 652 ++++++++++++------ 1 file changed, 446 insertions(+), 206 deletions(-) diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index 4593bf5..a883124 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -75,6 +75,11 @@ struct DurationTime { } }; +static WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; +static Context ctx = createContext({},{},{ + .requiredLimits = &requiredLimits + }); + void ENCODER_FORWARD_GPU(float* out, int* inp, float* wte, float* wpe, int B, int T, int C){ @@ -88,25 +93,40 @@ void ENCODER_FORWARD_GPU(float* out, uint32_t C; }; setLogLevel(kError); - WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; - Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); - Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp); - Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32, wte); - Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32, wpe); - Tensor output = createTensor(ctx, Shape{b * t * c}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "ENCODER_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, Shape{b * t}, ki32); + Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32); + Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32); + Tensor output = createTensor(ctx, Shape{b * t * c}, kf32); + op = createKernel(ctx, {kShaderEncoder, 256, kf32}, + Bindings{input, wte_t, wpe_t, output}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& wte_t = ctx.pool.data[op->buffers[1]]; + Tensor& wpe_t = ctx.pool.data[op->buffers[2]]; + Tensor& output = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, input); + toGPU(ctx, wte, wte_t); + toGPU(ctx, wpe, wpe_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderEncoder, 256, kf32}, - Bindings{input, wte_t, wpe_t, output}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - EncoderParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, out, b * t * c * sizeof(float)); @@ -125,25 +145,39 @@ void ENCODER_BACKWARD_GPU(float* dwte, float* dwpe, uint32_t C; }; setLogLevel(kError); - WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; - Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); - Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32, dwte); - Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32, dwpe); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp); + + // Generate the key of the cache by arguments. + std::string key = "ENCODER_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32); + Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor input = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, + Bindings{dwte_t, dwpe_t, dout_t, input}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dwte_t = ctx.pool.data[op->buffers[0]]; + Tensor& dwpe_t = ctx.pool.data[op->buffers[1]]; + Tensor& dout_t = ctx.pool.data[op->buffers[2]]; + Tensor& input = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, input); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, - Bindings{dwte_t, dwpe_t, dout_t, input}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - EncoderParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dwte_t, dwte, v * c * sizeof(float)); @@ -162,24 +196,44 @@ void LAYERNORM_FORWARD_GPU(float* out, float* mean, float* rstd, uint32_t C; }; setLogLevel(kError); - Context ctx = createContext(); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight); - Tensor bias_t = createTensor(ctx, Shape{c}, kf32, bias); - Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); - Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); - Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "LAYERNORM_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{c}, kf32); + Tensor bias_t = createTensor(ctx, Shape{c}, kf32); + Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + op = createKernel(ctx, {kShaderLayerNorm, 256, kf32}, + Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_t = ctx.pool.data[op->buffers[0]]; + Tensor& weight_t = ctx.pool.data[op->buffers[1]]; + Tensor& bias_t = ctx.pool.data[op->buffers[2]]; + Tensor& out_t = ctx.pool.data[op->buffers[3]]; + Tensor& mean_t = ctx.pool.data[op->buffers[4]]; + Tensor& rstd_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + toGPU(ctx, bias, bias_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNorm, 256, kf32}, - Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - LayerNormParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_t, out, b * t * c * sizeof(float)); @@ -199,26 +253,50 @@ void LAYERNORM_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, uint32_t C; }; setLogLevel(kError); - Context ctx = createContext(); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp); - Tensor dweight_t = createTensor(ctx, Shape{c}, kf32, dweight); - Tensor dbias_t = createTensor(ctx, Shape{c}, kf32, dbias); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight); - Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32, mean); - Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32, rstd); + + // Generate the key of the cache by arguments. + std::string key = "LAYERNORM_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor dweight_t = createTensor(ctx, Shape{c}, kf32); + Tensor dbias_t = createTensor(ctx, Shape{c}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{c}, kf32); + Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, + Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dweight_t = ctx.pool.data[op->buffers[1]]; + Tensor& dbias_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + Tensor& mean_t = ctx.pool.data[op->buffers[6]]; + Tensor& rstd_t = ctx.pool.data[op->buffers[7]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + toGPU(ctx, mean, mean_t); + toGPU(ctx, rstd, rstd_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, - Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - LayerNormParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float)); @@ -230,28 +308,24 @@ void matmul_forward_dummy(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC); -static WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; -static Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); -static constexpr size_t BT = 64; -static constexpr size_t BC = 16; -static constexpr size_t BOC = 64; -static constexpr size_t TT = BT / BC; -static constexpr size_t TOC = BOC / BC; -static constexpr size_t num_threads = BT * BOC / (TT * TOC); -static Shape wgSize = {num_threads, 1, 1}; - -static std::string codeString(kShaderMatmul2DTiling); -static std::string unrolledCode = loopUnrolling(replaceAll(codeString, {{"{{precision}}", toString(kf32)}, - {"{{BT}}", toString(BT)}, - {"{{BC}}", toString(BC)}, - {"{{BOC}}", toString(BOC)}, - {"{{TT}}", toString(TT)}, - {"{{TOC}}", toString(TOC)}, - {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, - {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} +static constexpr size_t MATMUL_BT = 64; +static constexpr size_t MATMUL_BC = 16; +static constexpr size_t MATMUL_BOC = 64; +static constexpr size_t MATMUL_TT = MATMUL_BT / MATMUL_BC; +static constexpr size_t MATMUL_TOC = MATMUL_BOC / MATMUL_BC; +static constexpr size_t MATMUL_num_threads = MATMUL_BT * MATMUL_BOC / (MATMUL_TT * MATMUL_TOC); +static Shape MATMUL_wgSize = {MATMUL_num_threads, 1, 1}; + +static std::string MATMUL_codeString(kShaderMatmul2DTiling); +static std::string MATMUL_unrolledCode = loopUnrolling(replaceAll(MATMUL_codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(MATMUL_BT)}, + {"{{BC}}", toString(MATMUL_BC)}, + {"{{BOC}}", toString(MATMUL_BOC)}, + {"{{TT}}", toString(MATMUL_TT)}, + {"{{TOC}}", toString(MATMUL_TOC)}, + {"{{NUM_TILEI}}", toString(MATMUL_BT * MATMUL_BC / MATMUL_num_threads)}, + {"{{NUM_TILEW}}", toString(MATMUL_BOC * MATMUL_BC / MATMUL_num_threads)} })); @@ -285,15 +359,15 @@ void MATMUL_FORWARD_GPU(float* out, { DurationTime duration("matmul_forward_gpu: before creating tensors", verbose); // Generate the key of the cache by arguments. - std::string key = std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + std::string key = "MATMUL_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); Kernel op; if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { - Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); - op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + op = createKernel(ctx, {MATMUL_unrolledCode, MATMUL_wgSize, kf32}, Bindings{inp_i, weight_i, bias_i, out_o}, nWorkgroups, /* params */ @@ -411,28 +485,45 @@ void MATMUL_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, unsigned long c = static_cast(C); unsigned long oc = static_cast(OC); setLogLevel(kError); - WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; - Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp); - Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32, dweight); - Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32, dbias); - Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32, weight); + + // Generate the key of the cache by arguments. + std::string key = "MATMUL_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32); + Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32); + op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, + Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dweight_t = ctx.pool.data[op->buffers[1]]; + Tensor& dbias_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, - Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float)); @@ -454,23 +545,39 @@ void ATTENTION_FORWARD_GPU(float* out, float* preatt, float* att, unsigned long c = static_cast(C); unsigned long nh = static_cast(NH); setLogLevel(kError); - Context ctx = createContext(); - Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp); - Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, preatt); - Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att); - Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "ATTENTION_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + op = createKernel(ctx, {kShaderAttention, 256, kf32}, + Bindings{inp_t, preatt_t, att_t, out_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(nh) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_t = ctx.pool.data[op->buffers[0]]; + Tensor& preatt_t = ctx.pool.data[op->buffers[1]]; + Tensor& att_t = ctx.pool.data[op->buffers[2]]; + Tensor& out_t = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, inp_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderAttention, 256, kf32}, - Bindings{inp_t, preatt_t, att_t, out_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - AttentionParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(nh) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, preatt_t, preatt, b * nh * t * t * sizeof(float)); @@ -492,25 +599,45 @@ void ATTENTION_BACKWARD_GPU(float* dinp, float* dpreatt, float* datt, unsigned long c = static_cast(C); unsigned long nh = static_cast(NH); setLogLevel(kError); - Context ctx = createContext(); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, dinp); - Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, dpreatt); - Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, datt); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp); - Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att); + + // Generate the key of the cache by arguments. + std::string key = "ATTENTION_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, + Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(nh) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dpreatt_t = ctx.pool.data[op->buffers[1]]; + Tensor& datt_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& att_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, att, att_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, - Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - AttentionParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(nh) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * 3 * sizeof(float)); @@ -521,14 +648,29 @@ void ATTENTION_BACKWARD_GPU(float* dinp, float* dpreatt, float* datt, void GELU_FORWARD_GPU(float* out, float* inp, int n) { unsigned long N = static_cast(n); setLogLevel(kError); - Context ctx = createContext(); - Tensor input = createTensor(ctx, Shape{N}, kf32, inp); - Tensor output = createTensor(ctx, Shape{N}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "GELU_FORWARD_GPU_" + std::to_string(n); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, Shape{N}, kf32); + Tensor output = createTensor(ctx, Shape{N}, kf32); + op = createKernel(ctx, {kShaderGelu, 256, kf32}, + Bindings{input, output}, + /* nWorkgroups */ {cdiv(N, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& output = ctx.pool.data[op->buffers[1]]; + + toGPU(ctx, inp, input); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderGelu, 256, kf32}, - Bindings{input, output}, - /* nWorkgroups */ {cdiv(N, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, out, N * sizeof(float)); @@ -537,15 +679,32 @@ void GELU_FORWARD_GPU(float* out, float* inp, int n) { void GELU_BACKWARD_GPU(float* dinp, float* inp, float* dout, int N){ unsigned long n = static_cast(N); setLogLevel(kError); - Context ctx = createContext(); - Tensor inp_i = createTensor(ctx, Shape{n}, kf32, inp); - Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout); - Tensor dinp_o = createTensor(ctx, Shape{n}, kf32, dinp); + + // Generate the key of the cache by arguments. + std::string key = "GELU_BACKWARD_GPU_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_i = createTensor(ctx, Shape{n}, kf32); + Tensor dout_i = createTensor(ctx, Shape{n}, kf32); + Tensor dinp_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderGeluBackward, 256, kf32}, + Bindings{inp_i, dout_i, dinp_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_i = ctx.pool.data[op->buffers[0]]; + Tensor& dout_i = ctx.pool.data[op->buffers[1]]; + Tensor& dinp_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, inp, inp_i); + toGPU(ctx, dout, dout_i); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderGeluBackward, 256, kf32}, - Bindings{inp_i, dout_i, dinp_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_o, dinp, n * sizeof(float)); @@ -554,15 +713,32 @@ void GELU_BACKWARD_GPU(float* dinp, float* inp, float* dout, int N){ void RESIDUAL_FORWARD_GPU(float* out, float* inp1, float* inp2, int N){ unsigned long n = static_cast(N); setLogLevel(kError); - Context ctx = createContext(); - Tensor inp1_i = createTensor(ctx, Shape{n}, kf32, inp1); - Tensor inp2_i = createTensor(ctx, Shape{n}, kf32, inp2); - Tensor out_o = createTensor(ctx, Shape{n}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "RESIDUAL_FORWARD_GPU_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp1_i = createTensor(ctx, Shape{n}, kf32); + Tensor inp2_i = createTensor(ctx, Shape{n}, kf32); + Tensor out_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderResidual, 256, kf32}, + Bindings{inp1_i, inp2_i, out_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp1_i = ctx.pool.data[op->buffers[0]]; + Tensor& inp2_i = ctx.pool.data[op->buffers[1]]; + Tensor& out_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, inp1, inp1_i); + toGPU(ctx, inp2, inp2_i); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderResidual, 256, kf32}, - Bindings{inp1_i, inp2_i, out_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_o, out, n * sizeof(float)); @@ -571,15 +747,31 @@ void RESIDUAL_FORWARD_GPU(float* out, float* inp1, float* inp2, int N){ void RESIDUAL_BACKWARD_GPU(float* dinp1, float* dinp2, float* dout, int N){ unsigned long n = static_cast(N); setLogLevel(kError); - Context ctx = createContext(); - Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout); - Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32, dinp1); - Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32, dinp2); + + // Generate the key of the cache by arguments. + std::string key = "RESIDUAL_BACKWARD_GPU_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dout_i = createTensor(ctx, Shape{n}, kf32); + Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32); + Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderResidualBackward, 256, kf32}, + Bindings{dout_i, dinp1_o, dinp2_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dout_i = ctx.pool.data[op->buffers[0]]; + Tensor& dinp1_o = ctx.pool.data[op->buffers[1]]; + Tensor& dinp2_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, dout, dout_i); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderResidualBackward, 256, kf32}, - Bindings{dout_i, dinp1_o, dinp2_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp1_o, dinp1, n * sizeof(float)); @@ -596,15 +788,29 @@ void SOFTMAX_FORWARD_GPU(float* probs, float* logits, int B, int T, int V, int V uint32_t t = static_cast(T); uint32_t c = static_cast(V); uint32_t cp = static_cast(Vp); - Context ctx = createContext(); - Tensor input = createTensor(ctx, {b * t, cp}, kf32, logits); - Tensor output = createTensor(ctx, {b * t, cp}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "SOFTMAX_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, {b * t, cp}, kf32); + Tensor output = createTensor(ctx, {b * t, cp}, kf32); + assert( (B*T) % 256 == 0); + op = createKernel( + ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, + Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& output = ctx.pool.data[op->buffers[1]]; + + toGPU(ctx, logits, input); + std::promise promise; std::future future = promise.get_future(); - assert( (B*T) % 256 == 0); - Kernel op = createKernel( - ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, - Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, probs, sizeof(float)*b*t*cp); @@ -622,21 +828,37 @@ void CROSSENTROPY_FORWARD_GPU(float* losses, unsigned long t = static_cast(T); unsigned long vp = static_cast(Vp); setLogLevel(kError); - Context ctx = createContext(); - Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32, losses); - Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs); - Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets); + + // Generate the key of the cache by arguments. + std::string key = "CROSSENTROPY_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, + Bindings{losses_t, probs_t, targets_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropyParams{ + static_cast(b), + static_cast(t), + static_cast(vp) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& losses_t = ctx.pool.data[op->buffers[0]]; + Tensor& probs_t = ctx.pool.data[op->buffers[1]]; + Tensor& targets_t = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, probs, probs_t); + toGPU(ctx, targets, targets_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, - Bindings{losses_t, probs_t, targets_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - CrossEntropyParams{ - static_cast(b), - static_cast(t), - static_cast(vp) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, losses_t, losses, b * t * sizeof(float)); @@ -656,23 +878,41 @@ void CROSSENTROPY_SOFTMAX_BACKWARD_GPU(float* dlogits, unsigned long v = static_cast(V); unsigned long vp = static_cast(Vp); setLogLevel(kError); - Context ctx = createContext(); - Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32, dlogits); - Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32, dlosses); - Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs); - Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets); + + // Generate the key of the cache by arguments. + std::string key = "CROSSENTROPY_SOFTMAX_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, + Bindings{dlogits_t, dlosses_t, probs_t, targets_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropySoftmaxBackwardParams{ + static_cast(b), + static_cast(t), + static_cast(v), + static_cast(vp) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dlogits_t = ctx.pool.data[op->buffers[0]]; + Tensor& dlosses_t = ctx.pool.data[op->buffers[1]]; + Tensor& probs_t = ctx.pool.data[op->buffers[2]]; + Tensor& targets_t = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, dlosses, dlosses_t); + toGPU(ctx, probs, probs_t); + toGPU(ctx, targets, targets_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, - Bindings{dlogits_t, dlosses_t, probs_t, targets_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - CrossEntropySoftmaxBackwardParams{ - static_cast(b), - static_cast(t), - static_cast(v), - static_cast(vp) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dlogits_t, dlogits, b * t * vp * sizeof(float)); From efb87eebce70acaf5dcc5b9a2a0d67339e719142 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 16 Oct 2024 19:59:28 +0900 Subject: [PATCH 06/13] Fix bugs --- experimental/kernels/ops.cpp | 18 +++++++++++++++++ .../unittest_llmc/unittest_kernels.cpp | 20 ++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index 6390314..3d5f528 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -101,6 +101,8 @@ void encoder_backward(Context& ctx, float* dwte, float* dwpe, Tensor& dout_t = ctx.pool.data[op->buffers[2]]; Tensor& input = ctx.pool.data[op->buffers[3]]; + toGPU(ctx, dwte, dwte_t); + toGPU(ctx, dwpe, dwpe_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, input); @@ -215,6 +217,9 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias, Tensor& mean_t = ctx.pool.data[op->buffers[6]]; Tensor& rstd_t = ctx.pool.data[op->buffers[7]]; + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dweight, dweight_t); + toGPU(ctx, dbias, dbias_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, inp_t); toGPU(ctx, weight, weight_t); @@ -387,6 +392,9 @@ void matmul_backward(Context& ctx, float* dinp, float* dweight, float* dbias, Tensor& inp_t = ctx.pool.data[op->buffers[4]]; Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dweight, dweight_t); + toGPU(ctx, dbias, dbias_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, inp_t); toGPU(ctx, weight, weight_t); @@ -443,6 +451,8 @@ void attention_forward(Context& ctx, float* out, float* preatt, float* att, Tensor& out_t = ctx.pool.data[op->buffers[3]]; toGPU(ctx, inp, inp_t); + toGPU(ctx, preatt, preatt_t); + toGPU(ctx, att, att_t); std::promise promise; std::future future = promise.get_future(); @@ -499,6 +509,9 @@ void attention_backward(Context& ctx, float* dinp, float* dpreatt, float* datt, Tensor& inp_t = ctx.pool.data[op->buffers[4]]; Tensor& att_t = ctx.pool.data[op->buffers[5]]; + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dpreatt, dpreatt_t); + toGPU(ctx, datt, datt_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, inp_t); toGPU(ctx, att, att_t); @@ -567,6 +580,7 @@ void gelu_backward(Context& ctx, float* dinp, float* inp, float* dout, int N){ toGPU(ctx, inp, inp_i); toGPU(ctx, dout, dout_i); + toGPU(ctx, dinp, dinp_o); std::promise promise; std::future future = promise.get_future(); @@ -632,6 +646,8 @@ void residual_backward(Context& ctx, float* dinp1, float* dinp2, float* dout, in Tensor& dinp2_o = ctx.pool.data[op->buffers[2]]; toGPU(ctx, dout, dout_i); + toGPU(ctx, dinp1, dinp1_o); + toGPU(ctx, dinp2, dinp2_o); std::promise promise; std::future future = promise.get_future(); @@ -715,6 +731,7 @@ void crossentropy_forward(Context& ctx, float* losses, Tensor& probs_t = ctx.pool.data[op->buffers[1]]; Tensor& targets_t = ctx.pool.data[op->buffers[2]]; + toGPU(ctx, losses, losses_t); toGPU(ctx, probs, probs_t); toGPU(ctx, targets, targets_t); @@ -767,6 +784,7 @@ void crossentropy_softmax_backward(Context& ctx, float* dlogits, Tensor& probs_t = ctx.pool.data[op->buffers[2]]; Tensor& targets_t = ctx.pool.data[op->buffers[3]]; + toGPU(ctx, dlogits, dlogits_t); toGPU(ctx, dlosses, dlosses_t); toGPU(ctx, probs, probs_t); toGPU(ctx, targets, targets_t); diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index a883124..a455780 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -173,6 +173,8 @@ void ENCODER_BACKWARD_GPU(float* dwte, float* dwpe, Tensor& dout_t = ctx.pool.data[op->buffers[2]]; Tensor& input = ctx.pool.data[op->buffers[3]]; + toGPU(ctx, dwte, dwte_t); + toGPU(ctx, dwpe, dwpe_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, input); @@ -289,6 +291,9 @@ void LAYERNORM_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, Tensor& mean_t = ctx.pool.data[op->buffers[6]]; Tensor& rstd_t = ctx.pool.data[op->buffers[7]]; + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dweight, dweight_t); + toGPU(ctx, dbias, dbias_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, inp_t); toGPU(ctx, weight, weight_t); @@ -518,6 +523,9 @@ void MATMUL_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, Tensor& inp_t = ctx.pool.data[op->buffers[4]]; Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dweight, dweight_t); + toGPU(ctx, dbias, dbias_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, inp_t); toGPU(ctx, weight, weight_t); @@ -575,6 +583,8 @@ void ATTENTION_FORWARD_GPU(float* out, float* preatt, float* att, Tensor& out_t = ctx.pool.data[op->buffers[3]]; toGPU(ctx, inp, inp_t); + toGPU(ctx, preatt, preatt_t); + toGPU(ctx, att, att_t); std::promise promise; std::future future = promise.get_future(); @@ -632,10 +642,13 @@ void ATTENTION_BACKWARD_GPU(float* dinp, float* dpreatt, float* datt, Tensor& inp_t = ctx.pool.data[op->buffers[4]]; Tensor& att_t = ctx.pool.data[op->buffers[5]]; + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dpreatt, dpreatt_t); + toGPU(ctx, datt, datt_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, inp_t); toGPU(ctx, att, att_t); - + std::promise promise; std::future future = promise.get_future(); dispatchKernel(ctx, op, promise); @@ -702,6 +715,7 @@ void GELU_BACKWARD_GPU(float* dinp, float* inp, float* dout, int N){ toGPU(ctx, inp, inp_i); toGPU(ctx, dout, dout_i); + toGPU(ctx, dinp, dinp_o); std::promise promise; std::future future = promise.get_future(); @@ -769,6 +783,8 @@ void RESIDUAL_BACKWARD_GPU(float* dinp1, float* dinp2, float* dout, int N){ Tensor& dinp2_o = ctx.pool.data[op->buffers[2]]; toGPU(ctx, dout, dout_i); + toGPU(ctx, dinp1, dinp1_o); + toGPU(ctx, dinp2, dinp2_o); std::promise promise; std::future future = promise.get_future(); @@ -854,6 +870,7 @@ void CROSSENTROPY_FORWARD_GPU(float* losses, Tensor& probs_t = ctx.pool.data[op->buffers[1]]; Tensor& targets_t = ctx.pool.data[op->buffers[2]]; + toGPU(ctx, losses, losses_t); toGPU(ctx, probs, probs_t); toGPU(ctx, targets, targets_t); @@ -907,6 +924,7 @@ void CROSSENTROPY_SOFTMAX_BACKWARD_GPU(float* dlogits, Tensor& probs_t = ctx.pool.data[op->buffers[2]]; Tensor& targets_t = ctx.pool.data[op->buffers[3]]; + toGPU(ctx, dlogits, dlogits_t); toGPU(ctx, dlosses, dlosses_t); toGPU(ctx, probs, probs_t); toGPU(ctx, targets, targets_t); From c474fba4560a942a0a21c1323d83b5799c277fad Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 16 Oct 2024 20:07:16 +0900 Subject: [PATCH 07/13] Remove global variables of kernels --- experimental/kernels/ops.cpp | 50 ++++++++++--------- .../unittest_llmc/unittest_kernels.cpp | 43 ++++++++-------- 2 files changed, 47 insertions(+), 46 deletions(-) diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index 3d5f528..0e9c076 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -235,28 +235,6 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias, toCPU(ctx, dbias_t, dbias, c * sizeof(float)); } -static constexpr size_t MATMUL_BT = 64; -static constexpr size_t MATMUL_BC = 8; -static constexpr size_t MATMUL_BOC = 64; -static constexpr size_t MATMUL_TT = MATMUL_BT / MATMUL_BC; -static constexpr size_t MATMUL_TOC = MATMUL_BOC / MATMUL_BC; -static size_t MATMUL_num_threads = MATMUL_BT * MATMUL_BOC / (MATMUL_TT * MATMUL_TOC); -static Shape MATMUL_wgSize = {MATMUL_num_threads, 1, 1}; -static std::string kShaderMatmul2DTiling_(kShaderMatmul2DTiling); -static std::string kShaderMatmul2D(loopUnrolling( - replaceAll(kShaderMatmul2DTiling_, - {{"{{precision}}", toString(kf32)}, - {"{{BT}}", toString(MATMUL_BT)}, - {"{{BC}}", toString(MATMUL_BC)}, - {"{{BOC}}", toString(MATMUL_BOC)}, - {"{{TT}}", toString(MATMUL_TT)}, - {"{{TOC}}", toString(MATMUL_TOC)}, - {"{{NUM_TILEI}}", toString(MATMUL_BT * MATMUL_BC / MATMUL_num_threads)}, - {"{{NUM_TILEW}}", toString(MATMUL_BOC * MATMUL_BC / MATMUL_num_threads)} - }) - ) - ); - struct DurationTime { std::chrono::high_resolution_clock::time_point start; std::chrono::high_resolution_clock::time_point end; @@ -301,12 +279,36 @@ void matmul_forward(Context& ctx, float* out, std::string key = "matmul_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); Kernel op; if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { - Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; + constexpr size_t BT = 64; + constexpr size_t BC = 8; + constexpr size_t BOC = 64; + constexpr size_t TT = BT / BC; + constexpr size_t TOC = BOC / BC; + size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + + std::string kShaderMatmul2DTiling_(kShaderMatmul2DTiling); + std::string kShaderMatmul2D(loopUnrolling( + replaceAll(kShaderMatmul2DTiling_, + {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + }) + ) + ); + Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); - op = createKernel(ctx, {kShaderMatmul2D, MATMUL_wgSize, kf32}, + + op = createKernel(ctx, {kShaderMatmul2D, wgSize, kf32}, Bindings{inp_i, weight_i, bias_i, out_o}, nWorkgroups, /* params */ diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index a455780..d58432e 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -314,26 +314,6 @@ void matmul_forward_dummy(float* out, int B, int T, int C, int OC); -static constexpr size_t MATMUL_BT = 64; -static constexpr size_t MATMUL_BC = 16; -static constexpr size_t MATMUL_BOC = 64; -static constexpr size_t MATMUL_TT = MATMUL_BT / MATMUL_BC; -static constexpr size_t MATMUL_TOC = MATMUL_BOC / MATMUL_BC; -static constexpr size_t MATMUL_num_threads = MATMUL_BT * MATMUL_BOC / (MATMUL_TT * MATMUL_TOC); -static Shape MATMUL_wgSize = {MATMUL_num_threads, 1, 1}; - -static std::string MATMUL_codeString(kShaderMatmul2DTiling); -static std::string MATMUL_unrolledCode = loopUnrolling(replaceAll(MATMUL_codeString, {{"{{precision}}", toString(kf32)}, - {"{{BT}}", toString(MATMUL_BT)}, - {"{{BC}}", toString(MATMUL_BC)}, - {"{{BOC}}", toString(MATMUL_BOC)}, - {"{{TT}}", toString(MATMUL_TT)}, - {"{{TOC}}", toString(MATMUL_TOC)}, - {"{{NUM_TILEI}}", toString(MATMUL_BT * MATMUL_BC / MATMUL_num_threads)}, - {"{{NUM_TILEW}}", toString(MATMUL_BOC * MATMUL_BC / MATMUL_num_threads)} - })); - - void MATMUL_FORWARD_GPU(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC){ @@ -367,12 +347,31 @@ void MATMUL_FORWARD_GPU(float* out, std::string key = "MATMUL_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); Kernel op; if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { - Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; + constexpr size_t BT = 64; + constexpr size_t BC = 16; + constexpr size_t BOC = 64; + constexpr size_t TT = BT / BC; + constexpr size_t TOC = BOC / BC; + constexpr size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; + + std::string codeString(kShaderMatmul2DTiling); + std::string unrolledCode = loopUnrolling(replaceAll(codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + })); + + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); - op = createKernel(ctx, {MATMUL_unrolledCode, MATMUL_wgSize, kf32}, + op = createKernel(ctx, {unrolledCode, wgSize, kf32}, Bindings{inp_i, weight_i, bias_i, out_o}, nWorkgroups, /* params */ From 438be22df671de669ebeae58698e3560ab5c7b17 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 16 Oct 2024 21:09:44 +0900 Subject: [PATCH 08/13] Fix the matmul of version 1 in unittests --- .../unittest_llmc/unittest_kernels.cpp | 129 +++++++++--------- 1 file changed, 61 insertions(+), 68 deletions(-) diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index d58432e..311d895 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -321,7 +321,7 @@ void MATMUL_FORWARD_GPU(float* out, bool verbose = false; bool debug = false; float *out_exp; - DurationTime duration("matmul_forward_gpu with creating context", verbose); + DurationTime duration("matmul_forward_gpu with preparing a kernel", verbose); if (verbose) { printf("matmul forward: B=%d, T=%d, C=%d, OC=%d, bias=%d\n", B, T, C, OC, bias != NULL); } @@ -341,49 +341,65 @@ void MATMUL_FORWARD_GPU(float* out, unsigned long oc = static_cast(OC); setLogLevel(kError); - { - DurationTime duration("matmul_forward_gpu: before creating tensors", verbose); + if (version == 2 || version == 1) { // Generate the key of the cache by arguments. std::string key = "MATMUL_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); Kernel op; if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { - constexpr size_t BT = 64; - constexpr size_t BC = 16; - constexpr size_t BOC = 64; - constexpr size_t TT = BT / BC; - constexpr size_t TOC = BOC / BC; - constexpr size_t num_threads = BT * BOC / (TT * TOC); - Shape wgSize = {num_threads, 1, 1}; - - std::string codeString(kShaderMatmul2DTiling); - std::string unrolledCode = loopUnrolling(replaceAll(codeString, {{"{{precision}}", toString(kf32)}, - {"{{BT}}", toString(BT)}, - {"{{BC}}", toString(BC)}, - {"{{BOC}}", toString(BOC)}, - {"{{TT}}", toString(TT)}, - {"{{TOC}}", toString(TOC)}, - {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, - {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} - })); - - Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); - op = createKernel(ctx, {unrolledCode, wgSize, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - nWorkgroups, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }, - nullptr, - key.c_str() - ); + + if (version == 2) { + constexpr size_t BT = 64; + constexpr size_t BC = 16; + constexpr size_t BOC = 64; + constexpr size_t TT = BT / BC; + constexpr size_t TOC = BOC / BC; + constexpr size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; + + std::string codeString(kShaderMatmul2DTiling); + std::string unrolledCode = loopUnrolling(replaceAll(codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + })); + + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }, + nullptr, + key.c_str() + ); + } else { + op = createKernel(ctx, {kShaderMatmul, 256, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }, + nullptr, + key.c_str() + ); + } } else { op = ctx.kernelPool.data[key]; } @@ -400,39 +416,16 @@ void MATMUL_FORWARD_GPU(float* out, std::promise promise; std::future future = promise.get_future(); - - if (version == 2) { - DurationTime duration("matmul_forward_gpu: after creating tensors", verbose); - { - DurationTime duration("matmul_forward_gpu: before creating kernels", verbose); - { - DurationTime duration("matmul_forward_gpu without creating context", verbose); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); - } - } - // } else if (version == 1) { - // Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, - // Bindings{inp_i, weight_i, bias_i, out_o}, - // /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - // /* params */ - // MatmulParams{ - // static_cast(b), - // static_cast(t), - // static_cast(c), - // static_cast(oc) - // }); - // { - // DurationTime duration("matmul_forward_gpu without creating context", verbose); - // dispatchKernel(ctx, op, promise); - // wait(ctx, future); - // toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); - // } - } else { - DurationTime duration("matmul_forward_cpu", verbose); - matmul_forward_dummy(out, inp, weight, bias, B, T, C, OC); + + { + DurationTime duration("matmul_forward_gpu", verbose); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } + } else { + DurationTime duration("matmul_forward_cpu", verbose); + matmul_forward_dummy(out, inp, weight, bias, B, T, C, OC); } if (debug) { // compare out with out_exp. From dd145acd76a00d399ea93760dd837c7927b87cd7 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Thu, 17 Oct 2024 02:00:18 +0900 Subject: [PATCH 09/13] Add the duration-time of matmul_forward_dummy to compare GPU's one with CPU's one --- experimental/kernels/unittest_llmc/unittest_kernels.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index 311d895..d037eac 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -327,7 +327,10 @@ void MATMUL_FORWARD_GPU(float* out, } if (debug) { out_exp = new float[B*T*OC]; - matmul_forward_dummy(out_exp, inp, weight, bias, B, T, C, OC); + { + DurationTime duration("matmul_forward_cpu", verbose); + matmul_forward_dummy(out_exp, inp, weight, bias, B, T, C, OC); + } } struct MatmulParams { uint32_t B; @@ -421,8 +424,8 @@ void MATMUL_FORWARD_GPU(float* out, DurationTime duration("matmul_forward_gpu", verbose); dispatchKernel(ctx, op, promise); wait(ctx, future); - toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } else { DurationTime duration("matmul_forward_cpu", verbose); matmul_forward_dummy(out, inp, weight, bias, B, T, C, OC); From 590f2571710ebca733e316b96e1360d502df36a7 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Fri, 18 Oct 2024 16:09:51 +0900 Subject: [PATCH 10/13] Add wgpuBufferRelease for CopyData --- gpu.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gpu.hpp b/gpu.hpp index 8f8bd9b..2db9f66 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -1065,6 +1065,9 @@ inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize) { check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__); } toCPU(ctx, tensor, data, bufferSize, op); + if (op.readbackBuffer) { + wgpuBufferRelease(op.readbackBuffer); + } } /** @@ -1131,6 +1134,9 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, }, &callbackData); wait(ctx, op.future); + if (op.readbackBuffer) { + wgpuBufferRelease(op.readbackBuffer); + } } From 6c52b98dd80ac174b46ba1b439f74956a5392d01 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Fri, 18 Oct 2024 17:26:43 +0900 Subject: [PATCH 11/13] Add wgpuCommandBufferRelease after calling wgpuQueueSubmit --- gpu.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gpu.hpp b/gpu.hpp index 2db9f66..4cca2a1 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -1002,6 +1002,7 @@ inline void wait(Context &ctx, std::future &future) { inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize, CopyData &op) { wgpuQueueSubmit(ctx.queue, 1, &op.commandBuffer); + wgpuCommandBufferRelease(op.commandBuffer); CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise, &op.future}; wgpuQueueOnSubmittedWorkDone( @@ -1109,6 +1110,7 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__); } wgpuQueueSubmit(ctx.queue, 1, &op.commandBuffer); + wgpuCommandBufferRelease(op.commandBuffer); CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise, &op.future}; wgpuQueueOnSubmittedWorkDone( @@ -1513,6 +1515,7 @@ inline void dispatchKernel(Context &ctx, Kernel &kernel, resetCommandBuffer(ctx.device, kernel); } wgpuQueueSubmit(ctx.queue, 1, &kernel->commandBuffer); + wgpuCommandBufferRelease(kernel->commandBuffer); kernel->used = true; wgpuQueueOnSubmittedWorkDone( ctx.queue, From a6140e0d83bf6c27de6d1264dd1b15f3abcaf67f Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Fri, 18 Oct 2024 17:31:16 +0900 Subject: [PATCH 12/13] Add wgpuCommandEncoderRelease after calling wgpuCommandEncoderFinish --- gpu.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gpu.hpp b/gpu.hpp index 4cca2a1..9262dcb 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -1063,6 +1063,7 @@ inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize) { wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, tensor.data.buffer, 0, op.readbackBuffer, 0, bufferSize); op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + wgpuCommandEncoderRelease(commandEncoder); check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__); } toCPU(ctx, tensor, data, bufferSize, op); @@ -1107,6 +1108,7 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, buffer, 0, op.readbackBuffer, 0, bufferSize); op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + wgpuCommandEncoderRelease(commandEncoder); check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__); } wgpuQueueSubmit(ctx.queue, 1, &op.commandBuffer); @@ -1225,6 +1227,7 @@ inline void resetCommandBuffer(WGPUDevice &device, Kernel &op) { op->totalWorkgroups[2]); wgpuComputePassEncoderEnd(computePassEncoder); op->commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + wgpuCommandEncoderRelease(commandEncoder); op->used = false; } } From 30ed026e1f62e70ca1c9bfbd5a4541f8e76845ea Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Fri, 18 Oct 2024 17:41:48 +0900 Subject: [PATCH 13/13] Add wgpuComputePassEncoderRelease after calling wgpuComputePassEncoderEnd --- gpu.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gpu.hpp b/gpu.hpp index 9262dcb..941656e 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -1058,7 +1058,6 @@ inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize) { } { WGPUCommandEncoder commandEncoder; - WGPUComputePassEncoder computePassEncoder; commandEncoder = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr); wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, tensor.data.buffer, 0, op.readbackBuffer, 0, bufferSize); @@ -1103,7 +1102,6 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, } { WGPUCommandEncoder commandEncoder; - WGPUComputePassEncoder computePassEncoder; commandEncoder = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr); wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, buffer, 0, op.readbackBuffer, 0, bufferSize); @@ -1226,6 +1224,7 @@ inline void resetCommandBuffer(WGPUDevice &device, Kernel &op) { computePassEncoder, op->totalWorkgroups[0], op->totalWorkgroups[1], op->totalWorkgroups[2]); wgpuComputePassEncoderEnd(computePassEncoder); + wgpuComputePassEncoderRelease(computePassEncoder); op->commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); wgpuCommandEncoderRelease(commandEncoder); op->used = false;