Skip to content

Commit 17614f7

Browse files
Add the SubgroupMatrixMultiply shader
1 parent 3d6e51c commit 17614f7

File tree

1 file changed

+182
-43
lines changed

1 file changed

+182
-43
lines changed

examples/matmul/run.cpp

Lines changed: 182 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,93 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
613613
return {unrolledCode, workgroupSize, precision};
614614
}
615615

616+
inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
617+
const size_t K, const size_t N,
618+
const size_t TM, const size_t TN,
619+
const size_t LID,
620+
const Shape &workgroupSize = {256, 1, 1},
621+
NumType precision = kf32) {
622+
std::string codeString(shaderTemplate);
623+
replaceAll(codeString, {{"{{precision}}", toString(precision)},
624+
{"{{M}}", toString(M)},
625+
{"{{K}}", toString(K)},
626+
{"{{N}}", toString(N)},
627+
{"{{TM}}", toString(TM)},
628+
{"{{TN}}", toString(TN)},
629+
{"{{LID}}", toString(LID)}
630+
});
631+
return {loopUnrolling(codeString), workgroupSize, precision};
632+
}
633+
634+
// ─────────────────────────────────────────────────────────────────────────────
635+
// Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
636+
// and subgroupMatrixMultiplyAccumulate
637+
// ─────────────────────────────────────────────────────────────────────────────
638+
const char* kShaderSubgroupMatrixMultiply = R"(
639+
enable chromium_experimental_subgroup_matrix;
640+
diagnostic (off, chromium.subgroup_matrix_uniformity);
641+
642+
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
643+
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
644+
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
645+
646+
@compute @workgroup_size({{workgroupSize}})
647+
fn main(@builtin(workgroup_id) wg: vec3<u32>,
648+
@builtin(local_invocation_id) localID : vec3<u32>) {
649+
650+
let rowStart: u32 = wg.x * 8u * {{TM}};
651+
let colStart: u32 = (wg.y * {{LID}} + localID.y) * 8u * {{TN}};
652+
653+
let baseA: u32 = rowStart * {{K}};
654+
let baseB: u32 = colStart;
655+
let cBase: u32 = rowStart * {{N}} + colStart;
656+
657+
var Ax: array<subgroup_matrix_left<{{precision}}, 8, 8>, {{TM}}>;
658+
var Bx: array<subgroup_matrix_right<{{precision}}, 8, 8>, {{TN}}>;
659+
660+
// 4x4 accumulators (8x8 each)
661+
var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;
662+
663+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
664+
Ax[idx_i] = subgroup_matrix_left<{{precision}}, 8, 8>(0);
665+
}
666+
667+
for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
668+
Bx[idx_i] = subgroup_matrix_right<{{precision}}, 8, 8>(0);
669+
}
670+
671+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
672+
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
673+
accxx[idx_i+idx_j*{{TM}}] = subgroup_matrix_result<{{precision}}, 8, 8>(0);
674+
}
675+
}
676+
677+
for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
678+
workgroupBarrier();
679+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
680+
Ax[idx_i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + k + 8u * {{K}} * idx_i, false, {{K}});
681+
}
682+
683+
for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
684+
Bx[idx_i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k * {{N}} + 8u * idx_i, false, {{N}});
685+
}
686+
687+
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
688+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
689+
accxx[idx_j*{{TM}} + idx_i] = subgroupMatrixMultiplyAccumulate(Ax[idx_i], Bx[idx_j], accxx[idx_j*{{TM}} + idx_i]);
690+
}
691+
}
692+
}
693+
694+
workgroupBarrier();
695+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
696+
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
697+
subgroupMatrixStore(&C, cBase + idx_i * 8u * {{N}} + 8u * idx_j, accxx[idx_j*{{TM}} + idx_i], false, {{N}});
698+
}
699+
}
700+
}
701+
)";
702+
616703
/**
617704
* @brief No-Op shader with matmul bindings for performance testing
618705
*/
@@ -683,26 +770,30 @@ Kernel selectMatmul(Context &ctx, int version,
683770
const Bindings</* input, weights, output */ 3> &bindings,
684771
size_t M, size_t K, size_t N, NumType numtype) {
685772
Kernel kernel;
773+
CompilationInfo info;
686774
if (version == 1) {
687775
Shape wgSize = {256, 1, 1};
688776
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
689777
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
690778
kernel = createKernel(ctx, matmul, bindings,
691-
/*nWorkgroups*/ nWorkgroups);
779+
/*nWorkgroups*/ nWorkgroups,
780+
NoParam{}, &info);
692781
} else if (version == 2) {
693782
Shape wgSize = {16, 16, 1};
694783
LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str());
695784
KernelCode matmul =
696785
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize, numtype);
697786
kernel = createKernel(ctx, matmul, bindings,
698-
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize));
787+
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize),
788+
NoParam{}, &info);
699789
} else if (version == 3) {
700790
static constexpr size_t tileSize = 16;
701791
KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N,
702792
/*wgSize*/ {tileSize * tileSize, 1, 1}, numtype);
703793
kernel =
704794
createKernel(ctx, matmul, bindings,
705-
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
795+
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}),
796+
NoParam{}, &info);
706797
} else if (version == 4 || version == 6) {
707798
static constexpr size_t BM = 64;
708799
static constexpr size_t BK = 4;
@@ -721,7 +812,8 @@ Kernel selectMatmul(Context &ctx, int version,
721812
numtype,
722813
/*Loop unrolling*/ version == 6 ? true: false);
723814
kernel = createKernel(ctx, matmul, bindings,
724-
/*nWorkgroups*/ nWorkgroups);
815+
/*nWorkgroups*/ nWorkgroups,
816+
NoParam{}, &info);
725817
} else if (version == 5 || version == 7) {
726818
static constexpr size_t BM = 64;
727819
static constexpr size_t BK = 8;
@@ -739,7 +831,8 @@ Kernel selectMatmul(Context &ctx, int version,
739831
numtype,
740832
/*Loop unrolling*/ version == 7 ? true: false);
741833
kernel = createKernel(ctx, matmul, bindings,
742-
/*nWorkgroups*/ nWorkgroups);
834+
/*nWorkgroups*/ nWorkgroups,
835+
NoParam{}, &info);
743836
} else if (version == 8 || version == 10) {
744837
static constexpr size_t BM = 64;
745838
static constexpr size_t BK = 8;
@@ -757,7 +850,8 @@ Kernel selectMatmul(Context &ctx, int version,
757850
numtype,
758851
/*Loop unrolling*/ true);
759852
kernel = createKernel(ctx, matmul, bindings,
760-
/*nWorkgroups*/ nWorkgroups);
853+
/*nWorkgroups*/ nWorkgroups,
854+
NoParam{}, &info);
761855
} else if (version == 9 || version == 11) {
762856
static constexpr size_t BM = 64;
763857
static constexpr size_t BK = 8;
@@ -774,8 +868,38 @@ Kernel selectMatmul(Context &ctx, int version,
774868
/*wgSize*/ wgSize,
775869
numtype);
776870
kernel = createKernel(ctx, matmul, bindings,
777-
/*nWorkgroups*/ nWorkgroups);
871+
/*nWorkgroups*/ nWorkgroups,
872+
NoParam{}, &info);
873+
} else if (version == 12 || version == 13) {
874+
// f16: Subgroup matrix multiply
875+
static constexpr size_t TM = 4;
876+
static constexpr size_t TN = 8;
877+
static constexpr size_t LID = 2;
878+
Shape wgSize = {32, LID, 1}; // One subgroup per workgroup
879+
Shape nWorkgroups = {cdiv(M, 8 * TM), cdiv(N, 8 * TN * LID), 1};
880+
LOG(kDefLog, kInfo, "M: %zu, K: %zu, N: %zu", M, K, N);
881+
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
882+
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
883+
KernelCode matmul = createMatmul12(kShaderSubgroupMatrixMultiply, M, K, N, TM, TN, LID, wgSize, numtype);
884+
kernel = createKernel(ctx, matmul, bindings, nWorkgroups,
885+
NoParam{}, &info);
886+
}
887+
888+
if (info.status != WGPUCompilationInfoRequestStatus_Success) {
889+
LOG(kDefLog, kError, "Failed to compile shader");
890+
for (size_t i = 0; i < info.messages.size(); i++) {
891+
LOG(kDefLog, kError, "Line %llu, Pos %llu: %s", info.lineNums[i],
892+
info.linePos[i], info.messages[i].c_str());
893+
}
894+
exit(1);
895+
} else {
896+
LOG(kDefLog, kInfo, "Shader compiled successfully");
897+
for (size_t i = 0; i < info.messages.size(); i++) {
898+
LOG(kDefLog, kInfo, "Line %llu, Pos %llu: %s", info.lineNums[i],
899+
info.linePos[i], info.messages[i].c_str());
900+
}
778901
}
902+
779903
return kernel;
780904
}
781905

