Skip to content

Commit d0f5e87

Browse files
Add matrix multiplication with transpose
1 parent 9586723 commit d0f5e87

File tree

1 file changed

+187
-26
lines changed

1 file changed

+187
-26
lines changed

examples/matmul/run.cpp

+187-26
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
using namespace gpu;
1616

17+
const char* versionToStr(int version);
18+
1719
static const char *kShaderMatmul1 = R"(
1820
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
1921
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@@ -466,6 +468,123 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
466468
}
467469
}
468470

471+
/* 2D block-tiling with transpose
472+
*
473+
*/
474+
static const char *kShaderMatmulWithTranspose = R"(
475+
@group(0) @binding(0) var<storage, read_write> a: array<{{precision}}>;
476+
@group(0) @binding(1) var<storage, read_write> b: array<{{precision}}>;
477+
@group(0) @binding(2) var<storage, read_write> c: array<vec4<{{precision}}>>;
478+
var<workgroup> tileA: array<{{precision}}, {{BM}} * {{BK}}>;
479+
var<workgroup> tileB: array<{{precision}}, {{BK}} * {{BN}}>;
480+
481+
@compute @workgroup_size({{workgroupSize}})
482+
fn main(
483+
@builtin(global_invocation_id) globalID : vec3<u32>,
484+
@builtin(local_invocation_id) localID : vec3<u32>,
485+
@builtin(workgroup_id) groupid : vec3<u32>) {
486+
487+
var threadResults: array<vec4<{{precision}}>, {{TM}} * {{TN4}}>;
488+
var localM: array<{{precision}}, {{TM}}>;
489+
var localN: array<vec4<{{precision}}>, {{TN4}}>;
490+
491+
let cRow: u32 = groupid.x;
492+
let cCol: u32 = groupid.y;
493+
let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}});
494+
495+
// position of the first c element computed by the thread
496+
let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}};
497+
let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}};
498+
499+
// aPtr and bPtr are the starting positions of the tiles in a and b,
500+
// incremented in the bkidx loop.
501+
// cPtr is the starting position of the tile in c which is fixed.
502+
503+
var aPtr: u32 = cRow * {{BM}} * {{K}};
504+
var bPtr: u32 = cCol * {{BN}};
505+
let cPtr: u32 = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
506+
507+
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
508+
509+
// Load tile
510+
// Load BM x BK by numThread(BM * BN / (TM * TN))
511+
// The number of iteration == BM * BK / (BM * BN / (TM * TN))
512+
for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) {
513+
tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}];
514+
}
515+
// Load BK x BN by numThread(BM * BN / (TM * TN))
516+
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
517+
for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
518+
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BN}}) * {{N}} + ((localID.x + idx * numThread) % {{BN}})];
519+
}
520+
521+
aPtr += {{BK}};
522+
bPtr += {{BK}} * {{N}};
523+
524+
workgroupBarrier();
525+
// Compute tile
526+
for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
527+
for (var idx: u32 = 0; idx < {{TM}}; idx++) {
528+
localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
529+
}
530+
for (var idx: u32 = 0; idx < {{TN4}}; idx++) {
531+
localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) + dotIdx * {{BN}}],
532+
tileB[(threadCol + idx*4 + 1) + dotIdx * {{BN}}],
533+
tileB[(threadCol + idx*4 + 2) + dotIdx * {{BN}}],
534+
tileB[(threadCol + idx*4 + 3) + dotIdx * {{BN}}]);
535+
}
536+
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
537+
for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
538+
threadResults[resIdxM * {{TN4}} + resIdxN] += localM[resIdxM] * localN[resIdxN];
539+
}
540+
}
541+
}
542+
workgroupBarrier();
543+
}
544+
545+
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
546+
for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
547+
c[cPtr + (threadRow + resIdxM) * {{N4}} + (threadCol/4) + resIdxN] = threadResults[resIdxM * {{TN4}} + resIdxN];
548+
}
549+
}
550+
}
551+
)";
552+
553+
inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const size_t M,
554+
const size_t K, const size_t N, const size_t BM,
555+
const size_t BK, const size_t BN,
556+
const size_t TM, const size_t TN,
557+
const Shape &workgroupSize = {256, 1, 1},
558+
NumType precision = kf32) {
559+
assert(BM % TM == 0);
560+
assert(BN % TN == 0);
561+
assert(K % BK == 0);
562+
assert(M % BM == 0);
563+
assert(N % BN == 0);
564+
// # threads = tile A size == tile B size == # threads for computing C
565+
int num_threads = BM * BN / (TM * TN);
566+
std::string codeString(shaderTemplate);
567+
replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
568+
{"{{precision}}", toString(precision)},
569+
{"{{M}}", toString(M)},
570+
{"{{K}}", toString(K)},
571+
{"{{N}}", toString(N)},
572+
{"{{BM}}", toString(BM)},
573+
{"{{BK}}", toString(BK)},
574+
{"{{BN}}", toString(BN)},
575+
{"{{TM}}", toString(TM)},
576+
{"{{TN}}", toString(TN)},
577+
{"{{NUM_TILEA}}", toString(BM * BK / num_threads)},
578+
{"{{NUM_TILEB}}", toString(BN * BK / num_threads)},
579+
{"{{TN4}}", toString(TN / 4)},
580+
{"{{N4}}", toString(N / 4)},
581+
{"{{BN4}}", toString(BN / 4)},
582+
});
583+
std::string unrolledCode = loopUnrolling(codeString);
584+
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
585+
return {unrolledCode, workgroupSize};
586+
}
587+
469588
/**
470589
* @brief No-Op shader with matmul bindings for performance testing
471590
*/
@@ -519,20 +638,26 @@ Kernel selectMatmul(Context &ctx, int version,
519638
size_t M, size_t K, size_t N) {
520639
Kernel kernel;
521640
if (version == 1) {
641+
Shape wgSize = {256, 1, 1};
642+
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
643+
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
644+
kernel = createKernel(ctx, matmul, bindings,
645+
/*nWorkgroups*/ nWorkgroups);
646+
} else if (version == 2) {
522647
Shape wgSize = {16, 16, 1};
523648
LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str());
524649
KernelCode matmul =
525650
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize);
526651
kernel = createKernel(ctx, matmul, bindings,
527652
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize));
528-
} else if (version == 2) {
653+
} else if (version == 3) {
529654
static constexpr size_t tileSize = 16;
530655
KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N,
531656
/*wgSize*/ {tileSize * tileSize, 1, 1});
532657
kernel =
533658
createKernel(ctx, matmul, bindings,
534659
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
535-
} else if (version == 3 || version == 5) {
660+
} else if (version == 4 || version == 6) {
536661
static constexpr size_t BM = 64;
537662
static constexpr size_t BK = 4;
538663
static constexpr size_t BN = BM;
@@ -548,10 +673,10 @@ Kernel selectMatmul(Context &ctx, int version,
548673
KernelCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM,
549674
/*wgSize*/ wgSize,
550675
kf32,
551-
/*Loop unrolling*/ version == 5 ? true: false);
676+
/*Loop unrolling*/ version == 6 ? true: false);
552677
kernel = createKernel(ctx, matmul, bindings,
553678
/*nWorkgroups*/ nWorkgroups);
554-
} else if (version == 4 || version == 6) {
679+
} else if (version == 5 || version == 7) {
555680
static constexpr size_t BM = 64;
556681
static constexpr size_t BK = 8;
557682
static constexpr size_t BN = 64;
@@ -566,10 +691,10 @@ Kernel selectMatmul(Context &ctx, int version,
566691
KernelCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN,
567692
/*wgSize*/ wgSize,
568693
kf32,
569-
/*Loop unrolling*/ version == 6 ? true: false);
694+
/*Loop unrolling*/ version == 7 ? true: false);
570695
kernel = createKernel(ctx, matmul, bindings,
571696
/*nWorkgroups*/ nWorkgroups);
572-
} else if (version == 7) {
697+
} else if (version == 8) {
573698
static constexpr size_t BM = 64;
574699
static constexpr size_t BK = 8;
575700
static constexpr size_t BN = 64;
@@ -587,10 +712,21 @@ Kernel selectMatmul(Context &ctx, int version,
587712
/*Loop unrolling*/ true);
588713
kernel = createKernel(ctx, matmul, bindings,
589714
/*nWorkgroups*/ nWorkgroups);
590-
} else if (version == 8) {
591-
Shape wgSize = {256, 1, 1};
592-
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
593-
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
715+
} else if (version == 9) {
716+
static constexpr size_t BM = 64;
717+
static constexpr size_t BK = 8;
718+
static constexpr size_t BN = 64;
719+
static constexpr size_t TM = BM / BK;
720+
static constexpr size_t TN = BN / BK;
721+
Shape wgSize = {(BM / TM) * (BN / TN), 1, 1}; // This is the same as BK * BK.
722+
Shape nWorkgroups = {cdiv(M, BM), cdiv(N, BN), 1};
723+
LOG(kDefLog, kInfo, "M: %d, K: %d, N: %d", M, K, N);
724+
LOG(kDefLog, kInfo, "BM: %d, BK: %d, BN: %d, TM: %d, TN: %d", BM, BK, BN, TM, TN);
725+
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
726+
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
727+
KernelCode matmul = createMatmulWithTranspose(kShaderMatmulWithTranspose, M, K, N, BM, BK, BN, TM, TN,
728+
/*wgSize*/ wgSize,
729+
kf32);
594730
kernel = createKernel(ctx, matmul, bindings,
595731
/*nWorkgroups*/ nWorkgroups);
596732
}
@@ -626,8 +762,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
626762

