Skip to content

Commit b558a20

Browse files
Add matrix multiplication with transpose
1 parent 9586723 commit b558a20

File tree

1 file changed

+196
-34
lines changed

1 file changed

+196
-34
lines changed

examples/matmul/run.cpp

Lines changed: 196 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
using namespace gpu;
1616

17+
const char* versionToStr(int version);
18+
1719
static const char *kShaderMatmul1 = R"(
1820
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
1921
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@@ -260,9 +262,9 @@ fn main(
260262
// incremented in the bkidx loop.
261263
// cPtr is the starting position of the tile in c which is fixed.
262264
263-
var aPtr = cRow * {{BM}} * {{K}};
264-
var bPtr = cCol * {{BN}} * {{K}};
265-
let cPtr = cRow * {{BM}} * {{N}} + cCol * {{BN}};
265+
var aPtr: u32 = cRow * {{BM}} * {{K}};
266+
var bPtr: u32 = cCol * {{BN}} * {{K}};
267+
let cPtr: u32 = cRow * {{BM}} * {{N}} + cCol * {{BN}};
266268
267269
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
268270
@@ -275,7 +277,7 @@ fn main(
275277
// Load BK x BN by numThread(BM * BN / (TM * TN))
276278
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
277279
for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
278-
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
280+
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
279281
}
280282
281283
aPtr += {{BK}};
@@ -344,6 +346,7 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
344346
}
345347
}
346348

349+
347350
/* 2D block-tiling with vectorization
348351
*
349352
*/
@@ -376,9 +379,9 @@ fn main(
376379
// incremented in the bkidx loop.
377380
// cPtr is the starting position of the tile in c which is fixed.
378381
379-
var aPtr = cRow * {{BM}} * {{K}};
380-
var bPtr = cCol * {{BN}} * {{K}};
381-
let cPtr = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
382+
var aPtr: u32 = cRow * {{BM}} * {{K}};
383+
var bPtr: u32 = cCol * {{BN}} * {{K}};
384+
let cPtr: u32 = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
382385
383386
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
384387
@@ -455,7 +458,7 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
455458
{"{{NUM_TILEB}}", toString(BN * BK / num_threads)},
456459
{"{{TN4}}", toString(TN / 4)},
457460
{"{{N4}}", toString(N / 4)},
458-
{"{{BN4}}", toString(BN / 4)},
461+
{"{{BN4}}", toString(BN / 4)}
459462
});
460463
if (unrolling) {
461464
std::string unrolledCode = loopUnrolling(codeString);
@@ -466,6 +469,123 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
466469
}
467470
}
468471