@@ -791,41 +915,51 @@ void runTest(int version, size_t M, size_t K, size_t N,
791915
assert(numtype == kf16);
792916
}
793917

794-
// Allocate GPU buffers and copy data
795-
WGPUDeviceDescriptor devDescriptor = {};
796-
devDescriptor.requiredFeatureCount = 1;
797-
devDescriptor.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data();
798-
799-
Context ctx;
800-
if (numtype == kf16) {
801-
ctx = createContext(
802-
{}, {},
803-
/*device descriptor, enabling f16 in WGSL*/
804-
{
805-
.requiredFeatureCount = 1,
806-
.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data()
807-
});
808-
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) {
809-
LOG(kDefLog, kError, "Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9).");
810-
exit(1);
918+
static WGPUDawnTogglesDescriptor toggles = {};
919+
toggles.chain.sType = WGPUSType_DawnTogglesDescriptor;
920+
const char* enableList[] = {"allow_unsafe_apis"};
921+
toggles.enabledToggles = enableList;
922+
toggles.enabledToggleCount = 1;
923+
924+
static WGPUDeviceDescriptor devDesc = {};
925+
devDesc.nextInChain = &toggles.chain;
926+
devDesc.requiredFeatureCount = 3,
927+
devDesc.requiredFeatures = std::array{
928+
WGPUFeatureName_ShaderF16,
929+
WGPUFeatureName_Subgroups,
930+
WGPUFeatureName_ChromiumExperimentalSubgroupMatrix
931+
}.data();
932+
devDesc.uncapturedErrorCallbackInfo = WGPUUncapturedErrorCallbackInfo {
933+
.callback = [](WGPUDevice const * device, WGPUErrorType type, WGPUStringView msg, void*, void*) {
934+
LOG(kDefLog, kError, "[Uncaptured %d] %.*s\n", (int)type, (int)msg.length, msg.data);
811935
}
812-
if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
813-
LOG(kDefLog, kError, "Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)");
814-
exit(1);
936+
};
937+
devDesc.deviceLostCallbackInfo = WGPUDeviceLostCallbackInfo {
938+
.mode = WGPUCallbackMode_AllowSpontaneous,
939+
.callback = [](WGPUDevice const * device, WGPUDeviceLostReason reason, WGPUStringView msg, void*, void*) {
940+
LOG(kDefLog, kError, "[DeviceLost %d] %.*s\n", (int)reason, (int)msg.length, msg.data);
815941
}
816-
}
942+
};
817943

