@@ -531,26 +531,43 @@ void ToGPU(GPUContext &ctx, const float *data, GPUTensor &tensor) {
531
531
tensor.data .size );
532
532
}
533
533
534
+ // Separate this out since WGPUCommandBuffer is destroyed upon submission
534
535
WGPUCommandBuffer
535
- CreateCommandBuffer (GPUContext &ctx,
536
- const WGPUComputePipeline &computePipeline) {
536
+ CreateCommandBuffer (WGPUDevice &device,
537
+ const WGPUComputePipeline &computePipeline,
538
+ const WGPUBindGroup &bindGroup, const ShaderCode &shader,
539
+ const Shape &nThreads) {
537
540
WGPUCommandBuffer commandBuffer;
538
- WGPUCommandEncoder commandEncoder;
539
- WGPUComputePassEncoder computePassEncoder;
540
- commandEncoder = wgpuDeviceCreateCommandEncoder (ctx.device , nullptr );
541
- computePassEncoder =
542
- wgpuCommandEncoderBeginComputePass (commandEncoder, nullptr );
543
- wgpuComputePassEncoderSetPipeline (computePassEncoder, computePipeline);
544
- // TODO(avh): WIP - set bind group etc and finish the command buffer
545
- // then split CreateKernel / CreateMultiKernel s.t.
546
- // CommandBuffer is prepared per-dispatch since it's not reusable
541
+ log (kDefLog , kInfo , " Create command buffer 0x%x" , commandBuffer);
542
+ {
543
+ WGPUCommandEncoder commandEncoder =
544
+ wgpuDeviceCreateCommandEncoder (device, nullptr );
545
+ WGPUComputePassEncoder computePassEncoder =
546
+ wgpuCommandEncoderBeginComputePass (commandEncoder, nullptr );
547
+ wgpuComputePassEncoderSetPipeline (computePassEncoder, computePipeline);
548
+ wgpuComputePassEncoderSetBindGroup (computePassEncoder, 0 , bindGroup, 0 ,
549
+ nullptr );
550
+ log (kDefLog , kInfo , " Dispatching workgroups for number of threads = %s" ,
551
+ ToString (nThreads).c_str ());
552
+ wgpuComputePassEncoderDispatchWorkgroups (
553
+ computePassEncoder,
554
+ /* # X workgroups */ (nThreads[0 ] + (shader.workgroupSize [0 ] - 1 )) /
555
+ shader.workgroupSize [0 ],
556
+ /* # Y workgroups */ (nThreads[1 ] + (shader.workgroupSize [1 ] - 1 )) /
557
+ shader.workgroupSize [1 ],
558
+ /* # Z workgroups */ (nThreads[2 ] + (shader.workgroupSize [2 ] - 1 )) /
559
+ shader.workgroupSize [2 ]);
560
+ wgpuComputePassEncoderEnd (computePassEncoder);
561
+ commandBuffer = wgpuCommandEncoderFinish (commandEncoder, nullptr );
562
+ check (commandBuffer, " Create command buffer" , __FILE__, __LINE__);
563
+ }
547
564
return commandBuffer;
548
565
}
549
566
550
567
Kernel CreateKernel (GPUContext &ctx, const ShaderCode &shader,
551
568
const GPUTensor *inputs, size_t numInputs,
552
- const GPUTensor &output, const Shape &nThreads, const void *params,
553
- size_t paramsSize) {
569
+ const GPUTensor &output, const Shape &nThreads,
570
+ const void *params, size_t paramsSize) {
554
571
assert (nThreads.rank == 3 );
555
572
WGPUDevice device = ctx.device ;
556
573
WGPUQueue queue = ctx.queue ;
@@ -710,35 +727,8 @@ Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
710
727
wgpuDeviceCreateComputePipeline (device, &computePipelineDesc);
711
728
check (op.computePipeline , " Create compute pipeline" , __FILE__, __LINE__);
712
729
}
713
-
714
- log (kDefLog , kInfo , " Create the command encoder" );
715
- {
716
- // After beginning the compute pass, use
717
- // wgpuComputePassEncoderInsertDebugMarker instead of
718
- // wgpuCommandEncoderInsertDebugMarker o/w the command encoder will be
719
- // locked after wgpuComputePassEncoderEnd.
720
- WGPUCommandEncoder commandEncoder =
721
- wgpuDeviceCreateCommandEncoder (device, nullptr );
722
- WGPUComputePassEncoder computePassEncoder =
723
- wgpuCommandEncoderBeginComputePass (commandEncoder, nullptr );
724
- wgpuComputePassEncoderSetPipeline (computePassEncoder, op.computePipeline );
725
- wgpuComputePassEncoderSetBindGroup (computePassEncoder, 0 , bindGroup, 0 ,
726
- nullptr );
727
-
728
- log (kDefLog , kInfo , " Dispatching workgroups for number of threads = %s" ,
729
- ToString (nThreads).c_str ());
730
- wgpuComputePassEncoderDispatchWorkgroups (
731
- computePassEncoder,
732
- /* # X workgroups */ (nThreads[0 ] + (shader.workgroupSize [0 ] - 1 )) /
733
- shader.workgroupSize [0 ],
734
- /* # Y workgroups */ (nThreads[1 ] + (shader.workgroupSize [1 ] - 1 )) /
735
- shader.workgroupSize [1 ],
736
- /* # Z workgroups */ (nThreads[2 ] + (shader.workgroupSize [2 ] - 1 )) /
737
- shader.workgroupSize [2 ]);
738
- wgpuComputePassEncoderEnd (computePassEncoder);
739
- op.commandBuffer = wgpuCommandEncoderFinish (commandEncoder, nullptr );
740
- check (op.commandBuffer , " Create command buffer" , __FILE__, __LINE__);
741
- }
730
+ op.commandBuffer = CreateCommandBuffer (device, op.computePipeline , bindGroup,
731
+ shader, nThreads);
742
732
743
733
log (kDefLog , kInfo , " Initializing callbackData" );
744
734
op.callbackData = {op.readbackBuffer , op.outputSize , nullptr , &op.promise };
@@ -767,7 +757,8 @@ Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
767
757
}
768
758
}
769
759
770
- // Convenience wrapper: inputs is GPUTensors static collection instead of a pointer
760
+ // Convenience wrapper: inputs is GPUTensors static collection instead of a
761
+ // pointer
771
762
template <typename ParamsType = NoParam, size_t numInputs>
772
763
Kernel CreateKernel (GPUContext &ctx, const ShaderCode &shader,
773
764
const GPUTensors<numInputs> &inputs,
@@ -817,7 +808,6 @@ MultiKernel CreateMultiKernel(GPUContext &ctx, const MultiKernelDesc &desc) {
817
808
WGPUCommandEncoder commandEncoder =
818
809
wgpuDeviceCreateCommandEncoder (device, nullptr );
819
810
size_t bufferIndex = 0 ;
820
- commandEncoder = wgpuDeviceCreateCommandEncoder (device, nullptr );
821
811
822
812
// Iterate over all shaders in the pipeline
823
813
for (size_t shaderIndex = 0 ; shaderIndex < desc.numShaders ; ++shaderIndex) {
0 commit comments