From b33be2d03c33ffeb437521f88c326b53cbcc407f Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Tue, 12 Aug 2025 16:56:16 -0700 Subject: [PATCH 1/4] Add Qwen3 reranker support for sequence classification models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Added Qwen3ClassificationHead with flexible tensor loading that handles: - score.weight at top level (for converted Qwen3 rerankers) - classifier.weight/bias patterns for standard models - Updated Qwen3Model and FlashQwen3Model to support classification - Added predict() method implementations for both model variants - Extended Qwen3Config with id2label field for classification - Added test case for Qwen3 reranker models with snapshot The implementation supports Qwen3 models converted to sequence classifiers for reranking tasks (e.g., tomaarsen/Qwen3-Reranker-0.6B-seq-cls). The classification head gracefully handles different tensor naming conventions from various conversion approaches. Tested with both embedding and reranking Qwen3 models. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- backends/candle/src/models/flash_qwen3.rs | 24 +++++-- backends/candle/src/models/mod.rs | 2 +- backends/candle/src/models/qwen3.rs | 71 ++++++++++++++++++- .../test_qwen3__qwen3_reranker_single.snap | 7 ++ backends/candle/tests/test_qwen3.rs | 41 ++++++++++- 5 files changed, 135 insertions(+), 10 deletions(-) create mode 100644 backends/candle/tests/snapshots/test_qwen3__qwen3_reranker_single.snap diff --git a/backends/candle/src/models/flash_qwen3.rs b/backends/candle/src/models/flash_qwen3.rs index 10f27bdd..48c3a0c2 100644 --- a/backends/candle/src/models/flash_qwen3.rs +++ b/backends/candle/src/models/flash_qwen3.rs @@ -1,6 +1,6 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; -use crate::models::{Model, Qwen3Config}; +use crate::models::{Model, Qwen3Config, Qwen3ClassificationHead}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; use candle_rotary::apply_rotary_inplace; @@ -288,6 +288,7 @@ pub struct FlashQwen3Model { cos_cache: Tensor, sin_cache: Tensor, pool: Pool, + classification_head: Option, pub device: Device, span: tracing::Span, @@ -304,11 +305,13 @@ impl FlashQwen3Model { candle::bail!("FlashQwen3 requires DType::F16") } - let pool = match model_type { + let (pool, classification_head) = match model_type { ModelType::Classifier => { - candle::bail!("`classifier` model type is not supported for Qwen3") + // Load classification head before the vb is modified + let classification_head = Some(Qwen3ClassificationHead::load(vb.clone(), config)?); + (Pool::Cls, classification_head) // Use CLS pooling for classification } - ModelType::Embedding(pool) => pool, + ModelType::Embedding(pool) => (pool, None), }; // The Qwen3-Reranker models contain the `model` key @@ -351,6 +354,7 @@ impl FlashQwen3Model { cos_cache, sin_cache, pool, + classification_head, device: vb.device().clone(), span: tracing::span!(tracing::Level::TRACE, "model"), }) @@ -512,4 +516,16 @@ impl Model for FlashQwen3Model { fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } + + fn predict(&self, batch: Batch) -> Result { + match &self.classification_head { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classification_head) => { + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classification_head.forward(&pooled_embeddings) + } + } + } } diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index 65fb8744..05bcacb3 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -61,7 +61,7 @@ pub use modernbert::{ModernBertConfig, ModernBertModel}; pub use mpnet::{MPNetConfig, MPNetModel}; pub use nomic::{NomicBertModel, NomicConfig}; pub use qwen2::Qwen2Config; -pub use qwen3::{Qwen3Config, Qwen3Model}; +pub use qwen3::{Qwen3Config, Qwen3Model, Qwen3ClassificationHead}; use text_embeddings_backend_core::Batch; #[cfg(feature = "cuda")] diff --git a/backends/candle/src/models/qwen3.rs b/backends/candle/src/models/qwen3.rs index 13309927..3c979fdc 100644 --- a/backends/candle/src/models/qwen3.rs +++ b/backends/candle/src/models/qwen3.rs @@ -5,6 +5,7 @@ use crate::models::Model; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; +use std::collections::HashMap; use text_embeddings_backend_core::{Batch, ModelType, Pool}; #[derive(Debug, Clone, PartialEq, Deserialize)] @@ -24,6 +25,7 @@ pub struct Qwen3Config { pub sliding_window: Option, pub use_sliding_window: bool, pub eos_token_id: usize, + pub id2label: Option>, } struct Qwen3Attention { @@ -375,6 +377,54 @@ impl Qwen3Layer { } } +pub struct Qwen3ClassificationHead { + classifier: Linear, + span: tracing::Span, +} + +impl Qwen3ClassificationHead { + pub fn load(vb: VarBuilder, config: &Qwen3Config) -> Result { + let n_classes = match &config.id2label { + None => candle::bail!("`id2label` must be set for classifier models"), + Some(id2label) => id2label.len(), + }; + + // Try different common classification head layer names + // The tomaarsen/Qwen3-Reranker models have score.weight at the top level with no bias + let classifier = if let Ok(weight) = vb.get((n_classes, config.hidden_size), "score.weight") { + // No bias for score layer in converted Qwen3 rerankers + Linear::new(weight, None, None) + } else if let (Ok(weight), Ok(bias)) = ( + vb.pp("classifier").get((n_classes, config.hidden_size), "weight"), + vb.pp("classifier").get(n_classes, "bias") + ) { + Linear::new(weight, Some(bias), None) + } else if let (Ok(weight), Ok(bias)) = ( + vb.pp("score").get((n_classes, config.hidden_size), "weight"), + vb.pp("score").get(n_classes, "bias") + ) { + Linear::new(weight, Some(bias), None) + } else if let (Ok(weight), Ok(bias)) = ( + vb.get((n_classes, config.hidden_size), "classifier.weight"), + vb.get(n_classes, "classifier.bias") + ) { + Linear::new(weight, Some(bias), None) + } else { + candle::bail!("Could not find classification head weights. Tried: score.weight, classifier.weight"); + }; + + Ok(Self { + classifier, + span: tracing::span!(tracing::Level::TRACE, "classification_head"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + self.classifier.forward(hidden_states) + } +} + pub struct Qwen3Model { embeddings: Embedding, layers: Vec, @@ -382,6 +432,7 @@ pub struct Qwen3Model { rotary_cache: (Tensor, Tensor), rotary_dim: usize, pool: Pool, + classification_head: Option, num_attention_heads: usize, pad_token_id: u32, @@ -393,11 +444,12 @@ pub struct Qwen3Model { impl Qwen3Model { pub fn load(vb: VarBuilder, config: &Qwen3Config, model_type: ModelType) -> Result { - let pool = match model_type { + let (pool, classification_head) = match model_type { ModelType::Classifier => { - candle::bail!("`classifier` model type is not supported for Qwen3") + let classification_head = Some(Qwen3ClassificationHead::load(vb.clone(), config)?); + (Pool::Cls, classification_head) // Use CLS pooling for classification } - ModelType::Embedding(pool) => pool, + ModelType::Embedding(pool) => (pool, None), }; // The Qwen3-Reranker models contain the `model` key @@ -436,6 +488,7 @@ impl Qwen3Model { rotary_cache, rotary_dim, pool, + classification_head, pad_token_id: config.eos_token_id as u32, num_attention_heads: config.num_attention_heads, dtype: vb.dtype(), @@ -700,4 +753,16 @@ impl Model for Qwen3Model { fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } + + fn predict(&self, batch: Batch) -> Result { + match &self.classification_head { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classification_head) => { + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classification_head.forward(&pooled_embeddings) + } + } + } } diff --git a/backends/candle/tests/snapshots/test_qwen3__qwen3_reranker_single.snap b/backends/candle/tests/snapshots/test_qwen3__qwen3_reranker_single.snap new file mode 100644 index 00000000..dca4c1aa --- /dev/null +++ b/backends/candle/tests/snapshots/test_qwen3__qwen3_reranker_single.snap @@ -0,0 +1,7 @@ +--- +source: backends/candle/tests/test_qwen3.rs +assertion_line: 86 +expression: predictions_single +--- +- - 2.0719934 + diff --git a/backends/candle/tests/test_qwen3.rs b/backends/candle/tests/test_qwen3.rs index 8f6a980a..021b69f9 100644 --- a/backends/candle/tests/test_qwen3.rs +++ b/backends/candle/tests/test_qwen3.rs @@ -1,8 +1,8 @@ mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores}; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -50,3 +50,40 @@ fn test_qwen3() -> Result<()> { Ok(()) } + +#[test] +#[serial_test::serial] +fn test_qwen3_reranker() -> Result<()> { + let model_root = download_artifacts("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", None, None)?; + let tokenizer = load_tokenizer(&model_root)?; + + let backend = CandleBackend::new( + &model_root, + "float32".to_string(), + ModelType::Classifier, + None, + )?; + + let input_single = batch( + vec![tokenizer + .encode( + "What is Deep Learning?", + true, + ) + .unwrap()], + [0].to_vec(), + vec![], + ); + + let predictions: Vec> = backend + .predict(input_single)? + .into_iter() + .map(|(_, v)| v) + .collect(); + let predictions_single = SnapshotScores::from(predictions); + + let matcher = relative_matcher(); + insta::assert_yaml_snapshot!("qwen3_reranker_single", predictions_single, &matcher); + + Ok(()) +} From 1c27e67a7cd8805c7965ca36fb489af91b077cdb Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Tue, 12 Aug 2025 18:17:02 -0700 Subject: [PATCH 2/4] add support for qwen3 reranker sequence classifier Signed-off-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com> --- core/src/templates.rs | 110 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 core/src/templates.rs diff --git a/core/src/templates.rs b/core/src/templates.rs new file mode 100644 index 00000000..ad2737aa --- /dev/null +++ b/core/src/templates.rs @@ -0,0 +1,110 @@ +use std::fmt::Write; + +/// Template formatter for models that require structured prompts +pub trait TemplateFormatter { + /// Format a query-document pair for reranking + fn format_rerank( + &self, + query: &str, + document: &str, + instruction: Option<&str>, + ) -> String; +} + +/// Qwen3 reranker template formatter +pub struct Qwen3RerankerTemplate { + default_instruction: String, +} + +impl Qwen3RerankerTemplate { + pub fn new() -> Self { + Self { + default_instruction: "Select only the Documents that are semantically similar to the Query.".to_string(), + } + } +} + +impl TemplateFormatter for Qwen3RerankerTemplate { + fn format_rerank( + &self, + query: &str, + document: &str, + instruction: Option<&str>, + ) -> String { + let instruction = instruction.unwrap_or(&self.default_instruction); + + let mut result = String::with_capacity(512); + + // System prompt + result.push_str("<|im_start|>system\n"); + result.push_str("Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"); + + // User prompt with instruction, query, and document + result.push_str("<|im_start|>user\n"); + write!(&mut result, ": {}\n", instruction).unwrap(); + write!(&mut result, ": {}\n", query).unwrap(); + write!(&mut result, ": {}", document).unwrap(); + result.push_str("<|im_end|>\n"); + + // Assistant prompt to trigger reasoning + result.push_str("<|im_start|>assistant\n"); + result.push_str("\n\n\n\n"); + + result + } +} + +/// Check if a model requires template formatting +pub fn requires_template(model_name: &str) -> bool { + // Check if this is a Qwen3 sequence classification model + model_name.contains("Qwen3") && model_name.contains("seq-cls") +} + +/// Get the appropriate template formatter for a model +pub fn get_template_formatter(model_name: &str) -> Option> { + if requires_template(model_name) { + Some(Box::new(Qwen3RerankerTemplate::new())) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_qwen3_template() { + let template = Qwen3RerankerTemplate::new(); + let formatted = template.format_rerank( + "What is Deep Learning?", + "Deep Learning is a branch of machine learning", + None, + ); + + assert!(formatted.contains("<|im_start|>system")); + assert!(formatted.contains(": What is Deep Learning?")); + assert!(formatted.contains(": Deep Learning is a branch of machine learning")); + assert!(formatted.contains("")); + } + + #[test] + fn test_custom_instruction() { + let template = Qwen3RerankerTemplate::new(); + let formatted = template.format_rerank( + "test query", + "test doc", + Some("Custom instruction"), + ); + + assert!(formatted.contains(": Custom instruction")); + } + + #[test] + fn test_requires_template() { + assert!(requires_template("tomaarsen/Qwen3-Reranker-0.6B-seq-cls")); + assert!(requires_template("Qwen3-Something-seq-cls")); + assert!(!requires_template("BAAI/bge-reranker")); + assert!(!requires_template("Qwen3-Embed")); + } +} \ No newline at end of file From b3cf408caafecf22ccbcb641530ef60733c6a1a4 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Wed, 13 Aug 2025 14:22:40 -0700 Subject: [PATCH 3/4] Enhance Qwen3 reranker functionality with template support - Added optional fields `instruction` and `use_template` to `RerankRequest` for custom instructions and template usage. - Updated rerank logic to apply templates conditionally based on model type and user input. - Introduced template formatting for Qwen3 models to improve reranking results. This update allows for more flexible and context-aware reranking capabilities, particularly for Qwen3 models. --- backends/candle/src/models/mod.rs | 2 +- core/src/lib.rs | 1 + proto/tei.proto | 4 ++++ router/src/http/server.rs | 28 ++++++++++++++++++++++++++-- router/src/http/types.rs | 10 ++++++++++ 5 files changed, 42 insertions(+), 3 deletions(-) diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index 05bcacb3..65fb8744 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -61,7 +61,7 @@ pub use modernbert::{ModernBertConfig, ModernBertModel}; pub use mpnet::{MPNetConfig, MPNetModel}; pub use nomic::{NomicBertModel, NomicConfig}; pub use qwen2::Qwen2Config; -pub use qwen3::{Qwen3Config, Qwen3Model, Qwen3ClassificationHead}; +pub use qwen3::{Qwen3Config, Qwen3Model}; use text_embeddings_backend_core::Batch; #[cfg(feature = "cuda")] diff --git a/core/src/lib.rs b/core/src/lib.rs index 5ce449e2..e09aded8 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,6 +1,7 @@ pub mod download; pub mod infer; pub mod queue; +pub mod templates; pub mod tokenization; use text_embeddings_backend::BackendError; diff --git a/proto/tei.proto b/proto/tei.proto index ea96457c..4ea37ab6 100644 --- a/proto/tei.proto +++ b/proto/tei.proto @@ -152,6 +152,8 @@ message RerankRequest { bool raw_scores = 4; bool return_text = 5; TruncationDirection truncation_direction = 6; + optional string instruction = 7; + optional bool use_template = 8; } message RerankStreamRequest{ @@ -163,6 +165,8 @@ message RerankStreamRequest{ // The server will only consider the first value bool return_text = 5; TruncationDirection truncation_direction = 6; + optional string instruction = 7; + optional bool use_template = 8; } message Rank { diff --git a/router/src/http/server.rs b/router/src/http/server.rs index a22af962..e92e38a1 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -345,13 +345,35 @@ async fn rerank( ErrorResponse::from(err) })?; + // Apply template if needed for Qwen3 rerankers + let model_id = info.model_id.clone(); + let use_template = req.use_template.unwrap_or_else(|| { + // Default to true for Qwen3 sequence classification models + text_embeddings_core::templates::requires_template(&model_id) + }); + // Closure for rerank - let rerank_inner = move |query: String, text: String, truncate: bool, infer: Infer| async move { + let rerank_inner = move |query: String, text: String, truncate: bool, instruction: Option, model_id: String, infer: Infer| async move { let permit = infer.acquire_permit().await; + // Apply template formatting if needed + let input: text_embeddings_core::tokenization::EncodingInput = if use_template { + if let Some(formatter) = text_embeddings_core::templates::get_template_formatter(&model_id) { + // Format as single string with template + let formatted = formatter.format_rerank(&query, &text, instruction.as_deref()); + formatted.into() + } else { + // No template, use dual input + (query, text).into() + } + } else { + // Template disabled, use dual input + (query, text).into() + }; + let response = infer .predict( - (query, text), + input, truncate, req.truncation_direction.into(), req.raw_scores, @@ -404,6 +426,8 @@ async fn rerank( req.query.clone(), text.clone(), truncate, + req.instruction.clone(), + model_id.clone(), local_infer.0, )) } diff --git a/router/src/http/types.rs b/router/src/http/types.rs index dedaab60..abde2fce 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -255,6 +255,16 @@ pub(crate) struct RerankRequest { #[serde(default)] #[schema(default = "false", example = "false")] pub return_text: bool, + /// Custom instruction for reranking (e.g., "Select only semantically similar documents") + /// Used with models that support templated prompts like Qwen3 rerankers + #[serde(default)] + #[schema(default = "null", example = "Select only the Documents that are semantically similar to the Query.", nullable = true)] + pub instruction: Option, + /// Whether to use the model's chat template for formatting. + /// Defaults to true for models that support it (e.g., Qwen3 rerankers) + #[serde(default)] + #[schema(default = "null", example = "true", nullable = true)] + pub use_template: Option, } #[derive(Serialize, ToSchema)] From 6d4c27220bf99627f267f7ea9e94b7cb6d61b14c Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Wed, 13 Aug 2025 14:41:20 -0700 Subject: [PATCH 4/4] Changed pooling method from `Pool::Cls` to `Pool::LastToken` for the classification head in the Qwen3 model. --- backends/candle/src/models/qwen3.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/candle/src/models/qwen3.rs b/backends/candle/src/models/qwen3.rs index 3c979fdc..44003399 100644 --- a/backends/candle/src/models/qwen3.rs +++ b/backends/candle/src/models/qwen3.rs @@ -447,7 +447,7 @@ impl Qwen3Model { let (pool, classification_head) = match model_type { ModelType::Classifier => { let classification_head = Some(Qwen3ClassificationHead::load(vb.clone(), config)?); - (Pool::Cls, classification_head) // Use CLS pooling for classification + (Pool::LastToken, classification_head) } ModelType::Embedding(pool) => (pool, None), };