Skip to content

Commit 8570b5c

Browse files
committed
remove deprecated deviceLostCallbackInfo, add array constructor for Bindings, improve error handling for wasm gpu puzzles
1 parent 73c07af commit 8570b5c

File tree

5 files changed

+119
-81
lines changed

5 files changed

+119
-81
lines changed

experimental/fasthtml/gpu_puzzles/Makefile

+9-5
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,30 @@ COMMON_FLAGS=-std=c++17 -s USE_WEBGPU=1 -s ASYNCIFY=1 -I$(GPUCPP)
33
# COMMON_FLAGS=-std=c++17 -s USE_WEBGPU=1 -I$(GPUCPP)
44
# Note - no spaces after comma
55
# enable exceptions to recover from WGSL failure
6-
JS_FLAGS=-s EXPORTED_RUNTIME_METHODS=['UTF8ToString','setValue','addFunction'] -s EXPORTED_FUNCTIONS=['_malloc','_free','_executeKernel','_runCheck'] -s DISABLE_EXCEPTION_CATCHING=0
7-
WASM_FLAGS=-s STANDALONE_WASM -s ERROR_ON_UNDEFINED_SYMBOLS=0 -s EXPORTED_FUNCTIONS=['_executeKernel','_runCheck'] -s EXPORTED_RUNTIME_METHODS=['ccall','cwrap'] -DSTANDALONE_WASM
6+
JS_FLAGS=-s EXPORTED_RUNTIME_METHODS=['UTF8ToString','setValue','addFunction','customPrint'] -s EXPORTED_FUNCTIONS=['_malloc','_free','_evaluate'] -s DISABLE_EXCEPTION_CATCHING=0
7+
WASM_FLAGS=-s STANDALONE_WASM -s ERROR_ON_UNDEFINED_SYMBOLS=0 -s EXPORTED_FUNCTIONS=['_evaluate'] -s EXPORTED_RUNTIME_METHODS=['ccall','cwrap'] -DSTANDALONE_WASM
88
MODULARIZE_FLAGS=-s EXPORT_NAME='createModule' -s MODULARIZE=1 --bind
9-
NO_MODULARIZE_FLAGS=-s EXPORTED_FUNCTIONS=['_executeKernel','_runCheck'] -s EXPORTED_RUNTIME_METHODS=['ccall','cwrap'] --bind
9+
NO_MODULARIZE_FLAGS=-s EXPORTED_FUNCTIONS=['_evaluate'] -s EXPORTED_RUNTIME_METHODS=['ccall','cwrap'] --bind
1010

11-
.PHONY: default cmake check-emsdk browser clean server
11+
.PHONY: default cmake check-emsdk browser clean server debug
1212

1313
default: server
1414

1515
build/run.js: check-emsdk run.cpp
1616
em++ run.cpp -o build/run.js \
1717
$(COMMON_FLAGS) $(JS_FLAGS) $(MODULARIZE_FLAGS)
1818

19+
debug: check-emsdk run.cpp
20+
em++ -g -gsource-map run.cpp -o build/run.js \
21+
$(COMMON_FLAGS) $(JS_FLAGS) $(MODULARIZE_FLAGS)
22+
1923
build/run.wasm: check-emsdk run.cpp
2024
em++ run.cpp -o build/run.wasm \
2125
$(COMMON_FLAGS) $(WASM_FLAGS)
2226

2327
watch:
2428
@echo "Watching for changes..."
25-
ls run.cpp | entr -c make build/run.js
29+
ls run.cpp | entr -c make debug
2630

2731
server: build/run.js
2832
python3 run.py

experimental/fasthtml/gpu_puzzles/client.js

+32-10
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,13 @@ function updateDispatchParams() {
209209
}
210210

