@@ -613,6 +613,93 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
613613 return {unrolledCode, workgroupSize, precision};
614614}
615615
616+ inline KernelCode createMatmul12 (const char *shaderTemplate, const size_t M,
617+ const size_t K, const size_t N,
618+ const size_t TM, const size_t TN,
619+ const size_t LID,
620+ const Shape &workgroupSize = {256 , 1 , 1 },
621+ NumType precision = kf32) {
622+ std::string codeString (shaderTemplate);
623+ replaceAll (codeString, {{" {{precision}}" , toString (precision)},
624+ {" {{M}}" , toString (M)},
625+ {" {{K}}" , toString (K)},
626+ {" {{N}}" , toString (N)},
627+ {" {{TM}}" , toString (TM)},
628+ {" {{TN}}" , toString (TN)},
629+ {" {{LID}}" , toString (LID)}
630+ });
631+ return {loopUnrolling (codeString), workgroupSize, precision};
632+ }
633+
634+ // ─────────────────────────────────────────────────────────────────────────────
635+ // Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
636+ // and subgroupMatrixMultiplyAccumulate
637+ // ─────────────────────────────────────────────────────────────────────────────
638+ const char * kShaderSubgroupMatrixMultiply = R"(
639+ enable chromium_experimental_subgroup_matrix;
640+ diagnostic (off, chromium.subgroup_matrix_uniformity);
641+
642+ @group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
643+ @group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
644+ @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
645+
646+ @compute @workgroup_size({{workgroupSize}})
647+ fn main(@builtin(workgroup_id) wg: vec3<u32>,
648+ @builtin(local_invocation_id) localID : vec3<u32>) {
649+
650+ let rowStart: u32 = wg.x * 8u * {{TM}};
651+ let colStart: u32 = (wg.y * {{LID}} + localID.y) * 8u * {{TN}};
652+
653+ let baseA: u32 = rowStart * {{K}};
654+ let baseB: u32 = colStart;
655+ let cBase: u32 = rowStart * {{N}} + colStart;
656+
657+ var Ax: array<subgroup_matrix_left<{{precision}}, 8, 8>, {{TM}}>;
658+ var Bx: array<subgroup_matrix_right<{{precision}}, 8, 8>, {{TN}}>;
659+
660+ // 4x4 accumulators (8x8 each)
661+ var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;
662+
663+ for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
664+ Ax[idx_i] = subgroup_matrix_left<{{precision}}, 8, 8>(0);
665+ }
666+
667+ for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
668+ Bx[idx_i] = subgroup_matrix_right<{{precision}}, 8, 8>(0);
669+ }
670+
671+ for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
672+ for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
673+ accxx[idx_i+idx_j*{{TM}}] = subgroup_matrix_result<{{precision}}, 8, 8>(0);
674+ }
675+ }
676+
677+ for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
678+ workgroupBarrier();
679+ for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
680+ Ax[idx_i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + k + 8u * {{K}} * idx_i, false, {{K}});
681+ }
682+
683+ for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
684+ Bx[idx_i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k * {{N}} + 8u * idx_i, false, {{N}});
685+ }
686+
687+ for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
688+ for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
689+ accxx[idx_j*{{TM}} + idx_i] = subgroupMatrixMultiplyAccumulate(Ax[idx_i], Bx[idx_j], accxx[idx_j*{{TM}} + idx_i]);
690+ }
691+ }
692+ }
693+
694+ workgroupBarrier();
695+ for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
696+ for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
697+ subgroupMatrixStore(&C, cBase + idx_i * 8u * {{N}} + 8u * idx_j, accxx[idx_j*{{TM}} + idx_i], false, {{N}});
698+ }
699+ }
700+ }
701+ )" ;
702+
616703/* *
617704 * @brief No-Op shader with matmul bindings for performance testing
618705 */
@@ -683,26 +770,30 @@ Kernel selectMatmul(Context &ctx, int version,
683770 const Bindings</* input, weights, output */ 3 > &bindings,
684771 size_t M, size_t K, size_t N, NumType numtype) {
685772 Kernel kernel;
773+ CompilationInfo info;
686774 if (version == 1 ) {
687775 Shape wgSize = {256 , 1 , 1 };
688776 Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
689777 KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
690778 kernel = createKernel (ctx, matmul, bindings,
691- /* nWorkgroups*/ nWorkgroups);
779+ /* nWorkgroups*/ nWorkgroups,
780+ NoParam{}, &info);
692781 } else if (version == 2 ) {
693782 Shape wgSize = {16 , 16 , 1 };
694783 LOG (kDefLog , kInfo , " wgSize: %s" , toString (wgSize).c_str ());
695784 KernelCode matmul =
696785 createMatmul1 (kShaderMatmul1 , M, K, N, /* wgsize*/ wgSize, numtype);
697786 kernel = createKernel (ctx, matmul, bindings,
698- /* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize));
787+ /* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize),
788+ NoParam{}, &info);
699789 } else if (version == 3 ) {
700790 static constexpr size_t tileSize = 16 ;
701791 KernelCode matmul = createMatmul2 (kShaderMatmul2 , M, K, N,
702792 /* wgSize*/ {tileSize * tileSize, 1 , 1 }, numtype);
703793 kernel =
704794 createKernel (ctx, matmul, bindings,
705- /* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }));
795+ /* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }),
796+ NoParam{}, &info);
706797 } else if (version == 4 || version == 6 ) {
707798 static constexpr size_t BM = 64 ;
708799 static constexpr size_t BK = 4 ;
@@ -721,7 +812,8 @@ Kernel selectMatmul(Context &ctx, int version,
721812 numtype,
722813 /* Loop unrolling*/ version == 6 ? true : false );
723814 kernel = createKernel (ctx, matmul, bindings,
724- /* nWorkgroups*/ nWorkgroups);
815+ /* nWorkgroups*/ nWorkgroups,
816+ NoParam{}, &info);
725817 } else if (version == 5 || version == 7 ) {
726818 static constexpr size_t BM = 64 ;
727819 static constexpr size_t BK = 8 ;
@@ -739,7 +831,8 @@ Kernel selectMatmul(Context &ctx, int version,
739831 numtype,
740832 /* Loop unrolling*/ version == 7 ? true : false );
741833 kernel = createKernel (ctx, matmul, bindings,
742- /* nWorkgroups*/ nWorkgroups);
834+ /* nWorkgroups*/ nWorkgroups,
835+ NoParam{}, &info);
743836 } else if (version == 8 || version == 10 ) {
744837 static constexpr size_t BM = 64 ;
745838 static constexpr size_t BK = 8 ;
@@ -757,7 +850,8 @@ Kernel selectMatmul(Context &ctx, int version,
757850 numtype,
758851 /* Loop unrolling*/ true );
759852 kernel = createKernel (ctx, matmul, bindings,
760- /* nWorkgroups*/ nWorkgroups);
853+ /* nWorkgroups*/ nWorkgroups,
854+ NoParam{}, &info);
761855 } else if (version == 9 || version == 11 ) {
762856 static constexpr size_t BM = 64 ;
763857 static constexpr size_t BK = 8 ;
@@ -774,8 +868,38 @@ Kernel selectMatmul(Context &ctx, int version,
774868 /* wgSize*/ wgSize,
775869 numtype);
776870 kernel = createKernel (ctx, matmul, bindings,
777- /* nWorkgroups*/ nWorkgroups);
871+ /* nWorkgroups*/ nWorkgroups,
872+ NoParam{}, &info);
873+ } else if (version == 12 || version == 13 ) {
874+ // f16: Subgroup matrix multiply
875+ static constexpr size_t TM = 4 ;
876+ static constexpr size_t TN = 8 ;
877+ static constexpr size_t LID = 2 ;
878+ Shape wgSize = {32 , LID, 1 }; // One subgroup per workgroup
879+ Shape nWorkgroups = {cdiv (M, 8 * TM), cdiv (N, 8 * TN * LID), 1 };
880+ LOG (kDefLog , kInfo , " M: %zu, K: %zu, N: %zu" , M, K, N);
881+ LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
882+ LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
883+ KernelCode matmul = createMatmul12 (kShaderSubgroupMatrixMultiply , M, K, N, TM, TN, LID, wgSize, numtype);
884+ kernel = createKernel (ctx, matmul, bindings, nWorkgroups,
885+ NoParam{}, &info);
886+ }
887+
888+ if (info.status != WGPUCompilationInfoRequestStatus_Success) {
889+ LOG (kDefLog , kError , " Failed to compile shader" );
890+ for (size_t i = 0 ; i < info.messages .size (); i++) {
891+ LOG (kDefLog , kError , " Line %llu, Pos %llu: %s" , info.lineNums [i],
892+ info.linePos [i], info.messages [i].c_str ());
893+ }
894+ exit (1 );
895+ } else {
896+ LOG (kDefLog , kInfo , " Shader compiled successfully" );
897+ for (size_t i = 0 ; i < info.messages .size (); i++) {
898+ LOG (kDefLog , kInfo , " Line %llu, Pos %llu: %s" , info.lineNums [i],
899+ info.linePos [i], info.messages [i].c_str ());
900+ }
778901 }
902+
779903 return kernel;
780904}
781905
@@ -791,41 +915,51 @@ void runTest(int version, size_t M, size_t K, size_t N,
791915 assert (numtype == kf16);
792916 }
793917
794- // Allocate GPU buffers and copy data
795- WGPUDeviceDescriptor devDescriptor = {} ;
796- devDescriptor. requiredFeatureCount = 1 ;
797- devDescriptor. requiredFeatures = std::array{WGPUFeatureName_ShaderF16}. data () ;
798-
799- Context ctx;
800- if (numtype == kf16) {
801- ctx = createContext (
802- {}, {} ,
803- /* device descriptor, enabling f16 in WGSL */
804- {
805- . requiredFeatureCount = 1 ,
806- . requiredFeatures = std::array{WGPUFeatureName_ShaderF16}. data ()
807- } );
808- if (ctx. adapterStatus != WGPURequestAdapterStatus_Success) {
809- LOG ( kDefLog , kError , " Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9). " );
810- exit ( 1 );
918+ static WGPUDawnTogglesDescriptor toggles = {};
919+ toggles. chain . sType = WGPUSType_DawnTogglesDescriptor ;
920+ const char * enableList[] = { " allow_unsafe_apis " } ;
921+ toggles. enabledToggles = enableList ;
922+ toggles. enabledToggleCount = 1 ;
923+
924+ static WGPUDeviceDescriptor devDesc = {};
925+ devDesc. nextInChain = &toggles. chain ;
926+ devDesc. requiredFeatureCount = 3 ,
927+ devDesc. requiredFeatures = std::array{
928+ WGPUFeatureName_ShaderF16,
929+ WGPUFeatureName_Subgroups ,
930+ WGPUFeatureName_ChromiumExperimentalSubgroupMatrix
931+ }. data ( );
932+ devDesc. uncapturedErrorCallbackInfo = WGPUUncapturedErrorCallbackInfo {
933+ . callback = [](WGPUDevice const * device, WGPUErrorType type, WGPUStringView msg, void *, void *) {
934+ LOG ( kDefLog , kError , " [Uncaptured %d] %.*s \n " , ( int )type, ( int )msg. length , msg. data );
811935 }
812- if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
813- LOG (kDefLog , kError , " Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)" );
814- exit (1 );
936+ };
937+ devDesc.deviceLostCallbackInfo = WGPUDeviceLostCallbackInfo {
938+ .mode = WGPUCallbackMode_AllowSpontaneous,
939+ .callback = [](WGPUDevice const * device, WGPUDeviceLostReason reason, WGPUStringView msg, void *, void *) {
940+ LOG (kDefLog , kError , " [DeviceLost %d] %.*s\n " , (int )reason, (int )msg.length , msg.data );
815941 }
816- }
942+ };
817943
818- if (numtype == kf32) {
819- ctx = createContext ({}, {}, {});
820- if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
821- ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
822- LOG (kDefLog , kError , " Failed to create adapter or device" );
823- // stop execution
824- exit (1 );
825- } else {
826- LOG (kDefLog , kInfo , " Successfully created adapter and device" );
944+ static WGPULimits requiredLimits = WGPU_LIMITS_INIT;
945+ devDesc.requiredLimits = &requiredLimits;
946+ Context ctx = createContext ({}, {}, devDesc);
947+
948+ WGPULoggingCallbackInfo logCb{
949+ .callback = [](WGPULoggingType type, WGPUStringView msg, void *, void *) {
950+ LOG (kDefLog , kError , " [WGPU %d] %.*s\n " , (int )type, (int )msg.length , msg.data );
827951 }
828- }
952+ };
953+ wgpuDeviceSetLoggingCallback (ctx.device , logCb);
954+
955+ if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
956+ ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
957+ LOG (kDefLog , kError , " Failed to create adapter or device" );
958+ // stop execution
959+ exit (1 );
960+ } else {
961+ LOG (kDefLog , kInfo , " Successfully created adapter and device" );
962+ }
829963
830964 Tensor input = createTensor (ctx, Shape{M, K}, numtype, inputPtr.get ());
831965 Tensor weights = createTensor (ctx, Shape{N, K}, numtype, weightsPtr.get ()); // column-major
@@ -859,7 +993,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
859993 // Use microsecond for more accurate time measurement
860994 auto duration =
861995 std::chrono::duration_cast<std::chrono::microseconds>(end - start);
862- float gflops = 2 * M * N *
996+ float gflops = 2 . 0f * M * N *
863997 K / // factor of 2 for multiplication & accumulation
864998 (static_cast <double >(duration.count ()) / 1000000.0 ) /
865999 1000000000.0 * static_cast <float >(nIter);
@@ -870,7 +1004,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
8701004 show<precision>(outputPtr.get (), M, N, " Output[0]" ).c_str ());
8711005
8721006 LOG (kDefLog , kInfo , " \n\n ===================================================================="
873- " ============\n Execution Time: (M = %d , K = %d , N = %d ) x %d iterations "
1007+ " ============\n Execution Time: (M = %zu , K = %zu , N = %zu ) x %zu iterations "
8741008 " :\n %.1f "
8751009 " milliseconds / dispatch ~ %.2f "
8761010 " GFLOPS\n ================================================================"
@@ -913,13 +1047,16 @@ const std::string versionToStr(int version){
9131047 case 9 : return " f32: 2D blocktiling with loop unrolling, vectorization and transpose" ;
9141048 case 10 : return " f16: 2D blocktiling with loop unrolling and vectorization" ;
9151049 case 11 : return " f16: 2D blocktiling with loop unrolling, vectorization and transpose" ;
1050+ case 12 : return " f16: Subgroup matrix multiply with transpose (default)" ;
1051+ case 13 : return " f32: Subgroup matrix multiply with transpose" ;
9161052 default : return " Not specified" ;
9171053 }
9181054}
9191055
9201056int main () {
1057+ std::cout << " Starting matmul test..." << std::endl;
9211058 char * version_str = getenv (" MATMUL_VERSION" );
922- int version = version_str == NULL ? 10 : atoi (version_str);
1059+ int version = version_str == NULL ? 12 : atoi (version_str);
9231060 // 1 == f32: No-Op
9241061 // 2 == f32: naive matmul
9251062 // 3 == f32: tiling
@@ -931,8 +1068,10 @@ int main() {
9311068 // 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
9321069 // 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
9331070 // 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
934- bool enableF16 = version == 10 || version ==11 ;
935- bool transposedInput = version == 9 || version == 11 ;
1071+ // 12 == f16: Subgroup matrix multiply with transpose (default)
1072+ // 13 == f32: Subgroup matrix multiply with transpose
1073+ bool enableF16 = version == 10 || version ==11 || version == 12 ;
1074+ bool transposedInput = version == 9 || version == 11 || version == 12 || version == 13 ;
9361075 NumType numtype = enableF16 ? kf16 : kf32;
9371076
9381077 size_t M, K, N; // Matrix dimensions
0 commit comments