diff --git a/.env.example b/.env.example index 44ef9b936..7edceedd0 100644 --- a/.env.example +++ b/.env.example @@ -45,6 +45,17 @@ CYPHER_ENDPOINT=http://localhost:11434/v1 # CYPHER_MODEL=gemini-2.5-flash # CYPHER_API_KEY=your-google-api-key +# Example 6: LiteLLM with custom provider +# ORCHESTRATOR_PROVIDER=litellm_proxy +# ORCHESTRATOR_MODEL=gpt-oss:120b +# ORCHESTRATOR_ENDPOINT=http://litellm:4000/v1 +# ORCHESTRATOR_API_KEY=sk-your-litellm-key + +# CYPHER_PROVIDER=litellm_proxy +# CYPHER_MODEL=openrouter/gpt-oss:120b +# CYPHER_ENDPOINT=http://litellm:4000/v1 +# CYPHER_API_KEY=sk-your-litellm-key + # Memgraph settings MEMGRAPH_HOST=localhost MEMGRAPH_PORT=7687 diff --git a/codebase_rag/providers/base.py b/codebase_rag/providers/base.py index d525c1e27..a94b4c283 100644 --- a/codebase_rag/providers/base.py +++ b/codebase_rag/providers/base.py @@ -166,6 +166,16 @@ def create_model(self, model_id: str, **kwargs: Any) -> OpenAIModel: "ollama": OllamaProvider, } +# Import LiteLLM provider after base classes are defined to avoid circular import +try: + from .litellm import LiteLLMProvider + + PROVIDER_REGISTRY["litellm_proxy"] = LiteLLMProvider + _litellm_available = True +except ImportError as e: + logger.debug(f"LiteLLM provider not available: {e}") + _litellm_available = False + def get_provider(provider_name: str, **config: Any) -> ModelProvider: """Factory function to create a provider instance.""" @@ -199,3 +209,44 @@ def check_ollama_running(endpoint: str = "http://localhost:11434") -> bool: return bool(response.status_code == 200) except (httpx.RequestError, httpx.TimeoutException): return False + + +def check_litellm_proxy_running( + endpoint: str = "http://localhost:4000", api_key: str | None = None +) -> bool: + """Check if LiteLLM proxy is running and accessible. + + Args: + endpoint: Base URL of the LiteLLM proxy server + api_key: Optional API key for authenticated proxies + + Returns: + True if the proxy is accessible, False otherwise + """ + try: + base_url = endpoint.rstrip("/v1").rstrip("/") + + # Try health endpoint first (works for unauthenticated proxies) + health_url = urljoin(base_url, "/health") + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + with httpx.Client(timeout=5.0) as client: + response = client.get(health_url, headers=headers) + + # If health endpoint works, we're good + if response.status_code == 200: + return True + + # If health endpoint fails (401, 404, 405, 500, etc.), + # try the models endpoint as a fallback when we have an API key + if api_key: + models_url = urljoin(base_url, "/v1/models") + response = client.get(models_url, headers=headers) + # Accept 200 (success) - server is up and API key works + return bool(response.status_code == 200) + + return False + except (httpx.RequestError, httpx.TimeoutException): + return False diff --git a/codebase_rag/providers/litellm.py b/codebase_rag/providers/litellm.py new file mode 100644 index 000000000..9095bfe52 --- /dev/null +++ b/codebase_rag/providers/litellm.py @@ -0,0 +1,62 @@ +"""LiteLLM provider using pydantic-ai's native LiteLLMProvider.""" + +from typing import Any + +from loguru import logger +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.litellm import LiteLLMProvider as PydanticLiteLLMProvider + +from .base import ModelProvider + + +class LiteLLMProvider(ModelProvider): + def __init__( + self, + api_key: str | None = None, + endpoint: str = "http://localhost:4000/v1", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.api_key = api_key + self.endpoint = endpoint + + @property + def provider_name(self) -> str: + return "litellm_proxy" + + def validate_config(self) -> None: + if not self.endpoint: + raise ValueError( + "LiteLLM provider requires endpoint. " + "Set ORCHESTRATOR_ENDPOINT or CYPHER_ENDPOINT in .env file." + ) + + # Check if LiteLLM proxy is running + # Import locally to avoid circular import + from .base import check_litellm_proxy_running + + base_url = self.endpoint.rstrip("/v1").rstrip("/") + if not check_litellm_proxy_running(base_url, api_key=self.api_key): + raise ValueError( + f"LiteLLM proxy server not responding at {base_url}. " + f"Make sure LiteLLM proxy is running and API key is valid." + ) + + def create_model(self, model_id: str, **kwargs: Any) -> OpenAIChatModel: + """Create OpenAI-compatible model for LiteLLM proxy. + + Args: + model_id: Model identifier (e.g., "openai/gpt-3.5-turbo", "anthropic/claude-3") + **kwargs: Additional arguments passed to OpenAIChatModel + + Returns: + OpenAIChatModel configured to use the LiteLLM proxy + """ + self.validate_config() + + logger.info(f"Creating LiteLLM proxy model: {model_id} at {self.endpoint}") + + # Use pydantic-ai's native LiteLLMProvider + provider = PydanticLiteLLMProvider(api_key=self.api_key, api_base=self.endpoint) + + return OpenAIChatModel(model_id, provider=provider, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 6be0c2979..28c8ea85e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.12" dependencies = [ "loguru>=0.7.3", "mcp>=1.21.1", - "pydantic-ai-slim[google,openai,vertexai]>=0.2.18", + "pydantic-ai-slim[google,openai,vertexai]>=1.18.0", "pydantic-settings>=2.0.0", "pymgclient>=1.4.0", "python-dotenv>=1.1.0",