Skip to content

Commit d1bb95e

Browse files
committed
fasthtml live wgsl editing experiment works
1 parent b3819a5 commit d1bb95e

File tree

6 files changed

+143
-120
lines changed

6 files changed

+143
-120
lines changed

Diff for: examples/web/run.cpp

-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <memory>
55
#include "gpu.h"
66

7-
// #include <webgpu/webgpu.h>
87
#include "emscripten/emscripten.h"
98

109
using namespace gpu; // createContext, createTensor, createKernel,
@@ -33,10 +32,6 @@ int main(int argc, char **argv) {
3332
printf("\nHello gpu.cpp!\n");
3433
printf("--------------\n\n");
3534

36-
// const WGPUInstanceDescriptor descriptor = { };
37-
// std::unique_ptr<WGPUInstanceDescriptor> descriptor = std::make_unique<WGPUInstanceDescriptor>();
38-
39-
// WGPUInstance instance = wgpuCreateInstance({});
4035
Context ctx = createContext({});
4136
static constexpr size_t N = 5000;
4237
std::array<float, N> inputArr, outputArr;

Diff for: experimental/fasthtml/Makefile

+15-9
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
11
GPUCPP=../..
2-
FLAGS=-std=c++17 -s USE_WEBGPU=1 -s ASYNCIFY=1 -I$(GPUCPP)
2+
COMMON_FLAGS=-std=c++17 -s USE_WEBGPU=1 -s ASYNCIFY=1 -I$(GPUCPP)
3+
# Note - no spaces after comma
4+
# enable exceptions to recover from WGSL failure
5+
JS_FLAGS=-s EXPORTED_RUNTIME_METHODS=['UTF8ToString','setValue','addFunction'] -s EXPORTED_FUNCTIONS=['_malloc','_free','_executeKernel'] -s NO_DISABLE_EXCEPTION_CATCHING
6+
WASM_FLAGS=-s STANDALONE_WASM -s ERROR_ON_UNDEFINED_SYMBOLS=0 -s EXPORTED_FUNCTIONS=['_executeKernel'] -s EXPORTED_RUNTIME_METHODS=['ccall','cwrap'] -DSTANDALONE_WASM
7+
MODULARIZE_FLAGS=-s EXPORT_NAME='createModule' -s MODULARIZE=1 --bind
8+
NO_MODULARIZE_FLAGS=-s EXPORTED_FUNCTIONS=['_executeKernel'] -s EXPORTED_RUNTIME_METHODS=['ccall','cwrap'] --bind
39

4-
.PHONY: default cmake check-emsdk browser clean
10+
.PHONY: default cmake check-emsdk browser clean server
511

612
default: server
713

8-
build/run.html: check-emsdk run.cpp
9-
em++ run.cpp -o build/run.html \
10-
$(FLAGS) --shell-file ./custom_shell.html
14+
# build/run.html: check-emsdk run.cpp
15+
# em++ run.cpp -o build/run.html \
16+
# $(COMMON_FLAGS) $(JS_FLAGS) --shell-file ./custom_shell.html
1117

1218
build/run.js: check-emsdk run.cpp
13-
em++ run.cpp -o build/run.js --shell-file ./custom_shell.html \
14-
$(FLAGS)
19+
em++ run.cpp -o build/run.js \
20+
$(COMMON_FLAGS) $(JS_FLAGS) $(MODULARIZE_FLAGS)
1521

1622
build/run.wasm: check-emsdk run.cpp
1723
em++ run.cpp -o build/run.wasm \
18-
$(FLAGS)
24+
$(COMMON_FLAGS) $(WASM_FLAGS)
1925

20-
server: build/run.wasm
26+
server: build/run.js
2127
python3 run.py
2228

2329
clean:

Diff for: experimental/fasthtml/components/code_editor.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
let completionTippy;
99
let currentCompletion = '';
1010
11+
1112
function initEditor() {
1213
editor = ace.edit("editor");
1314
editor.setTheme("ace/theme/monokai");
@@ -23,20 +24,16 @@
2324
editor.setKeyboardHandler("ace/keyboard/vim");
2425
2526
editor.setValue(
26-
"// Puzzle 1 : Map\\n\
27-
// Implement a kernel that adds 10 to each position of vector\\n\
28-
// a and stores it in vector out. You have 1 thread per position.\\n\
29-
//\\n\
30-
// Warning: You are in vim mode.\\n\
27+
"// Start editing here to see the results.\\n\// Warning: You are in vim mode.\\n\
3128
\\n\
32-
@group(0) @binding(0) var<storage, read_write> a: array<f32>;\\n\
29+
@group(0) @binding(0) var<storage, read_write> input: array<f32>;\\n\
3330
@group(0) @binding(1) var<storage, read_write> output : array<f32>;\\n\
3431
@compute @workgroup_size(256)\\n\
3532
fn main(\\n\
3633
@builtin(local_invocation_id) LocalInvocationID: vec3<u32>) {\\n\
3734
let local_idx = LocalInvocationID.x;\\n\
38-
if (local_idx < arrayLength(&a)) {\\n\
39-
output[local_idx] = a[local_idx] + 10;\\n\
35+
if (local_idx < arrayLength(&input)) {\\n\
36+
output[local_idx] = input[local_idx] + 1;\\n\
4037
}\\n\
4138
}\\n\
4239
");
@@ -54,6 +51,24 @@
5451
if (delta.action === 'insert' && (delta.lines[0] === '.' || delta.lines[0] === ' ')) {
5552
showCompletionSuggestion();
5653
}
54+
55+
// Recover from errors TODO(avh): only do this if there's an error
56+
createModule().then((Module) => {
57+
// Keep your existing Module setup
58+
Module.print = window.customPrint;
59+
Module.printErr = window.customPrint;
60+
window.Module = Module;
61+
console.log("Module ready");
62+
});
63+
64+
if (window.Module && window.Module.executeKernel) {
65+
console.log("Executing kernel");
66+
window.terminal.clear();
67+
window.Module.executeKernel(editor.getValue());
68+
} else {
69+
console.log("Module not ready");
70+
}
71+
5772
});
5873
5974
completionTippy = tippy(document.getElementById('editor'), {

Diff for: experimental/fasthtml/custom_shell.html

-53
This file was deleted.

Diff for: experimental/fasthtml/run.cpp

+51-41
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,72 @@
1+
#include "gpu.h"
12
#include <array>
23
#include <cstdio>
4+
#include <emscripten/emscripten.h>
35
#include <future>
46
#include <memory>
5-
#include "gpu.h"
7+
#include <string>
68

7-
// #include <webgpu/webgpu.h>
8-
#include "emscripten/emscripten.h"
9-
10-
using namespace gpu; // createContext, createTensor, createKernel,
11-
// createShader, dispatchKernel, wait, toCPU
12-
// Tensor, Kernel, Context, Shape, kf32
13-
14-
static const char *kGelu = R"(
15-
const GELU_SCALING_FACTOR: f32 = 0.7978845608028654; // sqrt(2.0 / PI)
16-
@group(0) @binding(0) var<storage, read_write> inp: array<{{precision}}>;
17-
@group(0) @binding(1) var<storage, read_write> out: array<{{precision}}>;
18-
@group(0) @binding(1) var<storage, read_write> dummy: array<{{precision}}>;
19-
@compute @workgroup_size({{workgroupSize}})
20-
fn main(
21-
@builtin(global_invocation_id) GlobalInvocationID: vec3<u32>) {
22-
let i: u32 = GlobalInvocationID.x;
23-
if (i < arrayLength(&inp)) {
24-
let x: f32 = inp[i];
25-
out[i] = select(0.5 * x * (1.0 + tanh(GELU_SCALING_FACTOR
26-
* (x + .044715 * x * x * x))), x, x > 10.0);
27-
}
28-
}
29-
)";
9+
using namespace gpu;
3010

31-
int main(int argc, char **argv) {
32-
printf("\033[2J\033[1;1H");
33-
printf("\nHello gpu.cpp!\n");
34-
printf("--------------\n\n");
11+
EM_JS(void, js_print, (const char *str), {
12+
if (typeof window != 'undefined' && window.customPrint) {
13+
window.customPrint(UTF8ToString(str));
14+
} else {
15+
console.log("window.customPrint is not defined.");
16+
console.log(UTF8ToString(str));
17+
}
18+
});
3519

36-
// const WGPUInstanceDescriptor descriptor = { };
37-
// std::unique_ptr<WGPUInstanceDescriptor> descriptor = std::make_unique<WGPUInstanceDescriptor>();
20+
extern "C" {
3821

39-
// WGPUInstance instance = wgpuCreateInstance({});
22+
EMSCRIPTEN_KEEPALIVE
23+
void executeKernel(const char *kernelCode) {
4024
Context ctx = createContext({});
4125
static constexpr size_t N = 5000;
4226
std::array<float, N> inputArr, outputArr;
27+
4328
for (int i = 0; i < N; ++i) {
44-
inputArr[i] = static_cast<float>(i) / 10.0; // dummy input data
29+
inputArr[i] = static_cast<float>(i);
4530
}
31+
4632
Tensor input = createTensor(ctx, Shape{N}, kf32, inputArr.data());
4733
Tensor output = createTensor(ctx, Shape{N}, kf32);
34+
4835
std::promise<void> promise;
4936
std::future<void> future = promise.get_future();
50-
Kernel op = createKernel(ctx, {kGelu, 256, kf32},
51-
Bindings{input, output},
52-
{cdiv(N, 256), 1, 1});
53-
dispatchKernel(ctx, op, promise);
54-
wait(ctx, future);
37+
38+
try {
39+
Kernel op = createKernel(ctx, {kernelCode, 256, kf32},
40+
Bindings{input, output}, {cdiv(N, 256), 1, 1});
41+
42+
dispatchKernel(ctx, op, promise);
43+
wait(ctx, future);
44+
} catch (const std::exception &e) {
45+
js_print("Invalid kernel code.");
46+
exit(1);
47+
}
48+
5549
toCPU(ctx, output, outputArr.data(), sizeof(outputArr));
50+
51+
char buffer[1024];
5652
for (int i = 0; i < 12; ++i) {
57-
printf(" gelu(%.2f) = %.2f\n", inputArr[i], outputArr[i]);
53+
snprintf(buffer, sizeof(buffer), " kernel(%.1f) = %.4f", inputArr[i],
54+
outputArr[i]);
55+
js_print(buffer);
5856
}
59-
printf(" ...\n\n");
60-
printf("Computed %zu values of GELU(x)\n\n", N);
61-
return 0;
57+
snprintf(buffer, sizeof(buffer), " ...");
58+
js_print(buffer);
59+
snprintf(buffer, sizeof(buffer), "Computed %zu values", N);
60+
js_print(buffer);
61+
}
62+
}
63+
64+
#ifndef STANDALONE_WASM
65+
#include "emscripten/bind.h"
66+
EMSCRIPTEN_BINDINGS(module) {
67+
emscripten::function("executeKernel", emscripten::optional_override(
68+
[](const std::string &kernelCode) {
69+
executeKernel(kernelCode.c_str());
70+
}));
6271
}
72+
#endif

Diff for: experimental/fasthtml/run.py

+54-4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
ace_editor = Script(src="https://cdnjs.cloudflare.com/ajax/libs/ace/1.4.12/ace.js")
1010
gpucpp_runtime = Script(src="/build/run.js")
1111
gpucpp_wasm = Script(src="/build/run.wasm")
12+
tippy_css = Link(rel="stylesheet", href="https://unpkg.com/tippy.js@6/dist/tippy.css")
13+
tippy_js = Script(src="https://unpkg.com/@popperjs/core@2")
14+
tippy_js2 = Script(src="https://unpkg.com/tippy.js@6")
15+
xterm_css = Link(rel="stylesheet", href="https://cdn.jsdelivr.net/npm/xterm/css/xterm.css")
16+
xterm_js = Script(src="https://cdn.jsdelivr.net/npm/xterm/lib/xterm.js")
17+
xterm_fit_js = Script(src="https://cdn.jsdelivr.net/npm/xterm-addon-fit/lib/xterm-addon-fit.js")
1218

1319
global_style = Style("""
1420
#editor {
@@ -17,11 +23,53 @@
1723
}
1824
""")
1925

26+
terminal_init = Script("""
27+
console.log("Terminal initialized");
28+
const terminal = new Terminal();
29+
const fitAddon = new FitAddon.FitAddon();
30+
terminal.loadAddon(fitAddon);
31+
// terminal.open(document.getElementById('output'));
32+
window.terminal = terminal;
33+
fitAddon.fit();
34+
console.log("Terminal initialized");
35+
""");
36+
37+
print_script = Script("""
38+
window.customPrint = function(text) {
39+
console.log(text);
40+
if (window.terminal) {
41+
window.terminal.writeln(text);
42+
} else {
43+
console.warn("Terminal not initialized");
44+
}
45+
};
46+
createModule().then((Module) => {
47+
Module.print = window.customPrint;
48+
Module.printErr = window.customPrint;
49+
window.Module = Module;
50+
console.log("Module ready");
51+
});
52+
"""),
53+
54+
bind_terminal = Script("window.terminal.open(document.getElementById('output'));")
55+
56+
# TODO(avh): Global state handling of terminal binding, module creation, etc.
57+
# could be improved
58+
2059
HDRS = (
2160
picolink,
2261
ace_editor,
62+
xterm_css,
63+
xterm_js,
64+
xterm_fit_js,
65+
terminal_init,
2366
gpucpp_runtime,
67+
print_script,
68+
bind_terminal,
2469
global_style,
70+
tippy_css,
71+
tippy_js,
72+
tippy_js2,
2573
*Socials(
2674
title="gpu.cpp gpu puzzles",
2775
description="",
@@ -39,9 +87,9 @@
3987

4088
rt = app.route
4189

42-
@app.get("/build/{fname:path}.{ext:static}")
43-
async def build(fname: str, ext: str):
44-
return FileResponse(f"build/{fname}.{ext}")
90+
@app.get("/build/run.js")
91+
async def serve_wasm(fname: str, ext: str):
92+
return FileResponse(f"build/run.js")
4593

4694
@app.get("/build/run.wasm")
4795
async def serve_wasm(fname: str, ext: str):
@@ -56,10 +104,12 @@ def page():
56104
),
57105
Div(
58106
"Output",
107+
id="output",
59108
style="width: 34vw; height:100vh; background-color: #444; float: right;",
60109
),
61110
),
62-
style="height: 100vh; overflow: hidden;")
111+
bind_terminal,
112+
style="height: 100vh; overflow: hidden;"),
63113

64114

65115
@rt("/")

0 commit comments

Comments
 (0)