Skip to content

Commit 1fd8245

Browse files
authored
Merge pull request #63 from multimindlab/ensemble_cli_error_solved
Add full Ollama provider integration for text generation, code review…
2 parents 233f6c1 + 95cf6d0 commit 1fd8245

File tree

8 files changed

+1468
-110
lines changed

8 files changed

+1468
-110
lines changed

examples/cli/ensemble_cli.py

Lines changed: 478 additions & 69 deletions
Large diffs are not rendered by default.

multimind/core/router.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,28 @@ async def route(
115115
**kwargs
116116
) -> Union[GenerationResult, EmbeddingResult, ImageAnalysisResult]:
117117
"""Route a request to the appropriate provider(s)."""
118-
if task_type not in self.task_configs:
118+
provider_override = kwargs.get("provider")
119+
if not provider_override and task_type not in self.task_configs:
119120
raise ValueError(f"No configuration found for task type: {task_type}")
120121

121-
config = self.task_configs[task_type]
122122
start_time = time.time()
123123

124124
try:
125-
if config.routing_strategy == RoutingStrategy.ENSEMBLE:
126-
result = await self._handle_ensemble(task_type, input_data, config, **kwargs)
127-
elif config.routing_strategy == RoutingStrategy.CASCADE:
128-
result = await self._handle_cascade(task_type, input_data, config, **kwargs)
125+
if provider_override:
126+
result = await self._route_specific_provider(
127+
provider_override,
128+
task_type,
129+
input_data,
130+
**kwargs
131+
)
129132
else:
130-
result = await self._handle_single_provider(task_type, input_data, config, **kwargs)
133+
config = self.task_configs[task_type]
134+
if config.routing_strategy == RoutingStrategy.ENSEMBLE:
135+
result = await self._handle_ensemble(task_type, input_data, config, **kwargs)
136+
elif config.routing_strategy == RoutingStrategy.CASCADE:
137+
result = await self._handle_cascade(task_type, input_data, config, **kwargs)
138+
else:
139+
result = await self._handle_single_provider(task_type, input_data, config, **kwargs)
131140

132141
# Record successful request metrics
133142
latency_ms = (time.time() - start_time) * 1000
@@ -179,6 +188,35 @@ async def route(
179188
)
180189
raise
181190