472+
/* 2D block-tiling with transpose
473+
*
474+
*/
475+
static const char *kShaderMatmulWithTranspose = R"(
476+
@group(0) @binding(0) var<storage, read_write> a: array<{{precision}}>;
477+
@group(0) @binding(1) var<storage, read_write> b: array<{{precision}}>;
478+
@group(0) @binding(2) var<storage, read_write> c: array<vec4<{{precision}}>>;
479+
var<workgroup> tileA: array<{{precision}}, {{BM}} * {{BK}}>;
480+
var<workgroup> tileB: array<{{precision}}, {{BN}} * {{BK}}>;
481+
482+
@compute @workgroup_size({{workgroupSize}})
483+
fn main(
484+
@builtin(global_invocation_id) globalID : vec3<u32>,
485+
@builtin(local_invocation_id) localID : vec3<u32>,
486+
@builtin(workgroup_id) groupid : vec3<u32>) {
487+
488+
var threadResults: array<vec4<{{precision}}>, {{TM}} * {{TN4}}>;
489+
var localM: array<{{precision}}, {{TM}}>;
490+
var localN: array<vec4<{{precision}}>, {{TN4}}>;
491+
492+
let cRow: u32 = groupid.x;
493+
let cCol: u32 = groupid.y;
494+
let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}});
495+
496+
// position of the first c element computed by the thread
497+
let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}};
498+
let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}};
499+
500+
// aPtr and bPtr are the starting positions of the tiles in a and b,
501+
// incremented in the bkidx loop.
502+
// cPtr is the starting position of the tile in c which is fixed.
503+
504+
var aPtr: u32 = cRow * {{BM}} * {{K}};
505+
var bPtr: u32 = cCol * {{BN}};
506+
let cPtr: u32 = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
507+
508+
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
509+
510+
// Load tile
511+
// Load BM x BK by numThread(BM * BN / (TM * TN))
512+
// The number of iteration == BM * BK / (BM * BN / (TM * TN))
513+
for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) {
514+
tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}];
515+
}
516+
// Load BK x BN by numThread(BM * BN / (TM * TN))
517+
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
518+
for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
519+
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BN}}) * {{N}} + ((localID.x + idx * numThread) % {{BN}})];
520+
}
521+
522+
aPtr += {{BK}};
523+
bPtr += {{BK}} * {{N}};
524+
525+
workgroupBarrier();
526+
// Compute tile
527+
for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
528+
for (var idx: u32 = 0; idx < {{TM}}; idx++) {
529+
localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
530+
}
531+
for (var idx: u32 = 0; idx < {{TN4}}; idx++) {
532+
localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) + dotIdx * {{BN}}],
533+
tileB[(threadCol + idx*4 + 1) + dotIdx * {{BN}}],
534+
tileB[(threadCol + idx*4 + 2) + dotIdx * {{BN}}],
535+
tileB[(threadCol + idx*4 + 3) + dotIdx * {{BN}}]);
536+
}
537+
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
538+
for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
539+
threadResults[resIdxM * {{TN4}} + resIdxN] += localM[resIdxM] * localN[resIdxN];
540+
}
541+
}
542+
}
543+
workgroupBarrier();
544+
}
545+
546+
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
547+
for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
548+
c[cPtr + (threadRow + resIdxM) * {{N4}} + (threadCol/4) + resIdxN] = threadResults[resIdxM * {{TN4}} + resIdxN];
549+
}
550+
}
551+
}
552+
)";
553+
554+
inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const size_t M,
555+
const size_t K, const size_t N, const size_t BM,
556+
const size_t BK, const size_t BN,
557+
const size_t TM, const size_t TN,
558+
const Shape &workgroupSize = {256, 1, 1},
559+
NumType precision = kf32) {
560+
assert(BM % TM == 0);
561+
assert(BN % TN == 0);
562+
assert(K % BK == 0);
563+
assert(M % BM == 0);
564+
assert(N % BN == 0);
565+
// # threads = tile A size == tile B size == # threads for computing C
566+
int num_threads = BM * BN / (TM * TN);
567+
std::string codeString(shaderTemplate);
568+
replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
569+
{"{{precision}}", toString(precision)},
570+
{"{{M}}", toString(M)},
571+
{"{{K}}", toString(K)},
572+
{"{{N}}", toString(N)},
573+
{"{{BM}}", toString(BM)},
574+
{"{{BK}}", toString(BK)},
575+
{"{{BN}}", toString(BN)},
576+
{"{{TM}}", toString(TM)},
577+
{"{{TN}}", toString(TN)},
578+
{"{{NUM_TILEA}}", toString(BM * BK / num_threads)},
579+
{"{{NUM_TILEB}}", toString(BN * BK / num_threads)},
580+
{"{{TN4}}", toString(TN / 4)},
581+
{"{{N4}}", toString(N / 4)},
582+
{"{{BN4}}", toString(BN / 4)},
583+
});
584+
std::string unrolledCode = loopUnrolling(codeString);
585+
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
586+
return {unrolledCode, workgroupSize};
587+
}
588+
469589
/**
470590
* @brief No-Op shader with matmul bindings for performance testing
471591
*/
@@ -519,20 +639,26 @@ Kernel selectMatmul(Context &ctx, int version,
519639
size_t M, size_t K, size_t N) {
520640
Kernel kernel;
521641
if (version == 1) {
642+
Shape wgSize = {256, 1, 1};
643+
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
644+
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
645+
kernel = createKernel(ctx, matmul, bindings,
646+
/*nWorkgroups*/ nWorkgroups);
647+
} else if (version == 2) {
522648
Shape wgSize = {16, 16, 1};
523649
LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str());
524650
KernelCode matmul =
525651
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize);
526652
kernel = createKernel(ctx, matmul, bindings,
527653
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize));
528-
} else if (version == 2) {
654+
} else if (version == 3) {
529655
static constexpr size_t tileSize = 16;
530656
KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N,
531657
/*wgSize*/ {tileSize * tileSize, 1, 1});
532658
kernel =
533659
createKernel(ctx, matmul, bindings,
534660
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
535-
} else if (version == 3 || version == 5) {
661+
} else if (version == 4 || version == 6) {
536662
static constexpr size_t BM = 64;
537663
static constexpr size_t BK = 4;
538664
static constexpr size_t BN = BM;
@@ -548,10 +674,10 @@ Kernel selectMatmul(Context &ctx, int version,
548674
KernelCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM,
549675
/*wgSize*/ wgSize,
550676
kf32,
551-
/*Loop unrolling*/ version == 5 ? true: false);
677+
/*Loop unrolling*/ version == 6 ? true: false);
552678
kernel = createKernel(ctx, matmul, bindings,
553679
/*nWorkgroups*/ nWorkgroups);
554-
} else if (version == 4 || version == 6) {
680+
} else if (version == 5 || version == 7) {
555681
static constexpr size_t BM = 64;
556682
static constexpr size_t BK = 8;
557683
static constexpr size_t BN = 64;
@@ -566,10 +692,10 @@ Kernel selectMatmul(Context &ctx, int version,
566692
KernelCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN,
567693
/*wgSize*/ wgSize,
568694
kf32,
569-
/*Loop unrolling*/ version == 6 ? true: false);
695+
/*Loop unrolling*/ version == 7 ? true: false);
570696
kernel = createKernel(ctx, matmul, bindings,
571697
/*nWorkgroups*/ nWorkgroups);
572-
} else if (version == 7) {
698+
} else if (version == 8) {
573699
static constexpr size_t BM = 64;
574700
static constexpr size_t BK = 8;
575701
static constexpr size_t BN = 64;
@@ -587,10 +713,21 @@ Kernel selectMatmul(Context &ctx, int version,
587713
/*Loop unrolling*/ true);
588714
kernel = createKernel(ctx, matmul, bindings,
589715
/*nWorkgroups*/ nWorkgroups);
590-
} else if (version == 8) {
591-
Shape wgSize = {256, 1, 1};
592-
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
593-
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
716+
} else if (version == 9) {
717+
static constexpr size_t BM = 64;
718+
static constexpr size_t BK = 8;
719+
static constexpr size_t BN = 64;
720+
static constexpr size_t TM = BM / BK;
721+
static constexpr size_t TN = BN / BK;
722+
Shape wgSize = {(BM / TM) * (BN / TN), 1, 1}; // This is the same as BK * BK.
723+
Shape nWorkgroups = {cdiv(M, BM), cdiv(N, BN), 1};
724+
LOG(kDefLog, kInfo, "M: %d, K: %d, N: %d", M, K, N);
725+
LOG(kDefLog, kInfo, "BM: %d, BK: %d, BN: %d, TM: %d, TN: %d", BM, BK, BN, TM, TN);
726+
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
727+
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
728+
KernelCode matmul = createMatmulWithTranspose(kShaderMatmulWithTranspose, M, K, N, BM, BK, BN, TM, TN,
729+
/*wgSize*/ wgSize,
730+
kf32);
594731
kernel = createKernel(ctx, matmul, bindings,
595732
/*nWorkgroups*/ nWorkgroups);
596733
}
@@ -626,8 +763,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
626763

