diff --git a/src/gpu/pipeline.rs b/src/gpu/pipeline.rs index 0c219f7..c56a97e 100644 --- a/src/gpu/pipeline.rs +++ b/src/gpu/pipeline.rs @@ -2,11 +2,12 @@ use super::{GpuConfig, GpuContext}; use anyhow::Result; -use std::sync::Arc; +use std::collections::HashMap; +use std::sync::{Arc, Mutex, OnceLock}; use tracing::info; use wgpu::{BindGroupLayout, ComputePipeline}; -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum WorkgroupVariant { Wg64, Wg128, @@ -31,6 +32,17 @@ pub struct KangarooPipeline { impl KangarooPipeline { pub fn new(ctx: &GpuContext, variant: WorkgroupVariant) -> Result { + static PIPELINE_CACHE: OnceLock< + Mutex>, + > = OnceLock::new(); + + let device_key = Arc::as_ptr(&ctx.device) as usize; + let cache = PIPELINE_CACHE.get_or_init(|| Mutex::new(HashMap::new())); + let mut guard = cache.lock().expect("pipeline cache poisoned"); + if let Some(pipeline) = guard.get(&(device_key, variant)).cloned() { + return Ok(pipeline); + } + info!("Loading shader sources..."); let field = crate::gpu_crypto::shaders::FIELD_WGSL; @@ -148,10 +160,14 @@ impl KangarooPipeline { }); info!("Compute pipeline created"); - Ok(Self { + let pipeline = Self { pipeline: Arc::new(pipeline), bind_group_layout: Arc::new(bind_group_layout), variant, - }) + }; + + guard.insert((device_key, variant), pipeline.clone()); + + Ok(pipeline) } }