Skip to content

Allow obtaining custom implementation from wgpu api types #7541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions examples/standalone/custom_backend/src/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::pin::Pin;
use std::sync::Arc;

use wgpu::custom::{
AdapterInterface, DeviceInterface, DispatchAdapter, DispatchDevice, DispatchQueue,
DispatchShaderModule, DispatchSurface, InstanceInterface, QueueInterface, RequestAdapterFuture,
ShaderModuleInterface,
AdapterInterface, ComputePipelineInterface, DeviceInterface, DispatchAdapter, DispatchDevice,
DispatchQueue, DispatchShaderModule, DispatchSurface, InstanceInterface, QueueInterface,
RequestAdapterFuture, ShaderModuleInterface,
};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -163,9 +163,10 @@ impl DeviceInterface for CustomDevice {

fn create_compute_pipeline(
&self,
_desc: &wgpu::ComputePipelineDescriptor<'_>,
desc: &wgpu::ComputePipelineDescriptor<'_>,
) -> wgpu::custom::DispatchComputePipeline {
unimplemented!()
let module = desc.module.as_custom::<CustomShaderModule>().unwrap();
wgpu::custom::DispatchComputePipeline::custom(CustomComputePipeline(module.0.clone()))
}

unsafe fn create_pipeline_cache(
Expand Down Expand Up @@ -262,7 +263,7 @@ impl DeviceInterface for CustomDevice {
}

#[derive(Debug)]
struct CustomShaderModule(Counter);
pub struct CustomShaderModule(pub Counter);

impl ShaderModuleInterface for CustomShaderModule {
fn get_compilation_info(&self) -> Pin<Box<dyn wgpu::custom::ShaderCompilationInfoFuture>> {
Expand Down Expand Up @@ -343,3 +344,12 @@ impl QueueInterface for CustomQueue {
unimplemented!()
}
}

#[derive(Debug)]
pub struct CustomComputePipeline(pub Counter);

impl ComputePipelineInterface for CustomComputePipeline {
fn get_bind_group_layout(&self, _index: u32) -> wgpu::custom::DispatchBindGroupLayout {
unimplemented!()
}
}
18 changes: 16 additions & 2 deletions examples/standalone/custom_backend/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::marker::PhantomData;

use custom::Counter;
use custom::{Counter, CustomShaderModule};
use wgpu::{DeviceDescriptor, RequestAdapterOptions};

mod custom;
Expand Down Expand Up @@ -31,12 +31,26 @@ async fn main() {
.unwrap();
assert_eq!(counter.count(), 5);

let _module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("shader"),
source: wgpu::ShaderSource::Dummy(PhantomData),
});

let custom_module = module.as_custom::<CustomShaderModule>().unwrap();
assert_eq!(custom_module.0.count(), 6);
let _module_clone = module.clone();
assert_eq!(counter.count(), 6);

let _pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: None,
compilation_options: Default::default(),
cache: None,
});

assert_eq!(counter.count(), 7);
}
assert_eq!(counter.count(), 1);
}
6 changes: 6 additions & 0 deletions wgpu/src/api/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ impl Adapter {
}
}

#[cfg(custom)]
/// Returns custom implementation of adapter (if custom backend and is internally T)
pub fn as_custom<T: custom::AdapterInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}

#[cfg(custom)]
/// Creates Adapter from custom implementation
pub fn from_custom<T: custom::AdapterInterface>(adapter: T) -> Self {
Expand Down
8 changes: 8 additions & 0 deletions wgpu/src/api/bind_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ static_assertions::assert_impl_all!(BindGroup: Send, Sync);

crate::cmp::impl_eq_ord_hash_proxy!(BindGroup => .inner);

impl BindGroup {
#[cfg(custom)]
/// Returns custom implementation of BindGroup (if custom backend and is internally T)
pub fn as_custom<T: custom::BindGroupInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// Resource to be bound by a [`BindGroup`] for use with a pipeline.
///
/// The pipeline’s [`BindGroupLayout`] must contain a matching [`BindingType`].
Expand Down
8 changes: 8 additions & 0 deletions wgpu/src/api/bind_group_layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ static_assertions::assert_impl_all!(BindGroupLayout: Send, Sync);

crate::cmp::impl_eq_ord_hash_proxy!(BindGroupLayout => .inner);

impl BindGroupLayout {
#[cfg(custom)]
/// Returns custom implementation of BindGroupLayout (if custom backend and is internally T)
pub fn as_custom<T: custom::BindGroupLayoutInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// Describes a [`BindGroupLayout`].
///
/// For use with [`Device::create_bind_group_layout`].
Expand Down
6 changes: 6 additions & 0 deletions wgpu/src/api/blas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ impl Blas {
hal_blas_callback(None)
}
}

#[cfg(custom)]
/// Returns custom implementation of Blas (if custom backend and is internally T)
pub fn as_custom<T: crate::custom::BlasInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// Context version of [BlasTriangleGeometry].
Expand Down
6 changes: 6 additions & 0 deletions wgpu/src/api/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,12 @@ impl Buffer {
) -> BufferViewMut<'_> {
self.slice(bounds).get_mapped_range_mut()
}

#[cfg(custom)]
/// Returns custom implementation of Buffer (if custom backend and is internally T)
pub fn as_custom<T: custom::BufferInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// A slice of a [`Buffer`], to be mapped, used for vertex or index data, or the like.
Expand Down
8 changes: 8 additions & 0 deletions wgpu/src/api/command_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,11 @@ pub struct CommandBuffer {
}
#[cfg(send_sync)]
static_assertions::assert_impl_all!(CommandBuffer: Send, Sync);

impl CommandBuffer {
#[cfg(custom)]
/// Returns custom implementation of CommandBuffer (if custom backend and is internally T)
pub fn as_custom<T: custom::CommandBufferInterface>(&self) -> Option<&T> {
self.buffer.as_custom()
}
}
6 changes: 6 additions & 0 deletions wgpu/src/api/command_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ impl CommandEncoder {
hal_command_encoder_callback(None)
}
}

