14
14
15
15
using namespace gpu ;
16
16
17
+ const char * versionToStr (int version);
18
+
17
19
static const char *kShaderMatmul1 = R"(
18
20
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
19
21
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@@ -220,7 +222,7 @@ inline KernelCode createMatmul3(const char *shaderTemplate, const size_t M,
220
222
{" {{BN}}" , toString (BN)},
221
223
{" {{TM}}" , toString (TM)}});
222
224
if (unrolling) {
223
- std::string unrolledCode = loopUnrolling (codeString);
225
+ std::string unrolledCode = loopUnrolling (removeUnnecessaryIfStatements ( codeString) );
224
226
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
225
227
return {unrolledCode, workgroupSize};
226
228
} else {
@@ -260,9 +262,14 @@ fn main(
260
262
// incremented in the bkidx loop.
261
263
// cPtr is the starting position of the tile in c which is fixed.
262
264
263
- var aPtr = cRow * {{BM}} * {{K}};
264
- var bPtr = cCol * {{BN}} * {{K}};
265
- let cPtr = cRow * {{BM}} * {{N}} + cCol * {{BN}};
265
+ var aPtr: u32 = cRow * {{BM}} * {{K}};
266
+ var bPtr: u32 = 0;
267
+ if ({{TRANSPOSE}}) {
268
+ bPtr = cCol * {{BN}};
269
+ } else {
270
+ bPtr = cCol * {{BN}} * {{K}};
271
+ }
272
+ let cPtr: u32 = cRow * {{BM}} * {{N}} + cCol * {{BN}};
266
273
267
274
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
268
275
@@ -275,11 +282,19 @@ fn main(
275
282
// Load BK x BN by numThread(BM * BN / (TM * TN))
276
283
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
277
284
for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
278
- tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
285
+ if ({{TRANSPOSE}}) {
286
+ tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BN}}) * {{N}} + ((localID.x + idx * numThread) % {{BN}})];
287
+ } else {
288
+ tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
289
+ }
279
290
}
280
291
281
292
aPtr += {{BK}};
282
- bPtr += {{BK}};
293
+ if ({{TRANSPOSE}}) {
294
+ bPtr += {{BK}} * {{N}};
295
+ } else {
296
+ bPtr += {{BK}};
297
+ }
283
298
284
299
workgroupBarrier();
285
300
// Compute tile
@@ -288,7 +303,11 @@ fn main(
288
303
localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
289
304
}
290
305
for (var idx: u32 = 0; idx < {{TN}}; idx++) {
291
- localN[idx] = tileB[(threadCol + idx) * {{BK}} + dotIdx];
306
+ if ({{TRANSPOSE}}) {
307
+ localN[idx] = tileB[dotIdx * {{BN}} + threadCol + idx];
308
+ } else {
309
+ localN[idx] = tileB[(threadCol + idx) * {{BK}} + dotIdx];
310
+ }
292
311
}
293
312
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
294
313
for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) {
@@ -311,6 +330,7 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
311
330
const size_t K, const size_t N, const size_t BM,
312
331
const size_t BK, const size_t BN,
313
332
const size_t TM, const size_t TN,
333
+ const bool transpose = false ,
314
334
const Shape &workgroupSize = {256 , 1 , 1 },
315
335
NumType precision = kf32,
316
336
bool unrolling = false ) {
@@ -333,17 +353,19 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
333
353
{" {{TM}}" , toString (TM)},
334
354
{" {{TN}}" , toString (TN)},
335
355
{" {{NUM_TILEA}}" , toString (BM * BK / num_threads)},
336
- {" {{NUM_TILEB}}" , toString (BN * BK / num_threads)}
356
+ {" {{NUM_TILEB}}" , toString (BN * BK / num_threads)},
357
+ {" {{TRANSPOSE}}" , transpose ? " true" : " false" },
337
358
});
338
359
if (unrolling) {
339
- std::string unrolledCode = loopUnrolling (codeString);
360
+ std::string unrolledCode = loopUnrolling (removeUnnecessaryIfStatements ( codeString) );
340
361
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
341
362
return {unrolledCode, workgroupSize};
342
363
} else {
343
364
return {codeString, workgroupSize};
344
365
}
345
366
}
346
367
368
+
347
369
/* 2D block-tiling with vectorization
348
370
*
349
371
*/
@@ -376,9 +398,14 @@ fn main(
376
398
// incremented in the bkidx loop.
377
399
// cPtr is the starting position of the tile in c which is fixed.
378
400
379
- var aPtr = cRow * {{BM}} * {{K}};
380
- var bPtr = cCol * {{BN}} * {{K}};
381
- let cPtr = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
401
+ var aPtr: u32 = cRow * {{BM}} * {{K}};
402
+ var bPtr: u32 = 0;
403
+ if ({{TRANSPOSE}}) {
404
+ bPtr = cCol * {{BN}};
405
+ } else {
406
+ bPtr = cCol * {{BN}} * {{K}};
407
+ }
408
+ let cPtr: u32 = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
382
409
383
410
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
384
411
@@ -391,11 +418,19 @@ fn main(
391
418
// Load BK x BN by numThread(BM * BN / (TM * TN))
392
419
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
393
420
for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
394
- tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
421
+ if ({{TRANSPOSE}}) {
422
+ tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BN}}) * {{N}} + ((localID.x + idx * numThread) % {{BN}})];
423
+ } else {
424
+ tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
425
+ }
395
426
}
396
427
397
428
aPtr += {{BK}};
398
- bPtr += {{BK}};
429
+ if ({{TRANSPOSE}}) {
430
+ bPtr += {{BK}} * {{N}};
431
+ } else {
432
+ bPtr += {{BK}};
433
+ }
399
434
400
435
workgroupBarrier();
401
436
// Compute tile
@@ -404,10 +439,17 @@ fn main(
404
439
localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
405
440
}
406
441
for (var idx: u32 = 0; idx < {{TN4}}; idx++) {
407
- localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) * {{BK}} + dotIdx],
408
- tileB[(threadCol + idx*4 + 1) * {{BK}} + dotIdx],
409
- tileB[(threadCol + idx*4 + 2) * {{BK}} + dotIdx],
410
- tileB[(threadCol + idx*4 + 3) * {{BK}} + dotIdx]);
442
+ if ({{TRANSPOSE}}) {
443
+ localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) + dotIdx * {{BN}}],
444
+ tileB[(threadCol + idx*4 + 1) + dotIdx * {{BN}}],
445
+ tileB[(threadCol + idx*4 + 2) + dotIdx * {{BN}}],
446
+ tileB[(threadCol + idx*4 + 3) + dotIdx * {{BN}}]);
447
+ } else {
448
+ localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) * {{BK}} + dotIdx],
449
+ tileB[(threadCol + idx*4 + 1) * {{BK}} + dotIdx],
450
+ tileB[(threadCol + idx*4 + 2) * {{BK}} + dotIdx],
451
+ tileB[(threadCol + idx*4 + 3) * {{BK}} + dotIdx]);
452
+ }
411
453
}
412
454
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
413
455
for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
@@ -430,6 +472,7 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
430
472
const size_t K, const size_t N, const size_t BM,
431
473
const size_t BK, const size_t BN,
432
474
const size_t TM, const size_t TN,
475
+ const bool transpose = false ,
433
476
const Shape &workgroupSize = {256 , 1 , 1 },
434
477
NumType precision = kf32,
435
478
bool unrolling = false ) {
@@ -456,9 +499,10 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
456
499
{" {{TN4}}" , toString (TN / 4 )},
457
500
{" {{N4}}" , toString (N / 4 )},
458
501
{" {{BN4}}" , toString (BN / 4 )},
502
+ {" {{TRANSPOSE}}" , transpose ? " true" : " false" },
459
503
});
460
504
if (unrolling) {
461
- std::string unrolledCode = loopUnrolling (codeString);
505
+ std::string unrolledCode = loopUnrolling (removeUnnecessaryIfStatements ( codeString) );
462
506
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
463
507
return {unrolledCode, workgroupSize};
464
508
} else {
@@ -519,20 +563,26 @@ Kernel selectMatmul(Context &ctx, int version,
519
563
size_t M, size_t K, size_t N) {
520
564
Kernel kernel;
521
565
if (version == 1 ) {
566
+ Shape wgSize = {256 , 1 , 1 };
567
+ Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
568
+ KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
569
+ kernel = createKernel (ctx, matmul, bindings,
570
+ /* nWorkgroups*/ nWorkgroups);
571
+ } else if (version == 2 ) {
522
572
Shape wgSize = {16 , 16 , 1 };
523
573
LOG (kDefLog , kInfo , " wgSize: %s" , toString (wgSize).c_str ());
524
574
KernelCode matmul =
525
575
createMatmul1 (kShaderMatmul1 , M, K, N, /* wgsize*/ wgSize);
526
576
kernel = createKernel (ctx, matmul, bindings,
527
577
/* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize));
528
- } else if (version == 2 ) {
578
+ } else if (version == 3 ) {
529
579
static constexpr size_t tileSize = 16 ;
530
580
KernelCode matmul = createMatmul2 (kShaderMatmul2 , M, K, N,
531
581
/* wgSize*/ {tileSize * tileSize, 1 , 1 });
532
582
kernel =
533
583
createKernel (ctx, matmul, bindings,
534
584
/* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }));
535
- } else if (version == 3 || version == 5 ) {
585
+ } else if (version == 4 || version == 6 ) {
536
586
static constexpr size_t BM = 64 ;
537
587
static constexpr size_t BK = 4 ;
538
588
static constexpr size_t BN = BM;
@@ -548,10 +598,10 @@ Kernel selectMatmul(Context &ctx, int version,
548
598
KernelCode matmul = createMatmul3 (kShaderMatmul3 , M, K, N, BM, BK, BN, TM,
549
599
/* wgSize*/ wgSize,
550
600
kf32,
551
- /* Loop unrolling*/ version == 5 ? true : false );
601
+ /* Loop unrolling*/ version == 6 ? true : false );
552
602
kernel = createKernel (ctx, matmul, bindings,
553
603
/* nWorkgroups*/ nWorkgroups);
554
- } else if (version == 4 || version == 6 ) {
604
+ } else if (version == 5 || version == 7 ) {
555
605
static constexpr size_t BM = 64 ;
556
606
static constexpr size_t BK = 8 ;
557
607
static constexpr size_t BN = 64 ;
@@ -564,12 +614,13 @@ Kernel selectMatmul(Context &ctx, int version,
564
614
LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
565
615
LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
566
616
KernelCode matmul = createMatmul4 (kShaderMatmul4 , M, K, N, BM, BK, BN, TM, TN,
617
+ /* tranpose*/ false ,
567
618
/* wgSize*/ wgSize,
568
619
kf32,
569
- /* Loop unrolling*/ version == 6 ? true : false );
620
+ /* Loop unrolling*/ version == 7 ? true : false );
570
621
kernel = createKernel (ctx, matmul, bindings,
571
622
/* nWorkgroups*/ nWorkgroups);
572
- } else if (version == 7 ) {
623
+ } else if (version == 8 || version == 9 ) {
573
624
static constexpr size_t BM = 64 ;
574
625
static constexpr size_t BK = 8 ;
575
626
static constexpr size_t BN = 64 ;
@@ -582,17 +633,12 @@ Kernel selectMatmul(Context &ctx, int version,
582
633
LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
583
634
LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
584
635
KernelCode matmul = createMatmulWithVectorization (kShaderMatmulWithVectorization , M, K, N, BM, BK, BN, TM, TN,
636
+ /* tranpose*/ version == 9 ,
585
637
/* wgSize*/ wgSize,
586
638
kf32,
587
639
/* Loop unrolling*/ true );
588
640
kernel = createKernel (ctx, matmul, bindings,
589
641
/* nWorkgroups*/ nWorkgroups);
590
- } else if (version == 8 ) {
591
- Shape wgSize = {256 , 1 , 1 };
592
- Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
593
- KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
594
- kernel = createKernel (ctx, matmul, bindings,
595
- /* nWorkgroups*/ nWorkgroups);
596
642
}
597
643
return kernel;
598
644
}
@@ -626,8 +672,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
626
672
627
673
printf (" [ Press enter to start tests ... ]\n " );
628
674
getchar ();
629
- LOG (kDefLog , kInfo , " Dispatching Kernel version %d, %d iterations ..." ,
630
- version, nIter);
675
+ LOG (kDefLog , kInfo , " Dispatching Kernel version %d: %s , %d iterations ..." ,
676
+ version, versionToStr (version), nIter);
631
677
632
678
// Dispatch kernel nIter times
633
679
auto start = std::chrono::high_resolution_clock::now ();
@@ -662,26 +708,43 @@ void runTest(int version, size_t M, size_t K, size_t N,
662
708
M, K, N, nIter, duration.count () / static_cast <double >(nIter) / 1000.0 /* us -> ms */ , gflops);
663
709
}
664
710
711
+ const char * versionToStr (int version){
712
+ switch (version) {
713
+ case 1 : return " No-Op" ;
714
+ case 2 : return " naive matmul" ;
715
+ case 3 : return " tiling" ;
716
+ case 4 : return " 1D blocktiling" ;
717
+ case 5 : return " 2D blocktiling" ;
718
+ case 6 : return " 1D blocktiling with loop unrolling" ;
719
+ case 7 : return " 2D blocktiling with loop unrolling" ;
720
+ case 8 : return " 2D blocktiling with loop unrolling and vectorization" ;
721
+ case 9 : return " 2D blocktiling with loop unrolling, vectorization and transpose" ;
722
+ default : return " Not specified" ;
723
+ }
724
+ }
725
+
665
726
int main () {
666
727
char * version_str = getenv (" MATMUL_VERSION" );
667
- int version = version_str == NULL ? 7 : atoi (version_str);
668
- // 1 == naive matmul
669
- // 2 == tiling
670
- // 3 == 1D blocktiling
671
- // 4 == 2D blocktiling
672
- // 5 == 1D blocktiling with loop unrolling
673
- // 6 == 2D blocktiling with loop unrolling
674
- // 7 == 2D blocktiling with loop unrolling and vectorization
675
- // 8 == No-Op
728
+ char * kTestSize_str = getenv (" MATMUL_SIZE" );
729
+ int version = version_str == NULL ? 9 : atoi (version_str);
730
+ // 1 == No-Op
731
+ // 2 == naive matmul
732
+ // 3 == tiling
733
+ // 4 == 1D blocktiling
734
+ // 5 == 2D blocktiling
735
+ // 6 == 1D blocktiling with loop unrolling
736
+ // 7 == 2D blocktiling with loop unrolling
737
+ // 8 == 2D blocktiling with loop unrolling and vectorization
738
+ // 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
676
739
677
740
size_t M, K, N; // Matrix dimensions
678
- static constexpr int kTestSize = 2 ;
679
- if constexpr (kTestSize == 0 ) {
741
+ int kTestSize = kTestSize_str == NULL ? 2 : atoi ( kTestSize_str ) ;
742
+ if (kTestSize == 0 ) {
680
743
// Tiny test
681
744
M = 32 ;
682
745
K = 32 ;
683
746
N = 32 ;
684
- } else if constexpr (kTestSize == 1 ) {
747
+ } else if (kTestSize == 1 ) {
685
748
// Small test
686
749
M = 256 ;
687
750
K = 128 ;
@@ -696,11 +759,19 @@ int main() {
696
759
std::unique_ptr<float []> inputPtr = std::make_unique<float []>(M * K);
697
760
std::unique_ptr<float []> weightsPtr = std::make_unique<float []>(N * K);
698
761
std::unique_ptr<float []> outputPtr = std::make_unique<float []>(M * N);
762
+ bool transposedInput = version == 9 ;
699
763
700
764
initData (M, K, N, inputPtr, weightsPtr);
701
- runTest (version, M, K, N, inputPtr, weightsPtr, outputPtr);
765
+ if (transposedInput) {
766
+ std::unique_ptr<float []> transposedWeightPtr = std::make_unique<float []>(K * N);
767
+ transpose (weightsPtr.get (), transposedWeightPtr.get (), N, K);
768
+ runTest (version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
769
+ } else {
770
+ runTest (version, M, K, N, inputPtr, weightsPtr, outputPtr);
771
+ }
772
+
702
773
703
- if constexpr (kTestSize <= 1 ) {
774
+ if (kTestSize <= 1 ) {
704
775
// Check result with CPU reference implementation for tiny/small tests
705
776
checkCPU (M, K, N, inputPtr, weightsPtr, outputPtr);
706
777
}
0 commit comments