627764
printf("[ Press enter to start tests ... ]\n");
628765
getchar();
629-
LOG(kDefLog, kInfo, "Dispatching Kernel version %d, %d iterations ...",
630-
version, nIter);
766+
LOG(kDefLog, kInfo, "Dispatching Kernel version %d: %s, %d iterations ...",
767+
version, versionToStr(version), nIter);
631768

632769
// Dispatch kernel nIter times
633770
auto start = std::chrono::high_resolution_clock::now();
@@ -662,26 +799,43 @@ void runTest(int version, size_t M, size_t K, size_t N,
662799
M, K, N, nIter, duration.count() / static_cast<double>(nIter) / 1000.0 /* us -> ms */, gflops);
663800
}
664801

802+
const char* versionToStr(int version){
803+
switch (version) {
804+
case 1: return "No-Op";
805+
case 2: return "naive matmul";
806+
case 3: return "tiling";
807+
case 4: return "1D blocktiling";
808+
case 5: return "2D blocktiling";
809+
case 6: return "1D blocktiling with loop unrolling";
810+
case 7: return "2D blocktiling with loop unrolling";
811+
case 8: return "2D blocktiling with loop unrolling and vectorization";
812+
case 9: return "2D blocktiling with loop unrolling, vectorization and transpose";
813+
default: return "Not specified";
814+
}
815+
}
816+
665817
int main() {
666818
char* version_str = getenv("MATMUL_VERSION");
667-
int version = version_str == NULL ? 7 : atoi(version_str);
668-
// 1 == naive matmul
669-
// 2 == tiling
670-
// 3 == 1D blocktiling
671-
// 4 == 2D blocktiling
672-
// 5 == 1D blocktiling with loop unrolling
673-
// 6 == 2D blocktiling with loop unrolling
674-
// 7 == 2D blocktiling with loop unrolling and vectorization
675-
// 8 == No-Op
819+
char* kTestSize_str = getenv("MATMUL_SIZE");
820+
int version = version_str == NULL ? 9 : atoi(version_str);
821+
// 1 == No-Op
822+
// 2 == naive matmul
823+
// 3 == tiling
824+
// 4 == 1D blocktiling
825+
// 5 == 2D blocktiling
826+
// 6 == 1D blocktiling with loop unrolling
827+
// 7 == 2D blocktiling with loop unrolling
828+
// 8 == 2D blocktiling with loop unrolling and vectorization
829+
// 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
676830

677831
size_t M, K, N; // Matrix dimensions
678-
static constexpr int kTestSize = 2;
679-
if constexpr (kTestSize == 0) {
832+
int kTestSize = kTestSize_str == NULL ? 2 : atoi(kTestSize_str);
833+
if (kTestSize == 0) {
680834
// Tiny test
681835
M = 32;
682836
K = 32;
683837
N = 32;
684-
} else if constexpr (kTestSize == 1) {
838+
} else if (kTestSize == 1) {
685839
// Small test
686840
M = 256;
687841
K = 128;
@@ -696,11 +850,19 @@ int main() {
696850
std::unique_ptr<float[]> inputPtr = std::make_unique<float[]>(M * K);
697851
std::unique_ptr<float[]> weightsPtr = std::make_unique<float[]>(N * K);
698852
std::unique_ptr<float[]> outputPtr = std::make_unique<float[]>(M * N);
853+
bool transposedInput = version == 9;
699854

700855
initData(M, K, N, inputPtr, weightsPtr);
701-
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr);
856+
if (transposedInput) {
857+
std::unique_ptr<float[]> transposedWeightPtr = std::make_unique<float[]>(K * N);
858+
transpose(weightsPtr.get(), transposedWeightPtr.get(), N, K);
859+
runTest(version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
860+
} else {
861+
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr);
862+
}
863+
702864

703-
if constexpr (kTestSize <= 1) {
865+
if (kTestSize <= 1) {
704866
// Check result with CPU reference implementation for tiny/small tests
705867
checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr);
706868
}

0 commit comments

Comments
 (0)