Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ Pillow>=10.0.0
mammoth>=1.0.0
python-docx>=1.1.0

# LiteLLM AI gateway (100+ providers) - OPTIONAL
# pip install litellm>=1.65,<1.85

# Chatterbox TTS (GPU-accelerated local TTS) - OPTIONAL
# These dependencies have strict numpy version requirements that may conflict.
# Install separately if needed:
Expand Down
10 changes: 9 additions & 1 deletion src/core/llm/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .providers.mistral import MistralProvider
from .providers.deepseek import DeepSeekProvider
from .providers.poe import PoeProvider
from .providers.litellm import LiteLLMProvider


def create_llm_provider(provider_type: str = "ollama", **kwargs) -> LLMProvider:
Expand All @@ -35,7 +36,7 @@ def create_llm_provider(provider_type: str = "ollama", **kwargs) -> LLMProvider:
automatically switches to Gemini provider.

Args:
provider_type: Type of provider ("ollama", "openai", "gemini", "openrouter", "mistral", "deepseek", "poe")
provider_type: Type of provider ("ollama", "openai", "gemini", "openrouter", "mistral", "deepseek", "poe", "litellm")
**kwargs: Provider-specific parameters:
- api_endpoint: API endpoint URL (Ollama, OpenAI)
- model: Model name/identifier
Expand Down Expand Up @@ -154,5 +155,12 @@ def create_llm_provider(provider_type: str = "ollama", **kwargs) -> LLMProvider:
api_endpoint=kwargs.get("api_endpoint", NIM_API_ENDPOINT)
)

elif provider_type.lower() == "litellm":
return LiteLLMProvider(
model=kwargs.get("model", DEFAULT_MODEL),
api_key=kwargs.get("api_key"),
api_base=kwargs.get("api_endpoint") or kwargs.get("endpoint"),
)

else:
raise ValueError(f"Unknown provider type: {provider_type}")
125 changes: 125 additions & 0 deletions src/core/llm/providers/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
LiteLLM provider implementation.

Routes to 100+ LLM providers (OpenAI, Anthropic, Gemini, Bedrock, Vertex AI,
Groq, etc.) via a unified interface using provider-prefixed model names.

Install: pip install litellm

Example:
>>> provider = LiteLLMProvider(model="anthropic/claude-sonnet-4-6")
>>> response = await provider.generate("Translate: Hello")
"""

from typing import Optional

from src.config import REQUEST_TIMEOUT, MAX_TRANSLATION_ATTEMPTS
from ..base import LLMProvider, LLMResponse
from ..exceptions import ContextOverflowError


class LiteLLMProvider(LLMProvider):
"""
Provider that uses LiteLLM to access 100+ LLM providers.

Uses provider-prefixed model names for routing:
- "openai/gpt-4o"
- "anthropic/claude-sonnet-4-6"
- "gemini/gemini-2.5-flash"
- "bedrock/anthropic.claude-v2"