818-
if (numtype == kf32) {
819-
ctx = createContext({}, {}, {});
820-
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
821-
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
822-
LOG(kDefLog, kError, "Failed to create adapter or device");
823-
// stop execution
824-
exit(1);
825-
} else {
826-
LOG(kDefLog, kInfo, "Successfully created adapter and device");
944+
static WGPULimits requiredLimits = WGPU_LIMITS_INIT;
945+
devDesc.requiredLimits = &requiredLimits;
946+
Context ctx = createContext({}, {}, devDesc);
947+
948+
WGPULoggingCallbackInfo logCb{
949+
.callback = [](WGPULoggingType type, WGPUStringView msg, void*, void*) {
950+
LOG(kDefLog, kError, "[WGPU %d] %.*s\n", (int)type, (int)msg.length, msg.data);
827951
}
828-
}
952+
};
953+
wgpuDeviceSetLoggingCallback(ctx.device, logCb);
954+
955+
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
956+
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
957+
LOG(kDefLog, kError, "Failed to create adapter or device");
958+
// stop execution
959+
exit(1);
960+
} else {
961+
LOG(kDefLog, kInfo, "Successfully created adapter and device");
962+
}
829963

830964
Tensor input = createTensor(ctx, Shape{M, K}, numtype, inputPtr.get());
831965
Tensor weights = createTensor(ctx, Shape{N, K}, numtype, weightsPtr.get()); // column-major
@@ -859,7 +993,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
859993
// Use microsecond for more accurate time measurement
860994
auto duration =
861995
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
862-
float gflops = 2 * M * N *
996+
float gflops = 2.0f * M * N *
863997
K / // factor of 2 for multiplication & accumulation
864998
(static_cast<double>(duration.count()) / 1000000.0) /
865999
1000000000.0 * static_cast<float>(nIter);
@@ -870,7 +1004,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
8701004
show<precision>(outputPtr.get(), M, N, "Output[0]").c_str());
8711005

