Skip to content

Initial precompiled shaders implementation #7834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/api-specs/precompiled_shaders.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Precompiled shaders
There are two main issues an implementation needs to cover
* Including and using reflection info
* Exposing how individual backends compile shaders outside of the backends
What changes need to be made
* I propose making a new crate, `wgpu-shaders`
* This crate would be a "wrapper" around `naga`, that would include all shader compiling logic
* This logic could then be used by both compile time macros and `wgpu-hal` itself
* This crate would include "backend"-specific parts, but it wouldn't need actual access to backends
* I also propose moving many `naga` types into `wgpu-types`, primarily those useful for reflection.
* The type to look out for here is `wgpu_core::validation::Interface`. This would also need to be moved into `wgpu-types`
18 changes: 18 additions & 0 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,24 @@ impl Global {
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Generic(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
},
data,
});
Expand Down
28 changes: 27 additions & 1 deletion wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1796,11 +1796,37 @@ impl Device {
pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => {
self.require_features(wgt::Features::MSL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Msl {
shader: inner.source.to_string(),
shader: inner.source,
entry_point: inner.entry_point.to_string(),
num_workgroups: inner.num_workgroups,
}
}
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Dxil {
shader: inner.source,
entry_point: inner.entry_point.clone(),
num_workgroups: inner.num_workgroups,
}
}
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Hlsl {
shader: inner.source,
entry_point: inner.entry_point.clone(),
num_workgroups: inner.num_workgroups,
}
}
pipeline::ShaderModuleDescriptorPassthrough::Generic(inner) => {
self.require_features(wgt::Features::EXPERIMENTAL_PRECOMPILED_SHADERS)?;
hal::ShaderInput::Generic {
entry_point: inner.entry_point.clone(),
num_workgroups: inner.num_workgroups,
spirv: inner.spirv.as_deref(),
dxil: inner.dxil.as_deref(),
msl: inner.msl.as_deref(),
}
}
};

