Skip to content

Commit cf19e1a

Browse files
Add matrix multiplication with transpose
1 parent 9586723 commit cf19e1a

File tree

2 files changed

+158
-47
lines changed

2 files changed

+158
-47
lines changed

examples/matmul/run.cpp

Lines changed: 118 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
using namespace gpu;
1616

17+
const char* versionToStr(int version);
18+
1719
static const char *kShaderMatmul1 = R"(
1820
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
1921
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@@ -220,7 +222,7 @@ inline KernelCode createMatmul3(const char *shaderTemplate, const size_t M,
220222
{"{{BN}}", toString(BN)},
221223
{"{{TM}}", toString(TM)}});
222224
if (unrolling) {
223-
std::string unrolledCode = loopUnrolling(codeString);
225+
std::string unrolledCode = loopUnrolling(removeUnnecessaryIfStatements(codeString));
224226
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
225227
return {unrolledCode, workgroupSize};
226228
} else {
@@ -260,9 +262,14 @@ fn main(
260262
// incremented in the bkidx loop.
261263
// cPtr is the starting position of the tile in c which is fixed.
262264
263-
var aPtr = cRow * {{BM}} * {{K}};
264-
var bPtr = cCol * {{BN}} * {{K}};
265-
let cPtr = cRow * {{BM}} * {{N}} + cCol * {{BN}};
265+
var aPtr: u32 = cRow * {{BM}} * {{K}};
266+
var bPtr: u32 = 0;
267+
if ({{TRANSPOSE}}) {
268+
bPtr = cCol * {{BN}};
269+
} else {
270+
bPtr = cCol * {{BN}} * {{K}};
271+
}
272+
let cPtr: u32 = cRow * {{BM}} * {{N}} + cCol * {{BN}};
266273
267274
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
268275
@@ -275,11 +282,19 @@ fn main(
275282
// Load BK x BN by numThread(BM * BN / (TM * TN))
276283
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
277284
for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
278-
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
285+
if ({{TRANSPOSE}}) {
286+
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BN}}) * {{N}} + ((localID.x + idx * numThread) % {{BN}})];
287+
} else {
288+
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
289+
}
279290
}
280291
281292
aPtr += {{BK}};
282-
bPtr += {{BK}};
293+
if ({{TRANSPOSE}}) {
294+
bPtr += {{BK}} * {{N}};
295+
} else {
296+
bPtr += {{BK}};
297+
}
283298
284299
workgroupBarrier();
285300
// Compute tile
@@ -288,7 +303,11 @@ fn main(
288303
localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
289304
}
290305
for (var idx: u32 = 0; idx < {{TN}}; idx++) {
291-
localN[idx] = tileB[(threadCol + idx) * {{BK}} + dotIdx];
306+
if ({{TRANSPOSE}}) {
307+
localN[idx] = tileB[dotIdx * {{BN}} + threadCol + idx];
308+
} else {
309+
localN[idx] = tileB[(threadCol + idx) * {{BK}} + dotIdx];
310+
}
292311
}
293312
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
294313
for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) {
@@ -311,6 +330,7 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
311330
const size_t K, const size_t N, const size_t BM,
312331
const size_t BK, const size_t BN,
313332
const size_t TM, const size_t TN,
333+
const bool transpose = false,
314334
const Shape &workgroupSize = {256, 1, 1},
315335
NumType precision = kf32,
316336
bool unrolling = false) {
@@ -333,17 +353,19 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
333353
{"{{TM}}", toString(TM)},
334354
{"{{TN}}", toString(TN)},
335355
{"{{NUM_TILEA}}", toString(BM * BK / num_threads)},
336-
{"{{NUM_TILEB}}", toString(BN * BK / num_threads)}
356+
{"{{NUM_TILEB}}", toString(BN * BK / num_threads)},
357+
{"{{TRANSPOSE}}", transpose ? "true" : "false"},
337358
});
338359
if (unrolling) {
339-
std::string unrolledCode = loopUnrolling(codeString);
360+
std::string unrolledCode = loopUnrolling(removeUnnecessaryIfStatements(codeString));
340361
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
341362
return {unrolledCode, workgroupSize};
342363
} else {
343364
return {codeString, workgroupSize};
344365
}
345366
}
346367

