Skip to content

Commit 6666d52

Browse files
authored
Allow obtaining custom implementation from wgpu api types (#7541)
1 parent a9a3ea3 commit 6666d52

30 files changed

+262
-17
lines changed

examples/standalone/custom_backend/src/custom.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ use std::pin::Pin;
33
use std::sync::Arc;
44

55
use wgpu::custom::{
6-
AdapterInterface, DeviceInterface, DispatchAdapter, DispatchDevice, DispatchQueue,
7-
DispatchShaderModule, DispatchSurface, InstanceInterface, QueueInterface, RequestAdapterFuture,
8-
ShaderModuleInterface,
6+
AdapterInterface, ComputePipelineInterface, DeviceInterface, DispatchAdapter, DispatchDevice,
7+
DispatchQueue, DispatchShaderModule, DispatchSurface, InstanceInterface, QueueInterface,
8+
RequestAdapterFuture, ShaderModuleInterface,
99
};
1010

1111
#[derive(Debug, Clone)]
@@ -163,9 +163,10 @@ impl DeviceInterface for CustomDevice {
163163

164164
fn create_compute_pipeline(
165165
&self,
166-
_desc: &wgpu::ComputePipelineDescriptor<'_>,
166+
desc: &wgpu::ComputePipelineDescriptor<'_>,
167167
) -> wgpu::custom::DispatchComputePipeline {
168-
unimplemented!()
168+
let module = desc.module.as_custom::<CustomShaderModule>().unwrap();
169+
wgpu::custom::DispatchComputePipeline::custom(CustomComputePipeline(module.0.clone()))
169170
}
170171

171172
unsafe fn create_pipeline_cache(
@@ -265,7 +266,7 @@ impl DeviceInterface for CustomDevice {
265266
}
266267

267268
#[derive(Debug)]
268-
struct CustomShaderModule(Counter);
269+
pub struct CustomShaderModule(pub Counter);
269270

270271
impl ShaderModuleInterface for CustomShaderModule {
271272
fn get_compilation_info(&self) -> Pin<Box<dyn wgpu::custom::ShaderCompilationInfoFuture>> {
@@ -346,3 +347,12 @@ impl QueueInterface for CustomQueue {
346347
unimplemented!()
347348
}
348349
}
350+
351+
#[derive(Debug)]
352+
pub struct CustomComputePipeline(pub Counter);
353+
354+
impl ComputePipelineInterface for CustomComputePipeline {
355+
fn get_bind_group_layout(&self, _index: u32) -> wgpu::custom::DispatchBindGroupLayout {
356+
unimplemented!()
357+
}
358+
}

examples/standalone/custom_backend/src/main.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::marker::PhantomData;
22

3-
use custom::Counter;
3+
use custom::{Counter, CustomShaderModule};
44
use wgpu::{DeviceDescriptor, RequestAdapterOptions};
55

66
mod custom;
@@ -31,12 +31,26 @@ async fn main() {
3131
.unwrap();
3232
assert_eq!(counter.count(), 5);
3333

34-
let _module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
34+
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
3535
label: Some("shader"),
3636
source: wgpu::ShaderSource::Dummy(PhantomData),
3737
});
3838

39+
let custom_module = module.as_custom::<CustomShaderModule>().unwrap();
40+
assert_eq!(custom_module.0.count(), 6);
41+
let _module_clone = module.clone();
3942
assert_eq!(counter.count(), 6);
43+
44+
let _pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
45+
label: None,
46+
layout: None,
47+
module: &module,
48+
entry_point: None,
49+
compilation_options: Default::default(),
50+
cache: None,
51+
});
52+
53+
assert_eq!(counter.count(), 7);
4054
}
4155
assert_eq!(counter.count(), 1);
4256
}

wgpu/src/api/adapter.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@ impl Adapter {
133133
}
134134
}
135135

