Skip to content

Commit a536c8f

Browse files
committed
revert matrix transpose PR temporarily #35 until segfault is resolved
1 parent 6bdd778 commit a536c8f

File tree

1 file changed

+26
-187
lines changed

1 file changed

+26
-187
lines changed

examples/matmul/run.cpp

+26-187
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
using namespace gpu;
1616

17-
const char* versionToStr(int version);
18-
1917
static const char *kShaderMatmul1 = R"(
2018
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
2119
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@@ -468,123 +466,6 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
468466
}
469467
}
470468

471-
/* 2D block-tiling with transpose
472-
*
473-
*/
474-
static const char *kShaderMatmulWithTranspose = R"(
475-
@group(0) @binding(0) var<storage, read_write> a: array<{{precision}}>;
476-
@group(0) @binding(1) var<storage, read_write> b: array<{{precision}}>;
477-
@group(0) @binding(2) var<storage, read_write> c: array<vec4<{{precision}}>>;
478-
var<workgroup> tileA: array<{{precision}}, {{BM}} * {{BK}}>;
479-
var<workgroup> tileB: array<{{precision}}, {{BK}} * {{BN}}>;
480-
481-
@compute @workgroup_size({{workgroupSize}})
482-
fn main(
483-
@builtin(global_invocation_id) globalID : vec3<u32>,
484-
@builtin(local_invocation_id) localID : vec3<u32>,
485-
@builtin(workgroup_id) groupid : vec3<u32>) {
486-
487-
var threadResults: array<vec4<{{precision}}>, {{TM}} * {{TN4}}>;
488-
var localM: array<{{precision}}, {{TM}}>;
489-
var localN: array<vec4<{{precision}}>, {{TN4}}>;
490-
491-
let cRow: u32 = groupid.x;
492-
let cCol: u32 = groupid.y;
493-
let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}});
494-
495-
// position of the first c element computed by the thread
496-
let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}};
497-
let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}};
498-
499-
// aPtr and bPtr are the starting positions of the tiles in a and b,
500-
// incremented in the bkidx loop.
501-
// cPtr is the starting position of the tile in c which is fixed.
502-
503-
var aPtr: u32 = cRow * {{BM}} * {{K}};
504-
var bPtr: u32 = cCol * {{BN}};
505-
let cPtr: u32 = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
506-
507-
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
508-
509-
// Load tile
510-
// Load BM x BK by numThread(BM * BN / (TM * TN))
511-
// The number of iteration == BM * BK / (BM * BN / (TM * TN))
512-
for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) {
513-
tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}];
514-
}
515-
// Load BK x BN by numThread(BM * BN / (TM * TN))
516-
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
517-
for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
518-
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BN}}) * {{N}} + ((localID.x + idx * numThread) % {{BN}})];
519-
}
520-
521-
aPtr += {{BK}};
522-
bPtr += {{BK}} * {{N}};
523-
524-
workgroupBarrier();
525-
// Compute tile
526-
for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
527-
for (var idx: u32 = 0; idx < {{TM}}; idx++) {
528-
localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
529-
}
530-
for (var idx: u32 = 0; idx < {{TN4}}; idx++) {
531-
localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) + dotIdx * {{BN}}],
532-
tileB[(threadCol + idx*4 + 1) + dotIdx * {{BN}}],
533-
tileB[(threadCol + idx*4 + 2) + dotIdx * {{BN}}],
534-
tileB[(threadCol + idx*4 + 3) + dotIdx * {{BN}}]);
535-
}
536-
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
537-
for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
538-
threadResults[resIdxM * {{TN4}} + resIdxN] += localM[resIdxM] * localN[resIdxN];
539-
}
540-
}
541-
}
542-
workgroupBarrier();
543-
}
544-
545-
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
546-
for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
547-
c[cPtr + (threadRow + resIdxM) * {{N4}} + (threadCol/4) + resIdxN] = threadResults[resIdxM * {{TN4}} + resIdxN];
548-
}
549-
}
550-
}
551-
)";
552-
553-
inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const size_t M,
554-
const size_t K, const size_t N, const size_t BM,
555-
const size_t BK, const size_t BN,
556-
const size_t TM, const size_t TN,
557-
const Shape &workgroupSize = {256, 1, 1},
558-
NumType precision = kf32) {
559-
assert(BM % TM == 0);
560-
assert(BN % TN == 0);
561-
assert(K % BK == 0);
562-
assert(M % BM == 0);
563-
assert(N % BN == 0);
564-
// # threads = tile A size == tile B size == # threads for computing C
565-
int num_threads = BM * BN / (TM * TN);
566-
std::string codeString(shaderTemplate);
567-
replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
568-
{"{{precision}}", toString(precision)},
569-
{"{{M}}", toString(M)},
570-
{"{{K}}", toString(K)},
571-
{"{{N}}", toString(N)},
572-
{"{{BM}}", toString(BM)},
573-
{"{{BK}}", toString(BK)},
574-
{"{{BN}}", toString(BN)},
575-
{"{{TM}}", toString(TM)},
576-
{"{{TN}}", toString(TN)},
577-
{"{{NUM_TILEA}}", toString(BM * BK / num_threads)},
578-
{"{{NUM_TILEB}}", toString(BN * BK / num_threads)},
579-
{"{{TN4}}", toString(TN / 4)},
580-
{"{{N4}}", toString(N / 4)},
581-
{"{{BN4}}", toString(BN / 4)},
582-
});
583-
std::string unrolledCode = loopUnrolling(codeString);
584-
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
585-
return {unrolledCode, workgroupSize};
586-
}
587-
588469
/**
589470
* @brief No-Op shader with matmul bindings for performance testing
590471
*/
@@ -638,26 +519,20 @@ Kernel selectMatmul(Context &ctx, int version,
638519
size_t M, size_t K, size_t N) {
639520
Kernel kernel;
640521
if (version == 1) {
641-
Shape wgSize = {256, 1, 1};
642-
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
643-
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
644-
kernel = createKernel(ctx, matmul, bindings,
645-
/*nWorkgroups*/ nWorkgroups);
646-
} else if (version == 2) {
647522
Shape wgSize = {16, 16, 1};
648523
LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str());
649524
KernelCode matmul =
650525
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize);
651526
kernel = createKernel(ctx, matmul, bindings,
652527
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize));
653-
} else if (version == 3) {
528+
} else if (version == 2) {
654529
static constexpr size_t tileSize = 16;
655530
KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N,
656531
/*wgSize*/ {tileSize * tileSize, 1, 1});
657532
kernel =
658533
createKernel(ctx, matmul, bindings,
659534
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
660-
} else if (version == 4 || version == 6) {
535+
} else if (version == 3 || version == 5) {
661536
static constexpr size_t BM = 64;
662537
static constexpr size_t BK = 4;
663538
static constexpr size_t BN = BM;
@@ -673,10 +548,10 @@ Kernel selectMatmul(Context &ctx, int version,
673548
KernelCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM,
674549
/*wgSize*/ wgSize,
675550
kf32,
676-
/*Loop unrolling*/ version == 6 ? true: false);
551+
/*Loop unrolling*/ version == 5 ? true: false);
677552
kernel = createKernel(ctx, matmul, bindings,
678553
/*nWorkgroups*/ nWorkgroups);
679-
} else if (version == 5 || version == 7) {
554+
} else if (version == 4 || version == 6) {
680555
static constexpr size_t BM = 64;
681556
static constexpr size_t BK = 8;
682557
static constexpr size_t BN = 64;
@@ -691,10 +566,10 @@ Kernel selectMatmul(Context &ctx, int version,
691566
KernelCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN,
692567
/*wgSize*/ wgSize,
693568
kf32,
694-
/*Loop unrolling*/ version == 7 ? true: false);
569+
/*Loop unrolling*/ version == 6 ? true: false);
695570
kernel = createKernel(ctx, matmul, bindings,
696571
/*nWorkgroups*/ nWorkgroups);
697-
} else if (version == 8) {
572+
} else if (version == 7) {
698573
static constexpr size_t BM = 64;
699574
static constexpr size_t BK = 8;
700575
static constexpr size_t BN = 64;
@@ -712,21 +587,10 @@ Kernel selectMatmul(Context &ctx, int version,
712587
/*Loop unrolling*/ true);
713588
kernel = createKernel(ctx, matmul, bindings,
714589
/*nWorkgroups*/ nWorkgroups);
715-
} else if (version == 9) {
716-
static constexpr size_t BM = 64;
717-
static constexpr size_t BK = 8;
718-
static constexpr size_t BN = 64;
719-
static constexpr size_t TM = BM / BK;
720-
static constexpr size_t TN = BN / BK;
721-
Shape wgSize = {(BM / TM) * (BN / TN), 1, 1}; // This is the same as BK * BK.
722-
Shape nWorkgroups = {cdiv(M, BM), cdiv(N, BN), 1};
723-
LOG(kDefLog, kInfo, "M: %d, K: %d, N: %d", M, K, N);
724-
LOG(kDefLog, kInfo, "BM: %d, BK: %d, BN: %d, TM: %d, TN: %d", BM, BK, BN, TM, TN);
725-
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
726-
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
727-
KernelCode matmul = createMatmulWithTranspose(kShaderMatmulWithTranspose, M, K, N, BM, BK, BN, TM, TN,
728-
/*wgSize*/ wgSize,
729-
kf32);
590+
} else if (version == 8) {
591+
Shape wgSize = {256, 1, 1};
592+
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
593+
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
730594
kernel = createKernel(ctx, matmul, bindings,
731595
/*nWorkgroups*/ nWorkgroups);
732596
}
@@ -762,8 +626,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
762626