#[cfg(custom)]
/// Returns custom implementation of CommandEncoder (if custom backend and is internally T)
pub fn as_custom<T: custom::CommandEncoderInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// [`Features::TIMESTAMP_QUERY_INSIDE_ENCODERS`] must be enabled on the device in order to call these functions.
Expand Down
6 changes: 6 additions & 0 deletions wgpu/src/api/compute_pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ impl ComputePass<'_> {
self.inner
.dispatch_workgroups_indirect(&indirect_buffer.inner, indirect_offset);
}

#[cfg(custom)]
/// Returns custom implementation of ComputePass (if custom backend and is internally T)
pub fn as_custom<T: custom::ComputePassInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// [`Features::PUSH_CONSTANTS`] must be enabled on the device in order to call these functions.
Expand Down
6 changes: 6 additions & 0 deletions wgpu/src/api/compute_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ impl ComputePipeline {
let bind_group = self.inner.get_bind_group_layout(index);
BindGroupLayout { inner: bind_group }
}

#[cfg(custom)]
/// Returns custom implementation of ComputePipeline (if custom backend and is internally T)
pub fn as_custom<T: custom::ComputePipelineInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// Describes a compute pipeline.
Expand Down
6 changes: 6 additions & 0 deletions wgpu/src/api/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>;
static_assertions::assert_impl_all!(DeviceDescriptor<'_>: Send, Sync);

impl Device {
#[cfg(custom)]
/// Returns custom implementation of Device (if custom backend and is internally T)
pub fn as_custom<T: custom::DeviceInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}

#[cfg(custom)]
/// Creates Device from custom implementation
pub fn from_custom<T: custom::DeviceInterface>(device: T) -> Self {
Expand Down
6 changes: 6 additions & 0 deletions wgpu/src/api/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ impl Instance {
}
}

#[cfg(custom)]
/// Returns custom implementation of Instance (if custom backend and is internally T)
pub fn as_custom<T: custom::InstanceInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}

