Skip to content

Commit 228084f

Browse files
authored
Merge pull request #39 from junjihashimoto/feature/matmul-f16
Add matmul with float16
2 parents 305c25a + 23dd96e commit 228084f

File tree

3 files changed

+189
-77
lines changed

3 files changed

+189
-77
lines changed

examples/matmul/run.cpp

+143-73
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,35 @@
1111
#include "utils/array_utils.h" // show, isclose, randn, randint
1212
#include "utils/logging.h" // LOG
1313
#include "experimental/wgsl.h" // loopUnrolling
14+
#include "numeric_types/half.h"
1415

1516
using namespace gpu;
1617

1718
const std::string versionToStr(int version);
1819

20+
void matmulf16_forward_cpu(half* out,
21+
const half* inp, const half* weight, const half* bias,
22+
int B, int T, int C, int OC) {
23+
// OC is short for "output channels"
24+
// inp is (B,T,C), weight is (OC, C)
25+
// out will be (B,T,OC)
26+
#pragma omp parallel for collapse(2)
27+
for (int b = 0; b < B; b++) {
28+
for (int t = 0; t < T; t++) {
29+
half* out_bt = out + b * T * OC + t * OC;
30+
const half* inp_bt = inp + b * T * C + t * C;
31+
for (int o = 0; o < OC; o++) {
32+
float val = (bias != NULL) ? halfToFloat(bias[o]) : 0.0f;
33+
const half* wrow = weight + o*C;
34+
for (int i = 0; i < C; i++) {
35+
val += halfToFloat(inp_bt[i]) * halfToFloat(wrow[i]);
36+
}
37+
out_bt[o] = val;
38+
}
39+
}
40+
}
41+
}
42+
1943
static const char *kShaderMatmul1 = R"(
2044
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
2145
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@@ -47,7 +71,7 @@ inline KernelCode createMatmul1(const char *shaderTemplate, const size_t M,
4771
{"{{M}}", toString(M)},
4872
{"{{K}}", toString(K)},
4973
{"{{N}}", toString(N)}});
50-
return {codeString, workgroupSize};
74+
return {codeString, workgroupSize, precision};
5175
}
5276

5377
// Shared memory cache-blocking
@@ -108,7 +132,7 @@ inline KernelCode createMatmul2(const char *shaderTemplate, const size_t M,
108132
{"{{N}}", toString(N)},
109133
{"{{tileSize}}",
110134
toString(static_cast<size_t>(sqrt(workgroupSize[0])))}});
111-
return {codeString, workgroupSize};
135+
return {codeString, workgroupSize, precision};
112136
}
113137

