14
14
15
15
using namespace gpu ;
16
16
17
+ const char * versionToStr (int version);
18
+
17
19
static const char *kShaderMatmul1 = R"(
18
20
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
19
21
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@@ -260,9 +262,9 @@ fn main(
260
262
// incremented in the bkidx loop.
261
263
// cPtr is the starting position of the tile in c which is fixed.
262
264
263
- var aPtr = cRow * {{BM}} * {{K}};
264
- var bPtr = cCol * {{BN}} * {{K}};
265
- let cPtr = cRow * {{BM}} * {{N}} + cCol * {{BN}};
265
+ var aPtr: u32 = cRow * {{BM}} * {{K}};
266
+ var bPtr: u32 = cCol * {{BN}} * {{K}};
267
+ let cPtr: u32 = cRow * {{BM}} * {{N}} + cCol * {{BN}};
266
268
267
269
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
268
270
@@ -275,7 +277,7 @@ fn main(
275
277
// Load BK x BN by numThread(BM * BN / (TM * TN))
276
278
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
277
279
for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
278
- tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
280
+ tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
279
281
}
280
282
281
283
aPtr += {{BK}};
@@ -344,6 +346,7 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
344
346
}
345
347
}
346
348
349
+
347
350
/* 2D block-tiling with vectorization
348
351
*
349
352
*/
@@ -376,9 +379,9 @@ fn main(
376
379
// incremented in the bkidx loop.
377
380
// cPtr is the starting position of the tile in c which is fixed.
378
381
379
- var aPtr = cRow * {{BM}} * {{K}};
380
- var bPtr = cCol * {{BN}} * {{K}};
381
- let cPtr = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
382
+ var aPtr: u32 = cRow * {{BM}} * {{K}};
383
+ var bPtr: u32 = cCol * {{BN}} * {{K}};
384
+ let cPtr: u32 = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
382
385
383
386
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
384
387
@@ -455,7 +458,7 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
455
458
{" {{NUM_TILEB}}" , toString (BN * BK / num_threads)},
456
459
{" {{TN4}}" , toString (TN / 4 )},
457
460
{" {{N4}}" , toString (N / 4 )},
458
- {" {{BN4}}" , toString (BN / 4 )},
461
+ {" {{BN4}}" , toString (BN / 4 )}
459
462
});
460
463
if (unrolling) {
461
464
std::string unrolledCode = loopUnrolling (codeString);
@@ -466,6 +469,123 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
466
469
}
467
470
}
468
471
472
+ /* 2D block-tiling with transpose
473
+ *
474
+ */
475
+ static const char *kShaderMatmulWithTranspose = R"(
476
+ @group(0) @binding(0) var<storage, read_write> a: array<{{precision}}>;
477
+ @group(0) @binding(1) var<storage, read_write> b: array<{{precision}}>;
478
+ @group(0) @binding(2) var<storage, read_write> c: array<vec4<{{precision}}>>;
479
+ var<workgroup> tileA: array<{{precision}}, {{BM}} * {{BK}}>;
480
+ var<workgroup> tileB: array<{{precision}}, {{BN}} * {{BK}}>;
481
+
482
+ @compute @workgroup_size({{workgroupSize}})
483
+ fn main(
484
+ @builtin(global_invocation_id) globalID : vec3<u32>,
485
+ @builtin(local_invocation_id) localID : vec3<u32>,
486
+ @builtin(workgroup_id) groupid : vec3<u32>) {
487
+
488
+ var threadResults: array<vec4<{{precision}}>, {{TM}} * {{TN4}}>;
489
+ var localM: array<{{precision}}, {{TM}}>;
490
+ var localN: array<vec4<{{precision}}>, {{TN4}}>;
491
+
492
+ let cRow: u32 = groupid.x;
493
+ let cCol: u32 = groupid.y;
494
+ let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}});
495
+
496
+ // position of the first c element computed by the thread
497
+ let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}};
498
+ let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}};
499
+
500
+ // aPtr and bPtr are the starting positions of the tiles in a and b,
501
+ // incremented in the bkidx loop.
502
+ // cPtr is the starting position of the tile in c which is fixed.
503
+
504
+ var aPtr: u32 = cRow * {{BM}} * {{K}};
505
+ var bPtr: u32 = cCol * {{BN}};
506
+ let cPtr: u32 = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
507
+
508
+ for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
509
+
510
+ // Load tile
511
+ // Load BM x BK by numThread(BM * BN / (TM * TN))
512
+ // The number of iteration == BM * BK / (BM * BN / (TM * TN))
513
+ for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) {
514
+ tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}];
515
+ }
516
+ // Load BK x BN by numThread(BM * BN / (TM * TN))
517
+ // The number of iteration == BK * BN / (BM * BN / (TM * TN))
518
+ for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
519
+ tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BN}}) * {{N}} + ((localID.x + idx * numThread) % {{BN}})];
520
+ }
521
+
522
+ aPtr += {{BK}};
523
+ bPtr += {{BK}} * {{N}};
524
+
525
+ workgroupBarrier();
526
+ // Compute tile
527
+ for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
528
+ for (var idx: u32 = 0; idx < {{TM}}; idx++) {
529
+ localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
530
+ }
531
+ for (var idx: u32 = 0; idx < {{TN4}}; idx++) {
532
+ localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) + dotIdx * {{BN}}],
533
+ tileB[(threadCol + idx*4 + 1) + dotIdx * {{BN}}],
534
+ tileB[(threadCol + idx*4 + 2) + dotIdx * {{BN}}],
535
+ tileB[(threadCol + idx*4 + 3) + dotIdx * {{BN}}]);
536
+ }
537
+ for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
538
+ for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
539
+ threadResults[resIdxM * {{TN4}} + resIdxN] += localM[resIdxM] * localN[resIdxN];
540
+ }
541
+ }
542
+ }
543
+ workgroupBarrier();
544
+ }
545
+
546
+ for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
547
+ for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
548
+ c[cPtr + (threadRow + resIdxM) * {{N4}} + (threadCol/4) + resIdxN] = threadResults[resIdxM * {{TN4}} + resIdxN];
549
+ }
550
+ }
551
+ }
552
+ )" ;
553
+
554
+ inline KernelCode createMatmulWithTranspose (const char *shaderTemplate, const size_t M,
555
+ const size_t K, const size_t N, const size_t BM,
556
+ const size_t BK, const size_t BN,
557
+ const size_t TM, const size_t TN,
558
+ const Shape &workgroupSize = {256 , 1 , 1 },
559
+ NumType precision = kf32) {
560
+ assert (BM % TM == 0 );
561
+ assert (BN % TN == 0 );
562
+ assert (K % BK == 0 );
563
+ assert (M % BM == 0 );
564
+ assert (N % BN == 0 );
565
+ // # threads = tile A size == tile B size == # threads for computing C
566
+ int num_threads = BM * BN / (TM * TN);
567
+ std::string codeString (shaderTemplate);
568
+ replaceAll (codeString, {{" {{workgroupSize}}" , toString (workgroupSize)},
569
+ {" {{precision}}" , toString (precision)},
570
+ {" {{M}}" , toString (M)},
571
+ {" {{K}}" , toString (K)},
572
+ {" {{N}}" , toString (N)},
573
+ {" {{BM}}" , toString (BM)},
574
+ {" {{BK}}" , toString (BK)},
575
+ {" {{BN}}" , toString (BN)},
576
+ {" {{TM}}" , toString (TM)},
577
+ {" {{TN}}" , toString (TN)},
578
+ {" {{NUM_TILEA}}" , toString (BM * BK / num_threads)},
579
+ {" {{NUM_TILEB}}" , toString (BN * BK / num_threads)},
580
+ {" {{TN4}}" , toString (TN / 4 )},
581
+ {" {{N4}}" , toString (N / 4 )},
582
+ {" {{BN4}}" , toString (BN / 4 )},
583
+ });
584
+ std::string unrolledCode = loopUnrolling (codeString);
585
+ // LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
586
+ return {unrolledCode, workgroupSize};
587
+ }
588
+
469
589
/* *
470
590
* @brief No-Op shader with matmul bindings for performance testing
471
591
*/
@@ -519,20 +639,26 @@ Kernel selectMatmul(Context &ctx, int version,
519
639
size_t M, size_t K, size_t N) {
520
640
Kernel kernel;
521
641
if (version == 1 ) {
642
+ Shape wgSize = {256 , 1 , 1 };
643
+ Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
644
+ KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
645
+ kernel = createKernel (ctx, matmul, bindings,
646
+ /* nWorkgroups*/ nWorkgroups);
647
+ } else if (version == 2 ) {
522
648
Shape wgSize = {16 , 16 , 1 };
523
649
LOG (kDefLog , kInfo , " wgSize: %s" , toString (wgSize).c_str ());
524
650
KernelCode matmul =
525
651
createMatmul1 (kShaderMatmul1 , M, K, N, /* wgsize*/ wgSize);
526
652
kernel = createKernel (ctx, matmul, bindings,
527
653
/* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize));
528
- } else if (version == 2 ) {
654
+ } else if (version == 3 ) {
529
655
static constexpr size_t tileSize = 16 ;
530
656
KernelCode matmul = createMatmul2 (kShaderMatmul2 , M, K, N,
531
657
/* wgSize*/ {tileSize * tileSize, 1 , 1 });
532
658
kernel =
533
659
createKernel (ctx, matmul, bindings,
534
660
/* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }));
535
- } else if (version == 3 || version == 5 ) {
661
+ } else if (version == 4 || version == 6 ) {
536
662
static constexpr size_t BM = 64 ;
537
663
static constexpr size_t BK = 4 ;
538
664
static constexpr size_t BN = BM;
@@ -548,10 +674,10 @@ Kernel selectMatmul(Context &ctx, int version,
548
674
KernelCode matmul = createMatmul3 (kShaderMatmul3 , M, K, N, BM, BK, BN, TM,
549
675
/* wgSize*/ wgSize,
550
676
kf32,
551
- /* Loop unrolling*/ version == 5 ? true : false );
677
+ /* Loop unrolling*/ version == 6 ? true : false );
552
678
kernel = createKernel (ctx, matmul, bindings,
553
679
/* nWorkgroups*/ nWorkgroups);
554
- } else if (version == 4 || version == 6 ) {
680
+ } else if (version == 5 || version == 7 ) {
555
681
static constexpr size_t BM = 64 ;
556
682
static constexpr size_t BK = 8 ;
557
683
static constexpr size_t BN = 64 ;
@@ -566,10 +692,10 @@ Kernel selectMatmul(Context &ctx, int version,
566
692
KernelCode matmul = createMatmul4 (kShaderMatmul4 , M, K, N, BM, BK, BN, TM, TN,
567
693
/* wgSize*/ wgSize,
568
694
kf32,
569
- /* Loop unrolling*/ version == 6 ? true : false );
695
+ /* Loop unrolling*/ version == 7 ? true : false );
570
696
kernel = createKernel (ctx, matmul, bindings,
571
697
/* nWorkgroups*/ nWorkgroups);
572
- } else if (version == 7 ) {
698
+ } else if (version == 8 ) {
573
699
static constexpr size_t BM = 64 ;
574
700
static constexpr size_t BK = 8 ;
575
701
static constexpr size_t BN = 64 ;
@@ -587,10 +713,21 @@ Kernel selectMatmul(Context &ctx, int version,
587
713
/* Loop unrolling*/ true );
588
714
kernel = createKernel (ctx, matmul, bindings,
589
715
/* nWorkgroups*/ nWorkgroups);
590
- } else if (version == 8 ) {
591
- Shape wgSize = {256 , 1 , 1 };
592
- Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
593
- KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
716
+ } else if (version == 9 ) {
717
+ static constexpr size_t BM = 64 ;
718
+ static constexpr size_t BK = 8 ;
719
+ static constexpr size_t BN = 64 ;
720
+ static constexpr size_t TM = BM / BK;
721
+ static constexpr size_t TN = BN / BK;
722
+ Shape wgSize = {(BM / TM) * (BN / TN), 1 , 1 }; // This is the same as BK * BK.
723
+ Shape nWorkgroups = {cdiv (M, BM), cdiv (N, BN), 1 };
724
+ LOG (kDefLog , kInfo , " M: %d, K: %d, N: %d" , M, K, N);
725
+ LOG (kDefLog , kInfo , " BM: %d, BK: %d, BN: %d, TM: %d, TN: %d" , BM, BK, BN, TM, TN);
726
+ LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
727
+ LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
728
+ KernelCode matmul = createMatmulWithTranspose (kShaderMatmulWithTranspose , M, K, N, BM, BK, BN, TM, TN,
729
+ /* wgSize*/ wgSize,
730
+ kf32);
594
731
kernel = createKernel (ctx, matmul, bindings,
595
732
/* nWorkgroups*/ nWorkgroups);
596
733
}
@@ -626,8 +763,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
626
763
627
764
printf (" [ Press enter to start tests ... ]\n " );
628
765
getchar ();
629
- LOG (kDefLog , kInfo , " Dispatching Kernel version %d, %d iterations ..." ,
630
- version, nIter);
766
+ LOG (kDefLog , kInfo , " Dispatching Kernel version %d: %s , %d iterations ..." ,
767
+ version, versionToStr (version), nIter);
631
768
632
769
// Dispatch kernel nIter times
633
770
auto start = std::chrono::high_resolution_clock::now ();
@@ -662,26 +799,43 @@ void runTest(int version, size_t M, size_t K, size_t N,
662
799
M, K, N, nIter, duration.count () / static_cast <double >(nIter) / 1000.0 /* us -> ms */ , gflops);
663
800
}
664
801
802
+ const char * versionToStr (int version){
803
+ switch (version) {
804
+ case 1 : return " No-Op" ;
805
+ case 2 : return " naive matmul" ;
806
+ case 3 : return " tiling" ;
807
+ case 4 : return " 1D blocktiling" ;
808
+ case 5 : return " 2D blocktiling" ;
809
+ case 6 : return " 1D blocktiling with loop unrolling" ;
810
+ case 7 : return " 2D blocktiling with loop unrolling" ;
811
+ case 8 : return " 2D blocktiling with loop unrolling and vectorization" ;
812
+ case 9 : return " 2D blocktiling with loop unrolling, vectorization and transpose" ;
813
+ default : return " Not specified" ;
814
+ }
815
+ }
816
+
665
817
int main () {
666
818
char * version_str = getenv (" MATMUL_VERSION" );
667
- int version = version_str == NULL ? 7 : atoi (version_str);
668
- // 1 == naive matmul
669
- // 2 == tiling
670
- // 3 == 1D blocktiling
671
- // 4 == 2D blocktiling
672
- // 5 == 1D blocktiling with loop unrolling
673
- // 6 == 2D blocktiling with loop unrolling
674
- // 7 == 2D blocktiling with loop unrolling and vectorization
675
- // 8 == No-Op
819
+ char * kTestSize_str = getenv (" MATMUL_SIZE" );
820
+ int version = version_str == NULL ? 9 : atoi (version_str);
821
+ // 1 == No-Op
822
+ // 2 == naive matmul
823
+ // 3 == tiling
824
+ // 4 == 1D blocktiling
825
+ // 5 == 2D blocktiling
826
+ // 6 == 1D blocktiling with loop unrolling
827
+ // 7 == 2D blocktiling with loop unrolling
828
+ // 8 == 2D blocktiling with loop unrolling and vectorization
829
+ // 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
676
830
677
831
size_t M, K, N; // Matrix dimensions
678
- static constexpr int kTestSize = 2 ;
679
- if constexpr (kTestSize == 0 ) {
832
+ int kTestSize = kTestSize_str == NULL ? 2 : atoi ( kTestSize_str ) ;
833
+ if (kTestSize == 0 ) {
680
834
// Tiny test
681
835
M = 32 ;
682
836
K = 32 ;
683
837
N = 32 ;
684
- } else if constexpr (kTestSize == 1 ) {
838
+ } else if (kTestSize == 1 ) {
685
839
// Small test
686
840
M = 256 ;
687
841
K = 128 ;
@@ -696,11 +850,19 @@ int main() {
696
850
std::unique_ptr<float []> inputPtr = std::make_unique<float []>(M * K);
697
851
std::unique_ptr<float []> weightsPtr = std::make_unique<float []>(N * K);
698
852
std::unique_ptr<float []> outputPtr = std::make_unique<float []>(M * N);
853
+ bool transposedInput = version == 9 ;
699
854
700
855
initData (M, K, N, inputPtr, weightsPtr);
701
- runTest (version, M, K, N, inputPtr, weightsPtr, outputPtr);
856
+ if (transposedInput) {
857
+ std::unique_ptr<float []> transposedWeightPtr = std::make_unique<float []>(K * N);
858
+ transpose (weightsPtr.get (), transposedWeightPtr.get (), N, K);
859
+ runTest (version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
860
+ } else {
861
+ runTest (version, M, K, N, inputPtr, weightsPtr, outputPtr);
862
+ }
863
+
702
864
703
- if constexpr (kTestSize <= 1 ) {
865
+ if (kTestSize <= 1 ) {
704
866
// Check result with CPU reference implementation for tiny/small tests
705
867
checkCPU (M, K, N, inputPtr, weightsPtr, outputPtr);
706
868
}
0 commit comments