diff --git a/experimental/kernels/Makefile b/experimental/kernels/Makefile
index c233ef5..5817e23 100644
--- a/experimental/kernels/Makefile
+++ b/experimental/kernels/Makefile
@@ -79,7 +79,7 @@ endef
 build/test_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin
 	mkdir -p build
 	$(call preprocess_file)
-	$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o
+	$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o -g
 
 build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bin
 	mkdir -p build
@@ -90,12 +90,12 @@ build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bi
 build/train_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin
 	mkdir -p build
 	$(call preprocess_file)
-	$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o
+	$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o -g
 
 build/ops.o: ops.cpp ops.hpp kernels.h llm.c
 	mkdir -p build && $(CXX) $(CXXFLAGS) -c -o $@ $<
 
-build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c
+build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c gpt2_webgpu.cpp ops.cpp
 	mkdir -p build
 	$(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu.cpp ops.cpp
 
diff --git a/experimental/kernels/kernels.h b/experimental/kernels/kernels.h
index 212c075..1bce081 100644
--- a/experimental/kernels/kernels.h
+++ b/experimental/kernels/kernels.h
@@ -309,6 +309,104 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
       }
     }
 }
+
+)";
+
+
+static const char *kShaderMatmul2DTiling = R"(
+@group(0) @binding(0) var<storage, read_write> inp : array<{{precision}}>;
+@group(0) @binding(1) var<storage, read_write> weight : array<{{precision}}>;
+@group(0) @binding(2) var<storage, read_write> bias : array<{{precision}}>;
+@group(0) @binding(3) var<storage, read_write> out : array<{{precision}}>;
+@group(0) @binding(4) var<uniform> params : Params;
+struct Params {
+    B: u32,
+    T: u32,
+    C: u32,
+    OC: u32,
+};
+var<workgroup> tileInp: array<{{precision}}, {{BT}} * {{BC}}>;
+var<workgroup> tileWeight: array<{{precision}}, {{BOC}} * {{BC}}>;
+
+@compute @workgroup_size({{workgroupSize}})
+fn main(
+    @builtin(local_invocation_id) localID : vec3<u32>,
+    @builtin(workgroup_id) groupid : vec3<u32>) {
+    let B : u32 = params.B;
+    let T : u32 = params.T;
+    let C : u32 = params.C;
+    let OC : u32 = params.OC;
+
+    var localT: array<{{precision}}, {{TT}}>;
+    var localOC: array<{{precision}}, {{TOC}}>;
+
+    let outB: u32 = groupid.x;
+    let outT: u32 = groupid.y;
+    let outOC: u32 = groupid.z;
+    let numThread: u32 = ({{BT}} * {{BOC}}) / ({{TT}} * {{TOC}});
+
+    // position of the first c element computed by the thread
+    let threadRow: u32 = (localID.x / ({{BOC}} / {{TOC}})) * {{TT}};
+    let threadCol: u32 = (localID.x % ({{BOC}} / {{TOC}})) * {{TOC}};
+
+    // inpPtr and weightPtr are the starting positions of the tiles in a and b,
+    // incremented in the bkidx loop. 
+    // outPtr is the starting position of the tile in c which is fixed.
+
+    var inpPtr = (outB * T + outT * {{BT}}) * C; // BTC 
+    var weightPtr = outOC * {{BOC}} * C; //OCC
+    var threadResults: array<{{precision}}, {{TT}} * {{TOC}}>;
+    let outPtr = (outB * T + outT * {{BT}}) * OC + outOC * {{BOC}}; //BTOC
+    let biasPtr = outOC * {{BOC}};
+
+    for (var bkidx: u32 = 0; bkidx < C; bkidx += {{BC}}) {
+      // Load BC x BOC by numThread(BT * BOC / (TT * TOC))
+      // The number of iteration == BC * BOC / (BT * BOC / (TT * TOC))
+      for (var idx: u32 = 0; idx < {{NUM_TILEW}}; idx++) {
+        tileWeight[localID.x + idx * numThread] = weight[weightPtr + ((localID.x + idx * numThread) / {{BC}}) * C + ((localID.x + idx * numThread) % {{BC}})];
+      }
+      weightPtr += {{BC}};
+    
+      // Load tile
+      // Load BT x BC by numThread(BT * BOC / (TT * TOC))
+      // The number of iteration == BT * BC / (BT * BOC / (TT * TOC))
+      for (var idx: u32 = 0; idx < {{NUM_TILEI}}; idx++) {
+        tileInp[localID.x + idx * numThread] = inp[inpPtr + ((localID.x + idx * numThread) / {{BC}}) * C + (localID.x + idx * numThread) % {{BC}}];
+      }
+      inpPtr += {{BC}};
+    
+      workgroupBarrier();
+      // Compute tile
+      for (var dotIdx: u32 = 0; dotIdx < {{BC}}; dotIdx = dotIdx + 1) {
+        for (var idx: u32 = 0; idx < {{TT}}; idx++) {
+          localT[idx] = tileInp[(threadRow + idx) * {{BC}} + dotIdx];
+        }
+        for (var idx: u32 = 0; idx < {{TOC}}; idx++) {
+          localOC[idx] = tileWeight[(threadCol + idx) * {{BC}} + dotIdx];
+        }
+        for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
+          for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
+            threadResults[resIdxT * {{TOC}} + resIdxOC] += localT[resIdxT] * localOC[resIdxOC];
+          }
+        }
+      }
+      workgroupBarrier();
+    }
+    
+    if (arrayLength(&bias) == 1) {
+      for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
+        for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
+          out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC];
+        }
+      }
+    } else {
+      for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
+        for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
+          out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC] + bias[biasPtr + threadCol + resIdxOC];
+        }
+      }
+    }
+}
 )";
 
 static const char *kShaderMatmulBackward = R"(
diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp
index 67fc679..0e9c076 100644
--- a/experimental/kernels/ops.cpp
+++ b/experimental/kernels/ops.cpp
@@ -6,6 +6,7 @@
 
 #include "kernels.h"
 #include "ops.hpp"
+#include "experimental/wgsl.h"      // loopUnrolling
 
 using namespace gpu;
 
@@ -22,27 +23,39 @@ void encoder_forward(Context& ctx, float* out,
     uint32_t C;
   };
   setLogLevel(kError);
-  printf("Creating tensors\n");
-  printf("Creating input tensor\%pn", inp);
-  Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp);
-  printf("Created input tensor\n");
-  Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32, wte);
-  printf("Created wte tensor\n");
-  Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32, wpe);
-  printf("Created wpe tensor\n");
-  Tensor output = createTensor(ctx, Shape{b * t * c}, kf32);
-  printf("Created tensors\n");
+  // Generate the key of the cache by arguments.
+  std::string key = "encoder_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor input = createTensor(ctx, Shape{b * t}, ki32);
+    Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32);
+    Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32);
+    Tensor output = createTensor(ctx, Shape{b * t * c}, kf32);
+    op = createKernel(ctx, {kShaderEncoder, 256, kf32},
+                      Bindings{input, wte_t, wpe_t, output},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      EncoderParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& input = ctx.pool.data[op->buffers[0]];
+  Tensor& wte_t = ctx.pool.data[op->buffers[1]];
+  Tensor& wpe_t = ctx.pool.data[op->buffers[2]];
+  Tensor& output = ctx.pool.data[op->buffers[3]];
+
+  toGPU(ctx, inp, input);
+  toGPU(ctx, wte, wte_t);
+  toGPU(ctx, wpe, wpe_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderEncoder, 256, kf32},
-                           Bindings{input, wte_t, wpe_t, output},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           EncoderParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, output, out, b * t * c * sizeof(float));
@@ -61,21 +74,40 @@ void encoder_backward(Context& ctx, float* dwte, float* dwpe,
     uint32_t C;
   };
   setLogLevel(kError);
-  Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32, dwte);
-  Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32, dwpe);
-  Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout);
-  Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp);
+  // Generate the key of the cache by arguments.
+  std::string key = "encoder_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32);
+    Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32);
+    Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor input = createTensor(ctx, Shape{b * t}, ki32);
+    op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32},
+                      Bindings{dwte_t, dwpe_t, dout_t, input},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      EncoderParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& dwte_t = ctx.pool.data[op->buffers[0]];
+  Tensor& dwpe_t = ctx.pool.data[op->buffers[1]];
+  Tensor& dout_t = ctx.pool.data[op->buffers[2]];
+  Tensor& input = ctx.pool.data[op->buffers[3]];
+
+  toGPU(ctx, dwte, dwte_t);
+  toGPU(ctx, dwpe, dwpe_t);
+  toGPU(ctx, dout, dout_t);
+  toGPU(ctx, inp, input);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32},
-                           Bindings{dwte_t, dwpe_t, dout_t, input},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           EncoderParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dwte_t, dwte, v * c * sizeof(float));
@@ -94,23 +126,43 @@ void layernorm_forward(Context& ctx, float* out, float* mean, float* rstd,
     uint32_t C;
   };
   setLogLevel(kError);
