Skip to content

Commit 850c3d4

Browse files
[naga] Write only the current entrypoint (#7626)
Changes the MSL and HLSL backends to support writing only a single entry point, and uses them that way in wgpu-hal. This is working towards a fix for #5885. * Increase the limit in test_stack_size
1 parent 9fccdf5 commit 850c3d4

File tree

12 files changed

+141
-45
lines changed

12 files changed

+141
-45
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ Naga now infers the correct binding layout when a resource appears only in an as
7979
#### Naga
8080

8181
- Mark `readonly_and_readwrite_storage_textures` & `packed_4x8_integer_dot_product` language extensions as implemented. By @teoxoy in [#7543](https://github.com/gfx-rs/wgpu/pull/7543)
82+
- `naga::back::hlsl::Writer::new` has a new `pipeline_options` argument. `hlsl::PipelineOptions::default()` can be passed as a default. The `shader_stage` and `entry_point` members of `pipeline_options` can be used to write only a single entry point when using the HLSL and MSL backends (GLSL and SPIR-V already had this functionality). The Metal and DX12 HALs now write only a single entry point when loading shaders. By @andyleiserson in [#7626](https://github.com/gfx-rs/wgpu/pull/7626).
8283

8384
#### D3D12
8485

benches/benches/wgpu-benchmark/shader.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,9 @@ fn backends(c: &mut Criterion) {
349349
let options = naga::back::hlsl::Options::default();
350350
let mut string = String::new();
351351
for input in &inputs.inner {
352-
let mut writer = naga::back::hlsl::Writer::new(&mut string, &options);
352+
let pipeline_options = Default::default();
353+
let mut writer =
354+
naga::back::hlsl::Writer::new(&mut string, &options, &pipeline_options);
353355
let _ = writer.write(
354356
input.module.as_ref().unwrap(),
355357
input.module_info.as_ref().unwrap(),

naga-cli/src/bin/naga.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,8 @@ fn write_output(
824824
.unwrap_pretty();
825825

826826
let mut buffer = String::new();
827-
let mut writer = hlsl::Writer::new(&mut buffer, &params.hlsl);
827+
let pipeline_options = Default::default();
828+
let mut writer = hlsl::Writer::new(&mut buffer, &params.hlsl, &pipeline_options);
828829
writer.write(&module, &info, None).unwrap_pretty();
829830
fs::write(output_path, buffer)?;
830831
}

naga/src/back/glsl/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,8 @@ pub struct PipelineOptions {
349349
pub shader_stage: ShaderStage,
350350
/// The name of the entry point.
351351
///
352-
/// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
352+
/// If no entry point that matches is found while creating a [`Writer`], an
353+
/// error will be thrown.
353354
pub entry_point: String,
354355
/// How many views to render to, if doing multiview rendering.
355356
pub multiview: Option<core::num::NonZeroU32>,

naga/src/back/hlsl/mod.rs

+22-2
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ use core::fmt::Error as FmtError;
119119

120120
use thiserror::Error;
121121

122-
use crate::{back, proc};
122+
use crate::{back, ir, proc};
123123

124124
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
125125
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
@@ -434,6 +434,22 @@ pub struct ReflectionInfo {
434434
pub entry_point_names: Vec<Result<String, EntryPointError>>,
435435
}
436436

437+
/// A subset of options that are meant to be changed per pipeline.
438+
#[derive(Debug, Default, Clone)]
439+
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
440+
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
441+
#[cfg_attr(feature = "deserialize", serde(default))]
442+
pub struct PipelineOptions {
443+
/// The entry point to write.
444+
///
445+
/// Entry points are identified by a shader stage specification,
446+
/// and a name.
447+
///
448+
/// If `None`, all entry points will be written. If `Some` and the entry
449+
/// point is not found, an error will be thrown while writing.
450+
pub entry_point: Option<(ir::ShaderStage, String)>,
451+
}
452+
437453
#[derive(Error, Debug)]
438454
pub enum Error {
439455
#[error(transparent)]
@@ -448,6 +464,8 @@ pub enum Error {
448464
Override,
449465
#[error(transparent)]
450466
ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
467+
#[error("entry point with stage {0:?} and name '{1}' not found")]
468+
EntryPointNotFound(ir::ShaderStage, String),
451469
}
452470

453471
#[derive(PartialEq, Eq, Hash)]
@@ -519,8 +537,10 @@ pub struct Writer<'a, W> {
519537
namer: proc::Namer,
520538
/// HLSL backend options
521539
options: &'a Options,
540+
/// Per-stage backend options
541+
pipeline_options: &'a PipelineOptions,
522542
/// Information about entry point arguments and result types.
523-
entry_point_io: Vec<writer::EntryPointInterface>,
543+
entry_point_io: crate::FastHashMap<usize, writer::EntryPointInterface>,
524544
/// Set of expressions that have associated temporary variables
525545
named_expressions: crate::NamedExpressions,
526546
wrapped: Wrapped,

naga/src/back/hlsl/writer.rs

+44-18
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ use super::{
1212
WrappedZeroValue,
1313
},
1414
storage::StoreValue,
15-
BackendResult, Error, FragmentEntryPoint, Options, ShaderModel,
15+
BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
1616
};
1717
use crate::{
18-
back::{self, Baked},
18+
back::{self, get_entry_points, Baked},
1919
common,
2020
proc::{self, index, NameKey},
2121
valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
@@ -123,13 +123,14 @@ struct BindingArraySamplerInfo {
123123
}
124124

125125
impl<'a, W: fmt::Write> super::Writer<'a, W> {
126-
pub fn new(out: W, options: &'a Options) -> Self {
126+
pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
127127
Self {
128128
out,
129129
names: crate::FastHashMap::default(),
130130
namer: proc::Namer::default(),
131131
options,
132-
entry_point_io: Vec::new(),
132+
pipeline_options,
133+
entry_point_io: crate::FastHashMap::default(),
133134
named_expressions: crate::NamedExpressions::default(),
134135
wrapped: super::Wrapped::default(),
135136
written_committed_intersection: false,
@@ -387,8 +388,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
387388
writeln!(self.out)?;
388389
}
389390

391+
let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
392+
.map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
393+
390394
// Write all entry points wrapped structs
391-
for (index, ep) in module.entry_points.iter().enumerate() {
395+
for index in ep_range.clone() {
396+
let ep = &module.entry_points[index];
392397
let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
393398
let ep_io = self.write_ep_interface(
394399
module,
@@ -397,7 +402,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
397402
&ep_name,
398403
fragment_entry_point,
399404
)?;
400-
self.entry_point_io.push(ep_io);
405+
self.entry_point_io.insert(index, ep_io);
401406
}
402407

403408
// Write all regular functions
@@ -442,10 +447,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
442447
writeln!(self.out)?;
443448
}
444449

445-
let mut entry_point_names = Vec::with_capacity(module.entry_points.len());
450+
let mut translated_ep_names = Vec::with_capacity(ep_range.len());
446451

447452
// Write all entry points
448-
for (index, ep) in module.entry_points.iter().enumerate() {
453+
for index in ep_range {
454+
let ep = &module.entry_points[index];
449455
let info = module_info.get_entry_point(index);
450456

451457
if !self.options.fake_missing_bindings {
@@ -462,7 +468,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
462468
}
463469
}
464470
if let Some(err) = ep_error {
465-
entry_point_names.push(Err(err));
471+
translated_ep_names.push(Err(err));
466472
continue;
467473
}
468474
}
@@ -493,10 +499,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
493499
writeln!(self.out)?;
494500
}
495501

496-
entry_point_names.push(Ok(name));
502+
translated_ep_names.push(Ok(name));
497503
}
498504

499-
Ok(super::ReflectionInfo { entry_point_names })
505+
Ok(super::ReflectionInfo {
506+
entry_point_names: translated_ep_names,
507+
})
500508
}
501509

502510
fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
@@ -816,7 +824,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
816824
ep_index: u16,
817825
) -> BackendResult {
818826
let ep = &module.entry_points[ep_index as usize];
819-
let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
827+
let ep_input = match self
828+
.entry_point_io
829+
.get_mut(&(ep_index as usize))
830+
.unwrap()
831+
.input
832+
.take()
833+
{
820834
Some(ep_input) => ep_input,
821835
None => return Ok(()),
822836
};
@@ -1432,7 +1446,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
14321446
}
14331447
}
14341448
back::FunctionType::EntryPoint(index) => {
1435-
if let Some(ref ep_output) = self.entry_point_io[index as usize].output {
1449+
if let Some(ref ep_output) =
1450+
self.entry_point_io.get(&(index as usize)).unwrap().output
1451+
{
14361452
write!(self.out, "{}", ep_output.ty_name)?;
14371453
} else {
14381454
self.write_type(module, result.ty)?;
@@ -1479,7 +1495,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
14791495
}
14801496
}
14811497
back::FunctionType::EntryPoint(ep_index) => {
1482-
if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
1498+
if let Some(ref ep_input) =
1499+
self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1500+
{
14831501
write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
14841502
} else {
14851503
let stage = module.entry_points[ep_index as usize].stage;
@@ -1501,7 +1519,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
15011519
}
15021520
}
15031521
if need_workgroup_variables_initialization {
1504-
if self.entry_point_io[ep_index as usize].input.is_some()
1522+
if self
1523+
.entry_point_io
1524+
.get(&(ep_index as usize))
1525+
.unwrap()
1526+
.input
1527+
.is_some()
15051528
|| !func.arguments.is_empty()
15061529
{
15071530
write!(self.out, ", ")?;
@@ -1870,9 +1893,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
18701893
// for entry point returns, we may need to reshuffle the outputs into a different struct
18711894
let ep_output = match func_ctx.ty {
18721895
back::FunctionType::Function(_) => None,
1873-
back::FunctionType::EntryPoint(index) => {
1874-
self.entry_point_io[index as usize].output.as_ref()
1875-
}
1896+
back::FunctionType::EntryPoint(index) => self
1897+
.entry_point_io
1898+
.get(&(index as usize))
1899+
.unwrap()
1900+
.output
1901+
.as_ref(),
18761902
};
18771903
let final_name = match ep_output {
18781904
Some(ep_output) => {

naga/src/back/mod.rs

+27
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,33 @@ impl core::fmt::Display for Level {
7979
}
8080
}
8181

82+
/// Locate the entry point(s) to write.
83+
///
84+
/// If `entry_point` is given, and the specified entry point exists, returns a
85+
/// length-1 `Range` containing the index of that entry point. If no
86+
/// `entry_point` is given, returns the complete range of entry point indices.
87+
/// If `entry_point` is given but does not exist, returns an error.
88+
#[cfg(any(hlsl_out, msl_out))]
89+
fn get_entry_points(
90+
module: &crate::ir::Module,
91+
entry_point: Option<&(crate::ir::ShaderStage, String)>,
92+
) -> Result<core::ops::Range<usize>, (crate::ir::ShaderStage, String)> {
93+
use alloc::borrow::ToOwned;
94+
95+
if let Some(&(stage, ref name)) = entry_point {
96+
let Some(ep_index) = module
97+
.entry_points
98+
.iter()
99+
.position(|ep| ep.stage == stage && ep.name == *name)
100+
else {
101+
return Err((stage, name.to_owned()));
102+
};
103+
Ok(ep_index..ep_index + 1)
104+
} else {
105+
Ok(0..module.entry_points.len())
106+
}
107+
}
108+
82109
/// Whether we're generating an entry point or a regular function.
83110
///
84111
/// Backend languages often require different code for a [`Function`]

naga/src/back/msl/mod.rs

+14-3
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ use alloc::{
5252
};
5353
use core::fmt::{Error as FmtError, Write};
5454

55-
use crate::{arena::Handle, proc::index, valid::ModuleInfo};
55+
use crate::{arena::Handle, ir, proc::index, valid::ModuleInfo};
5656

5757
mod keywords;
5858
pub mod sampler;
@@ -184,7 +184,7 @@ pub enum Error {
184184
#[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")]
185185
UnsupportedWriteableStorageBuffer,
186186
#[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")]
187-
UnsupportedWriteableStorageTexture(crate::ShaderStage),
187+
UnsupportedWriteableStorageTexture(ir::ShaderStage),
188188
#[error("can not use read-write storage textures prior to MSL 1.2")]
189189
UnsupportedRWStorageTexture,
190190
#[error("array of '{0}' is not supported for target MSL version")]
@@ -199,6 +199,8 @@ pub enum Error {
199199
UnsupportedBitCast(crate::TypeInner),
200200
#[error(transparent)]
201201
ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError),
202+
#[error("entry point with stage {0:?} and name '{1}' not found")]
203+
EntryPointNotFound(ir::ShaderStage, String),
202204
}
203205

204206
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
@@ -420,6 +422,15 @@ pub struct VertexBufferMapping {
420422
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
421423
#[cfg_attr(feature = "deserialize", serde(default))]
422424
pub struct PipelineOptions {
425+
/// The entry point to write.
426+
///
427+
/// Entry points are identified by a shader stage specification,
428+
/// and a name.
429+
///
430+
/// If `None`, all entry points will be written. If `Some` and the entry
431+
/// point is not found, an error will be thrown while writing.
432+
pub entry_point: Option<(ir::ShaderStage, String)>,
433+
423434
/// Allow `BuiltIn::PointSize` and inject it if doesn't exist.
424435
///
425436
/// Metal doesn't like this for non-point primitive topologies and requires it for
@@ -737,5 +748,5 @@ pub fn write_string(
737748

738749
#[test]
739750
fn test_error_size() {
740-
assert_eq!(size_of::<Error>(), 32);
751+
assert_eq!(size_of::<Error>(), 40);
741752
}

naga/src/back/msl/writer.rs

+10-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use half::f16;
1616
use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
1717
use crate::{
1818
arena::{Handle, HandleSet},
19-
back::{self, Baked},
19+
back::{self, get_entry_points, Baked},
2020
common,
2121
proc::{
2222
self,
@@ -5872,10 +5872,15 @@ template <typename A>
58725872
self.named_expressions.clear();
58735873
}
58745874

5875+
let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
5876+
.map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
5877+
58755878
let mut info = TranslationInfo {
5876-
entry_point_names: Vec::with_capacity(module.entry_points.len()),
5879+
entry_point_names: Vec::with_capacity(ep_range.len()),
58775880
};
5878-
for (ep_index, ep) in module.entry_points.iter().enumerate() {
5881+
5882+
for ep_index in ep_range {
5883+
let ep = &module.entry_points[ep_index];
58795884
let fun = &ep.function;
58805885
let fun_info = mod_info.get_entry_point(ep_index);
58815886
let mut ep_error = None;
@@ -7076,8 +7081,8 @@ fn test_stack_size() {
70767081
}
70777082
let stack_size = addresses_end - addresses_start;
70787083
// check the size (in debug only)
7079-
// last observed macOS value: 20528 (CI)
7080-
if !(11000..=25000).contains(&stack_size) {
7084+
// last observed macOS value: 25904 (CI), 2025-04-29
7085+
if !(11000..=27000).contains(&stack_size) {
70817086
panic!("`put_expression` stack size {stack_size} has changed!");
70827087
}
70837088
}

naga/tests/naga/snapshots.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,8 @@ fn write_output_hlsl(
741741
.expect("override evaluation failed");
742742

743743
let mut buffer = String::new();
744-
let mut writer = hlsl::Writer::new(&mut buffer, options);
744+
let pipeline_options = Default::default();
745+
let mut writer = hlsl::Writer::new(&mut buffer, options, &pipeline_options);
745746
let reflection_info = writer
746747
.write(&module, &info, frag_ep.as_ref())
747748
.expect("HLSL write failed");

0 commit comments

Comments
 (0)