11
11
#include " utils/array_utils.h" // show, isclose, randn, randint
12
12
#include " utils/logging.h" // LOG
13
13
#include " experimental/wgsl.h" // loopUnrolling
14
+ #include " numeric_types/half.h"
14
15
15
16
using namespace gpu ;
16
17
17
18
const std::string versionToStr (int version);
18
19
20
+ void matmulf16_forward_cpu (half* out,
21
+ const half* inp, const half* weight, const half* bias,
22
+ int B, int T, int C, int OC) {
23
+ // OC is short for "output channels"
24
+ // inp is (B,T,C), weight is (OC, C)
25
+ // out will be (B,T,OC)
26
+ #pragma omp parallel for collapse(2)
27
+ for (int b = 0 ; b < B; b++) {
28
+ for (int t = 0 ; t < T; t++) {
29
+ half* out_bt = out + b * T * OC + t * OC;
30
+ const half* inp_bt = inp + b * T * C + t * C;
31
+ for (int o = 0 ; o < OC; o++) {
32
+ float val = (bias != NULL ) ? halfToFloat (bias[o]) : 0 .0f ;
33
+ const half* wrow = weight + o*C;
34
+ for (int i = 0 ; i < C; i++) {
35
+ val += halfToFloat (inp_bt[i]) * halfToFloat (wrow[i]);
36
+ }
37
+ out_bt[o] = val;
38
+ }
39
+ }
40
+ }
41
+ }
42
+
19
43
static const char *kShaderMatmul1 = R"(
20
44
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
21
45
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@@ -47,7 +71,7 @@ inline KernelCode createMatmul1(const char *shaderTemplate, const size_t M,
47
71
{" {{M}}" , toString (M)},
48
72
{" {{K}}" , toString (K)},
49
73
{" {{N}}" , toString (N)}});
50
- return {codeString, workgroupSize};
74
+ return {codeString, workgroupSize, precision };
51
75
}
52
76
53
77
// Shared memory cache-blocking
@@ -108,7 +132,7 @@ inline KernelCode createMatmul2(const char *shaderTemplate, const size_t M,
108
132
{" {{N}}" , toString (N)},
109
133
{" {{tileSize}}" ,
110
134
toString (static_cast <size_t >(sqrt (workgroupSize[0 ])))}});
111
- return {codeString, workgroupSize};
135
+ return {codeString, workgroupSize, precision };
112
136
}
113
137
114
138
/* 1D block-tiling
@@ -224,9 +248,9 @@ inline KernelCode createMatmul3(const char *shaderTemplate, const size_t M,
224
248
if (unrolling) {
225
249
std::string unrolledCode = loopUnrolling (codeString);
226
250
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
227
- return {unrolledCode, workgroupSize};
251
+ return {unrolledCode, workgroupSize, precision };
228
252
} else {
229
- return {codeString, workgroupSize};
253
+ return {codeString, workgroupSize, precision };
230
254
}
231
255
}
232
256
@@ -340,9 +364,9 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
340
364
if (unrolling) {
341
365
std::string unrolledCode = loopUnrolling (codeString);
342
366
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
343
- return {unrolledCode, workgroupSize};
367
+ return {unrolledCode, workgroupSize, precision };
344
368
} else {
345
- return {codeString, workgroupSize};
369
+ return {codeString, workgroupSize, precision };
346
370
}
347
371
}
348
372
@@ -462,9 +486,9 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
462
486
if (unrolling) {
463
487
std::string unrolledCode = loopUnrolling (codeString);
464
488
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
465
- return {unrolledCode, workgroupSize};
489
+ return {unrolledCode, workgroupSize, precision };
466
490
} else {
467
- return {codeString, workgroupSize};
491
+ return {codeString, workgroupSize, precision };
468
492
}
469
493
}
470
494
@@ -582,7 +606,7 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
582
606
});
583
607
std::string unrolledCode = loopUnrolling (codeString);
584
608
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
585
- return {unrolledCode, workgroupSize};
609
+ return {unrolledCode, workgroupSize, precision };
586
610
}
587
611
588
612
/* *
@@ -604,7 +628,7 @@ inline KernelCode createNoOp(const char *shaderTemplate,
604
628
std::string codeString (shaderTemplate);
605
629
replaceAll (codeString, {{" {{workgroupSize}}" , toString (workgroupSize)},
606
630
{" {{precision}}" , toString (precision)}});
607
- return {codeString, workgroupSize};
631
+ return {codeString, workgroupSize, precision };
608
632
}
609
633
610
634
void initData (size_t M, size_t K, size_t N, std::unique_ptr<float []> &inputPtr,
@@ -619,23 +643,41 @@ void initData(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
619
643
show<float >(weightsPtr.get (), N, K, " Weights" ).c_str ());
620
644
}
621
645
622
- void checkCPU (size_t M, size_t K, size_t N, std::unique_ptr<float []> &inputPtr,
623
- std::unique_ptr<float []> &weightsPtr,
624
- std::unique_ptr<float []> &outputPtr) {
646
+ void initData (size_t M, size_t K, size_t N, std::unique_ptr<half[]> &inputPtr,
647
+ std::unique_ptr<half[]> &weightsPtr) {
648
+ std::mt19937 gen (314159 );
649
+ randn (inputPtr.get (), M * K, gen);
650
+ randn (weightsPtr.get (), N * K, gen);
651
+ // randint(inputPtr.get(), M * K, gen, 1, 2);
652
+ // randint(weightsPtr.get(), N * K, gen, 1, 2);
653
+ LOG (kDefLog , kInfo , " %s" , show<half>(inputPtr.get (), M, K, " Input" ).c_str ());
654
+ LOG (kDefLog , kInfo , " %s" ,
655
+ show<half>(weightsPtr.get (), N, K, " Weights" ).c_str ());
656
+ }
657
+
658
+ template <class precision =float >
659
+ void checkCPU (size_t M, size_t K, size_t N, std::unique_ptr<precision[]> &inputPtr,
660
+ std::unique_ptr<precision[]> &weightsPtr,
661
+ std::unique_ptr<precision[]> &outputPtr) {
625
662
LOG (kDefLog , kInfo , " Computing CPU reference implementation" );
626
- std::unique_ptr<float []> outputRefPtr = std::make_unique<float []>(M * N);
627
- ref::matmul_forward_cpu (outputRefPtr.get (), inputPtr.get (), weightsPtr.get (),
628
- nullptr , 1 , M, K, N);
663
+ std::unique_ptr<precision[]> outputRefPtr = std::make_unique<precision[]>(M * N);
664
+ if constexpr (std::is_same<precision, float >::value) {
665
+ ref::matmul_forward_cpu (outputRefPtr.get (), inputPtr.get (), weightsPtr.get (),
666
+ nullptr , 1 , M, K, N);
667
+ } else if constexpr (std::is_same<precision, half>::value) {
668
+ matmulf16_forward_cpu (outputRefPtr.get (), inputPtr.get (), weightsPtr.get (),
669
+ nullptr , 1 , M, K, N);
670
+ }
629
671
LOG (kDefLog , kInfo , " Reference Output: %s" ,
630
- show<float >(outputRefPtr.get (), M, N, " Output (Reference)" ).c_str ());
672
+ show<precision >(outputRefPtr.get (), M, N, " Output (Reference)" ).c_str ());
631
673
LOG (kDefLog , kInfo ,
632
674
isclose (outputPtr.get (), outputRefPtr.get (), M * N) ? " CPU Check: PASS"
633
675
: " CPU Check: FAIL" );
634
676
}
635
677
636
678
Kernel selectMatmul (Context &ctx, int version,
637
679
const Bindings</* input, weights, output */ 3 > &bindings,
638
- size_t M, size_t K, size_t N) {
680
+ size_t M, size_t K, size_t N, NumType numtype ) {
639
681
Kernel kernel;
640
682
if (version == 1 ) {
641
683
Shape wgSize = {256 , 1 , 1 };
@@ -647,13 +689,13 @@ Kernel selectMatmul(Context &ctx, int version,
647
689
Shape wgSize = {16 , 16 , 1 };
648
690
LOG (kDefLog , kInfo , " wgSize: %s" , toString (wgSize).c_str ());
649
691
KernelCode matmul =
650
- createMatmul1 (kShaderMatmul1 , M, K, N, /* wgsize*/ wgSize);
692
+ createMatmul1 (kShaderMatmul1 , M, K, N, /* wgsize*/ wgSize, numtype );
651
693
kernel = createKernel (ctx, matmul, bindings,
652
694
/* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize));
653
695
} else if (version == 3 ) {
654
696
static constexpr size_t tileSize = 16 ;
655
697
KernelCode matmul = createMatmul2 (kShaderMatmul2 , M, K, N,
656
- /* wgSize*/ {tileSize * tileSize, 1 , 1 });
698
+ /* wgSize*/ {tileSize * tileSize, 1 , 1 }, numtype );
657
699
kernel =
658
700
createKernel (ctx, matmul, bindings,
659
701
/* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }));
@@ -672,7 +714,7 @@ Kernel selectMatmul(Context &ctx, int version,
672
714
LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
673
715
KernelCode matmul = createMatmul3 (kShaderMatmul3 , M, K, N, BM, BK, BN, TM,
674
716
/* wgSize*/ wgSize,
675
- kf32 ,
717
+ numtype ,
676
718
/* Loop unrolling*/ version == 6 ? true : false );
677
719
kernel = createKernel (ctx, matmul, bindings,
678
720
/* nWorkgroups*/ nWorkgroups);
@@ -690,11 +732,11 @@ Kernel selectMatmul(Context &ctx, int version,
690
732
LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
691
733
KernelCode matmul = createMatmul4 (kShaderMatmul4 , M, K, N, BM, BK, BN, TM, TN,
692
734
/* wgSize*/ wgSize,
693
- kf32 ,
735
+ numtype ,
694
736
/* Loop unrolling*/ version == 7 ? true : false );
695
737
kernel = createKernel (ctx, matmul, bindings,
696
738
/* nWorkgroups*/ nWorkgroups);
697
- } else if (version == 8 ) {
739
+ } else if (version == 8 || version == 10 ) {
698
740
static constexpr size_t BM = 64 ;
699
741
static constexpr size_t BK = 8 ;
700
742
static constexpr size_t BN = 64 ;
@@ -708,11 +750,11 @@ Kernel selectMatmul(Context &ctx, int version,
708
750
LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
709
751
KernelCode matmul = createMatmulWithVectorization (kShaderMatmulWithVectorization , M, K, N, BM, BK, BN, TM, TN,
710
752
/* wgSize*/ wgSize,
711
- kf32 ,
753
+ numtype ,
712
754
/* Loop unrolling*/ true );
713
755
kernel = createKernel (ctx, matmul, bindings,
714
756
/* nWorkgroups*/ nWorkgroups);
715
- } else if (version == 9 ) {
757
+ } else if (version == 9 || version == 11 ) {
716
758
static constexpr size_t BM = 64 ;
717
759
static constexpr size_t BK = 8 ;
718
760
static constexpr size_t BN = 64 ;
@@ -726,23 +768,36 @@ Kernel selectMatmul(Context &ctx, int version,
726
768
LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
727
769
KernelCode matmul = createMatmulWithTranspose (kShaderMatmulWithTranspose , M, K, N, BM, BK, BN, TM, TN,
728
770
/* wgSize*/ wgSize,
729
- kf32 );
771
+ numtype );
730
772
kernel = createKernel (ctx, matmul, bindings,
731
773
/* nWorkgroups*/ nWorkgroups);
732
774
}
733
775
return kernel;
734
776
}
735
777
778
+ template <class precision =float >
736
779
void runTest (int version, size_t M, size_t K, size_t N,
737
- std::unique_ptr<float []> &inputPtr,
738
- std::unique_ptr<float []> &weightsPtr,
739
- std::unique_ptr<float []> &outputPtr) {
780
+ std::unique_ptr<precision[]> &inputPtr,
781
+ std::unique_ptr<precision[]> &weightsPtr,
782
+ std::unique_ptr<precision[]> &outputPtr,
783
+ NumType numtype) {
784
+ if constexpr (std::is_same<precision, float >::value) {
785
+ assert (numtype == kf32);
786
+ } else if constexpr (std::is_same<precision, half>::value) {
787
+ assert (numtype == kf16);
788
+ }
740
789
741
790
// Allocate GPU buffers and copy data
742
- Context ctx = createContext ();
743
- Tensor input = createTensor (ctx, Shape{M, K}, kf32, inputPtr.get ());
744
- Tensor weights =
745
- createTensor (ctx, Shape{N, K}, kf32, weightsPtr.get ()); // column-major
791
+ Context ctx = createContext (
792
+ {}, {},
793
+ /* device descriptor, enabling f16 in WGSL*/
794
+ {
795
+ .requiredFeatureCount = 1 ,
796
+ .requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data (),
797
+ });
798
+
799
+ Tensor input = createTensor (ctx, Shape{M, K}, numtype, inputPtr.get ());
800
+ Tensor weights = createTensor (ctx, Shape{N, K}, numtype, weightsPtr.get ()); // column-major
746
801
747
802
constexpr size_t nIter = 30 ;
748
803
@@ -756,8 +811,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
756
811
std::array<Tensor, nIter> outputs;
757
812
for (int i = 0 ; i < nIter; i++) {
758
813
futures[i] = promises[i].get_future ();
759
- outputs[i] = createTensor (ctx, Shape{M, N}, kf32 );
760
- kernels[i] = selectMatmul (ctx, version, {input, weights, outputs[i]}, M, K, N);
814
+ outputs[i] = createTensor (ctx, Shape{M, N}, numtype );
815
+ kernels[i] = selectMatmul (ctx, version, {input, weights, outputs[i]}, M, K, N, numtype );
761
816
}
762
817
763
818
printf (" [ Press enter to start tests ... ]\n " );
@@ -785,9 +840,9 @@ void runTest(int version, size_t M, size_t K, size_t N,
785
840
1000000000.0 * static_cast <float >(nIter);
786
841
787
842
LOG (kDefLog , kInfo , " Copying result to CPU" );
788
- toCPU (ctx, outputs[0 ], outputPtr.get (), M * N * sizeof (float ));
843
+ toCPU (ctx, outputs[0 ], outputPtr.get (), M * N * sizeof (precision ));
789
844
LOG (kDefLog , kInfo , " %s" ,
790
- show<float >(outputPtr.get (), M, N, " Output[0]" ).c_str ());
845
+ show<precision >(outputPtr.get (), M, N, " Output[0]" ).c_str ());
791
846
792
847
LOG (kDefLog , kInfo , " \n\n ===================================================================="
793
848
" ============\n Execution Time: (M = %d, K = %d, N = %d) x %d iterations "
@@ -798,33 +853,62 @@ void runTest(int version, size_t M, size_t K, size_t N,
798
853
M, K, N, nIter, duration.count () / static_cast <double >(nIter) / 1000.0 /* us -> ms */ , gflops);
799
854
}
800
855
856
+ template <class precision =float >
857
+ void runTestWithCheck (int version, size_t M, size_t K, size_t N,
858
+ bool transposedInput, int kTestSize , NumType numtype) {
859
+ std::unique_ptr<precision[]> inputPtr = std::make_unique<precision[]>(M * K);
860
+ std::unique_ptr<precision[]> weightsPtr = std::make_unique<precision[]>(N * K);
861
+ std::unique_ptr<precision[]> outputPtr = std::make_unique<precision[]>(M * N);
862
+
863
+ initData (M, K, N, inputPtr, weightsPtr);
864
+ if (transposedInput) {
865
+ std::unique_ptr<precision[]> transposedWeightPtr = std::make_unique<precision[]>(K * N);
866
+ transpose (weightsPtr.get (), transposedWeightPtr.get (), N, K);
867
+ runTest (version, M, K, N, inputPtr, transposedWeightPtr, outputPtr, numtype);
868
+ } else {
869
+ runTest (version, M, K, N, inputPtr, weightsPtr, outputPtr, numtype);
870
+ }
871
+
872
+ if (kTestSize <= 1 ) {
873
+ // Check result with CPU reference implementation for tiny/small tests
874
+ checkCPU (M, K, N, inputPtr, weightsPtr, outputPtr);
875
+ }
876
+ }
877
+
801
878
const std::string versionToStr (int version){
802
879
switch (version) {
803
- case 1 : return " No-Op" ;
804
- case 2 : return " naive matmul" ;
805
- case 3 : return " tiling" ;
806
- case 4 : return " 1D blocktiling" ;
807
- case 5 : return " 2D blocktiling" ;
808
- case 6 : return " 1D blocktiling with loop unrolling" ;
809
- case 7 : return " 2D blocktiling with loop unrolling" ;
810
- case 8 : return " 2D blocktiling with loop unrolling and vectorization" ;
811
- case 9 : return " 2D blocktiling with loop unrolling, vectorization and transpose" ;
880
+ case 1 : return " f32: No-Op" ;
881
+ case 2 : return " f32: naive matmul" ;
882
+ case 3 : return " f32: tiling" ;
883
+ case 4 : return " f32: 1D blocktiling" ;
884
+ case 5 : return " f32: 2D blocktiling" ;
885
+ case 6 : return " f32: 1D blocktiling with loop unrolling" ;
886
+ case 7 : return " f32: 2D blocktiling with loop unrolling" ;
887
+ case 8 : return " f32: 2D blocktiling with loop unrolling and vectorization" ;
888
+ case 9 : return " f32: 2D blocktiling with loop unrolling, vectorization and transpose" ;
889
+ case 10 : return " f16: 2D blocktiling with loop unrolling and vectorization" ;
890
+ case 11 : return " f16: 2D blocktiling with loop unrolling, vectorization and transpose" ;
812
891
default : return " Not specified" ;
813
892
}
814
893
}
815
894
816
895
int main () {
817
896
char * version_str = getenv (" MATMUL_VERSION" );
818
- int version = version_str == NULL ? 9 : atoi (version_str);
819
- // 1 == No-Op
820
- // 2 == naive matmul
821
- // 3 == tiling
822
- // 4 == 1D blocktiling
823
- // 5 == 2D blocktiling
824
- // 6 == 1D blocktiling with loop unrolling
825
- // 7 == 2D blocktiling with loop unrolling
826
- // 8 == 2D blocktiling with loop unrolling and vectorization
827
- // 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
897
+ int version = version_str == NULL ? 10 : atoi (version_str);
898
+ // 1 == f32: No-Op
899
+ // 2 == f32: naive matmul
900
+ // 3 == f32: tiling
901
+ // 4 == f32: 1D blocktiling
902
+ // 5 == f32: 2D blocktiling
903
+ // 6 == f32: 1D blocktiling with loop unrolling
904
+ // 7 == f32: 2D blocktiling with loop unrolling
905
+ // 8 == f32: 2D blocktiling with loop unrolling and vectorization
906
+ // 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
907
+ // 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
908
+ // 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
909
+ bool enableF16 = version == 10 || version ==11 ;
910
+ bool transposedInput = version == 9 || version == 11 ;
911
+ NumType numtype = enableF16 ? kf16 : kf32;
828
912
829
913
size_t M, K, N; // Matrix dimensions
830
914
char * kTestSize_str = getenv (" MATMUL_SIZE" );
@@ -846,24 +930,10 @@ int main() {
846
930
N = 2 * 4096 ;
847
931
}
848
932
849
- std::unique_ptr<float []> inputPtr = std::make_unique<float []>(M * K);
850
- std::unique_ptr<float []> weightsPtr = std::make_unique<float []>(N * K);
851
- std::unique_ptr<float []> outputPtr = std::make_unique<float []>(M * N);
852
- bool transposedInput = version == 9 ;
853
-
854
- initData (M, K, N, inputPtr, weightsPtr);
855
- if (transposedInput) {
856
- std::unique_ptr<float []> transposedWeightPtr = std::make_unique<float []>(K * N);
857
- transpose (weightsPtr.get (), transposedWeightPtr.get (), N, K);
858
- runTest (version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
933
+ if (enableF16) {
934
+ runTestWithCheck<half>(version, M, K, N, transposedInput, kTestSize , numtype);
859
935
} else {
860
- runTest (version, M, K, N, inputPtr, weightsPtr, outputPtr);
861
- }
862
-
863
-
864
- if (kTestSize <= 1 ) {
865
- // Check result with CPU reference implementation for tiny/small tests
866
- checkCPU (M, K, N, inputPtr, weightsPtr, outputPtr);
936
+ runTestWithCheck<float >(version, M, K, N, transposedInput, kTestSize , numtype);
867
937
}
868
938
869
939
LOG (kDefLog , kInfo , " Done." );
0 commit comments