diff --git a/CHANGELOG.md b/CHANGELOG.md index 180ae0d1ea..fdff27d2fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -205,6 +205,7 @@ By @wumpf in [#7144](https://github.com/gfx-rs/wgpu/pull/7144) - If you use Binding Arrays in a bind group, you may not use Dynamic Offset Buffers or Uniform Buffers in that bind group. By @cwfitzgerald in [#6811](https://github.com/gfx-rs/wgpu/pull/6811) - Rename `instance_id` and `instance_custom_index` to `instance_index` and `instance_custom_data` by @Vecvec in [#6780](https://github.com/gfx-rs/wgpu/pull/6780) +- Add mesh shader support to `wgpu` (currently vulkan + spirv-passthrough only). By @SupaMaggie70Incorporated in [#7345](https://github.com/gfx-rs/wgpu/pull/7345) #### Naga diff --git a/examples/features/src/lib.rs b/examples/features/src/lib.rs index f56f19c62f..7193306e81 100644 --- a/examples/features/src/lib.rs +++ b/examples/features/src/lib.rs @@ -12,6 +12,7 @@ pub mod hello_synchronization; pub mod hello_triangle; pub mod hello_windows; pub mod hello_workgroups; +pub mod mesh_shader; pub mod mipmap; pub mod msaa_line; pub mod multiple_render_targets; diff --git a/examples/features/src/main.rs b/examples/features/src/main.rs index d803ba249d..97790fe355 100644 --- a/examples/features/src/main.rs +++ b/examples/features/src/main.rs @@ -176,6 +176,12 @@ const EXAMPLES: &[ExampleDesc] = &[ webgl: false, // No Ray-tracing extensions webgpu: false, // No Ray-tracing extensions (yet) }, + ExampleDesc { + name: "mesh_shader", + function: wgpu_examples::mesh_shader::main, + webgl: false, + webgpu: false, + }, ]; fn get_example_name() -> Option { diff --git a/examples/features/src/mesh_shader/README.md b/examples/features/src/mesh_shader/README.md new file mode 100644 index 0000000000..9b57d3e490 --- /dev/null +++ b/examples/features/src/mesh_shader/README.md @@ -0,0 +1,9 @@ +# mesh_shader + +This example renders a triangle to a window with mesh shaders, while showcasing most mesh shader related features(task shaders, payloads, per primitive data). + +## To Run + +``` +cargo run --bin wgpu-examples mesh_shader +``` \ No newline at end of file diff --git a/examples/features/src/mesh_shader/mod.rs b/examples/features/src/mesh_shader/mod.rs new file mode 100644 index 0000000000..777018a609 --- /dev/null +++ b/examples/features/src/mesh_shader/mod.rs @@ -0,0 +1,130 @@ +use std::{io::Write, process::Stdio}; + +// Same as in mesh shader tests +fn compile_spv_asm(device: &wgpu::Device, data: &[u8]) -> wgpu::ShaderModule { + let cmd = std::process::Command::new("spirv-as") + .args(["-", "-o", "-"]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .expect("Failed to call spirv-as"); + cmd.stdin.as_ref().unwrap().write_all(data).unwrap(); + let output = cmd.wait_with_output().expect("Error waiting for spirv-as"); + assert!(output.status.success()); + unsafe { + device.create_shader_module_spirv(&wgpu::ShaderModuleDescriptorSpirV { + label: None, + source: wgpu::util::make_spirv_raw(&output.stdout), + }) + } +} + +pub struct Example { + pipeline: wgpu::RenderPipeline, +} +impl crate::framework::Example for Example { + fn init( + config: &wgpu::SurfaceConfiguration, + _adapter: &wgpu::Adapter, + device: &wgpu::Device, + _queue: &wgpu::Queue, + ) -> Self { + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[], + push_constant_ranges: &[], + }); + let (ts, ms, fs) = ( + compile_spv_asm(device, include_bytes!("shader.task.spv.asm")), + compile_spv_asm(device, include_bytes!("shader.mesh.spv.asm")), + compile_spv_asm(device, include_bytes!("shader.frag.spv.asm")), + ); + let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor { + label: None, + layout: Some(&pipeline_layout), + task: Some(wgpu::TaskState { + module: &ts, + entry_point: Some("main"), + compilation_options: Default::default(), + }), + mesh: wgpu::MeshState { + module: &ms, + entry_point: Some("main"), + compilation_options: Default::default(), + }, + fragment: Some(wgpu::FragmentState { + module: &fs, + entry_point: Some("main"), + compilation_options: Default::default(), + targets: &[Some(config.view_formats[0].into())], + }), + primitive: wgpu::PrimitiveState { + cull_mode: Some(wgpu::Face::Back), + ..Default::default() + }, + depth_stencil: None, + multisample: Default::default(), + multiview: None, + cache: None, + }); + Self { pipeline } + } + fn optional_features() -> wgpu::Features { + wgpu::Features::empty() + } + fn render(&mut self, view: &wgpu::TextureView, device: &wgpu::Device, queue: &wgpu::Queue) { + let mut encoder = + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor { + label: None, + color_attachments: &[Some(wgpu::RenderPassColorAttachment { + view, + resolve_target: None, + ops: wgpu::Operations { + load: wgpu::LoadOp::Clear(wgpu::Color { + r: 0.1, + g: 0.2, + b: 0.3, + a: 1.0, + }), + store: wgpu::StoreOp::Store, + }, + })], + depth_stencil_attachment: None, + timestamp_writes: None, + occlusion_query_set: None, + }); + rpass.push_debug_group("Prepare data for draw."); + rpass.set_pipeline(&self.pipeline); + rpass.pop_debug_group(); + rpass.insert_debug_marker("Draw!"); + rpass.draw_mesh_tasks(1, 1, 1); + } + queue.submit(Some(encoder.finish())); + } + fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities { + Default::default() + } + fn required_features() -> wgpu::Features { + wgpu::Features::EXPERIMENTAL_MESH_SHADER | wgpu::Features::SPIRV_SHADER_PASSTHROUGH + } + fn required_limits() -> wgpu::Limits { + Default::default() + } + fn resize( + &mut self, + _config: &wgpu::SurfaceConfiguration, + _device: &wgpu::Device, + _queue: &wgpu::Queue, + ) { + // empty + } + fn update(&mut self, _event: winit::event::WindowEvent) { + // empty + } +} + +pub fn main() { + crate::framework::run::("mesh_shader"); +} diff --git a/examples/features/src/mesh_shader/shader.frag b/examples/features/src/mesh_shader/shader.frag new file mode 100644 index 0000000000..49624990f1 --- /dev/null +++ b/examples/features/src/mesh_shader/shader.frag @@ -0,0 +1,11 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +in VertexInput { layout(location = 0) vec4 color; } +vertexInput; +layout(location = 1) perprimitiveEXT in PrimitiveInput { vec4 colorMask; } +primitiveInput; + +layout(location = 0) out vec4 fragColor; + +void main() { fragColor = vertexInput.color * primitiveInput.colorMask; } \ No newline at end of file diff --git a/examples/features/src/mesh_shader/shader.frag.spv.asm b/examples/features/src/mesh_shader/shader.frag.spv.asm new file mode 100644 index 0000000000..88ca008242 --- /dev/null +++ b/examples/features/src/mesh_shader/shader.frag.spv.asm @@ -0,0 +1,53 @@ +; SPIR-V +; Version: 1.5 +; Generator: Khronos Glslang Reference Front End; 11 +; Bound: 24 +; Schema: 0 + OpCapability Shader + OpCapability MeshShadingEXT + OpExtension "SPV_EXT_mesh_shader" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %fragColor %vertexInput %primitiveInput + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpSourceExtension "GL_EXT_mesh_shader" + OpName %main "main" + OpName %fragColor "fragColor" + OpName %VertexInput "VertexInput" + OpMemberName %VertexInput 0 "color" + OpName %vertexInput "vertexInput" + OpName %PrimitiveInput "PrimitiveInput" + OpMemberName %PrimitiveInput 0 "colorMask" + OpName %primitiveInput "primitiveInput" + OpDecorate %fragColor Location 0 + OpDecorate %VertexInput Block + OpMemberDecorate %VertexInput 0 Location 0 + OpDecorate %PrimitiveInput Block + OpMemberDecorate %PrimitiveInput 0 PerPrimitiveEXT + OpDecorate %primitiveInput Location 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %fragColor = OpVariable %_ptr_Output_v4float Output +%VertexInput = OpTypeStruct %v4float +%_ptr_Input_VertexInput = OpTypePointer Input %VertexInput +%vertexInput = OpVariable %_ptr_Input_VertexInput Input + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%PrimitiveInput = OpTypeStruct %v4float +%_ptr_Input_PrimitiveInput = OpTypePointer Input %PrimitiveInput +%primitiveInput = OpVariable %_ptr_Input_PrimitiveInput Input + %main = OpFunction %void None %3 + %5 = OpLabel + %16 = OpAccessChain %_ptr_Input_v4float %vertexInput %int_0 + %17 = OpLoad %v4float %16 + %21 = OpAccessChain %_ptr_Input_v4float %primitiveInput %int_0 + %22 = OpLoad %v4float %21 + %23 = OpFMul %v4float %17 %22 + OpStore %fragColor %23 + OpReturn + OpFunctionEnd diff --git a/examples/features/src/mesh_shader/shader.mesh b/examples/features/src/mesh_shader/shader.mesh new file mode 100644 index 0000000000..8805d39317 --- /dev/null +++ b/examples/features/src/mesh_shader/shader.mesh @@ -0,0 +1,36 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +const vec4[3] positions = {vec4(0., 1.0, 0., 1.0), vec4(-1.0, -1.0, 0., 1.0), + vec4(1.0, -1.0, 0., 1.0)}; +const vec4[3] colors = {vec4(0., 1., 0., 1.), vec4(0., 0., 1., 1.), + vec4(1., 0., 0., 1.)}; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +struct PayloadData { + vec4 colorMask; + bool visible; +}; +taskPayloadSharedEXT PayloadData payloadData; + +out VertexOutput { layout(location = 0) vec4 color; } +vertexOutput[]; +layout(location = 1) perprimitiveEXT out PrimitiveOutput { vec4 colorMask; } +primitiveOutput[]; + +shared uint sharedData; + +layout(triangles, max_vertices = 3, max_primitives = 1) out; +void main() { + sharedData = 5; + SetMeshOutputsEXT(3, 1); + gl_MeshVerticesEXT[0].gl_Position = positions[0]; + gl_MeshVerticesEXT[1].gl_Position = positions[1]; + gl_MeshVerticesEXT[2].gl_Position = positions[2]; + vertexOutput[0].color = colors[0] * payloadData.colorMask; + vertexOutput[1].color = colors[1] * payloadData.colorMask; + vertexOutput[2].color = colors[2] * payloadData.colorMask; + gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uvec3(0, 1, 2); + primitiveOutput[0].colorMask = vec4(1.0, 0.0, 1.0, 1.0); + gl_MeshPrimitivesEXT[0].gl_CullPrimitiveEXT = !payloadData.visible; +} \ No newline at end of file diff --git a/examples/features/src/mesh_shader/shader.mesh.spv.asm b/examples/features/src/mesh_shader/shader.mesh.spv.asm new file mode 100644 index 0000000000..639a9a8d8c --- /dev/null +++ b/examples/features/src/mesh_shader/shader.mesh.spv.asm @@ -0,0 +1,164 @@ +; SPIR-V +; Version: 1.5 +; Generator: Khronos Glslang Reference Front End; 11 +; Bound: 89 +; Schema: 0 + OpCapability MeshShadingEXT + OpExtension "SPV_EXT_mesh_shader" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint MeshEXT %main "main" %sharedData %gl_MeshVerticesEXT %vertexOutput %payloadData %gl_PrimitiveTriangleIndicesEXT %gl_LocalInvocationIndex %primitiveOutput %gl_MeshPrimitivesEXT + OpExecutionMode %main LocalSize 1 1 1 + OpExecutionMode %main OutputVertices 3 + OpExecutionMode %main OutputPrimitivesEXT 1 + OpExecutionMode %main OutputTrianglesEXT + OpSource GLSL 450 + OpSourceExtension "GL_EXT_mesh_shader" + OpName %main "main" + OpName %sharedData "sharedData" + OpName %gl_MeshPerVertexEXT "gl_MeshPerVertexEXT" + OpMemberName %gl_MeshPerVertexEXT 0 "gl_Position" + OpMemberName %gl_MeshPerVertexEXT 1 "gl_PointSize" + OpMemberName %gl_MeshPerVertexEXT 2 "gl_ClipDistance" + OpMemberName %gl_MeshPerVertexEXT 3 "gl_CullDistance" + OpName %gl_MeshVerticesEXT "gl_MeshVerticesEXT" + OpName %VertexOutput "VertexOutput" + OpMemberName %VertexOutput 0 "color" + OpName %vertexOutput "vertexOutput" + OpName %PayloadData "PayloadData" + OpMemberName %PayloadData 0 "colorMask" + OpMemberName %PayloadData 1 "visible" + OpName %payloadData "payloadData" + OpName %gl_PrimitiveTriangleIndicesEXT "gl_PrimitiveTriangleIndicesEXT" + OpName %gl_LocalInvocationIndex "gl_LocalInvocationIndex" + OpName %PrimitiveOutput "PrimitiveOutput" + OpMemberName %PrimitiveOutput 0 "colorMask" + OpName %primitiveOutput "primitiveOutput" + OpName %gl_MeshPerPrimitiveEXT "gl_MeshPerPrimitiveEXT" + OpMemberName %gl_MeshPerPrimitiveEXT 0 "gl_PrimitiveID" + OpMemberName %gl_MeshPerPrimitiveEXT 1 "gl_Layer" + OpMemberName %gl_MeshPerPrimitiveEXT 2 "gl_ViewportIndex" + OpMemberName %gl_MeshPerPrimitiveEXT 3 "gl_CullPrimitiveEXT" + OpName %gl_MeshPrimitivesEXT "gl_MeshPrimitivesEXT" + OpDecorate %gl_MeshPerVertexEXT Block + OpMemberDecorate %gl_MeshPerVertexEXT 0 BuiltIn Position + OpMemberDecorate %gl_MeshPerVertexEXT 1 BuiltIn PointSize + OpMemberDecorate %gl_MeshPerVertexEXT 2 BuiltIn ClipDistance + OpMemberDecorate %gl_MeshPerVertexEXT 3 BuiltIn CullDistance + OpDecorate %VertexOutput Block + OpMemberDecorate %VertexOutput 0 Location 0 + OpDecorate %gl_PrimitiveTriangleIndicesEXT BuiltIn PrimitiveTriangleIndicesEXT + OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex + OpDecorate %PrimitiveOutput Block + OpMemberDecorate %PrimitiveOutput 0 PerPrimitiveEXT + OpDecorate %primitiveOutput Location 1 + OpDecorate %gl_MeshPerPrimitiveEXT Block + OpMemberDecorate %gl_MeshPerPrimitiveEXT 0 BuiltIn PrimitiveId + OpMemberDecorate %gl_MeshPerPrimitiveEXT 0 PerPrimitiveEXT + OpMemberDecorate %gl_MeshPerPrimitiveEXT 1 BuiltIn Layer + OpMemberDecorate %gl_MeshPerPrimitiveEXT 1 PerPrimitiveEXT + OpMemberDecorate %gl_MeshPerPrimitiveEXT 2 BuiltIn ViewportIndex + OpMemberDecorate %gl_MeshPerPrimitiveEXT 2 PerPrimitiveEXT + OpMemberDecorate %gl_MeshPerPrimitiveEXT 3 BuiltIn CullPrimitiveEXT + OpMemberDecorate %gl_MeshPerPrimitiveEXT 3 PerPrimitiveEXT + OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 +%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint + %sharedData = OpVariable %_ptr_Workgroup_uint Workgroup + %uint_5 = OpConstant %uint 5 + %uint_3 = OpConstant %uint 3 + %uint_1 = OpConstant %uint 1 + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_arr_float_uint_1 = OpTypeArray %float %uint_1 +%gl_MeshPerVertexEXT = OpTypeStruct %v4float %float %_arr_float_uint_1 %_arr_float_uint_1 +%_arr_gl_MeshPerVertexEXT_uint_3 = OpTypeArray %gl_MeshPerVertexEXT %uint_3 +%_ptr_Output__arr_gl_MeshPerVertexEXT_uint_3 = OpTypePointer Output %_arr_gl_MeshPerVertexEXT_uint_3 +%gl_MeshVerticesEXT = OpVariable %_ptr_Output__arr_gl_MeshPerVertexEXT_uint_3 Output + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %float_0 = OpConstant %float 0 + %float_1 = OpConstant %float 1 + %23 = OpConstantComposite %v4float %float_0 %float_1 %float_0 %float_1 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %int_1 = OpConstant %int 1 + %float_n1 = OpConstant %float -1 + %28 = OpConstantComposite %v4float %float_n1 %float_n1 %float_0 %float_1 + %int_2 = OpConstant %int 2 + %31 = OpConstantComposite %v4float %float_1 %float_n1 %float_0 %float_1 +%VertexOutput = OpTypeStruct %v4float +%_arr_VertexOutput_uint_3 = OpTypeArray %VertexOutput %uint_3 +%_ptr_Output__arr_VertexOutput_uint_3 = OpTypePointer Output %_arr_VertexOutput_uint_3 +%vertexOutput = OpVariable %_ptr_Output__arr_VertexOutput_uint_3 Output + %bool = OpTypeBool +%PayloadData = OpTypeStruct %v4float %bool +%_ptr_TaskPayloadWorkgroupEXT_PayloadData = OpTypePointer TaskPayloadWorkgroupEXT %PayloadData +%payloadData = OpVariable %_ptr_TaskPayloadWorkgroupEXT_PayloadData TaskPayloadWorkgroupEXT +%_ptr_TaskPayloadWorkgroupEXT_v4float = OpTypePointer TaskPayloadWorkgroupEXT %v4float + %46 = OpConstantComposite %v4float %float_0 %float_0 %float_1 %float_1 + %51 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1 + %v3uint = OpTypeVector %uint 3 +%_arr_v3uint_uint_1 = OpTypeArray %v3uint %uint_1 +%_ptr_Output__arr_v3uint_uint_1 = OpTypePointer Output %_arr_v3uint_uint_1 +%gl_PrimitiveTriangleIndicesEXT = OpVariable %_ptr_Output__arr_v3uint_uint_1 Output +%_ptr_Input_uint = OpTypePointer Input %uint +%gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input + %uint_0 = OpConstant %uint 0 + %uint_2 = OpConstant %uint 2 + %65 = OpConstantComposite %v3uint %uint_0 %uint_1 %uint_2 +%_ptr_Output_v3uint = OpTypePointer Output %v3uint +%PrimitiveOutput = OpTypeStruct %v4float +%_arr_PrimitiveOutput_uint_1 = OpTypeArray %PrimitiveOutput %uint_1 +%_ptr_Output__arr_PrimitiveOutput_uint_1 = OpTypePointer Output %_arr_PrimitiveOutput_uint_1 +%primitiveOutput = OpVariable %_ptr_Output__arr_PrimitiveOutput_uint_1 Output + %72 = OpConstantComposite %v4float %float_1 %float_0 %float_1 %float_1 +%gl_MeshPerPrimitiveEXT = OpTypeStruct %int %int %int %bool +%_arr_gl_MeshPerPrimitiveEXT_uint_1 = OpTypeArray %gl_MeshPerPrimitiveEXT %uint_1 +%_ptr_Output__arr_gl_MeshPerPrimitiveEXT_uint_1 = OpTypePointer Output %_arr_gl_MeshPerPrimitiveEXT_uint_1 +%gl_MeshPrimitivesEXT = OpVariable %_ptr_Output__arr_gl_MeshPerPrimitiveEXT_uint_1 Output + %int_3 = OpConstant %int 3 +%_ptr_TaskPayloadWorkgroupEXT_bool = OpTypePointer TaskPayloadWorkgroupEXT %bool +%_ptr_Output_bool = OpTypePointer Output %bool +%_arr_v4float_uint_3 = OpTypeArray %v4float %uint_3 + %86 = OpConstantComposite %_arr_v4float_uint_3 %23 %28 %31 + %87 = OpConstantComposite %_arr_v4float_uint_3 %23 %46 %51 +%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1 + %main = OpFunction %void None %3 + %5 = OpLabel + OpStore %sharedData %uint_5 + OpSetMeshOutputsEXT %uint_3 %uint_1 + %25 = OpAccessChain %_ptr_Output_v4float %gl_MeshVerticesEXT %int_0 %int_0 + OpStore %25 %23 + %29 = OpAccessChain %_ptr_Output_v4float %gl_MeshVerticesEXT %int_1 %int_0 + OpStore %29 %28 + %32 = OpAccessChain %_ptr_Output_v4float %gl_MeshVerticesEXT %int_2 %int_0 + OpStore %32 %31 + %42 = OpAccessChain %_ptr_TaskPayloadWorkgroupEXT_v4float %payloadData %int_0 + %43 = OpLoad %v4float %42 + %44 = OpFMul %v4float %23 %43 + %45 = OpAccessChain %_ptr_Output_v4float %vertexOutput %int_0 %int_0 + OpStore %45 %44 + %47 = OpAccessChain %_ptr_TaskPayloadWorkgroupEXT_v4float %payloadData %int_0 + %48 = OpLoad %v4float %47 + %49 = OpFMul %v4float %46 %48 + %50 = OpAccessChain %_ptr_Output_v4float %vertexOutput %int_1 %int_0 + OpStore %50 %49 + %52 = OpAccessChain %_ptr_TaskPayloadWorkgroupEXT_v4float %payloadData %int_0 + %53 = OpLoad %v4float %52 + %54 = OpFMul %v4float %51 %53 + %55 = OpAccessChain %_ptr_Output_v4float %vertexOutput %int_2 %int_0 + OpStore %55 %54 + %62 = OpLoad %uint %gl_LocalInvocationIndex + %67 = OpAccessChain %_ptr_Output_v3uint %gl_PrimitiveTriangleIndicesEXT %62 + OpStore %67 %65 + %73 = OpAccessChain %_ptr_Output_v4float %primitiveOutput %int_0 %int_0 + OpStore %73 %72 + %80 = OpAccessChain %_ptr_TaskPayloadWorkgroupEXT_bool %payloadData %int_1 + %81 = OpLoad %bool %80 + %82 = OpLogicalNot %bool %81 + %84 = OpAccessChain %_ptr_Output_bool %gl_MeshPrimitivesEXT %int_0 %int_3 + OpStore %84 %82 + OpReturn + OpFunctionEnd diff --git a/examples/features/src/mesh_shader/shader.task b/examples/features/src/mesh_shader/shader.task new file mode 100644 index 0000000000..6c766bc83a --- /dev/null +++ b/examples/features/src/mesh_shader/shader.task @@ -0,0 +1,16 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; + +struct TaskPayload { + vec4 colorMask; + bool visible; +}; +taskPayloadSharedEXT TaskPayload taskPayload; + +void main() { + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + EmitMeshTasksEXT(3, 1, 1); +} \ No newline at end of file diff --git a/examples/features/src/mesh_shader/shader.task.spv.asm b/examples/features/src/mesh_shader/shader.task.spv.asm new file mode 100644 index 0000000000..4e3416941f --- /dev/null +++ b/examples/features/src/mesh_shader/shader.task.spv.asm @@ -0,0 +1,50 @@ +; SPIR-V +; Version: 1.5 +; Generator: Khronos Glslang Reference Front End; 11 +; Bound: 30 +; Schema: 0 + OpCapability MeshShadingEXT + OpExtension "SPV_EXT_mesh_shader" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint TaskEXT %main "main" %taskPayload + OpExecutionMode %main LocalSize 4 1 1 + OpSource GLSL 450 + OpSourceExtension "GL_EXT_mesh_shader" + OpName %main "main" + OpName %TaskPayload "TaskPayload" + OpMemberName %TaskPayload 0 "colorMask" + OpMemberName %TaskPayload 1 "visible" + OpName %taskPayload "taskPayload" + OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %bool = OpTypeBool +%TaskPayload = OpTypeStruct %v4float %bool +%_ptr_TaskPayloadWorkgroupEXT_TaskPayload = OpTypePointer TaskPayloadWorkgroupEXT %TaskPayload +%taskPayload = OpVariable %_ptr_TaskPayloadWorkgroupEXT_TaskPayload TaskPayloadWorkgroupEXT + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %float_1 = OpConstant %float 1 + %float_0 = OpConstant %float 0 + %16 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1 +%_ptr_TaskPayloadWorkgroupEXT_v4float = OpTypePointer TaskPayloadWorkgroupEXT %v4float + %int_1 = OpConstant %int 1 + %true = OpConstantTrue %bool +%_ptr_TaskPayloadWorkgroupEXT_bool = OpTypePointer TaskPayloadWorkgroupEXT %bool + %uint = OpTypeInt 32 0 + %uint_3 = OpConstant %uint 3 + %uint_1 = OpConstant %uint 1 + %v3uint = OpTypeVector %uint 3 + %uint_4 = OpConstant %uint 4 +%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_4 %uint_1 %uint_1 + %main = OpFunction %void None %3 + %5 = OpLabel + %18 = OpAccessChain %_ptr_TaskPayloadWorkgroupEXT_v4float %taskPayload %int_0 + OpStore %18 %16 + %22 = OpAccessChain %_ptr_TaskPayloadWorkgroupEXT_bool %taskPayload %int_1 + OpStore %22 %true + OpEmitMeshTasksEXT %uint_3 %uint_1 %uint_1 %taskPayload + OpFunctionEnd diff --git a/examples/standalone/03_custom_backend/src/custom.rs b/examples/standalone/03_custom_backend/src/custom.rs index 6a30b8f7d3..49bd9887fc 100644 --- a/examples/standalone/03_custom_backend/src/custom.rs +++ b/examples/standalone/03_custom_backend/src/custom.rs @@ -161,6 +161,13 @@ impl DeviceInterface for CustomDevice { unimplemented!() } + fn create_mesh_pipeline( + &self, + _desc: &wgpu::MeshPipelineDescriptor<'_>, + ) -> wgpu::custom::DispatchRenderPipeline { + unimplemented!() + } + fn create_compute_pipeline( &self, _desc: &wgpu::ComputePipelineDescriptor<'_>, diff --git a/player/src/lib.rs b/player/src/lib.rs index b1934036ba..563b66853f 100644 --- a/player/src/lib.rs +++ b/player/src/lib.rs @@ -377,6 +377,24 @@ impl GlobalPlay for wgc::global::Global { panic!("{e}"); } } + Action::CreateMeshPipeline { + id, + desc, + implicit_context, + } => { + let implicit_ids = + implicit_context + .as_ref() + .map(|ic| wgc::device::ImplicitPipelineIds { + root_id: ic.root_id, + group_ids: &ic.group_ids, + }); + let (_, error) = + self.device_create_mesh_pipeline(device, &desc, Some(id), implicit_ids); + if let Some(e) = error { + panic!("{e}"); + } + } Action::DestroyRenderPipeline(id) => { self.render_pipeline_drop(id); } diff --git a/tests/gpu-tests/mesh_shader/basic.frag b/tests/gpu-tests/mesh_shader/basic.frag new file mode 100644 index 0000000000..9d2b777326 --- /dev/null +++ b/tests/gpu-tests/mesh_shader/basic.frag @@ -0,0 +1,9 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +in VertexInput { layout(location = 0) vec4 color; } +vertexInput; + +layout(location = 0) out vec4 fragColor; + +void main() { fragColor = vertexInput.color; } \ No newline at end of file diff --git a/tests/gpu-tests/mesh_shader/basic.frag.spv.asm b/tests/gpu-tests/mesh_shader/basic.frag.spv.asm new file mode 100644 index 0000000000..956e71b46d --- /dev/null +++ b/tests/gpu-tests/mesh_shader/basic.frag.spv.asm @@ -0,0 +1,39 @@ +; SPIR-V +; Version: 1.5 +; Generator: Khronos Glslang Reference Front End; 11 +; Bound: 18 +; Schema: 0 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %fragColor %vertexInput + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpSourceExtension "GL_EXT_mesh_shader" + OpName %main "main" + OpName %fragColor "fragColor" + OpName %VertexInput "VertexInput" + OpMemberName %VertexInput 0 "color" + OpName %vertexInput "vertexInput" + OpDecorate %fragColor Location 0 + OpDecorate %VertexInput Block + OpMemberDecorate %VertexInput 0 Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %fragColor = OpVariable %_ptr_Output_v4float Output +%VertexInput = OpTypeStruct %v4float +%_ptr_Input_VertexInput = OpTypePointer Input %VertexInput +%vertexInput = OpVariable %_ptr_Input_VertexInput Input + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 +%_ptr_Input_v4float = OpTypePointer Input %v4float + %main = OpFunction %void None %3 + %5 = OpLabel + %16 = OpAccessChain %_ptr_Input_v4float %vertexInput %int_0 + %17 = OpLoad %v4float %16 + OpStore %fragColor %17 + OpReturn + OpFunctionEnd diff --git a/tests/gpu-tests/mesh_shader/basic.mesh b/tests/gpu-tests/mesh_shader/basic.mesh new file mode 100644 index 0000000000..400cafb36f --- /dev/null +++ b/tests/gpu-tests/mesh_shader/basic.mesh @@ -0,0 +1,25 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +const vec4[3] positions = {vec4(0., 1.0, 0., 1.0), vec4(-1.0, -1.0, 0., 1.0), + vec4(1.0, -1.0, 0., 1.0)}; +const vec4[3] colors = {vec4(0., 1., 0., 1.), vec4(0., 0., 1., 1.), + vec4(1., 0., 0., 1.)}; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +out VertexOutput { layout(location = 0) vec4 color; } +vertexOutput[]; + +layout(triangles, max_vertices = 3, max_primitives = 1) out; + +void main() { + SetMeshOutputsEXT(3, 1); + gl_MeshVerticesEXT[0].gl_Position = positions[0]; + gl_MeshVerticesEXT[1].gl_Position = positions[1]; + gl_MeshVerticesEXT[2].gl_Position = positions[2]; + vertexOutput[0].color = colors[0]; + vertexOutput[1].color = colors[1]; + vertexOutput[2].color = colors[2]; + gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uvec3(0, 1, 2); +} \ No newline at end of file diff --git a/tests/gpu-tests/mesh_shader/basic.mesh.spv.asm b/tests/gpu-tests/mesh_shader/basic.mesh.spv.asm new file mode 100644 index 0000000000..f7ea01908c --- /dev/null +++ b/tests/gpu-tests/mesh_shader/basic.mesh.spv.asm @@ -0,0 +1,101 @@ +; SPIR-V +; Version: 1.5 +; Generator: Khronos Glslang Reference Front End; 11 +; Bound: 55 +; Schema: 0 + OpCapability MeshShadingEXT + OpExtension "SPV_EXT_mesh_shader" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint MeshEXT %main "main" %gl_MeshVerticesEXT %vertexOutput %gl_PrimitiveTriangleIndicesEXT %gl_LocalInvocationIndex + OpExecutionMode %main LocalSize 1 1 1 + OpExecutionMode %main OutputVertices 3 + OpExecutionMode %main OutputPrimitivesEXT 1 + OpExecutionMode %main OutputTrianglesEXT + OpSource GLSL 450 + OpSourceExtension "GL_EXT_mesh_shader" + OpName %main "main" + OpName %gl_MeshPerVertexEXT "gl_MeshPerVertexEXT" + OpMemberName %gl_MeshPerVertexEXT 0 "gl_Position" + OpMemberName %gl_MeshPerVertexEXT 1 "gl_PointSize" + OpMemberName %gl_MeshPerVertexEXT 2 "gl_ClipDistance" + OpMemberName %gl_MeshPerVertexEXT 3 "gl_CullDistance" + OpName %gl_MeshVerticesEXT "gl_MeshVerticesEXT" + OpName %VertexOutput "VertexOutput" + OpMemberName %VertexOutput 0 "color" + OpName %vertexOutput "vertexOutput" + OpName %gl_PrimitiveTriangleIndicesEXT "gl_PrimitiveTriangleIndicesEXT" + OpName %gl_LocalInvocationIndex "gl_LocalInvocationIndex" + OpDecorate %gl_MeshPerVertexEXT Block + OpMemberDecorate %gl_MeshPerVertexEXT 0 BuiltIn Position + OpMemberDecorate %gl_MeshPerVertexEXT 1 BuiltIn PointSize + OpMemberDecorate %gl_MeshPerVertexEXT 2 BuiltIn ClipDistance + OpMemberDecorate %gl_MeshPerVertexEXT 3 BuiltIn CullDistance + OpDecorate %VertexOutput Block + OpMemberDecorate %VertexOutput 0 Location 0 + OpDecorate %gl_PrimitiveTriangleIndicesEXT BuiltIn PrimitiveTriangleIndicesEXT + OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex + OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_3 = OpConstant %uint 3 + %uint_1 = OpConstant %uint 1 + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_arr_float_uint_1 = OpTypeArray %float %uint_1 +%gl_MeshPerVertexEXT = OpTypeStruct %v4float %float %_arr_float_uint_1 %_arr_float_uint_1 +%_arr_gl_MeshPerVertexEXT_uint_3 = OpTypeArray %gl_MeshPerVertexEXT %uint_3 +%_ptr_Output__arr_gl_MeshPerVertexEXT_uint_3 = OpTypePointer Output %_arr_gl_MeshPerVertexEXT_uint_3 +%gl_MeshVerticesEXT = OpVariable %_ptr_Output__arr_gl_MeshPerVertexEXT_uint_3 Output + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %float_0 = OpConstant %float 0 + %float_1 = OpConstant %float 1 + %20 = OpConstantComposite %v4float %float_0 %float_1 %float_0 %float_1 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %int_1 = OpConstant %int 1 + %float_n1 = OpConstant %float -1 + %25 = OpConstantComposite %v4float %float_n1 %float_n1 %float_0 %float_1 + %int_2 = OpConstant %int 2 + %28 = OpConstantComposite %v4float %float_1 %float_n1 %float_0 %float_1 +%VertexOutput = OpTypeStruct %v4float +%_arr_VertexOutput_uint_3 = OpTypeArray %VertexOutput %uint_3 +%_ptr_Output__arr_VertexOutput_uint_3 = OpTypePointer Output %_arr_VertexOutput_uint_3 +%vertexOutput = OpVariable %_ptr_Output__arr_VertexOutput_uint_3 Output + %35 = OpConstantComposite %v4float %float_0 %float_0 %float_1 %float_1 + %37 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1 + %v3uint = OpTypeVector %uint 3 +%_arr_v3uint_uint_1 = OpTypeArray %v3uint %uint_1 +%_ptr_Output__arr_v3uint_uint_1 = OpTypePointer Output %_arr_v3uint_uint_1 +%gl_PrimitiveTriangleIndicesEXT = OpVariable %_ptr_Output__arr_v3uint_uint_1 Output +%_ptr_Input_uint = OpTypePointer Input %uint +%gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input + %uint_0 = OpConstant %uint 0 + %uint_2 = OpConstant %uint 2 + %48 = OpConstantComposite %v3uint %uint_0 %uint_1 %uint_2 +%_ptr_Output_v3uint = OpTypePointer Output %v3uint +%_arr_v4float_uint_3 = OpTypeArray %v4float %uint_3 + %52 = OpConstantComposite %_arr_v4float_uint_3 %20 %25 %28 + %53 = OpConstantComposite %_arr_v4float_uint_3 %20 %35 %37 +%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1 + %main = OpFunction %void None %3 + %5 = OpLabel + OpSetMeshOutputsEXT %uint_3 %uint_1 + %22 = OpAccessChain %_ptr_Output_v4float %gl_MeshVerticesEXT %int_0 %int_0 + OpStore %22 %20 + %26 = OpAccessChain %_ptr_Output_v4float %gl_MeshVerticesEXT %int_1 %int_0 + OpStore %26 %25 + %29 = OpAccessChain %_ptr_Output_v4float %gl_MeshVerticesEXT %int_2 %int_0 + OpStore %29 %28 + %34 = OpAccessChain %_ptr_Output_v4float %vertexOutput %int_0 %int_0 + OpStore %34 %20 + %36 = OpAccessChain %_ptr_Output_v4float %vertexOutput %int_1 %int_0 + OpStore %36 %35 + %38 = OpAccessChain %_ptr_Output_v4float %vertexOutput %int_2 %int_0 + OpStore %38 %37 + %45 = OpLoad %uint %gl_LocalInvocationIndex + %50 = OpAccessChain %_ptr_Output_v3uint %gl_PrimitiveTriangleIndicesEXT %45 + OpStore %50 %48 + OpReturn + OpFunctionEnd diff --git a/tests/gpu-tests/mesh_shader/basic.task b/tests/gpu-tests/mesh_shader/basic.task new file mode 100644 index 0000000000..418cffa3f6 --- /dev/null +++ b/tests/gpu-tests/mesh_shader/basic.task @@ -0,0 +1,6 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; + +void main() { EmitMeshTasksEXT(1, 1, 1); } \ No newline at end of file diff --git a/tests/gpu-tests/mesh_shader/basic.task.spv.asm b/tests/gpu-tests/mesh_shader/basic.task.spv.asm new file mode 100644 index 0000000000..016715ac6a --- /dev/null +++ b/tests/gpu-tests/mesh_shader/basic.task.spv.asm @@ -0,0 +1,26 @@ +; SPIR-V +; Version: 1.5 +; Generator: Khronos Glslang Reference Front End; 11 +; Bound: 12 +; Schema: 0 + OpCapability MeshShadingEXT + OpExtension "SPV_EXT_mesh_shader" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint TaskEXT %main "main" + OpExecutionMode %main LocalSize 4 1 1 + OpSource GLSL 450 + OpSourceExtension "GL_EXT_mesh_shader" + OpName %main "main" + OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_1 = OpConstant %uint 1 + %v3uint = OpTypeVector %uint 3 + %uint_4 = OpConstant %uint 4 +%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_4 %uint_1 %uint_1 + %main = OpFunction %void None %3 + %5 = OpLabel + OpEmitMeshTasksEXT %uint_1 %uint_1 %uint_1 + OpFunctionEnd diff --git a/tests/gpu-tests/mesh_shader/mod.rs b/tests/gpu-tests/mesh_shader/mod.rs new file mode 100644 index 0000000000..8942a5f9bd --- /dev/null +++ b/tests/gpu-tests/mesh_shader/mod.rs @@ -0,0 +1,296 @@ +use std::{io::Write, process::Stdio}; + +use wgpu::util::DeviceExt; +use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; + +// Same as in mesh shader example +fn compile_spv_asm(device: &wgpu::Device, data: &[u8]) -> wgpu::ShaderModule { + let cmd = std::process::Command::new("spirv-as") + .args(["-", "-o", "-"]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .expect("Failed to call spirv-as"); + cmd.stdin.as_ref().unwrap().write_all(data).unwrap(); + let output = cmd.wait_with_output().expect("Error waiting for spirv-as"); + assert!(output.status.success()); + unsafe { + device.create_shader_module_spirv(&wgpu::ShaderModuleDescriptorSpirV { + label: None, + source: wgpu::util::make_spirv_raw(&output.stdout), + }) + } +} + +fn create_depth( + device: &wgpu::Device, +) -> (wgpu::Texture, wgpu::TextureView, wgpu::DepthStencilState) { + let image_size = wgpu::Extent3d { + width: 64, + height: 64, + depth_or_array_layers: 1, + }; + let depth_texture = device.create_texture(&wgpu::TextureDescriptor { + label: None, + size: image_size, + mip_level_count: 1, + sample_count: 1, + dimension: wgpu::TextureDimension::D2, + format: wgpu::TextureFormat::Depth32Float, + usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING, + view_formats: &[], + }); + let depth_view = depth_texture.create_view(&Default::default()); + let state = wgpu::DepthStencilState { + format: wgpu::TextureFormat::Depth32Float, + depth_write_enabled: true, + depth_compare: wgpu::CompareFunction::Less, // 1. + stencil: wgpu::StencilState::default(), // 2. + bias: wgpu::DepthBiasState::default(), + }; + (depth_texture, depth_view, state) +} + +fn mesh_pipeline_build( + ctx: &TestingContext, + task: Option<&[u8]>, + mesh: &[u8], + frag: Option<&[u8]>, + draw: bool, +) { + let device = &ctx.device; + let (_depth_image, depth_view, depth_state) = create_depth(device); + let task = task.map(|t| compile_spv_asm(device, t)); + let mesh = compile_spv_asm(device, mesh); + let frag = frag.map(|f| compile_spv_asm(device, f)); + let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[], + push_constant_ranges: &[], + }); + let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor { + label: None, + layout: Some(&layout), + task: task.as_ref().map(|task| wgpu::TaskState { + module: task, + entry_point: Some("main"), + compilation_options: Default::default(), + }), + mesh: wgpu::MeshState { + module: &mesh, + entry_point: Some("main"), + compilation_options: Default::default(), + }, + fragment: frag.as_ref().map(|frag| wgpu::FragmentState { + module: frag, + entry_point: Some("main"), + targets: &[], + compilation_options: Default::default(), + }), + primitive: wgpu::PrimitiveState { + cull_mode: Some(wgpu::Face::Back), + ..Default::default() + }, + depth_stencil: Some(depth_state), + multisample: Default::default(), + multiview: None, + cache: None, + }); + if draw { + let mut encoder = + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor { + label: None, + color_attachments: &[], + depth_stencil_attachment: Some(wgpu::RenderPassDepthStencilAttachment { + view: &depth_view, + depth_ops: Some(wgpu::Operations { + load: wgpu::LoadOp::Clear(1.0), + store: wgpu::StoreOp::Store, + }), + stencil_ops: None, + }), + timestamp_writes: None, + occlusion_query_set: None, + }); + pass.set_pipeline(&pipeline); + pass.draw_mesh_tasks(1, 1, 1); + } + ctx.queue.submit(Some(encoder.finish())); + ctx.device.poll(wgpu::PollType::Wait).unwrap(); + } +} + +#[derive(PartialEq, Eq, Clone, Copy)] +pub enum DrawType { + #[allow(dead_code)] + Standard, + Indirect, + MultiIndirect, + MultiIndirectCount, +} + +fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) { + let device = &ctx.device; + let (_depth_image, depth_view, depth_state) = create_depth(device); + let task = compile_spv_asm(device, BASIC_TASK); + let mesh = compile_spv_asm(device, BASIC_MESH); + let frag = compile_spv_asm(device, NO_WRITE_FRAG); + let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[], + push_constant_ranges: &[], + }); + let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor { + label: None, + layout: Some(&layout), + task: Some(wgpu::TaskState { + module: &task, + entry_point: Some("main"), + compilation_options: Default::default(), + }), + mesh: wgpu::MeshState { + module: &mesh, + entry_point: Some("main"), + compilation_options: Default::default(), + }, + fragment: Some(wgpu::FragmentState { + module: &frag, + entry_point: Some("main"), + targets: &[], + compilation_options: Default::default(), + }), + primitive: wgpu::PrimitiveState { + cull_mode: Some(wgpu::Face::Back), + ..Default::default() + }, + depth_stencil: Some(depth_state), + multisample: Default::default(), + multiview: None, + cache: None, + }); + let buffer = match draw_type { + DrawType::Standard => None, + DrawType::Indirect | DrawType::MultiIndirect | DrawType::MultiIndirectCount => Some( + device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: None, + usage: wgpu::BufferUsages::INDIRECT, + contents: bytemuck::bytes_of(&[1u32; 4]), + }), + ), + }; + let count_buffer = match draw_type { + DrawType::MultiIndirectCount => Some(device.create_buffer_init( + &wgpu::util::BufferInitDescriptor { + label: None, + usage: wgpu::BufferUsages::INDIRECT, + contents: bytemuck::bytes_of(&[1u32; 1]), + }, + )), + _ => None, + }; + let mut encoder = + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor { + label: None, + color_attachments: &[], + depth_stencil_attachment: Some(wgpu::RenderPassDepthStencilAttachment { + view: &depth_view, + depth_ops: Some(wgpu::Operations { + load: wgpu::LoadOp::Clear(1.0), + store: wgpu::StoreOp::Store, + }), + stencil_ops: None, + }), + timestamp_writes: None, + occlusion_query_set: None, + }); + pass.set_pipeline(&pipeline); + match draw_type { + DrawType::Standard => pass.draw_mesh_tasks(1, 1, 1), + DrawType::Indirect => pass.draw_mesh_tasks_indirect(buffer.as_ref().unwrap(), 0), + DrawType::MultiIndirect => { + pass.multi_draw_mesh_tasks_indirect(buffer.as_ref().unwrap(), 0, 1) + } + DrawType::MultiIndirectCount => pass.multi_draw_mesh_tasks_indirect_count( + buffer.as_ref().unwrap(), + 0, + count_buffer.as_ref().unwrap(), + 0, + 1, + ), + } + pass.draw_mesh_tasks_indirect(buffer.as_ref().unwrap(), 0); + } + ctx.queue.submit(Some(encoder.finish())); + ctx.device.poll(wgpu::PollType::Wait).unwrap(); +} + +const BASIC_TASK: &[u8] = include_bytes!("basic.task.spv.asm"); +const BASIC_MESH: &[u8] = include_bytes!("basic.mesh.spv.asm"); +//const BASIC_FRAG: &[u8] = include_bytes!("basic.frag.spv"); +const NO_WRITE_FRAG: &[u8] = include_bytes!("no-write.frag.spv.asm"); + +fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration { + GpuTestConfiguration::new().parameters( + TestParameters::default() + .test_features_limits() + .features( + wgpu::Features::EXPERIMENTAL_MESH_SHADER + | wgpu::Features::SPIRV_SHADER_PASSTHROUGH + | match draw_type { + DrawType::Standard | DrawType::Indirect => wgpu::Features::empty(), + DrawType::MultiIndirect => wgpu::Features::MULTI_DRAW_INDIRECT, + DrawType::MultiIndirectCount => wgpu::Features::MULTI_DRAW_INDIRECT_COUNT, + }, + ) + .limits(wgpu::Limits::default()), + ) +} + +// Mesh pipeline configs +#[gpu_test] +static MESH_PIPELINE_BASIC_MESH: GpuTestConfiguration = default_gpu_test_config(DrawType::Standard) + .run_sync(|ctx| { + mesh_pipeline_build(&ctx, None, BASIC_MESH, None, true); + }); +#[gpu_test] +static MESH_PIPELINE_BASIC_TASK_MESH: GpuTestConfiguration = + default_gpu_test_config(DrawType::Standard).run_sync(|ctx| { + mesh_pipeline_build(&ctx, Some(BASIC_TASK), BASIC_MESH, None, true); + }); +#[gpu_test] +static MESH_PIPELINE_BASIC_MESH_FRAG: GpuTestConfiguration = + default_gpu_test_config(DrawType::Standard).run_sync(|ctx| { + mesh_pipeline_build(&ctx, None, BASIC_MESH, Some(NO_WRITE_FRAG), true); + }); +#[gpu_test] +static MESH_PIPELINE_BASIC_TASK_MESH_FRAG: GpuTestConfiguration = + default_gpu_test_config(DrawType::Standard).run_sync(|ctx| { + mesh_pipeline_build( + &ctx, + Some(BASIC_TASK), + BASIC_MESH, + Some(NO_WRITE_FRAG), + true, + ); + }); + +// Mesh draw +#[gpu_test] +static MESH_DRAW_INDIRECT: GpuTestConfiguration = default_gpu_test_config(DrawType::Indirect) + .run_sync(|ctx| { + mesh_draw(&ctx, DrawType::Indirect); + }); +#[gpu_test] +static MESH_MULTI_DRAW_INDIRECT: GpuTestConfiguration = + default_gpu_test_config(DrawType::MultiIndirect).run_sync(|ctx| { + mesh_draw(&ctx, DrawType::MultiIndirect); + }); +#[gpu_test] +static MESH_MULTI_DRAW_INDIRECT_COUNT: GpuTestConfiguration = + default_gpu_test_config(DrawType::MultiIndirectCount).run_sync(|ctx| { + mesh_draw(&ctx, DrawType::MultiIndirectCount); + }); diff --git a/tests/gpu-tests/mesh_shader/no-write.frag b/tests/gpu-tests/mesh_shader/no-write.frag new file mode 100644 index 0000000000..d0512bb0fa --- /dev/null +++ b/tests/gpu-tests/mesh_shader/no-write.frag @@ -0,0 +1,7 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +in VertexInput { layout(location = 0) vec4 color; } +vertexInput; + +void main() {} \ No newline at end of file diff --git a/tests/gpu-tests/mesh_shader/no-write.frag.spv.asm b/tests/gpu-tests/mesh_shader/no-write.frag.spv.asm new file mode 100644 index 0000000000..5cd07d5abb --- /dev/null +++ b/tests/gpu-tests/mesh_shader/no-write.frag.spv.asm @@ -0,0 +1,29 @@ +; SPIR-V +; Version: 1.5 +; Generator: Khronos Glslang Reference Front End; 11 +; Bound: 11 +; Schema: 0 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %vertexInput + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpSourceExtension "GL_EXT_mesh_shader" + OpName %main "main" + OpName %VertexInput "VertexInput" + OpMemberName %VertexInput 0 "color" + OpName %vertexInput "vertexInput" + OpDecorate %VertexInput Block + OpMemberDecorate %VertexInput 0 Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%VertexInput = OpTypeStruct %v4float +%_ptr_Input_VertexInput = OpTypePointer Input %VertexInput +%vertexInput = OpVariable %_ptr_Input_VertexInput Input + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd diff --git a/tests/gpu-tests/root.rs b/tests/gpu-tests/root.rs index f4f10fa83b..f8bc5e1414 100644 --- a/tests/gpu-tests/root.rs +++ b/tests/gpu-tests/root.rs @@ -32,6 +32,7 @@ mod image_atomics; mod instance; mod life_cycle; mod mem_leaks; +mod mesh_shader; mod nv12_texture; mod occlusion_query; mod oob_indexing; diff --git a/wgpu-core/src/command/bundle.rs b/wgpu-core/src/command/bundle.rs index e192c0b99b..c7b43ae74c 100644 --- a/wgpu-core/src/command/bundle.rs +++ b/wgpu-core/src/command/bundle.rs @@ -118,7 +118,7 @@ use crate::{ use super::{ render_command::{ArcRenderCommand, RenderCommand}, - DrawKind, + DrawCommandFamily, DrawKind, }; /// Describes a [`RenderBundleEncoder`]. @@ -375,7 +375,7 @@ impl RenderBundleEncoder { } => { let scope = PassErrorScope::Draw { kind: DrawKind::Draw, - indexed: false, + family: DrawCommandFamily::Draw, }; draw( &mut state, @@ -396,7 +396,7 @@ impl RenderBundleEncoder { } => { let scope = PassErrorScope::Draw { kind: DrawKind::Draw, - indexed: true, + family: DrawCommandFamily::DrawIndexed, }; draw_indexed( &mut state, @@ -409,15 +409,33 @@ impl RenderBundleEncoder { ) .map_pass_err(scope)?; } + RenderCommand::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + } => { + let scope = PassErrorScope::Draw { + kind: DrawKind::Draw, + family: DrawCommandFamily::DrawMeshTasks, + }; + draw_mesh_tasks( + &mut state, + &base.dynamic_offsets, + group_count_x, + group_count_y, + group_count_z, + ) + .map_pass_err(scope)?; + } RenderCommand::DrawIndirect { buffer_id, offset, count: 1, - indexed, + family, } => { let scope = PassErrorScope::Draw { kind: DrawKind::DrawIndirect, - indexed, + family, }; multi_draw_indirect( &mut state, @@ -425,7 +443,7 @@ impl RenderBundleEncoder { &buffer_guard, buffer_id, offset, - indexed, + family, ) .map_pass_err(scope)?; } @@ -767,13 +785,48 @@ fn draw_indexed( Ok(()) } +fn draw_mesh_tasks( + state: &mut State, + dynamic_offsets: &[u32], + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, +) -> Result<(), RenderBundleErrorInner> { + let pipeline = state.pipeline()?; + let used_bind_groups = pipeline.used_bind_groups; + + let groups_size_limit = state.device.limits.max_task_workgroups_per_dimension; + let max_groups = state.device.limits.max_task_workgroup_total_count; + if group_count_x > groups_size_limit + || group_count_y > groups_size_limit + || group_count_z > groups_size_limit + || group_count_x * group_count_y * group_count_z > max_groups + { + return Err(RenderBundleErrorInner::Draw(DrawError::InvalidGroupSize { + current: [group_count_x, group_count_y, group_count_z], + limit: groups_size_limit, + max_total: max_groups, + })); + } + + if group_count_x > 0 && group_count_y > 0 && group_count_z > 0 { + state.flush_binds(used_bind_groups, dynamic_offsets); + state.commands.push(ArcRenderCommand::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + }); + } + Ok(()) +} + fn multi_draw_indirect( state: &mut State, dynamic_offsets: &[u32], buffer_guard: &crate::storage::Storage>, buffer_id: id::Id, offset: u64, - indexed: bool, + family: DrawCommandFamily, ) -> Result<(), RenderBundleErrorInner> { state .device @@ -800,7 +853,7 @@ fn multi_draw_indirect( MemoryInitKind::NeedsInitializedMemory, )); - if indexed { + if family == DrawCommandFamily::DrawIndexed { let index = match state.index { Some(ref mut index) => index, None => return Err(DrawError::MissingIndexBuffer.into()), @@ -814,7 +867,7 @@ fn multi_draw_indirect( buffer, offset, count: 1, - indexed, + family, }); Ok(()) } @@ -1024,11 +1077,18 @@ impl RenderBundle { ) }; } + Cmd::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + } => unsafe { + raw.draw_mesh_tasks(*group_count_x, *group_count_y, *group_count_z); + }, Cmd::DrawIndirect { buffer, offset, count: 1, - indexed: false, + family: DrawCommandFamily::Draw, } => { let buffer = buffer.try_raw(snatch_guard)?; unsafe { raw.draw_indirect(buffer, *offset, 1) }; @@ -1037,7 +1097,7 @@ impl RenderBundle { buffer, offset, count: 1, - indexed: true, + family: DrawCommandFamily::DrawIndexed, } => { let buffer = buffer.try_raw(snatch_guard)?; unsafe { raw.draw_indexed_indirect(buffer, *offset, 1) }; @@ -1502,7 +1562,7 @@ where pub mod bundle_ffi { use super::{RenderBundleEncoder, RenderCommand}; - use crate::{id, RawString}; + use crate::{command::DrawCommandFamily, id, RawString}; use core::{convert::TryInto, slice}; use wgt::{BufferAddress, BufferSize, DynamicOffset, IndexFormat}; @@ -1657,7 +1717,7 @@ pub mod bundle_ffi { buffer_id, offset, count: 1, - indexed: false, + family: DrawCommandFamily::Draw, }); } @@ -1670,7 +1730,7 @@ pub mod bundle_ffi { buffer_id, offset, count: 1, - indexed: true, + family: DrawCommandFamily::DrawIndexed, }); } diff --git a/wgpu-core/src/command/draw.rs b/wgpu-core/src/command/draw.rs index d2bdd19038..3444ff2c2f 100644 --- a/wgpu-core/src/command/draw.rs +++ b/wgpu-core/src/command/draw.rs @@ -54,6 +54,21 @@ pub enum DrawError { }, #[error(transparent)] BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch), + + #[error( + "Wrong pipeline type for this draw command. Attempted to call {} draw command on {} pipeline", + if *wanted_mesh_pipeline {"mesh shader"} else {"standard"}, + if *wanted_mesh_pipeline {"standard"} else {"mesh shader"}, + )] + WrongPipelineType { wanted_mesh_pipeline: bool }, + #[error( + "Each current draw group size dimension ({current:?}) must be less or equal to {limit}, and the product must be less or equal to {max_total}" + )] + InvalidGroupSize { + current: [u32; 3], + limit: u32, + max_total: u32, + }, } /// Error encountered when encoding a render command. diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index abc8043732..3387925bf7 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -1024,6 +1024,15 @@ pub enum DrawKind { MultiDrawIndirectCount, } +/// The type of draw command(indexed or not, or mesh shader) +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum DrawCommandFamily { + Draw, + DrawIndexed, + DrawMeshTasks, +} + #[derive(Clone, Copy, Debug, Error)] pub enum PassErrorScope { // TODO: Extract out the 2 error variants below so that we can always @@ -1053,7 +1062,10 @@ pub enum PassErrorScope { #[error("In a set_scissor_rect command")] SetScissorRect, #[error("In a draw command, kind: {kind:?}")] - Draw { kind: DrawKind, indexed: bool }, + Draw { + kind: DrawKind, + family: DrawCommandFamily, + }, #[error("In a write_timestamp command")] WriteTimestamp, #[error("In a begin_occlusion_query command")] diff --git a/wgpu-core/src/command/render.rs b/wgpu-core/src/command/render.rs index 6d59f1b186..30307a811f 100644 --- a/wgpu-core/src/command/render.rs +++ b/wgpu-core/src/command/render.rs @@ -53,7 +53,7 @@ use super::{ memory_init::TextureSurfaceDiscard, CommandBufferTextureMemoryActions, CommandEncoder, QueryResetMap, }; -use super::{DrawKind, Rect}; +use super::{DrawCommandFamily, DrawKind, Rect}; pub use wgt::{LoadOp, StoreOp}; @@ -524,7 +524,7 @@ struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> { impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> { - fn is_ready(&self, indexed: bool) -> Result<(), DrawError> { + fn is_ready(&self, family: DrawCommandFamily) -> Result<(), DrawError> { if let Some(pipeline) = self.pipeline.as_ref() { self.binder.check_compatibility(pipeline.as_ref())?; self.binder.check_late_buffer_bindings()?; @@ -548,7 +548,7 @@ impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> }); } - if indexed { + if family == DrawCommandFamily::DrawIndexed { // Pipeline expects an index buffer if let Some(pipeline_index_format) = pipeline.strip_index_format { // We have a buffer bound @@ -567,6 +567,11 @@ impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> } } } + if (family == DrawCommandFamily::DrawMeshTasks) != pipeline.is_mesh { + return Err(DrawError::WrongPipelineType { + wanted_mesh_pipeline: !pipeline.is_mesh, + }); + } Ok(()) } else { Err(DrawError::MissingPipeline) @@ -1782,7 +1787,7 @@ impl Global { } => { let scope = PassErrorScope::Draw { kind: DrawKind::Draw, - indexed: false, + family: DrawCommandFamily::Draw, }; draw( &mut state, @@ -1802,7 +1807,7 @@ impl Global { } => { let scope = PassErrorScope::Draw { kind: DrawKind::Draw, - indexed: true, + family: DrawCommandFamily::DrawIndexed, }; draw_indexed( &mut state, @@ -1814,11 +1819,23 @@ impl Global { ) .map_pass_err(scope)?; } + ArcRenderCommand::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + } => { + let scope = PassErrorScope::Draw { + kind: DrawKind::Draw, + family: DrawCommandFamily::DrawMeshTasks, + }; + draw_mesh_tasks(&mut state, group_count_x, group_count_y, group_count_z) + .map_pass_err(scope)?; + } ArcRenderCommand::DrawIndirect { buffer, offset, count, - indexed, + family, } => { let scope = PassErrorScope::Draw { kind: if count != 1 { @@ -1826,9 +1843,9 @@ impl Global { } else { DrawKind::DrawIndirect }, - indexed, + family, }; - multi_draw_indirect(&mut state, cmd_buf, buffer, offset, count, indexed) + multi_draw_indirect(&mut state, cmd_buf, buffer, offset, count, family) .map_pass_err(scope)?; } ArcRenderCommand::MultiDrawIndirectCount { @@ -1837,11 +1854,11 @@ impl Global { count_buffer, count_buffer_offset, max_count, - indexed, + family, } => { let scope = PassErrorScope::Draw { kind: DrawKind::MultiDrawIndirectCount, - indexed, + family, }; multi_draw_indirect_count( &mut state, @@ -1851,7 +1868,7 @@ impl Global { count_buffer, count_buffer_offset, max_count, - indexed, + family, ) .map_pass_err(scope)?; } @@ -2395,7 +2412,7 @@ fn draw( ) -> Result<(), DrawError> { api_log!("RenderPass::draw {vertex_count} {instance_count} {first_vertex} {first_instance}"); - state.is_ready(false)?; + state.is_ready(DrawCommandFamily::Draw)?; state .vertex @@ -2426,7 +2443,7 @@ fn draw_indexed( ) -> Result<(), DrawError> { api_log!("RenderPass::draw_indexed {index_count} {instance_count} {first_index} {base_vertex} {first_instance}"); - state.is_ready(true)?; + state.is_ready(DrawCommandFamily::DrawIndexed)?; let last_index = first_index as u64 + index_count as u64; let index_limit = state.index.limit; @@ -2455,24 +2472,59 @@ fn draw_indexed( Ok(()) } +fn draw_mesh_tasks( + state: &mut State, + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, +) -> Result<(), DrawError> { + api_log!("RenderPass::draw_mesh_tasks {group_count_x} {group_count_y} {group_count_z}"); + + state.is_ready(DrawCommandFamily::DrawMeshTasks)?; + + let groups_size_limit = state.device.limits.max_task_workgroups_per_dimension; + let max_groups = state.device.limits.max_task_workgroup_total_count; + if group_count_x > groups_size_limit + || group_count_y > groups_size_limit + || group_count_z > groups_size_limit + || group_count_x * group_count_y * group_count_z > max_groups + { + return Err(DrawError::InvalidGroupSize { + current: [group_count_x, group_count_y, group_count_z], + limit: groups_size_limit, + max_total: max_groups, + }); + } + + unsafe { + if group_count_x > 0 && group_count_y > 0 && group_count_z > 0 { + state + .raw_encoder + .draw_mesh_tasks(group_count_x, group_count_y, group_count_z); + } + } + Ok(()) +} + fn multi_draw_indirect( state: &mut State, cmd_buf: &Arc, indirect_buffer: Arc, offset: u64, count: u32, - indexed: bool, + family: DrawCommandFamily, ) -> Result<(), RenderPassErrorInner> { api_log!( - "RenderPass::draw_indirect (indexed:{indexed}) {} {offset} {count:?}", + "RenderPass::draw_indirect (family:{family:?}) {} {offset} {count:?}", indirect_buffer.error_ident() ); - state.is_ready(indexed)?; + state.is_ready(family)?; - let stride = match indexed { - false => size_of::(), - true => size_of::(), + let stride = match family { + DrawCommandFamily::Draw => size_of::(), + DrawCommandFamily::DrawIndexed => size_of::(), + DrawCommandFamily::DrawMeshTasks => size_of::(), }; if count != 1 { @@ -2517,15 +2569,20 @@ fn multi_draw_indirect( ), ); - match indexed { - false => unsafe { + match family { + DrawCommandFamily::Draw => unsafe { state.raw_encoder.draw_indirect(indirect_raw, offset, count); }, - true => unsafe { + DrawCommandFamily::DrawIndexed => unsafe { state .raw_encoder .draw_indexed_indirect(indirect_raw, offset, count); }, + DrawCommandFamily::DrawMeshTasks => unsafe { + state + .raw_encoder + .draw_mesh_tasks_indirect(indirect_raw, offset, count); + }, } Ok(()) } @@ -2538,19 +2595,20 @@ fn multi_draw_indirect_count( count_buffer: Arc, count_buffer_offset: u64, max_count: u32, - indexed: bool, + family: DrawCommandFamily, ) -> Result<(), RenderPassErrorInner> { api_log!( - "RenderPass::multi_draw_indirect_count (indexed:{indexed}) {} {offset} {} {count_buffer_offset:?} {max_count:?}", + "RenderPass::multi_draw_indirect_count (family:{family:?}) {} {offset} {} {count_buffer_offset:?} {max_count:?}", indirect_buffer.error_ident(), count_buffer.error_ident() ); - state.is_ready(indexed)?; + state.is_ready(family)?; - let stride = match indexed { - false => size_of::(), - true => size_of::(), + let stride = match family { + DrawCommandFamily::Draw => size_of::(), + DrawCommandFamily::DrawIndexed => size_of::(), + DrawCommandFamily::DrawMeshTasks => size_of::(), } as u64; state @@ -2619,8 +2677,8 @@ fn multi_draw_indirect_count( ), ); - match indexed { - false => unsafe { + match family { + DrawCommandFamily::Draw => unsafe { state.raw_encoder.draw_indirect_count( indirect_raw, offset, @@ -2629,7 +2687,7 @@ fn multi_draw_indirect_count( max_count, ); }, - true => unsafe { + DrawCommandFamily::DrawIndexed => unsafe { state.raw_encoder.draw_indexed_indirect_count( indirect_raw, offset, @@ -2638,6 +2696,15 @@ fn multi_draw_indirect_count( max_count, ); }, + DrawCommandFamily::DrawMeshTasks => unsafe { + state.raw_encoder.draw_mesh_tasks_indirect_count( + indirect_raw, + offset, + count_raw, + count_buffer_offset, + max_count, + ); + }, } Ok(()) } @@ -3040,7 +3107,7 @@ impl Global { ) -> Result<(), RenderPassError> { let scope = PassErrorScope::Draw { kind: DrawKind::Draw, - indexed: false, + family: DrawCommandFamily::Draw, }; let base = pass.base_mut(scope)?; @@ -3065,7 +3132,7 @@ impl Global { ) -> Result<(), RenderPassError> { let scope = PassErrorScope::Draw { kind: DrawKind::Draw, - indexed: true, + family: DrawCommandFamily::DrawIndexed, }; let base = pass.base_mut(scope)?; @@ -3080,6 +3147,27 @@ impl Global { Ok(()) } + pub fn render_pass_draw_mesh_tasks( + &self, + pass: &mut RenderPass, + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, + ) -> Result<(), RenderPassError> { + let scope = PassErrorScope::Draw { + kind: DrawKind::Draw, + family: DrawCommandFamily::DrawMeshTasks, + }; + let base = pass.base_mut(scope)?; + + base.commands.push(ArcRenderCommand::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + }); + Ok(()) + } + pub fn render_pass_draw_indirect( &self, pass: &mut RenderPass, @@ -3088,7 +3176,7 @@ impl Global { ) -> Result<(), RenderPassError> { let scope = PassErrorScope::Draw { kind: DrawKind::DrawIndirect, - indexed: false, + family: DrawCommandFamily::Draw, }; let base = pass.base_mut(scope)?; @@ -3096,7 +3184,7 @@ impl Global { buffer: self.resolve_render_pass_buffer_id(scope, buffer_id)?, offset, count: 1, - indexed: false, + family: DrawCommandFamily::Draw, }); Ok(()) @@ -3110,7 +3198,29 @@ impl Global { ) -> Result<(), RenderPassError> { let scope = PassErrorScope::Draw { kind: DrawKind::DrawIndirect, - indexed: true, + family: DrawCommandFamily::DrawIndexed, + }; + let base = pass.base_mut(scope)?; + + base.commands.push(ArcRenderCommand::DrawIndirect { + buffer: self.resolve_render_pass_buffer_id(scope, buffer_id)?, + offset, + count: 1, + family: DrawCommandFamily::DrawIndexed, + }); + + Ok(()) + } + + pub fn render_pass_draw_mesh_tasks_indirect( + &self, + pass: &mut RenderPass, + buffer_id: id::BufferId, + offset: BufferAddress, + ) -> Result<(), RenderPassError> { + let scope = PassErrorScope::Draw { + kind: DrawKind::DrawIndirect, + family: DrawCommandFamily::DrawMeshTasks, }; let base = pass.base_mut(scope)?; @@ -3118,7 +3228,7 @@ impl Global { buffer: self.resolve_render_pass_buffer_id(scope, buffer_id)?, offset, count: 1, - indexed: true, + family: DrawCommandFamily::DrawMeshTasks, }); Ok(()) @@ -3133,7 +3243,7 @@ impl Global { ) -> Result<(), RenderPassError> { let scope = PassErrorScope::Draw { kind: DrawKind::MultiDrawIndirect, - indexed: false, + family: DrawCommandFamily::Draw, }; let base = pass.base_mut(scope)?; @@ -3141,7 +3251,7 @@ impl Global { buffer: self.resolve_render_pass_buffer_id(scope, buffer_id)?, offset, count, - indexed: false, + family: DrawCommandFamily::Draw, }); Ok(()) @@ -3156,7 +3266,7 @@ impl Global { ) -> Result<(), RenderPassError> { let scope = PassErrorScope::Draw { kind: DrawKind::MultiDrawIndirect, - indexed: true, + family: DrawCommandFamily::DrawIndexed, }; let base = pass.base_mut(scope)?; @@ -3164,7 +3274,30 @@ impl Global { buffer: self.resolve_render_pass_buffer_id(scope, buffer_id)?, offset, count, - indexed: true, + family: DrawCommandFamily::DrawIndexed, + }); + + Ok(()) + } + + pub fn render_pass_multi_draw_mesh_tasks_indirect( + &self, + pass: &mut RenderPass, + buffer_id: id::BufferId, + offset: BufferAddress, + count: u32, + ) -> Result<(), RenderPassError> { + let scope = PassErrorScope::Draw { + kind: DrawKind::MultiDrawIndirect, + family: DrawCommandFamily::DrawMeshTasks, + }; + let base = pass.base_mut(scope)?; + + base.commands.push(ArcRenderCommand::DrawIndirect { + buffer: self.resolve_render_pass_buffer_id(scope, buffer_id)?, + offset, + count, + family: DrawCommandFamily::DrawMeshTasks, }); Ok(()) @@ -3181,7 +3314,7 @@ impl Global { ) -> Result<(), RenderPassError> { let scope = PassErrorScope::Draw { kind: DrawKind::MultiDrawIndirectCount, - indexed: false, + family: DrawCommandFamily::Draw, }; let base = pass.base_mut(scope)?; @@ -3192,7 +3325,7 @@ impl Global { count_buffer: self.resolve_render_pass_buffer_id(scope, count_buffer_id)?, count_buffer_offset, max_count, - indexed: false, + family: DrawCommandFamily::Draw, }); Ok(()) @@ -3209,7 +3342,35 @@ impl Global { ) -> Result<(), RenderPassError> { let scope = PassErrorScope::Draw { kind: DrawKind::MultiDrawIndirectCount, - indexed: true, + family: DrawCommandFamily::DrawIndexed, + }; + let base = pass.base_mut(scope)?; + + base.commands + .push(ArcRenderCommand::MultiDrawIndirectCount { + buffer: self.resolve_render_pass_buffer_id(scope, buffer_id)?, + offset, + count_buffer: self.resolve_render_pass_buffer_id(scope, count_buffer_id)?, + count_buffer_offset, + max_count, + family: DrawCommandFamily::DrawIndexed, + }); + + Ok(()) + } + + pub fn render_pass_multi_draw_mesh_tasks_indirect_count( + &self, + pass: &mut RenderPass, + buffer_id: id::BufferId, + offset: BufferAddress, + count_buffer_id: id::BufferId, + count_buffer_offset: BufferAddress, + max_count: u32, + ) -> Result<(), RenderPassError> { + let scope = PassErrorScope::Draw { + kind: DrawKind::MultiDrawIndirectCount, + family: DrawCommandFamily::DrawMeshTasks, }; let base = pass.base_mut(scope)?; @@ -3220,7 +3381,7 @@ impl Global { count_buffer: self.resolve_render_pass_buffer_id(scope, count_buffer_id)?, count_buffer_offset, max_count, - indexed: true, + family: DrawCommandFamily::DrawMeshTasks, }); Ok(()) diff --git a/wgpu-core/src/command/render_command.rs b/wgpu-core/src/command/render_command.rs index b7fa739157..9fdb81e7eb 100644 --- a/wgpu-core/src/command/render_command.rs +++ b/wgpu-core/src/command/render_command.rs @@ -2,7 +2,7 @@ use alloc::sync::Arc; use wgt::{BufferAddress, BufferSize, Color}; -use super::{Rect, RenderBundle}; +use super::{DrawCommandFamily, Rect, RenderBundle}; use crate::{ binding_model::BindGroup, id, @@ -82,11 +82,16 @@ pub enum RenderCommand { base_vertex: i32, first_instance: u32, }, + DrawMeshTasks { + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, + }, DrawIndirect { buffer_id: id::BufferId, offset: BufferAddress, count: u32, - indexed: bool, + family: DrawCommandFamily, }, MultiDrawIndirectCount { buffer_id: id::BufferId, @@ -94,7 +99,7 @@ pub enum RenderCommand { count_buffer_id: id::BufferId, count_buffer_offset: BufferAddress, max_count: u32, - indexed: bool, + family: DrawCommandFamily, }, PushDebugGroup { color: u32, @@ -310,12 +315,21 @@ impl RenderCommand { base_vertex, first_instance, }, + RenderCommand::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + } => ArcRenderCommand::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + }, RenderCommand::DrawIndirect { buffer_id, offset, count, - indexed, + family, } => ArcRenderCommand::DrawIndirect { buffer: buffers_guard.get(buffer_id).get().map_err(|e| { RenderPassError { @@ -325,14 +339,14 @@ impl RenderCommand { } else { DrawKind::DrawIndirect }, - indexed, + family, }, inner: e.into(), } })?, offset, count, - indexed, + family, }, RenderCommand::MultiDrawIndirectCount { @@ -341,11 +355,11 @@ impl RenderCommand { count_buffer_id, count_buffer_offset, max_count, - indexed, + family, } => { let scope = PassErrorScope::Draw { kind: DrawKind::MultiDrawIndirectCount, - indexed, + family, }; ArcRenderCommand::MultiDrawIndirectCount { buffer: buffers_guard.get(buffer_id).get().map_err(|e| { @@ -363,7 +377,7 @@ impl RenderCommand { )?, count_buffer_offset, max_count, - indexed, + family, } } @@ -459,11 +473,16 @@ pub enum ArcRenderCommand { base_vertex: i32, first_instance: u32, }, + DrawMeshTasks { + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, + }, DrawIndirect { buffer: Arc, offset: BufferAddress, count: u32, - indexed: bool, + family: DrawCommandFamily, }, MultiDrawIndirectCount { buffer: Arc, @@ -471,7 +490,7 @@ pub enum ArcRenderCommand { count_buffer: Arc, count_buffer_offset: BufferAddress, max_count: u32, - indexed: bool, + family: DrawCommandFamily, }, PushDebugGroup { #[cfg_attr(not(any(feature = "serde", feature = "replay")), allow(dead_code))] diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 5b56fe99c4..700df0c450 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -17,8 +17,9 @@ use crate::{ id::{self, AdapterId, DeviceId, QueueId, SurfaceId}, instance::{self, Adapter, Surface}, pipeline::{ - self, ResolvedComputePipelineDescriptor, ResolvedFragmentState, - ResolvedProgrammableStageDescriptor, ResolvedRenderPipelineDescriptor, ResolvedVertexState, + self, RenderPipelineVertexProcessor, ResolvedComputePipelineDescriptor, + ResolvedFragmentState, ResolvedGeneralRenderPipelineDescriptor, ResolvedMeshState, + ResolvedProgrammableStageDescriptor, ResolvedTaskState, ResolvedVertexState, }, present, resource::{ @@ -1190,33 +1191,89 @@ impl Global { id::RenderPipelineId, Option, ) { - profiling::scope!("Device::create_render_pipeline"); + let missing_implicit_pipeline_ids = + desc.layout.is_none() && id_in.is_some() && implicit_pipeline_ids.is_none(); let hub = &self.hub; + let fid = hub.render_pipelines.prepare(id_in); + let implicit_context = implicit_pipeline_ids.map(|ipi| ipi.prepare(hub)); + + let device = self.hub.devices.get(device_id); + #[cfg(feature = "trace")] + if let Some(ref mut trace) = *device.trace.lock() { + trace.add(trace::Action::CreateRenderPipeline { + id: fid.id(), + desc: desc.clone(), + implicit_context: implicit_context.clone(), + }); + } + self.device_create_general_render_pipeline( + &desc.clone().into(), + missing_implicit_pipeline_ids, + device, + fid, + implicit_context, + ) + } + + pub fn device_create_mesh_pipeline( + &self, + device_id: DeviceId, + desc: &pipeline::MeshPipelineDescriptor, + id_in: Option, + implicit_pipeline_ids: Option>, + ) -> ( + id::RenderPipelineId, + Option, + ) { let missing_implicit_pipeline_ids = desc.layout.is_none() && id_in.is_some() && implicit_pipeline_ids.is_none(); + let hub = &self.hub; + let fid = hub.render_pipelines.prepare(id_in); let implicit_context = implicit_pipeline_ids.map(|ipi| ipi.prepare(hub)); + let device = self.hub.devices.get(device_id); + #[cfg(feature = "trace")] + if let Some(ref mut trace) = *device.trace.lock() { + trace.add(trace::Action::CreateMeshPipeline { + id: fid.id(), + desc: desc.clone(), + implicit_context: implicit_context.clone(), + }); + } + self.device_create_general_render_pipeline( + &desc.clone().into(), + missing_implicit_pipeline_ids, + device, + fid, + implicit_context, + ) + } + + fn device_create_general_render_pipeline( + &self, + desc: &pipeline::GeneralRenderPipelineDescriptor, + missing_implicit_pipeline_ids: bool, + device: Arc, + fid: crate::registry::FutureId>, + implicit_context: Option, + ) -> ( + id::RenderPipelineId, + Option, + ) { + profiling::scope!("Device::create_general_render_pipeline"); + + let hub = &self.hub; + let error = 'error: { if missing_implicit_pipeline_ids { // TODO: categorize this error as API misuse break 'error pipeline::ImplicitLayoutError::MissingImplicitPipelineIds.into(); } - let device = self.hub.devices.get(device_id); - - #[cfg(feature = "trace")] - if let Some(ref mut trace) = *device.trace.lock() { - trace.add(trace::Action::CreateRenderPipeline { - id: fid.id(), - desc: desc.clone(), - implicit_context: implicit_context.clone(), - }); - } - let layout = desc .layout .map(|layout| hub.pipeline_layouts.get(layout).get()) @@ -1235,31 +1292,83 @@ impl Global { Err(e) => break 'error e.into(), }; - let vertex = { - let module = hub - .shader_modules - .get(desc.vertex.stage.module) - .get() - .map_err(|e| pipeline::CreateRenderPipelineError::Stage { - stage: wgt::ShaderStages::VERTEX, - error: e.into(), - }); - let module = match module { - Ok(module) => module, - Err(e) => break 'error e, - }; - let stage = ResolvedProgrammableStageDescriptor { - module, - entry_point: desc.vertex.stage.entry_point.clone(), - constants: desc.vertex.stage.constants.clone(), - zero_initialize_workgroup_memory: desc - .vertex - .stage - .zero_initialize_workgroup_memory, - }; - ResolvedVertexState { - stage, - buffers: desc.vertex.buffers.clone(), + let vertex = match desc.vertex { + RenderPipelineVertexProcessor::Vertex(ref vertex) => { + let module = hub + .shader_modules + .get(vertex.stage.module) + .get() + .map_err(|e| pipeline::CreateRenderPipelineError::Stage { + stage: wgt::ShaderStages::VERTEX, + error: e.into(), + }); + let module = match module { + Ok(module) => module, + Err(e) => break 'error e, + }; + let stage = ResolvedProgrammableStageDescriptor { + module, + entry_point: vertex.stage.entry_point.clone(), + constants: vertex.stage.constants.clone(), + zero_initialize_workgroup_memory: vertex + .stage + .zero_initialize_workgroup_memory, + }; + RenderPipelineVertexProcessor::Vertex(ResolvedVertexState { + stage, + buffers: vertex.buffers.clone(), + }) + } + RenderPipelineVertexProcessor::Mesh(ref task, ref mesh) => { + let task_module = if let Some(task) = task { + let module = hub + .shader_modules + .get(task.stage.module) + .get() + .map_err(|e| pipeline::CreateRenderPipelineError::Stage { + stage: wgt::ShaderStages::VERTEX, + error: e.into(), + }); + let module = match module { + Ok(module) => module, + Err(e) => break 'error e, + }; + let state = ResolvedProgrammableStageDescriptor { + module, + entry_point: task.stage.entry_point.clone(), + constants: task.stage.constants.clone(), + zero_initialize_workgroup_memory: task + .stage + .zero_initialize_workgroup_memory, + }; + Some(ResolvedTaskState { stage: state }) + } else { + None + }; + let mesh_module = + hub.shader_modules + .get(mesh.stage.module) + .get() + .map_err(|e| pipeline::CreateRenderPipelineError::Stage { + stage: wgt::ShaderStages::MESH, + error: e.into(), + }); + let mesh_module = match mesh_module { + Ok(module) => module, + Err(e) => break 'error e, + }; + let mesh_stage = ResolvedProgrammableStageDescriptor { + module: mesh_module, + entry_point: mesh.stage.entry_point.clone(), + constants: mesh.stage.constants.clone(), + zero_initialize_workgroup_memory: mesh + .stage + .zero_initialize_workgroup_memory, + }; + RenderPipelineVertexProcessor::Mesh( + task_module, + ResolvedMeshState { stage: mesh_stage }, + ) } }; @@ -1280,10 +1389,7 @@ impl Global { module, entry_point: state.stage.entry_point.clone(), constants: state.stage.constants.clone(), - zero_initialize_workgroup_memory: desc - .vertex - .stage - .zero_initialize_workgroup_memory, + zero_initialize_workgroup_memory: state.stage.zero_initialize_workgroup_memory, }; Some(ResolvedFragmentState { stage, @@ -1293,7 +1399,7 @@ impl Global { None }; - let desc = ResolvedRenderPipelineDescriptor { + let desc = ResolvedGeneralRenderPipelineDescriptor { label: desc.label.clone(), layout, vertex, diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 83fbcc3327..8d14571541 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -2976,7 +2976,7 @@ impl Device { pub(crate) fn create_render_pipeline( self: &Arc, - desc: pipeline::ResolvedRenderPipelineDescriptor, + desc: pipeline::ResolvedGeneralRenderPipelineDescriptor, ) -> Result, pipeline::CreateRenderPipelineError> { use wgt::TextureFormatFeatureFlags as Tfff; @@ -3017,127 +3017,137 @@ impl Device { let mut io = validation::StageIo::default(); let mut validated_stages = wgt::ShaderStages::empty(); - let mut vertex_steps = Vec::with_capacity(desc.vertex.buffers.len()); - let mut vertex_buffers = Vec::with_capacity(desc.vertex.buffers.len()); - let mut total_attributes = 0; + let mut vertex_steps; + let mut vertex_buffers; + let mut total_attributes; let mut shader_expects_dual_source_blending = false; let mut pipeline_expects_dual_source_blending = false; - for (i, vb_state) in desc.vertex.buffers.iter().enumerate() { - // https://gpuweb.github.io/gpuweb/#abstract-opdef-validating-gpuvertexbufferlayout - - if vb_state.array_stride > self.limits.max_vertex_buffer_array_stride as u64 { - return Err(pipeline::CreateRenderPipelineError::VertexStrideTooLarge { - index: i as u32, - given: vb_state.array_stride as u32, - limit: self.limits.max_vertex_buffer_array_stride, - }); - } - if vb_state.array_stride % wgt::VERTEX_STRIDE_ALIGNMENT != 0 { - return Err(pipeline::CreateRenderPipelineError::UnalignedVertexStride { - index: i as u32, - stride: vb_state.array_stride, - }); - } - - let max_stride = if vb_state.array_stride == 0 { - self.limits.max_vertex_buffer_array_stride as u64 - } else { - vb_state.array_stride - }; - let mut last_stride = 0; - for attribute in vb_state.attributes.iter() { - let attribute_stride = attribute.offset + attribute.format.size(); - if attribute_stride > max_stride { - return Err( - pipeline::CreateRenderPipelineError::VertexAttributeStrideTooLarge { - location: attribute.shader_location, - given: attribute_stride as u32, - limit: max_stride as u32, - }, - ); + if let pipeline::RenderPipelineVertexProcessor::Vertex(ref vertex) = desc.vertex { + vertex_steps = Vec::with_capacity(vertex.buffers.len()); + vertex_buffers = Vec::with_capacity(vertex.buffers.len()); + total_attributes = 0; + shader_expects_dual_source_blending = false; + pipeline_expects_dual_source_blending = false; + for (i, vb_state) in vertex.buffers.iter().enumerate() { + // https://gpuweb.github.io/gpuweb/#abstract-opdef-validating-gpuvertexbufferlayout + + if vb_state.array_stride > self.limits.max_vertex_buffer_array_stride as u64 { + return Err(pipeline::CreateRenderPipelineError::VertexStrideTooLarge { + index: i as u32, + given: vb_state.array_stride as u32, + limit: self.limits.max_vertex_buffer_array_stride, + }); } - - let required_offset_alignment = attribute.format.size().min(4); - if attribute.offset % required_offset_alignment != 0 { - return Err( - pipeline::CreateRenderPipelineError::InvalidVertexAttributeOffset { - location: attribute.shader_location, - offset: attribute.offset, - }, - ); + if vb_state.array_stride % wgt::VERTEX_STRIDE_ALIGNMENT != 0 { + return Err(pipeline::CreateRenderPipelineError::UnalignedVertexStride { + index: i as u32, + stride: vb_state.array_stride, + }); } - if attribute.shader_location >= self.limits.max_vertex_attributes { - return Err( - pipeline::CreateRenderPipelineError::TooManyVertexAttributes { - given: attribute.shader_location, - limit: self.limits.max_vertex_attributes, - }, - ); - } + let max_stride = if vb_state.array_stride == 0 { + self.limits.max_vertex_buffer_array_stride as u64 + } else { + vb_state.array_stride + }; + let mut last_stride = 0; + for attribute in vb_state.attributes.iter() { + let attribute_stride = attribute.offset + attribute.format.size(); + if attribute_stride > max_stride { + return Err( + pipeline::CreateRenderPipelineError::VertexAttributeStrideTooLarge { + location: attribute.shader_location, + given: attribute_stride as u32, + limit: max_stride as u32, + }, + ); + } - last_stride = last_stride.max(attribute_stride); - } - vertex_steps.push(pipeline::VertexStep { - stride: vb_state.array_stride, - last_stride, - mode: vb_state.step_mode, - }); - if vb_state.attributes.is_empty() { - continue; - } - vertex_buffers.push(hal::VertexBufferLayout { - array_stride: vb_state.array_stride, - step_mode: vb_state.step_mode, - attributes: vb_state.attributes.as_ref(), - }); + let required_offset_alignment = attribute.format.size().min(4); + if attribute.offset % required_offset_alignment != 0 { + return Err( + pipeline::CreateRenderPipelineError::InvalidVertexAttributeOffset { + location: attribute.shader_location, + offset: attribute.offset, + }, + ); + } - for attribute in vb_state.attributes.iter() { - if attribute.offset >= 0x10000000 { - return Err( - pipeline::CreateRenderPipelineError::InvalidVertexAttributeOffset { - location: attribute.shader_location, - offset: attribute.offset, - }, - ); - } + if attribute.shader_location >= self.limits.max_vertex_attributes { + return Err( + pipeline::CreateRenderPipelineError::TooManyVertexAttributes { + given: attribute.shader_location, + limit: self.limits.max_vertex_attributes, + }, + ); + } - if let wgt::VertexFormat::Float64 - | wgt::VertexFormat::Float64x2 - | wgt::VertexFormat::Float64x3 - | wgt::VertexFormat::Float64x4 = attribute.format - { - self.require_features(wgt::Features::VERTEX_ATTRIBUTE_64BIT)?; + last_stride = last_stride.max(attribute_stride); } + vertex_steps.push(pipeline::VertexStep { + stride: vb_state.array_stride, + last_stride, + mode: vb_state.step_mode, + }); + if vb_state.attributes.is_empty() { + continue; + } + vertex_buffers.push(hal::VertexBufferLayout { + array_stride: vb_state.array_stride, + step_mode: vb_state.step_mode, + attributes: vb_state.attributes.as_ref(), + }); - let previous = io.insert( - attribute.shader_location, - validation::InterfaceVar::vertex_attribute(attribute.format), - ); + for attribute in vb_state.attributes.iter() { + if attribute.offset >= 0x10000000 { + return Err( + pipeline::CreateRenderPipelineError::InvalidVertexAttributeOffset { + location: attribute.shader_location, + offset: attribute.offset, + }, + ); + } - if previous.is_some() { - return Err(pipeline::CreateRenderPipelineError::ShaderLocationClash( + if let wgt::VertexFormat::Float64 + | wgt::VertexFormat::Float64x2 + | wgt::VertexFormat::Float64x3 + | wgt::VertexFormat::Float64x4 = attribute.format + { + self.require_features(wgt::Features::VERTEX_ATTRIBUTE_64BIT)?; + } + + let previous = io.insert( attribute.shader_location, - )); + validation::InterfaceVar::vertex_attribute(attribute.format), + ); + + if previous.is_some() { + return Err(pipeline::CreateRenderPipelineError::ShaderLocationClash( + attribute.shader_location, + )); + } } + total_attributes += vb_state.attributes.len(); } - total_attributes += vb_state.attributes.len(); - } - if vertex_buffers.len() > self.limits.max_vertex_buffers as usize { - return Err(pipeline::CreateRenderPipelineError::TooManyVertexBuffers { - given: vertex_buffers.len() as u32, - limit: self.limits.max_vertex_buffers, - }); - } - if total_attributes > self.limits.max_vertex_attributes as usize { - return Err( - pipeline::CreateRenderPipelineError::TooManyVertexAttributes { - given: total_attributes as u32, - limit: self.limits.max_vertex_attributes, - }, - ); - } + if vertex_buffers.len() > self.limits.max_vertex_buffers as usize { + return Err(pipeline::CreateRenderPipelineError::TooManyVertexBuffers { + given: vertex_buffers.len() as u32, + limit: self.limits.max_vertex_buffers, + }); + } + if total_attributes > self.limits.max_vertex_attributes as usize { + return Err( + pipeline::CreateRenderPipelineError::TooManyVertexAttributes { + given: total_attributes as u32, + limit: self.limits.max_vertex_attributes, + }, + ); + } + } else { + vertex_steps = Vec::new(); + vertex_buffers = Vec::new(); + }; if desc.primitive.strip_index_format.is_some() && !desc.primitive.topology.is_strip() { return Err( @@ -3347,44 +3357,132 @@ impl Device { sc }; - let vertex_entry_point_name; - let vertex_stage = { - let stage_desc = &desc.vertex.stage; - let stage = wgt::ShaderStages::VERTEX; + let mut vertex_stage = None; + let mut task_stage = None; + let mut mesh_stage = None; + let mut _vertex_entry_point_name = String::new(); + let mut _task_entry_point_name = String::new(); + let mut _mesh_entry_point_name = String::new(); + match desc.vertex { + pipeline::RenderPipelineVertexProcessor::Vertex(ref vertex) => { + vertex_stage = { + let stage_desc = &vertex.stage; + let stage = wgt::ShaderStages::VERTEX; - let vertex_shader_module = &stage_desc.module; - vertex_shader_module.same_device(self)?; + let vertex_shader_module = &stage_desc.module; + vertex_shader_module.same_device(self)?; - let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; + let stage_err = + |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; - vertex_entry_point_name = vertex_shader_module - .finalize_entry_point_name( - stage, - stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()), - ) - .map_err(stage_err)?; - - if let Some(ref interface) = vertex_shader_module.interface { - io = interface - .check_stage( - &mut binding_layout_source, - &mut shader_binding_sizes, - &vertex_entry_point_name, - stage, - io, - desc.depth_stencil.as_ref().map(|d| d.depth_compare), - ) - .map_err(stage_err)?; - validated_stages |= stage; + _vertex_entry_point_name = vertex_shader_module + .finalize_entry_point_name( + stage, + stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()), + ) + .map_err(stage_err)?; + + if let Some(ref interface) = vertex_shader_module.interface { + io = interface + .check_stage( + &mut binding_layout_source, + &mut shader_binding_sizes, + &_vertex_entry_point_name, + stage, + io, + desc.depth_stencil.as_ref().map(|d| d.depth_compare), + ) + .map_err(stage_err)?; + validated_stages |= stage; + } + Some(hal::ProgrammableStage { + module: vertex_shader_module.raw(), + entry_point: &_vertex_entry_point_name, + constants: &stage_desc.constants, + zero_initialize_workgroup_memory: stage_desc + .zero_initialize_workgroup_memory, + }) + }; } + pipeline::RenderPipelineVertexProcessor::Mesh(ref task, ref mesh) => { + task_stage = if let Some(task) = task { + let stage_desc = &task.stage; + let stage = wgt::ShaderStages::TASK; + let task_shader_module = &stage_desc.module; + task_shader_module.same_device(self)?; + + let stage_err = + |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; + + _task_entry_point_name = task_shader_module + .finalize_entry_point_name( + stage, + stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()), + ) + .map_err(stage_err)?; + + if let Some(ref interface) = task_shader_module.interface { + io = interface + .check_stage( + &mut binding_layout_source, + &mut shader_binding_sizes, + &_task_entry_point_name, + stage, + io, + desc.depth_stencil.as_ref().map(|d| d.depth_compare), + ) + .map_err(stage_err)?; + validated_stages |= stage; + } + Some(hal::ProgrammableStage { + module: task_shader_module.raw(), + entry_point: &_task_entry_point_name, + constants: &stage_desc.constants, + zero_initialize_workgroup_memory: stage_desc + .zero_initialize_workgroup_memory, + }) + } else { + None + }; + mesh_stage = { + let stage_desc = &mesh.stage; + let stage = wgt::ShaderStages::MESH; + let mesh_shader_module = &stage_desc.module; + mesh_shader_module.same_device(self)?; + + let stage_err = + |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; + + _mesh_entry_point_name = mesh_shader_module + .finalize_entry_point_name( + stage, + stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()), + ) + .map_err(stage_err)?; - hal::ProgrammableStage { - module: vertex_shader_module.raw(), - entry_point: &vertex_entry_point_name, - constants: &stage_desc.constants, - zero_initialize_workgroup_memory: stage_desc.zero_initialize_workgroup_memory, + if let Some(ref interface) = mesh_shader_module.interface { + io = interface + .check_stage( + &mut binding_layout_source, + &mut shader_binding_sizes, + &_mesh_entry_point_name, + stage, + io, + desc.depth_stencil.as_ref().map(|d| d.depth_compare), + ) + .map_err(stage_err)?; + validated_stages |= stage; + } + Some(hal::ProgrammableStage { + module: mesh_shader_module.raw(), + entry_point: &_mesh_entry_point_name, + constants: &stage_desc.constants, + zero_initialize_workgroup_memory: stage_desc + .zero_initialize_workgroup_memory, + }) + }; } - }; + } let fragment_entry_point_name; let fragment_stage = match desc.fragment { @@ -3533,20 +3631,29 @@ impl Device { None => None, }; - let pipeline_desc = hal::RenderPipelineDescriptor { - label: desc.label.to_hal(self.instance_flags), - layout: pipeline_layout.raw(), - vertex_buffers: &vertex_buffers, - vertex_stage, - primitive: desc.primitive, - depth_stencil: desc.depth_stencil.clone(), - multisample: desc.multisample, - fragment_stage, - color_targets, - multiview: desc.multiview, - cache: cache.as_ref().map(|it| it.raw()), - }; - let raw = + let is_mesh = mesh_stage.is_some(); + let raw = { + let pipeline_desc = hal::RenderPipelineDescriptor { + label: desc.label.to_hal(self.instance_flags), + layout: pipeline_layout.raw(), + vertex_processor: match vertex_stage { + Some(vertex_stage) => hal::VertexProcessor::Standard { + vertex_buffers: &vertex_buffers, + vertex_stage, + }, + None => hal::VertexProcessor::Mesh { + task_stage, + mesh_stage: mesh_stage.unwrap(), + }, + }, + primitive: desc.primitive, + depth_stencil: desc.depth_stencil.clone(), + multisample: desc.multisample, + fragment_stage, + color_targets, + multiview: desc.multiview, + cache: cache.as_ref().map(|it| it.raw()), + }; unsafe { self.raw().create_render_pipeline(&pipeline_desc) }.map_err( |err| match err { hal::PipelineError::Device(error) => { @@ -3565,7 +3672,8 @@ impl Device { pipeline::CreateRenderPipelineError::PipelineConstants { stage, error } } }, - )?; + )? + }; let pass_context = RenderPassContext { attachments: AttachmentData { @@ -3599,10 +3707,19 @@ impl Device { flags |= pipeline::PipelineFlags::WRITES_STENCIL; } } - let shader_modules = { let mut shader_modules = ArrayVec::new(); - shader_modules.push(desc.vertex.stage.module); + match desc.vertex { + pipeline::RenderPipelineVertexProcessor::Vertex(vertex) => { + shader_modules.push(vertex.stage.module) + } + pipeline::RenderPipelineVertexProcessor::Mesh(task, mesh) => { + if let Some(task) = task { + shader_modules.push(task.stage.module); + } + shader_modules.push(mesh.stage.module); + } + } shader_modules.extend(desc.fragment.map(|f| f.stage.module)); shader_modules }; @@ -3619,6 +3736,7 @@ impl Device { late_sized_buffer_groups, label: desc.label.to_string(), tracking_data: TrackingData::new(self.tracker_indices.render_pipelines.clone()), + is_mesh, }; let pipeline = Arc::new(pipeline); diff --git a/wgpu-core/src/device/trace.rs b/wgpu-core/src/device/trace.rs index 3932d4086b..8e60b64c16 100644 --- a/wgpu-core/src/device/trace.rs +++ b/wgpu-core/src/device/trace.rs @@ -100,6 +100,12 @@ pub enum Action<'a> { #[cfg_attr(feature = "replay", serde(default))] implicit_context: Option, }, + CreateMeshPipeline { + id: id::RenderPipelineId, + desc: crate::pipeline::MeshPipelineDescriptor<'a>, + #[cfg_attr(feature = "replay", serde(default))] + implicit_context: Option, + }, DestroyRenderPipeline(id::RenderPipelineId), CreatePipelineCache { id: id::PipelineCacheId, diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index ebf3cdcae2..f9ee541368 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -334,6 +334,33 @@ pub struct FragmentState<'a, SM = ShaderModuleId> { /// cbindgen:ignore pub type ResolvedFragmentState<'a> = FragmentState<'a, Arc>; +/// Describes the task shader in a mesh shader pipeline. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct TaskState<'a, SM = ShaderModuleId> { + /// The compiled task stage and its entry point. + pub stage: ProgrammableStageDescriptor<'a, SM>, +} + +pub type ResolvedTaskState<'a> = TaskState<'a, Arc>; + +/// Describes the mesh shader in a mesh shader pipeline. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct MeshState<'a, SM = ShaderModuleId> { + /// The compiled mesh stage and its entry point. + pub stage: ProgrammableStageDescriptor<'a, SM>, +} + +pub type ResolvedMeshState<'a> = MeshState<'a, Arc>; + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub(crate) enum RenderPipelineVertexProcessor<'a, SM = ShaderModuleId> { + Vertex(VertexState<'a, SM>), + Mesh(Option>, MeshState<'a, SM>), +} + /// Describes a render (graphics) pipeline. #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -365,10 +392,109 @@ pub struct RenderPipelineDescriptor< /// The pipeline cache to use when creating this pipeline. pub cache: Option, } +/// Describes a render (graphics) pipeline. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct MeshPipelineDescriptor< + 'a, + PLL = PipelineLayoutId, + SM = ShaderModuleId, + PLC = PipelineCacheId, +> { + pub label: Label<'a>, + /// The layout of bind groups for this pipeline. + pub layout: Option, + /// The task processing state for this pipeline. + pub task: Option>, + /// The mesh processing state for this pipeline + pub mesh: MeshState<'a, SM>, + /// The properties of the pipeline at the primitive assembly and rasterization level. + #[cfg_attr(feature = "serde", serde(default))] + pub primitive: wgt::PrimitiveState, + /// The effect of draw calls on the depth and stencil aspects of the output target, if any. + #[cfg_attr(feature = "serde", serde(default))] + pub depth_stencil: Option, + /// The multi-sampling properties of the pipeline. + #[cfg_attr(feature = "serde", serde(default))] + pub multisample: wgt::MultisampleState, + /// The fragment processing state for this pipeline. + pub fragment: Option>, + /// If the pipeline will be used with a multiview render pass, this indicates how many array + /// layers the attachments will have. + pub multiview: Option, + /// The pipeline cache to use when creating this pipeline. + pub cache: Option, +} + +/// Describes a render (graphics) pipeline. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub(crate) struct GeneralRenderPipelineDescriptor< + 'a, + PLL = PipelineLayoutId, + SM = ShaderModuleId, + PLC = PipelineCacheId, +> { + pub label: Label<'a>, + /// The layout of bind groups for this pipeline. + pub layout: Option, + /// The vertex processing state for this pipeline. + pub vertex: RenderPipelineVertexProcessor<'a, SM>, + /// The properties of the pipeline at the primitive assembly and rasterization level. + #[cfg_attr(feature = "serde", serde(default))] + pub primitive: wgt::PrimitiveState, + /// The effect of draw calls on the depth and stencil aspects of the output target, if any. + #[cfg_attr(feature = "serde", serde(default))] + pub depth_stencil: Option, + /// The multi-sampling properties of the pipeline. + #[cfg_attr(feature = "serde", serde(default))] + pub multisample: wgt::MultisampleState, + /// The fragment processing state for this pipeline. + pub fragment: Option>, + /// If the pipeline will be used with a multiview render pass, this indicates how many array + /// layers the attachments will have. + pub multiview: Option, + /// The pipeline cache to use when creating this pipeline. + pub cache: Option, +} +impl<'a, PLL, SM, PLC> From> + for GeneralRenderPipelineDescriptor<'a, PLL, SM, PLC> +{ + fn from(value: RenderPipelineDescriptor<'a, PLL, SM, PLC>) -> Self { + Self { + label: value.label, + layout: value.layout, + vertex: RenderPipelineVertexProcessor::Vertex(value.vertex), + primitive: value.primitive, + depth_stencil: value.depth_stencil, + multisample: value.multisample, + fragment: value.fragment, + multiview: value.multiview, + cache: value.cache, + } + } +} +impl<'a, PLL, SM, PLC> From> + for GeneralRenderPipelineDescriptor<'a, PLL, SM, PLC> +{ + fn from(value: MeshPipelineDescriptor<'a, PLL, SM, PLC>) -> Self { + Self { + label: value.label, + layout: value.layout, + vertex: RenderPipelineVertexProcessor::Mesh(value.task, value.mesh), + primitive: value.primitive, + depth_stencil: value.depth_stencil, + multisample: value.multisample, + fragment: value.fragment, + multiview: value.multiview, + cache: value.cache, + } + } +} /// cbindgen:ignore -pub type ResolvedRenderPipelineDescriptor<'a> = - RenderPipelineDescriptor<'a, Arc, Arc, Arc>; +pub(crate) type ResolvedGeneralRenderPipelineDescriptor<'a> = + GeneralRenderPipelineDescriptor<'a, Arc, Arc, Arc>; #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -545,6 +671,8 @@ pub struct RenderPipeline { /// The `label` from the descriptor used to create the resource. pub(crate) label: String, pub(crate) tracking_data: TrackingData, + /// Whether this is a mesh shader pipeline + pub(crate) is_mesh: bool, } impl Drop for RenderPipeline { diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs index 58b93b00bc..92014b5091 100644 --- a/wgpu-core/src/validation.rs +++ b/wgpu-core/src/validation.rs @@ -1241,6 +1241,7 @@ impl Interface { ) } naga::ShaderStage::Compute => (false, 0), + // TODO: add validation for these naga::ShaderStage::Task | naga::ShaderStage::Mesh => { unreachable!() } diff --git a/wgpu-hal/examples/halmark/main.rs b/wgpu-hal/examples/halmark/main.rs index f7ce594d61..d8c10578d9 100644 --- a/wgpu-hal/examples/halmark/main.rs +++ b/wgpu-hal/examples/halmark/main.rs @@ -251,13 +251,15 @@ impl Example { let pipeline_desc = hal::RenderPipelineDescriptor { label: None, layout: &pipeline_layout, - vertex_stage: hal::ProgrammableStage { - module: &shader, - entry_point: "vs_main", - constants: &constants, - zero_initialize_workgroup_memory: true, + vertex_processor: hal::VertexProcessor::Standard { + vertex_stage: hal::ProgrammableStage { + module: &shader, + entry_point: "vs_main", + constants: &constants, + zero_initialize_workgroup_memory: true, + }, + vertex_buffers: &[], }, - vertex_buffers: &[], fragment_stage: Some(hal::ProgrammableStage { module: &shader, entry_point: "fs_main", diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index 82b62c5161..4f4c51b3f9 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -584,6 +584,11 @@ impl super::Adapter { // store buffer sizes using 32 bit ints (a situation we have already encountered with vulkan). max_buffer_size: i32::MAX as u64, max_non_sampler_bindings: 1_000_000, + + max_task_workgroup_total_count: 0, + max_task_workgroups_per_dimension: 0, + max_mesh_multiview_count: 0, + max_mesh_output_layers: 0, }, alignments: crate::Alignments { buffer_copy_offset: wgt::BufferSize::new( diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 8b03df0e97..de059d9ba8 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -1707,8 +1707,16 @@ impl crate::Device for super::Device { let (topology_class, topology) = conv::map_topology(desc.primitive.topology); let mut shader_stages = wgt::ShaderStages::VERTEX; + let (vertex_stage_desc, vertex_buffers_desc) = match &desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers, + vertex_stage, + } => (vertex_stage, *vertex_buffers), + crate::VertexProcessor::Mesh { .. } => unreachable!(), + }; + let blob_vs = self.load_shader( - &desc.vertex_stage, + vertex_stage_desc, desc.layout, naga::ShaderStage::Vertex, desc.fragment_stage.as_ref(), @@ -1725,7 +1733,7 @@ impl crate::Device for super::Device { let mut input_element_descs = Vec::new(); for (i, (stride, vbuf)) in vertex_strides .iter_mut() - .zip(desc.vertex_buffers) + .zip(vertex_buffers_desc) .enumerate() { *stride = NonZeroU32::new(vbuf.array_stride as u32); @@ -1885,17 +1893,6 @@ impl crate::Device for super::Device { }) } - unsafe fn create_mesh_pipeline( - &self, - _desc: &crate::MeshPipelineDescriptor< - ::PipelineLayout, - ::ShaderModule, - ::PipelineCache, - >, - ) -> Result<::RenderPipeline, crate::PipelineError> { - unreachable!() - } - unsafe fn destroy_render_pipeline(&self, _pipeline: super::RenderPipeline) { self.counters.render_pipelines.sub(1); } diff --git a/wgpu-hal/src/dynamic/device.rs b/wgpu-hal/src/dynamic/device.rs index c0de61f88c..7f8a37a449 100644 --- a/wgpu-hal/src/dynamic/device.rs +++ b/wgpu-hal/src/dynamic/device.rs @@ -4,10 +4,10 @@ use crate::{ AccelerationStructureBuildSizes, AccelerationStructureDescriptor, Api, BindGroupDescriptor, BindGroupLayoutDescriptor, BufferDescriptor, BufferMapping, CommandEncoderDescriptor, ComputePipelineDescriptor, Device, DeviceError, FenceValue, - GetAccelerationStructureBuildSizesDescriptor, Label, MemoryRange, MeshPipelineDescriptor, - PipelineCacheDescriptor, PipelineCacheError, PipelineError, PipelineLayoutDescriptor, - RenderPipelineDescriptor, SamplerDescriptor, ShaderError, ShaderInput, ShaderModuleDescriptor, - TextureDescriptor, TextureViewDescriptor, TlasInstance, + GetAccelerationStructureBuildSizesDescriptor, Label, MemoryRange, PipelineCacheDescriptor, + PipelineCacheError, PipelineError, PipelineLayoutDescriptor, RenderPipelineDescriptor, + SamplerDescriptor, ShaderError, ShaderInput, ShaderModuleDescriptor, TextureDescriptor, + TextureViewDescriptor, TlasInstance, }; use super::{ @@ -100,14 +100,6 @@ pub trait DynDevice: DynResource { dyn DynPipelineCache, >, ) -> Result, PipelineError>; - unsafe fn create_mesh_pipeline( - &self, - desc: &MeshPipelineDescriptor< - dyn DynPipelineLayout, - dyn DynShaderModule, - dyn DynPipelineCache, - >, - ) -> Result, PipelineError>; unsafe fn destroy_render_pipeline(&self, pipeline: Box); unsafe fn create_compute_pipeline( @@ -386,8 +378,22 @@ impl DynDevice for D { let desc = RenderPipelineDescriptor { label: desc.label, layout: desc.layout.expect_downcast_ref(), - vertex_buffers: desc.vertex_buffers, - vertex_stage: desc.vertex_stage.clone().expect_downcast(), + vertex_processor: match &desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers, + vertex_stage, + } => crate::VertexProcessor::Standard { + vertex_buffers, + vertex_stage: vertex_stage.clone().expect_downcast(), + }, + crate::VertexProcessor::Mesh { + task_stage: task, + mesh_stage: mesh, + } => crate::VertexProcessor::Mesh { + task_stage: task.as_ref().map(|a| a.clone().expect_downcast()), + mesh_stage: mesh.clone().expect_downcast(), + }, + }, primitive: desc.primitive, depth_stencil: desc.depth_stencil.clone(), multisample: desc.multisample, @@ -401,32 +407,6 @@ impl DynDevice for D { .map(|b| -> Box { Box::new(b) }) } - unsafe fn create_mesh_pipeline( - &self, - desc: &MeshPipelineDescriptor< - dyn DynPipelineLayout, - dyn DynShaderModule, - dyn DynPipelineCache, - >, - ) -> Result, PipelineError> { - let desc = MeshPipelineDescriptor { - label: desc.label, - layout: desc.layout.expect_downcast_ref(), - task_stage: desc.task_stage.clone().map(|f| f.expect_downcast()), - mesh_stage: desc.mesh_stage.clone().expect_downcast(), - primitive: desc.primitive, - depth_stencil: desc.depth_stencil.clone(), - multisample: desc.multisample, - fragment_stage: desc.fragment_stage.clone().map(|f| f.expect_downcast()), - color_targets: desc.color_targets, - multiview: desc.multiview, - cache: desc.cache.map(|c| c.expect_downcast_ref()), - }; - - unsafe { D::create_mesh_pipeline(self, &desc) } - .map(|b| -> Box { Box::new(b) }) - } - unsafe fn destroy_render_pipeline(&self, pipeline: Box) { unsafe { D::destroy_render_pipeline(self, pipeline.unbox()) }; } diff --git a/wgpu-hal/src/gles/adapter.rs b/wgpu-hal/src/gles/adapter.rs index ff476645b8..4d116917b1 100644 --- a/wgpu-hal/src/gles/adapter.rs +++ b/wgpu-hal/src/gles/adapter.rs @@ -790,6 +790,11 @@ impl super::Adapter { max_compute_workgroups_per_dimension, max_buffer_size: i32::MAX as u64, max_non_sampler_bindings: u32::MAX, + + max_task_workgroup_total_count: 0, + max_task_workgroups_per_dimension: 0, + max_mesh_multiview_count: 0, + max_mesh_output_layers: 0, }; let mut workarounds = super::Workarounds::empty(); diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index e3f3f61a38..457f7eccd0 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -1348,9 +1348,16 @@ impl crate::Device for super::Device { super::PipelineCache, >, ) -> Result { + let (vertex_stage, vertex_buffers) = match &desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers, + ref vertex_stage, + } => (vertex_stage, vertex_buffers), + crate::VertexProcessor::Mesh { .. } => unreachable!(), + }; let gl = &self.shared.context.lock(); let mut shaders = ArrayVec::new(); - shaders.push((naga::ShaderStage::Vertex, &desc.vertex_stage)); + shaders.push((naga::ShaderStage::Vertex, vertex_stage)); if let Some(ref fs) = desc.fragment_stage { shaders.push((naga::ShaderStage::Fragment, fs)); } @@ -1360,7 +1367,7 @@ impl crate::Device for super::Device { let (vertex_buffers, vertex_attributes) = { let mut buffers = Vec::new(); let mut attributes = Vec::new(); - for (index, vb_layout) in desc.vertex_buffers.iter().enumerate() { + for (index, vb_layout) in vertex_buffers.iter().enumerate() { buffers.push(super::VertexBufferDesc { step: vb_layout.step_mode, stride: vb_layout.array_stride as u32, @@ -1415,16 +1422,6 @@ impl crate::Device for super::Device { alpha_to_coverage_enabled: desc.multisample.alpha_to_coverage_enabled, }) } - unsafe fn create_mesh_pipeline( - &self, - _desc: &crate::MeshPipelineDescriptor< - ::PipelineLayout, - ::ShaderModule, - ::PipelineCache, - >, - ) -> Result<::RenderPipeline, crate::PipelineError> { - unreachable!() - } unsafe fn destroy_render_pipeline(&self, pipeline: super::RenderPipeline) { // If the pipeline only has 2 strong references remaining, they're `pipeline` and `program_cache` diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index aa997a2d9d..e60c6d9aae 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -901,15 +901,6 @@ pub trait Device: WasmNotSendSync { ::PipelineCache, >, ) -> Result<::RenderPipeline, PipelineError>; - #[allow(clippy::type_complexity)] - unsafe fn create_mesh_pipeline( - &self, - desc: &MeshPipelineDescriptor< - ::PipelineLayout, - ::ShaderModule, - ::PipelineCache, - >, - ) -> Result<::RenderPipeline, PipelineError>; unsafe fn destroy_render_pipeline(&self, pipeline: ::RenderPipeline); #[allow(clippy::type_complexity)] @@ -2156,6 +2147,20 @@ pub struct VertexBufferLayout<'a> { pub attributes: &'a [wgt::VertexAttribute], } +#[derive(Clone, Debug)] +pub enum VertexProcessor<'a, M: DynShaderModule + ?Sized> { + Standard { + /// The format of any vertex buffers used with this pipeline. + vertex_buffers: &'a [VertexBufferLayout<'a>], + /// The vertex stage for this pipeline. + vertex_stage: ProgrammableStage<'a, M>, + }, + Mesh { + task_stage: Option>, + mesh_stage: ProgrammableStage<'a, M>, + }, +} + /// Describes a render (graphics) pipeline. #[derive(Clone, Debug)] pub struct RenderPipelineDescriptor< @@ -2167,37 +2172,8 @@ pub struct RenderPipelineDescriptor< pub label: Label<'a>, /// The layout of bind groups for this pipeline. pub layout: &'a Pl, - /// The format of any vertex buffers used with this pipeline. - pub vertex_buffers: &'a [VertexBufferLayout<'a>], - /// The vertex stage for this pipeline. - pub vertex_stage: ProgrammableStage<'a, M>, - /// The properties of the pipeline at the primitive assembly and rasterization level. - pub primitive: wgt::PrimitiveState, - /// The effect of draw calls on the depth and stencil aspects of the output target, if any. - pub depth_stencil: Option, - /// The multi-sampling properties of the pipeline. - pub multisample: wgt::MultisampleState, - /// The fragment stage for this pipeline. - pub fragment_stage: Option>, - /// The effect of draw calls on the color aspect of the output target. - pub color_targets: &'a [Option], - /// If the pipeline will be used with a multiview render pass, this indicates how many array - /// layers the attachments will have. - pub multiview: Option, - /// The cache which will be used and filled when compiling this pipeline - pub cache: Option<&'a Pc>, -} -pub struct MeshPipelineDescriptor< - 'a, - Pl: DynPipelineLayout + ?Sized, - M: DynShaderModule + ?Sized, - Pc: DynPipelineCache + ?Sized, -> { - pub label: Label<'a>, - /// The layout of bind groups for this pipeline. - pub layout: &'a Pl, - pub task_stage: Option>, - pub mesh_stage: ProgrammableStage<'a, M>, + /// The vertex processing state(vertex shader + buffers or task + mesh shaders) + pub vertex_processor: VertexProcessor<'a, M>, /// The properties of the pipeline at the primitive assembly and rasterization level. pub primitive: wgt::PrimitiveState, /// The effect of draw calls on the depth and stencil aspects of the output target, if any. diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index f9e02cbde1..32f19fe90d 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -1063,6 +1063,11 @@ impl super::PrivateCapabilities { max_compute_workgroups_per_dimension: 0xFFFF, max_buffer_size: self.max_buffer_size, max_non_sampler_bindings: u32::MAX, + + max_task_workgroup_total_count: 0, + max_task_workgroups_per_dimension: 0, + max_mesh_multiview_count: 0, + max_mesh_output_layers: 0, }, alignments: crate::Alignments { buffer_copy_offset: wgt::BufferSize::new(self.buffer_alignment).unwrap(), diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 6fb172d007..04bf59a38c 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1006,6 +1006,14 @@ impl crate::Device for super::Device { super::PipelineCache, >, ) -> Result { + let (desc_vertex_stage, desc_vertex_buffers) = match &desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers, + vertex_stage, + } => (vertex_stage, *vertex_buffers), + crate::VertexProcessor::Mesh { .. } => unreachable!(), + }; + objc::rc::autoreleasepool(|| { let descriptor = metal::RenderPipelineDescriptor::new(); @@ -1024,7 +1032,7 @@ impl crate::Device for super::Device { // Vertex shader let (vs_lib, vs_info) = { let mut vertex_buffer_mappings = Vec::::new(); - for (i, vbl) in desc.vertex_buffers.iter().enumerate() { + for (i, vbl) in desc_vertex_buffers.iter().enumerate() { let mut attributes = Vec::::new(); for attribute in vbl.attributes.iter() { attributes.push(naga::back::msl::AttributeMapping { @@ -1053,7 +1061,7 @@ impl crate::Device for super::Device { } let vs = self.load_shader( - &desc.vertex_stage, + desc_vertex_stage, &vertex_buffer_mappings, desc.layout, primitive_class, @@ -1167,12 +1175,12 @@ impl crate::Device for super::Device { None => None, }; - if desc.layout.total_counters.vs.buffers + (desc.vertex_buffers.len() as u32) + if desc.layout.total_counters.vs.buffers + (desc_vertex_buffers.len() as u32) > self.shared.private_caps.max_vertex_buffers { let msg = format!( "pipeline needs too many buffers in the vertex stage: {} vertex and {} layout", - desc.vertex_buffers.len(), + desc_vertex_buffers.len(), desc.layout.total_counters.vs.buffers ); return Err(crate::PipelineError::Linkage( @@ -1181,9 +1189,9 @@ impl crate::Device for super::Device { )); } - if !desc.vertex_buffers.is_empty() { + if !desc_vertex_buffers.is_empty() { let vertex_descriptor = metal::VertexDescriptor::new(); - for (i, vb) in desc.vertex_buffers.iter().enumerate() { + for (i, vb) in desc_vertex_buffers.iter().enumerate() { let buffer_index = self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64; let buffer_desc = vertex_descriptor.layouts().object_at(buffer_index).unwrap(); @@ -1269,17 +1277,6 @@ impl crate::Device for super::Device { }) } - unsafe fn create_mesh_pipeline( - &self, - _desc: &crate::MeshPipelineDescriptor< - ::PipelineLayout, - ::ShaderModule, - ::PipelineCache, - >, - ) -> Result<::RenderPipeline, crate::PipelineError> { - unreachable!() - } - unsafe fn destroy_render_pipeline(&self, _pipeline: super::RenderPipeline) { self.counters.render_pipelines.sub(1); } diff --git a/wgpu-hal/src/noop/mod.rs b/wgpu-hal/src/noop/mod.rs index f5f9853928..ab983297eb 100644 --- a/wgpu-hal/src/noop/mod.rs +++ b/wgpu-hal/src/noop/mod.rs @@ -176,6 +176,11 @@ const CAPABILITIES: crate::Capabilities = { max_subgroup_size: ALLOC_MAX_U32, max_push_constant_size: ALLOC_MAX_U32, max_non_sampler_bindings: ALLOC_MAX_U32, + + max_task_workgroup_total_count: 0, + max_task_workgroups_per_dimension: 0, + max_mesh_multiview_count: 0, + max_mesh_output_layers: 0, }, alignments: crate::Alignments { // All maximally permissive @@ -368,16 +373,6 @@ impl crate::Device for Context { ) -> Result { Ok(Resource) } - unsafe fn create_mesh_pipeline( - &self, - desc: &crate::MeshPipelineDescriptor< - ::PipelineLayout, - ::ShaderModule, - ::PipelineCache, - >, - ) -> Result<::RenderPipeline, crate::PipelineError> { - Ok(Resource) - } unsafe fn destroy_render_pipeline(&self, pipeline: Resource) {} unsafe fn create_compute_pipeline( &self, diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index 8d315f042b..381baf9a95 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -909,7 +909,7 @@ pub struct PhysicalDeviceProperties { /// Additional `vk::PhysicalDevice` properties from the /// `VK_EXT_mesh_shader` extension. - _mesh_shader: Option>, + mesh_shader: Option>, /// The device API version. /// @@ -1136,6 +1136,20 @@ impl PhysicalDeviceProperties { let max_compute_workgroups_per_dimension = limits.max_compute_work_group_count[0] .min(limits.max_compute_work_group_count[1]) .min(limits.max_compute_work_group_count[2]); + let ( + max_task_workgroup_total_count, + max_task_workgroups_per_dimension, + max_mesh_multiview_count, + max_mesh_output_layers, + ) = match self.mesh_shader { + Some(m) => ( + m.max_task_work_group_total_count, + m.max_task_work_group_count.into_iter().min().unwrap(), + m.max_mesh_multiview_view_count, + m.max_mesh_output_layers, + ), + None => (0, 0, 0, 0), + }; // Prevent very large buffers on mesa and most android devices. let is_nvidia = self.properties.vendor_id == crate::auxil::db::nvidia::VENDOR; @@ -1231,6 +1245,10 @@ impl PhysicalDeviceProperties { max_compute_workgroups_per_dimension, max_buffer_size, max_non_sampler_bindings: u32::MAX, + max_task_workgroup_total_count, + max_task_workgroups_per_dimension, + max_mesh_multiview_count, + max_mesh_output_layers, } } @@ -1361,7 +1379,7 @@ impl super::InstanceShared { if supports_mesh_shader { let next = capabilities - ._mesh_shader + .mesh_shader .insert(vk::PhysicalDeviceMeshShaderPropertiesEXT::default()); properties2 = properties2.push_next(next); } diff --git a/wgpu-hal/src/vulkan/conv.rs b/wgpu-hal/src/vulkan/conv.rs index e1d5cb30e8..ef9108ac10 100644 --- a/wgpu-hal/src/vulkan/conv.rs +++ b/wgpu-hal/src/vulkan/conv.rs @@ -749,6 +749,12 @@ pub fn map_shader_stage(stage: wgt::ShaderStages) -> vk::ShaderStageFlags { if stage.contains(wgt::ShaderStages::COMPUTE) { flags |= vk::ShaderStageFlags::COMPUTE; } + if stage.contains(wgt::ShaderStages::TASK) { + flags |= vk::ShaderStageFlags::TASK_EXT; + } + if stage.contains(wgt::ShaderStages::MESH) { + flags |= vk::ShaderStageFlags::MESH_EXT; + } flags } diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index b71cd93e3d..4c72ad35e8 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -1907,25 +1907,32 @@ impl crate::Device for super::Device { ..Default::default() }; let mut stages = ArrayVec::<_, { crate::MAX_CONCURRENT_SHADER_STAGES }>::new(); - let mut vertex_buffers = Vec::with_capacity(desc.vertex_buffers.len()); + let mut vertex_buffers = Vec::new(); let mut vertex_attributes = Vec::new(); - for (i, vb) in desc.vertex_buffers.iter().enumerate() { - vertex_buffers.push(vk::VertexInputBindingDescription { - binding: i as u32, - stride: vb.array_stride as u32, - input_rate: match vb.step_mode { - wgt::VertexStepMode::Vertex => vk::VertexInputRate::VERTEX, - wgt::VertexStepMode::Instance => vk::VertexInputRate::INSTANCE, - }, - }); - for at in vb.attributes { - vertex_attributes.push(vk::VertexInputAttributeDescription { - location: at.shader_location, + if let crate::VertexProcessor::Standard { + vertex_buffers: desc_vertex_buffers, + vertex_stage: _, + } = &desc.vertex_processor + { + vertex_buffers = Vec::with_capacity(desc_vertex_buffers.len()); + for (i, vb) in desc_vertex_buffers.iter().enumerate() { + vertex_buffers.push(vk::VertexInputBindingDescription { binding: i as u32, - format: conv::map_vertex_format(at.format), - offset: at.offset as u32, + stride: vb.array_stride as u32, + input_rate: match vb.step_mode { + wgt::VertexStepMode::Vertex => vk::VertexInputRate::VERTEX, + wgt::VertexStepMode::Instance => vk::VertexInputRate::INSTANCE, + }, }); + for at in vb.attributes { + vertex_attributes.push(vk::VertexInputAttributeDescription { + location: at.shader_location, + binding: i as u32, + format: conv::map_vertex_format(at.format), + offset: at.offset as u32, + }); + } } } @@ -1937,12 +1944,41 @@ impl crate::Device for super::Device { .topology(conv::map_topology(desc.primitive.topology)) .primitive_restart_enable(desc.primitive.strip_index_format.is_some()); - let compiled_vs = self.compile_stage( - &desc.vertex_stage, - naga::ShaderStage::Vertex, - &desc.layout.binding_arrays, - )?; - stages.push(compiled_vs.create_info); + let mut compiled_vs = None; + let mut compiled_ms = None; + let mut compiled_ts = None; + match &desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers: _, + vertex_stage, + } => { + compiled_vs = Some(self.compile_stage( + vertex_stage, + naga::ShaderStage::Vertex, + &desc.layout.binding_arrays, + )?); + stages.push(compiled_vs.as_ref().unwrap().create_info); + } + crate::VertexProcessor::Mesh { + task_stage, + mesh_stage, + } => { + if let Some(t) = task_stage.as_ref() { + compiled_ts = Some(self.compile_stage( + t, + naga::ShaderStage::Task, + &desc.layout.binding_arrays, + )?); + stages.push(compiled_ts.as_ref().unwrap().create_info); + } + compiled_ms = Some(self.compile_stage( + mesh_stage, + naga::ShaderStage::Mesh, + &desc.layout.binding_arrays, + )?); + stages.push(compiled_ms.as_ref().unwrap().create_info); + } + } let compiled_fs = match desc.fragment_stage { Some(ref stage) => { let compiled = self.compile_stage( @@ -2105,228 +2141,13 @@ impl crate::Device for super::Device { unsafe { self.shared.set_object_name(raw, label) }; } - if let Some(raw_module) = compiled_vs.temp_raw_module { - unsafe { self.shared.raw.destroy_shader_module(raw_module, None) }; - } if let Some(CompiledStage { temp_raw_module: Some(raw_module), .. - }) = compiled_fs + }) = compiled_vs { unsafe { self.shared.raw.destroy_shader_module(raw_module, None) }; } - - self.counters.render_pipelines.add(1); - - Ok(super::RenderPipeline { raw }) - } - unsafe fn create_mesh_pipeline( - &self, - desc: &crate::MeshPipelineDescriptor< - ::PipelineLayout, - ::ShaderModule, - ::PipelineCache, - >, - ) -> Result<::RenderPipeline, crate::PipelineError> { - let dynamic_states = [ - vk::DynamicState::VIEWPORT, - vk::DynamicState::SCISSOR, - vk::DynamicState::BLEND_CONSTANTS, - vk::DynamicState::STENCIL_REFERENCE, - ]; - let mut compatible_rp_key = super::RenderPassKey { - sample_count: desc.multisample.count, - multiview: desc.multiview, - ..Default::default() - }; - let mut stages = ArrayVec::<_, { crate::MAX_CONCURRENT_SHADER_STAGES }>::new(); - - let vk_input_assembly = vk::PipelineInputAssemblyStateCreateInfo::default() - .topology(conv::map_topology(desc.primitive.topology)) - .primitive_restart_enable(desc.primitive.strip_index_format.is_some()); - - let compiled_ts = match desc.task_stage { - Some(ref stage) => { - let mut compiled = self.compile_stage( - stage, - naga::ShaderStage::Task, - &desc.layout.binding_arrays, - )?; - compiled.create_info.stage = vk::ShaderStageFlags::TASK_EXT; - stages.push(compiled.create_info); - Some(compiled) - } - None => None, - }; - - let mut compiled_ms = self.compile_stage( - &desc.mesh_stage, - naga::ShaderStage::Mesh, - &desc.layout.binding_arrays, - )?; - compiled_ms.create_info.stage = vk::ShaderStageFlags::MESH_EXT; - stages.push(compiled_ms.create_info); - let compiled_fs = match desc.fragment_stage { - Some(ref stage) => { - let compiled = self.compile_stage( - stage, - naga::ShaderStage::Fragment, - &desc.layout.binding_arrays, - )?; - stages.push(compiled.create_info); - Some(compiled) - } - None => None, - }; - - let mut vk_rasterization = vk::PipelineRasterizationStateCreateInfo::default() - .polygon_mode(conv::map_polygon_mode(desc.primitive.polygon_mode)) - .front_face(conv::map_front_face(desc.primitive.front_face)) - .line_width(1.0) - .depth_clamp_enable(desc.primitive.unclipped_depth); - if let Some(face) = desc.primitive.cull_mode { - vk_rasterization = vk_rasterization.cull_mode(conv::map_cull_face(face)) - } - let mut vk_rasterization_conservative_state = - vk::PipelineRasterizationConservativeStateCreateInfoEXT::default() - .conservative_rasterization_mode( - vk::ConservativeRasterizationModeEXT::OVERESTIMATE, - ); - if desc.primitive.conservative { - vk_rasterization = vk_rasterization.push_next(&mut vk_rasterization_conservative_state); - } - - let mut vk_depth_stencil = vk::PipelineDepthStencilStateCreateInfo::default(); - if let Some(ref ds) = desc.depth_stencil { - let vk_format = self.shared.private_caps.map_texture_format(ds.format); - let vk_layout = if ds.is_read_only(desc.primitive.cull_mode) { - vk::ImageLayout::DEPTH_STENCIL_READ_ONLY_OPTIMAL - } else { - vk::ImageLayout::DEPTH_STENCIL_ATTACHMENT_OPTIMAL - }; - compatible_rp_key.depth_stencil = Some(super::DepthStencilAttachmentKey { - base: super::AttachmentKey::compatible(vk_format, vk_layout), - stencil_ops: crate::AttachmentOps::all(), - }); - - if ds.is_depth_enabled() { - vk_depth_stencil = vk_depth_stencil - .depth_test_enable(true) - .depth_write_enable(ds.depth_write_enabled) - .depth_compare_op(conv::map_comparison(ds.depth_compare)); - } - if ds.stencil.is_enabled() { - let s = &ds.stencil; - let front = conv::map_stencil_face(&s.front, s.read_mask, s.write_mask); - let back = conv::map_stencil_face(&s.back, s.read_mask, s.write_mask); - vk_depth_stencil = vk_depth_stencil - .stencil_test_enable(true) - .front(front) - .back(back); - } - - if ds.bias.is_enabled() { - vk_rasterization = vk_rasterization - .depth_bias_enable(true) - .depth_bias_constant_factor(ds.bias.constant as f32) - .depth_bias_clamp(ds.bias.clamp) - .depth_bias_slope_factor(ds.bias.slope_scale); - } - } - - let vk_viewport = vk::PipelineViewportStateCreateInfo::default() - .flags(vk::PipelineViewportStateCreateFlags::empty()) - .scissor_count(1) - .viewport_count(1); - - let vk_sample_mask = [ - desc.multisample.mask as u32, - (desc.multisample.mask >> 32) as u32, - ]; - let vk_multisample = vk::PipelineMultisampleStateCreateInfo::default() - .rasterization_samples(vk::SampleCountFlags::from_raw(desc.multisample.count)) - .alpha_to_coverage_enable(desc.multisample.alpha_to_coverage_enabled) - .sample_mask(&vk_sample_mask); - - let mut vk_attachments = Vec::with_capacity(desc.color_targets.len()); - for cat in desc.color_targets { - let (key, attarchment) = if let Some(cat) = cat.as_ref() { - let mut vk_attachment = vk::PipelineColorBlendAttachmentState::default() - .color_write_mask(vk::ColorComponentFlags::from_raw(cat.write_mask.bits())); - if let Some(ref blend) = cat.blend { - let (color_op, color_src, color_dst) = conv::map_blend_component(&blend.color); - let (alpha_op, alpha_src, alpha_dst) = conv::map_blend_component(&blend.alpha); - vk_attachment = vk_attachment - .blend_enable(true) - .color_blend_op(color_op) - .src_color_blend_factor(color_src) - .dst_color_blend_factor(color_dst) - .alpha_blend_op(alpha_op) - .src_alpha_blend_factor(alpha_src) - .dst_alpha_blend_factor(alpha_dst); - } - - let vk_format = self.shared.private_caps.map_texture_format(cat.format); - ( - Some(super::ColorAttachmentKey { - base: super::AttachmentKey::compatible( - vk_format, - vk::ImageLayout::COLOR_ATTACHMENT_OPTIMAL, - ), - resolve: None, - }), - vk_attachment, - ) - } else { - (None, vk::PipelineColorBlendAttachmentState::default()) - }; - - compatible_rp_key.colors.push(key); - vk_attachments.push(attarchment); - } - - let vk_color_blend = - vk::PipelineColorBlendStateCreateInfo::default().attachments(&vk_attachments); - - let vk_dynamic_state = - vk::PipelineDynamicStateCreateInfo::default().dynamic_states(&dynamic_states); - - let raw_pass = self.shared.make_render_pass(compatible_rp_key)?; - - let vk_infos = [{ - vk::GraphicsPipelineCreateInfo::default() - .layout(desc.layout.raw) - .stages(&stages) - .input_assembly_state(&vk_input_assembly) - .rasterization_state(&vk_rasterization) - .viewport_state(&vk_viewport) - .multisample_state(&vk_multisample) - .depth_stencil_state(&vk_depth_stencil) - .color_blend_state(&vk_color_blend) - .dynamic_state(&vk_dynamic_state) - .render_pass(raw_pass) - }]; - - let pipeline_cache = desc - .cache - .map(|it| it.raw) - .unwrap_or(vk::PipelineCache::null()); - - let mut raw_vec = { - profiling::scope!("vkCreateGraphicsPipelines"); - unsafe { - self.shared - .raw - .create_graphics_pipelines(pipeline_cache, &vk_infos, None) - .map_err(|(_, e)| super::map_pipeline_err(e)) - }? - }; - - let raw = raw_vec.pop().unwrap(); - if let Some(label) = desc.label { - unsafe { self.shared.set_object_name(raw, label) }; - } - // NOTE: this could leak shaders in case of an error. if let Some(CompiledStage { temp_raw_module: Some(raw_module), .. @@ -2334,7 +2155,11 @@ impl crate::Device for super::Device { { unsafe { self.shared.raw.destroy_shader_module(raw_module, None) }; } - if let Some(raw_module) = compiled_ms.temp_raw_module { + if let Some(CompiledStage { + temp_raw_module: Some(raw_module), + .. + }) = compiled_ms + { unsafe { self.shared.raw.destroy_shader_module(raw_module, None) }; } if let Some(CompiledStage { diff --git a/wgpu-info/src/human.rs b/wgpu-info/src/human.rs index 1c2f7a841d..f6f4cbdba0 100644 --- a/wgpu-info/src/human.rs +++ b/wgpu-info/src/human.rs @@ -160,6 +160,11 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize max_subgroup_size, max_push_constant_size, max_non_sampler_bindings, + + max_task_workgroup_total_count, + max_task_workgroups_per_dimension, + max_mesh_multiview_count, + max_mesh_output_layers, } = limits; writeln!(output, "\t\t Max Texture Dimension 1d: {max_texture_dimension_1d}")?; writeln!(output, "\t\t Max Texture Dimension 2d: {max_texture_dimension_2d}")?; @@ -196,6 +201,10 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize writeln!(output, "\t\t Max Compute Workgroup Size Y: {max_compute_workgroup_size_y}")?; writeln!(output, "\t\t Max Compute Workgroup Size Z: {max_compute_workgroup_size_z}")?; writeln!(output, "\t\t Max Compute Workgroups Per Dimension: {max_compute_workgroups_per_dimension}")?; + writeln!(output, "\t\t Max Task Workgroup Total Count: {max_task_workgroup_total_count}")?; + writeln!(output, "\t\t Max Task Workgroups Per Dimension: {max_task_workgroups_per_dimension}")?; + writeln!(output, "\t\t Max Mesh Multiview Count: {max_mesh_multiview_count}")?; + writeln!(output, "\t\t Max Mesh Output Layers: {max_mesh_output_layers}")?; // This one reflects more of a wgpu implementation limitations than a hardware limit // so don't show it here. diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index e619bfd9dc..0b515386d6 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -600,6 +600,16 @@ pub struct Limits { /// This limit only affects the d3d12 backend. Using a large number will allow the device /// to create many bind groups at the cost of a large up-front allocation at device creation. pub max_non_sampler_bindings: u32, + + /// The maximum total value of x*y*z for a given `draw_mesh_tasks` command + pub max_task_workgroup_total_count: u32, + /// The maximum value for each dimension of a `RenderPass::draw_mesh_tasks(x, y, z)` operation. + /// Defaults to 65535. Higher is "better". + pub max_task_workgroups_per_dimension: u32, + /// The maximum number of layers that can be output from a mesh shader + pub max_mesh_output_layers: u32, + /// The maximum number of views that can be used by a mesh shader + pub max_mesh_multiview_count: u32, } impl Default for Limits { @@ -649,6 +659,14 @@ impl Limits { max_subgroup_size: 0, max_push_constant_size: 0, max_non_sampler_bindings: 1_000_000, + + // Literally just made this up as 1024^2. + // My GPU supports 4 times this, and compute shaders don't have this kind of limit. + // This very likely is never a real limiter + max_task_workgroup_total_count: 1048576, + max_task_workgroups_per_dimension: 65535, + max_mesh_multiview_count: 1, + max_mesh_output_layers: 1024, } } @@ -694,6 +712,11 @@ impl Limits { /// max_compute_workgroups_per_dimension: 65535, /// max_buffer_size: 256 << 20, // (256 MiB) /// max_non_sampler_bindings: 1_000_000, + /// + /// max_task_workgroup_total_count: 0, + /// max_task_workgroups_per_dimension: 0, + /// max_mesh_multiview_count: 0, + /// max_mesh_output_layers: 0, /// }); /// ``` #[must_use] @@ -707,6 +730,11 @@ impl Limits { max_color_attachments: 4, // see: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf#page=7 max_compute_workgroup_storage_size: 16352, + + max_task_workgroups_per_dimension: 0, + max_task_workgroup_total_count: 0, + max_mesh_multiview_count: 0, + max_mesh_output_layers: 0, ..Self::defaults() } } @@ -754,6 +782,11 @@ impl Limits { /// max_compute_workgroups_per_dimension: 0, // + /// max_buffer_size: 256 << 20, // (256 MiB), /// max_non_sampler_bindings: 1_000_000, + /// + /// max_task_workgroup_total_count: 0, + /// max_task_workgroups_per_dimension: 0, + /// max_mesh_multiview_count: 0, + /// max_mesh_output_layers: 0, /// }); /// ``` #[must_use] @@ -777,6 +810,11 @@ impl Limits { // Value supported by Intel Celeron B830 on Windows (OpenGL 3.1) max_inter_stage_shader_components: 31, + max_task_workgroups_per_dimension: 0, + max_task_workgroup_total_count: 0, + max_mesh_multiview_count: 0, + max_mesh_output_layers: 0, + // Most of the values should be the same as the downlevel defaults ..Self::downlevel_defaults() } @@ -888,6 +926,13 @@ impl Limits { } compare!(max_push_constant_size, Less); compare!(max_non_sampler_bindings, Less); + + if self.max_task_workgroup_total_count > 0 { + compare!(max_task_workgroup_total_count, Less); + compare!(max_task_workgroups_per_dimension, Less); + compare!(max_mesh_multiview_count, Less); + compare!(max_mesh_output_layers, Less); + } } } @@ -1305,9 +1350,9 @@ bitflags::bitflags! { const COMPUTE = 1 << 2; /// Binding is visible from the vertex and fragment shaders of a render pipeline. const VERTEX_FRAGMENT = Self::VERTEX.bits() | Self::FRAGMENT.bits(); - /// Binding is visible from the task shader of a mesh pipeline + /// Binding is visible from the task shader of a mesh pipeline. const TASK = 1 << 3; - /// Binding is visible from the mesh shader of a mesh pipeline + /// Binding is visible from the mesh shader of a mesh pipeline. const MESH = 1 << 4; } } diff --git a/wgpu/src/api/device.rs b/wgpu/src/api/device.rs index 992d8e5717..c3dca103ac 100644 --- a/wgpu/src/api/device.rs +++ b/wgpu/src/api/device.rs @@ -240,6 +240,13 @@ impl Device { RenderPipeline { inner: pipeline } } + /// Creates a mesh shader based [`RenderPipeline`]. + #[must_use] + pub fn create_mesh_pipeline(&self, desc: &MeshPipelineDescriptor<'_>) -> RenderPipeline { + let pipeline = self.inner.create_mesh_pipeline(desc); + RenderPipeline { inner: pipeline } + } + /// Creates a [`ComputePipeline`]. #[must_use] pub fn create_compute_pipeline(&self, desc: &ComputePipelineDescriptor<'_>) -> ComputePipeline { diff --git a/wgpu/src/api/render_pass.rs b/wgpu/src/api/render_pass.rs index 6c30543f62..1c46b5386c 100644 --- a/wgpu/src/api/render_pass.rs +++ b/wgpu/src/api/render_pass.rs @@ -226,6 +226,12 @@ impl RenderPass<'_> { self.inner.draw_indexed(indices, base_vertex, instances); } + /// Draws using a mesh shader pipeline + pub fn draw_mesh_tasks(&mut self, group_count_x: u32, group_count_y: u32, group_count_z: u32) { + self.inner + .draw_mesh_tasks(group_count_x, group_count_y, group_count_z); + } + /// Draws primitives from the active vertex buffer(s) based on the contents of the `indirect_buffer`. /// /// This is like calling [`RenderPass::draw`] but the contents of the call are specified in the `indirect_buffer`. @@ -267,6 +273,25 @@ impl RenderPass<'_> { .draw_indexed_indirect(&indirect_buffer.inner, indirect_offset); } + /// Draws using a mesh shader pipeline, + /// based on the contents of the `indirect_buffer` + /// + /// This is like calling [`RenderPass::draw_mesh_tasks`] but the contents of the call are specified in the `indirect_buffer`. + /// The structure expected in the `indirect_buffer` must conform to [`DispatchIndirectArgs`](crate::util::DispatchIndirectArgs). + /// + /// Indirect drawing has some caveats depending on the features available. We are not currently able to validate + /// these and issue an error. + /// + /// See details on the individual flags for more information. + pub fn draw_mesh_tasks_indirect( + &mut self, + indirect_buffer: &Buffer, + indirect_offset: BufferAddress, + ) { + self.inner + .draw_mesh_tasks_indirect(&indirect_buffer.inner, indirect_offset); + } + /// Execute a [render bundle][RenderBundle], which is a set of pre-recorded commands /// that can be run together. /// @@ -324,6 +349,23 @@ impl RenderPass<'_> { self.inner .multi_draw_indexed_indirect(&indirect_buffer.inner, indirect_offset, count); } + + /// Dispatches multiple draw calls based on the contents of the `indirect_buffer`. + /// `count` draw calls are issued. + /// + /// The structure expected in the `indirect_buffer` must conform to [`DispatchIndirectArgs`](crate::util::DispatchIndirectArgs). + /// + /// This drawing command uses the current render state, as set by preceding `set_*()` methods. + /// It is not affected by changes to the state that are performed after it is called. + pub fn multi_draw_mesh_tasks_indirect( + &mut self, + indirect_buffer: &Buffer, + indirect_offset: BufferAddress, + count: u32, + ) { + self.inner + .multi_draw_mesh_tasks_indirect(&indirect_buffer.inner, indirect_offset, count); + } } /// [`Features::MULTI_DRAW_INDIRECT_COUNT`] must be enabled on the device in order to call these functions. @@ -407,6 +449,34 @@ impl RenderPass<'_> { max_count, ); } + + /// Dispatches multiple draw calls based on the contents of the `indirect_buffer`. The count buffer is read to determine how many draws to issue. + /// + /// The indirect buffer must be long enough to account for `max_count` draws, however only `count` + /// draws will be read. If `count` is greater than `max_count`, `max_count` will be used. + /// + /// The structure expected in the `indirect_buffer` must conform to [`DispatchIndirectArgs`](crate::util::DispatchIndirectArgs). + /// + /// These draw structures are expected to be tightly packed. + /// + /// This drawing command uses the current render state, as set by preceding `set_*()` methods. + /// It is not affected by changes to the state that are performed after it is called. + pub fn multi_draw_mesh_tasks_indirect_count( + &mut self, + indirect_buffer: &Buffer, + indirect_offset: BufferAddress, + count_buffer: &Buffer, + count_offset: BufferAddress, + max_count: u32, + ) { + self.inner.multi_draw_mesh_tasks_indirect_count( + &indirect_buffer.inner, + indirect_offset, + &count_buffer.inner, + count_offset, + max_count, + ); + } } /// [`Features::PUSH_CONSTANTS`] must be enabled on the device in order to call these functions. diff --git a/wgpu/src/api/render_pipeline.rs b/wgpu/src/api/render_pipeline.rs index b033b2bfda..1f84c9971d 100644 --- a/wgpu/src/api/render_pipeline.rs +++ b/wgpu/src/api/render_pipeline.rs @@ -139,6 +139,48 @@ pub struct FragmentState<'a> { #[cfg(send_sync)] static_assertions::assert_impl_all!(FragmentState<'_>: Send, Sync); +/// Describes the task shader stage in a mesh shader pipeline. +/// +/// For use in [`MeshPipelineDescriptor`] +#[derive(Clone, Debug)] +pub struct TaskState<'a> { + /// The compiled shader module for this stage. + pub module: &'a ShaderModule, + /// The name of the entry point in the compiled shader to use. + /// + /// If [`Some`], there must be a vertex-stage shader entry point with this name in `module`. + /// Otherwise, expect exactly one vertex-stage entry point in `module`, which will be + /// selected. + pub entry_point: Option<&'a str>, + /// Advanced options for when this pipeline is compiled + /// + /// This implements `Default`, and for most users can be set to `Default::default()` + pub compilation_options: PipelineCompilationOptions<'a>, +} +#[cfg(send_sync)] +static_assertions::assert_impl_all!(TaskState<'_>: Send, Sync); + +/// Describes the mesh shader stage in a mesh shader pipeline. +/// +/// For use in [`MeshPipelineDescriptor`] +#[derive(Clone, Debug)] +pub struct MeshState<'a> { + /// The compiled shader module for this stage. + pub module: &'a ShaderModule, + /// The name of the entry point in the compiled shader to use. + /// + /// If [`Some`], there must be a vertex-stage shader entry point with this name in `module`. + /// Otherwise, expect exactly one vertex-stage entry point in `module`, which will be + /// selected. + pub entry_point: Option<&'a str>, + /// Advanced options for when this pipeline is compiled + /// + /// This implements `Default`, and for most users can be set to `Default::default()` + pub compilation_options: PipelineCompilationOptions<'a>, +} +#[cfg(send_sync)] +static_assertions::assert_impl_all!(MeshState<'_>: Send, Sync); + /// Describes a render (graphics) pipeline. /// /// For use with [`Device::create_render_pipeline`]. @@ -187,3 +229,51 @@ pub struct RenderPipelineDescriptor<'a> { } #[cfg(send_sync)] static_assertions::assert_impl_all!(RenderPipelineDescriptor<'_>: Send, Sync); + +/// Describes a mesh shader (graphics) pipeline. +/// +/// For use with [`Device::create_mesh_pipeline`]. +#[derive(Clone, Debug)] +pub struct MeshPipelineDescriptor<'a> { + /// Debug label of the pipeline. This will show up in graphics debuggers for easy identification. + pub label: Label<'a>, + /// The layout of bind groups for this pipeline. + /// + /// If this is set, then [`Device::create_render_pipeline`] will raise a validation error if + /// the layout doesn't match what the shader module(s) expect. + /// + /// Using the same [`PipelineLayout`] for many [`RenderPipeline`] or [`ComputePipeline`] + /// pipelines guarantees that you don't have to rebind any resources when switching between + /// those pipelines. + /// + /// ## Default pipeline layout + /// + /// If `layout` is `None`, then the pipeline has a [default layout] created and used instead. + /// The default layout is deduced from the shader modules. + /// + /// You can use [`RenderPipeline::get_bind_group_layout`] to create bind groups for use with the + /// default layout. However, these bind groups cannot be used with any other pipelines. This is + /// convenient for simple pipelines, but using an explicit layout is recommended in most cases. + /// + /// [default layout]: https://www.w3.org/TR/webgpu/#default-pipeline-layout + pub layout: Option<&'a PipelineLayout>, + /// The compiled task stage, its entry point, and the color targets. + pub task: Option>, + /// The compiled mesh stage and its entry point + pub mesh: MeshState<'a>, + /// The properties of the pipeline at the primitive assembly and rasterization level. + pub primitive: PrimitiveState, + /// The effect of draw calls on the depth and stencil aspects of the output target, if any. + pub depth_stencil: Option, + /// The multi-sampling properties of the pipeline. + pub multisample: MultisampleState, + /// The compiled fragment stage, its entry point, and the color targets. + pub fragment: Option>, + /// If the pipeline will be used with a multiview render pass, this indicates how many array + /// layers the attachments will have. + pub multiview: Option, + /// The pipeline cache to use when creating this pipeline. + pub cache: Option<&'a PipelineCache>, +} +#[cfg(send_sync)] +static_assertions::assert_impl_all!(MeshPipelineDescriptor<'_>: Send, Sync); diff --git a/wgpu/src/backend/webgpu.rs b/wgpu/src/backend/webgpu.rs index d7a76384f9..d0c93a185c 100644 --- a/wgpu/src/backend/webgpu.rs +++ b/wgpu/src/backend/webgpu.rs @@ -816,6 +816,13 @@ fn map_wgt_limits(limits: webgpu_sys::GpuSupportedLimits) -> wgt::Limits { max_push_constant_size: wgt::Limits::default().max_push_constant_size, max_non_sampler_bindings: wgt::Limits::default().max_non_sampler_bindings, max_inter_stage_shader_components: wgt::Limits::default().max_inter_stage_shader_components, + + max_mesh_invocations_per_workgroup: wgt::Limits::default() + .max_mesh_invocations_per_workgroup, + max_mesh_workgroup_size_x: wgt::Limits::default().max_mesh_workgroup_size_x, + max_mesh_workgroup_size_y: wgt::Limits::default().max_mesh_workgroup_size_y, + max_mesh_workgroup_size_z: wgt::Limits::default().max_mesh_workgroup_size_z, + max_mesh_workgroups_per_dimension: wgt::Limits::default().max_mesh_workgroups_per_dimension, } } @@ -2156,6 +2163,13 @@ impl dispatch::DeviceInterface for WebDevice { .into() } + fn create_mesh_pipeline( + &self, + _desc: &crate::MeshPipelineDescriptor<'_>, + ) -> dispatch::DispatchRenderPipeline { + panic!("MESH_SHADER feature must be enabled to call create_mesh_pipeline") + } + fn create_compute_pipeline( &self, desc: &crate::ComputePipelineDescriptor<'_>, @@ -3352,6 +3366,10 @@ impl dispatch::RenderPassInterface for WebRenderPassEncoder { ) } + fn draw_mesh_tasks(&mut self, _group_count_x: u32, _group_count_y: u32, _group_count_z: u32) { + panic!("MESH_SHADER feature must be enabled to call draw_mesh_tasks") + } + fn draw_indirect( &mut self, indirect_buffer: &dispatch::DispatchBuffer, @@ -3372,6 +3390,14 @@ impl dispatch::RenderPassInterface for WebRenderPassEncoder { .draw_indexed_indirect_with_f64(&buffer.inner, indirect_offset as f64); } + fn draw_mesh_tasks_indirect( + &mut self, + _indirect_buffer: &dispatch::DispatchBuffer, + _indirect_offset: crate::BufferAddress, + ) { + panic!("MESH_SHADER feature must be enabled to call draw_mesh_tasks_indirect") + } + fn multi_draw_indirect( &mut self, indirect_buffer: &dispatch::DispatchBuffer, @@ -3402,6 +3428,15 @@ impl dispatch::RenderPassInterface for WebRenderPassEncoder { } } + fn multi_draw_mesh_tasks_indirect( + &mut self, + _indirect_buffer: &dispatch::DispatchBuffer, + _indirect_offset: crate::BufferAddress, + _count: u32, + ) { + panic!("MESH_SHADER feature must be enabled to call multi_draw_mesh_tasks_indirect") + } + fn multi_draw_indirect_count( &mut self, _indirect_buffer: &dispatch::DispatchBuffer, @@ -3426,6 +3461,17 @@ impl dispatch::RenderPassInterface for WebRenderPassEncoder { panic!("MULTI_DRAW_INDIRECT_COUNT feature must be enabled to call multi_draw_indexed_indirect_count") } + fn multi_draw_mesh_tasks_indirect_count( + &mut self, + _indirect_buffer: &dispatch::DispatchBuffer, + _indirect_offset: crate::BufferAddress, + _count_buffer: &dispatch::DispatchBuffer, + _count_buffer_offset: crate::BufferAddress, + _max_count: u32, + ) { + panic!("MESH_SHADER feature must be enabled to call multi_draw_mesh_tasks_indirect_count") + } + fn insert_debug_marker(&mut self, _label: &str) { // Not available in gecko yet // self.inner.insert_debug_marker(label); diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index e81d02c24b..5fc9b54f7b 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -1352,6 +1352,102 @@ impl dispatch::DeviceInterface for CoreDevice { .into() } + fn create_mesh_pipeline( + &self, + desc: &crate::MeshPipelineDescriptor<'_>, + ) -> dispatch::DispatchRenderPipeline { + use wgc::pipeline as pipe; + + let mesh_constants = desc + .mesh + .compilation_options + .constants + .iter() + .map(|&(key, value)| (String::from(key), value)) + .collect(); + let descriptor = pipe::MeshPipelineDescriptor { + label: desc.label.map(Borrowed), + task: desc.task.as_ref().map(|task| { + let task_constants = task + .compilation_options + .constants + .iter() + .map(|&(key, value)| (String::from(key), value)) + .collect(); + pipe::TaskState { + stage: pipe::ProgrammableStageDescriptor { + module: task.module.inner.as_core().id, + entry_point: task.entry_point.map(Borrowed), + constants: task_constants, + zero_initialize_workgroup_memory: desc + .mesh + .compilation_options + .zero_initialize_workgroup_memory, + }, + } + }), + mesh: pipe::MeshState { + stage: pipe::ProgrammableStageDescriptor { + module: desc.mesh.module.inner.as_core().id, + entry_point: desc.mesh.entry_point.map(Borrowed), + constants: mesh_constants, + zero_initialize_workgroup_memory: desc + .mesh + .compilation_options + .zero_initialize_workgroup_memory, + }, + }, + layout: desc.layout.map(|layout| layout.inner.as_core().id), + primitive: desc.primitive, + depth_stencil: desc.depth_stencil.clone(), + multisample: desc.multisample, + fragment: desc.fragment.as_ref().map(|frag| { + let frag_constants = frag + .compilation_options + .constants + .iter() + .map(|&(key, value)| (String::from(key), value)) + .collect(); + pipe::FragmentState { + stage: pipe::ProgrammableStageDescriptor { + module: frag.module.inner.as_core().id, + entry_point: frag.entry_point.map(Borrowed), + constants: frag_constants, + zero_initialize_workgroup_memory: frag + .compilation_options + .zero_initialize_workgroup_memory, + }, + targets: Borrowed(frag.targets), + } + }), + multiview: desc.multiview, + cache: desc.cache.map(|cache| cache.inner.as_core().id), + }; + + let (id, error) = + self.context + .0 + .device_create_mesh_pipeline(self.id, &descriptor, None, None); + if let Some(cause) = error { + if let wgc::pipeline::CreateRenderPipelineError::Internal { stage, ref error } = cause { + log::error!("Shader translation error for stage {:?}: {}", stage, error); + log::error!("Please report it to https://github.com/gfx-rs/wgpu"); + } + self.context.handle_error( + &self.error_sink, + cause, + desc.label, + "Device::create_render_pipeline", + ); + } + CoreRenderPipeline { + context: self.context.clone(), + id, + error_sink: Arc::clone(&self.error_sink), + } + .into() + } + fn create_compute_pipeline( &self, desc: &crate::ComputePipelineDescriptor<'_>, @@ -3053,6 +3149,22 @@ impl dispatch::RenderPassInterface for CoreRenderPass { } } + fn draw_mesh_tasks(&mut self, group_count_x: u32, group_count_y: u32, group_count_z: u32) { + if let Err(cause) = self.context.0.render_pass_draw_mesh_tasks( + &mut self.pass, + group_count_x, + group_count_y, + group_count_z, + ) { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RenderPass::draw_mesh_tasks", + ); + } + } + fn draw_indirect( &mut self, indirect_buffer: &dispatch::DispatchBuffer, @@ -3095,6 +3207,27 @@ impl dispatch::RenderPassInterface for CoreRenderPass { } } + fn draw_mesh_tasks_indirect( + &mut self, + indirect_buffer: &dispatch::DispatchBuffer, + indirect_offset: crate::BufferAddress, + ) { + let indirect_buffer = indirect_buffer.as_core(); + + if let Err(cause) = self.context.0.render_pass_draw_mesh_tasks_indirect( + &mut self.pass, + indirect_buffer.id, + indirect_offset, + ) { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RenderPass::draw_mesh_tasks_indirect", + ); + } + } + fn multi_draw_indirect( &mut self, indirect_buffer: &dispatch::DispatchBuffer, @@ -3141,6 +3274,29 @@ impl dispatch::RenderPassInterface for CoreRenderPass { } } + fn multi_draw_mesh_tasks_indirect( + &mut self, + indirect_buffer: &dispatch::DispatchBuffer, + indirect_offset: crate::BufferAddress, + count: u32, + ) { + let indirect_buffer = indirect_buffer.as_core(); + + if let Err(cause) = self.context.0.render_pass_multi_draw_mesh_tasks_indirect( + &mut self.pass, + indirect_buffer.id, + indirect_offset, + count, + ) { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RenderPass::multi_draw_mesh_tasks_indirect", + ); + } + } + fn multi_draw_indirect_count( &mut self, indirect_buffer: &dispatch::DispatchBuffer, @@ -3201,6 +3357,38 @@ impl dispatch::RenderPassInterface for CoreRenderPass { } } + fn multi_draw_mesh_tasks_indirect_count( + &mut self, + indirect_buffer: &dispatch::DispatchBuffer, + indirect_offset: crate::BufferAddress, + count_buffer: &dispatch::DispatchBuffer, + count_buffer_offset: crate::BufferAddress, + max_count: u32, + ) { + let indirect_buffer = indirect_buffer.as_core(); + let count_buffer = count_buffer.as_core(); + + if let Err(cause) = self + .context + .0 + .render_pass_multi_draw_mesh_tasks_indirect_count( + &mut self.pass, + indirect_buffer.id, + indirect_offset, + count_buffer.id, + count_buffer_offset, + max_count, + ) + { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RenderPass::multi_draw_mesh_tasks_indirect_count", + ); + } + } + fn insert_debug_marker(&mut self, label: &str) { if let Err(cause) = self .context diff --git a/wgpu/src/dispatch.rs b/wgpu/src/dispatch.rs index 07924d914c..18f5b2dca2 100644 --- a/wgpu/src/dispatch.rs +++ b/wgpu/src/dispatch.rs @@ -128,6 +128,10 @@ pub trait DeviceInterface: CommonTraits { &self, desc: &crate::RenderPipelineDescriptor<'_>, ) -> DispatchRenderPipeline; + fn create_mesh_pipeline( + &self, + desc: &crate::MeshPipelineDescriptor<'_>, + ) -> DispatchRenderPipeline; fn create_compute_pipeline( &self, desc: &crate::ComputePipelineDescriptor<'_>, @@ -393,6 +397,7 @@ pub trait RenderPassInterface: CommonTraits { fn draw(&mut self, vertices: Range, instances: Range); fn draw_indexed(&mut self, indices: Range, base_vertex: i32, instances: Range); + fn draw_mesh_tasks(&mut self, group_count_x: u32, group_count_y: u32, group_count_z: u32); fn draw_indirect( &mut self, indirect_buffer: &DispatchBuffer, @@ -403,6 +408,11 @@ pub trait RenderPassInterface: CommonTraits { indirect_buffer: &DispatchBuffer, indirect_offset: crate::BufferAddress, ); + fn draw_mesh_tasks_indirect( + &mut self, + indirect_buffer: &DispatchBuffer, + indirect_offset: crate::BufferAddress, + ); fn multi_draw_indirect( &mut self, @@ -424,6 +434,12 @@ pub trait RenderPassInterface: CommonTraits { count_buffer_offset: crate::BufferAddress, max_count: u32, ); + fn multi_draw_mesh_tasks_indirect( + &mut self, + indirect_buffer: &DispatchBuffer, + indirect_offset: crate::BufferAddress, + count: u32, + ); fn multi_draw_indexed_indirect_count( &mut self, indirect_buffer: &DispatchBuffer, @@ -432,6 +448,14 @@ pub trait RenderPassInterface: CommonTraits { count_buffer_offset: crate::BufferAddress, max_count: u32, ); + fn multi_draw_mesh_tasks_indirect_count( + &mut self, + indirect_buffer: &DispatchBuffer, + indirect_offset: crate::BufferAddress, + count_buffer: &DispatchBuffer, + count_buffer_offset: crate::BufferAddress, + max_count: u32, + ); fn insert_debug_marker(&mut self, label: &str); fn push_debug_group(&mut self, group_label: &str);