-  Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp);
-  Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight);
-  Tensor bias_t = createTensor(ctx, Shape{c}, kf32, bias);
-  Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32);
-  Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32);
-  Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32);
+  // Generate the key of the cache by arguments.
+  std::string key = "layernorm_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor weight_t = createTensor(ctx, Shape{c}, kf32);
+    Tensor bias_t = createTensor(ctx, Shape{c}, kf32);
+    Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32);
+    Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32);
+    op = createKernel(ctx, {kShaderLayerNorm, 256, kf32},
+                      Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      LayerNormParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& inp_t = ctx.pool.data[op->buffers[0]];
+  Tensor& weight_t = ctx.pool.data[op->buffers[1]];
+  Tensor& bias_t = ctx.pool.data[op->buffers[2]];
+  Tensor& out_t = ctx.pool.data[op->buffers[3]];
+  Tensor& mean_t = ctx.pool.data[op->buffers[4]];
+  Tensor& rstd_t = ctx.pool.data[op->buffers[5]];
+
+  toGPU(ctx, inp, inp_t);
+  toGPU(ctx, weight, weight_t);
+  toGPU(ctx, bias, bias_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderLayerNorm, 256, kf32},
-                           Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           LayerNormParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, out_t, out, b * t * c * sizeof(float));
@@ -130,25 +182,52 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias,
     uint32_t C;
   };
   setLogLevel(kError);
-  Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp);
-  Tensor dweight_t = createTensor(ctx, Shape{c}, kf32, dweight);
-  Tensor dbias_t = createTensor(ctx, Shape{c}, kf32, dbias);
-  Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout);
-  Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp);
-  Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight);
-  Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32, mean);
-  Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32, rstd);
+  // Generate the key of the cache by arguments.
+  std::string key = "layernorm_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor dweight_t = createTensor(ctx, Shape{c}, kf32);
+    Tensor dbias_t = createTensor(ctx, Shape{c}, kf32);
+    Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor weight_t = createTensor(ctx, Shape{c}, kf32);
+    Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32);
+    Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32);
+    op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32},
+                      Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      LayerNormParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& dinp_t = ctx.pool.data[op->buffers[0]];
+  Tensor& dweight_t = ctx.pool.data[op->buffers[1]];
+  Tensor& dbias_t = ctx.pool.data[op->buffers[2]];
+  Tensor& dout_t = ctx.pool.data[op->buffers[3]];
+  Tensor& inp_t = ctx.pool.data[op->buffers[4]];
+  Tensor& weight_t = ctx.pool.data[op->buffers[5]];
+  Tensor& mean_t = ctx.pool.data[op->buffers[6]];
+  Tensor& rstd_t = ctx.pool.data[op->buffers[7]];
+
+  toGPU(ctx, dinp, dinp_t);
+  toGPU(ctx, dweight, dweight_t);
+  toGPU(ctx, dbias, dbias_t);
+  toGPU(ctx, dout, dout_t);
+  toGPU(ctx, inp, inp_t);
+  toGPU(ctx, weight, weight_t);
+  toGPU(ctx, mean, mean_t);
+  toGPU(ctx, rstd, rstd_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32},
-                           Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           LayerNormParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float));
@@ -156,9 +235,34 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias,
   toCPU(ctx, dbias_t, dbias, c * sizeof(float));
 }
 
+struct DurationTime {
+  std::chrono::high_resolution_clock::time_point start;
+  std::chrono::high_resolution_clock::time_point end;
+  std::chrono::microseconds duration;
+  std::string src;
+  bool verbose;
+  
+  inline DurationTime(const std::string& src, bool verbose = true) {
+    this->src = src;
+    this->verbose = verbose;
+    start = std::chrono::high_resolution_clock::now();
+  }
+
+  inline ~DurationTime() {
+    end = std::chrono::high_resolution_clock::now();
+    duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
+    if (this->verbose) {
+      printf("Duration(%s): %.1f microseconds\n", src.c_str(), static_cast<double>(duration.count()));
+    }
+  }
+};
+
+
 void matmul_forward(Context& ctx, float* out,
                         const float* inp, const float* weight, const float* bias,
                         int B, int T, int C, int OC){
+  bool verbose = false;
+  DurationTime duration("matmul_forward_gpu", verbose);
   struct MatmulParams {
     uint32_t B;
     uint32_t T;
@@ -171,25 +275,76 @@ void matmul_forward(Context& ctx, float* out,
   unsigned long oc = static_cast<unsigned long>(OC);
   setLogLevel(kError);
 
-  Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32, inp);
-  Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32, weight);
-  Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32, bias);
-  Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32);
+  // Generate the key of the cache by arguments.
+  std::string key = "matmul_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    constexpr size_t BT = 64;
+    constexpr size_t BC = 8;
+    constexpr size_t BOC = 64;
+    constexpr size_t TT = BT / BC;
+    constexpr size_t TOC = BOC / BC;
+    size_t num_threads = BT * BOC / (TT * TOC);
+    Shape wgSize = {num_threads, 1, 1};
+    Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)};
+
+    std::string kShaderMatmul2DTiling_(kShaderMatmul2DTiling);
+    std::string kShaderMatmul2D(loopUnrolling(
+                                              replaceAll(kShaderMatmul2DTiling_,
+                                                         {{"{{precision}}", toString(kf32)},
+                                                          {"{{BT}}", toString(BT)},
+                                                          {"{{BC}}", toString(BC)},
+                                                          {"{{BOC}}", toString(BOC)},
+                                                          {"{{TT}}", toString(TT)},
+                                                          {"{{TOC}}", toString(TOC)},
+                                                          {"{{NUM_TILEI}}", toString(BT * BC / num_threads)},
+                                                          {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)}
+                                                         })
+                                              )
+                                );
+
+    Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32);
+    Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32);
+    Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32);
+    
+    op = createKernel(ctx, {kShaderMatmul2D, wgSize, kf32},
+                      Bindings{inp_i, weight_i, bias_i, out_o},
+                      nWorkgroups,
+                      /* params */
+                      MatmulParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c),
+                        static_cast<uint32_t>(oc)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& inp_i = ctx.pool.data[op->buffers[0]];
+  Tensor& weight_i = ctx.pool.data[op->buffers[1]];
+  Tensor& bias_i = ctx.pool.data[op->buffers[2]];
+  Tensor& out_o = ctx.pool.data[op->buffers[3]];
+
+  toGPU(ctx, inp, inp_i);
+  toGPU(ctx, weight, weight_i);
+  if (bias != NULL) {
+    toGPU(ctx, bias, bias_i);
+  }
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  assert ( (b*t) % 256 == 0 );
-  Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32},
-                           Bindings{inp_i, weight_i, bias_i, out_o},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           MatmulParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c),
-                             static_cast<uint32_t>(oc)
-                           });
-  dispatchKernel(ctx, op, promise);
-  wait(ctx, future);
+  {
+    DurationTime duration("matmul_forward_gpu without creating tensors", verbose);
+    {
+      DurationTime duration("matmul_forward_gpu without creating kernel", verbose);
+      dispatchKernel(ctx, op, promise);
+      wait(ctx, future);
+      toCPU(ctx, out_o, out, b * t * oc * sizeof(float));
+    }
+  }
   toCPU(ctx, out_o, out, b * t * oc * sizeof(float));
 }
 