368+
347369
/* 2D block-tiling with vectorization
348370
*
349371
*/
@@ -376,9 +398,14 @@ fn main(
376398
// incremented in the bkidx loop.
377399
// cPtr is the starting position of the tile in c which is fixed.
378400
379-
var aPtr = cRow * {{BM}} * {{K}};
380-
var bPtr = cCol * {{BN}} * {{K}};
381-
let cPtr = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
401+
var aPtr: u32 = cRow * {{BM}} * {{K}};
402+
var bPtr: u32 = 0;
403+
if ({{TRANSPOSE}}) {
404+
bPtr = cCol * {{BN}};
405+
} else {
406+
bPtr = cCol * {{BN}} * {{K}};
407+
}
408+
let cPtr: u32 = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
382409
383410
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
384411
@@ -391,11 +418,19 @@ fn main(
391418
// Load BK x BN by numThread(BM * BN / (TM * TN))
392419
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
393420
for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
394-
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
421+
if ({{TRANSPOSE}}) {
422+
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BN}}) * {{N}} + ((localID.x + idx * numThread) % {{BN}})];
423+
} else {
424+
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
425+
}
395426
}
396427
397428
aPtr += {{BK}};
398-
bPtr += {{BK}};
429+
if ({{TRANSPOSE}}) {
430+
bPtr += {{BK}} * {{N}};
431+
} else {
432+
bPtr += {{BK}};
433+
}
399434
400435
workgroupBarrier();
401436
// Compute tile
@@ -404,10 +439,17 @@ fn main(
404439
localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
405440
}
406441
for (var idx: u32 = 0; idx < {{TN4}}; idx++) {
407-
localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) * {{BK}} + dotIdx],
408-
tileB[(threadCol + idx*4 + 1) * {{BK}} + dotIdx],
409-
tileB[(threadCol + idx*4 + 2) * {{BK}} + dotIdx],
410-
tileB[(threadCol + idx*4 + 3) * {{BK}} + dotIdx]);
442+
if ({{TRANSPOSE}}) {
443+
localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) + dotIdx * {{BN}}],
444+
tileB[(threadCol + idx*4 + 1) + dotIdx * {{BN}}],
445+
tileB[(threadCol + idx*4 + 2) + dotIdx * {{BN}}],
446+
tileB[(threadCol + idx*4 + 3) + dotIdx * {{BN}}]);
447+
} else {
448+
localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) * {{BK}} + dotIdx],
449+
tileB[(threadCol + idx*4 + 1) * {{BK}} + dotIdx],
450+
tileB[(threadCol + idx*4 + 2) * {{BK}} + dotIdx],
451+
tileB[(threadCol + idx*4 + 3) * {{BK}} + dotIdx]);
452+
}
411453
}
412454
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
413455
for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
@@ -430,6 +472,7 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
430472
const size_t K, const size_t N, const size_t BM,
431473
const size_t BK, const size_t BN,
432474
const size_t TM, const size_t TN,
475+
const bool transpose = false,
433476
const Shape &workgroupSize = {256, 1, 1},
434477
NumType precision = kf32,
435478
bool unrolling = false) {
@@ -456,9 +499,10 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
456499
{"{{TN4}}", toString(TN / 4)},
457500
{"{{N4}}", toString(N / 4)},
458501
{"{{BN4}}", toString(BN / 4)},
502+
{"{{TRANSPOSE}}", transpose ? "true" : "false"},
459503
});
460504
if (unrolling) {
461-
std::string unrolledCode = loopUnrolling(codeString);
505+
std::string unrolledCode = loopUnrolling(removeUnnecessaryIfStatements(codeString));
462506
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
463507
return {unrolledCode, workgroupSize};
464508
} else {
@@ -519,20 +563,26 @@ Kernel selectMatmul(Context &ctx, int version,
519563
size_t M, size_t K, size_t N) {
520564
Kernel kernel;
521565
if (version == 1) {
566+
Shape wgSize = {256, 1, 1};
567+
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
568+
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
569+
kernel = createKernel(ctx, matmul, bindings,
570+
/*nWorkgroups*/ nWorkgroups);
571+
} else if (version == 2) {
522572
Shape wgSize = {16, 16, 1};
523573
LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str());
524574
KernelCode matmul =
525575
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize);
526576
kernel = createKernel(ctx, matmul, bindings,
527577
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize));
528-
} else if (version == 2) {
578+
} else if (version == 3) {
529579
static constexpr size_t tileSize = 16;
530580
KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N,
531581
/*wgSize*/ {tileSize * tileSize, 1, 1});
532582
kernel =
533583
createKernel(ctx, matmul, bindings,
534584
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
535-
} else if (version == 3 || version == 5) {
585+
} else if (version == 4 || version == 6) {
536586
static constexpr size_t BM = 64;
537587
static constexpr size_t BK = 4;
538588
static constexpr size_t BN = BM;
@@ -548,10 +598,10 @@ Kernel selectMatmul(Context &ctx, int version,
548598
KernelCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM,
549599
/*wgSize*/ wgSize,
550600
kf32,
551-
/*Loop unrolling*/ version == 5 ? true: false);
601+
/*Loop unrolling*/ version == 6 ? true: false);
552602
kernel = createKernel(ctx, matmul, bindings,
553603
/*nWorkgroups*/ nWorkgroups);
554-
} else if (version == 4 || version == 6) {
604+
} else if (version == 5 || version == 7) {
555605
static constexpr size_t BM = 64;
556606
static constexpr size_t BK = 8;
557607
static constexpr size_t BN = 64;
@@ -564,12 +614,13 @@ Kernel selectMatmul(Context &ctx, int version,
564614
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
565615
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
566616
KernelCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN,
617+
/*tranpose*/ false,
567618
/*wgSize*/ wgSize,
568619
kf32,
569-
/*Loop unrolling*/ version == 6 ? true: false);
620+
/*Loop unrolling*/ version == 7 ? true: false);
570621
kernel = createKernel(ctx, matmul, bindings,
571622
/*nWorkgroups*/ nWorkgroups);
572-
} else if (version == 7) {
623+
} else if (version == 8 || version == 9) {
573624
static constexpr size_t BM = 64;
574625
static constexpr size_t BK = 8;
575626
static constexpr size_t BN = 64;
@@ -582,17 +633,12 @@ Kernel selectMatmul(Context &ctx, int version,
582633
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
583634
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
584635
KernelCode matmul = createMatmulWithVectorization(kShaderMatmulWithVectorization, M, K, N, BM, BK, BN, TM, TN,
636+
/*tranpose*/ version == 9,
585637
/*wgSize*/ wgSize,
586638
kf32,
587639
/*Loop unrolling*/ true);
588640
kernel = createKernel(ctx, matmul, bindings,
589641
/*nWorkgroups*/ nWorkgroups);
590-
} else if (version == 8) {
591-
Shape wgSize = {256, 1, 1};
592-
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
593-
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
594-
kernel = createKernel(ctx, matmul, bindings,
595-
/*nWorkgroups*/ nWorkgroups);
596642
}
597643
return kernel;
598644
}
@@ -626,8 +672,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
626672