114138
/* 1D block-tiling
@@ -224,9 +248,9 @@ inline KernelCode createMatmul3(const char *shaderTemplate, const size_t M,
224248
if (unrolling) {
225249
std::string unrolledCode = loopUnrolling(codeString);
226250
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
227-
return {unrolledCode, workgroupSize};
251+
return {unrolledCode, workgroupSize, precision};
228252
} else {
229-
return {codeString, workgroupSize};
253+
return {codeString, workgroupSize, precision};
230254
}
231255
}
232256

@@ -340,9 +364,9 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
340364
if (unrolling) {
341365
std::string unrolledCode = loopUnrolling(codeString);
342366
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
343-
return {unrolledCode, workgroupSize};
367+
return {unrolledCode, workgroupSize, precision};
344368
} else {
345-
return {codeString, workgroupSize};
369+
return {codeString, workgroupSize, precision};
346370
}
347371
}
348372

@@ -462,9 +486,9 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
462486
if (unrolling) {
463487
std::string unrolledCode = loopUnrolling(codeString);
464488
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
465-
return {unrolledCode, workgroupSize};
489+
return {unrolledCode, workgroupSize, precision};
466490
} else {
467-
return {codeString, workgroupSize};
491+
return {codeString, workgroupSize, precision};
468492
}
469493
}
470494

@@ -582,7 +606,7 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
582606
});
583607
std::string unrolledCode = loopUnrolling(codeString);
584608
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
585-
return {unrolledCode, workgroupSize};
609+
return {unrolledCode, workgroupSize, precision};
586610
}
587611

588612
/**
@@ -604,7 +628,7 @@ inline KernelCode createNoOp(const char *shaderTemplate,
604628
std::string codeString(shaderTemplate);
605629
replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
606630
{"{{precision}}", toString(precision)}});
607-
return {codeString, workgroupSize};
631+
return {codeString, workgroupSize, precision};
608632
}
609633

610634
void initData(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
@@ -619,23 +643,41 @@ void initData(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
619643
show<float>(weightsPtr.get(), N, K, "Weights").c_str());
620644
}
621645

622-
void checkCPU(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
623-
std::unique_ptr<float[]> &weightsPtr,
624-
std::unique_ptr<float[]> &outputPtr) {
646+
void initData(size_t M, size_t K, size_t N, std::unique_ptr<half[]> &inputPtr,
647+
std::unique_ptr<half[]> &weightsPtr) {
648+
std::mt19937 gen(314159);
649+
randn(inputPtr.get(), M * K, gen);
650+
randn(weightsPtr.get(), N * K, gen);
651+
// randint(inputPtr.get(), M * K, gen, 1, 2);
652+
// randint(weightsPtr.get(), N * K, gen, 1, 2);
653+
LOG(kDefLog, kInfo, "%s", show<half>(inputPtr.get(), M, K, "Input").c_str());
654+
LOG(kDefLog, kInfo, "%s",
655+
show<half>(weightsPtr.get(), N, K, "Weights").c_str());
656+
}
657+
658+
template<class precision=float>
659+
void checkCPU(size_t M, size_t K, size_t N, std::unique_ptr<precision[]> &inputPtr,
660+
std::unique_ptr<precision[]> &weightsPtr,
661+
std::unique_ptr<precision[]> &outputPtr) {
625662
LOG(kDefLog, kInfo, "Computing CPU reference implementation");
626-
std::unique_ptr<float[]> outputRefPtr = std::make_unique<float[]>(M * N);
627-
ref::matmul_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(),
628-
nullptr, 1, M, K, N);
663+
std::unique_ptr<precision[]> outputRefPtr = std::make_unique<precision[]>(M * N);
664+
if constexpr (std::is_same<precision, float>::value) {
665+
ref::matmul_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(),
666+
nullptr, 1, M, K, N);
667+
} else if constexpr (std::is_same<precision, half>::value) {
668+
matmulf16_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(),
669+
nullptr, 1, M, K, N);
670+
}
629671
LOG(kDefLog, kInfo, "Reference Output: %s",
630-
show<float>(outputRefPtr.get(), M, N, "Output (Reference)").c_str());
672+
show<precision>(outputRefPtr.get(), M, N, "Output (Reference)").c_str());
631673
LOG(kDefLog, kInfo,
632674
isclose(outputPtr.get(), outputRefPtr.get(), M * N) ? "CPU Check: PASS"
633675
: "CPU Check: FAIL");
634676
}
635677

636678
Kernel selectMatmul(Context &ctx, int version,
637679
const Bindings</* input, weights, output */ 3> &bindings,
638-
size_t M, size_t K, size_t N) {
680+
size_t M, size_t K, size_t N, NumType numtype) {
639681
Kernel kernel;
640682
if (version == 1) {
641683
Shape wgSize = {256, 1, 1};
@@ -647,13 +689,13 @@ Kernel selectMatmul(Context &ctx, int version,
647689
Shape wgSize = {16, 16, 1};
648690
LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str());
649691
KernelCode matmul =
650-
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize);
692+
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize, numtype);
651693
kernel = createKernel(ctx, matmul, bindings,
652694
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize));
653695
} else if (version == 3) {
654696
static constexpr size_t tileSize = 16;
655697
KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N,
656-
/*wgSize*/ {tileSize * tileSize, 1, 1});
698+
/*wgSize*/ {tileSize * tileSize, 1, 1}, numtype);
657699
kernel =
658700
createKernel(ctx, matmul, bindings,
659701
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
@@ -672,7 +714,7 @@ Kernel selectMatmul(Context &ctx, int version,
672714
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
673715
KernelCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM,
674716
/*wgSize*/ wgSize,
675-
kf32,
717+
numtype,
676718
/*Loop unrolling*/ version == 6 ? true: false);
677719
kernel = createKernel(ctx, matmul, bindings,
678720
/*nWorkgroups*/ nWorkgroups);
@@ -690,11 +732,11 @@ Kernel selectMatmul(Context &ctx, int version,
690732
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
691733
KernelCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN,
692734
/*wgSize*/ wgSize,
693-
kf32,
735+
numtype,
694736
/*Loop unrolling*/ version == 7 ? true: false);
695737
kernel = createKernel(ctx, matmul, bindings,
696738
/*nWorkgroups*/ nWorkgroups);
697-
} else if (version == 8) {
739+
} else if (version == 8 || version == 10) {
698740
static constexpr size_t BM = 64;
699741
static constexpr size_t BK = 8;
700742
static constexpr size_t BN = 64;
@@ -708,11 +750,11 @@ Kernel selectMatmul(Context &ctx, int version,
708750
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
709751
KernelCode matmul = createMatmulWithVectorization(kShaderMatmulWithVectorization, M, K, N, BM, BK, BN, TM, TN,
710752
/*wgSize*/ wgSize,
711-
kf32,
753+
numtype,
712754
/*Loop unrolling*/ true);
713755
kernel = createKernel(ctx, matmul, bindings,
714756
/*nWorkgroups*/ nWorkgroups);
715-
} else if (version == 9) {
757+
} else if (version == 9 || version == 11) {
716758
static constexpr size_t BM = 64;
717759
static constexpr size_t BK = 8;
718760
static constexpr size_t BN = 64;
@@ -726,23 +768,36 @@ Kernel selectMatmul(Context &ctx, int version,
726768
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
727769
KernelCode matmul = createMatmulWithTranspose(kShaderMatmulWithTranspose, M, K, N, BM, BK, BN, TM, TN,
728770
/*wgSize*/ wgSize,
729-
kf32);
771+
numtype);
730772
kernel = createKernel(ctx, matmul, bindings,
731773
/*nWorkgroups*/ nWorkgroups);
732774
}
733775
return kernel;
734776
}
735777