@@ -207,24 +362,47 @@ void matmul_backward(Context& ctx, float* dinp, float* dweight, float* dbias,
   unsigned long c = static_cast<unsigned long>(C);
   unsigned long oc = static_cast<unsigned long>(OC);
   setLogLevel(kError);
-  Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp);
-  Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32, dweight);
-  Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32, dbias);
-  Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32, dout);
-  Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp);
-  Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32, weight);
+  // Generate the key of the cache by arguments.
+  std::string key = "matmul_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32);
+    Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32);
+    Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32);
+    Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32);
+    op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32},
+                      Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      MatmulParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c),
+                        static_cast<uint32_t>(oc)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& dinp_t = ctx.pool.data[op->buffers[0]];
+  Tensor& dweight_t = ctx.pool.data[op->buffers[1]];
+  Tensor& dbias_t = ctx.pool.data[op->buffers[2]];
+  Tensor& dout_t = ctx.pool.data[op->buffers[3]];
+  Tensor& inp_t = ctx.pool.data[op->buffers[4]];
+  Tensor& weight_t = ctx.pool.data[op->buffers[5]];
+
+  toGPU(ctx, dinp, dinp_t);
+  toGPU(ctx, dweight, dweight_t);
+  toGPU(ctx, dbias, dbias_t);
+  toGPU(ctx, dout, dout_t);
+  toGPU(ctx, inp, inp_t);
+  toGPU(ctx, weight, weight_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32},
-                           Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           MatmulParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c),
-                             static_cast<uint32_t>(oc)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float));
@@ -246,22 +424,40 @@ void attention_forward(Context& ctx, float* out, float* preatt, float* att,
   unsigned long c = static_cast<unsigned long>(C);
   unsigned long nh = static_cast<unsigned long>(NH);
   setLogLevel(kError);
-  Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp);
-  Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, preatt);
-  Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att);
-  Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32);
+  // Generate the key of the cache by arguments.
+  std::string key = "attention_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32);
+    Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32);
+    Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32);
+    Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    op = createKernel(ctx, {kShaderAttention, 256, kf32},
+                      Bindings{inp_t, preatt_t, att_t, out_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      AttentionParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c),
+                        static_cast<uint32_t>(nh)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& inp_t = ctx.pool.data[op->buffers[0]];
+  Tensor& preatt_t = ctx.pool.data[op->buffers[1]];
+  Tensor& att_t = ctx.pool.data[op->buffers[2]];
+  Tensor& out_t = ctx.pool.data[op->buffers[3]];
+
+  toGPU(ctx, inp, inp_t);
+  toGPU(ctx, preatt, preatt_t);
+  toGPU(ctx, att, att_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderAttention, 256, kf32},
-                           Bindings{inp_t, preatt_t, att_t, out_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           AttentionParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c),
-                             static_cast<uint32_t>(nh)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, preatt_t, preatt, b * nh * t * t * sizeof(float));
@@ -283,24 +479,47 @@ void attention_backward(Context& ctx, float* dinp, float* dpreatt, float* datt,
   unsigned long c = static_cast<unsigned long>(C);
   unsigned long nh = static_cast<unsigned long>(NH);
   setLogLevel(kError);
-  Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, dinp);
-  Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, dpreatt);
-  Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, datt);
-  Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout);
-  Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp);
-  Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att);
+  // Generate the key of the cache by arguments.
+  std::string key = "attention_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32);
+    Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32);
+    Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32);
+    Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32);
+    Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32);
+    op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32},
+                      Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      AttentionParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c),
+                        static_cast<uint32_t>(nh)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& dinp_t = ctx.pool.data[op->buffers[0]];
+  Tensor& dpreatt_t = ctx.pool.data[op->buffers[1]];
+  Tensor& datt_t = ctx.pool.data[op->buffers[2]];
+  Tensor& dout_t = ctx.pool.data[op->buffers[3]];
+  Tensor& inp_t = ctx.pool.data[op->buffers[4]];
+  Tensor& att_t = ctx.pool.data[op->buffers[5]];
+
+  toGPU(ctx, dinp, dinp_t);
+  toGPU(ctx, dpreatt, dpreatt_t);
+  toGPU(ctx, datt, datt_t);
+  toGPU(ctx, dout, dout_t);
+  toGPU(ctx, inp, inp_t);
+  toGPU(ctx, att, att_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32},
-                           Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           AttentionParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c),
-                             static_cast<uint32_t>(nh)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dinp_t, dinp, b * t * c * 3 * sizeof(float));
@@ -311,13 +530,28 @@ void attention_backward(Context& ctx, float* dinp, float* dpreatt, float* datt,
 void gelu_forward(Context& ctx, float* out, float* inp, int n) {
   unsigned long N = static_cast<unsigned long>(n);
   setLogLevel(kError);
-  Tensor input = createTensor(ctx, Shape{N}, kf32, inp);
-  Tensor output = createTensor(ctx, Shape{N}, kf32);
+  // Generate the key of the cache by arguments.
+  std::string key = "gelu_forward_" + std::to_string(n);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor input = createTensor(ctx, Shape{N}, kf32);
+    Tensor output = createTensor(ctx, Shape{N}, kf32);
+    op = createKernel(ctx, {kShaderGelu, 256, kf32},
+                      Bindings{input, output},
+                      /* nWorkgroups */ {cdiv(N, 256), 1, 1},
+                      nullptr,
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& input = ctx.pool.data[op->buffers[0]];
+  Tensor& output = ctx.pool.data[op->buffers[1]];
+
+  toGPU(ctx, inp, input);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderGelu, 256, kf32},
-                           Bindings{input, output},
-                           /* nWorkgroups */ {cdiv(N, 256), 1, 1});
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, output, out, N * sizeof(float));
@@ -326,14 +560,32 @@ void gelu_forward(Context& ctx, float* out, float* inp, int n) {
 void gelu_backward(Context& ctx, float* dinp, float* inp, float* dout, int N){
   unsigned long n = static_cast<unsigned long>(N);
   setLogLevel(kError);
-  Tensor inp_i = createTensor(ctx, Shape{n}, kf32, inp);
-  Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout);
-  Tensor dinp_o = createTensor(ctx, Shape{n}, kf32, dinp);
+  // Generate the key of the cache by arguments.
+  std::string key = "gelu_backward_" + std::to_string(N);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor inp_i = createTensor(ctx, Shape{n}, kf32);
+    Tensor dout_i = createTensor(ctx, Shape{n}, kf32);
+    Tensor dinp_o = createTensor(ctx, Shape{n}, kf32);
+    op = createKernel(ctx, {kShaderGeluBackward, 256, kf32},
+                      Bindings{inp_i, dout_i, dinp_o},
+                      /* nWorkgroups */ {cdiv(n, 256), 1, 1},
+                      nullptr,
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& inp_i = ctx.pool.data[op->buffers[0]];
+  Tensor& dout_i = ctx.pool.data[op->buffers[1]];
+  Tensor& dinp_o = ctx.pool.data[op->buffers[2]];
+
+  toGPU(ctx, inp, inp_i);
+  toGPU(ctx, dout, dout_i);
+  toGPU(ctx, dinp, dinp_o);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderGeluBackward, 256, kf32},
-                           Bindings{inp_i, dout_i, dinp_o},
-                           /* nWorkgroups */ {cdiv(n, 256), 1, 1});
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dinp_o, dinp, n * sizeof(float));
@@ -342,14 +594,31 @@ void gelu_backward(Context& ctx, float* dinp, float* inp, float* dout, int N){
 void residual_forward(Context& ctx, float* out, float* inp1, float* inp2, int N){
   unsigned long n = static_cast<unsigned long>(N);
   setLogLevel(kError);
-  Tensor inp1_i = createTensor(ctx, Shape{n}, kf32, inp1);
-  Tensor inp2_i = createTensor(ctx, Shape{n}, kf32, inp2);
-  Tensor out_o = createTensor(ctx, Shape{n}, kf32);
+  // Generate the key of the cache by arguments.
+  std::string key = "residual_forward_" + std::to_string(N);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor inp1_i = createTensor(ctx, Shape{n}, kf32);
+    Tensor inp2_i = createTensor(ctx, Shape{n}, kf32);
+    Tensor out_o = createTensor(ctx, Shape{n}, kf32);
+    op = createKernel(ctx, {kShaderResidual, 256, kf32},
+                      Bindings{inp1_i, inp2_i, out_o},
+                      /* nWorkgroups */ {cdiv(n, 256), 1, 1},
+                      nullptr,
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& inp1_i = ctx.pool.data[op->buffers[0]];
+  Tensor& inp2_i = ctx.pool.data[op->buffers[1]];
+  Tensor& out_o = ctx.pool.data[op->buffers[2]];
+
+  toGPU(ctx, inp1, inp1_i);
+  toGPU(ctx, inp2, inp2_i);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderResidual, 256, kf32},
-                           Bindings{inp1_i, inp2_i, out_o},
-                           /* nWorkgroups */ {cdiv(n, 256), 1, 1});
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, out_o, out, n * sizeof(float));
@@ -358,14 +627,32 @@ void residual_forward(Context& ctx, float* out, float* inp1, float* inp2, int N)
 void residual_backward(Context& ctx, float* dinp1, float* dinp2, float* dout, int N){
   unsigned long n = static_cast<unsigned long>(N);
   setLogLevel(kError);
-  Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout);
-  Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32, dinp1);
-  Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32, dinp2);
+  // Generate the key of the cache by arguments.
+  std::string key = "residual_backward_" + std::to_string(N);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor dout_i = createTensor(ctx, Shape{n}, kf32);
+    Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32);
+    Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32);
+    op = createKernel(ctx, {kShaderResidualBackward, 256, kf32},
+                      Bindings{dout_i, dinp1_o, dinp2_o},
+                      /* nWorkgroups */ {cdiv(n, 256), 1, 1},
+                      nullptr,
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& dout_i = ctx.pool.data[op->buffers[0]];
+  Tensor& dinp1_o = ctx.pool.data[op->buffers[1]];
+  Tensor& dinp2_o = ctx.pool.data[op->buffers[2]];
+
+  toGPU(ctx, dout, dout_i);
+  toGPU(ctx, dinp1, dinp1_o);
+  toGPU(ctx, dinp2, dinp2_o);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderResidualBackward, 256, kf32},
-                           Bindings{dout_i, dinp1_o, dinp2_o},
-                           /* nWorkgroups */ {cdiv(n, 256), 1, 1});
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dinp1_o, dinp1, n * sizeof(float));
@@ -382,14 +669,28 @@ void softmax_forward(Context& ctx, float* probs, float* logits, int B, int T, in
   uint32_t t = static_cast<uint32_t>(T);
   uint32_t c = static_cast<uint32_t>(V);
   uint32_t cp = static_cast<uint32_t>(Vp);
-  Tensor input = createTensor(ctx, {b * t, cp}, kf32, logits);
-  Tensor output = createTensor(ctx, {b * t, cp}, kf32);
+  // Generate the key of the cache by arguments.
+  std::string key = "softmax_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor input = createTensor(ctx, {b * t, cp}, kf32);
+    Tensor output = createTensor(ctx, {b * t, cp}, kf32);
+    assert( (B*T) % 256 == 0);
+    op = createKernel(
+        ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output},
+        Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp},
+        nullptr,
+        key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& input = ctx.pool.data[op->buffers[0]];
+  Tensor& output = ctx.pool.data[op->buffers[1]];
+
+  toGPU(ctx, logits, input);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  assert( (B*T) % 256 == 0);
-  Kernel op = createKernel(
-      ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output},
-      Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp});
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, output, probs, sizeof(float)*b*t*cp);
@@ -407,20 +708,37 @@ void crossentropy_forward(Context& ctx, float* losses,
   unsigned long t = static_cast<unsigned long>(T);
   unsigned long vp = static_cast<unsigned long>(Vp);
   setLogLevel(kError);
-  Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32, losses);
-  Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs);
-  Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets);
+  // Generate the key of the cache by arguments.
+  std::string key = "crossentropy_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(Vp);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32);
+    Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32);
+    Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32);
+    op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32},
+                      Bindings{losses_t, probs_t, targets_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      CrossEntropyParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(vp)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& losses_t = ctx.pool.data[op->buffers[0]];
+  Tensor& probs_t = ctx.pool.data[op->buffers[1]];
+  Tensor& targets_t = ctx.pool.data[op->buffers[2]];
+
+  toGPU(ctx, losses, losses_t);
+  toGPU(ctx, probs, probs_t);
+  toGPU(ctx, targets, targets_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32},
-                           Bindings{losses_t, probs_t, targets_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           CrossEntropyParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(vp)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, losses_t, losses, b * t * sizeof(float));
@@ -440,22 +758,41 @@ void crossentropy_softmax_backward(Context& ctx, float* dlogits,
   unsigned long v = static_cast<unsigned long>(V);
   unsigned long vp = static_cast<unsigned long>(Vp);
   setLogLevel(kError);
-  Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32, dlogits);
-  Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32, dlosses);
-  Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs);
-  Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets);
+  // Generate the key of the cache by arguments.
+  std::string key = "crossentropy_softmax_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32);
+    Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32);
+    Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32);
+    Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32);
+    op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32},
+                      Bindings{dlogits_t, dlosses_t, probs_t, targets_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      CrossEntropySoftmaxBackwardParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(v),
+                        static_cast<uint32_t>(vp)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& dlogits_t = ctx.pool.data[op->buffers[0]];
+  Tensor& dlosses_t = ctx.pool.data[op->buffers[1]];
+  Tensor& probs_t = ctx.pool.data[op->buffers[2]];
+  Tensor& targets_t = ctx.pool.data[op->buffers[3]];
+
+  toGPU(ctx, dlogits, dlogits_t);
+  toGPU(ctx, dlosses, dlosses_t);
+  toGPU(ctx, probs, probs_t);
+  toGPU(ctx, targets, targets_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32},
-                           Bindings{dlogits_t, dlosses_t, probs_t, targets_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           CrossEntropySoftmaxBackwardParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(v),
-                             static_cast<uint32_t>(vp)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dlogits_t, dlogits, b * t * vp * sizeof(float));
diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp
index 37cdcaf..d037eac 100644
--- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp
+++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp
@@ -2,9 +2,11 @@
 #include <array>
 #include <cstdio>
 #include <future>
+#include <map>
 
 #include "kernels.h"
 #include "unittest_llmc/unittest_kernels.h"
+#include "experimental/wgsl.h"      // loopUnrolling
 
 using namespace gpu; // createContext, createTensor, createKernel,
                      // createShader, dispatchKernel, wait, toCPU
@@ -51,6 +53,33 @@ using namespace gpu; // createContext, createTensor, createKernel,
     } \
   }
 
