1111#include < unordered_map>
1212#include < vector>
1313
14- #include " webgpu/webgpu.h"
1514#include " utils/logging.h"
15+ #include " webgpu/webgpu.h"
1616
1717namespace gpu {
1818
@@ -67,7 +67,7 @@ struct GPUTensor {
6767};
6868
6969struct TensorPool {
70- TensorPool (GPUContext *ctx) : ctx(ctx), data(){};
70+ TensorPool (GPUContext *ctx) : ctx(ctx), data() {};
7171 GPUContext *ctx;
7272 std::unordered_map<WGPUBuffer, GPUTensor> data;
7373 ~TensorPool ();
@@ -121,9 +121,9 @@ const char *ToString(NumType type) {
121121
122122/* Tensor factory function */
123123GPUTensor CreateTensor (TensorPool &pool, const Shape &shape, NumType dtype,
124- WGPUBufferUsageFlags usage = WGPUBufferUsage_Storage |
125- WGPUBufferUsage_CopyDst |
126- WGPUBufferUsage_CopySrc) {
124+ WGPUBufferUsageFlags usage = WGPUBufferUsage_Storage |
125+ WGPUBufferUsage_CopyDst |
126+ WGPUBufferUsage_CopySrc) {
127127 log (kDefLog , kInfo , " Creating tensor" );
128128 size_t numElements = 1 ;
129129 for (size_t dim = 0 ; dim < shape.rank ; dim++) {
@@ -146,16 +146,17 @@ GPUTensor CreateTensor(TensorPool &pool, const Shape &shape, NumType dtype,
146146/* Syntactic sugar - take in ctx instead of pool*/
147147GPUTensor CreateTensor (GPUContext &ctx, const Shape &shape, NumType dtype) {
148148 return CreateTensor (ctx.pool , shape, dtype,
149- WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
150- WGPUBufferUsage_CopySrc);
149+ WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
150+ WGPUBufferUsage_CopySrc);
151151}
152152
153153/* With Value Initialization (pointer) */
154154GPUTensor CreateTensor (GPUContext &ctx, const Shape &shape, NumType dtype,
155- float *data) {
156- GPUTensor tensor = CreateTensor (ctx.pool , shape, dtype,
157- WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
158- WGPUBufferUsage_CopySrc);
155+ float *data) {
156+ GPUTensor tensor =
157+ CreateTensor (ctx.pool , shape, dtype,
158+ WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
159+ WGPUBufferUsage_CopySrc);
159160 wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , data,
160161 tensor.data .size );
161162 return tensor;
@@ -187,16 +188,33 @@ struct CallbackDataDyn {
187188};
188189
189190struct ShaderCode {
190- std::string code ;
191+ std::string data ;
191192 size_t wgSize; // workgroup size
192193};
193194
195+ void ReplaceAll (std::string &str, const std::string &from,
196+ const std::string &to) {
197+ size_t start_pos = 0 ;
198+ while ((start_pos = str.find (from, start_pos)) != std::string::npos) {
199+ str.replace (start_pos, from.length (), to);
200+ start_pos += to.length ();
201+ }
202+ }
203+
204+ ShaderCode CreateShader (const char *shaderRaw, size_t workgroupSize,
205+ NumType precision) {
206+ std::string codeString (shaderRaw);
207+ ReplaceAll (codeString, " {{workgroupSize}}" , std::to_string (workgroupSize));
208+ ReplaceAll (codeString, " {{precision}}" , ToString (precision));
209+ return ShaderCode{codeString, workgroupSize};
210+ }
211+
194212struct KernelDesc {
195213 const ShaderCode shader;
196214 const GPUTensor *inputs;
197215 size_t numInputs;
198216 const GPUTensor output;
199- const void * params;
217+ const void * params;
200218 const size_t paramSize;
201219};
202220
@@ -441,9 +459,9 @@ void ToGPU(GPUContext &ctx, const float *data, GPUTensor &tensor) {
441459}
442460
443461Kernel CreateKernel (GPUContext &ctx, const ShaderCode &shader,
444- const GPUTensor *inputs, size_t numInputs,
445- const GPUTensor &output, const void *params = nullptr ,
446- size_t paramsSize = 0 ) {
462+ const GPUTensor *inputs, size_t numInputs,
463+ const GPUTensor &output, const void *params = nullptr ,
464+ size_t paramsSize = 0 ) {
447465 WGPUDevice device = ctx.device ;
448466 WGPUQueue queue = ctx.queue ;
449467 Kernel op;
@@ -591,7 +609,7 @@ Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
591609 pipelineLayout =
592610 wgpuDeviceCreatePipelineLayout (device, &pipelineLayoutDesc);
593611 WGPUShaderModuleWGSLDescriptor wgslDesc = {
594- .code = shader.code .c_str (),
612+ .code = shader.data .c_str (),
595613 };
596614 wgslDesc.chain .sType = WGPUSType_ShaderModuleWGSLDescriptor;
597615 WGPUShaderModuleDescriptor shaderModuleDesc = {};
@@ -634,14 +652,14 @@ Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
634652
635653template <typename ParamsType = NoParam>
636654Kernel CreateKernel (GPUContext &ctx, const ShaderCode &shader,
637- const GPUTensor *inputs, size_t numInputs,
638- const GPUTensor &output,
639- const ParamsType ¶ms = ParamsType{}) {
655+ const GPUTensor *inputs, size_t numInputs,
656+ const GPUTensor &output,
657+ const ParamsType ¶ms = ParamsType{}) {
640658 if constexpr (!IsNoParam<ParamsType>) {
641659 log (kDefLog , kInfo , " Using params of size %d bytes" , sizeof (ParamsType));
642660 return CreateKernel (ctx, shader, inputs, numInputs, output,
643- reinterpret_cast <const void *>(¶ms),
644- sizeof (ParamsType));
661+ reinterpret_cast <const void *>(¶ms),
662+ sizeof (ParamsType));
645663 } else {
646664 log (kDefLog , kInfo , " No params" );
647665 return CreateKernel (ctx, shader, inputs, numInputs, output, nullptr , 0 );
@@ -653,11 +671,11 @@ Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
653671 */
654672template <typename ParamsType = NoParam, size_t numInputs>
655673Kernel CreateKernel (GPUContext &ctx, const ShaderCode &shader,
656- const std::array<GPUTensor, numInputs> &inputs,
657- const GPUTensor &output,
658- const ParamsType ¶ms = ParamsType{}) {
659- return CreateKernel<ParamsType>(ctx, shader, inputs.data (), numInputs,
660- output, params);
674+ const std::array<GPUTensor, numInputs> &inputs,
675+ const GPUTensor &output,
676+ const ParamsType ¶ms = ParamsType{}) {
677+ return CreateKernel<ParamsType>(ctx, shader, inputs.data (), numInputs, output,
678+ params);
661679}
662680
663681MultiKernel CreateMultiKernel (GPUContext &ctx, const MultiKernelDesc &desc) {
@@ -791,7 +809,7 @@ MultiKernel CreateMultiKernel(GPUContext &ctx, const MultiKernelDesc &desc) {
791809 // Create shader module
792810 log (kDefLog , kInfo , " Create shader module" );
793811 WGPUShaderModuleWGSLDescriptor wgslDesc = {
794- .code = desc.shader [shaderIndex].code .c_str (),
812+ .code = desc.shader [shaderIndex].data .c_str (),
795813 };
796814 wgslDesc.chain .sType = WGPUSType_ShaderModuleWGSLDescriptor;
797815 WGPUShaderModuleDescriptor shaderModuleDesc = {
0 commit comments