Skip to content

Commit aab9aa7

Browse files
authored
[Refactor] Remove ClassifyCategory and add embedding classifier config (#620)
* [Refactor] Remove ClassifyCategory and add embedding classifier config - Removed ClassifyCategory method and all references, keeping only ClassifyCategoryWithEntropy - Integrated keywordEmbeddingClassifier call within ClassifyCategoryWithEntropy - Removed ClassifyAndSelectBestModel method and its tests - Removed findCategoryForClassification helper method - Updated all callers to use ClassifyCategoryWithEntropy with 4 return values - Added embedding classifier configuration in config/intelligent-routing/in-tree/embedding.yaml - Updated tests to use ClassifyCategoryWithEntropy with proper mock data This refactoring simplifies the classification API by consolidating to a single entropy-based classification method and adds comprehensive embedding-based classification rules configuration. Signed-off-by: bitliu <[email protected]> * lint Signed-off-by: bitliu <[email protected]> --------- Signed-off-by: bitliu <[email protected]>
1 parent d90deab commit aab9aa7

File tree

6 files changed

+243
-191
lines changed

6 files changed

+243
-191
lines changed
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
bert_model:
2+
model_id: models/all-MiniLM-L12-v2
3+
threshold: 0.6
4+
use_cpu: true
5+
6+
semantic_cache:
7+
enabled: true
8+
backend_type: "memory" # Options: "memory", "milvus", or "hybrid"
9+
similarity_threshold: 0.8
10+
max_entries: 1000 # Only applies to memory backend
11+
ttl_seconds: 3600
12+
eviction_policy: "fifo"
13+
# HNSW index configuration (for memory backend only)
14+
use_hnsw: true # Enable HNSW index for faster similarity search
15+
hnsw_m: 16 # Number of bi-directional links (higher = better recall, more memory)
16+
hnsw_ef_construction: 200 # Construction parameter (higher = better quality, slower build)
17+
18+
# Hybrid cache configuration (when backend_type: "hybrid")
19+
# Combines in-memory HNSW for fast search with Milvus for scalable storage
20+
# max_memory_entries: 100000 # Max entries in HNSW index (default: 100,000)
21+
# backend_config_path: "config/milvus.yaml" # Path to Milvus config
22+
23+
# Embedding model for semantic similarity matching
24+
# Options: "bert" (fast, 384-dim), "qwen3" (high quality, 1024-dim, 32K context), "gemma" (balanced, 768-dim, 8K context)
25+
# Default: "bert" (fastest, lowest memory)
26+
embedding_model: "bert"
27+
28+
tools:
29+
enabled: true
30+
top_k: 3
31+
similarity_threshold: 0.2
32+
tools_db_path: "config/tools_db.json"
33+
fallback_to_empty: true
34+
35+
prompt_guard:
36+
enabled: true # Global default - can be overridden per category with jailbreak_enabled
37+
use_modernbert: true
38+
model_id: "models/jailbreak_classifier_modernbert-base_model"
39+
threshold: 0.7
40+
use_cpu: true
41+
jailbreak_mapping_path: "models/jailbreak_classifier_modernbert-base_model/jailbreak_type_mapping.json"
42+
43+
# vLLM Endpoints Configuration
44+
# IMPORTANT: 'address' field must be a valid IP address (IPv4 or IPv6)
45+
# Supported formats: 127.0.0.1, 192.168.1.1, ::1, 2001:db8::1
46+
# NOT supported: domain names (example.com), protocol prefixes (http://), paths (/api), ports in address (use 'port' field)
47+
vllm_endpoints:
48+
- name: "endpoint1"
49+
address: "172.28.0.20" # Static IPv4 of llm-katan within docker compose network
50+
port: 8002
51+
weight: 1
52+
53+
model_config:
54+
"qwen3":
55+
reasoning_family: "qwen3" # This model uses Qwen-3 reasoning syntax
56+
preferred_endpoints: ["endpoint1"]
57+
pii_policy:
58+
allow_by_default: true
59+
60+
# Classifier configuration
61+
classifier:
62+
category_model:
63+
model_id: "models/category_classifier_modernbert-base_model"
64+
use_modernbert: true
65+
threshold: 0.6
66+
use_cpu: true
67+
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
68+
pii_model:
69+
model_id: "models/pii_classifier_modernbert-base_presidio_token_model"
70+
use_modernbert: true
71+
threshold: 0.7
72+
use_cpu: true
73+
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"
74+
75+
# Embedding-based classification rules
76+
# These rules use semantic similarity between query text and keywords
77+
embedding_rules:
78+
- category: "technical_support"
79+
threshold: 0.75
80+
keywords:
81+
- "how to configure the system"
82+
- "installation guide"
83+
- "troubleshooting steps"
84+
- "error message explanation"
85+
- "setup instructions"
86+
aggregation_method: "max" # Options: "max", "avg", "any"
87+
model: "auto" # Options: "auto", "qwen3", "gemma"
88+
dimension: 768 # Options: 128, 256, 512, 768, 1024
89+
quality_priority: 0.7 # 0.0-1.0, only for "auto" model
90+
latency_priority: 0.3 # 0.0-1.0, only for "auto" model
91+
92+
- category: "product_inquiry"
93+
threshold: 0.70
94+
keywords:
95+
- "product features and specifications"
96+
- "pricing information"
97+
- "availability and stock"
98+
- "product comparison"
99+
- "warranty details"
100+
aggregation_method: "avg"
101+
model: "gemma"
102+
dimension: 768
103+
104+
- category: "account_management"
105+
threshold: 0.72
106+
keywords:
107+
- "password reset"
108+
- "account settings"
109+
- "profile update"
110+
- "subscription management"
111+
- "billing information"
112+
aggregation_method: "max"
113+
model: "qwen3"
114+
dimension: 1024
115+
116+
- category: "general_inquiry"
117+
threshold: 0.65
118+
keywords:
119+
- "general question"
120+
- "information request"
121+
- "help needed"
122+
- "customer service"
123+
aggregation_method: "any"
124+
model: "auto"
125+
dimension: 512
126+
quality_priority: 0.5
127+
latency_priority: 0.5
128+
129+
# Categories with model scores
130+
categories:
131+
# Embedding-based categories
132+
- name: technical_support
133+
system_prompt: "You are a technical support specialist. Provide detailed, step-by-step guidance for technical issues. Use clear explanations and include relevant troubleshooting steps."
134+
model_scores:
135+
- model: qwen3
136+
score: 0.9
137+
use_reasoning: true
138+
jailbreak_enabled: true
139+
pii_detection_enabled: true
140+
141+
- name: product_inquiry
142+
system_prompt: "You are a product specialist. Provide accurate information about products, features, pricing, and availability. Be helpful and informative."
143+
model_scores:
144+
- model: qwen3
145+
score: 0.85
146+
use_reasoning: false
147+
jailbreak_enabled: true
148+
pii_detection_enabled: false
149+
150+
- name: account_management
151+
system_prompt: "You are an account management assistant. Help users with account-related tasks such as password resets, profile updates, and subscription management. Prioritize security and privacy."
152+
model_scores:
153+
- model: qwen3
154+
score: 0.88
155+
use_reasoning: false
156+
jailbreak_enabled: true
157+
pii_detection_enabled: true
158+
159+
- name: general_inquiry
160+
system_prompt: "You are a helpful general assistant. Answer questions clearly and concisely. If you need more information, ask clarifying questions."
161+
model_scores:
162+
- model: qwen3
163+
score: 0.75
164+
use_reasoning: false
165+
jailbreak_enabled: true
166+
pii_detection_enabled: false
167+
168+
# Embedding Models Configuration
169+
# These models provide intelligent embedding generation with automatic routing:
170+
# - Qwen3-Embedding-0.6B: Up to 32K context, high quality, 1024-dim embeddings
171+
# - EmbeddingGemma-300M: Up to 8K context, fast inference, Matryoshka support (768/512/256/128)
172+
embedding_models:
173+
qwen3_model_path: "models/Qwen3-Embedding-0.6B"
174+
gemma_model_path: "models/embeddinggemma-300m"
175+
use_cpu: true # Set to false for GPU acceleration (requires CUDA)
176+
177+
# Default model for fallback
178+
default_model: "qwen3"
179+
180+
# Entropy-based reasoning configuration
181+
entropy_threshold: 0.5 # Threshold for entropy-based reasoning decision
182+
high_entropy_threshold: 0.8 # High entropy threshold for complex queries
183+

src/semantic-router/pkg/classification/classifier.go

Lines changed: 13 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -382,91 +382,6 @@ func (c *Classifier) initializeCategoryClassifier() error {
382382
return c.categoryInitializer.Init(c.Config.CategoryModel.ModelID, c.Config.CategoryModel.UseCPU, numClasses)
383383
}
384384

385-
// ClassifyCategory performs category classification on the given text
386-
func (c *Classifier) ClassifyCategory(text string) (string, float64, error) {
387-
// Try keyword classifier first
388-
if c.keywordClassifier != nil {
389-
category, confidence, err := c.keywordClassifier.Classify(text)
390-
if err != nil {
391-
return "", 0.0, err
392-
}
393-
if category != "" {
394-
return category, confidence, nil
395-
}
396-
}
397-
// TODO: more sophiscated fusion engine needs to be designed and implemented to combine classifiers' results
398-
// Try embedding based similarity classification if properly configured
399-
if c.keywordEmbeddingClassifier != nil {
400-
category, confidence, err := c.keywordEmbeddingClassifier.Classify(text)
401-
if err != nil {
402-
return "", 0.0, err
403-
}
404-
if category != "" {
405-
return category, confidence, nil
406-
}
407-
}
408-
// Try in-tree first if properly configured
409-
if c.IsCategoryEnabled() && c.categoryInference != nil {
410-
return c.classifyCategoryInTree(text)
411-
}
412-
413-
// If in-tree classifier was initialized but config is now invalid, return specific error
414-
if c.categoryInference != nil && !c.IsCategoryEnabled() {
415-
return "", 0.0, fmt.Errorf("category classification is not properly configured")
416-
}
417-
418-
// Fall back to MCP
419-
if c.IsMCPCategoryEnabled() && c.mcpCategoryInference != nil {
420-
return c.classifyCategoryMCP(text)
421-
}
422-
423-
return "", 0.0, fmt.Errorf("no category classification method available")
424-
}
425-
426-
// classifyCategoryInTree performs category classification using in-tree model
427-
func (c *Classifier) classifyCategoryInTree(text string) (string, float64, error) {
428-
if !c.IsCategoryEnabled() {
429-
return "", 0.0, fmt.Errorf("category classification is not properly configured")
430-
}
431-
432-
// Use appropriate classifier based on configuration
433-
var result candle_binding.ClassResult
434-
var err error
435-
436-
start := time.Now()
437-
result, err = c.categoryInference.Classify(text)
438-
metrics.RecordClassifierLatency("category", time.Since(start).Seconds())
439-
440-
if err != nil {
441-
return "", 0.0, fmt.Errorf("classification error: %w", err)
442-
}
443-
444-
logging.Infof("Classification result: class=%d, confidence=%.4f", result.Class, result.Confidence)
445-
446-
// Check confidence threshold
447-
if result.Confidence < c.Config.CategoryModel.Threshold {
448-
logging.Infof("Classification confidence (%.4f) below threshold (%.4f)",
449-
result.Confidence, c.Config.CategoryModel.Threshold)
450-
return "", float64(result.Confidence), nil
451-
}
452-
453-
// Convert class index to category name (MMLU-Pro)
454-
categoryName, ok := c.CategoryMapping.GetCategoryFromIndex(result.Class)
455-
if !ok {
456-
logging.Warnf("Class index %d not found in category mapping", result.Class)
457-
return "", float64(result.Confidence), nil
458-
}
459-
460-
// Translate to generic category if mapping is configured
461-
genericCategory := c.translateMMLUToGeneric(categoryName)
462-
463-
// Record the category classification metric using generic name when available
464-
metrics.RecordCategoryClassification(genericCategory)
465-
466-
logging.Infof("Classified as category: %s (mmlu=%s)", genericCategory, categoryName)
467-
return genericCategory, float64(result.Confidence), nil
468-
}
469-
470385
// IsJailbreakEnabled checks if jailbreak detection is enabled and properly configured
471386
func (c *Classifier) IsJailbreakEnabled() bool {
472387
return c.Config.PromptGuard.Enabled && c.Config.PromptGuard.ModelID != "" && c.Config.PromptGuard.JailbreakMappingPath != "" && c.JailbreakMapping != nil
@@ -611,6 +526,19 @@ func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64,
611526
}
612527
}
613528