+struct DurationTime {
+  std::chrono::high_resolution_clock::time_point start;
+  std::chrono::high_resolution_clock::time_point end;
+  std::chrono::microseconds duration;
+  std::string src;
+  bool verbose;
+  
+  inline DurationTime(const std::string& src, bool verbose = true) {
+    this->src = src;
+    this->verbose = verbose;
+    start = std::chrono::high_resolution_clock::now();
+  }
+
+  inline ~DurationTime() {
+    end = std::chrono::high_resolution_clock::now();
+    duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
+    if (this->verbose) {
+      printf("Duration(%s): %.1f microseconds\n", src.c_str(), static_cast<double>(duration.count()));
+    }
+  }
+};
+
+static WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB;
+static Context ctx = createContext({},{},{
+    .requiredLimits = &requiredLimits
+  });
+
 void ENCODER_FORWARD_GPU(float* out,
                          int* inp, float* wte, float* wpe,
                          int B, int T, int C){
@@ -64,25 +93,40 @@ void ENCODER_FORWARD_GPU(float* out,
     uint32_t C;
   };
   setLogLevel(kError);
-  WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB;
-  Context ctx = createContext({},{},{
-      .requiredLimits = &requiredLimits
-    });
-  Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp);
-  Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32, wte);
-  Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32, wpe);
-  Tensor output = createTensor(ctx, Shape{b * t * c}, kf32);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "ENCODER_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor input = createTensor(ctx, Shape{b * t}, ki32);
+    Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32);
+    Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32);
+    Tensor output = createTensor(ctx, Shape{b * t * c}, kf32);
+    op = createKernel(ctx, {kShaderEncoder, 256, kf32},
+                      Bindings{input, wte_t, wpe_t, output},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      EncoderParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& input = ctx.pool.data[op->buffers[0]];
+  Tensor& wte_t = ctx.pool.data[op->buffers[1]];
+  Tensor& wpe_t = ctx.pool.data[op->buffers[2]];
+  Tensor& output = ctx.pool.data[op->buffers[3]];
+
+  toGPU(ctx, inp, input);
+  toGPU(ctx, wte, wte_t);
+  toGPU(ctx, wpe, wpe_t);
+  
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderEncoder, 256, kf32},
-                           Bindings{input, wte_t, wpe_t, output},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           EncoderParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, output, out, b * t * c * sizeof(float));
@@ -101,25 +145,41 @@ void ENCODER_BACKWARD_GPU(float* dwte, float* dwpe,
     uint32_t C;
   };
   setLogLevel(kError);
-  WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB;
-  Context ctx = createContext({},{},{
-      .requiredLimits = &requiredLimits
-    });
-  Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32, dwte);
-  Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32, dwpe);
-  Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout);
-  Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "ENCODER_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32);
+    Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32);
+    Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor input = createTensor(ctx, Shape{b * t}, ki32);
+    op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32},
+                      Bindings{dwte_t, dwpe_t, dout_t, input},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      EncoderParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& dwte_t = ctx.pool.data[op->buffers[0]];
+  Tensor& dwpe_t = ctx.pool.data[op->buffers[1]];
+  Tensor& dout_t = ctx.pool.data[op->buffers[2]];
+  Tensor& input = ctx.pool.data[op->buffers[3]];
+
+  toGPU(ctx, dwte, dwte_t);
+  toGPU(ctx, dwpe, dwpe_t);
+  toGPU(ctx, dout, dout_t);
+  toGPU(ctx, inp, input);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32},
-                           Bindings{dwte_t, dwpe_t, dout_t, input},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           EncoderParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dwte_t, dwte, v * c * sizeof(float));
@@ -138,24 +198,44 @@ void LAYERNORM_FORWARD_GPU(float* out, float* mean, float* rstd,
     uint32_t C;
   };
   setLogLevel(kError);
-  Context ctx = createContext();
-  Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp);
-  Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight);
-  Tensor bias_t = createTensor(ctx, Shape{c}, kf32, bias);
-  Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32);
-  Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32);
-  Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "LAYERNORM_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor weight_t = createTensor(ctx, Shape{c}, kf32);
+    Tensor bias_t = createTensor(ctx, Shape{c}, kf32);
+    Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32);
+    Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32);
+    op = createKernel(ctx, {kShaderLayerNorm, 256, kf32},
+                      Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      LayerNormParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& inp_t = ctx.pool.data[op->buffers[0]];
+  Tensor& weight_t = ctx.pool.data[op->buffers[1]];
+  Tensor& bias_t = ctx.pool.data[op->buffers[2]];
+  Tensor& out_t = ctx.pool.data[op->buffers[3]];
+  Tensor& mean_t = ctx.pool.data[op->buffers[4]];
+  Tensor& rstd_t = ctx.pool.data[op->buffers[5]];
+
+  toGPU(ctx, inp, inp_t);
+  toGPU(ctx, weight, weight_t);
+  toGPU(ctx, bias, bias_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderLayerNorm, 256, kf32},
-                           Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           LayerNormParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, out_t, out, b * t * c * sizeof(float));
@@ -175,26 +255,53 @@ void LAYERNORM_BACKWARD_GPU(float* dinp, float* dweight, float* dbias,
     uint32_t C;
   };
   setLogLevel(kError);
-  Context ctx = createContext();
-  Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp);
-  Tensor dweight_t = createTensor(ctx, Shape{c}, kf32, dweight);
-  Tensor dbias_t = createTensor(ctx, Shape{c}, kf32, dbias);
-  Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout);
-  Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp);
-  Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight);
-  Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32, mean);
-  Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32, rstd);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "LAYERNORM_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor dweight_t = createTensor(ctx, Shape{c}, kf32);
+    Tensor dbias_t = createTensor(ctx, Shape{c}, kf32);
+    Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor weight_t = createTensor(ctx, Shape{c}, kf32);
+    Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32);
+    Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32);
+    op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32},
+                      Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      LayerNormParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& dinp_t = ctx.pool.data[op->buffers[0]];
+  Tensor& dweight_t = ctx.pool.data[op->buffers[1]];
+  Tensor& dbias_t = ctx.pool.data[op->buffers[2]];
+  Tensor& dout_t = ctx.pool.data[op->buffers[3]];
+  Tensor& inp_t = ctx.pool.data[op->buffers[4]];
+  Tensor& weight_t = ctx.pool.data[op->buffers[5]];
+  Tensor& mean_t = ctx.pool.data[op->buffers[6]];
+  Tensor& rstd_t = ctx.pool.data[op->buffers[7]];
+
+  toGPU(ctx, dinp, dinp_t);
+  toGPU(ctx, dweight, dweight_t);
+  toGPU(ctx, dbias, dbias_t);
+  toGPU(ctx, dout, dout_t);
+  toGPU(ctx, inp, inp_t);
+  toGPU(ctx, weight, weight_t);
+  toGPU(ctx, mean, mean_t);
+  toGPU(ctx, rstd, rstd_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32},
-                           Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           LayerNormParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float));
@@ -202,9 +309,29 @@ void LAYERNORM_BACKWARD_GPU(float* dinp, float* dweight, float* dbias,
   toCPU(ctx, dbias_t, dbias, c * sizeof(float));
 }
 