763627
printf("[ Press enter to start tests ... ]\n");
764628
getchar();
765-
LOG(kDefLog, kInfo, "Dispatching Kernel version %d: %s, %d iterations ...",
766-
version, versionToStr(version), nIter);
629+
LOG(kDefLog, kInfo, "Dispatching Kernel version %d, %d iterations ...",
630+
version, nIter);
767631

768632
// Dispatch kernel nIter times
769633
auto start = std::chrono::high_resolution_clock::now();
@@ -798,43 +662,26 @@ void runTest(int version, size_t M, size_t K, size_t N,
798662
M, K, N, nIter, duration.count() / static_cast<double>(nIter) / 1000.0 /* us -> ms */, gflops);
799663
}
800664

801-
const char* versionToStr(int version){
802-
switch (version) {
803-
case 1: return "No-Op";
804-
case 2: return "naive matmul";
805-
case 3: return "tiling";
806-
case 4: return "1D blocktiling";
807-
case 5: return "2D blocktiling";
808-
case 6: return "1D blocktiling with loop unrolling";
809-
case 7: return "2D blocktiling with loop unrolling";
810-
case 8: return "2D blocktiling with loop unrolling and vectorization";
811-
case 9: return "2D blocktiling with loop unrolling, vectorization and transpose";
812-
default: return "Not specified";
813-
}
814-
}
815-
816665
int main() {
817666
char* version_str = getenv("MATMUL_VERSION");
818-
char* kTestSize_str = getenv("MATMUL_SIZE");
819-
int version = version_str == NULL ? 9 : atoi(version_str);
820-
// 1 == No-Op
821-
// 2 == naive matmul
822-
// 3 == tiling
823-
// 4 == 1D blocktiling
824-
// 5 == 2D blocktiling
825-
// 6 == 1D blocktiling with loop unrolling
826-
// 7 == 2D blocktiling with loop unrolling
827-
// 8 == 2D blocktiling with loop unrolling and vectorization
828-
// 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
667+
int version = version_str == NULL ? 7 : atoi(version_str);
668+
// 1 == naive matmul
669+
// 2 == tiling
670+
// 3 == 1D blocktiling
671+
// 4 == 2D blocktiling
672+
// 5 == 1D blocktiling with loop unrolling
673+
// 6 == 2D blocktiling with loop unrolling
674+
// 7 == 2D blocktiling with loop unrolling and vectorization
675+
// 8 == No-Op
829676

830677
size_t M, K, N; // Matrix dimensions
831-
int kTestSize = kTestSize_str == NULL ? 2 : atoi(kTestSize_str);
832-
if (kTestSize == 0) {
678+
static constexpr int kTestSize = 2;
679+
if constexpr (kTestSize == 0) {
833680
// Tiny test
834681
M = 32;
835682
K = 32;
836683
N = 32;
837-
} else if (kTestSize == 1) {
684+
} else if constexpr (kTestSize == 1) {
838685
// Small test
839686
M = 256;
840687
K = 128;
@@ -849,19 +696,11 @@ int main() {
849696
std::unique_ptr<float[]> inputPtr = std::make_unique<float[]>(M * K);
850697
std::unique_ptr<float[]> weightsPtr = std::make_unique<float[]>(N * K);
851698
std::unique_ptr<float[]> outputPtr = std::make_unique<float[]>(M * N);
852-
bool transposedInput = version == 9;
853699

854700
initData(M, K, N, inputPtr, weightsPtr);
855-
if (transposedInput) {
856-
std::unique_ptr<float[]> transposedWeightPtr = std::make_unique<float[]>(K * N);
857-
transpose(weightsPtr.get(), transposedWeightPtr.get(), N, K);
858-
runTest(version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
859-
} else {
860-
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr);
861-
}
862-
701+
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr);
863702

864-
if (kTestSize <= 1) {
703+
if constexpr (kTestSize <= 1) {
865704
// Check result with CPU reference implementation for tiny/small tests
866705
checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr);
867706
}

0 commit comments

Comments
 (0)