diff --git a/cli/src/main.rs b/cli/src/main.rs index 7f20dcc..b163793 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -850,6 +850,29 @@ fn cmd_config_show(config: &EchoConfig) { "auto" } ); + println!(); + println!(" Embedding Provider:"); + println!( + " {:25} {:>15} {}", + "embedding_provider", + config.embedding_provider.to_string(), + source( + "SHRIMPK_EMBEDDING_PROVIDER", + fc.embedding_provider.is_some() + ) + ); + println!( + " {:25} {:>15} {}", + "embedding_model", + &config.embedding_model, + source("SHRIMPK_EMBEDDING_MODEL", fc.embedding_model.is_some()) + ); + println!( + " {:25} {:>15} {}", + "embedding_api_url", + &config.embedding_api_url, + source("SHRIMPK_EMBEDDING_API_URL", fc.embedding_api_url.is_some()) + ); } fn cmd_config_set(key: &str, value: &str) -> anyhow::Result<()> { @@ -872,6 +895,11 @@ fn cmd_config_set(key: &str, value: &str) -> anyhow::Result<()> { "enrichment_model" => fc.enrichment_model = Some(value.to_string()), "consolidation_provider" => fc.consolidation_provider = Some(value.to_string()), "max_facts_per_memory" => fc.max_facts_per_memory = Some(value.parse()?), + "embedding_provider" => { + fc.embedding_provider = Some(value.parse().map_err(|e: String| anyhow::anyhow!(e))?) + } + "embedding_model" => fc.embedding_model = Some(value.to_string()), + "embedding_api_url" => fc.embedding_api_url = Some(value.to_string()), other => anyhow::bail!("Unknown config key: \"{other}\""), } diff --git a/crates/shrimpk-core/src/config.rs b/crates/shrimpk-core/src/config.rs index e3254de..80b1521 100644 --- a/crates/shrimpk-core/src/config.rs +++ b/crates/shrimpk-core/src/config.rs @@ -45,6 +45,45 @@ impl std::str::FromStr for RerankerBackend { } } +/// Backend for embedding vector generation. +/// +/// Controls which embedding model and provider is used for memory storage +/// and echo queries. The default `Fastembed` backend uses a local ONNX model +/// (BGE-small-EN-v1.5, 384-dim) with zero external API calls. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +pub enum EmbeddingBackend { + /// Local fastembed ONNX model (default). Zero network calls. + /// Models: BGE-small-EN-v1.5 (384-dim), all-MiniLM-L6-v2 (384-dim), etc. + #[default] + Fastembed, + /// OpenAI-compatible embedding API (local or cloud). + /// Requires `embedding_api_url` and `embedding_model` to be set. + /// Works with: OpenAI, Ollama `/api/embeddings`, LiteLLM, vLLM, etc. + OpenAI, +} + +impl std::fmt::Display for EmbeddingBackend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Fastembed => write!(f, "fastembed"), + Self::OpenAI => write!(f, "openai"), + } + } +} + +impl std::str::FromStr for EmbeddingBackend { + type Err = String; + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "fastembed" | "local" | "onnx" => Ok(Self::Fastembed), + "openai" | "api" | "ollama" => Ok(Self::OpenAI), + _ => Err(format!( + "invalid embedding provider '{s}': expected fastembed or openai" + )), + } + } +} + /// Quantization mode for embedding vectors in the echo index. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] pub enum QuantizationMode { @@ -277,6 +316,21 @@ pub struct EchoConfig { /// Embedding dimension for speech channel. Default: 640 (ECAPA-TDNN 256 + Whisper-tiny 384). #[serde(default = "default_speech_dim")] pub speech_embedding_dim: usize, + + // --- Embedding provider (KS75) --- + /// Embedding backend: `Fastembed` (local ONNX, default) or `OpenAI` (API). + #[serde(default)] + pub embedding_provider: EmbeddingBackend, + /// Model name for the embedding provider. + /// Fastembed: "BGE-small-EN-v1.5" (default), "all-MiniLM-L6-v2", etc. + /// OpenAI: "text-embedding-3-small", "nomic-embed-text", Ollama model name, etc. + #[serde(default = "default_embedding_model")] + pub embedding_model: String, + /// API URL for OpenAI-compatible embedding providers. + /// Only used when `embedding_provider = OpenAI`. + /// Default: "http://127.0.0.1:11434" (Ollama). + #[serde(default = "default_embedding_api_url")] + pub embedding_api_url: String, } fn default_true() -> bool { @@ -317,6 +371,13 @@ fn default_speech_dim() -> usize { 640 } +fn default_embedding_model() -> String { + "BGE-small-EN-v1.5".to_string() +} +fn default_embedding_api_url() -> String { + "http://127.0.0.1:11434".to_string() +} + fn default_proxy_target() -> String { "http://127.0.0.1:11434".to_string() } @@ -411,6 +472,9 @@ impl Default for EchoConfig { enabled_modalities: default_modalities(), vision_embedding_dim: default_vision_dim(), speech_embedding_dim: default_speech_dim(), + embedding_provider: EmbeddingBackend::default(), + embedding_model: default_embedding_model(), + embedding_api_url: default_embedding_api_url(), } } } @@ -476,6 +540,37 @@ impl EchoConfig { } } + /// Infer the embedding dimension from the configured model name. + /// + /// Returns the known dimension for well-known models, or falls back to + /// `self.embedding_dim` (the explicitly configured value) if the model + /// is not recognized. This lets users set `embedding_model` without also + /// needing to manually set `embedding_dim`. + pub fn infer_embedding_dim(&self) -> usize { + match self.embedding_model.to_lowercase().as_str() { + // fastembed ONNX models + s if s.contains("bge-small") => 384, + s if s.contains("bge-base") => 768, + s if s.contains("bge-large") => 1024, + s if s.contains("bge-m3") => 1024, + s if s.contains("gte-large") => 1024, + s if s.contains("gte-base") => 768, + s if s.contains("minilm-l6") => 384, + s if s.contains("minilm-l12") => 384, + // OpenAI + s if s.contains("text-embedding-3-small") => 1536, + s if s.contains("text-embedding-3-large") => 3072, + s if s.contains("text-embedding-ada") => 1536, + // Ollama common models + s if s.contains("nomic-embed-text") => 768, + s if s.contains("mxbai-embed-large") => 1024, + s if s.contains("all-minilm") => 384, + s if s.contains("snowflake-arctic-embed") => 1024, + // Fallback to explicit config + _ => self.embedding_dim, + } + } + /// Estimated index size in bytes for the current config. pub fn estimated_index_bytes(&self) -> u64 { let bytes_per_entry = self.quantization.bytes_per_vector(self.embedding_dim) + 100; @@ -537,6 +632,9 @@ pub struct FileConfig { pub enabled_modalities: Option>, pub vision_embedding_dim: Option, pub speech_embedding_dim: Option, + pub embedding_provider: Option, + pub embedding_model: Option, + pub embedding_api_url: Option, } /// Default data directory: `~/.shrimpk-kernel/` @@ -642,6 +740,7 @@ pub fn resolve_config() -> crate::Result { let mut config = EchoConfig::auto_detect(); // Layer 2: file overrides + let mut dim_set_by_file = false; if let Some(fc) = load_config_file()? { if let Some(v) = fc.max_memories { config.max_memories = v; @@ -663,6 +762,7 @@ pub fn resolve_config() -> crate::Result { } if let Some(v) = fc.embedding_dim { config.embedding_dim = v; + dim_set_by_file = true; } if let Some(v) = fc.use_lsh { config.use_lsh = v; @@ -781,6 +881,15 @@ pub fn resolve_config() -> crate::Result { if let Some(v) = fc.speech_embedding_dim { config.speech_embedding_dim = v; } + if let Some(v) = fc.embedding_provider { + config.embedding_provider = v; + } + if let Some(v) = fc.embedding_model { + config.embedding_model = v; + } + if let Some(v) = fc.embedding_api_url { + config.embedding_api_url = v; + } } // Layer 3: env var overrides (highest priority) @@ -844,6 +953,26 @@ pub fn resolve_config() -> crate::Result { config.hebbian_prune_threshold = v; } + if let Ok(v) = std::env::var("SHRIMPK_EMBEDDING_PROVIDER") + && let Ok(provider) = v.parse::() + { + config.embedding_provider = provider; + } + if let Ok(v) = std::env::var("SHRIMPK_EMBEDDING_MODEL") { + config.embedding_model = v; + } + if let Ok(v) = std::env::var("SHRIMPK_EMBEDDING_API_URL") { + config.embedding_api_url = v; + } + + // Auto-infer embedding_dim from model name unless explicitly overridden + // by either env var (Layer 3) or config file (Layer 2). + if let Some(v) = env_usize("SHRIMPK_EMBEDDING_DIM")? { + config.embedding_dim = v; + } else if !dim_set_by_file { + config.embedding_dim = config.infer_embedding_dim(); + } + // Backward compatibility: if reranker_enabled=true but backend=None, default to Llm if config.reranker_enabled && config.reranker_backend == RerankerBackend::None { config.reranker_backend = RerankerBackend::Llm; @@ -1265,4 +1394,108 @@ mod tests { "Explicit backend should override legacy reranker_enabled" ); } + + // --- KS75: EmbeddingBackend --- + + #[test] + fn embedding_provider_default_is_fastembed() { + let config = EchoConfig::default(); + assert_eq!(config.embedding_provider, EmbeddingBackend::Fastembed); + assert_eq!(config.embedding_model, "BGE-small-EN-v1.5"); + } + + #[test] + fn embedding_provider_parse_roundtrip() { + for (input, expected) in [ + ("fastembed", EmbeddingBackend::Fastembed), + ("local", EmbeddingBackend::Fastembed), + ("onnx", EmbeddingBackend::Fastembed), + ("openai", EmbeddingBackend::OpenAI), + ("api", EmbeddingBackend::OpenAI), + ("ollama", EmbeddingBackend::OpenAI), + ] { + let parsed: EmbeddingBackend = input.parse().unwrap(); + assert_eq!(parsed, expected, "parsing '{input}'"); + } + } + + #[test] + fn embedding_provider_parse_invalid() { + assert!("unknown".parse::().is_err()); + } + + #[test] + fn embedding_provider_display() { + assert_eq!(EmbeddingBackend::Fastembed.to_string(), "fastembed"); + assert_eq!(EmbeddingBackend::OpenAI.to_string(), "openai"); + } + + #[test] + fn infer_embedding_dim_known_models() { + let cases: &[(&str, usize)] = &[ + ("BGE-small-EN-v1.5", 384), + ("BGE-base-EN-v1.5", 768), + ("BGE-large-EN-v1.5", 1024), + ("bge-m3", 1024), + ("gte-large-en-v1.5", 1024), + ("gte-base-en-v1.5", 768), + ("text-embedding-3-small", 1536), + ("text-embedding-3-large", 3072), + ("nomic-embed-text", 768), + ]; + for &(model, expected_dim) in cases { + let config = EchoConfig { + embedding_model: model.into(), + ..Default::default() + }; + assert_eq!( + config.infer_embedding_dim(), + expected_dim, + "model '{model}'" + ); + } + } + + #[test] + fn infer_embedding_dim_unknown_falls_back() { + let config = EchoConfig { + embedding_model: "my-custom-model".into(), + embedding_dim: 512, + ..Default::default() + }; + assert_eq!( + config.infer_embedding_dim(), + 512, + "Unknown model should fall back to embedding_dim" + ); + } + + #[test] + fn embedding_provider_serde_roundtrip() { + let config = EchoConfig { + embedding_provider: EmbeddingBackend::OpenAI, + embedding_model: "text-embedding-3-small".into(), + embedding_api_url: "https://api.openai.com".into(), + ..Default::default() + }; + let json = serde_json::to_string(&config).unwrap(); + let parsed: EchoConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.embedding_provider, EmbeddingBackend::OpenAI); + assert_eq!(parsed.embedding_model, "text-embedding-3-small"); + assert_eq!(parsed.embedding_api_url, "https://api.openai.com"); + } + + #[test] + fn file_config_embedding_fields_toml_roundtrip() { + let fc = FileConfig { + embedding_provider: Some(EmbeddingBackend::OpenAI), + embedding_model: Some("nomic-embed-text".into()), + embedding_api_url: Some("http://localhost:11434".into()), + ..Default::default() + }; + let toml_str = toml::to_string_pretty(&fc).unwrap(); + let parsed: FileConfig = toml::from_str(&toml_str).unwrap(); + assert_eq!(parsed.embedding_provider, Some(EmbeddingBackend::OpenAI)); + assert_eq!(parsed.embedding_model, Some("nomic-embed-text".into())); + } } diff --git a/crates/shrimpk-core/src/lib.rs b/crates/shrimpk-core/src/lib.rs index 1658303..0dd9f8c 100644 --- a/crates/shrimpk-core/src/lib.rs +++ b/crates/shrimpk-core/src/lib.rs @@ -13,8 +13,8 @@ pub mod traits; // Re-export commonly used types at crate root pub use config::{ - EchoConfig, FileConfig, QuantizationMode, RerankerBackend, config_dir, config_path, disk_usage, - load_config_file, resolve_config, save_config_file, + EchoConfig, EmbeddingBackend, FileConfig, QuantizationMode, RerankerBackend, config_dir, + config_path, disk_usage, load_config_file, resolve_config, save_config_file, }; pub use entity::{EntityFrame, EntityId, EntityKind}; pub use error::{Result, ShrimPKError}; @@ -26,5 +26,6 @@ pub use memory::{ }; pub use pii::{PiiMatch, PiiType}; pub use traits::{ - ConsolidationOutput, Consolidator, ExtractedFact, FactType, LabelSet, ModelBackend, Provider, + ConsolidationOutput, Consolidator, EmbeddingProvider, ExtractedFact, FactType, LabelSet, + ModelBackend, Provider, }; diff --git a/crates/shrimpk-core/src/traits.rs b/crates/shrimpk-core/src/traits.rs index 5540640..4e119db 100644 --- a/crates/shrimpk-core/src/traits.rs +++ b/crates/shrimpk-core/src/traits.rs @@ -86,6 +86,37 @@ pub struct ModelCapabilities { pub is_local: bool, } +/// A backend that generates embedding vectors from text. +/// +/// Implementations wrap either a local ONNX model (fastembed) or an HTTP API +/// (OpenAI-compatible). The engine holds the provider behind a `Mutex` since +/// fastembed requires `&mut self` for inference. +/// +/// # Providers +/// - `FastembedProvider` — local ONNX (default, offline, zero API calls) +/// - `OpenAIProvider` — any OpenAI-compatible API (cloud or local Ollama) +/// +/// # Contract +/// - `embed` and `embed_batch` must return vectors of length `dimension()`. +/// - On error, return `Err` — never panic. +/// - `Send` but not `Sync` (fastembed's `TextEmbedding` is `!Sync`). +pub trait EmbeddingProvider: Send { + /// Embed a single text string, returning a vector of length `dimension()`. + fn embed(&mut self, text: &str) -> Result>; + + /// Embed a batch of texts. Default implementation calls `embed()` in a loop. + /// Providers with native batching (fastembed, OpenAI) should override this. + fn embed_batch(&mut self, texts: Vec) -> Result>> { + texts.iter().map(|t| self.embed(t)).collect() + } + + /// The embedding dimension this provider produces (e.g., 384, 768, 1536). + fn dimension(&self) -> usize; + + /// Human-readable name (e.g., "fastembed/bge-small-en-v1.5", "openai/text-embedding-3-small"). + fn name(&self) -> &str; +} + /// Type classification for extracted facts (KS67 — schema-driven extraction). /// /// Maps to distinct retrieval patterns and supersession thresholds. diff --git a/crates/shrimpk-daemon/src/routes.rs b/crates/shrimpk-daemon/src/routes.rs index f969ef7..52e34a1 100644 --- a/crates/shrimpk-daemon/src/routes.rs +++ b/crates/shrimpk-daemon/src/routes.rs @@ -346,7 +346,10 @@ pub async fn config_show(State(state): State) -> Json { "proxy_max_echo_results": c.proxy_max_echo_results, "hebbian_half_life_secs": c.hebbian_half_life_secs, "hebbian_prune_threshold": c.hebbian_prune_threshold, - "proxy_max_conversation_turns": c.proxy_max_conversation_turns + "proxy_max_conversation_turns": c.proxy_max_conversation_turns, + "embedding_provider": c.embedding_provider.to_string(), + "embedding_model": c.embedding_model, + "embedding_api_url": c.embedding_api_url })) } @@ -440,6 +443,15 @@ pub async fn config_set( ) })?) } + "embedding_provider" => { + fc.embedding_provider = Some( + req.value + .parse() + .map_err(|e: String| (StatusCode::BAD_REQUEST, Json(json!({"error": e}))))?, + ) + } + "embedding_model" => fc.embedding_model = Some(req.value.clone()), + "embedding_api_url" => fc.embedding_api_url = Some(req.value.clone()), other => { return Err(( StatusCode::BAD_REQUEST, diff --git a/crates/shrimpk-mcp/src/tools.rs b/crates/shrimpk-mcp/src/tools.rs index 359025e..3ec2ac2 100644 --- a/crates/shrimpk-mcp/src/tools.rs +++ b/crates/shrimpk-mcp/src/tools.rs @@ -701,6 +701,29 @@ pub fn handle_config_show(config: &EchoConfig) -> Result { } ), String::new(), + " Embedding Provider:".to_string(), + format!( + " {:25} {:>15} {}", + "embedding_provider", + config.embedding_provider.to_string(), + source( + "SHRIMPK_EMBEDDING_PROVIDER", + fc.embedding_provider.is_some() + ) + ), + format!( + " {:25} {:>15} {}", + "embedding_model", + truncate(&config.embedding_model, 30), + source("SHRIMPK_EMBEDDING_MODEL", fc.embedding_model.is_some()) + ), + format!( + " {:25} {:>15} {}", + "embedding_api_url", + truncate(&config.embedding_api_url, 30), + source("SHRIMPK_EMBEDDING_API_URL", fc.embedding_api_url.is_some()) + ), + String::new(), " Intelligence Engine:".to_string(), format!( " {:25} {:>15} {}", @@ -817,6 +840,9 @@ pub fn handle_config_set(args: &Value) -> Result { "use_full_actr_history" => { fc.use_full_actr_history = Some(value.parse().map_err(|_| "Invalid boolean")?) } + "embedding_provider" => fc.embedding_provider = Some(value.parse().map_err(|e: String| e)?), + "embedding_model" => fc.embedding_model = Some(value.to_string()), + "embedding_api_url" => fc.embedding_api_url = Some(value.to_string()), other => return Err(format!("Unknown config key: \"{other}\"")), } diff --git a/crates/shrimpk-memory/src/echo.rs b/crates/shrimpk-memory/src/echo.rs index b362c90..9bafce4 100644 --- a/crates/shrimpk-memory/src/echo.rs +++ b/crates/shrimpk-memory/src/echo.rs @@ -182,7 +182,7 @@ impl EchoEngine { /// Returns `ShrimPKError::Embedding` if the model fails to initialize. #[instrument(skip(config), fields(max_memories = config.max_memories, threshold = config.similarity_threshold))] pub fn new(config: EchoConfig) -> Result { - let mut embedder = MultiEmbedder::new()?; + let mut embedder = MultiEmbedder::new(&config)?; // Initialize label prototypes BEFORE wrapping embedder in Mutex (ADR-015 D4). // Prototype embeddings are computed once at startup. @@ -194,13 +194,43 @@ impl EchoEngine { let pii_filter = PiiFilter::new(); let reformulator = MemoryReformulator::new(); let store = RwLock::new(EchoStore::new()); - let text_lsh = CosineHash::new(config.embedding_dim, 16, 10); + // KS75: use embedder's actual dimension, not config's possibly-stale value + let text_dim = embedder.text_dimension(); + let text_lsh = CosineHash::new(text_dim, 16, 10); let bloom = TopicFilter::new(config.max_memories, 0.01); + // Pre-build vision/speech LSH before embedder is moved into Mutex + #[cfg(feature = "vision")] + let vision_lsh_init = if config + .enabled_modalities + .contains(&shrimpk_core::Modality::Vision) + { + Some(Mutex::new(CosineHash::new( + embedder.vision_dimension(), + 16, + 10, + ))) + } else { + None + }; + #[cfg(feature = "speech")] + let speech_lsh_init = if config + .enabled_modalities + .contains(&shrimpk_core::Modality::Speech) + { + Some(Mutex::new(CosineHash::new( + embedder.speech_dimension(), + 16, + 10, + ))) + } else { + None + }; + tracing::info!( max_memories = config.max_memories, threshold = config.similarity_threshold, - dim = config.embedding_dim, + dim = text_dim, use_lsh = config.use_lsh, use_bloom = config.use_bloom, "EchoEngine initialized (empty store)" @@ -213,31 +243,9 @@ impl EchoEngine { store, text_lsh: Mutex::new(text_lsh), #[cfg(feature = "vision")] - vision_lsh: if config - .enabled_modalities - .contains(&shrimpk_core::Modality::Vision) - { - Some(Mutex::new(CosineHash::new( - config.vision_embedding_dim, - 16, - 10, - ))) - } else { - None - }, + vision_lsh: vision_lsh_init, #[cfg(feature = "speech")] - speech_lsh: if config - .enabled_modalities - .contains(&shrimpk_core::Modality::Speech) - { - Some(Mutex::new(CosineHash::new( - config.speech_embedding_dim, - 16, - 10, - ))) - } else { - None - }, + speech_lsh: speech_lsh_init, bloom: RwLock::new(bloom), bloom_dirty: Mutex::new(false), pii_filter, @@ -346,12 +354,7 @@ impl EchoEngine { // - Reformulated text if available (structured form embeds better) // - Otherwise original text (semantic meaning preserved) let embed_text = reformulated.as_deref().unwrap_or(text); - let embedding = { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_text(embed_text)? - }; + let embedding = self.embed_blocking(|e| e.embed_text(embed_text))?; // 4. Build entry with auto-categorization for adaptive decay let category = self.reformulator.categorize(text); @@ -559,12 +562,7 @@ impl EchoEngine { } // 1. Embed image with CLIP - let vision_embedding = { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_image(image_data)? - }; + let vision_embedding = self.embed_blocking(|e| e.embed_image(image_data))?; let vision_embedding = vision_embedding.ok_or_else(|| { ShrimPKError::Embedding("Vision model not available — cannot embed image".into()) @@ -573,10 +571,7 @@ impl EchoEngine { // 2. Build content and optional text embedding for cross-modal recall let content = description.unwrap_or("[image]").to_string(); let text_embedding = if let Some(desc) = description { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_text(desc)? + self.embed_blocking(|e| e.embed_text(desc))? } else { Vec::new() }; @@ -700,12 +695,7 @@ impl EchoEngine { } // 1. Embed audio with speech stack - let speech_embedding = { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_audio(pcm_f32, sample_rate)? - }; + let speech_embedding = self.embed_blocking(|e| e.embed_audio(pcm_f32, sample_rate))?; let speech_embedding = speech_embedding.ok_or_else(|| { ShrimPKError::Embedding("Speech models not available — cannot embed audio".into()) @@ -714,10 +704,7 @@ impl EchoEngine { // 2. Build content and optional text embedding for cross-modal recall let content = description.unwrap_or("[audio]").to_string(); let text_embedding = if let Some(desc) = description { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_text(desc)? + self.embed_blocking(|e| e.embed_text(desc))? } else { Vec::new() }; @@ -840,35 +827,32 @@ impl EchoEngine { let reformulated = self.reformulator.reformulate(text_for_reformulation); let embed_text = reformulated.as_deref().unwrap_or(text); - let (text_embedding, vision_embedding, speech_embedding) = { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - - let text_emb = embedder.embed_text(embed_text)?; + let (text_embedding, vision_embedding, speech_embedding) = + self.embed_blocking(|embedder| { + let text_emb = embedder.embed_text(embed_text)?; - // 2. Optional vision embedding - #[cfg(feature = "vision")] - let vis_emb = if let Some(img) = image_data { - embedder.embed_image(img)? - } else { - None - }; - #[cfg(not(feature = "vision"))] - let vis_emb: Option> = None; + // 2. Optional vision embedding + #[cfg(feature = "vision")] + let vis_emb = if let Some(img) = image_data { + embedder.embed_image(img)? + } else { + None + }; + #[cfg(not(feature = "vision"))] + let vis_emb: Option> = None; - // 3. Optional speech embedding - #[cfg(feature = "speech")] - let speech_emb = if let Some((pcm, sr)) = audio_pcm { - embedder.embed_audio(pcm, sr)? - } else { - None - }; - #[cfg(not(feature = "speech"))] - let speech_emb: Option> = None; + // 3. Optional speech embedding + #[cfg(feature = "speech")] + let speech_emb = if let Some((pcm, sr)) = audio_pcm { + embedder.embed_audio(pcm, sr)? + } else { + None + }; + #[cfg(not(feature = "speech"))] + let speech_emb: Option> = None; - (text_emb, vis_emb, speech_emb) - }; + Ok((text_emb, vis_emb, speech_emb)) + })?; // 4. Build entry with all embeddings let category = self.reformulator.categorize(text); @@ -1070,12 +1054,7 @@ impl EchoEngine { }; // 1. Embed the (possibly expanded) query - let query_embedding = { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_text(&effective_query)? - }; + let query_embedding = self.embed_blocking(|e| e.embed_text(&effective_query))?; // 2. Bloom filter pre-check — skip everything if no fingerprints match. // Bypass for small stores (< 50 entries) where Bloom adds risk without benefit. @@ -1817,12 +1796,7 @@ impl EchoEngine { let start = std::time::Instant::now(); // 1. Embed query with CLIP text encoder - let query_embedding = { - let mut embedder = self.embedder.lock().map_err(|e| { - ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) - })?; - embedder.embed_text_for_vision(query)? - }; + let query_embedding = self.embed_blocking(|e| e.embed_text_for_vision(query))?; let query_embedding = match query_embedding { Some(emb) => emb, @@ -2300,12 +2274,7 @@ impl EchoEngine { } // Embed the entity name for ranking - let mut embedder = self - .embedder - .lock() - .map_err(|e| ShrimPKError::Memory(format!("lock: {e}")))?; - let query_emb = embedder.embed_text(entity)?; - drop(embedder); + let query_emb = self.embed_blocking(|e| e.embed_text(entity))?; let mut scored: Vec<(usize, f32)> = indices .iter() @@ -2662,6 +2631,12 @@ impl EchoEngine { // Persist entity store sidecar (KS73) crate::persistence::save_entities(&store, &self.config.data_dir)?; + // KS75: Write embedding model name sidecar for mismatch detection on next load + if let Ok(embedder) = self.embedder.lock() { + let model_path = self.config.data_dir.join("embedding_model.txt"); + let _ = std::fs::write(&model_path, embedder.text_provider_name()); + } + Ok(()) } @@ -2675,7 +2650,7 @@ impl EchoEngine { /// Returns `ShrimPKError::Persistence` if store file is corrupted. #[instrument(skip(config), fields(data_dir = %config.data_dir.display()))] pub fn load(config: EchoConfig) -> Result { - let mut embedder = MultiEmbedder::new()?; + let mut embedder = MultiEmbedder::new(&config)?; // Initialize label prototypes (ADR-015) let mut prototypes = crate::labels::LabelPrototypes::new_empty(); @@ -2689,6 +2664,37 @@ impl EchoEngine { let store_path = config.data_dir.join("echo_store.shrm"); let mut loaded_store = EchoStore::load(&store_path)?; + // KS75: Dimension mismatch detection — hard error if stored vectors don't match config + if let Some(first_emb) = loaded_store.all_embeddings().first() { + let stored_dim = first_emb.len(); + let config_dim = embedder.text_dimension(); + if stored_dim != config_dim { + return Err(ShrimPKError::Embedding(format!( + "Embedding dimension mismatch: stored data has {stored_dim}-dim vectors \ + but current model '{}' produces {config_dim}-dim. \ + Either switch back to the original model or clear the store with /api/clear.", + embedder.text_provider_name() + ))); + } + } + + // KS75: Model name sidecar — warn if model changed (same dim, different model = mixed space) + let model_sidecar = config.data_dir.join("embedding_model.txt"); + if model_sidecar.exists() + && let Ok(stored_model) = std::fs::read_to_string(&model_sidecar) + { + let stored_model = stored_model.trim(); + let current_model = embedder.text_provider_name(); + if stored_model != current_model && !loaded_store.all_entries().is_empty() { + tracing::warn!( + stored_model = %stored_model, + current_model = %current_model, + "Embedding model changed since last persist. \ + Vectors from different models in the same space may degrade similarity quality." + ); + } + } + // Load community summaries sidecar (KS64) if let Err(e) = crate::persistence::load_community_summaries(&mut loaded_store, &config.data_dir) @@ -2701,8 +2707,11 @@ impl EchoEngine { tracing::warn!(error = %e, "Failed to load entities, continuing without"); } + // KS75: use embedder's actual dimension, not config's possibly-stale value + let text_dim = embedder.text_dimension(); + // Rebuild text LSH index from loaded embeddings - let mut text_lsh = CosineHash::new(config.embedding_dim, 16, 10); + let mut text_lsh = CosineHash::new(text_dim, 16, 10); if config.use_lsh { for (i, embedding) in loaded_store.all_embeddings().iter().enumerate() { text_lsh.insert(i as u32, embedding); @@ -2738,7 +2747,7 @@ impl EchoEngine { .enabled_modalities .contains(&shrimpk_core::Modality::Vision) { - let mut vlsh = CosineHash::new(config.vision_embedding_dim, 16, 10); + let mut vlsh = CosineHash::new(embedder.vision_dimension(), 16, 10); let mut vision_count = 0usize; for (i, entry) in loaded_store.all_entries().iter().enumerate() { if let Some(ref ve) = entry.vision_embedding { @@ -2763,7 +2772,7 @@ impl EchoEngine { .enabled_modalities .contains(&shrimpk_core::Modality::Speech) { - let mut slsh = CosineHash::new(config.speech_embedding_dim, 16, 10); + let mut slsh = CosineHash::new(embedder.speech_dimension(), 16, 10); let mut speech_count = 0usize; for (i, entry) in loaded_store.all_entries().iter().enumerate() { if let Some(ref se) = entry.speech_embedding { @@ -2967,6 +2976,40 @@ impl EchoEngine { } } + /// Lock the embedder and run a blocking embedding operation. + /// + /// On a **multi-thread** Tokio runtime (the daemon) this uses + /// `tokio::task::block_in_place` to inform the scheduler that the + /// current thread will block, preventing worker-thread starvation. + /// On a **current-thread** runtime (`#[tokio::test]`) or outside + /// Tokio entirely (sync tests, CLI) we call `f` directly, because + /// `block_in_place` panics on a single-threaded runtime. + fn embed_blocking(&self, f: F) -> Result + where + F: FnOnce(&mut MultiEmbedder) -> Result, + { + let embedder_mutex = &self.embedder; + // Use block_in_place on multi-thread runtime to prevent worker starvation. + // Both the lock acquisition AND inference run inside block_in_place so that + // a contended lock() does not silently block a Tokio worker thread. + match tokio::runtime::Handle::try_current() { + Ok(handle) if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread => { + tokio::task::block_in_place(|| { + let mut embedder = embedder_mutex.lock().map_err(|e| { + ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) + })?; + f(&mut embedder) + }) + } + _ => { + let mut embedder = embedder_mutex.lock().map_err(|e| { + ShrimPKError::Embedding(format!("MultiEmbedder lock poisoned: {e}")) + })?; + f(&mut embedder) + } + } + } + /// Test-only: generate an embedding for text using the engine's embedder. /// /// Provides access to the same embedding model used by `store()` so that diff --git a/crates/shrimpk-memory/src/embedder.rs b/crates/shrimpk-memory/src/embedder.rs index 402d03e..d6c7207 100644 --- a/crates/shrimpk-memory/src/embedder.rs +++ b/crates/shrimpk-memory/src/embedder.rs @@ -1,34 +1,33 @@ -//! Multi-channel embedding via fastembed. +//! Multi-channel embedding with pluggable text provider (KS75). //! -//! Wraps `fastembed::TextEmbedding` with the BGE-small-EN-v1.5 model -//! for 384-dimensional sentence embeddings. Vision (CLIP 512-dim) and -//! Speech (640-dim) channels are gated behind `vision` and `speech` -//! feature flags. -//! -//! When `vision` is enabled, loads two additional models: -//! - CLIP ViT-B-32 *vision* encoder (`ImageEmbedding`) — embeds images to 512-dim. -//! - CLIP ViT-B-32 *text* encoder (`TextEmbedding`) — embeds text to the same 512-dim -//! space, enabling cross-modal text-to-image retrieval. +//! Text channel delegates to an `EmbeddingProvider` implementation selected +//! at runtime via `EchoConfig` (default: fastembed BGE-small-EN-v1.5, 384-dim). +//! Vision (CLIP 512-dim) and Speech (640-dim) channels are gated behind +//! `vision` and `speech` feature flags. +#[cfg(feature = "vision")] use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; -use shrimpk_core::{Result, ShrimPKError}; +#[cfg(any(feature = "vision", feature = "speech"))] +use shrimpk_core::ShrimPKError; +use shrimpk_core::{EchoConfig, EmbeddingProvider, Result}; use tracing::instrument; /// Multi-channel embedder for text, vision, and speech modalities. /// -/// Text channel (always available): BGE-small-EN-v1.5, 384-dim. +/// Text channel delegates to a pluggable `EmbeddingProvider` (KS75). /// Vision channel (feature = "vision"): CLIP ViT-B-32, 512-dim. /// Speech channel (feature = "speech"): ECAPA-TDNN (256) + Whisper-tiny encoder (384) = 640-dim. /// -/// Thread-safe: `TextEmbedding` and `ImageEmbedding` are `Send` (but not `Sync`), +/// Thread-safe: providers are `Send` (but not `Sync`), /// so share via `Mutex` or create per-thread instances. pub struct MultiEmbedder { - text: TextEmbedding, + /// Pluggable text embedding provider (fastembed or OpenAI-compatible API). + text_provider: Box, /// CLIP vision encoder — embeds images into 512-dim CLIP space. #[cfg(feature = "vision")] vision: Option, /// CLIP text encoder — embeds text into the same 512-dim CLIP space. - /// Separate from `text` (MiniLM 384-dim) because the embedding spaces are incompatible. + /// Separate from text provider because the embedding spaces are incompatible. #[cfg(feature = "vision")] vision_text: Option, /// Speech embedder — 2 ONNX models producing a 640-dim paralinguistic embedding. @@ -38,32 +37,18 @@ pub struct MultiEmbedder { } impl MultiEmbedder { - /// Initialize the multi-channel embedder. + /// Initialize the multi-channel embedder from config. /// - /// Always loads the text model (BGE-small-EN-v1.5, 384-dim). - /// When the `vision` feature is enabled, also attempts to load - /// CLIP ViT-B-32 vision + text encoders (512-dim). If CLIP fails - /// to initialize, vision is disabled gracefully — text still works. + /// The text channel is delegated to an `EmbeddingProvider` selected by + /// `config.embedding_provider` (default: fastembed BGE-small-EN-v1.5). + /// Vision/speech channels are unchanged (compile-time feature flags). /// /// # Errors - /// Returns `ShrimPKError::Embedding` if the *text* model fails to initialize. + /// Returns `ShrimPKError::Embedding` if the text provider fails to initialize. /// Vision model failures are logged as warnings and result in `vision = None`. - #[instrument] - pub fn new() -> Result { - let start = std::time::Instant::now(); - - let text = TextEmbedding::try_new(InitOptions::new(EmbeddingModel::BGESmallENV15)) - .map_err(|e| { - ShrimPKError::Embedding(format!("Failed to init BGE-small-EN-v1.5: {e}")) - })?; - - let elapsed = start.elapsed(); - tracing::info!( - elapsed_ms = elapsed.as_millis(), - model = "BGE-small-EN-v1.5", - dim = 384, - "MultiEmbedder initialized (text channel)" - ); + #[instrument(skip(config))] + pub fn new(config: &EchoConfig) -> Result { + let text_provider = crate::embedding_provider::from_config(config)?; #[cfg(feature = "vision")] let (vision, vision_text) = { @@ -119,7 +104,7 @@ impl MultiEmbedder { }; Ok(Self { - text, + text_provider, #[cfg(feature = "vision")] vision, #[cfg(feature = "vision")] @@ -129,7 +114,7 @@ impl MultiEmbedder { }) } - /// Embed a single text string into a 384-dimensional vector. + /// Embed a single text string into a vector of `text_dimension()` dimensions. /// /// # Errors /// Returns `ShrimPKError::Embedding` if embedding generation fails. @@ -137,20 +122,11 @@ impl MultiEmbedder { pub fn embed_text(&mut self, text: &str) -> Result> { let start = std::time::Instant::now(); - let results = self - .text - .embed(vec![text.to_string()], None) - .map_err(|e| ShrimPKError::Embedding(format!("Embed failed: {e}")))?; - - let embedding = results - .into_iter() - .next() - .ok_or_else(|| ShrimPKError::Embedding("Empty embedding result".into()))?; + let embedding = self.text_provider.embed(text)?; - let elapsed = start.elapsed(); tracing::debug!( dim = embedding.len(), - elapsed_us = elapsed.as_micros(), + elapsed_us = start.elapsed().as_micros(), "Single text embed complete" ); @@ -159,8 +135,7 @@ impl MultiEmbedder { /// Batch-embed multiple texts. /// - /// More efficient than calling `embed_text()` in a loop because - /// fastembed batches the ONNX inference. + /// Delegates to the provider's native batch implementation for efficiency. /// /// # Errors /// Returns `ShrimPKError::Embedding` if any embedding generation fails. @@ -173,17 +148,13 @@ impl MultiEmbedder { let start = std::time::Instant::now(); let count = texts.len(); - let results = self - .text - .embed(texts, None) - .map_err(|e| ShrimPKError::Embedding(format!("Batch embed failed: {e}")))?; + let results = self.text_provider.embed_batch(texts)?; - let elapsed = start.elapsed(); tracing::debug!( count = count, - elapsed_ms = elapsed.as_millis(), + elapsed_ms = start.elapsed().as_millis(), avg_us = if count > 0 { - elapsed.as_micros() / count as u128 + start.elapsed().as_micros() / count as u128 } else { 0 }, @@ -193,9 +164,14 @@ impl MultiEmbedder { Ok(results) } - /// Get the text embedding dimension (384 for BGE-small-EN-v1.5). + /// Get the text embedding dimension from the active provider. pub fn text_dimension(&self) -> usize { - 384 + self.text_provider.dimension() + } + + /// Get the human-readable name of the active text embedding provider. + pub fn text_provider_name(&self) -> &str { + self.text_provider.name() } /// Embed an image into a 512-dimensional CLIP vector. @@ -337,14 +313,16 @@ mod tests { #[test] #[ignore = "requires fastembed model download"] fn embedder_initializes() { - let embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); assert_eq!(embedder.text_dimension(), 384); } #[test] #[ignore = "requires fastembed model download"] fn embed_single_text() { - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); let embedding = embedder.embed_text("Hello world").expect("Should embed"); assert_eq!( embedding.len(), @@ -356,7 +334,8 @@ mod tests { #[test] #[ignore = "requires fastembed model download"] fn embed_batch_texts() { - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); let texts = vec![ "The cat sat on the mat".to_string(), "Dogs are loyal companions".to_string(), @@ -372,7 +351,8 @@ mod tests { #[test] #[ignore = "requires fastembed model download"] fn similar_texts_have_higher_similarity() { - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); let cat = embedder.embed_text("The cat sat on the mat").unwrap(); let kitten = embedder.embed_text("A kitten rests on a rug").unwrap(); let code = embedder @@ -392,7 +372,8 @@ mod tests { #[test] #[ignore = "requires fastembed model download"] fn embed_batch_empty_returns_empty() { - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); let embeddings = embedder .embed_batch(Vec::new()) .expect("Should handle empty"); @@ -406,7 +387,8 @@ mod tests { #[test] #[ignore = "requires CLIP model download (~352 MB)"] fn clip_vision_initializes() { - let embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); assert!(embedder.has_vision(), "CLIP vision should be available"); assert_eq!(embedder.vision_dimension(), 512); } @@ -415,7 +397,8 @@ mod tests { #[test] #[ignore = "requires CLIP model download (~352 MB)"] fn embed_image_produces_512_dim() { - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); // Create a minimal 2x2 red PNG image let png_data = create_test_png(2, 2, [255, 0, 0]); @@ -443,7 +426,8 @@ mod tests { #[test] #[ignore = "requires CLIP model download (~352 MB)"] fn embed_text_for_vision_produces_512_dim() { - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); let embedding = embedder .embed_text_for_vision("a photo of a cat") @@ -463,7 +447,8 @@ mod tests { fn clip_cross_modal_similarity() { // CLIP's key property: text and image embeddings in the same space // should have positive similarity for matching concepts. - let mut embedder = MultiEmbedder::new().expect("MultiEmbedder should init"); + let mut embedder = + MultiEmbedder::new(&EchoConfig::default()).expect("MultiEmbedder should init"); // Embed a red image let red_png = create_test_png(32, 32, [255, 0, 0]); @@ -496,7 +481,7 @@ mod tests { #[ignore = "requires CLIP model download (~352 MB)"] fn clip_init_latency_under_5s() { let start = std::time::Instant::now(); - let _embedder = MultiEmbedder::new().expect("Should init"); + let _embedder = MultiEmbedder::new(&EchoConfig::default()).expect("Should init"); let elapsed = start.elapsed(); assert!( elapsed.as_secs() < 10, // generous to account for cold cache diff --git a/crates/shrimpk-memory/src/embedding_provider.rs b/crates/shrimpk-memory/src/embedding_provider.rs new file mode 100644 index 0000000..263c3b4 --- /dev/null +++ b/crates/shrimpk-memory/src/embedding_provider.rs @@ -0,0 +1,427 @@ +//! Pluggable embedding provider implementations (KS75). +//! +//! Two backends: +//! - `FastembedProvider` — local ONNX via `fastembed` (default, zero API calls) +//! - `OpenAIProvider` — any OpenAI-compatible embedding API (cloud or local Ollama) +//! +//! Factory function `from_config()` selects the appropriate provider based on `EchoConfig`. + +use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; +use shrimpk_core::{EchoConfig, EmbeddingBackend, EmbeddingProvider, Result, ShrimPKError}; + +// --------------------------------------------------------------------------- +// FastembedProvider +// --------------------------------------------------------------------------- + +/// Local ONNX embedding via fastembed. +/// +/// Wraps `fastembed::TextEmbedding` with runtime model selection. +/// Zero external API calls — all inference runs locally. +pub struct FastembedProvider { + model: TextEmbedding, + dim: usize, + model_name: String, +} + +impl FastembedProvider { + /// Create a new FastembedProvider for the given model name. + /// + /// Supported model names (case-insensitive): + /// - "bge-small-en-v1.5" (384-dim, default) + /// - "bge-base-en-v1.5" (768-dim) + /// - "bge-large-en-v1.5" (1024-dim) + /// - "bge-m3" (1024-dim) + /// - "all-minilm-l6-v2" (384-dim) + /// - "all-minilm-l12-v2" (384-dim) + /// - "nomic-embed-text-v1.5" (768-dim) + /// - "mxbai-embed-large-v1" (1024-dim) + /// - "gte-large-en-v1.5" (1024-dim) + pub fn new(model_name: &str) -> Result { + let (variant, dim) = resolve_fastembed_model(model_name)?; + let display_name = format!("fastembed/{model_name}"); + + let start = std::time::Instant::now(); + let model = TextEmbedding::try_new(InitOptions::new(variant)).map_err(|e| { + ShrimPKError::Embedding(format!( + "Failed to init fastembed model '{model_name}': {e}" + )) + })?; + + tracing::info!( + elapsed_ms = start.elapsed().as_millis(), + model = %display_name, + dim = dim, + "FastembedProvider initialized" + ); + + Ok(Self { + model, + dim, + model_name: display_name, + }) + } +} + +impl EmbeddingProvider for FastembedProvider { + fn embed(&mut self, text: &str) -> Result> { + let results = self + .model + .embed(vec![text.to_string()], None) + .map_err(|e| ShrimPKError::Embedding(format!("fastembed embed failed: {e}")))?; + + results + .into_iter() + .next() + .ok_or_else(|| ShrimPKError::Embedding("Empty fastembed result".into())) + } + + fn embed_batch(&mut self, texts: Vec) -> Result>> { + if texts.is_empty() { + return Ok(Vec::new()); + } + self.model + .embed(texts, None) + .map_err(|e| ShrimPKError::Embedding(format!("fastembed batch embed failed: {e}"))) + } + + fn dimension(&self) -> usize { + self.dim + } + + fn name(&self) -> &str { + &self.model_name + } +} + +/// Map a config model name to a fastembed `EmbeddingModel` variant + dimension. +fn resolve_fastembed_model(name: &str) -> Result<(EmbeddingModel, usize)> { + let lower = name.to_lowercase(); + match lower.as_str() { + s if s.contains("bge-small-en") => Ok((EmbeddingModel::BGESmallENV15, 384)), + s if s.contains("bge-base-en") => Ok((EmbeddingModel::BGEBaseENV15, 768)), + s if s.contains("bge-large-en") => Ok((EmbeddingModel::BGELargeENV15, 1024)), + s if s.contains("bge-m3") => Ok((EmbeddingModel::BGEM3, 1024)), + s if s.contains("all-minilm-l6") || s.contains("minilm-l6") => { + Ok((EmbeddingModel::AllMiniLML6V2, 384)) + } + s if s.contains("all-minilm-l12") || s.contains("minilm-l12") => { + Ok((EmbeddingModel::AllMiniLML12V2, 384)) + } + s if s.contains("nomic-embed-text") => Ok((EmbeddingModel::NomicEmbedTextV15, 768)), + s if s.contains("mxbai-embed-large") => Ok((EmbeddingModel::MxbaiEmbedLargeV1, 1024)), + s if s.contains("gte-large-en") => Ok((EmbeddingModel::GTELargeENV15, 1024)), + s if s.contains("gte-base-en") => Ok((EmbeddingModel::GTEBaseENV15, 768)), + _ => Err(ShrimPKError::Embedding(format!( + "Unknown fastembed model '{name}'. Supported: bge-small-en-v1.5, bge-base-en-v1.5, \ + bge-large-en-v1.5, bge-m3, all-minilm-l6-v2, all-minilm-l12-v2, \ + nomic-embed-text-v1.5, mxbai-embed-large-v1, gte-large-en-v1.5, gte-base-en-v1.5" + ))), + } +} + +// --------------------------------------------------------------------------- +// OpenAIProvider +// --------------------------------------------------------------------------- + +/// OpenAI-compatible embedding API provider. +/// +/// Works with any endpoint that implements the `/v1/embeddings` contract: +/// OpenAI, Ollama, LiteLLM, vLLM, Azure OpenAI, etc. +/// +/// API key is read from `SHRIMPK_EMBEDDING_API_KEY` env var -- never stored in config. +/// +/// # Blocking +/// +/// Uses [`ureq`] (synchronous HTTP). All calls block the current thread for up to 30 s. +/// Callers in async contexts **must** invoke this provider through +/// `EchoEngine::embed_blocking()` which uses +/// `tokio::task::block_in_place` to prevent worker-thread starvation. +pub struct OpenAIProvider { + url: String, + model: String, + api_key: Option, + agent: ureq::Agent, + dim: usize, + display_name: String, +} + +impl OpenAIProvider { + /// Create a new OpenAI-compatible embedding provider. + /// + /// The `dim` parameter must match the actual dimension of the remote model. + /// Use `EchoConfig::infer_embedding_dim()` to auto-derive it from the model name. + pub fn new(url: &str, model: &str, dim: usize) -> Result { + let api_key = std::env::var("SHRIMPK_EMBEDDING_API_KEY").ok(); + let display_name = format!("openai/{model}"); + + let agent = ureq::Agent::new_with_config( + ureq::config::Config::builder() + .timeout_global(Some(std::time::Duration::from_secs(30))) + .build(), + ); + + tracing::info!( + url = %url, + model = %model, + dim = dim, + has_api_key = api_key.is_some(), + "OpenAIProvider initialized" + ); + + Ok(Self { + url: url.trim_end_matches('/').to_string(), + model: model.to_string(), + api_key, + agent, + dim, + display_name, + }) + } + + /// Call the embedding API for a batch of texts. + /// + /// # Blocking + /// + /// This method performs a synchronous HTTP POST via [`ureq`] and will block the + /// calling thread for up to 30 s (the global timeout configured in [`Self::new`]). + /// + /// When running on a **multi-thread** Tokio runtime (the daemon) the blocking + /// HTTP call is wrapped in [`tokio::task::block_in_place`] to inform the + /// scheduler and prevent worker-thread starvation. On a **current-thread** + /// runtime (`#[tokio::test]`) or outside Tokio entirely (sync tests, CLI) the + /// request runs directly, because `block_in_place` panics on a single-threaded + /// runtime. + /// + /// This is defense-in-depth: [`EchoEngine::embed_blocking()`] also wraps the + /// outer call with `block_in_place` for the mutex-lock concern; the inner wrap + /// here covers the provider-specific HTTP concern. + fn call_api(&self, texts: &[String]) -> Result>> { + let endpoint = format!("{}/v1/embeddings", self.url); + + let body = serde_json::json!({ + "model": self.model, + "input": texts, + }); + + // Audit logging (matches HttpConsolidator pattern) + let body_bytes = serde_json::to_vec(&body).unwrap_or_default(); + tracing::info!( + target: "shrimpk::audit", + endpoint = %endpoint, + data_bytes = body_bytes.len(), + batch_size = texts.len(), + direction = "outbound", + component = "embedding_provider", + "External embedding API call" + ); + + // Closure that performs the synchronous HTTP request and parses the response. + // Factored out so we can conditionally wrap it with block_in_place. + let do_request = || -> Result>> { + let mut req = self.agent.post(&endpoint); + if let Some(key) = &self.api_key { + req = req.header("Authorization", &format!("Bearer {key}")); + } + + let mut resp = req.send_json(&body).map_err(|e| { + ShrimPKError::Embedding(format!("OpenAI embedding API error at {endpoint}: {e}")) + })?; + + let json: serde_json::Value = resp.body_mut().read_json().map_err(|e| { + ShrimPKError::Embedding(format!("OpenAI embedding API parse error: {e}")) + })?; + + // Extract embeddings: {"data": [{"embedding": [...], "index": 0}, ...]} + let data = json["data"].as_array().ok_or_else(|| { + ShrimPKError::Embedding(format!( + "OpenAI embedding API: missing 'data' array in response: {}", + truncate_json(&json) + )) + })?; + + // Sort by index to maintain input order + let mut indexed: Vec<(usize, Vec)> = data + .iter() + .filter_map(|item| { + let index = item["index"].as_u64()? as usize; + let embedding: Vec = item["embedding"] + .as_array()? + .iter() + .filter_map(|v| v.as_f64().map(|f| f as f32)) + .collect(); + Some((index, embedding)) + }) + .collect(); + + indexed.sort_by_key(|(i, _)| *i); + let embeddings: Vec> = indexed.into_iter().map(|(_, e)| e).collect(); + + if embeddings.len() != texts.len() { + return Err(ShrimPKError::Embedding(format!( + "OpenAI embedding API returned {} embeddings for {} inputs", + embeddings.len(), + texts.len() + ))); + } + + // Validate dimension + if let Some(first) = embeddings.first() + && first.len() != self.dim + { + return Err(ShrimPKError::Embedding(format!( + "OpenAI embedding dimension mismatch: expected {}, got {} from model '{}'", + self.dim, + first.len(), + self.model + ))); + } + + Ok(embeddings) + }; + + // Wrap blocking HTTP in block_in_place on multi-thread Tokio runtime + // to prevent worker-thread starvation (ureq has a 30 s timeout). + match tokio::runtime::Handle::try_current() { + Ok(handle) if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread => { + tokio::task::block_in_place(do_request) + } + _ => do_request(), + } + } +} + +impl EmbeddingProvider for OpenAIProvider { + fn embed(&mut self, text: &str) -> Result> { + let results = self.call_api(&[text.to_string()])?; + results + .into_iter() + .next() + .ok_or_else(|| ShrimPKError::Embedding("Empty OpenAI embedding result".into())) + } + + fn embed_batch(&mut self, texts: Vec) -> Result>> { + if texts.is_empty() { + return Ok(Vec::new()); + } + self.call_api(&texts) + } + + fn dimension(&self) -> usize { + self.dim + } + + fn name(&self) -> &str { + &self.display_name + } +} + +/// Truncate JSON for error messages. +fn truncate_json(v: &serde_json::Value) -> String { + let s = v.to_string(); + if s.len() > 200 { + format!("{}...", &s[..200]) + } else { + s + } +} + +// --------------------------------------------------------------------------- +// Factory +// --------------------------------------------------------------------------- + +/// Create an embedding provider based on config. +/// +/// Reads `config.embedding_provider` to select the backend: +/// - `Fastembed` → `FastembedProvider` with `config.embedding_model` +/// - `OpenAI` → `OpenAIProvider` with `config.embedding_api_url` + `config.embedding_model` +pub fn from_config(config: &EchoConfig) -> Result> { + match config.embedding_provider { + EmbeddingBackend::Fastembed => { + let provider = FastembedProvider::new(&config.embedding_model)?; + Ok(Box::new(provider)) + } + EmbeddingBackend::OpenAI => { + let dim = config.infer_embedding_dim(); + let provider = + OpenAIProvider::new(&config.embedding_api_url, &config.embedding_model, dim)?; + Ok(Box::new(provider)) + } + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn resolve_fastembed_known_models() { + let cases = [ + ("BGE-small-EN-v1.5", 384), + ("bge-small-en-v1.5", 384), + ("bge-base-en-v1.5", 768), + ("BGE-large-EN-v1.5", 1024), + ("bge-m3", 1024), + ("all-MiniLM-L6-v2", 384), + ("all-minilm-l12-v2", 384), + ("nomic-embed-text-v1.5", 768), + ("mxbai-embed-large-v1", 1024), + ("gte-large-en-v1.5", 1024), + ("gte-base-en-v1.5", 768), + ]; + for (name, expected_dim) in cases { + let (_, dim) = + resolve_fastembed_model(name).unwrap_or_else(|_| panic!("should resolve '{name}'")); + assert_eq!(dim, expected_dim, "model '{name}'"); + } + } + + #[test] + fn resolve_fastembed_unknown_errors() { + let result = resolve_fastembed_model("my-custom-model"); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("Unknown fastembed model"), + "error should mention unknown model: {err}" + ); + } + + #[test] + fn openai_provider_initializes() { + // OpenAI provider should initialize (api_key is read from env, may or may not be set) + let provider = OpenAIProvider::new("http://localhost:11434", "nomic-embed-text", 768); + assert!(provider.is_ok()); + let p = provider.unwrap(); + assert_eq!(p.dimension(), 768); + assert_eq!(p.name(), "openai/nomic-embed-text"); + } + + #[test] + fn from_config_default_selects_fastembed() { + // This test just checks that the factory selects the right backend. + // It doesn't actually initialize the model (that would download it). + let config = EchoConfig::default(); + assert_eq!(config.embedding_provider, EmbeddingBackend::Fastembed); + assert_eq!(config.embedding_model, "BGE-small-EN-v1.5"); + } + + #[test] + #[ignore = "requires fastembed model download"] + fn fastembed_provider_default_model_works() { + let mut provider = FastembedProvider::new("BGE-small-EN-v1.5").unwrap(); + assert_eq!(provider.dimension(), 384); + + let embedding = provider.embed("Hello world").unwrap(); + assert_eq!(embedding.len(), 384); + + let batch = provider + .embed_batch(vec!["Hello".into(), "World".into()]) + .unwrap(); + assert_eq!(batch.len(), 2); + assert_eq!(batch[0].len(), 384); + } +} diff --git a/crates/shrimpk-memory/src/lib.rs b/crates/shrimpk-memory/src/lib.rs index 932d915..b55a626 100644 --- a/crates/shrimpk-memory/src/lib.rs +++ b/crates/shrimpk-memory/src/lib.rs @@ -19,6 +19,7 @@ pub mod consolidation; pub mod consolidator; pub mod echo; pub mod embedder; +pub mod embedding_provider; pub mod hebbian; pub mod importance; pub mod labels; diff --git a/tests/echo_precision_tuning.rs b/tests/echo_precision_tuning.rs index e524521..f043fea 100644 --- a/tests/echo_precision_tuning.rs +++ b/tests/echo_precision_tuning.rs @@ -588,7 +588,8 @@ async fn threshold_range_sweep() { println!(); // Collect similarity scores for all pairs using raw embeddings - let mut embedder = MultiEmbedder::new().expect("embedder init"); + let mut embedder = + MultiEmbedder::new(&shrimpk_core::EchoConfig::default()).expect("embedder init"); // Build memory embeddings map let mut mem_embeddings: Vec<(&str, Vec)> = Vec::new(); @@ -730,7 +731,8 @@ async fn threshold_range_sweep() { #[tokio::test] #[ignore = "requires fastembed model download"] async fn query_formulation_analysis() { - let mut embedder = MultiEmbedder::new().expect("embedder init"); + let mut embedder = + MultiEmbedder::new(&shrimpk_core::EchoConfig::default()).expect("embedder init"); let positives = positive_pairs(); @@ -849,7 +851,8 @@ async fn query_formulation_analysis() { #[tokio::test] #[ignore = "requires fastembed model download"] async fn memory_formulation_analysis() { - let mut embedder = MultiEmbedder::new().expect("embedder init"); + let mut embedder = + MultiEmbedder::new(&shrimpk_core::EchoConfig::default()).expect("embedder init"); println!(); println!("===================================================================="); @@ -1198,7 +1201,8 @@ async fn context_window_simulation() { #[tokio::test] #[ignore = "requires fastembed model download"] async fn recommended_configuration() { - let mut embedder = MultiEmbedder::new().expect("embedder init"); + let mut embedder = + MultiEmbedder::new(&shrimpk_core::EchoConfig::default()).expect("embedder init"); let positives = positive_pairs(); let negatives = negative_pairs(); @@ -1533,7 +1537,8 @@ async fn recommended_configuration() { #[tokio::test] #[ignore = "requires fastembed model download"] async fn hardest_pairs_deep_dive() { - let mut embedder = MultiEmbedder::new().expect("embedder init"); + let mut embedder = + MultiEmbedder::new(&shrimpk_core::EchoConfig::default()).expect("embedder init"); println!(); println!("===================================================================="); diff --git a/tests/echo_token_efficiency.rs b/tests/echo_token_efficiency.rs index 0ab88c4..b69d394 100644 --- a/tests/echo_token_efficiency.rs +++ b/tests/echo_token_efficiency.rs @@ -919,7 +919,8 @@ async fn vllm_throughput_projection() { #[tokio::test] #[ignore = "requires fastembed model download"] async fn context_quality_comparison() { - let mut embedder = MultiEmbedder::new().expect("embedder should initialize"); + let mut embedder = MultiEmbedder::new(&shrimpk_core::EchoConfig::default()) + .expect("embedder should initialize"); let scenarios = scenarios(); println!("\n{}", "=".repeat(90));