From 48b9b1411f2c93296933d2a0830647432563c290 Mon Sep 17 00:00:00 2001 From: RheagalFire Date: Tue, 5 May 2026 23:56:17 +0530 Subject: [PATCH] feat: add LiteLLM as AI gateway provider --- requirements.txt | 3 + src/core/llm/factory.py | 10 ++- src/core/llm/providers/litellm.py | 125 ++++++++++++++++++++++++++++ tests/unit/test_litellm_provider.py | 88 ++++++++++++++++++++ 4 files changed, 225 insertions(+), 1 deletion(-) create mode 100644 src/core/llm/providers/litellm.py create mode 100644 tests/unit/test_litellm_provider.py diff --git a/requirements.txt b/requirements.txt index ff64ff5..bd479a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,9 @@ Pillow>=10.0.0 mammoth>=1.0.0 python-docx>=1.1.0 +# LiteLLM AI gateway (100+ providers) - OPTIONAL +# pip install litellm>=1.65,<1.85 + # Chatterbox TTS (GPU-accelerated local TTS) - OPTIONAL # These dependencies have strict numpy version requirements that may conflict. # Install separately if needed: diff --git a/src/core/llm/factory.py b/src/core/llm/factory.py index 4df293a..4fd9a18 100644 --- a/src/core/llm/factory.py +++ b/src/core/llm/factory.py @@ -25,6 +25,7 @@ from .providers.mistral import MistralProvider from .providers.deepseek import DeepSeekProvider from .providers.poe import PoeProvider +from .providers.litellm import LiteLLMProvider def create_llm_provider(provider_type: str = "ollama", **kwargs) -> LLMProvider: @@ -35,7 +36,7 @@ def create_llm_provider(provider_type: str = "ollama", **kwargs) -> LLMProvider: automatically switches to Gemini provider. Args: - provider_type: Type of provider ("ollama", "openai", "gemini", "openrouter", "mistral", "deepseek", "poe") + provider_type: Type of provider ("ollama", "openai", "gemini", "openrouter", "mistral", "deepseek", "poe", "litellm") **kwargs: Provider-specific parameters: - api_endpoint: API endpoint URL (Ollama, OpenAI) - model: Model name/identifier @@ -154,5 +155,12 @@ def create_llm_provider(provider_type: str = "ollama", **kwargs) -> LLMProvider: api_endpoint=kwargs.get("api_endpoint", NIM_API_ENDPOINT) ) + elif provider_type.lower() == "litellm": + return LiteLLMProvider( + model=kwargs.get("model", DEFAULT_MODEL), + api_key=kwargs.get("api_key"), + api_base=kwargs.get("api_endpoint") or kwargs.get("endpoint"), + ) + else: raise ValueError(f"Unknown provider type: {provider_type}") diff --git a/src/core/llm/providers/litellm.py b/src/core/llm/providers/litellm.py new file mode 100644 index 0000000..a16aecb --- /dev/null +++ b/src/core/llm/providers/litellm.py @@ -0,0 +1,125 @@ +""" +LiteLLM provider implementation. + +Routes to 100+ LLM providers (OpenAI, Anthropic, Gemini, Bedrock, Vertex AI, +Groq, etc.) via a unified interface using provider-prefixed model names. + +Install: pip install litellm + +Example: + >>> provider = LiteLLMProvider(model="anthropic/claude-sonnet-4-6") + >>> response = await provider.generate("Translate: Hello") +""" + +from typing import Optional + +from src.config import REQUEST_TIMEOUT, MAX_TRANSLATION_ATTEMPTS +from ..base import LLMProvider, LLMResponse +from ..exceptions import ContextOverflowError + + +class LiteLLMProvider(LLMProvider): + """ + Provider that uses LiteLLM to access 100+ LLM providers. + + Uses provider-prefixed model names for routing: + - "openai/gpt-4o" + - "anthropic/claude-sonnet-4-6" + - "gemini/gemini-2.5-flash" + - "bedrock/anthropic.claude-v2" + + API keys are read from provider-specific env vars (OPENAI_API_KEY, + ANTHROPIC_API_KEY, etc.) or passed explicitly. + """ + + def __init__( + self, + model: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ): + super().__init__(model) + self.api_key = api_key + self.api_base = api_base + + def _build_kwargs(self) -> dict: + kwargs: dict = {"drop_params": True} + if self.api_key: + kwargs["api_key"] = self.api_key + if self.api_base: + kwargs["api_base"] = self.api_base + return kwargs + + async def generate( + self, + prompt: str, + timeout: int = REQUEST_TIMEOUT, + system_prompt: Optional[str] = None, + ) -> Optional[LLMResponse]: + import litellm + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + kwargs = self._build_kwargs() + + for attempt in range(MAX_TRANSLATION_ATTEMPTS): + try: + response = await litellm.acompletion( + model=self.model, + messages=messages, + timeout=timeout, + **kwargs, + ) + + choice = response.choices[0] + content = getattr(choice.message, "content", "") or "" + + usage = getattr(response, "usage", None) + prompt_tokens = getattr(usage, "prompt_tokens", 0) if usage else 0 + completion_tokens = getattr(usage, "completion_tokens", 0) if usage else 0 + + return LLMResponse( + content=content, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + context_used=prompt_tokens + completion_tokens, + context_limit=0, + was_truncated=False, + ) + + except Exception as e: + error_str = str(e).lower() + context_keywords = [ + "context_length", "maximum context", "token limit", + "too many tokens", "reduce the length", "max_tokens", + "context window", "exceeds", + ] + if any(kw in error_str for kw in context_keywords): + raise ContextOverflowError( + f"LiteLLM context overflow: {e}" + ) from e + + qualname = f"{type(e).__module__}.{type(e).__name__}" + transient = { + "litellm.exceptions.RateLimitError", + "litellm.exceptions.APIConnectionError", + "litellm.exceptions.Timeout", + "litellm.exceptions.InternalServerError", + "litellm.exceptions.ServiceUnavailableError", + } + if qualname in transient and attempt < MAX_TRANSLATION_ATTEMPTS - 1: + import asyncio + await asyncio.sleep(min(2 ** (attempt + 1), 10)) + continue + + print(f"[LiteLLM] Error (attempt {attempt + 1}/{MAX_TRANSLATION_ATTEMPTS}): {e}") + if attempt < MAX_TRANSLATION_ATTEMPTS - 1: + import asyncio + await asyncio.sleep(2) + continue + return None + + return None diff --git a/tests/unit/test_litellm_provider.py b/tests/unit/test_litellm_provider.py new file mode 100644 index 0000000..c9a018a --- /dev/null +++ b/tests/unit/test_litellm_provider.py @@ -0,0 +1,88 @@ +import sys +import types +from unittest import mock + +import pytest + + +def _install_litellm_stub(): + fake = types.ModuleType("litellm") + fake.acompletion = mock.AsyncMock(name="litellm.acompletion") + sys.modules["litellm"] = fake + return fake + + +@pytest.fixture(autouse=True) +def litellm_stub(): + fake = _install_litellm_stub() + yield fake + sys.modules.pop("litellm", None) + + +def _mock_response(content="Hello!", prompt_tokens=10, completion_tokens=5): + from types import SimpleNamespace + + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content=content))], + usage=SimpleNamespace( + prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ), + ) + + +@pytest.mark.asyncio +async def test_generate_calls_acompletion(litellm_stub): + litellm_stub.acompletion.return_value = _mock_response("translated text") + + from src.core.llm.providers.litellm import LiteLLMProvider + + provider = LiteLLMProvider(model="anthropic/claude-haiku-4-5", api_key="sk-test") + result = await provider.generate("Translate this") + + litellm_stub.acompletion.assert_called_once() + kwargs = litellm_stub.acompletion.call_args.kwargs + assert kwargs["model"] == "anthropic/claude-haiku-4-5" + assert kwargs["api_key"] == "sk-test" + assert kwargs["drop_params"] is True + assert result.content == "translated text" + assert result.prompt_tokens == 10 + assert result.completion_tokens == 5 + + +@pytest.mark.asyncio +async def test_generate_omits_blank_credentials(litellm_stub): + litellm_stub.acompletion.return_value = _mock_response() + + from src.core.llm.providers.litellm import LiteLLMProvider + + provider = LiteLLMProvider(model="openai/gpt-4o") + await provider.generate("Hello") + + kwargs = litellm_stub.acompletion.call_args.kwargs + assert "api_key" not in kwargs + assert "api_base" not in kwargs + + +@pytest.mark.asyncio +async def test_generate_forwards_system_prompt(litellm_stub): + litellm_stub.acompletion.return_value = _mock_response() + + from src.core.llm.providers.litellm import LiteLLMProvider + + provider = LiteLLMProvider(model="openai/gpt-4o", api_key="k") + await provider.generate("Translate this", system_prompt="You are a translator") + + kwargs = litellm_stub.acompletion.call_args.kwargs + messages = kwargs["messages"] + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "You are a translator" + assert messages[1]["role"] == "user" + + +def test_factory_creates_litellm_provider(): + from src.core.llm.factory import create_llm_provider + from src.core.llm.providers.litellm import LiteLLMProvider + + provider = create_llm_provider("litellm", model="anthropic/claude-haiku-4-5", api_key="k") + assert isinstance(provider, LiteLLMProvider) + assert provider.model == "anthropic/claude-haiku-4-5"