11
11
12
12
using namespace gpu ;
13
13
14
+ constexpr size_t kN = 100 ;
15
+
14
16
EM_JS (void , js_print, (const char *str), {
15
17
if (typeof window != ' undefined' && window.customPrint ) {
16
18
window.customPrint (UTF8ToString (str));
@@ -20,60 +22,76 @@ EM_JS(void, js_print, (const char *str), {
20
22
}
21
23
});
22
24
23
- constexpr size_t kN = 5000 ;
24
-
25
- extern " C" {
26
-
27
- EMSCRIPTEN_KEEPALIVE bool checkAnswer (std::array<float , kN > &outputArr) {
28
- return outputArr[0 ] == 10 ;
29
- // return false;
30
- }
31
-
32
- EMSCRIPTEN_KEEPALIVE
33
- void executeKernel (Context& ctx, const char *kernelCode, const Shape &wgSize,
34
- const Shape &nWorkgroups,
35
- std::array<float , kN > &outputArr) {
25
+ template <size_t nInputs>
26
+ struct HostSpec {
27
+ const Shape wgSize;
28
+ const Shape nWorkgroups;
29
+ const std::string kernelCode;
30
+ std::array<std::vector<float >, nInputs> inputs;
31
+ };
36
32
37
- // TODO(avh): use puzzle dispatch from scaffold.h for host implementation
38
- char buffer[1024 ]; // for printing
39
- constexpr size_t N = 5000 ;
40
- std::array<float , N> inputArr;
41
- for (int i = 0 ; i < N; ++i) {
42
- inputArr[i] = static_cast <float >(i);
33
+ template <size_t nInputs>
34
+ void executeKernel (Context& ctx,
35
+ const HostSpec<nInputs>& spec,
36
+ float * outputPtr, size_t outputSize) {
37
+ std::array<Tensor, nInputs + 1 > bindingsArr; // + 1 for output binding
38
+ for (size_t inputIndex = 0 ; inputIndex < nInputs; ++inputIndex) {
39
+ bindingsArr[inputIndex] = createTensor (ctx, Shape{spec.inputs [inputIndex].size ()}, kf32, spec.inputs [inputIndex].data ());
43
40
}
44
- Tensor input = createTensor (ctx, Shape{N}, kf32, inputArr.data ());
45
- Tensor output = createTensor (ctx, Shape{N}, kf32);
41
+ Tensor output = createTensor (ctx, Shape{outputSize}, kf32);
42
+ bindingsArr[nInputs] = output;
43
+ Bindings bindings{bindingsArr};
46
44
std::promise<void > promise;
47
45
std::future<void > future = promise.get_future ();
48
- Kernel op = createKernel (ctx, {kernelCode, wgSize, kf32},
49
- Bindings{input, output}, nWorkgroups);
50
-
46
+ Kernel op = createKernel (ctx, {spec.kernelCode , spec.wgSize , kf32},
47
+ bindings, spec.nWorkgroups );
51
48
dispatchKernel (ctx, op, promise);
52
49
wait (ctx, future);
53
- toCPU (ctx, output, outputArr.data (), sizeof (outputArr));
54
- for (int i = 0 ; i < 10 ; ++i) {
55
- snprintf (buffer, sizeof (buffer), " [%d] kernel(%.1f) = %.4f" , i,
56
- inputArr[i], outputArr[i]);
57
- js_print (buffer);
50
+ toCPU (ctx, output, outputPtr, outputSize * sizeof (float ));
51
+ }
52
+
53
+ extern " C" {
54
+
55
+ void generatePreamble (size_t nInputs, Shape& wgSize, Shape& nWorkgroups, const char * out, size_t outSize) {
56
+ std::string result = " " ;
57
+ for (size_t i = 0 ; i < nInputs; ++i) {
58
+ result += " @group(0) @binding(" + std::to_string (i) + " ) var input" + std::to_string (i) + " : array;\n " ;
58
59
}
59
- js_print (" ..." );
60
- for (int i = N - 10 ; i < N; ++i) {
61
- snprintf (buffer, sizeof (buffer), " [%d] kernel(%.1f) = %.4f" , i,
62
- inputArr[i], outputArr[i]);
63
- js_print (buffer);
60
+ result += " @group(0) @binding(" + std::to_string (nInputs) + " ) var output : array;\n " ;
61
+ result += " @compute @workgroup_size(" + std::to_string (wgSize[0 ]) + " , " + std::to_string (wgSize[1 ]) + " , " + std::to_string (wgSize[2 ]) + " )\n " ;
62
+ std::strncpy (const_cast <char *>(out), result.c_str (), outSize);
63
+ }
64
+
65
+
66
+ EMSCRIPTEN_KEEPALIVE
67
+ void runCheck (const char *kernelCode, const Shape &wgSize,
68
+ const Shape &nWorkgroups) {
69
+ Context ctx = createContext ({});
70
+ std::array<float , kN > output;
71
+ std::vector<float > input (N);
72
+ for (int i = 0 ; i < kN ; ++i) {
73
+ input[i] = static_cast <float >(i);
64
74
}
65
- snprintf (buffer, sizeof (buffer), " Computed %zu values" , N);
66
- js_print (buffer);
67
- } // executeKernel
75
+ HostSpec<1 > spec = {
76
+ wgSize,
77
+ nWorkgroups,
78
+ kernelCode,
79
+ std::array<std::vector<float >, 1 > {input}
80
+ };
81
+ executeKernel<1 >(ctx, spec, output.data (), kN );
82
+ }
68
83
69
84
EMSCRIPTEN_KEEPALIVE
70
- bool runCheck (const char *kernelCode, const Shape &wgSize,
85
+ bool evaluate (const char *kernelCode, const Shape &wgSize,
71
86
const Shape &nWorkgroups) {
87
+ char buffer[1024 ]; // for printing
88
+
89
+ snprintf (buffer, sizeof (buffer), " Evaluating kernel with workgroup size (%zu, %zu, %zu) and nWorkgroups (%zu, %zu, %zu)" ,
90
+ wgSize[0 ], wgSize[1 ], wgSize[2 ], nWorkgroups[0 ], nWorkgroups[1 ], nWorkgroups[2 ]);
91
+ js_print (buffer);
72
92
Context ctx = createContext ({});
73
- std::array<float , kN > outputArr;
74
- executeKernel (ctx, kernelCode, wgSize, nWorkgroups, outputArr);
75
93
TestCases testCases = createTestCases ();
76
- return evaluate (ctx, testCases, std::string ( kernelCode) , 0 );
94
+ return evaluate (ctx, testCases, kernelCode, 0 );
77
95
}
78
96
79
97
} // extern "C"
@@ -89,20 +107,19 @@ EMSCRIPTEN_BINDINGS(module) {
89
107
emscripten::register_vector<std::vector<float >>(" VectorFloat" );
90
108
emscripten::register_vector<std::vector<int >>(" VectorInt" );
91
109
110
+
92
111
emscripten::function (
93
- " runCheck " ,
112
+ " evaluate " ,
94
113
emscripten::optional_override (
95
114
[](const std::string &kernelCode, const std::array<size_t , 3 > &wgSize,
96
115
const std::array<size_t , 3 > &nWorkgroups) {
97
- return runCheck (kernelCode.c_str (),
116
+ return evaluate (kernelCode.c_str (),
98
117
Shape{static_cast <size_t >(wgSize[0 ]),
99
118
static_cast <size_t >(wgSize[1 ]),
100
119
static_cast <size_t >(wgSize[2 ])},
101
120
Shape{static_cast <size_t >(nWorkgroups[0 ]),
102
121
static_cast <size_t >(nWorkgroups[1 ]),
103
122
static_cast <size_t >(nWorkgroups[2 ])});
104
123
}));
105
-
106
- emscripten::function (" checkAnswer" , &checkAnswer);
107
124
}
108
125
#endif
0 commit comments