@@ -309,6 +309,104 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
309309 }
310310 }
311311}
312+
313+ )" ;
314+
315+
316+ static const char *kShaderMatmul2DTiling = R"(
317+ @group(0) @binding(0) var<storage, read_write> inp : array<{{precision}}>;
318+ @group(0) @binding(1) var<storage, read_write> weight : array<{{precision}}>;
319+ @group(0) @binding(2) var<storage, read_write> bias : array<{{precision}}>;
320+ @group(0) @binding(3) var<storage, read_write> out : array<{{precision}}>;
321+ @group(0) @binding(4) var<uniform> params : Params;
322+ struct Params {
323+ B: u32,
324+ T: u32,
325+ C: u32,
326+ OC: u32,
327+ };
328+ var<workgroup> tileInp: array<{{precision}}, {{BT}} * {{BC}}>;
329+ var<workgroup> tileWeight: array<{{precision}}, {{BOC}} * {{BC}}>;
330+
331+ @compute @workgroup_size({{workgroupSize}})
332+ fn main(
333+ @builtin(local_invocation_id) localID : vec3<u32>,
334+ @builtin(workgroup_id) groupid : vec3<u32>) {
335+ let B : u32 = params.B;
336+ let T : u32 = params.T;
337+ let C : u32 = params.C;
338+ let OC : u32 = params.OC;
339+
340+ var localT: array<{{precision}}, {{TT}}>;
341+ var localOC: array<{{precision}}, {{TOC}}>;
342+
343+ let outB: u32 = groupid.x;
344+ let outT: u32 = groupid.y;
345+ let outOC: u32 = groupid.z;
346+ let numThread: u32 = ({{BT}} * {{BOC}}) / ({{TT}} * {{TOC}});
347+
348+ // position of the first c element computed by the thread
349+ let threadRow: u32 = (localID.x / ({{BOC}} / {{TOC}})) * {{TT}};
350+ let threadCol: u32 = (localID.x % ({{BOC}} / {{TOC}})) * {{TOC}};
351+
352+ // inpPtr and weightPtr are the starting positions of the tiles in a and b,
353+ // incremented in the bkidx loop.
354+ // outPtr is the starting position of the tile in c which is fixed.
355+
356+ var inpPtr = (outB * T + outT * {{BT}}) * C; // BTC
357+ var weightPtr = outOC * {{BOC}} * C; //OCC
358+ var threadResults: array<{{precision}}, {{TT}} * {{TOC}}>;
359+ let outPtr = (outB * T + outT * {{BT}}) * OC + outOC * {{BOC}}; //BTOC
360+ let biasPtr = outOC * {{BOC}};
361+
362+ for (var bkidx: u32 = 0; bkidx < C; bkidx += {{BC}}) {
363+ // Load BC x BOC by numThread(BT * BOC / (TT * TOC))
364+ // The number of iteration == BC * BOC / (BT * BOC / (TT * TOC))
365+ for (var idx: u32 = 0; idx < {{NUM_TILEW}}; idx++) {
366+ tileWeight[localID.x + idx * numThread] = weight[weightPtr + ((localID.x + idx * numThread) / {{BC}}) * C + ((localID.x + idx * numThread) % {{BC}})];
367+ }
368+ weightPtr += {{BC}};
369+
370+ // Load tile
371+ // Load BT x BC by numThread(BT * BOC / (TT * TOC))
372+ // The number of iteration == BT * BC / (BT * BOC / (TT * TOC))
373+ for (var idx: u32 = 0; idx < {{NUM_TILEI}}; idx++) {
374+ tileInp[localID.x + idx * numThread] = inp[inpPtr + ((localID.x + idx * numThread) / {{BC}}) * C + (localID.x + idx * numThread) % {{BC}}];
375+ }
376+ inpPtr += {{BC}};
377+
378+ workgroupBarrier();
379+ // Compute tile
380+ for (var dotIdx: u32 = 0; dotIdx < {{BC}}; dotIdx = dotIdx + 1) {
381+ for (var idx: u32 = 0; idx < {{TT}}; idx++) {
382+ localT[idx] = tileInp[(threadRow + idx) * {{BC}} + dotIdx];
383+ }
384+ for (var idx: u32 = 0; idx < {{TOC}}; idx++) {
385+ localOC[idx] = tileWeight[(threadCol + idx) * {{BC}} + dotIdx];
386+ }
387+ for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
388+ for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
389+ threadResults[resIdxT * {{TOC}} + resIdxOC] += localT[resIdxT] * localOC[resIdxOC];
390+ }
391+ }
392+ }
393+ workgroupBarrier();
394+ }
395+
396+ if (arrayLength(&bias) == 1) {
397+ for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
398+ for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
399+ out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC];
400+ }
401+ }
402+ } else {
403+ for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
404+ for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
405+ out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC] + bias[biasPtr + threadCol + resIdxOC];
406+ }
407+ }
408+ }
409+ }
312410)" ;
313411
314412static const char *kShaderMatmulBackward = R"(
0 commit comments