+void matmul_forward_dummy(float* out,
+                          const float* inp, const float* weight, const float* bias,
+                          int B, int T, int C, int OC);
+
+
 void MATMUL_FORWARD_GPU(float* out,
                         const float* inp, const float* weight, const float* bias,
                         int B, int T, int C, int OC){
+  int version = 2;
+  bool verbose = false;
+  bool debug = false;
+  float *out_exp;
+  DurationTime duration("matmul_forward_gpu with preparing a kernel", verbose);
+  if (verbose) {
+    printf("matmul forward: B=%d, T=%d, C=%d, OC=%d, bias=%d\n", B, T, C, OC, bias != NULL);
+  }
+  if (debug) {
+    out_exp = new float[B*T*OC];
+    {
+      DurationTime duration("matmul_forward_cpu", verbose);
+      matmul_forward_dummy(out_exp, inp, weight, bias, B, T, C, OC);
+    }
+  }
   struct MatmulParams {
     uint32_t B;
     uint32_t T;
@@ -216,31 +343,132 @@ void MATMUL_FORWARD_GPU(float* out,
   unsigned long c = static_cast<unsigned long>(C);
   unsigned long oc = static_cast<unsigned long>(OC);
   setLogLevel(kError);
-  WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB;
-  Context ctx = createContext({},{},{
-      .requiredLimits = &requiredLimits
-    });
-
-  Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32, inp);
-  Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32, weight);
-  Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32, bias);
-  Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32);
-  std::promise<void> promise;
-  std::future<void> future = promise.get_future();
-  assert ( (b*t) % 256 == 0 );
-  Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32},
-                           Bindings{inp_i, weight_i, bias_i, out_o},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           MatmulParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c),
-                             static_cast<uint32_t>(oc)
-                           });
-  dispatchKernel(ctx, op, promise);
-  wait(ctx, future);
-  toCPU(ctx, out_o, out, b * t * oc * sizeof(float));
+
+  if (version == 2 || version == 1) {
+    // Generate the key of the cache by arguments.
+    std::string key = "MATMUL_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC);
+    Kernel op;
+    if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+      Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32);
+      Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32);
+      Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32);
+      Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32);
+
+      if (version == 2) {
+        constexpr size_t BT = 64;
+        constexpr size_t BC = 16;
+        constexpr size_t BOC = 64;
+        constexpr size_t TT = BT / BC;
+        constexpr size_t TOC = BOC / BC;
+        constexpr size_t num_threads = BT * BOC / (TT * TOC);
+        Shape wgSize = {num_threads, 1, 1};
+
+        std::string codeString(kShaderMatmul2DTiling);
+        std::string unrolledCode = loopUnrolling(replaceAll(codeString, {{"{{precision}}", toString(kf32)},
+                                                                         {"{{BT}}", toString(BT)},
+                                                                         {"{{BC}}", toString(BC)},
+                                                                         {"{{BOC}}", toString(BOC)},
+                                                                         {"{{TT}}", toString(TT)},
+                                                                         {"{{TOC}}", toString(TOC)},
+                                                                         {"{{NUM_TILEI}}", toString(BT * BC / num_threads)},
+                                                                         {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)}
+            }));
+
+        Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)};
+        op = createKernel(ctx, {unrolledCode, wgSize, kf32},
+                          Bindings{inp_i, weight_i, bias_i, out_o},
+                          nWorkgroups,
+                          /* params */
+                          MatmulParams{
+                            static_cast<uint32_t>(b),
+                            static_cast<uint32_t>(t),
+                            static_cast<uint32_t>(c),
+                            static_cast<uint32_t>(oc)
+                          },
+                          nullptr,
+                          key.c_str()
+                          );
+      } else {
+        op = createKernel(ctx, {kShaderMatmul, 256, kf32},
+                          Bindings{inp_i, weight_i, bias_i, out_o},
+                          /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                          /* params */
+                          MatmulParams{
+                            static_cast<uint32_t>(b),
+                            static_cast<uint32_t>(t),
+                            static_cast<uint32_t>(c),
+                            static_cast<uint32_t>(oc)
+                          },
+                          nullptr,
+                          key.c_str()
+                          );
+      }
+    } else {
+      op = ctx.kernelPool.data[key];
+    }
+    Tensor& inp_i = ctx.pool.data[op->buffers[0]];
+    Tensor& weight_i = ctx.pool.data[op->buffers[1]];
+    Tensor& bias_i = ctx.pool.data[op->buffers[2]];
+    Tensor& out_o = ctx.pool.data[op->buffers[3]];
+      
+    toGPU(ctx, inp, inp_i);
+    toGPU(ctx, weight, weight_i);
+    if (bias != NULL) {
+      toGPU(ctx, bias, bias_i);
+    }
+    
+    std::promise<void> promise;
+    std::future<void> future = promise.get_future();
+
+    {
+      DurationTime duration("matmul_forward_gpu", verbose);
+      dispatchKernel(ctx, op, promise);
+      wait(ctx, future);
+    }
+    toCPU(ctx, out_o, out, b * t * oc * sizeof(float));
+  } else {
+    DurationTime duration("matmul_forward_cpu", verbose);
+    matmul_forward_dummy(out, inp, weight, bias, B, T, C, OC);
+  }
+
+  if (debug) { // compare out with out_exp.
+    for (int i = 0; i < B*T*OC; i++) {
+      if (fabs(out[i] - out_exp[i]) > 1e-2) {
+        printf("matmul forward: out[%d] = %f, out_exp[%d] = %f\n", i, out[i], i, out_exp[i]);
+        //Dump the first 4 x 4 elements by table, at first output out, then output out_exp
+        printf("inp:\n");
+        for (int j = 0; j < 4; j++) {
+          for (int k = 0; k < 4; k++) {
+            printf("%f ", inp[j * C + k]);
+          }
+          printf("\n");
+        }
+        printf("weight:\n");
+        for (int j = 0; j < 4; j++) {
+          for (int k = 0; k < 4; k++) {
+            printf("%f ", weight[j * OC + k]);
+          }
+          printf("\n");
+        }
+        printf("out:\n");
+        for (int j = 0; j < 4; j++) {
+          for (int k = 0; k < 4; k++) {
+            printf("%f ", out[j * OC + k]);
+          }
+          printf("\n");
+        }
+        printf("out_exp:\n");
+        for (int j = 0; j < 4; j++) {
+          for (int k = 0; k < 4; k++) {
+            printf("%f ", out_exp[j * OC + k]);
+          }
+          printf("\n");
+        }
+        exit(1);
+      }
+    } 
+    delete[] out_exp;
+  }
 }
 
 void MATMUL_BACKWARD_GPU(float* dinp, float* dweight, float* dbias,
@@ -257,28 +485,48 @@ void MATMUL_BACKWARD_GPU(float* dinp, float* dweight, float* dbias,
   unsigned long c = static_cast<unsigned long>(C);
   unsigned long oc = static_cast<unsigned long>(OC);
   setLogLevel(kError);
-  WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB;
-  Context ctx = createContext({},{},{
-      .requiredLimits = &requiredLimits
-    });
-  Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp);
-  Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32, dweight);
-  Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32, dbias);
-  Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32, dout);
-  Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp);
-  Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32, weight);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "MATMUL_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32);
+    Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32);
+    Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32);
+    Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32);
+    op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32},
+                      Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      MatmulParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c),
+                        static_cast<uint32_t>(oc)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& dinp_t = ctx.pool.data[op->buffers[0]];
+  Tensor& dweight_t = ctx.pool.data[op->buffers[1]];
+  Tensor& dbias_t = ctx.pool.data[op->buffers[2]];
+  Tensor& dout_t = ctx.pool.data[op->buffers[3]];
+  Tensor& inp_t = ctx.pool.data[op->buffers[4]];
+  Tensor& weight_t = ctx.pool.data[op->buffers[5]];
+
+  toGPU(ctx, dinp, dinp_t);
+  toGPU(ctx, dweight, dweight_t);
+  toGPU(ctx, dbias, dbias_t);
+  toGPU(ctx, dout, dout_t);
+  toGPU(ctx, inp, inp_t);
+  toGPU(ctx, weight, weight_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32},
-                           Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           MatmulParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c),
-                             static_cast<uint32_t>(oc)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float));
@@ -300,23 +548,41 @@ void ATTENTION_FORWARD_GPU(float* out, float* preatt, float* att,
   unsigned long c = static_cast<unsigned long>(C);
   unsigned long nh = static_cast<unsigned long>(NH);
   setLogLevel(kError);
-  Context ctx = createContext();
-  Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp);
-  Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, preatt);
-  Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att);
-  Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "ATTENTION_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32);
+    Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32);
+    Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32);
+    Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    op = createKernel(ctx, {kShaderAttention, 256, kf32},
+                      Bindings{inp_t, preatt_t, att_t, out_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      AttentionParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c),
+                        static_cast<uint32_t>(nh)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& inp_t = ctx.pool.data[op->buffers[0]];
+  Tensor& preatt_t = ctx.pool.data[op->buffers[1]];
+  Tensor& att_t = ctx.pool.data[op->buffers[2]];
+  Tensor& out_t = ctx.pool.data[op->buffers[3]];
+
+  toGPU(ctx, inp, inp_t);
+  toGPU(ctx, preatt, preatt_t);
+  toGPU(ctx, att, att_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderAttention, 256, kf32},
-                           Bindings{inp_t, preatt_t, att_t, out_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           AttentionParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c),
-                             static_cast<uint32_t>(nh)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, preatt_t, preatt, b * nh * t * t * sizeof(float));
@@ -338,25 +604,48 @@ void ATTENTION_BACKWARD_GPU(float* dinp, float* dpreatt, float* datt,
   unsigned long c = static_cast<unsigned long>(C);
   unsigned long nh = static_cast<unsigned long>(NH);
   setLogLevel(kError);
-  Context ctx = createContext();
-  Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, dinp);
-  Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, dpreatt);
-  Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, datt);
-  Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout);
-  Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp);
-  Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "ATTENTION_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32);
+    Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32);
+    Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32);
+    Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32);
+    Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32);
+    Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32);
+    op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32},
+                      Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      AttentionParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(c),
+                        static_cast<uint32_t>(nh)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& dinp_t = ctx.pool.data[op->buffers[0]];
+  Tensor& dpreatt_t = ctx.pool.data[op->buffers[1]];
+  Tensor& datt_t = ctx.pool.data[op->buffers[2]];
+  Tensor& dout_t = ctx.pool.data[op->buffers[3]];
+  Tensor& inp_t = ctx.pool.data[op->buffers[4]];
+  Tensor& att_t = ctx.pool.data[op->buffers[5]];
+
+  toGPU(ctx, dinp, dinp_t);
+  toGPU(ctx, dpreatt, dpreatt_t);
+  toGPU(ctx, datt, datt_t);
+  toGPU(ctx, dout, dout_t);
+  toGPU(ctx, inp, inp_t);
+  toGPU(ctx, att, att_t);
+  
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32},
-                           Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           AttentionParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(c),
-                             static_cast<uint32_t>(nh)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dinp_t, dinp, b * t * c * 3 * sizeof(float));
@@ -367,14 +656,29 @@ void ATTENTION_BACKWARD_GPU(float* dinp, float* dpreatt, float* datt,
 void GELU_FORWARD_GPU(float* out, float* inp, int n) {
   unsigned long N = static_cast<unsigned long>(n);
   setLogLevel(kError);
-  Context ctx = createContext();
-  Tensor input = createTensor(ctx, Shape{N}, kf32, inp);
-  Tensor output = createTensor(ctx, Shape{N}, kf32);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "GELU_FORWARD_GPU_" + std::to_string(n);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor input = createTensor(ctx, Shape{N}, kf32);
+    Tensor output = createTensor(ctx, Shape{N}, kf32);
+    op = createKernel(ctx, {kShaderGelu, 256, kf32},
+                      Bindings{input, output},
+                      /* nWorkgroups */ {cdiv(N, 256), 1, 1},
+                      nullptr,
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& input = ctx.pool.data[op->buffers[0]];
+  Tensor& output = ctx.pool.data[op->buffers[1]];
+
+  toGPU(ctx, inp, input);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderGelu, 256, kf32},
-                           Bindings{input, output},
-                           /* nWorkgroups */ {cdiv(N, 256), 1, 1});
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, output, out, N * sizeof(float));
@@ -383,15 +687,33 @@ void GELU_FORWARD_GPU(float* out, float* inp, int n) {
 void GELU_BACKWARD_GPU(float* dinp, float* inp, float* dout, int N){
   unsigned long n = static_cast<unsigned long>(N);
   setLogLevel(kError);
-  Context ctx = createContext();
-  Tensor inp_i = createTensor(ctx, Shape{n}, kf32, inp);
-  Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout);
-  Tensor dinp_o = createTensor(ctx, Shape{n}, kf32, dinp);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "GELU_BACKWARD_GPU_" + std::to_string(N);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor inp_i = createTensor(ctx, Shape{n}, kf32);
+    Tensor dout_i = createTensor(ctx, Shape{n}, kf32);
+    Tensor dinp_o = createTensor(ctx, Shape{n}, kf32);
+    op = createKernel(ctx, {kShaderGeluBackward, 256, kf32},
+                      Bindings{inp_i, dout_i, dinp_o},
+                      /* nWorkgroups */ {cdiv(n, 256), 1, 1},
+                      nullptr,
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& inp_i = ctx.pool.data[op->buffers[0]];
+  Tensor& dout_i = ctx.pool.data[op->buffers[1]];
+  Tensor& dinp_o = ctx.pool.data[op->buffers[2]];
+
+  toGPU(ctx, inp, inp_i);
+  toGPU(ctx, dout, dout_i);
+  toGPU(ctx, dinp, dinp_o);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderGeluBackward, 256, kf32},
-                           Bindings{inp_i, dout_i, dinp_o},
-                           /* nWorkgroups */ {cdiv(n, 256), 1, 1});
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dinp_o, dinp, n * sizeof(float));
@@ -400,15 +722,32 @@ void GELU_BACKWARD_GPU(float* dinp, float* inp, float* dout, int N){
 void RESIDUAL_FORWARD_GPU(float* out, float* inp1, float* inp2, int N){
   unsigned long n = static_cast<unsigned long>(N);
   setLogLevel(kError);
-  Context ctx = createContext();
-  Tensor inp1_i = createTensor(ctx, Shape{n}, kf32, inp1);
-  Tensor inp2_i = createTensor(ctx, Shape{n}, kf32, inp2);
-  Tensor out_o = createTensor(ctx, Shape{n}, kf32);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "RESIDUAL_FORWARD_GPU_" + std::to_string(N);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor inp1_i = createTensor(ctx, Shape{n}, kf32);
+    Tensor inp2_i = createTensor(ctx, Shape{n}, kf32);
+    Tensor out_o = createTensor(ctx, Shape{n}, kf32);
+    op = createKernel(ctx, {kShaderResidual, 256, kf32},
+                      Bindings{inp1_i, inp2_i, out_o},
+                      /* nWorkgroups */ {cdiv(n, 256), 1, 1},
+                      nullptr,
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& inp1_i = ctx.pool.data[op->buffers[0]];
+  Tensor& inp2_i = ctx.pool.data[op->buffers[1]];
+  Tensor& out_o = ctx.pool.data[op->buffers[2]];
+
+  toGPU(ctx, inp1, inp1_i);
+  toGPU(ctx, inp2, inp2_i);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderResidual, 256, kf32},
-                           Bindings{inp1_i, inp2_i, out_o},
-                           /* nWorkgroups */ {cdiv(n, 256), 1, 1});
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, out_o, out, n * sizeof(float));
@@ -417,15 +756,33 @@ void RESIDUAL_FORWARD_GPU(float* out, float* inp1, float* inp2, int N){
 void RESIDUAL_BACKWARD_GPU(float* dinp1, float* dinp2, float* dout, int N){
   unsigned long n = static_cast<unsigned long>(N);
   setLogLevel(kError);
-  Context ctx = createContext();
-  Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout);
-  Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32, dinp1);
-  Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32, dinp2);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "RESIDUAL_BACKWARD_GPU_" + std::to_string(N);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor dout_i = createTensor(ctx, Shape{n}, kf32);
+    Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32);
+    Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32);
+    op = createKernel(ctx, {kShaderResidualBackward, 256, kf32},
+                      Bindings{dout_i, dinp1_o, dinp2_o},
+                      /* nWorkgroups */ {cdiv(n, 256), 1, 1},
+                      nullptr,
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& dout_i = ctx.pool.data[op->buffers[0]];
+  Tensor& dinp1_o = ctx.pool.data[op->buffers[1]];
+  Tensor& dinp2_o = ctx.pool.data[op->buffers[2]];
+
+  toGPU(ctx, dout, dout_i);
+  toGPU(ctx, dinp1, dinp1_o);
+  toGPU(ctx, dinp2, dinp2_o);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderResidualBackward, 256, kf32},
-                           Bindings{dout_i, dinp1_o, dinp2_o},
-                           /* nWorkgroups */ {cdiv(n, 256), 1, 1});
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dinp1_o, dinp1, n * sizeof(float));
@@ -442,15 +799,29 @@ void SOFTMAX_FORWARD_GPU(float* probs, float* logits, int B, int T, int V, int V
   uint32_t t = static_cast<uint32_t>(T);
   uint32_t c = static_cast<uint32_t>(V);
   uint32_t cp = static_cast<uint32_t>(Vp);
-  Context ctx = createContext();
-  Tensor input = createTensor(ctx, {b * t, cp}, kf32, logits);
-  Tensor output = createTensor(ctx, {b * t, cp}, kf32);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "SOFTMAX_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor input = createTensor(ctx, {b * t, cp}, kf32);
+    Tensor output = createTensor(ctx, {b * t, cp}, kf32);
+    assert( (B*T) % 256 == 0);
+    op = createKernel(
+        ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output},
+        Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp},
+        nullptr,
+        key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& input = ctx.pool.data[op->buffers[0]];
+  Tensor& output = ctx.pool.data[op->buffers[1]];
+
+  toGPU(ctx, logits, input);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  assert( (B*T) % 256 == 0);
-  Kernel op = createKernel(
-      ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output},
-      Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp});
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, output, probs, sizeof(float)*b*t*cp);
@@ -468,21 +839,38 @@ void CROSSENTROPY_FORWARD_GPU(float* losses,
   unsigned long t = static_cast<unsigned long>(T);
   unsigned long vp = static_cast<unsigned long>(Vp);
   setLogLevel(kError);
-  Context ctx = createContext();
-  Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32, losses);
-  Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs);
-  Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "CROSSENTROPY_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(Vp);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32);
+    Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32);
+    Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32);
+    op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32},
+                      Bindings{losses_t, probs_t, targets_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      CrossEntropyParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(vp)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& losses_t = ctx.pool.data[op->buffers[0]];
+  Tensor& probs_t = ctx.pool.data[op->buffers[1]];
+  Tensor& targets_t = ctx.pool.data[op->buffers[2]];
+
+  toGPU(ctx, losses, losses_t);
+  toGPU(ctx, probs, probs_t);
+  toGPU(ctx, targets, targets_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32},
-                           Bindings{losses_t, probs_t, targets_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           CrossEntropyParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(vp)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, losses_t, losses, b * t * sizeof(float));
@@ -502,23 +890,42 @@ void CROSSENTROPY_SOFTMAX_BACKWARD_GPU(float* dlogits,
   unsigned long v = static_cast<unsigned long>(V);
   unsigned long vp = static_cast<unsigned long>(Vp);
   setLogLevel(kError);
-  Context ctx = createContext();
-  Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32, dlogits);
-  Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32, dlosses);
-  Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs);
-  Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets);
+
+  // Generate the key of the cache by arguments.
+  std::string key = "CROSSENTROPY_SOFTMAX_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp);
+  Kernel op;
+  if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) {
+    Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32);
+    Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32);
+    Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32);
+    Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32);
+    op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32},
+                      Bindings{dlogits_t, dlosses_t, probs_t, targets_t},
+                      /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
+                      /* params */
+                      CrossEntropySoftmaxBackwardParams{
+                        static_cast<uint32_t>(b),
+                        static_cast<uint32_t>(t),
+                        static_cast<uint32_t>(v),
+                        static_cast<uint32_t>(vp)
+                      },
+                      nullptr,
+                      key.c_str());
+  } else {
+    op = ctx.kernelPool.data[key];
+  }
+  Tensor& dlogits_t = ctx.pool.data[op->buffers[0]];
+  Tensor& dlosses_t = ctx.pool.data[op->buffers[1]];
+  Tensor& probs_t = ctx.pool.data[op->buffers[2]];
+  Tensor& targets_t = ctx.pool.data[op->buffers[3]];
+
+  toGPU(ctx, dlogits, dlogits_t);
+  toGPU(ctx, dlosses, dlosses_t);
+  toGPU(ctx, probs, probs_t);
+  toGPU(ctx, targets, targets_t);
+
   std::promise<void> promise;
   std::future<void> future = promise.get_future();
