11
11
#include < unordered_map>
12
12
#include < vector>
13
13
14
- #include " webgpu/webgpu.h"
15
14
#include " utils/logging.h"
15
+ #include " webgpu/webgpu.h"
16
16
17
17
namespace gpu {
18
18
@@ -67,7 +67,7 @@ struct GPUTensor {
67
67
};
68
68
69
69
struct TensorPool {
70
- TensorPool (GPUContext *ctx) : ctx(ctx), data(){};
70
+ TensorPool (GPUContext *ctx) : ctx(ctx), data() {};
71
71
GPUContext *ctx;
72
72
std::unordered_map<WGPUBuffer, GPUTensor> data;
73
73
~TensorPool ();
@@ -121,9 +121,9 @@ const char *ToString(NumType type) {
121
121
122
122
/* Tensor factory function */
123
123
GPUTensor CreateTensor (TensorPool &pool, const Shape &shape, NumType dtype,
124
- WGPUBufferUsageFlags usage = WGPUBufferUsage_Storage |
125
- WGPUBufferUsage_CopyDst |
126
- WGPUBufferUsage_CopySrc) {
124
+ WGPUBufferUsageFlags usage = WGPUBufferUsage_Storage |
125
+ WGPUBufferUsage_CopyDst |
126
+ WGPUBufferUsage_CopySrc) {
127
127
log (kDefLog , kInfo , " Creating tensor" );
128
128
size_t numElements = 1 ;
129
129
for (size_t dim = 0 ; dim < shape.rank ; dim++) {
@@ -146,16 +146,17 @@ GPUTensor CreateTensor(TensorPool &pool, const Shape &shape, NumType dtype,
146
146
/* Syntactic sugar - take in ctx instead of pool*/
147
147
GPUTensor CreateTensor (GPUContext &ctx, const Shape &shape, NumType dtype) {
148
148
return CreateTensor (ctx.pool , shape, dtype,
149
- WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
150
- WGPUBufferUsage_CopySrc);
149
+ WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
150
+ WGPUBufferUsage_CopySrc);
151
151
}
152
152
153
153
/* With Value Initialization (pointer) */
154
154
GPUTensor CreateTensor (GPUContext &ctx, const Shape &shape, NumType dtype,
155
- float *data) {
156
- GPUTensor tensor = CreateTensor (ctx.pool , shape, dtype,
157
- WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
158
- WGPUBufferUsage_CopySrc);
155
+ float *data) {
156
+ GPUTensor tensor =
157
+ CreateTensor (ctx.pool , shape, dtype,
158
+ WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
159
+ WGPUBufferUsage_CopySrc);
159
160
wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , data,
160
161
tensor.data .size );
161
162
return tensor;
@@ -187,16 +188,33 @@ struct CallbackDataDyn {
187
188
};
188
189
189
190
struct ShaderCode {
190
- std::string code ;
191
+ std::string data ;
191
192
size_t wgSize; // workgroup size
192
193
};
193
194
195
+ void ReplaceAll (std::string &str, const std::string &from,
196
+ const std::string &to) {
197
+ size_t start_pos = 0 ;
198
+ while ((start_pos = str.find (from, start_pos)) != std::string::npos) {
199
+ str.replace (start_pos, from.length (), to);
200
+ start_pos += to.length ();
201
+ }
202
+ }
203
+
204
+ ShaderCode CreateShader (const char *shaderRaw, size_t workgroupSize,
205
+ NumType precision) {
206
+ std::string codeString (shaderRaw);
207
+ ReplaceAll (codeString, " {{workgroupSize}}" , std::to_string (workgroupSize));
208
+ ReplaceAll (codeString, " {{precision}}" , ToString (precision));
209
+ return ShaderCode{codeString, workgroupSize};
210
+ }
211
+
194
212
struct KernelDesc {
195
213
const ShaderCode shader;
196
214
const GPUTensor *inputs;
197
215
size_t numInputs;
198
216
const GPUTensor output;
199
- const void * params;
217
+ const void * params;
200
218
const size_t paramSize;
201
219
};
202
220
@@ -441,9 +459,9 @@ void ToGPU(GPUContext &ctx, const float *data, GPUTensor &tensor) {
441
459
}
442
460
443
461
Kernel CreateKernel (GPUContext &ctx, const ShaderCode &shader,
444
- const GPUTensor *inputs, size_t numInputs,
445
- const GPUTensor &output, const void *params = nullptr ,
446
- size_t paramsSize = 0 ) {
462
+ const GPUTensor *inputs, size_t numInputs,
463
+ const GPUTensor &output, const void *params = nullptr ,
464
+ size_t paramsSize = 0 ) {
447
465
WGPUDevice device = ctx.device ;
448
466
WGPUQueue queue = ctx.queue ;
449
467
Kernel op;
@@ -591,7 +609,7 @@ Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
591
609
pipelineLayout =
592
610
wgpuDeviceCreatePipelineLayout (device, &pipelineLayoutDesc);
593
611
WGPUShaderModuleWGSLDescriptor wgslDesc = {
594
- .code = shader.code .c_str (),
612
+ .code = shader.data .c_str (),
595
613
};
596
614
wgslDesc.chain .sType = WGPUSType_ShaderModuleWGSLDescriptor;
597
615
WGPUShaderModuleDescriptor shaderModuleDesc = {};
@@ -634,14 +652,14 @@ Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
634
652
635
653
template <typename ParamsType = NoParam>
636
654
Kernel CreateKernel (GPUContext &ctx, const ShaderCode &shader,
637
- const GPUTensor *inputs, size_t numInputs,
638
- const GPUTensor &output,
639
- const ParamsType ¶ms = ParamsType{}) {
655
+ const GPUTensor *inputs, size_t numInputs,
656
+ const GPUTensor &output,
657
+ const ParamsType ¶ms = ParamsType{}) {
640
658
if constexpr (!IsNoParam<ParamsType>) {
641
659
log (kDefLog , kInfo , " Using params of size %d bytes" , sizeof (ParamsType));
642
660
return CreateKernel (ctx, shader, inputs, numInputs, output,
643
- reinterpret_cast <const void *>(¶ms),
644
- sizeof (ParamsType));
661
+ reinterpret_cast <const void *>(¶ms),
662
+ sizeof (ParamsType));
645
663
} else {
646
664
log (kDefLog , kInfo , " No params" );
647
665
return CreateKernel (ctx, shader, inputs, numInputs, output, nullptr , 0 );
@@ -653,11 +671,11 @@ Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
653
671
*/
654
672
template <typename ParamsType = NoParam, size_t numInputs>
655
673
Kernel CreateKernel (GPUContext &ctx, const ShaderCode &shader,
656
- const std::array<GPUTensor, numInputs> &inputs,
657
- const GPUTensor &output,
658
- const ParamsType ¶ms = ParamsType{}) {
659
- return CreateKernel<ParamsType>(ctx, shader, inputs.data (), numInputs,
660
- output, params);
674
+ const std::array<GPUTensor, numInputs> &inputs,
675
+ const GPUTensor &output,
676
+ const ParamsType ¶ms = ParamsType{}) {
677
+ return CreateKernel<ParamsType>(ctx, shader, inputs.data (), numInputs, output,
678
+ params);
661
679
}
662
680
663
681
MultiKernel CreateMultiKernel (GPUContext &ctx, const MultiKernelDesc &desc) {
@@ -791,7 +809,7 @@ MultiKernel CreateMultiKernel(GPUContext &ctx, const MultiKernelDesc &desc) {
791
809
// Create shader module
792
810
log (kDefLog , kInfo , " Create shader module" );
793
811
WGPUShaderModuleWGSLDescriptor wgslDesc = {
794
- .code = desc.shader [shaderIndex].code .c_str (),
812
+ .code = desc.shader [shaderIndex].data .c_str (),
795
813
};
796
814
wgslDesc.chain .sType = WGPUSType_ShaderModuleWGSLDescriptor;
797
815
WGPUShaderModuleDescriptor shaderModuleDesc = {
0 commit comments