191+
async def _route_specific_provider(
192+
self,
193+
provider_name: str,
194+
task_type: TaskType,
195+
input_data: Any,
196+
**kwargs
197+
) -> Union[GenerationResult, EmbeddingResult, ImageAnalysisResult]:
198+
"""
199+
Route directly to a specific provider when explicitly requested.
200+
This bypasses task configuration while still leveraging the same execution pipeline.
201+
"""
202+
if provider_name not in self.providers:
203+
raise ValueError(f"Provider '{provider_name}' is not registered with the router")
204+
205+
single_provider_config = TaskConfig(
206+
preferred_providers=[provider_name],
207+
fallback_providers=[],
208+
routing_strategy=RoutingStrategy.COST_BASED
209+
)
210+
call_kwargs = dict(kwargs)
211+
call_kwargs.pop("provider", None)
212+
return await self._handle_single_provider(
213+
task_type,
214+
input_data,
215+
single_provider_config,
216+
use_adaptive_routing=False,
217+
**call_kwargs
218+
)
219+
182220
async def _handle_single_provider(
183221
self,
184222
task_type: TaskType,

multimind/ensemble/advanced.py

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -127,23 +127,24 @@ async def _weighted_voting(
127127
) -> EnsembleResult:
128128
"""Combine results using weighted voting (adaptive if enabled)."""
129129
if use_adaptive_weights or not weights:
130-
providers = [result.provider for result in results]
130+
providers = [self._get_provider_name(result) for result in results]
131131
weights = self.performance_tracker.get_all_weights(providers)
132132
# Normalize weights
133133
total_weight = sum(weights.values())
134134
normalized_weights = {k: v/total_weight for k, v in weights.items()}
135135
# Calculate weighted scores for each result
136136
weighted_scores = []
137137
for result in results:
138-
weight = normalized_weights.get(result.provider, 0.0)
138+
provider_name = self._get_provider_name(result)
139+
weight = normalized_weights.get(provider_name, 0.0)
139140
weighted_scores.append((result, weight))
140141
# Select result with highest weight
141142
best_result, best_weight = max(weighted_scores, key=lambda x: x[1])
142143
return EnsembleResult(
143144
result=best_result,
144145
confidence=ConfidenceScore(
145146
score=best_weight,
146-
explanation=f"Selected result from {best_result.provider} with adaptive weight {best_weight:.2f}"
147+
explanation=f"Selected result from {self._get_provider_name(best_result)} with adaptive weight {best_weight:.2f}"
147148
),
148149
provider_votes=normalized_weights
149150
)
@@ -171,15 +172,15 @@ async def _confidence_cascade(
171172
return EnsembleResult(
172173
result=result,
173174
confidence=confidence,
174-
provider_votes={r.provider: c.score for r, c in confidence_scores}
175+
provider_votes={self._get_provider_name(r): c.score for r, c in confidence_scores}
175176
)
176177

177178
# If no result meets threshold, return highest confidence
178179
best_result, best_confidence = confidence_scores[0]
179180
return EnsembleResult(
180181
result=best_result,
181182
confidence=best_confidence,
182-
provider_votes={r.provider: c.score for r, c in confidence_scores}
183+
provider_votes={self._get_provider_name(r): c.score for r, c in confidence_scores}
183184
)
184185

185186
async def _parallel_voting(
@@ -203,7 +204,7 @@ async def _parallel_voting(
203204

204205
# Normalize scores
205206
total_score = sum(score for _, score in scores)
206-
normalized_scores = {r.provider: s/total_score for r, s in scores}
207+
normalized_scores = {self._get_provider_name(r): s/total_score for r, s in scores}
207208

208209
# Select best result
209210
best_result, best_score = max(scores, key=lambda x: x[1])
@@ -212,7 +213,7 @@ async def _parallel_voting(
212213
result=best_result,
213214
confidence=ConfidenceScore(
214215
score=best_score,
215-
explanation=f"Selected result from {best_result.provider} with LLM evaluation score {best_score:.2f}"
216+
explanation=f"Selected result from {self._get_provider_name(best_result)} with LLM evaluation score {best_score:.2f}"
216217
),
217218
provider_votes=normalized_scores
218219
)
@@ -234,7 +235,7 @@ async def _majority_voting(
234235
# Fallback to string equality if no embedder available
235236
embedder = None
236237

237-
texts = [str(r.result) for r in results]
238+
texts = [self._extract_result_content(r) for r in results]
238239
if embedder is not None:
239240
embeddings = embedder.encode(texts, convert_to_tensor=True)
240241
import torch
@@ -264,19 +265,19 @@ def get_score(r):
264265
vote_count = len(largest_group)
265266
total_votes = len(results)
266267
explanation = f"Selected result by semantic majority voting: {vote_count}/{total_votes} semantically similar."
267-
provider_votes = {r.provider: 1.0 if idx in largest_group else 0.0 for idx, r in enumerate(results)}
268+
provider_votes = {self._get_provider_name(r): 1.0 if idx in largest_group else 0.0 for idx, r in enumerate(results)}
268269
else:
269270
# Fallback: string equality
270271
result_counts = {}
271272
for result in results:
272-
key = str(result.result)
273+
key = self._extract_result_content(result)
273274
if key not in result_counts:
274275
result_counts[key] = (result, 0)
275276
result_counts[key] = (result, result_counts[key][1] + 1)
276277
best_result, vote_count = max(result_counts.values(), key=lambda x: x[1])
277278
total_votes = len(results)
278279
explanation = f"Selected result with {vote_count}/{total_votes} votes (string equality fallback)"
279-
provider_votes = {r.provider: 1.0 for r in results}
280+
provider_votes = {self._get_provider_name(r): 1.0 for r in results}
280281
return EnsembleResult(
281282
result=best_result,
282283
confidence=ConfidenceScore(
@@ -303,15 +304,15 @@ async def _rank_based(
303304
borda_scores = {}
304305
for result, ranking in zip(results, rankings):
305306
score = self._calculate_borda_score(ranking, len(results))
306-
borda_scores[result.provider] = score
307+
borda_scores[self._get_provider_name(result)] = score
307308

308309
# Normalize scores
309310
total_score = sum(borda_scores.values())
310311
normalized_scores = {k: v/total_score for k, v in borda_scores.items()}
311312

312313
# Select result with highest Borda score
313314
best_provider = max(borda_scores.items(), key=lambda x: x[1])[0]
314-
best_result = next(r for r in results if r.provider == best_provider)
315+
best_result = next(r for r in results if self._get_provider_name(r) == best_provider)
315316

316317
return EnsembleResult(
317318
result=best_result,
@@ -322,16 +323,40 @@ async def _rank_based(
322323
provider_votes=normalized_scores
323324
)
324325

326+
def _extract_result_content(self, result: Union[GenerationResult, EmbeddingResult, ImageAnalysisResult]) -> str:
327+
"""Extract text content from any result type for evaluation."""
328+
if isinstance(result, GenerationResult):
329+
return result.text
330+
elif isinstance(result, EmbeddingResult):
331+
return f"Embedding vector of length {len(result.embedding)}"
332+
elif isinstance(result, ImageAnalysisResult):
333+
# Combine text, captions, and objects for evaluation
334+
parts = []
335+
if result.text:
336+
parts.append(f"Text: {result.text}")
337+
if result.captions:
338+
parts.append(f"Captions: {', '.join(result.captions)}")
339+
if result.objects:
340+
parts.append(f"Objects: {len(result.objects)} detected")
341+
return " | ".join(parts) if parts else "No content extracted"
342+
else:
343+
return str(result)
344+
345+
def _get_provider_name(self, result: Union[GenerationResult, EmbeddingResult, ImageAnalysisResult]) -> str:
346+
"""Extract provider name from any result type."""
347+
return getattr(result, 'provider', None) or getattr(result, 'provider_name', 'unknown')
348+
325349
async def _evaluate_confidence(
326350
self,
327351
result: Union[GenerationResult, EmbeddingResult, ImageAnalysisResult],
328352
task_type: TaskType,
329353
**kwargs
330354
) -> ConfidenceScore:
331355
"""Evaluate confidence in a result using LLM."""
356+
content = self._extract_result_content(result)
332357
prompt = f"""
333358
Evaluate the confidence in this {task_type} result:
334-
{result.result}
359+
{content}
335360
336361
Consider:
337362
1. Completeness of the response
@@ -342,21 +367,33 @@ async def _evaluate_confidence(
342367
Provide a confidence score (0.0 to 1.0) and explanation.
343368
"""
344369

370+
provider_name = self._get_provider_name(result)
371+
evaluation_models = kwargs.get("evaluation_models", {})
372+
evaluation_providers = kwargs.get("evaluation_providers", {})
373+
default_model = kwargs.get("evaluation_model", "gpt-4")
374+
eval_model = evaluation_models.get(provider_name, default_model)
375+
eval_provider = evaluation_providers.get(provider_name, kwargs.get("evaluation_provider", provider_name))
376+
route_kwargs = {
377+
k: v for k, v in kwargs.items()
378+
if k not in {"evaluation_models", "evaluation_providers", "evaluation_model", "evaluation_provider"}
379+
}
345380
evaluation = await self.router.route(
346381
TaskType.TEXT_GENERATION,
347382
prompt,
348-
model="gpt-4",
349-
**kwargs
383+
provider=eval_provider,
384+
model=eval_model,
385+
**route_kwargs
350386
)
351387

352388
# Parse confidence score from evaluation
353-
score = self._parse_confidence_score(evaluation.result)
354-
explanation = self._parse_confidence_explanation(evaluation.result)
389+
eval_content = self._extract_result_content(evaluation)
390+
score = self._parse_confidence_score(eval_content)
391+
explanation = self._parse_confidence_explanation(eval_content)
355392

356393
return ConfidenceScore(
357394
score=score,
358395
explanation=explanation,
359-
metadata={"raw_evaluation": evaluation.result}
396+
metadata={"raw_evaluation": eval_content}
360397
)
361398

362399
async def _evaluate_with_llm(
@@ -366,9 +403,10 @@ async def _evaluate_with_llm(
366403
**kwargs
367404
) -> str:
368405
"""Evaluate a result using LLM."""
406+
content = self._extract_result_content(result)
369407
prompt = f"""
370408
Evaluate this {task_type} result:
371-
{result.result}
409+
{content}
372410
373411
Consider:
374412
1. Accuracy and correctness
@@ -379,14 +417,25 @@ async def _evaluate_with_llm(
379417
Provide a detailed evaluation with a numerical score (0-100).
380418
"""
381419

420+
provider_name = self._get_provider_name(result)
421+
evaluation_models = kwargs.get("evaluation_models", {})
422+
evaluation_providers = kwargs.get("evaluation_providers", {})
423+
default_model = kwargs.get("evaluation_model", "gpt-4")
424+
eval_model = evaluation_models.get(provider_name, default_model)
425+
eval_provider = evaluation_providers.get(provider_name, kwargs.get("evaluation_provider", provider_name))
426+
route_kwargs = {
427+
k: v for k, v in kwargs.items()
428+
if k not in {"evaluation_models", "evaluation_providers", "evaluation_model", "evaluation_provider"}
429+
}
382430
evaluation = await self.router.route(
383431
TaskType.TEXT_GENERATION,
384432
prompt,
385-
model="gpt-4",
386-
**kwargs
433+
provider=eval_provider,
434+
model=eval_model,
435+
**route_kwargs
387436
)
388437

389-
return evaluation.result
438+
return self._extract_result_content(evaluation)
390439

391440
async def _get_provider_ranking(
392441
self,
@@ -395,9 +444,10 @@ async def _get_provider_ranking(
395444
**kwargs
396445
) -> List[str]:
397446
"""Get ranking of results from a provider."""
447+
content = self._extract_result_content(result)
398448
prompt = f"""
399449
Rank the following {task_type} results from best to worst:
400-
{result.result}
450+
{content}
401451
402452
Consider:
403453
1. Quality and accuracy
@@ -408,14 +458,36 @@ async def _get_provider_ranking(
408458
Provide a ranked list of provider names.
409459
"""
410460

461+
provider_name = self._get_provider_name(result)
462+
evaluation_models = kwargs.get("evaluation_models", {})
463+
evaluation_providers = kwargs.get("evaluation_providers", {})
464+
default_model = kwargs.get("evaluation_model", "gpt-4")
465+
466+
# Use provider-specific model if available, otherwise use smart defaults
467+
eval_model = evaluation_models.get(provider_name)
468+
if not eval_model:
469+
# Use provider-appropriate default models
470+
if provider_name == "ollama":
471+
eval_model = "mistral" # Ollama doesn't have gpt-4
472+
elif provider_name == "anthropic" or provider_name == "claude":
473+
eval_model = "claude-3-sonnet"
474+
else:
475+
eval_model = default_model # Use gpt-4 for OpenAI and others
476+
477+
eval_provider = evaluation_providers.get(provider_name, kwargs.get("evaluation_provider", provider_name))
478+
route_kwargs = {
479+
k: v for k, v in kwargs.items()
480+
if k not in {"evaluation_models", "evaluation_providers", "evaluation_model", "evaluation_provider"}
481+
}
411482
ranking = await self.router.route(
412483
TaskType.TEXT_GENERATION,
413484
prompt,
414-
model="gpt-4",
415-
**kwargs
485+
provider=eval_provider,
486+
model=eval_model,
487+
**route_kwargs
416488
)
417489

418-
return self._parse_ranking(ranking.result)
490+
return self._parse_ranking(self._extract_result_content(ranking))
419491

420492
def _parse_confidence_score(self, evaluation: str) -> float:
421493
"""Parse confidence score from evaluation text."""
@@ -519,7 +591,7 @@ def tune_weights_with_optuna(self, results, task_type, eval_fn, n_trials=30):
519591
if not OPTUNA_AVAILABLE:
520592
raise ImportError("Optuna is required for hyperparameter tuning. Please install optuna.")
521593

522-
providers = [r.provider for r in results]
594+
providers = [self._get_provider_name(r) for r in results]
523595
def objective(trial):
524596
weights = {p: trial.suggest_float(f"weight_{p}", 0.01, 1.0) for p in providers}
525597
# Normalize

multimind/providers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
from .claude import ClaudeProvider
88
from .openai import OpenAIProvider
9+
from .ollama import OllamaProvider
910

1011
__all__ = [
1112
"ClaudeProvider",
12-
"OpenAIProvider"
13+
"OpenAIProvider",
14+
"OllamaProvider"
1315
]

multimind/providers/claude.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ async def chat(
9797
) / 1000 # Convert to USD
9898

9999
return GenerationResult(
100+
text=result,
101+
tokens_used=tokens_used,
100102
provider_name="claude",
101103
model_name=model,
102-
result=result,
103-
tokens_used=tokens_used,
104104
latency_ms=latency_ms,
105105
cost_estimate_usd=cost
106106
)

0 commit comments

Comments
 (0)