@@ -206,13 +206,11 @@ struct KernelPool {
206
206
KernelPool (Context *ctx) : ctx(ctx), data() {}
207
207
Context *ctx;
208
208
std::set<Kernel *> data;
209
- // std::set<MultiKernel *> multiData;
210
209
~KernelPool () {
211
210
// Note : Some kernel resources such as commandBuffer are harvested by
212
211
// queue submission, explicitly destroying readback and callback buffers
213
212
// produces runtime errors.
214
213
data.clear ();
215
- // multiData.clear();
216
214
}
217
215
};
218
216
@@ -664,8 +662,6 @@ void ResetCommandBuffer(WGPUDevice &device, const Shape &nThreads, Kernel &op) {
664
662
op.commandBuffer = wgpuCommandEncoderFinish (commandEncoder, nullptr );
665
663
check (op.commandBuffer , " Create command buffer" , __FILE__, __LINE__);
666
664
}
667
- // op.promise = std::promise<void>();
668
- // op.future = op.promise.get_future();
669
665
}
670
666
671
667
/* *
@@ -800,7 +796,6 @@ Kernel CreateKernel(Context &ctx, const ShaderCode &shader,
800
796
.entries = bindGroupEntries.data (),
801
797
};
802
798
op.bindGroup = wgpuDeviceCreateBindGroup (device, &bindGroupDesc);
803
-
804
799
{
805
800
WGPUPipelineLayoutDescriptor pipelineLayoutDesc = {
806
801
.bindGroupLayoutCount = 1 ,
@@ -833,42 +828,6 @@ Kernel CreateKernel(Context &ctx, const ShaderCode &shader,
833
828
return op;
834
829
}
835
830
836
- /* *
837
- * @brief Overload which wraps the CreateKernel factory function to create a
838
- * kernel on the GPU with a statically determined ParamsType instead of casting
839
- * params to a void pointer. paramSize is then determined by the size of the
840
- * ParamsType.
841
- *
842
- * @param[in] ctx Context instance to manage the kernel
843
- * @param[in] shader Shader code for the kernel
844
- * @param[in] inputs A span of input tensors as a pointer
845
- * @param[in] numInputs Number of input tensors, effectively the size of the
846
- * *inputs span.
847
- * @param[in] output Output tensor for the kernel
848
- * @param[in] nThreads Shape of the workgroup size for the kernel, must be of
849
- * rank 3.
850
- * @param[in] params Optional parameters for the kernel. If the kernel does not
851
- * have any parameters, use NoParam.
852
- * @example Kernel kernel = CreateKernel(ctx, shader, inputs, numInputs, output,
853
- * nThreads, params);
854
- */
855
- template <typename ParamsType = NoParam>
856
- Kernel CreateKernel (Context &ctx, const ShaderCode &shader,
857
- const Tensor *inputs, size_t numInputs,
858
- const Shape &nThreads,
859
- const ParamsType ¶ms = ParamsType{}) {
860
- if constexpr (!IsNoParam<ParamsType>) {
861
- log (kDefLog , kInfo , " Using params of size %d bytes" , sizeof (ParamsType));
862
- return CreateKernel (ctx, shader, inputs, numInputs, nThreads,
863
- reinterpret_cast <const void *>(¶ms),
864
- sizeof (ParamsType));
865
- } else {
866
- log (kDefLog , kInfo , " No params" );
867
- return CreateKernel (ctx, shader, inputs, numInputs, nThreads,
868
- nullptr , 0 );
869
- }
870
- }
871
-
872
831
/* *
873
832
* @brief Overload which wraps the CreateKernel factory function to create a
874
833
* kernel on the GPU. This overload uses takes a static collection of input
@@ -892,17 +851,16 @@ Kernel CreateKernel(Context &ctx, const ShaderCode &shader,
892
851
const TensorList<numInputs> &inputs,
893
852
const Shape &nThreads,
894
853
const ParamsType ¶ms = ParamsType{}) {
895
- // first .data gets the array, second .data() gets the pointer
896
- return CreateKernel<ParamsType>(ctx, shader, inputs.data .data (), numInputs,
897
- nThreads, params);
898
- }
899
-
900
- // Convenience wrapper: specialization for single input passed by reference
901
- template <typename ParamsType = NoParam>
902
- Kernel CreateKernel (Context &ctx, const ShaderCode &shader, const Tensor &input,
903
- const Shape &nThreads,
904
- const ParamsType ¶ms = ParamsType{}) {
905
- return CreateKernel (ctx, shader, &input, 1 , nThreads, params);
854
+ if constexpr (!IsNoParam<ParamsType>) {
855
+ log (kDefLog , kInfo , " Using params of size %d bytes" , sizeof (ParamsType));
856
+ return CreateKernel (ctx, shader, inputs.data .data (), numInputs, nThreads,
857
+ reinterpret_cast <const void *>(¶ms),
858
+ sizeof (ParamsType));
859
+ } else {
860
+ log (kDefLog , kInfo , " No params" );
861
+ return CreateKernel (ctx, shader, inputs.data .data (), numInputs, nThreads,
862
+ nullptr , 0 );
863
+ }
906
864
}
907
865
908
866
/* *
0 commit comments