-  Kernel op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32},
-                           Bindings{dlogits_t, dlosses_t, probs_t, targets_t},
-                           /* nWorkgroups */ {cdiv(b * t, 256), 1, 1},
-                           /* params */
-                           CrossEntropySoftmaxBackwardParams{
-                             static_cast<uint32_t>(b),
-                             static_cast<uint32_t>(t),
-                             static_cast<uint32_t>(v),
-                             static_cast<uint32_t>(vp)
-                           });
   dispatchKernel(ctx, op, promise);
   wait(ctx, future);
   toCPU(ctx, dlogits_t, dlogits, b * t * vp * sizeof(float));
diff --git a/gpu.hpp b/gpu.hpp
index 83fc94b..941656e 100644
--- a/gpu.hpp
+++ b/gpu.hpp
@@ -415,12 +415,14 @@ struct KernelCode {
  * @endcode
  * "f32"}});
  */
-inline void
+inline const std::string
 replaceAll(std::string &str,
            const std::vector<std::pair<std::string, std::string>> &reps) {
   for (const auto &rep : reps) {
     replaceAll(str, rep.first, rep.second);
   }
+
+  return str;
 }
 
 /**
@@ -452,7 +454,7 @@ struct CopyData {
  * The struct members can be divided into "consumed upon dispatch"
  * (commandBuffer) and reusable ahead-of-time setup (all other members).
  */