API keys are read from provider-specific env vars (OPENAI_API_KEY,
ANTHROPIC_API_KEY, etc.) or passed explicitly.
"""

def __init__(
self,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
):
super().__init__(model)
self.api_key = api_key
self.api_base = api_base

def _build_kwargs(self) -> dict:
kwargs: dict = {"drop_params": True}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
return kwargs

async def generate(
self,
prompt: str,
timeout: int = REQUEST_TIMEOUT,
system_prompt: Optional[str] = None,
) -> Optional[LLMResponse]:
import litellm

messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})

kwargs = self._build_kwargs()

for attempt in range(MAX_TRANSLATION_ATTEMPTS):
try:
response = await litellm.acompletion(
model=self.model,
messages=messages,
timeout=timeout,
**kwargs,
)

choice = response.choices[0]
content = getattr(choice.message, "content", "") or ""

usage = getattr(response, "usage", None)
prompt_tokens = getattr(usage, "prompt_tokens", 0) if usage else 0
completion_tokens = getattr(usage, "completion_tokens", 0) if usage else 0

return LLMResponse(
content=content,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
context_used=prompt_tokens + completion_tokens,
context_limit=0,
was_truncated=False,
)

except Exception as e:
error_str = str(e).lower()
context_keywords = [
"context_length", "maximum context", "token limit",
"too many tokens", "reduce the length", "max_tokens",
"context window", "exceeds",
]
if any(kw in error_str for kw in context_keywords):
raise ContextOverflowError(
f"LiteLLM context overflow: {e}"
) from e

qualname = f"{type(e).__module__}.{type(e).__name__}"
transient = {
"litellm.exceptions.RateLimitError",
"litellm.exceptions.APIConnectionError",
"litellm.exceptions.Timeout",
"litellm.exceptions.InternalServerError",
"litellm.exceptions.ServiceUnavailableError",
}
if qualname in transient and attempt < MAX_TRANSLATION_ATTEMPTS - 1:
import asyncio
await asyncio.sleep(min(2 ** (attempt + 1), 10))
continue

print(f"[LiteLLM] Error (attempt {attempt + 1}/{MAX_TRANSLATION_ATTEMPTS}): {e}")
if attempt < MAX_TRANSLATION_ATTEMPTS - 1:
import asyncio
await asyncio.sleep(2)
continue
return None

return None
88 changes: 88 additions & 0 deletions tests/unit/test_litellm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import sys
import types
from unittest import mock

import pytest


def _install_litellm_stub():
fake = types.ModuleType("litellm")
fake.acompletion = mock.AsyncMock(name="litellm.acompletion")
sys.modules["litellm"] = fake
return fake


@pytest.fixture(autouse=True)
def litellm_stub():
fake = _install_litellm_stub()
yield fake
sys.modules.pop("litellm", None)


def _mock_response(content="Hello!", prompt_tokens=10, completion_tokens=5):
from types import SimpleNamespace

return SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content=content))],
usage=SimpleNamespace(
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
),
)


@pytest.mark.asyncio
async def test_generate_calls_acompletion(litellm_stub):
litellm_stub.acompletion.return_value = _mock_response("translated text")

from src.core.llm.providers.litellm import LiteLLMProvider

provider = LiteLLMProvider(model="anthropic/claude-haiku-4-5", api_key="sk-test")
result = await provider.generate("Translate this")

litellm_stub.acompletion.assert_called_once()
kwargs = litellm_stub.acompletion.call_args.kwargs
assert kwargs["model"] == "anthropic/claude-haiku-4-5"
assert kwargs["api_key"] == "sk-test"
assert kwargs["drop_params"] is True
assert result.content == "translated text"
assert result.prompt_tokens == 10
assert result.completion_tokens == 5


@pytest.mark.asyncio
async def test_generate_omits_blank_credentials(litellm_stub):
litellm_stub.acompletion.return_value = _mock_response()

from src.core.llm.providers.litellm import LiteLLMProvider

provider = LiteLLMProvider(model="openai/gpt-4o")
await provider.generate("Hello")

kwargs = litellm_stub.acompletion.call_args.kwargs
assert "api_key" not in kwargs
assert "api_base" not in kwargs


@pytest.mark.asyncio
async def test_generate_forwards_system_prompt(litellm_stub):
litellm_stub.acompletion.return_value = _mock_response()

from src.core.llm.providers.litellm import LiteLLMProvider

provider = LiteLLMProvider(model="openai/gpt-4o", api_key="k")
await provider.generate("Translate this", system_prompt="You are a translator")

kwargs = litellm_stub.acompletion.call_args.kwargs
messages = kwargs["messages"]
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "You are a translator"
assert messages[1]["role"] == "user"


def test_factory_creates_litellm_provider():
from src.core.llm.factory import create_llm_provider
from src.core.llm.providers.litellm import LiteLLMProvider

provider = create_llm_provider("litellm", model="anthropic/claude-haiku-4-5", api_key="k")
assert isinstance(provider, LiteLLMProvider)
assert provider.model == "anthropic/claude-haiku-4-5"