529+
// Try embedding based similarity classification if properly configured
530+
if c.keywordEmbeddingClassifier != nil {
531+
category, confidence, err := c.keywordEmbeddingClassifier.Classify(text)
532+
if err != nil {
533+
return "", 0.0, entropy.ReasoningDecision{}, err
534+
}
535+
if category != "" {
536+
// Keyword embedding matched - determine reasoning mode from category configuration
537+
reasoningDecision := c.makeReasoningDecisionForKeywordCategory(category)
538+
return category, confidence, reasoningDecision, nil
539+
}
540+
}
541+
614542
// Try in-tree first if properly configured
615543
if c.IsCategoryEnabled() && c.categoryInference != nil {
616544
return c.classifyCategoryWithEntropyInTree(text)
@@ -926,29 +854,6 @@ func (c *Classifier) AnalyzeContentForPIIWithThreshold(contentList []string, thr
926854
return hasPII, analysisResults, nil
927855
}
928856

929-
// ClassifyAndSelectBestModel performs classification and selects the best model for the query
930-
func (c *Classifier) ClassifyAndSelectBestModel(query string) string {
931-
// If no categories defined, return default model
932-
if len(c.Config.Categories) == 0 {
933-
return c.Config.DefaultModel
934-
}
935-
936-
// First, classify the text to determine the category
937-
categoryName, confidence, err := c.ClassifyCategory(query)
938-
if err != nil {
939-
logging.Errorf("Classification error: %v, falling back to default model", err)
940-
return c.Config.DefaultModel
941-
}
942-
943-
if categoryName == "" {
944-
logging.Infof("Classification confidence (%.4f) below threshold, using default model", confidence)
945-
return c.Config.DefaultModel
946-
}
947-
948-
// Then select the best model from the determined category based on score and TTFT
949-
return c.SelectBestModelForCategory(categoryName)
950-
}
951-
952857
// SelectBestModelForCategory selects the best model from a category based on score and TTFT
953858
func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
954859
cat := c.findCategory(categoryName)

0 commit comments

Comments
 (0)