let hal_desc = hal::ShaderModuleDescriptor {
Expand Down
3 changes: 2 additions & 1 deletion wgpu-hal/src/dx12/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ impl super::Adapter {
| wgt::Features::DUAL_SOURCE_BLENDING
| wgt::Features::TEXTURE_FORMAT_NV12
| wgt::Features::FLOAT32_FILTERABLE
| wgt::Features::TEXTURE_ATOMIC;
| wgt::Features::TEXTURE_ATOMIC
| wgt::Features::EXPERIMENTAL_PRECOMPILED_SHADERS;

//TODO: in order to expose this, we need to run a compute shader
// that extract the necessary statistics out of the D3D12 result.
Expand Down
177 changes: 122 additions & 55 deletions wgpu-hal/src/dx12/device.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use alloc::borrow::ToOwned;
use alloc::{
borrow::Cow,
string::{String, ToString as _},
Expand Down Expand Up @@ -264,27 +265,8 @@ impl super::Device {
naga_stage: naga::ShaderStage,
fragment_stage: Option<&crate::ProgrammableStage<super::ShaderModule>>,
) -> Result<super::CompiledShader, crate::PipelineError> {
use naga::back::hlsl;

let frag_ep = fragment_stage
.map(|fs_stage| {
hlsl::FragmentEntryPoint::new(&fs_stage.module.naga.module, fs_stage.entry_point)
.ok_or(crate::PipelineError::EntryPoint(
naga::ShaderStage::Fragment,
))
})
.transpose()?;

let stage_bit = auxil::map_naga_stage(naga_stage);

let (module, info) = naga::back::pipeline_constants::process_overrides(
&stage.module.naga.module,
&stage.module.naga.info,
Some((naga_stage, stage.entry_point)),
stage.constants,
)
.map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}")))?;

let needs_temp_options = stage.zero_initialize_workgroup_memory
!= layout.naga_options.zero_initialize_workgroup_memory
|| stage.module.runtime_checks.bounds_checks != layout.naga_options.restrict_indexing
Expand All @@ -301,43 +283,90 @@ impl super::Device {
&layout.naga_options
};

let pipeline_options = hlsl::PipelineOptions {
entry_point: Some((naga_stage, stage.entry_point.to_string())),
};
let key = match &stage.module.source {
super::ShaderModuleSource::Naga(naga_shader) => {
use naga::back::hlsl;

let frag_ep = match fragment_stage {
Some(crate::ProgrammableStage {
module:
super::ShaderModule {
source: super::ShaderModuleSource::Naga(naga_shader),
..
},
entry_point,
..
}) => Some(
hlsl::FragmentEntryPoint::new(&naga_shader.module, entry_point).ok_or(
crate::PipelineError::EntryPoint(naga::ShaderStage::Fragment),
),
),
_ => None,
}
.transpose()?;
let (module, info) = naga::back::pipeline_constants::process_overrides(
&naga_shader.module,
&naga_shader.info,
Some((naga_stage, stage.entry_point)),
stage.constants,
)
.map_err(|e| {
crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}"))
})?;

//TODO: reuse the writer
let (source, entry_point) = {
let mut source = String::new();
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
let pipeline_options = hlsl::PipelineOptions {
entry_point: Some((naga_stage, stage.entry_point.to_string())),
};

profiling::scope!("naga::back::hlsl::write");
let mut reflection_info = writer
.write(&module, &info, frag_ep.as_ref())
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?;
//TODO: reuse the writer
let (source, entry_point) = {
let mut source = String::new();
let mut writer =
hlsl::Writer::new(&mut source, naga_options, &pipeline_options);

assert_eq!(reflection_info.entry_point_names.len(), 1);
profiling::scope!("naga::back::hlsl::write");
let mut reflection_info = writer
.write(&module, &info, frag_ep.as_ref())
.map_err(|e| {
crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}"))
})?;

let entry_point = reflection_info
.entry_point_names
.pop()
.unwrap()
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
assert_eq!(reflection_info.entry_point_names.len(), 1);

(source, entry_point)
};
let entry_point = reflection_info
.entry_point_names
.pop()
.unwrap()
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;

log::info!(
"Naga generated shader for {:?} at {:?}:\n{}",
entry_point,
naga_stage,
source
);
(source, entry_point)
};
log::info!(
"Naga generated shader for {:?} at {:?}:\n{}",
entry_point,
naga_stage,
source
);

let key = ShaderCacheKey {
source,
entry_point,
stage: naga_stage,
shader_model: naga_options.shader_model,
ShaderCacheKey {
source,
entry_point,
stage: naga_stage,
shader_model: naga_options.shader_model,
}
}
super::ShaderModuleSource::HlslPassthrough(passthrough) => ShaderCacheKey {
source: passthrough.shader.clone(),
entry_point: passthrough.entry_point.clone(),
stage: naga_stage,
shader_model: naga_options.shader_model,
},

super::ShaderModuleSource::DxilPassthrough(passthrough) => {
return Ok(super::CompiledShader::Precompiled(
passthrough.shader.clone(),
))
}
};

{
Expand All @@ -351,11 +380,7 @@ impl super::Device {

let source_name = stage.module.raw_name.as_deref();

let full_stage = format!(
"{}_{}",
naga_stage.to_hlsl_str(),
naga_options.shader_model.to_str()
);
let full_stage = format!("{}_{}", naga_stage.to_hlsl_str(), key.shader_model.to_str());

let compiled_shader = self.compiler_container.compile(
self,
Expand Down Expand Up @@ -1671,7 +1696,7 @@ impl crate::Device for super::Device {
.and_then(|label| alloc::ffi::CString::new(label).ok());
match shader {
crate::ShaderInput::Naga(naga) => Ok(super::ShaderModule {
naga,
source: super::ShaderModuleSource::Naga(naga),
raw_name,
runtime_checks: desc.runtime_checks,
}),
Expand All @@ -1681,6 +1706,48 @@ impl crate::Device for super::Device {
crate::ShaderInput::Msl { .. } => {
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Dxil {
shader,
entry_point,
num_workgroups,
} => Ok(super::ShaderModule {
source: super::ShaderModuleSource::DxilPassthrough(super::DxilPassthroughShader {
shader: shader.to_vec(),
entry_point,
num_workgroups,
}),
raw_name,
runtime_checks: desc.runtime_checks,
}),
crate::ShaderInput::Hlsl {
shader,
entry_point,
num_workgroups,
} => Ok(super::ShaderModule {
source: super::ShaderModuleSource::HlslPassthrough(super::HlslPassthroughShader {
shader: shader.to_owned(),
entry_point,
num_workgroups,
}),
raw_name,
runtime_checks: desc.runtime_checks,
}),
crate::ShaderInput::Generic {
dxil,
entry_point,
num_workgroups,
..
} => Ok(super::ShaderModule {
source: super::ShaderModuleSource::DxilPassthrough(super::DxilPassthroughShader {
shader: dxil
.expect("Generic passthrough was given to dx12 backend without DXIL data")
.to_vec(),
entry_point,
num_workgroups,
}),
raw_name,
runtime_checks: desc.runtime_checks,
}),
}
}
unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) {
Expand Down
25 changes: 24 additions & 1 deletion wgpu-hal/src/dx12/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ impl crate::DynPipelineLayout for PipelineLayout {}

#[derive(Debug)]
pub struct ShaderModule {
naga: crate::NagaShader,
source: ShaderModuleSource,
raw_name: Option<alloc::ffi::CString>,
runtime_checks: wgt::ShaderRuntimeChecks,
}
Expand Down Expand Up @@ -1104,6 +1104,7 @@ pub(super) struct ShaderCacheValue {
pub(super) enum CompiledShader {
Dxc(Direct3D::Dxc::IDxcBlob),
Fxc(Direct3D::ID3DBlob),
Precompiled(Vec<u8>),
}

impl CompiledShader {
Expand All @@ -1117,6 +1118,10 @@ impl CompiledShader {
pShaderBytecode: unsafe { shader.GetBufferPointer() },
BytecodeLength: unsafe { shader.GetBufferSize() },
},
CompiledShader::Precompiled(shader) => Direct3D12::D3D12_SHADER_BYTECODE {
pShaderBytecode: shader.as_ptr().cast(),
BytecodeLength: shader.len(),
},
}
}
}
Expand Down Expand Up @@ -1485,3 +1490,21 @@ impl crate::Queue for Queue {
(1_000_000_000.0 / frequency as f64) as f32
}
}
#[derive(Debug)]
pub struct DxilPassthroughShader {
pub shader: Vec<u8>,
pub entry_point: String,
pub num_workgroups: (u32, u32, u32),
}
#[derive(Debug)]
pub struct HlslPassthroughShader {
pub shader: String,
pub entry_point: String,
pub num_workgroups: (u32, u32, u32),
}
#[derive(Debug)]
pub enum ShaderModuleSource {
Naga(crate::NagaShader),
DxilPassthrough(DxilPassthroughShader),
HlslPassthrough(HlslPassthroughShader),
}
6 changes: 6 additions & 0 deletions wgpu-hal/src/gles/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,12 @@ impl crate::Device for super::Device {
panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::Naga(naga) => naga,
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::Generic { .. } => {
panic!("`Features::EXPERIMENTAL_PRECOMPILED_SHADERS` is not enabled")
}
},
label: desc.label.map(|str| str.to_string()),
id: self.shared.next_shader_id.fetch_add(1, Ordering::Relaxed),
Expand Down
Loading