Skip to content

Commit d8d618a

Browse files
committed
make Context lifetime more robust dont rely on RVO which seems to fail if createContext is called from eg conditional branches. Remove webgpu from scratch tutorial to avoid having to maintain/update the implementation
1 parent 73f438a commit d8d618a

File tree

7 files changed

+158
-1810
lines changed

7 files changed

+158
-1810
lines changed

examples/hello_world/run.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
#include <cstdio>
44
#include <future>
55

6-
using namespace gpu; // createContext, createTensor, createKernel,
7-
// createShader, dispatchKernel, wait, toCPU
8-
// Tensor, Kernel, Context, Shape, kf32
6+
using namespace gpu;
97

108
static const char *kGelu = R"(
119
const GELU_SCALING_FACTOR: f32 = 0.7978845608028654; // sqrt(2.0 / PI)
@@ -29,6 +27,7 @@ int main(int argc, char **argv) {
2927
printf("\nHello gpu.cpp!\n");
3028
printf("--------------\n\n");
3129

30+
// std::unique_ptr<Context> ctx = createContext();
3231
Context ctx = createContext();
3332
static constexpr size_t N = 10000;
3433
std::array<float, N> inputArr, outputArr;
@@ -41,7 +40,7 @@ int main(int argc, char **argv) {
4140
std::future<void> future = promise.get_future();
4241
Kernel op = createKernel(ctx, {kGelu, 256, kf32},
4342
Bindings{input, output},
44-
/* nWorkgroups */ {cdiv(N, 256), 1, 1});
43+
{cdiv(N, 256), 1, 1});
4544
dispatchKernel(ctx, op, promise);
4645
wait(ctx, future);
4746
toCPU(ctx, output, outputArr.data(), sizeof(outputArr));
@@ -50,5 +49,4 @@ int main(int argc, char **argv) {
5049
}
5150
printf(" ...\n\n");
5251
printf("Computed %zu values of GELU(x)\n\n", N);
53-
return 0;
5452
}

examples/matmul/run.cpp

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -792,13 +792,40 @@ void runTest(int version, size_t M, size_t K, size_t N,
792792
}
793793

794794
// Allocate GPU buffers and copy data
795-
Context ctx = createContext(
796-
{}, {},
797-
/*device descriptor, enabling f16 in WGSL*/
798-
{
795+
WGPUDeviceDescriptor devDescriptor = {};
796+
devDescriptor.requiredFeatureCount = 1;
797+
devDescriptor.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data();
798+
799+
Context ctx;
800+
if (numtype == kf16) {
801+
ctx = createContext(
802+
{}, {},
803+
/*device descriptor, enabling f16 in WGSL*/
804+
{
799805
.requiredFeatureCount = 1,
800-
.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data(),
801-
});
806+
.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data()
807+
});
808+
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) {
809+
LOG(kDefLog, kError, "Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9).");
810+
exit(1);
811+
}
812+
if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
813+
LOG(kDefLog, kError, "Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)");
814+
exit(1);
815+
}
816+
}
817+
818+
if (numtype == kf32) {
819+
ctx = createContext({}, {}, {});
820+
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
821+
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
822+
LOG(kDefLog, kError, "Failed to create adapter or device");
823+
// stop execution
824+
exit(1);
825+
} else {
826+
LOG(kDefLog, kInfo, "Successfully created adapter and device");
827+
}
828+
}
802829

803830
Tensor input = createTensor(ctx, Shape{M, K}, numtype, inputPtr.get());
804831
Tensor weights = createTensor(ctx, Shape{N, K}, numtype, weightsPtr.get()); // column-major
@@ -810,8 +837,6 @@ void runTest(int version, size_t M, size_t K, size_t N,
810837
#endif
811838

812839
// Initialize Kernel and bind GPU buffers
813-
814-
815840
// pre-allocate for async dispatch
816841
std::array<std::promise<void>, nIter> promises;
817842
std::array<std::future<void>, nIter> futures;
@@ -823,10 +848,6 @@ void runTest(int version, size_t M, size_t K, size_t N,
823848
kernels[i] = selectMatmul(ctx, version, {input, weights, outputs[i]}, M, K, N, numtype);
824849
}
825850

826-
#ifndef METAL_PROFILER
827-
printf("[ Press enter to start tests ... ]\n");
828-
getchar();
829-
#endif
830851
LOG(kDefLog, kInfo, "Dispatching Kernel version %d: %s, %d iterations ...",
831852
version, versionToStr(version).c_str(), nIter);
832853

examples/webgpu_from_scratch/CMakeLists.txt

Lines changed: 0 additions & 21 deletions
This file was deleted.

examples/webgpu_from_scratch/Makefile

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)