Skip to content
12 changes: 12 additions & 0 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,18 @@ impl Global {
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
},
data,
});
Expand Down
16 changes: 16 additions & 0 deletions wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1801,6 +1801,22 @@ impl Device {
num_workgroups: inner.num_workgroups,
}
}
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Dxil {
shader: inner.source,
entry_point: inner.entry_point.clone(),
num_workgroups: inner.num_workgroups,
}
}
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Hlsl {
shader: inner.source,
entry_point: inner.entry_point.clone(),
num_workgroups: inner.num_workgroups,
}
}
};

let hal_desc = hal::ShaderModuleDescriptor {
Expand Down
161 changes: 106 additions & 55 deletions wgpu-hal/src/dx12/device.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use alloc::borrow::ToOwned;
use alloc::{
borrow::Cow,
string::{String, ToString as _},
Expand Down Expand Up @@ -264,27 +265,8 @@ impl super::Device {
naga_stage: naga::ShaderStage,
fragment_stage: Option<&crate::ProgrammableStage<super::ShaderModule>>,
) -> Result<super::CompiledShader, crate::PipelineError> {
use naga::back::hlsl;

let frag_ep = fragment_stage
.map(|fs_stage| {
hlsl::FragmentEntryPoint::new(&fs_stage.module.naga.module, fs_stage.entry_point)
.ok_or(crate::PipelineError::EntryPoint(
naga::ShaderStage::Fragment,
))
})
.transpose()?;

let stage_bit = auxil::map_naga_stage(naga_stage);

let (module, info) = naga::back::pipeline_constants::process_overrides(
&stage.module.naga.module,
&stage.module.naga.info,
Some((naga_stage, stage.entry_point)),
stage.constants,
)
.map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}")))?;

let needs_temp_options = stage.zero_initialize_workgroup_memory
!= layout.naga_options.zero_initialize_workgroup_memory
|| stage.module.runtime_checks.bounds_checks != layout.naga_options.restrict_indexing
Expand All @@ -301,43 +283,90 @@ impl super::Device {
&layout.naga_options
};

let pipeline_options = hlsl::PipelineOptions {
entry_point: Some((naga_stage, stage.entry_point.to_string())),
};
let key = match &stage.module.source {
super::ShaderModuleSource::Naga(naga_shader) => {
use naga::back::hlsl;

//TODO: reuse the writer
let (source, entry_point) = {
let mut source = String::new();
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
let frag_ep = match fragment_stage {
Some(crate::ProgrammableStage {
module:
super::ShaderModule {
source: super::ShaderModuleSource::Naga(naga_shader),
..
},
entry_point,
..
}) => Some(
hlsl::FragmentEntryPoint::new(&naga_shader.module, entry_point).ok_or(
crate::PipelineError::EntryPoint(naga::ShaderStage::Fragment),
),
),
_ => None,
}
.transpose()?;
let (module, info) = naga::back::pipeline_constants::process_overrides(
&naga_shader.module,
&naga_shader.info,
Some((naga_stage, stage.entry_point)),
stage.constants,
)
.map_err(|e| {
crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}"))
})?;

profiling::scope!("naga::back::hlsl::write");
let mut reflection_info = writer
.write(&module, &info, frag_ep.as_ref())
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?;
let pipeline_options = hlsl::PipelineOptions {
entry_point: Some((naga_stage, stage.entry_point.to_string())),
};

assert_eq!(reflection_info.entry_point_names.len(), 1);
//TODO: reuse the writer
let (source, entry_point) = {
let mut source = String::new();
let mut writer =
hlsl::Writer::new(&mut source, naga_options, &pipeline_options);

let entry_point = reflection_info
.entry_point_names
.pop()
.unwrap()
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
profiling::scope!("naga::back::hlsl::write");
let mut reflection_info = writer
.write(&module, &info, frag_ep.as_ref())
.map_err(|e| {
crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}"))
})?;

(source, entry_point)
};
assert_eq!(reflection_info.entry_point_names.len(), 1);

log::info!(
"Naga generated shader for {:?} at {:?}:\n{}",
entry_point,
naga_stage,
source
);
let entry_point = reflection_info
.entry_point_names
.pop()
.unwrap()
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;

let key = ShaderCacheKey {
source,
entry_point,
stage: naga_stage,
shader_model: naga_options.shader_model,
(source, entry_point)
};
log::info!(
"Naga generated shader for {:?} at {:?}:\n{}",
entry_point,
naga_stage,
source
);

ShaderCacheKey {
source,
entry_point,
stage: naga_stage,
shader_model: naga_options.shader_model,
}
}
super::ShaderModuleSource::HlslPassthrough(passthrough) => ShaderCacheKey {
source: passthrough.shader.clone(),
entry_point: passthrough.entry_point.clone(),
stage: naga_stage,
shader_model: naga_options.shader_model,
},

super::ShaderModuleSource::DxilPassthrough(passthrough) => {
return Ok(super::CompiledShader::Precompiled(
passthrough.shader.clone(),
))
}
};

{
Expand All @@ -351,11 +380,7 @@ impl super::Device {

let source_name = stage.module.raw_name.as_deref();

let full_stage = format!(
"{}_{}",
naga_stage.to_hlsl_str(),
naga_options.shader_model.to_str()
);
let full_stage = format!("{}_{}", naga_stage.to_hlsl_str(), key.shader_model.to_str());

let compiled_shader = self.compiler_container.compile(
self,
Expand Down Expand Up @@ -1671,7 +1696,7 @@ impl crate::Device for super::Device {
.and_then(|label| alloc::ffi::CString::new(label).ok());
match shader {
crate::ShaderInput::Naga(naga) => Ok(super::ShaderModule {
naga,
source: super::ShaderModuleSource::Naga(naga),
raw_name,
runtime_checks: desc.runtime_checks,
}),
Expand All @@ -1681,6 +1706,32 @@ impl crate::Device for super::Device {
crate::ShaderInput::Msl { .. } => {
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Dxil {
shader,
entry_point,
num_workgroups,
} => Ok(super::ShaderModule {
source: super::ShaderModuleSource::DxilPassthrough(super::DxilPassthroughShader {
shader: shader.to_vec(),
entry_point,
num_workgroups,
}),
raw_name,
runtime_checks: desc.runtime_checks,
}),
crate::ShaderInput::Hlsl {
shader,
entry_point,
num_workgroups,
} => Ok(super::ShaderModule {
source: super::ShaderModuleSource::HlslPassthrough(super::HlslPassthroughShader {
shader: shader.to_owned(),
entry_point,
num_workgroups,
}),
raw_name,
runtime_checks: desc.runtime_checks,
}),
}
}
unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) {
Expand Down
27 changes: 26 additions & 1 deletion wgpu-hal/src/dx12/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,7 @@ impl crate::DynPipelineLayout for PipelineLayout {}

#[derive(Debug)]
pub struct ShaderModule {
naga: crate::NagaShader,
source: ShaderModuleSource,
raw_name: Option<alloc::ffi::CString>,
runtime_checks: wgt::ShaderRuntimeChecks,
}
Expand Down Expand Up @@ -1109,6 +1109,7 @@ pub(super) struct ShaderCacheValue {
pub(super) enum CompiledShader {
Dxc(Direct3D::Dxc::IDxcBlob),
Fxc(Direct3D::ID3DBlob),
Precompiled(Vec<u8>),
}

impl CompiledShader {
Expand All @@ -1122,6 +1123,10 @@ impl CompiledShader {
pShaderBytecode: unsafe { shader.GetBufferPointer() },
BytecodeLength: unsafe { shader.GetBufferSize() },
},
CompiledShader::Precompiled(shader) => Direct3D12::D3D12_SHADER_BYTECODE {
pShaderBytecode: shader.as_ptr().cast(),
BytecodeLength: shader.len(),
},
}
}
}
Expand Down Expand Up @@ -1490,3 +1495,23 @@ impl crate::Queue for Queue {
(1_000_000_000.0 / frequency as f64) as f32
}
}
#[derive(Debug)]
pub struct DxilPassthroughShader {
pub shader: Vec<u8>,
pub entry_point: String,
pub num_workgroups: (u32, u32, u32),
}

#[derive(Debug)]
pub struct HlslPassthroughShader {
pub shader: String,
pub entry_point: String,
pub num_workgroups: (u32, u32, u32),
}

#[derive(Debug)]
pub enum ShaderModuleSource {
Naga(crate::NagaShader),
DxilPassthrough(DxilPassthroughShader),
HlslPassthrough(HlslPassthroughShader),
}
3 changes: 3 additions & 0 deletions wgpu-hal/src/gles/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,9 @@ impl crate::Device for super::Device {
panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::Naga(naga) => naga,
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
}
},
label: desc.label.map(|str| str.to_string()),
id: self.shared.next_shader_id.fetch_add(1, Ordering::Relaxed),
Expand Down
10 changes: 10 additions & 0 deletions wgpu-hal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2104,6 +2104,16 @@ pub enum ShaderInput<'a> {
num_workgroups: (u32, u32, u32),
},
SpirV(&'a [u32]),
Dxil {
shader: &'a [u8],
entry_point: String,
num_workgroups: (u32, u32, u32),
},
Hlsl {
shader: &'a str,
entry_point: String,
num_workgroups: (u32, u32, u32),
},
}

pub struct ShaderModuleDescriptor<'a> {
Expand Down
3 changes: 3 additions & 0 deletions wgpu-hal/src/metal/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,9 @@ impl crate::Device for super::Device {
crate::ShaderInput::SpirV(_) => {
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled for this backend")
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions wgpu-hal/src/vulkan/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1908,6 +1908,9 @@ impl crate::Device for super::Device {
crate::ShaderInput::Msl { .. } => {
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv),
};

Expand Down
10 changes: 10 additions & 0 deletions wgpu-types/src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,16 @@ bitflags_array! {
///
/// [BlasTriangleGeometrySizeDescriptor::vertex_format]: super::BlasTriangleGeometrySizeDescriptor
const EXTENDED_ACCELERATION_STRUCTURE_VERTEX_FORMATS = 1 << 51;

/// Enables creating shader modules from DirectX HLSL or DXIL shaders (unsafe)
///
/// HLSL/DXIL data is not parsed or interpreted in any way
///
/// Supported platforms:
/// - DX12
///
/// This is a native only feature.
const HLSL_DXIL_SHADER_PASSTHROUGH = 1 << 53;
}

/// Features that are not guaranteed to be supported.
Expand Down
Loading