@@ -206,13 +206,11 @@ struct KernelPool {
206206 KernelPool (Context *ctx) : ctx(ctx), data() {}
207207 Context *ctx;
208208 std::set<Kernel *> data;
209- // std::set<MultiKernel *> multiData;
210209 ~KernelPool () {
211210 // Note : Some kernel resources such as commandBuffer are harvested by
212211 // queue submission, explicitly destroying readback and callback buffers
213212 // produces runtime errors.
214213 data.clear ();
215- // multiData.clear();
216214 }
217215};
218216
@@ -664,8 +662,6 @@ void ResetCommandBuffer(WGPUDevice &device, const Shape &nThreads, Kernel &op) {
664662 op.commandBuffer = wgpuCommandEncoderFinish (commandEncoder, nullptr );
665663 check (op.commandBuffer , " Create command buffer" , __FILE__, __LINE__);
666664 }
667- // op.promise = std::promise<void>();
668- // op.future = op.promise.get_future();
669665}
670666
671667/* *
@@ -800,7 +796,6 @@ Kernel CreateKernel(Context &ctx, const ShaderCode &shader,
800796 .entries = bindGroupEntries.data (),
801797 };
802798 op.bindGroup = wgpuDeviceCreateBindGroup (device, &bindGroupDesc);
803-
804799 {
805800 WGPUPipelineLayoutDescriptor pipelineLayoutDesc = {
806801 .bindGroupLayoutCount = 1 ,
@@ -833,42 +828,6 @@ Kernel CreateKernel(Context &ctx, const ShaderCode &shader,
833828 return op;
834829}
835830
836- /* *
837- * @brief Overload which wraps the CreateKernel factory function to create a
838- * kernel on the GPU with a statically determined ParamsType instead of casting
839- * params to a void pointer. paramSize is then determined by the size of the
840- * ParamsType.
841- *
842- * @param[in] ctx Context instance to manage the kernel
843- * @param[in] shader Shader code for the kernel
844- * @param[in] inputs A span of input tensors as a pointer
845- * @param[in] numInputs Number of input tensors, effectively the size of the
846- * *inputs span.
847- * @param[in] output Output tensor for the kernel
848- * @param[in] nThreads Shape of the workgroup size for the kernel, must be of
849- * rank 3.
850- * @param[in] params Optional parameters for the kernel. If the kernel does not
851- * have any parameters, use NoParam.
852- * @example Kernel kernel = CreateKernel(ctx, shader, inputs, numInputs, output,
853- * nThreads, params);
854- */
855- template <typename ParamsType = NoParam>
856- Kernel CreateKernel (Context &ctx, const ShaderCode &shader,
857- const Tensor *inputs, size_t numInputs,
858- const Shape &nThreads,
859- const ParamsType ¶ms = ParamsType{}) {
860- if constexpr (!IsNoParam<ParamsType>) {
861- log (kDefLog , kInfo , " Using params of size %d bytes" , sizeof (ParamsType));
862- return CreateKernel (ctx, shader, inputs, numInputs, nThreads,
863- reinterpret_cast <const void *>(¶ms),
864- sizeof (ParamsType));
865- } else {
866- log (kDefLog , kInfo , " No params" );
867- return CreateKernel (ctx, shader, inputs, numInputs, nThreads,
868- nullptr , 0 );
869- }
870- }
871-
872831/* *
873832 * @brief Overload which wraps the CreateKernel factory function to create a
874833 * kernel on the GPU. This overload uses takes a static collection of input
@@ -892,17 +851,16 @@ Kernel CreateKernel(Context &ctx, const ShaderCode &shader,
892851 const TensorList<numInputs> &inputs,
893852 const Shape &nThreads,
894853 const ParamsType ¶ms = ParamsType{}) {
895- // first .data gets the array, second .data() gets the pointer
896- return CreateKernel<ParamsType>(ctx, shader, inputs.data .data (), numInputs,
897- nThreads, params);
898- }
899-
900- // Convenience wrapper: specialization for single input passed by reference
901- template <typename ParamsType = NoParam>
902- Kernel CreateKernel (Context &ctx, const ShaderCode &shader, const Tensor &input,
903- const Shape &nThreads,
904- const ParamsType ¶ms = ParamsType{}) {
905- return CreateKernel (ctx, shader, &input, 1 , nThreads, params);
854+ if constexpr (!IsNoParam<ParamsType>) {
855+ log (kDefLog , kInfo , " Using params of size %d bytes" , sizeof (ParamsType));
856+ return CreateKernel (ctx, shader, inputs.data .data (), numInputs, nThreads,
857+ reinterpret_cast <const void *>(¶ms),
858+ sizeof (ParamsType));
859+ } else {
860+ log (kDefLog , kInfo , " No params" );
861+ return CreateKernel (ctx, shader, inputs.data .data (), numInputs, nThreads,
862+ nullptr , 0 );
863+ }
906864}
907865
908866/* *
0 commit comments