627763
printf("[ Press enter to start tests ... ]\n");
628764
getchar();
629-
LOG(kDefLog, kInfo, "Dispatching Kernel version %d, %d iterations ...",
630-
version, nIter);
765+
LOG(kDefLog, kInfo, "Dispatching Kernel version %d: %s, %d iterations ...",
766+
version, versionToStr(version), nIter);
631767

632768
// Dispatch kernel nIter times
633769
auto start = std::chrono::high_resolution_clock::now();
@@ -662,26 +798,43 @@ void runTest(int version, size_t M, size_t K, size_t N,
662798
M, K, N, nIter, duration.count() / static_cast<double>(nIter) / 1000.0 /* us -> ms */, gflops);
663799
}
664800

801+
const char* versionToStr(int version){
802+
switch (version) {
803+
case 1: return "No-Op";
804+
case 2: return "naive matmul";
805+
case 3: return "tiling";
806+
case 4: return "1D blocktiling";
807+
case 5: return "2D blocktiling";
808+
case 6: return "1D blocktiling with loop unrolling";
809+
case 7: return "2D blocktiling with loop unrolling";
810+
case 8: return "2D blocktiling with loop unrolling and vectorization";
811+
case 9: return "2D blocktiling with loop unrolling, vectorization and transpose";
812+
default: return "Not specified";
813+
}
814+
}
815+
665816
int main() {
666817
char* version_str = getenv("MATMUL_VERSION");
667-
int version = version_str == NULL ? 7 : atoi(version_str);
668-
// 1 == naive matmul
669-
// 2 == tiling
670-
// 3 == 1D blocktiling
671-
// 4 == 2D blocktiling
672-
// 5 == 1D blocktiling with loop unrolling
673-
// 6 == 2D blocktiling with loop unrolling
674-
// 7 == 2D blocktiling with loop unrolling and vectorization
675-
// 8 == No-Op
818+
char* kTestSize_str = getenv("MATMUL_SIZE");
819+
int version = version_str == NULL ? 9 : atoi(version_str);
820+
// 1 == No-Op
821+
// 2 == naive matmul
822+
// 3 == tiling
823+
// 4 == 1D blocktiling
824+
// 5 == 2D blocktiling
825+
// 6 == 1D blocktiling with loop unrolling
826+
// 7 == 2D blocktiling with loop unrolling
827+
// 8 == 2D blocktiling with loop unrolling and vectorization
828+
// 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
676829

677830
size_t M, K, N; // Matrix dimensions
678-
static constexpr int kTestSize = 2;
679-
if constexpr (kTestSize == 0) {
831+
int kTestSize = kTestSize_str == NULL ? 2 : atoi(kTestSize_str);
832+
if (kTestSize == 0) {
680833
// Tiny test
681834
M = 32;
682835
K = 32;
683836
N = 32;
684-
} else if constexpr (kTestSize == 1) {
837+
} else if (kTestSize == 1) {
685838
// Small test
686839
M = 256;
687840
K = 128;
@@ -696,11 +849,19 @@ int main() {
696849
std::unique_ptr<float[]> inputPtr = std::make_unique<float[]>(M * K);
697850
std::unique_ptr<float[]> weightsPtr = std::make_unique<float[]>(N * K);
698851
std::unique_ptr<float[]> outputPtr = std::make_unique<float[]>(M * N);
852+
bool transposedInput = version == 9;
699853

700854
initData(M, K, N, inputPtr, weightsPtr);
701-
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr);
855+
if (transposedInput) {
856+
std::unique_ptr<float[]> transposedWeightPtr = std::make_unique<float[]>(K * N);
857+
transpose(weightsPtr.get(), transposedWeightPtr.get(), N, K);
858+
runTest(version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
859+
} else {
860+
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr);
861+
}
862+
702863

703-
if constexpr (kTestSize <= 1) {
864+
if (kTestSize <= 1) {
704865
// Check result with CPU reference implementation for tiny/small tests
705866
checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr);
706867
}

0 commit comments

Comments
 (0)