@@ -309,6 +309,104 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
309
309
}
310
310
}
311
311
}
312
+
313
+ )" ;
314
+
315
+
316
+ static const char *kShaderMatmul2DTiling = R"(
317
+ @group(0) @binding(0) var<storage, read_write> inp : array<{{precision}}>;
318
+ @group(0) @binding(1) var<storage, read_write> weight : array<{{precision}}>;
319
+ @group(0) @binding(2) var<storage, read_write> bias : array<{{precision}}>;
320
+ @group(0) @binding(3) var<storage, read_write> out : array<{{precision}}>;
321
+ @group(0) @binding(4) var<uniform> params : Params;
322
+ struct Params {
323
+ B: u32,
324
+ T: u32,
325
+ C: u32,
326
+ OC: u32,
327
+ };
328
+ var<workgroup> tileInp: array<{{precision}}, {{BT}} * {{BC}}>;
329
+ var<workgroup> tileWeight: array<{{precision}}, {{BOC}} * {{BC}}>;
330
+
331
+ @compute @workgroup_size({{workgroupSize}})
332
+ fn main(
333
+ @builtin(local_invocation_id) localID : vec3<u32>,
334
+ @builtin(workgroup_id) groupid : vec3<u32>) {
335
+ let B : u32 = params.B;
336
+ let T : u32 = params.T;
337
+ let C : u32 = params.C;
338
+ let OC : u32 = params.OC;
339
+
340
+ var localT: array<{{precision}}, {{TT}}>;
341
+ var localOC: array<{{precision}}, {{TOC}}>;
342
+
343
+ let outB: u32 = groupid.x;
344
+ let outT: u32 = groupid.y;
345
+ let outOC: u32 = groupid.z;
346
+ let numThread: u32 = ({{BT}} * {{BOC}}) / ({{TT}} * {{TOC}});
347
+
348
+ // position of the first c element computed by the thread
349
+ let threadRow: u32 = (localID.x / ({{BOC}} / {{TOC}})) * {{TT}};
350
+ let threadCol: u32 = (localID.x % ({{BOC}} / {{TOC}})) * {{TOC}};
351
+
352
+ // inpPtr and weightPtr are the starting positions of the tiles in a and b,
353
+ // incremented in the bkidx loop.
354
+ // outPtr is the starting position of the tile in c which is fixed.
355
+
356
+ var inpPtr = (outB * T + outT * {{BT}}) * C; // BTC
357
+ var weightPtr = outOC * {{BOC}} * C; //OCC
358
+ var threadResults: array<{{precision}}, {{TT}} * {{TOC}}>;
359
+ let outPtr = (outB * T + outT * {{BT}}) * OC + outOC * {{BOC}}; //BTOC
360
+ let biasPtr = outOC * {{BOC}};
361
+
362
+ for (var bkidx: u32 = 0; bkidx < C; bkidx += {{BC}}) {
363
+ // Load BC x BOC by numThread(BT * BOC / (TT * TOC))
364
+ // The number of iteration == BC * BOC / (BT * BOC / (TT * TOC))
365
+ for (var idx: u32 = 0; idx < {{NUM_TILEW}}; idx++) {
366
+ tileWeight[localID.x + idx * numThread] = weight[weightPtr + ((localID.x + idx * numThread) / {{BC}}) * C + ((localID.x + idx * numThread) % {{BC}})];
367
+ }
368
+ weightPtr += {{BC}};
369
+
370
+ // Load tile
371
+ // Load BT x BC by numThread(BT * BOC / (TT * TOC))
372
+ // The number of iteration == BT * BC / (BT * BOC / (TT * TOC))
373
+ for (var idx: u32 = 0; idx < {{NUM_TILEI}}; idx++) {
374
+ tileInp[localID.x + idx * numThread] = inp[inpPtr + ((localID.x + idx * numThread) / {{BC}}) * C + (localID.x + idx * numThread) % {{BC}}];
375
+ }
376
+ inpPtr += {{BC}};
377
+
378
+ workgroupBarrier();
379
+ // Compute tile
380
+ for (var dotIdx: u32 = 0; dotIdx < {{BC}}; dotIdx = dotIdx + 1) {
381
+ for (var idx: u32 = 0; idx < {{TT}}; idx++) {
382
+ localT[idx] = tileInp[(threadRow + idx) * {{BC}} + dotIdx];
383
+ }
384
+ for (var idx: u32 = 0; idx < {{TOC}}; idx++) {
385
+ localOC[idx] = tileWeight[(threadCol + idx) * {{BC}} + dotIdx];
386
+ }
387
+ for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
388
+ for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
389
+ threadResults[resIdxT * {{TOC}} + resIdxOC] += localT[resIdxT] * localOC[resIdxOC];
390
+ }
391
+ }
392
+ }
393
+ workgroupBarrier();
394
+ }
395
+
396
+ if (arrayLength(&bias) == 1) {
397
+ for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
398
+ for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
399
+ out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC];
400
+ }
401
+ }
402
+ } else {
403
+ for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
404
+ for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
405
+ out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC] + bias[biasPtr + threadCol + resIdxOC];
406
+ }
407
+ }
408
+ }
409
+ }
312
410
)" ;
313
411
314
412
static const char *kShaderMatmulBackward = R"(
0 commit comments