8721006
LOG(kDefLog, kInfo, "\n\n===================================================================="
873-
"============\nExecution Time: (M = %d, K = %d, N = %d) x %d iterations "
1007+
"============\nExecution Time: (M = %zu, K = %zu, N = %zu) x %zu iterations "
8741008
":\n%.1f "
8751009
"milliseconds / dispatch ~ %.2f "
8761010
"GFLOPS\n================================================================"
@@ -913,13 +1047,16 @@ const std::string versionToStr(int version){
9131047
case 9: return "f32: 2D blocktiling with loop unrolling, vectorization and transpose";
9141048
case 10: return "f16: 2D blocktiling with loop unrolling and vectorization";
9151049
case 11: return "f16: 2D blocktiling with loop unrolling, vectorization and transpose";
1050+
case 12: return "f16: Subgroup matrix multiply with transpose (default)";
1051+
case 13: return "f32: Subgroup matrix multiply with transpose";
9161052
default: return "Not specified";
9171053
}
9181054
}
9191055

9201056
int main() {
1057+
std::cout << "Starting matmul test..." << std::endl;
9211058
char* version_str = getenv("MATMUL_VERSION");
922-
int version = version_str == NULL ? 10 : atoi(version_str);
1059+
int version = version_str == NULL ? 12 : atoi(version_str);
9231060
// 1 == f32: No-Op
9241061
// 2 == f32: naive matmul
9251062
// 3 == f32: tiling
@@ -931,8 +1068,10 @@ int main() {
9311068
// 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
9321069
// 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
9331070
// 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
934-
bool enableF16 = version == 10 || version ==11;
935-
bool transposedInput = version == 9 || version == 11;
1071+
// 12 == f16: Subgroup matrix multiply with transpose (default)
1072+
// 13 == f32: Subgroup matrix multiply with transpose
1073+
bool enableF16 = version == 10 || version ==11 || version == 12;
1074+
bool transposedInput = version == 9 || version == 11 || version == 12 || version == 13;
9361075
NumType numtype = enableF16 ? kf16 : kf32;
9371076

9381077
size_t M, K, N; // Matrix dimensions

0 commit comments

Comments
 (0)