Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ CYPHER_ENDPOINT=http://localhost:11434/v1
# CYPHER_MODEL=gemini-2.5-flash
# CYPHER_API_KEY=your-google-api-key

# Example 6: LiteLLM with custom provider
# ORCHESTRATOR_PROVIDER=litellm_proxy
# ORCHESTRATOR_MODEL=gpt-oss:120b
# ORCHESTRATOR_ENDPOINT=http://litellm:4000/v1
# ORCHESTRATOR_API_KEY=sk-your-litellm-key

# CYPHER_PROVIDER=litellm_proxy
# CYPHER_MODEL=openrouter/gpt-oss:120b
# CYPHER_ENDPOINT=http://litellm:4000/v1
# CYPHER_API_KEY=sk-your-litellm-key

# Memgraph settings
MEMGRAPH_HOST=localhost
MEMGRAPH_PORT=7687
Expand Down
51 changes: 51 additions & 0 deletions codebase_rag/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,16 @@ def create_model(self, model_id: str, **kwargs: Any) -> OpenAIModel:
"ollama": OllamaProvider,
}

# Import LiteLLM provider after base classes are defined to avoid circular import
try:
from .litellm import LiteLLMProvider

PROVIDER_REGISTRY["litellm_proxy"] = LiteLLMProvider
_litellm_available = True
except ImportError as e:
logger.debug(f"LiteLLM provider not available: {e}")
_litellm_available = False


def get_provider(provider_name: str, **config: Any) -> ModelProvider:
"""Factory function to create a provider instance."""
Expand Down Expand Up @@ -199,3 +209,44 @@ def check_ollama_running(endpoint: str = "http://localhost:11434") -> bool:
return bool(response.status_code == 200)
except (httpx.RequestError, httpx.TimeoutException):
return False


def check_litellm_proxy_running(
endpoint: str = "http://localhost:4000", api_key: str | None = None
) -> bool:
"""Check if LiteLLM proxy is running and accessible.

Args:
endpoint: Base URL of the LiteLLM proxy server
api_key: Optional API key for authenticated proxies

Returns:
True if the proxy is accessible, False otherwise
"""
try:
base_url = endpoint.rstrip("/v1").rstrip("/")

# Try health endpoint first (works for unauthenticated proxies)
health_url = urljoin(base_url, "/health")
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

with httpx.Client(timeout=5.0) as client:
response = client.get(health_url, headers=headers)

# If health endpoint works, we're good
if response.status_code == 200:
return True

# If health endpoint fails (401, 404, 405, 500, etc.),
# try the models endpoint as a fallback when we have an API key
if api_key:
models_url = urljoin(base_url, "/v1/models")
response = client.get(models_url, headers=headers)
# Accept 200 (success) - server is up and API key works
return bool(response.status_code == 200)

return False
except (httpx.RequestError, httpx.TimeoutException):
return False
62 changes: 62 additions & 0 deletions codebase_rag/providers/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""LiteLLM provider using pydantic-ai's native LiteLLMProvider."""

from typing import Any

from loguru import logger
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.litellm import LiteLLMProvider as PydanticLiteLLMProvider

from .base import ModelProvider


class LiteLLMProvider(ModelProvider):
def __init__(
self,
api_key: str | None = None,
endpoint: str = "http://localhost:4000/v1",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.api_key = api_key
self.endpoint = endpoint

@property
def provider_name(self) -> str:
return "litellm_proxy"

def validate_config(self) -> None:
if not self.endpoint:
raise ValueError(
"LiteLLM provider requires endpoint. "
"Set ORCHESTRATOR_ENDPOINT or CYPHER_ENDPOINT in .env file."
)

# Check if LiteLLM proxy is running
# Import locally to avoid circular import
from .base import check_litellm_proxy_running

base_url = self.endpoint.rstrip("/v1").rstrip("/")
if not check_litellm_proxy_running(base_url, api_key=self.api_key):
raise ValueError(
f"LiteLLM proxy server not responding at {base_url}. "
f"Make sure LiteLLM proxy is running and API key is valid."
)

def create_model(self, model_id: str, **kwargs: Any) -> OpenAIChatModel:
"""Create OpenAI-compatible model for LiteLLM proxy.

Args:
model_id: Model identifier (e.g., "openai/gpt-3.5-turbo", "anthropic/claude-3")
**kwargs: Additional arguments passed to OpenAIChatModel

Returns:
OpenAIChatModel configured to use the LiteLLM proxy
"""
self.validate_config()

logger.info(f"Creating LiteLLM proxy model: {model_id} at {self.endpoint}")

# Use pydantic-ai's native LiteLLMProvider
provider = PydanticLiteLLMProvider(api_key=self.api_key, api_base=self.endpoint)

return OpenAIChatModel(model_id, provider=provider, **kwargs)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires-python = ">=3.12"
dependencies = [
"loguru>=0.7.3",
"mcp>=1.21.1",
"pydantic-ai-slim[google,openai,vertexai]>=0.2.18",
"pydantic-ai-slim[google,openai,vertexai]>=1.18.0",
"pydantic-settings>=2.0.0",
"pymgclient>=1.4.0",
"python-dotenv>=1.1.0",
Expand Down