-struct Kernel {
+struct RawKernel {
   std::unique_ptr<WGPUBuffer[]> buffers; // non-owning
   std::unique_ptr<size_t[]> bufferSizes;
   size_t numBindings;
@@ -460,8 +462,11 @@ struct Kernel {
   WGPUBindGroup bindGroup;             // persists between submission
   WGPUComputePipeline computePipeline; // persists between submission
   WGPUCommandBuffer commandBuffer;     // destroyed upon submission
+  bool used;
 };
 
+typedef std::shared_ptr<RawKernel> Kernel;
+
 
 /**
  * @brief A struct to package the result of a WGSL code compilation.
@@ -481,7 +486,7 @@ struct CompilationInfo {
  * @return True if lhs < rhs, false otherwise
  */
 inline bool operator<(const Kernel &lhs, const Kernel &rhs) {
-  return lhs.commandBuffer < rhs.commandBuffer;
+  return lhs->commandBuffer < rhs->commandBuffer;
 }
 
 /**
@@ -492,7 +497,7 @@ inline bool operator<(const Kernel &lhs, const Kernel &rhs) {
 struct KernelPool {
   inline KernelPool(Context *ctx) : ctx(ctx), data() {}
   Context *ctx;
-  std::set<Kernel *> data;
+  std::unordered_map<std::string, Kernel> data;
   inline ~KernelPool() {
     // Note : Some kernel resources such as commandBuffer are harvested by
     // queue submission, explicitly destroying readback and callback buffers
@@ -997,6 +1002,7 @@ inline void wait(Context &ctx, std::future<void> &future) {
 inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize,
                   CopyData &op) {
   wgpuQueueSubmit(ctx.queue, 1, &op.commandBuffer);
+  wgpuCommandBufferRelease(op.commandBuffer);
   CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise,
                                &op.future};
   wgpuQueueOnSubmittedWorkDone(
@@ -1052,14 +1058,17 @@ inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize) {
   }
   {
     WGPUCommandEncoder commandEncoder;
-    WGPUComputePassEncoder computePassEncoder;
     commandEncoder = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr);
     wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, tensor.data.buffer, 0,
                                          op.readbackBuffer, 0, bufferSize);
     op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr);
+    wgpuCommandEncoderRelease(commandEncoder);
     check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__);
   }
   toCPU(ctx, tensor, data, bufferSize, op);
+  if (op.readbackBuffer) {
+    wgpuBufferRelease(op.readbackBuffer);
+  }
 }
 
 /**
@@ -1078,6 +1087,61 @@ void toCPU(Context &ctx, Tensor &tensor, std::array<float, N> &data) {
   toCPU(ctx, tensor, data.data(), sizeof(data));
 }
 
+inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data,
+                  size_t size) {
+  uint64_t bufferSize = size;
+  CopyData op;
+  op.future = op.promise.get_future();
+  {
+    WGPUBufferDescriptor readbackBufferDescriptor = {
+        .usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead,
+        .size = bufferSize,
+    };
+    op.readbackBuffer =
+        wgpuDeviceCreateBuffer(ctx.device, &readbackBufferDescriptor);
+  }
+  {
+    WGPUCommandEncoder commandEncoder;
+    commandEncoder = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr);
+    wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, buffer, 0,
+                                         op.readbackBuffer, 0, bufferSize);
+    op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr);
+    wgpuCommandEncoderRelease(commandEncoder);
+    check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__);
+  }
+  wgpuQueueSubmit(ctx.queue, 1, &op.commandBuffer);
+  wgpuCommandBufferRelease(op.commandBuffer);
+  CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise,
+                               &op.future};
+  wgpuQueueOnSubmittedWorkDone(
+      ctx.queue,
+      [](WGPUQueueWorkDoneStatus status, void *callbackData) {
+        check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done",
+              __FILE__, __LINE__);
+        const auto *data = static_cast<CallbackData *>(callbackData);
+        wgpuBufferMapAsync(
+            data->buffer, WGPUMapMode_Read, 0, data->bufferSize,
+            [](WGPUBufferMapAsyncStatus status, void *captureData) {
+              const auto *data = static_cast<CallbackData *>(captureData);
+              check(status == WGPUBufferMapAsyncStatus_Success,
+                    "Map readbackBuffer", __FILE__, __LINE__);
+              const void *mappedData = wgpuBufferGetConstMappedRange(
+                  data->buffer, /*offset=*/0, data->bufferSize);
+              check(mappedData, "Get mapped range", __FILE__, __LINE__);
+              memcpy(data->output, mappedData, data->bufferSize);
+              wgpuBufferUnmap(data->buffer);
+              data->promise->set_value();
+            },
+            callbackData);
+      },
+      &callbackData);
+  wait(ctx, op.future);
+  if (op.readbackBuffer) {
+    wgpuBufferRelease(op.readbackBuffer);
+  }
+}
+
+  
 /**
  * @brief Copies data from CPU memory to a GPU buffer. The toGPU overloads are
  * effectively a convenience wrapper around the WebGPU API call
@@ -1119,13 +1183,18 @@ inline void toGPU(Context &ctx, const half *data, Tensor &tensor) {
                        tensor.data.size);
 }
 
+inline void toGPU(Context &ctx, const int *data, Tensor &tensor) {
+  wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data,
+                       tensor.data.size);
+}
+
 template <typename Params>
 inline void toGPU(Context &ctx, Params &params, Kernel &op) {
   // TODO(avh): Maintain params metadata in Kernel and check for consistency.
   // If a kernel does not have parameters this will quietly overwrite
   // the last buffer in the bind group with the parameters buffer.
-  if (op.numBindings > 0) {
-    wgpuQueueWriteBuffer(ctx.queue, op.buffers[op.numBindings - 1], 0,
+  if (op->numBindings > 0) {
+    wgpuQueueWriteBuffer(ctx.queue, op->buffers[op->numBindings - 1], 0,
                          static_cast<void *>(&params), sizeof(params));
   }
 }
@@ -1148,14 +1217,17 @@ inline void resetCommandBuffer(WGPUDevice &device, Kernel &op) {
         wgpuDeviceCreateCommandEncoder(device, nullptr);
     WGPUComputePassEncoder computePassEncoder =
         wgpuCommandEncoderBeginComputePass(commandEncoder, nullptr);
-    wgpuComputePassEncoderSetPipeline(computePassEncoder, op.computePipeline);
-    wgpuComputePassEncoderSetBindGroup(computePassEncoder, 0, op.bindGroup, 0,
+    wgpuComputePassEncoderSetPipeline(computePassEncoder, op->computePipeline);
+    wgpuComputePassEncoderSetBindGroup(computePassEncoder, 0, op->bindGroup, 0,
                                        nullptr);
     wgpuComputePassEncoderDispatchWorkgroups(
-        computePassEncoder, op.totalWorkgroups[0], op.totalWorkgroups[1],
-        op.totalWorkgroups[2]);
+        computePassEncoder, op->totalWorkgroups[0], op->totalWorkgroups[1],
+        op->totalWorkgroups[2]);
     wgpuComputePassEncoderEnd(computePassEncoder);
-    op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr);
+    wgpuComputePassEncoderRelease(computePassEncoder);
+    op->commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr);
+    wgpuCommandEncoderRelease(commandEncoder);
+    op->used = false;
   }
 }
 
@@ -1217,11 +1289,19 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code,
                            const size_t *viewOffsets, const Shape &totalWorkgroups,
                            const void *params = nullptr,
                            size_t paramsSize = 0,
-                           CompilationInfo* compilationInfo = nullptr) {
+                           CompilationInfo* compilationInfo = nullptr,
+                           const char* cacheKey = nullptr) {
+  // Create a cache key by the pointer values of the data bindings and the kernel code
+  if (cacheKey != nullptr && ctx.kernelPool.data.find(cacheKey) != ctx.kernelPool.data.end()) {
+    LOG(kDefLog, kInfo, "Kernel cache hit");
+    return ctx.kernelPool.data[cacheKey];
+  }
+
   assert(totalWorkgroups.rank == 3);
   WGPUDevice device = ctx.device;
   WGPUQueue queue = ctx.queue;
-  Kernel op;
+  Kernel op(new RawKernel());
+
   // paramIndex is the index into bgLayoutEntries for the parameters buffer If
   // there are no parameters for the kernel, paramsSize == 0 and paramIndex is
   // effectively undefined (== -1)
@@ -1234,9 +1314,9 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code,
                                   // op.buffers, op.bufferSizes and
                                   // bgLayoutEntries
   }
-  op.buffers = std::make_unique<WGPUBuffer[]>(numBindings);
-  op.bufferSizes = std::make_unique<size_t[]>(numBindings);
-  op.numBindings = numBindings;
+  op->buffers = std::make_unique<WGPUBuffer[]>(numBindings);
+  op->bufferSizes = std::make_unique<size_t[]>(numBindings);
+  op->numBindings = numBindings;
   std::vector<WGPUBindGroupLayoutEntry> bgLayoutEntries(numBindings);
   // Create layout entries for input buffers
   for (size_t i = 0; i < numTensors; ++i) {
@@ -1270,8 +1350,8 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code,
   WGPUBindGroupLayout bgLayout =
       wgpuDeviceCreateBindGroupLayout(device, &bgLayoutDesc);
   for (size_t i = 0; i < numTensors; ++i) {
-    op.buffers[i] = dataBindings[i].data.buffer;
-    op.bufferSizes[i] = dataBindings[i].data.size;
+    op->buffers[i] = dataBindings[i].data.buffer;
+    op->bufferSizes[i] = dataBindings[i].data.size;
   }
   // Create a buffer for the Params struct
   if (paramsSize > 0) {
@@ -1280,9 +1360,9 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code,
         .size = paramsSize,
         .mappedAtCreation = false,
     };
-    op.buffers[paramIndex] = wgpuDeviceCreateBuffer(device, &paramsBufferDesc);
-    op.bufferSizes[paramIndex] = paramsSize;
-    wgpuQueueWriteBuffer(queue, op.buffers[paramIndex], 0, params, paramsSize);
+    op->buffers[paramIndex] = wgpuDeviceCreateBuffer(device, &paramsBufferDesc);
+    op->bufferSizes[paramIndex] = paramsSize;
+    wgpuQueueWriteBuffer(queue, op->buffers[paramIndex], 0, params, paramsSize);
     LOG(kDefLog, kTrace, "Params buffer written");
   } else {
     LOG(kDefLog, kTrace, "No params buffer needed");
@@ -1291,9 +1371,9 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code,
   for (size_t i = 0; i < numTensors; ++i) {
     bindGroupEntries[i] = WGPUBindGroupEntry{
         .binding = static_cast<uint32_t>(i),
-        .buffer = op.buffers[i],
+        .buffer = op->buffers[i],
         .offset = viewOffsets[i],
-        .size = op.bufferSizes[i],
+        .size = op->bufferSizes[i],
     };
   }
   if (paramsSize > 0) {
@@ -1301,7 +1381,7 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code,
     LOG(kDefLog, kInfo, "paramIndex: %d", paramIndex);
     bindGroupEntries[paramIndex] = WGPUBindGroupEntry{
         .binding = static_cast<uint32_t>(paramIndex),
-        .buffer = op.buffers[paramIndex],
+        .buffer = op->buffers[paramIndex],
         .offset = 0,
         .size = paramsSize,
     };
@@ -1312,7 +1392,7 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code,
       .entryCount = static_cast<uint32_t>(numBindings),
       .entries = bindGroupEntries.data(),
   };
-  op.bindGroup = wgpuDeviceCreateBindGroup(device, &bindGroupDesc);
+  op->bindGroup = wgpuDeviceCreateBindGroup(device, &bindGroupDesc);
 
   WGPUPipelineLayoutDescriptor pipelineLayoutDesc = {
       .bindGroupLayoutCount = 1,
@@ -1334,12 +1414,13 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code,
 
   computePipelineDesc.compute.entryPoint = code.entryPoint.c_str();
   computePipelineDesc.label = code.label.c_str();
-  op.computePipeline =
+  op->computePipeline =
       wgpuDeviceCreateComputePipeline(device, &computePipelineDesc);
-  op.totalWorkgroups = {totalWorkgroups[0], totalWorkgroups[1], totalWorkgroups[2]};
+  op->totalWorkgroups = {totalWorkgroups[0], totalWorkgroups[1], totalWorkgroups[2]};
   resetCommandBuffer(device, op);
-  ctx.kernelPool.data.insert(&op);
-
+  if (cacheKey != nullptr)
+    ctx.kernelPool.data[cacheKey]=op;
+  
   WGPUCompilationInfoCallback cb =
       [](WGPUCompilationInfoRequestStatus status,
          WGPUCompilationInfo const *compilationInfo, void *userData) {
@@ -1394,17 +1475,20 @@ Kernel createKernel(Context &ctx, const KernelCode &code,
                     const Bindings<numInputs> &dataBindings,
                     const Shape &totalWorkgroups,
                     const ParamsType &params = ParamsType{},
-                    CompilationInfo* compilationInfo = nullptr 
+                    CompilationInfo* compilationInfo = nullptr,
+                    const char* cacheKey = nullptr
                     ) {
   if constexpr (!IsNoParam<ParamsType>) {
     return createKernel(ctx, code, dataBindings.data.data(), numInputs,
                         dataBindings.viewOffsets.data(), totalWorkgroups,
                         reinterpret_cast<const void *>(&params),
-                        sizeof(ParamsType), compilationInfo);
+                        sizeof(ParamsType), compilationInfo,
+                        cacheKey);
   } else {
     return createKernel(ctx, code, dataBindings.data.data(), numInputs,
                         dataBindings.viewOffsets.data(), totalWorkgroups, nullptr,
-                        0, compilationInfo);
+                        0, compilationInfo,
+                        cacheKey);
   }
 }
 
@@ -1429,7 +1513,12 @@ Kernel createKernel(Context &ctx, const KernelCode &code,
 inline void dispatchKernel(Context &ctx, Kernel &kernel,
                            std::promise<void> &promise) {
   // Submit the command buffer
-  wgpuQueueSubmit(ctx.queue, 1, &kernel.commandBuffer);
+  if (kernel->used) {
+    resetCommandBuffer(ctx.device, kernel);
+  }
+  wgpuQueueSubmit(ctx.queue, 1, &kernel->commandBuffer);
+  wgpuCommandBufferRelease(kernel->commandBuffer);
+  kernel->used = true;
   wgpuQueueOnSubmittedWorkDone(
       ctx.queue,
       [](WGPUQueueWorkDoneStatus status, void *data) {