211211
async function updateEditor() {
212-
213212
function waitForDispatchReady() {
214213
return new Promise((resolve) => {
215214
function checkReady() {
216215
if (AppState.isDispatchReady) {
217216
resolve();
218217
} else {
219-
console.log("Waiting...");
218+
console.log("Waiting for dispatch to be ready");
220219
setTimeout(checkReady, 100); // Check every 100ms
221220
}
222221
}
@@ -228,28 +227,51 @@ async function updateEditor() {
228227
createModule().then((Module) => {
229228
console.log("updateEditor() - Module ready");
230229
});
231-
if (AppState.module && AppState.module.runCheck) {
230+
if (AppState.module) {
232231
if (!AppState.isDispatchReady) {
233-
console.log("Waiting for dispatch to be ready");
234232
await waitForDispatchReady();
235233
}
236-
237234
console.log("Executing kernel");
238235
AppState.terminal.clear();
239236
console.log("Code:\n", AppState.preamble + AppState.editor.getValue());
240237
AppState.isDispatchReady = false;
241-
AppState.module
242-
.runCheck(
238+
try {
239+
promise = AppState.module.evaluate(
243240
AppState.preamble + AppState.editor.getValue(),
244241
AppState.wgSize,
245242
AppState.gridSize,
246243
)
244+
.catch((error) => {
245+
console.error("execution failed", error);
246+
AppState.isDispatchReady = true;
247+
console.log("dispatch ready");
248+
render();
249+
})
247250
.then((result) => {
248251
console.log("check:", result);
249252
AppState.checkAnswer = result;
250253
AppState.isDispatchReady = true;
254+
console.log("dispatch ready");
251255
render();
252-
});
256+
})
257+
.finally(() => {
258+
console.log("finally");
259+
AppState.isDispatchReady = true;
260+
console.log("dispatch ready");
261+
})
262+
;
263+
} catch (error) {
264+
console.error("execution failed 2", error);
265+
AppState.isDispatchReady = true;
266+
console.log("dispatch ready");
267+
}
268+
if (promise) {
269+
await promise;
270+
} else {
271+
console.error("did not get promise");
272+
AppState.isDispatchReady = true;
273+
console.log("dispatch ready");
274+
}
253275
} else {
254276
console.log("updateEditor() - Module not ready");
255277
}
@@ -258,10 +280,10 @@ async function updateEditor() {
258280
function update(event) {
259281
console.log("Updating");
260282
if ((event.type === "selectPuzzle") & (event.value === "prev")) {
261-
AppState.puzzleIndex = (AppState.puzzleIndex - 1);
283+
AppState.puzzleIndex = AppState.puzzleIndex - 1;
262284
}
263285
if ((event.type === "selectPuzzle") & (event.value === "next")) {
264-
AppState.puzzleIndex = (AppState.puzzleIndex + 1);
286+
AppState.puzzleIndex = AppState.puzzleIndex + 1;
265287
}
266288
if (AppState.puzzleIndex < 0) {
267289
AppState.puzzleIndex = PuzzleSpec.length - 1;

experimental/fasthtml/gpu_puzzles/evaluator.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -127,21 +127,21 @@ std::vector<float> runPuzzle2(Context &ctx, const TestCase &testCase,
127127

128128
Tensor a = createTensor(ctx, {N}, kf32, aVec.data());
129129
Tensor b = createTensor(ctx, {N}, kf32, bVec.data());
130-
Tensor output = createTensor(ctx, {N}, kf32);
130+
Tensor outputTensor = createTensor(ctx, {N}, kf32);
131131

132-
Kernel op = createKernel(ctx, {kernelString, N}, Bindings{a, b, output},
132+
Kernel op = createKernel(ctx, {kernelString, N}, Bindings{a, b, outputTensor},
133133
testCase.gridSize);
134134

135135
std::promise<void> promise;
136136
std::future<void> future = promise.get_future();
137137

138138
dispatchKernel(ctx, op, promise);
139139

140-
std::vector<float> outputArr(N);
140+
std::vector<float> outputVec(N);
141141
wait(ctx, future);
142-
toCPU(ctx, output, outputArr.data(), outputArr.size() * sizeof(float));
142+
toCPU(ctx, outputTensor, outputVec.data(), outputVec.size() * sizeof(float));
143143

144-
return outputArr;
144+
return outputVec;
145145
}
146146

147147
// Function to initialize the test cases

experimental/fasthtml/gpu_puzzles/run.cpp

+62-45
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
using namespace gpu;
1313

14+
constexpr size_t kN = 100;
15+
1416
EM_JS(void, js_print, (const char *str), {
1517
if (typeof window != 'undefined' && window.customPrint) {
1618
window.customPrint(UTF8ToString(str));
@@ -20,60 +22,76 @@ EM_JS(void, js_print, (const char *str), {
2022
}
2123
});
2224

23-
constexpr size_t kN = 5000;
24-
25-
extern "C" {
26-
27-
EMSCRIPTEN_KEEPALIVE bool checkAnswer(std::array<float, kN> &outputArr) {
28-
return outputArr[0] == 10;
29-
// return false;
30-
}
31-
32-
EMSCRIPTEN_KEEPALIVE
33-
void executeKernel(Context& ctx, const char *kernelCode, const Shape &wgSize,
34-
const Shape &nWorkgroups,
35-
std::array<float, kN> &outputArr) {
25+
template <size_t nInputs>
26+
struct HostSpec {
27+
const Shape wgSize;
28+
const Shape nWorkgroups;
29+
const std::string kernelCode;
30+
std::array<std::vector<float>, nInputs> inputs;
31+
};
3632

37-
// TODO(avh): use puzzle dispatch from scaffold.h for host implementation
38-
char buffer[1024]; // for printing
39-
constexpr size_t N = 5000;
40-
std::array<float, N> inputArr;
41-
for (int i = 0; i < N; ++i) {
42-
inputArr[i] = static_cast<float>(i);
33+
template <size_t nInputs>
34+
void executeKernel(Context& ctx,
35+
const HostSpec<nInputs>& spec,
36+
float* outputPtr, size_t outputSize) {
37+
std::array<Tensor, nInputs + 1> bindingsArr; // + 1 for output binding
38+
for (size_t inputIndex = 0; inputIndex < nInputs; ++inputIndex) {
39+
bindingsArr[inputIndex] = createTensor(ctx, Shape{spec.inputs[inputIndex].size()}, kf32, spec.inputs[inputIndex].data());
4340
}
44-
Tensor input = createTensor(ctx, Shape{N}, kf32, inputArr.data());
45-
Tensor output = createTensor(ctx, Shape{N}, kf32);
41+
Tensor output = createTensor(ctx, Shape{outputSize}, kf32);
42+
bindingsArr[nInputs] = output;
43+
Bindings bindings{bindingsArr};
4644
std::promise<void> promise;
4745
std::future<void> future = promise.get_future();
48-
Kernel op = createKernel(ctx, {kernelCode, wgSize, kf32},
49-
Bindings{input, output}, nWorkgroups);
50-
46+
Kernel op = createKernel(ctx, {spec.kernelCode, spec.wgSize, kf32},
47+
bindings, spec.nWorkgroups);
5148
dispatchKernel(ctx, op, promise);
5249
wait(ctx, future);
53-
toCPU(ctx, output, outputArr.data(), sizeof(outputArr));
54-
for (int i = 0; i < 10; ++i) {
55-
snprintf(buffer, sizeof(buffer), " [%d] kernel(%.1f) = %.4f", i,
56-
inputArr[i], outputArr[i]);
57-
js_print(buffer);
50+
toCPU(ctx, output, outputPtr, outputSize * sizeof(float));
51+
}
52+
53+
extern "C" {
54+
55+
void generatePreamble(size_t nInputs, Shape& wgSize, Shape& nWorkgroups, const char* out, size_t outSize) {
56+
std::string result = "";
57+
for (size_t i = 0; i < nInputs; ++i) {
58+
result += "@group(0) @binding(" + std::to_string(i) + ") var input" + std::to_string(i) + " : array;\n";
5859
}
59-
js_print(" ...");
60-
for (int i = N - 10; i < N; ++i) {
61-
snprintf(buffer, sizeof(buffer), " [%d] kernel(%.1f) = %.4f", i,
62-
inputArr[i], outputArr[i]);
63-
js_print(buffer);
60+
result += "@group(0) @binding(" + std::to_string(nInputs) + ") var output : array;\n";
61+
result += "@compute @workgroup_size(" + std::to_string(wgSize[0]) + ", " + std::to_string(wgSize[1]) + ", " + std::to_string(wgSize[2]) + ")\n";
62+
std::strncpy(const_cast<char*>(out), result.c_str(), outSize);
63+
}
64+
65+
66+
EMSCRIPTEN_KEEPALIVE
67+
void runCheck(const char *kernelCode, const Shape &wgSize,
68+
const Shape &nWorkgroups) {
69+
Context ctx = createContext({});
70+
std::array<float, kN> output;
71+
std::vector<float> input(N);
72+
for (int i = 0; i < kN; ++i) {
73+
input[i] = static_cast<float>(i);
6474
}
65-
snprintf(buffer, sizeof(buffer), "Computed %zu values", N);
66-
js_print(buffer);
67-
} // executeKernel
75+
HostSpec<1> spec = {
76+
wgSize,
77+
nWorkgroups,
78+
kernelCode,
79+
std::array<std::vector<float>, 1> {input}
80+
};
81+
executeKernel<1>(ctx, spec, output.data(), kN);
82+
}
6883

6984
EMSCRIPTEN_KEEPALIVE
70-
bool runCheck(const char *kernelCode, const Shape &wgSize,
85+
bool evaluate(const char *kernelCode, const Shape &wgSize,
7186
const Shape &nWorkgroups) {
87+
char buffer[1024]; // for printing
88+
89+
snprintf(buffer, sizeof(buffer), "Evaluating kernel with workgroup size (%zu, %zu, %zu) and nWorkgroups (%zu, %zu, %zu)",
90+
wgSize[0], wgSize[1], wgSize[2], nWorkgroups[0], nWorkgroups[1], nWorkgroups[2]);
91+
js_print(buffer);
7292
Context ctx = createContext({});
73-
std::array<float, kN> outputArr;
74-
executeKernel(ctx, kernelCode, wgSize, nWorkgroups, outputArr);
7593
TestCases testCases = createTestCases();
76-
return evaluate(ctx, testCases, std::string(kernelCode), 0);
94+
return evaluate(ctx, testCases, kernelCode, 0);
7795
}
7896

7997
} // extern "C"
@@ -89,20 +107,19 @@ EMSCRIPTEN_BINDINGS(module) {
89107
emscripten::register_vector<std::vector<float>>("VectorFloat");
90108
emscripten::register_vector<std::vector<int>>("VectorInt");
91109

110+
92111
emscripten::function(
93-
"runCheck",
112+
"evaluate",
94113
emscripten::optional_override(
95114
[](const std::string &kernelCode, const std::array<size_t, 3> &wgSize,
96115
const std::array<size_t, 3> &nWorkgroups) {
97-
return runCheck(kernelCode.c_str(),
116+
return evaluate(kernelCode.c_str(),
98117
Shape{static_cast<size_t>(wgSize[0]),
99118
static_cast<size_t>(wgSize[1]),
100119
static_cast<size_t>(wgSize[2])},
101120
Shape{static_cast<size_t>(nWorkgroups[0]),
102121
static_cast<size_t>(nWorkgroups[1]),
103122
static_cast<size_t>(nWorkgroups[2])});
104123
}));
105-
106-
emscripten::function("checkAnswer", &checkAnswer);
107124
}
108125
#endif

gpu.h

+11-16
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ template <std::size_t N> struct Bindings {
133133
}
134134
}
135135

136+
Bindings(const std::array<Tensor, N> &init) {
137+
std::copy(begin(init), end(init), begin(data));
138+
std::fill(begin(viewOffsets), end(viewOffsets), 0);
139+
for (size_t i = 0; i < N; ++i) {
140+
viewSpans[i] = data[i].data.size;
141+
}
142+
}
143+
136144
Bindings(const std::initializer_list<TensorView> &init) {
137145
size_t i = 0;
138146
for (const auto &tv : init) {
@@ -174,7 +182,7 @@ struct Context; // Forward declaration so that TensorPool can have a pointer to
174182
* resources.
175183
*/
176184
struct TensorPool {
177-
inline TensorPool(Context *ctx) : ctx(ctx), data(){};
185+
inline TensorPool(Context *ctx) : ctx(ctx), data() {};
178186
Context *ctx;
179187
std::unordered_map<WGPUBuffer, Tensor> data;
180188
~TensorPool();
@@ -718,7 +726,8 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {},
718726
"enabled, particularly on Linux.\n"
719727
"- Open `chrome://flags/` in the browser and make sure "
720728
"\"WebGPU Support\" is enabled.\n"
721-
"- Chrome is launched with vulkan enabled. From the command line launch chrome as `google-chrome --enable-features=Vulkan`\n");
729+
"- Chrome is launched with vulkan enabled. From the command line "
730+
"launch chrome as `google-chrome --enable-features=Vulkan`\n");
722731
}
723732
#endif
724733
check(status == WGPURequestAdapterStatus_Success,
@@ -755,20 +764,6 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {},
755764
devData.device = device;
756765
devData.requestEnded = true;
757766
};
758-
#if defined(WEBGPU_BACKEND_DAWN) && !defined(__EMSCRIPTEN__)
759-
devDescriptor.deviceLostCallbackInfo = {
760-
.callback =
761-
[](WGPUDevice const *device, WGPUDeviceLostReason reason,
762-
char const *message, void *userdata) {
763-
if (reason != WGPUDeviceLostReason_Destroyed) {
764-
LOG(kDefLog, kError, "Device lost (code %d):\n%s", reason,
765-
message);
766-
} else {
767-
LOG(kDefLog, kInfo, "Device destroyed: %s", message);
768-
}
769-
},
770-
};
771-
#endif
772767
wgpuAdapterRequestDevice(context.adapter, &devDescriptor,
773768
onDeviceRequestEnded, (void *)&devData);
774769
LOG(kDefLog, kInfo, "Waiting for device request to end");

0 commit comments

Comments
 (0)