/// Retrieves all available [`Adapter`]s that match the given [`Backends`].
///
/// # Arguments
Expand Down
6 changes: 6 additions & 0 deletions wgpu/src/api/pipeline_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,10 @@ impl PipelineCache {
pub fn get_data(&self) -> Option<Vec<u8>> {
self.inner.get_data()
}

#[cfg(custom)]
/// Returns custom implementation of PipelineCache (if custom backend and is internally T)
pub fn as_custom<T: custom::PipelineCacheInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
8 changes: 8 additions & 0 deletions wgpu/src/api/pipeline_layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ static_assertions::assert_impl_all!(PipelineLayout: Send, Sync);

crate::cmp::impl_eq_ord_hash_proxy!(PipelineLayout => .inner);

impl PipelineLayout {
#[cfg(custom)]
/// Returns custom implementation of PipelineLayout (if custom backend and is internally T)
pub fn as_custom<T: custom::PipelineLayoutInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// Describes a [`PipelineLayout`].
///
/// For use with [`Device::create_pipeline_layout`].
Expand Down
8 changes: 8 additions & 0 deletions wgpu/src/api/query_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ static_assertions::assert_impl_all!(QuerySet: Send, Sync);

crate::cmp::impl_eq_ord_hash_proxy!(QuerySet => .inner);

impl QuerySet {
#[cfg(custom)]
/// Returns custom implementation of QuerySet (if custom backend and is internally T)
pub fn as_custom<T: custom::QuerySetInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// Describes a [`QuerySet`].
///
/// For use with [`Device::create_query_set`].
Expand Down
32 changes: 24 additions & 8 deletions wgpu/src/api/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ static_assertions::assert_impl_all!(Queue: Send, Sync);

crate::cmp::impl_eq_ord_hash_proxy!(Queue => .inner);

impl Queue {
#[cfg(custom)]
/// Returns custom implementation of Queue (if custom backend and is internally T)
pub fn as_custom<T: custom::QueueInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}

#[cfg(custom)]
/// Creates Queue from custom implementation
pub fn from_custom<T: custom::QueueInterface>(queue: T) -> Self {
Self {
inner: dispatch::DispatchQueue::custom(queue),
}
}
}

/// Identifier for a particular call to [`Queue::submit`]. Can be used
/// as part of an argument to [`Device::poll`] to block for a particular
/// submission to finish.
Expand Down Expand Up @@ -52,6 +68,14 @@ pub struct QueueWriteBufferView<'a> {
#[cfg(send_sync)]
static_assertions::assert_impl_all!(QueueWriteBufferView<'_>: Send, Sync);

impl QueueWriteBufferView<'_> {
#[cfg(custom)]
/// Returns custom implementation of QueueWriteBufferView (if custom backend and is internally T)
pub fn as_custom<T: custom::QueueWriteBufferInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

impl Deref for QueueWriteBufferView<'_> {
type Target = [u8];

Expand Down Expand Up @@ -82,14 +106,6 @@ impl Drop for QueueWriteBufferView<'_> {
}

impl Queue {
#[cfg(custom)]
/// Creates Queue from custom implementation
pub fn from_custom<T: custom::QueueInterface>(queue: T) -> Self {
Self {
inner: dispatch::DispatchQueue::custom(queue),
}
}

/// Copies the bytes of `data` into `buffer` starting at `offset`.
///
/// The data must be written fully in-bounds, that is, `offset + data.len() <= buffer.len()`.
Expand Down
8 changes: 8 additions & 0 deletions wgpu/src/api/render_bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ static_assertions::assert_impl_all!(RenderBundle: Send, Sync);

crate::cmp::impl_eq_ord_hash_proxy!(RenderBundle => .inner);

impl RenderBundle {
#[cfg(custom)]
/// Returns custom implementation of RenderBundle (if custom backend and is internally T)
pub fn as_custom<T: custom::RenderBundleInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// Describes a [`RenderBundle`].
///
/// For use with [`RenderBundleEncoder::finish`].
Expand Down
6 changes: 6 additions & 0 deletions wgpu/src/api/render_bundle_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ impl<'a> RenderBundleEncoder<'a> {
self.inner
.draw_indexed_indirect(&indirect_buffer.inner, indirect_offset);
}

#[cfg(custom)]
/// Returns custom implementation of RenderBundleEncoder (if custom backend and is internally T)
pub fn as_custom<T: custom::RenderBundleEncoderInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// [`Features::PUSH_CONSTANTS`] must be enabled on the device in order to call these functions.
Expand Down
6 changes: 6 additions & 0 deletions wgpu/src/api/render_pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,12 @@ impl RenderPass<'_> {
self.inner
.multi_draw_indexed_indirect(&indirect_buffer.inner, indirect_offset, count);
}

#[cfg(custom)]
/// Returns custom implementation of RenderPass (if custom backend and is internally T)
pub fn as_custom<T: custom::RenderPassInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// [`Features::MULTI_DRAW_INDIRECT_COUNT`] must be enabled on the device in order to call these functions.
Expand Down
6 changes: 6 additions & 0 deletions wgpu/src/api/render_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ impl RenderPipeline {
let layout = self.inner.get_bind_group_layout(index);
BindGroupLayout { inner: layout }
}

#[cfg(custom)]
/// Returns custom implementation of RenderPipeline (if custom backend and is internally T)
pub fn as_custom<T: custom::RenderPipelineInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// Specifies an interpretation of the bytes of a vertex buffer as vertex attributes.
Expand Down
8 changes: 8 additions & 0 deletions wgpu/src/api/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ static_assertions::assert_impl_all!(Sampler: Send, Sync);

crate::cmp::impl_eq_ord_hash_proxy!(Sampler => .inner);

impl Sampler {
#[cfg(custom)]
/// Returns custom implementation of Sampler (if custom backend and is internally T)
pub fn as_custom<T: custom::SamplerInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

/// Describes a [`Sampler`].
///
/// For use with [`Device::create_sampler`].
Expand Down
Loading