136+
#[cfg(custom)]
137+
/// Returns custom implementation of adapter (if custom backend and is internally T)
138+
pub fn as_custom<T: custom::AdapterInterface>(&self) -> Option<&T> {
139+
self.inner.as_custom()
140+
}
141+
136142
#[cfg(custom)]
137143
/// Creates Adapter from custom implementation
138144
pub fn from_custom<T: custom::AdapterInterface>(adapter: T) -> Self {

wgpu/src/api/bind_group.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ static_assertions::assert_impl_all!(BindGroup: Send, Sync);
1717

1818
crate::cmp::impl_eq_ord_hash_proxy!(BindGroup => .inner);
1919

20+
impl BindGroup {
21+
#[cfg(custom)]
22+
/// Returns custom implementation of BindGroup (if custom backend and is internally T)
23+
pub fn as_custom<T: custom::BindGroupInterface>(&self) -> Option<&T> {
24+
self.inner.as_custom()
25+
}
26+
}
27+
2028
/// Resource to be bound by a [`BindGroup`] for use with a pipeline.
2129
///
2230
/// The pipeline’s [`BindGroupLayout`] must contain a matching [`BindingType`].

wgpu/src/api/bind_group_layout.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ static_assertions::assert_impl_all!(BindGroupLayout: Send, Sync);
2020

2121
crate::cmp::impl_eq_ord_hash_proxy!(BindGroupLayout => .inner);
2222

23+
impl BindGroupLayout {
24+
#[cfg(custom)]
25+
/// Returns custom implementation of BindGroupLayout (if custom backend and is internally T)
26+
pub fn as_custom<T: custom::BindGroupLayoutInterface>(&self) -> Option<&T> {
27+
self.inner.as_custom()
28+
}
29+
}
30+
2331
/// Describes a [`BindGroupLayout`].
2432
///
2533
/// For use with [`Device::create_bind_group_layout`].

wgpu/src/api/blas.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,12 @@ impl Blas {
174174
hal_blas_callback(None)
175175
}
176176
}
177+
178+
#[cfg(custom)]
179+
/// Returns custom implementation of Blas (if custom backend and is internally T)
180+
pub fn as_custom<T: crate::custom::BlasInterface>(&self) -> Option<&T> {
181+
self.inner.as_custom()
182+
}
177183
}
178184

179185
/// Context version of [BlasTriangleGeometry].

wgpu/src/api/buffer.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,12 @@ impl Buffer {
386386
) -> BufferViewMut<'_> {
387387
self.slice(bounds).get_mapped_range_mut()
388388
}
389+
390+
#[cfg(custom)]
391+
/// Returns custom implementation of Buffer (if custom backend and is internally T)
392+
pub fn as_custom<T: custom::BufferInterface>(&self) -> Option<&T> {
393+
self.inner.as_custom()
394+
}
389395
}
390396

391397
/// A slice of a [`Buffer`], to be mapped, used for vertex or index data, or the like.

wgpu/src/api/command_buffer.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,11 @@ pub struct CommandBuffer {
1313
}
1414
#[cfg(send_sync)]
1515
static_assertions::assert_impl_all!(CommandBuffer: Send, Sync);
16+
17+
impl CommandBuffer {
18+
#[cfg(custom)]
19+
/// Returns custom implementation of CommandBuffer (if custom backend and is internally T)
20+
pub fn as_custom<T: custom::CommandBufferInterface>(&self) -> Option<&T> {
21+
self.buffer.as_custom()
22+
}
23+
}

wgpu/src/api/command_encoder.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,12 @@ impl CommandEncoder {
262262
hal_command_encoder_callback(None)
263263
}
264264
}
265+
266+
#[cfg(custom)]
267+
/// Returns custom implementation of CommandEncoder (if custom backend and is internally T)
268+
pub fn as_custom<T: custom::CommandEncoderInterface>(&self) -> Option<&T> {
269+
self.inner.as_custom()
270+
}
265271
}
266272

267273
/// [`Features::TIMESTAMP_QUERY_INSIDE_ENCODERS`] must be enabled on the device in order to call these functions.

wgpu/src/api/compute_pass.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ impl ComputePass<'_> {
9494
self.inner
9595
.dispatch_workgroups_indirect(&indirect_buffer.inner, indirect_offset);
9696
}
97+
98+
#[cfg(custom)]
99+
/// Returns custom implementation of ComputePass (if custom backend and is internally T)
100+
pub fn as_custom<T: custom::ComputePassInterface>(&self) -> Option<&T> {
101+
self.inner.as_custom()
102+
}
97103
}
98104

99105
/// [`Features::PUSH_CONSTANTS`] must be enabled on the device in order to call these functions.

0 commit comments

Comments
 (0)