@@ -792,13 +792,40 @@ void runTest(int version, size_t M, size_t K, size_t N,
792
792
}
793
793
794
794
// Allocate GPU buffers and copy data
795
- Context ctx = createContext (
796
- {}, {},
797
- /* device descriptor, enabling f16 in WGSL*/
798
- {
795
+ WGPUDeviceDescriptor devDescriptor = {};
796
+ devDescriptor.requiredFeatureCount = 1 ;
797
+ devDescriptor.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data ();
798
+
799
+ Context ctx;
800
+ if (numtype == kf16) {
801
+ ctx = createContext (
802
+ {}, {},
803
+ /* device descriptor, enabling f16 in WGSL*/
804
+ {
799
805
.requiredFeatureCount = 1 ,
800
- .requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data (),
801
- });
806
+ .requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data ()
807
+ });
808
+ if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) {
809
+ LOG (kDefLog , kError , " Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9)." );
810
+ exit (1 );
811
+ }
812
+ if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
813
+ LOG (kDefLog , kError , " Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)" );
814
+ exit (1 );
815
+ }
816
+ }
817
+
818
+ if (numtype == kf32) {
819
+ ctx = createContext ({}, {}, {});
820
+ if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
821
+ ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
822
+ LOG (kDefLog , kError , " Failed to create adapter or device" );
823
+ // stop execution
824
+ exit (1 );
825
+ } else {
826
+ LOG (kDefLog , kInfo , " Successfully created adapter and device" );
827
+ }
828
+ }
802
829
803
830
Tensor input = createTensor (ctx, Shape{M, K}, numtype, inputPtr.get ());
804
831
Tensor weights = createTensor (ctx, Shape{N, K}, numtype, weightsPtr.get ()); // column-major
@@ -810,8 +837,6 @@ void runTest(int version, size_t M, size_t K, size_t N,
810
837
#endif
811
838
812
839
// Initialize Kernel and bind GPU buffers
813
-
814
-
815
840
// pre-allocate for async dispatch
816
841
std::array<std::promise<void >, nIter> promises;
817
842
std::array<std::future<void >, nIter> futures;
@@ -823,10 +848,6 @@ void runTest(int version, size_t M, size_t K, size_t N,
823
848
kernels[i] = selectMatmul (ctx, version, {input, weights, outputs[i]}, M, K, N, numtype);
824
849
}
825
850
826
- #ifndef METAL_PROFILER
827
- printf (" [ Press enter to start tests ... ]\n " );
828
- getchar ();
829
- #endif
830
851
LOG (kDefLog , kInfo , " Dispatching Kernel version %d: %s, %d iterations ..." ,
831
852
version, versionToStr (version).c_str (), nIter);
832
853
0 commit comments