13
13
#include " experimental/wgsl.h" // loopUnrolling
14
14
#include " numeric_types/half.hpp"
15
15
16
+ #ifdef METAL_PROFILER
17
+ #include " experimental/profiler/metal.hpp"
18
+ #endif
19
+
16
20
using namespace gpu ;
17
21
18
22
const std::string versionToStr (int version);
@@ -799,7 +803,11 @@ void runTest(int version, size_t M, size_t K, size_t N,
799
803
Tensor input = createTensor (ctx, Shape{M, K}, numtype, inputPtr.get ());
800
804
Tensor weights = createTensor (ctx, Shape{N, K}, numtype, weightsPtr.get ()); // column-major
801
805
806
+ #ifdef METAL_PROFILER
807
+ constexpr size_t nIter = 1 ;
808
+ #else
802
809
constexpr size_t nIter = 30 ;
810
+ #endif
803
811
804
812
// Initialize Kernel and bind GPU buffers
805
813
@@ -815,8 +823,10 @@ void runTest(int version, size_t M, size_t K, size_t N,
815
823
kernels[i] = selectMatmul (ctx, version, {input, weights, outputs[i]}, M, K, N, numtype);
816
824
}
817
825
826
+ #ifndef METAL_PROFILER
818
827
printf (" [ Press enter to start tests ... ]\n " );
819
828
getchar ();
829
+ #endif
820
830
LOG (kDefLog , kInfo , " Dispatching Kernel version %d: %s, %d iterations ..." ,
821
831
version, versionToStr (version).c_str (), nIter);
822
832
@@ -930,11 +940,17 @@ int main() {
930
940
N = 2 * 4096 ;
931
941
}
932
942
943
+ #ifdef METAL_PROFILER
944
+ startCapture ();
945
+ #endif
933
946
if (enableF16) {
934
947
runTestWithCheck<half>(version, M, K, N, transposedInput, kTestSize , numtype);
935
948
} else {
936
949
runTestWithCheck<float >(version, M, K, N, transposedInput, kTestSize , numtype);
937
950
}
951
+ #ifdef METAL_PROFILER
952
+ stopCapture ();
953
+ #endif
938
954
939
955
LOG (kDefLog , kInfo , " Done." );
940
956
return 0 ;
0 commit comments