627673
printf("[ Press enter to start tests ... ]\n");
628674
getchar();
629-
LOG(kDefLog, kInfo, "Dispatching Kernel version %d, %d iterations ...",
630-
version, nIter);
675+
LOG(kDefLog, kInfo, "Dispatching Kernel version %d: %s, %d iterations ...",
676+
version, versionToStr(version), nIter);
631677

632678
// Dispatch kernel nIter times
633679
auto start = std::chrono::high_resolution_clock::now();
@@ -662,26 +708,43 @@ void runTest(int version, size_t M, size_t K, size_t N,
662708
M, K, N, nIter, duration.count() / static_cast<double>(nIter) / 1000.0 /* us -> ms */, gflops);
663709
}
664710

711+
const char* versionToStr(int version){
712+
switch (version) {
713+
case 1: return "No-Op";
714+
case 2: return "naive matmul";
715+
case 3: return "tiling";
716+
case 4: return "1D blocktiling";
717+
case 5: return "2D blocktiling";
718+
case 6: return "1D blocktiling with loop unrolling";
719+
case 7: return "2D blocktiling with loop unrolling";
720+
case 8: return "2D blocktiling with loop unrolling and vectorization";
721+
case 9: return "2D blocktiling with loop unrolling, vectorization and transpose";
722+
default: return "Not specified";
723+
}
724+
}
725+
665726
int main() {
666727
char* version_str = getenv("MATMUL_VERSION");
667-
int version = version_str == NULL ? 7 : atoi(version_str);
668-
// 1 == naive matmul
669-
// 2 == tiling
670-
// 3 == 1D blocktiling
671-
// 4 == 2D blocktiling
672-
// 5 == 1D blocktiling with loop unrolling
673-
// 6 == 2D blocktiling with loop unrolling
674-
// 7 == 2D blocktiling with loop unrolling and vectorization
675-
// 8 == No-Op
728+
char* kTestSize_str = getenv("MATMUL_SIZE");
729+
int version = version_str == NULL ? 9 : atoi(version_str);
730+
// 1 == No-Op
731+
// 2 == naive matmul
732+
// 3 == tiling
733+
// 4 == 1D blocktiling
734+
// 5 == 2D blocktiling
735+
// 6 == 1D blocktiling with loop unrolling
736+
// 7 == 2D blocktiling with loop unrolling
737+
// 8 == 2D blocktiling with loop unrolling and vectorization
738+
// 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
676739

677740
size_t M, K, N; // Matrix dimensions
678-
static constexpr int kTestSize = 2;
679-
if constexpr (kTestSize == 0) {
741+
int kTestSize = kTestSize_str == NULL ? 2 : atoi(kTestSize_str);
742+
if (kTestSize == 0) {
680743
// Tiny test
681744
M = 32;
682745
K = 32;
683746
N = 32;
684-
} else if constexpr (kTestSize == 1) {
747+
} else if (kTestSize == 1) {
685748
// Small test
686749
M = 256;
687750
K = 128;
@@ -696,11 +759,19 @@ int main() {
696759
std::unique_ptr<float[]> inputPtr = std::make_unique<float[]>(M * K);
697760
std::unique_ptr<float[]> weightsPtr = std::make_unique<float[]>(N * K);
698761
std::unique_ptr<float[]> outputPtr = std::make_unique<float[]>(M * N);
762+
bool transposedInput = version == 9;
699763

700764
initData(M, K, N, inputPtr, weightsPtr);
701-
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr);
765+
if (transposedInput) {
766+
std::unique_ptr<float[]> transposedWeightPtr = std::make_unique<float[]>(K * N);
767+
transpose(weightsPtr.get(), transposedWeightPtr.get(), N, K);
768+
runTest(version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
769+
} else {
770+
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr);
771+
}
772+
702773

703-
if constexpr (kTestSize <= 1) {
774+
if (kTestSize <= 1) {
704775
// Check result with CPU reference implementation for tiny/small tests
705776
checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr);
706777
}

experimental/wgsl.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,46 @@ std::string loopUnrolling(const std::string& code, int threshold = 32) {
7979
return unrolledCode;
8080
}
8181

82+
83+
std::string removeUnnecessaryIfStatements(const std::string& code) {
84+
// Pattern to match if(true) {...} else {...}
85+
std::regex ifTrueElsePattern(R"(if\s*\(\s*true\s*\)\s*\{([^{}]*)\}\s*else\s*\{([^{}]*)\})");
86+
// Pattern to match if(false) {...} else {...}
87+
std::regex ifFalseElsePattern(R"(if\s*\(\s*false\s*\)\s*\{([^{}]*)\}\s*else\s*\{([^{}]*)\})");
88+
// Pattern to match if(true) {...}
89+
std::regex ifTruePattern(R"(if\s*\(\s*true\s*\)\s*\{([^{}]*)\})");
90+
// Pattern to match if(false) {...}
91+
std::regex ifFalsePattern(R"(if\s*\(\s*false\s*\)\s*\{([^{}]*)\})");
92+
93+
std::string optimizedCode = code;
94+
std::smatch match;
95+
96+
// Handle if(true) {...} else {...}
97+
while (std::regex_search(optimizedCode, match, ifTrueElsePattern)) {
98+
std::string trueBlock = match[1].str();
99+
optimizedCode = optimizedCode.substr(0, match.position()) + trueBlock + optimizedCode.substr(match.position() + match.length());
100+
}
101+
102+
// Handle if(false) {...} else {...}
103+
while (std::regex_search(optimizedCode, match, ifFalseElsePattern)) {
104+
std::string elseBlock = match[2].str();
105+
optimizedCode = optimizedCode.substr(0, match.position()) + elseBlock + optimizedCode.substr(match.position() + match.length());
106+
}
107+
108+
// Handle if(true) {...}
109+
while (std::regex_search(optimizedCode, match, ifTruePattern)) {
110+
std::string trueBlock = match[1].str();
111+
optimizedCode = optimizedCode.substr(0, match.position()) + trueBlock + optimizedCode.substr(match.position() + match.length());
112+
}
113+
114+
// Handle if(false) {...}
115+
while (std::regex_search(optimizedCode, match, ifFalsePattern)) {
116+
optimizedCode = optimizedCode.substr(0, match.position()) + optimizedCode.substr(match.position() + match.length());
117+
}
118+
119+
return optimizedCode;
120+
}
121+
82122
} // namespace gpu
83123

84124
#endif

0 commit comments

Comments
 (0)