778+
template<class precision=float>
736779
void runTest(int version, size_t M, size_t K, size_t N,
737-
std::unique_ptr<float[]> &inputPtr,
738-
std::unique_ptr<float[]> &weightsPtr,
739-
std::unique_ptr<float[]> &outputPtr) {
780+
std::unique_ptr<precision[]> &inputPtr,
781+
std::unique_ptr<precision[]> &weightsPtr,
782+
std::unique_ptr<precision[]> &outputPtr,
783+
NumType numtype) {
784+
if constexpr (std::is_same<precision, float>::value) {
785+
assert(numtype == kf32);
786+
} else if constexpr (std::is_same<precision, half>::value) {
787+
assert(numtype == kf16);
788+
}
740789

741790
// Allocate GPU buffers and copy data
742-
Context ctx = createContext();
743-
Tensor input = createTensor(ctx, Shape{M, K}, kf32, inputPtr.get());
744-
Tensor weights =
745-
createTensor(ctx, Shape{N, K}, kf32, weightsPtr.get()); // column-major
791+
Context ctx = createContext(
792+
{}, {},
793+
/*device descriptor, enabling f16 in WGSL*/
794+
{
795+
.requiredFeatureCount = 1,
796+
.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data(),
797+
});
798+
799+
Tensor input = createTensor(ctx, Shape{M, K}, numtype, inputPtr.get());
800+
Tensor weights = createTensor(ctx, Shape{N, K}, numtype, weightsPtr.get()); // column-major
746801

747802
constexpr size_t nIter = 30;
748803

@@ -756,8 +811,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
756811
std::array<Tensor, nIter> outputs;
757812
for (int i = 0; i < nIter; i++) {
758813
futures[i] = promises[i].get_future();
759-
outputs[i] = createTensor(ctx, Shape{M, N}, kf32);
760-
kernels[i] = selectMatmul(ctx, version, {input, weights, outputs[i]}, M, K, N);
814+
outputs[i] = createTensor(ctx, Shape{M, N}, numtype);
815+
kernels[i] = selectMatmul(ctx, version, {input, weights, outputs[i]}, M, K, N, numtype);
761816
}
762817

763818
printf("[ Press enter to start tests ... ]\n");
@@ -785,9 +840,9 @@ void runTest(int version, size_t M, size_t K, size_t N,
785840
1000000000.0 * static_cast<float>(nIter);
786841

787842
LOG(kDefLog, kInfo, "Copying result to CPU");
788-
toCPU(ctx, outputs[0], outputPtr.get(), M * N * sizeof(float));
843+
toCPU(ctx, outputs[0], outputPtr.get(), M * N * sizeof(precision));
789844
LOG(kDefLog, kInfo, "%s",
790-
show<float>(outputPtr.get(), M, N, "Output[0]").c_str());
845+
show<precision>(outputPtr.get(), M, N, "Output[0]").c_str());
791846

792847
LOG(kDefLog, kInfo, "\n\n===================================================================="
793848
"============\nExecution Time: (M = %d, K = %d, N = %d) x %d iterations "
@@ -798,33 +853,62 @@ void runTest(int version, size_t M, size_t K, size_t N,
798853
M, K, N, nIter, duration.count() / static_cast<double>(nIter) / 1000.0 /* us -> ms */, gflops);
799854
}
800855

856+
template<class precision=float>
857+
void runTestWithCheck(int version, size_t M, size_t K, size_t N,
858+
bool transposedInput, int kTestSize, NumType numtype) {
859+
std::unique_ptr<precision[]> inputPtr = std::make_unique<precision[]>(M * K);
860+
std::unique_ptr<precision[]> weightsPtr = std::make_unique<precision[]>(N * K);
861+
std::unique_ptr<precision[]> outputPtr = std::make_unique<precision[]>(M * N);
862+
863+
initData(M, K, N, inputPtr, weightsPtr);
864+
if (transposedInput) {
865+
std::unique_ptr<precision[]> transposedWeightPtr = std::make_unique<precision[]>(K * N);
866+
transpose(weightsPtr.get(), transposedWeightPtr.get(), N, K);
867+
runTest(version, M, K, N, inputPtr, transposedWeightPtr, outputPtr, numtype);
868+
} else {
869+
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr, numtype);
870+
}
871+
872+
if (kTestSize <= 1) {
873+
// Check result with CPU reference implementation for tiny/small tests
874+
checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr);
875+
}
876+
}
877+
801878
const std::string versionToStr(int version){
802879
switch (version) {
803-
case 1: return "No-Op";
804-
case 2: return "naive matmul";
805-
case 3: return "tiling";
806-
case 4: return "1D blocktiling";
807-
case 5: return "2D blocktiling";
808-
case 6: return "1D blocktiling with loop unrolling";
809-
case 7: return "2D blocktiling with loop unrolling";
810-
case 8: return "2D blocktiling with loop unrolling and vectorization";
811-
case 9: return "2D blocktiling with loop unrolling, vectorization and transpose";
880+
case 1: return "f32: No-Op";
881+
case 2: return "f32: naive matmul";
882+
case 3: return "f32: tiling";
883+
case 4: return "f32: 1D blocktiling";
884+
case 5: return "f32: 2D blocktiling";
885+
case 6: return "f32: 1D blocktiling with loop unrolling";
886+
case 7: return "f32: 2D blocktiling with loop unrolling";
887+
case 8: return "f32: 2D blocktiling with loop unrolling and vectorization";
888+
case 9: return "f32: 2D blocktiling with loop unrolling, vectorization and transpose";
889+
case 10: return "f16: 2D blocktiling with loop unrolling and vectorization";
890+
case 11: return "f16: 2D blocktiling with loop unrolling, vectorization and transpose";
812891
default: return "Not specified";
813892
}
814893
}
815894

