Skip to content

Commit 65c1db7

Browse files
committed
separate out CreateCommandBuffer
1 parent 0735132 commit 65c1db7

File tree

1 file changed

+34
-44
lines changed

1 file changed

+34
-44
lines changed

gpu.h

Lines changed: 34 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -531,26 +531,43 @@ void ToGPU(GPUContext &ctx, const float *data, GPUTensor &tensor) {
531531
tensor.data.size);
532532
}
533533

534+
// Separate this out since WGPUCommandBuffer is destroyed upon submission
534535
WGPUCommandBuffer
535-
CreateCommandBuffer(GPUContext &ctx,
536-
const WGPUComputePipeline &computePipeline) {
536+
CreateCommandBuffer(WGPUDevice &device,
537+
const WGPUComputePipeline &computePipeline,
538+
const WGPUBindGroup &bindGroup, const ShaderCode &shader,
539+
const Shape &nThreads) {
537540
WGPUCommandBuffer commandBuffer;
538-
WGPUCommandEncoder commandEncoder;
539-
WGPUComputePassEncoder computePassEncoder;
540-
commandEncoder = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr);
541-
computePassEncoder =
542-
wgpuCommandEncoderBeginComputePass(commandEncoder, nullptr);
543-
wgpuComputePassEncoderSetPipeline(computePassEncoder, computePipeline);
544-
// TODO(avh): WIP - set bind group etc and finish the command buffer
545-
// then split CreateKernel / CreateMultiKernel s.t.
546-
// CommandBuffer is prepared per-dispatch since it's not reusable
541+
log(kDefLog, kInfo, "Create command buffer 0x%x", commandBuffer);
542+
{
543+
WGPUCommandEncoder commandEncoder =
544+
wgpuDeviceCreateCommandEncoder(device, nullptr);
545+
WGPUComputePassEncoder computePassEncoder =
546+
wgpuCommandEncoderBeginComputePass(commandEncoder, nullptr);
547+
wgpuComputePassEncoderSetPipeline(computePassEncoder, computePipeline);
548+
wgpuComputePassEncoderSetBindGroup(computePassEncoder, 0, bindGroup, 0,
549+
nullptr);
550+
log(kDefLog, kInfo, "Dispatching workgroups for number of threads = %s",
551+
ToString(nThreads).c_str());
552+
wgpuComputePassEncoderDispatchWorkgroups(
553+
computePassEncoder,
554+
/* # X workgroups */ (nThreads[0] + (shader.workgroupSize[0] - 1)) /
555+
shader.workgroupSize[0],
556+
/* # Y workgroups */ (nThreads[1] + (shader.workgroupSize[1] - 1)) /
557+
shader.workgroupSize[1],
558+
/* # Z workgroups */ (nThreads[2] + (shader.workgroupSize[2] - 1)) /
559+
shader.workgroupSize[2]);
560+
wgpuComputePassEncoderEnd(computePassEncoder);
561+
commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr);
562+
check(commandBuffer, "Create command buffer", __FILE__, __LINE__);
563+
}
547564
return commandBuffer;
548565
}
549566

550567
Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
551568
const GPUTensor *inputs, size_t numInputs,
552-
const GPUTensor &output, const Shape &nThreads, const void *params,
553-
size_t paramsSize) {
569+
const GPUTensor &output, const Shape &nThreads,
570+
const void *params, size_t paramsSize) {
554571
assert(nThreads.rank == 3);
555572
WGPUDevice device = ctx.device;
556573
WGPUQueue queue = ctx.queue;
@@ -710,35 +727,8 @@ Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
710727
wgpuDeviceCreateComputePipeline(device, &computePipelineDesc);
711728
check(op.computePipeline, "Create compute pipeline", __FILE__, __LINE__);
712729
}
713-
714-
log(kDefLog, kInfo, "Create the command encoder");
715-
{
716-
// After beginning the compute pass, use
717-
// wgpuComputePassEncoderInsertDebugMarker instead of
718-
// wgpuCommandEncoderInsertDebugMarker o/w the command encoder will be
719-
// locked after wgpuComputePassEncoderEnd.
720-
WGPUCommandEncoder commandEncoder =
721-
wgpuDeviceCreateCommandEncoder(device, nullptr);
722-
WGPUComputePassEncoder computePassEncoder =
723-
wgpuCommandEncoderBeginComputePass(commandEncoder, nullptr);
724-
wgpuComputePassEncoderSetPipeline(computePassEncoder, op.computePipeline);
725-
wgpuComputePassEncoderSetBindGroup(computePassEncoder, 0, bindGroup, 0,
726-
nullptr);
727-
728-
log(kDefLog, kInfo, "Dispatching workgroups for number of threads = %s",
729-
ToString(nThreads).c_str());
730-
wgpuComputePassEncoderDispatchWorkgroups(
731-
computePassEncoder,
732-
/* # X workgroups */ (nThreads[0] + (shader.workgroupSize[0] - 1)) /
733-
shader.workgroupSize[0],
734-
/* # Y workgroups */ (nThreads[1] + (shader.workgroupSize[1] - 1)) /
735-
shader.workgroupSize[1],
736-
/* # Z workgroups */ (nThreads[2] + (shader.workgroupSize[2] - 1)) /
737-
shader.workgroupSize[2]);
738-
wgpuComputePassEncoderEnd(computePassEncoder);
739-
op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr);
740-
check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__);
741-
}
730+
op.commandBuffer = CreateCommandBuffer(device, op.computePipeline, bindGroup,
731+
shader, nThreads);
742732

743733
log(kDefLog, kInfo, "Initializing callbackData");
744734
op.callbackData = {op.readbackBuffer, op.outputSize, nullptr, &op.promise};
@@ -767,7 +757,8 @@ Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
767757
}
768758
}
769759

770-
// Convenience wrapper: inputs is GPUTensors static collection instead of a pointer
760+
// Convenience wrapper: inputs is GPUTensors static collection instead of a
761+
// pointer
771762
template <typename ParamsType = NoParam, size_t numInputs>
772763
Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
773764
const GPUTensors<numInputs> &inputs,
@@ -817,7 +808,6 @@ MultiKernel CreateMultiKernel(GPUContext &ctx, const MultiKernelDesc &desc) {
817808
WGPUCommandEncoder commandEncoder =
818809
wgpuDeviceCreateCommandEncoder(device, nullptr);
819810
size_t bufferIndex = 0;
820-
commandEncoder = wgpuDeviceCreateCommandEncoder(device, nullptr);
821811

822812
// Iterate over all shaders in the pipeline
823813
for (size_t shaderIndex = 0; shaderIndex < desc.numShaders; ++shaderIndex) {

0 commit comments

Comments
 (0)