diff --git a/.env.example b/.env.example index 44ef9b936..2c6771399 100644 --- a/.env.example +++ b/.env.example @@ -27,7 +27,19 @@ CYPHER_ENDPOINT=http://localhost:11434/v1 # CYPHER_MODEL=gemini-2.5-flash # CYPHER_API_KEY=your-google-api-key -# Example 4: Mixed - Google orchestrator + Ollama cypher +# Example 4: All Azure OpenAI +# ORCHESTRATOR_PROVIDER=azure_openai +# ORCHESTRATOR_MODEL=gpt-5 +# ORCHESTRATOR_API_KEY=your-azure-api-key +# ORCHESTRATOR_ENDPOINT=your-azure-endpoint +# AZURE_OPEN_AI_API_VERSION=2025-03-01-preview + +# CYPHER_PROVIDER=azure_openai +# CYPHER_MODEL=gpt-5 +# CYPHER_API_KEY=your-azure-api-key +# CYPHER_ENDPOINT=your-azure-endpoint + +# Example 5: Mixed - Google orchestrator + Ollama cypher # ORCHESTRATOR_PROVIDER=google # ORCHESTRATOR_MODEL=gemini-2.5-pro # ORCHESTRATOR_API_KEY=your-google-api-key diff --git a/codebase_rag/config.py b/codebase_rag/config.py index eccf7bab5..b7352ac61 100644 --- a/codebase_rag/config.py +++ b/codebase_rag/config.py @@ -24,6 +24,7 @@ class ModelConfig: provider_type: str | None = None thinking_budget: int | None = None service_account_file: str | None = None + api_version: str | None = None class AppConfig(BaseSettings): @@ -67,6 +68,11 @@ class AppConfig(BaseSettings): CYPHER_THINKING_BUDGET: int | None = None CYPHER_SERVICE_ACCOUNT_FILE: str | None = None + # OpenAI API Version for Azure + AZURE_OPEN_AI_API_VERSION: str | None = ( + None # For models compatible with the OpenAI API, as specified in: https://ai.pydantic.dev/models/overview/#openai-compatible-providers + ) + # Fallback endpoint for ollama LOCAL_MODEL_ENDPOINT: AnyHttpUrl = AnyHttpUrl("http://localhost:11434/v1") @@ -100,6 +106,7 @@ def _get_default_config(self, role: str) -> ModelConfig: service_account_file=getattr( self, f"{role_upper}_SERVICE_ACCOUNT_FILE", None ), + api_version=getattr(self, "AZURE_OPEN_AI_API_VERSION", None), ) # Default to Ollama diff --git a/codebase_rag/main.py b/codebase_rag/main.py index 0586a9f3c..aa72d9ef3 100644 --- a/codebase_rag/main.py +++ b/codebase_rag/main.py @@ -726,6 +726,7 @@ def _validate_provider_config(role: str, config: Any) -> None: provider_type=config.provider_type, thinking_budget=config.thinking_budget, service_account_file=config.service_account_file, + api_version=config.api_version, ) provider.validate_config() except Exception as e: diff --git a/codebase_rag/providers/base.py b/codebase_rag/providers/base.py index d525c1e27..00a0b87d4 100644 --- a/codebase_rag/providers/base.py +++ b/codebase_rag/providers/base.py @@ -8,6 +8,7 @@ from loguru import logger from pydantic_ai.models.gemini import GeminiModel, GeminiModelSettings from pydantic_ai.models.openai import OpenAIModel, OpenAIResponsesModel +from pydantic_ai.providers.azure import AzureProvider from pydantic_ai.providers.google_gla import GoogleGLAProvider from pydantic_ai.providers.google_vertex import GoogleVertexProvider, VertexAiRegion from pydantic_ai.providers.openai import OpenAIProvider as PydanticOpenAIProvider @@ -94,6 +95,56 @@ def create_model(self, model_id: str, **kwargs: Any) -> GeminiModel: ) +class AzureOpenAIProvider(ModelProvider): + """Azure OpenAI provider.""" + + def __init__( + self, + api_key: str | None = None, + endpoint: str | None = None, + api_version: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.api_key = api_key + self.endpoint = endpoint + self.api_version = api_version + + @property + def provider_name(self) -> str: + return "azure_openai" + + def validate_config(self) -> None: + if not self.api_key: + raise ValueError( + "Azure OpenAI provider requires an API key. " + "Set ORCHESTRATOR_API_KEY or CYPHER_API_KEY in your .env file." + ) + + if not self.endpoint: + raise ValueError( + "Azure OpenAI provider requires an endpoint. " + "Set ORCHESTRATOR_ENDPOINT or CYPHER_ENDPOINT in your .env file." + ) + + if not self.api_version: + raise ValueError( + "Azure OpenAI provider requires api version. " + "Set AZURE_OPEN_AI_API_VERSION in .env file." + ) + + def create_model(self, model_id: str, **kwargs: Any) -> OpenAIModel: + self.validate_config() + + provider = AzureProvider( + azure_endpoint=self.endpoint, + api_version=self.api_version, + api_key=self.api_key, + ) + + return OpenAIModel(model_id, provider=provider, **kwargs) + + class OpenAIProvider(ModelProvider): """OpenAI provider.""" @@ -164,6 +215,7 @@ def create_model(self, model_id: str, **kwargs: Any) -> OpenAIModel: "google": GoogleProvider, "openai": OpenAIProvider, "ollama": OllamaProvider, + "azure_openai": AzureOpenAIProvider, } diff --git a/codebase_rag/services/llm.py b/codebase_rag/services/llm.py index 204b40b2b..9d3d2e8d3 100644 --- a/codebase_rag/services/llm.py +++ b/codebase_rag/services/llm.py @@ -43,6 +43,7 @@ def __init__(self) -> None: region=config.region, provider_type=config.provider_type, thinking_budget=config.thinking_budget, + api_version=config.api_version, ) # Create model using provider @@ -102,6 +103,7 @@ def create_rag_orchestrator(tools: list[Tool]) -> Agent: region=config.region, provider_type=config.provider_type, thinking_budget=config.thinking_budget, + api_version=config.api_version, ) # Create model using provider diff --git a/codebase_rag/tests/test_provider_classes.py b/codebase_rag/tests/test_provider_classes.py index f831a5428..e2f837247 100644 --- a/codebase_rag/tests/test_provider_classes.py +++ b/codebase_rag/tests/test_provider_classes.py @@ -8,6 +8,7 @@ import pytest from codebase_rag.providers.base import ( + AzureOpenAIProvider, GoogleProvider, ModelProvider, OllamaProvider, @@ -40,6 +41,15 @@ def test_get_valid_providers(self) -> None: assert isinstance(ollama_provider, OllamaProvider) assert ollama_provider.provider_name == "ollama" + azure_openai_provider = get_provider( + "azure_openai", + api_key="test-key", + endpoint="https://example.com", + api_version="2024-05-01-preview", + ) + assert isinstance(azure_openai_provider, AzureOpenAIProvider) + assert azure_openai_provider.provider_name == "azure_openai" + def test_get_invalid_provider(self) -> None: """Test that invalid provider names raise ValueError.""" with pytest.raises(ValueError, match="Unknown provider 'invalid_provider'"): @@ -51,7 +61,8 @@ def test_list_providers(self) -> None: assert "google" in providers assert "openai" in providers assert "ollama" in providers - assert len(providers) >= 3 + assert "azure_openai" in providers + assert len(providers) >= 4 def test_register_custom_provider(self) -> None: """Test registering a custom provider.""" diff --git a/codebase_rag/tests/test_provider_configuration.py b/codebase_rag/tests/test_provider_configuration.py index d2035210a..2cd3110fb 100644 --- a/codebase_rag/tests/test_provider_configuration.py +++ b/codebase_rag/tests/test_provider_configuration.py @@ -253,3 +253,36 @@ def test_openai_custom_endpoint(self) -> None: assert orch_config.model_id == "gpt-4o" assert orch_config.api_key == "sk-test-key" assert orch_config.endpoint == "https://api.custom-openai.com/v1" + + def test_azure_openai_endpoint_configuration(self) -> None: + """Test Azure OpenAI endpoint configuration with API version.""" + with patch.dict( + os.environ, + { + "ORCHESTRATOR_PROVIDER": "azure_openai", + "ORCHESTRATOR_MODEL": "DeepSeek-V3.1", + "ORCHESTRATOR_API_KEY": "test-azure-key", + "ORCHESTRATOR_ENDPOINT": "https://resource.openai.azure.com/", + "AZURE_OPEN_AI_API_VERSION": "2024-02-15-preview", + "CYPHER_PROVIDER": "azure_openai", + "CYPHER_MODEL": "DeepSeek-V3.1", + "CYPHER_API_KEY": "test-azure-key", + "CYPHER_ENDPOINT": "https://resource.openai.azure.com/", + }, + ): + config = AppConfig() + + # Test orchestrator Azure OpenAI config + orch_config = config.active_orchestrator_config + assert orch_config.provider == "azure_openai" + assert orch_config.model_id == "DeepSeek-V3.1" + assert orch_config.api_key == "test-azure-key" + assert orch_config.endpoint == "https://resource.openai.azure.com/" + assert orch_config.api_version == "2024-02-15-preview" + + # Test cypher Azure OpenAI config + cypher_config = config.active_cypher_config + assert cypher_config.provider == "azure_openai" + assert cypher_config.model_id == "DeepSeek-V3.1" + assert cypher_config.api_key == "test-azure-key" + assert cypher_config.endpoint == "https://resource.openai.azure.com/"