816895
int main() {
817896
char* version_str = getenv("MATMUL_VERSION");
818-
int version = version_str == NULL ? 9 : atoi(version_str);
819-
// 1 == No-Op
820-
// 2 == naive matmul
821-
// 3 == tiling
822-
// 4 == 1D blocktiling
823-
// 5 == 2D blocktiling
824-
// 6 == 1D blocktiling with loop unrolling
825-
// 7 == 2D blocktiling with loop unrolling
826-
// 8 == 2D blocktiling with loop unrolling and vectorization
827-
// 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
897+
int version = version_str == NULL ? 10 : atoi(version_str);
898+
// 1 == f32: No-Op
899+
// 2 == f32: naive matmul
900+
// 3 == f32: tiling
901+
// 4 == f32: 1D blocktiling
902+
// 5 == f32: 2D blocktiling
903+
// 6 == f32: 1D blocktiling with loop unrolling
904+
// 7 == f32: 2D blocktiling with loop unrolling
905+
// 8 == f32: 2D blocktiling with loop unrolling and vectorization
906+
// 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
907+
// 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
908+
// 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
909+
bool enableF16 = version == 10 || version ==11;
910+
bool transposedInput = version == 9 || version == 11;
911+
NumType numtype = enableF16 ? kf16 : kf32;
828912

829913
size_t M, K, N; // Matrix dimensions
830914
char* kTestSize_str = getenv("MATMUL_SIZE");
@@ -846,24 +930,10 @@ int main() {
846930
N = 2 * 4096;
847931
}
848932

849-
std::unique_ptr<float[]> inputPtr = std::make_unique<float[]>(M * K);
850-
std::unique_ptr<float[]> weightsPtr = std::make_unique<float[]>(N * K);
851-
std::unique_ptr<float[]> outputPtr = std::make_unique<float[]>(M * N);
852-
bool transposedInput = version == 9;
853-
854-
initData(M, K, N, inputPtr, weightsPtr);
855-
if (transposedInput) {
856-
std::unique_ptr<float[]> transposedWeightPtr = std::make_unique<float[]>(K * N);
857-
transpose(weightsPtr.get(), transposedWeightPtr.get(), N, K);
858-
runTest(version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
933+
if (enableF16) {
934+
runTestWithCheck<half>(version, M, K, N, transposedInput, kTestSize, numtype);
859935
} else {
860-
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr);
861-
}
862-
863-
864-
if (kTestSize <= 1) {
865-
// Check result with CPU reference implementation for tiny/small tests
866-
checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr);
936+
runTestWithCheck<float>(version, M, K, N, transposedInput, kTestSize, numtype);
867937
}
868938

869939
LOG(kDefLog, kInfo, "Done.");

gpu.h

+3
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,9 @@ struct KernelCode {
318318
const Shape &workgroupSize = {256, 1, 1},
319319
NumType precision = kf32)
320320
: data(pData), workgroupSize(workgroupSize), precision(precision) {
321+
if (precision == kf16) {
322+
data = "enable f16;\n" + data;
323+
}
321324
replaceAll(data, "{{workgroupSize}}", toString(workgroupSize));
322325
replaceAll(data, "{{precision}}", toString(precision));
323326
LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str());

0 commit comments

Comments
 (0)