Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,19 @@ CYPHER_ENDPOINT=http://localhost:11434/v1
# CYPHER_MODEL=gemini-2.5-flash
# CYPHER_API_KEY=your-google-api-key

# Example 4: Mixed - Google orchestrator + Ollama cypher
# Example 4: All Azure OpenAI
# ORCHESTRATOR_PROVIDER=azure_openai
# ORCHESTRATOR_MODEL=gpt-5
# ORCHESTRATOR_API_KEY=your-azure-api-key
# ORCHESTRATOR_ENDPOINT=your-azure-endpoint
# AZURE_OPEN_AI_API_VERSION=2025-03-01-preview

# CYPHER_PROVIDER=azure_openai
# CYPHER_MODEL=gpt-5
# CYPHER_API_KEY=your-azure-api-key
# CYPHER_ENDPOINT=your-azure-endpoint

# Example 5: Mixed - Google orchestrator + Ollama cypher
# ORCHESTRATOR_PROVIDER=google
# ORCHESTRATOR_MODEL=gemini-2.5-pro
# ORCHESTRATOR_API_KEY=your-google-api-key
Expand Down
7 changes: 7 additions & 0 deletions codebase_rag/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ModelConfig:
provider_type: str | None = None
thinking_budget: int | None = None
service_account_file: str | None = None
api_version: str | None = None


class AppConfig(BaseSettings):
Expand Down Expand Up @@ -67,6 +68,11 @@ class AppConfig(BaseSettings):
CYPHER_THINKING_BUDGET: int | None = None
CYPHER_SERVICE_ACCOUNT_FILE: str | None = None

# OpenAI API Version for Azure
AZURE_OPEN_AI_API_VERSION: str | None = (
None # For models compatible with the OpenAI API, as specified in: https://ai.pydantic.dev/models/overview/#openai-compatible-providers
)
Comment on lines +72 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment for AZURE_OPEN_AI_API_VERSION is placed inside the parentheses of the value assignment. This is an unconventional style and can be confusing to read. For better readability and to adhere to common Python style practices, it's recommended to place the comment on the same line after the assignment or on a separate line before it.

    # OpenAI API Version for Azure
    AZURE_OPEN_AI_API_VERSION: str | None = None  # For models compatible with the OpenAI API, as specified in: https://ai.pydantic.dev/models/overview/#openai-compatible-providers


# Fallback endpoint for ollama
LOCAL_MODEL_ENDPOINT: AnyHttpUrl = AnyHttpUrl("http://localhost:11434/v1")

Expand Down Expand Up @@ -100,6 +106,7 @@ def _get_default_config(self, role: str) -> ModelConfig:
service_account_file=getattr(
self, f"{role_upper}_SERVICE_ACCOUNT_FILE", None
),
api_version=getattr(self, "AZURE_OPEN_AI_API_VERSION", None),
)

# Default to Ollama
Expand Down
1 change: 1 addition & 0 deletions codebase_rag/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,7 @@ def _validate_provider_config(role: str, config: Any) -> None:
provider_type=config.provider_type,
thinking_budget=config.thinking_budget,
service_account_file=config.service_account_file,
api_version=config.api_version,
)
provider.validate_config()
except Exception as e:
Expand Down
52 changes: 52 additions & 0 deletions codebase_rag/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from loguru import logger
from pydantic_ai.models.gemini import GeminiModel, GeminiModelSettings
from pydantic_ai.models.openai import OpenAIModel, OpenAIResponsesModel
from pydantic_ai.providers.azure import AzureProvider
from pydantic_ai.providers.google_gla import GoogleGLAProvider
from pydantic_ai.providers.google_vertex import GoogleVertexProvider, VertexAiRegion
from pydantic_ai.providers.openai import OpenAIProvider as PydanticOpenAIProvider
Expand Down Expand Up @@ -94,6 +95,56 @@ def create_model(self, model_id: str, **kwargs: Any) -> GeminiModel:
)


class AzureOpenAIProvider(ModelProvider):
"""Azure OpenAI provider."""

def __init__(
self,
api_key: str | None = None,
endpoint: str | None = None,
api_version: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version

@property
def provider_name(self) -> str:
return "azure_openai"

def validate_config(self) -> None:
if not self.api_key:
raise ValueError(
"Azure OpenAI provider requires api key. "
"Set AZURE_OPENAI_API_KEY in .env file."
)

if not self.endpoint:
raise ValueError(
"Azure OpenAI provider requires endpoint. "
"Set AZURE_OPENAI_ENDPOINT in .env file."
)

if not self.api_version:
raise ValueError(
"Azure OpenAI provider requires api version. "
"Set AZURE_OPEN_AI_API_VERSION in .env file."
)

def create_model(self, model_id: str, **kwargs: Any) -> OpenAIModel:
self.validate_config()

provider = AzureProvider(
azure_endpoint=self.endpoint,
api_version=self.api_version,
api_key=self.api_key,
)

return OpenAIModel(model_id, provider=provider, **kwargs)


class OpenAIProvider(ModelProvider):
"""OpenAI provider."""

Expand Down Expand Up @@ -164,6 +215,7 @@ def create_model(self, model_id: str, **kwargs: Any) -> OpenAIModel:
"google": GoogleProvider,
"openai": OpenAIProvider,
"ollama": OllamaProvider,
"azure_openai": AzureOpenAIProvider,
}


Expand Down
2 changes: 2 additions & 0 deletions codebase_rag/services/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self) -> None:
region=config.region,
provider_type=config.provider_type,
thinking_budget=config.thinking_budget,
api_version=config.api_version,
)

# Create model using provider
Expand Down Expand Up @@ -102,6 +103,7 @@ def create_rag_orchestrator(tools: list[Tool]) -> Agent:
region=config.region,
provider_type=config.provider_type,
thinking_budget=config.thinking_budget,
api_version=config.api_version,
)

# Create model using provider
Expand Down
13 changes: 12 additions & 1 deletion codebase_rag/tests/test_provider_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from codebase_rag.providers.base import (
AzureOpenAIProvider,
GoogleProvider,
ModelProvider,
OllamaProvider,
Expand Down Expand Up @@ -40,6 +41,15 @@ def test_get_valid_providers(self) -> None:
assert isinstance(ollama_provider, OllamaProvider)
assert ollama_provider.provider_name == "ollama"

azure_openai_provider = get_provider(
"azure_openai",
api_key="test-key",
endpoint="https://example.com",
api_version="2024-05-01-preview",
)
assert isinstance(azure_openai_provider, AzureOpenAIProvider)
assert azure_openai_provider.provider_name == "azure_openai"

def test_get_invalid_provider(self) -> None:
"""Test that invalid provider names raise ValueError."""
with pytest.raises(ValueError, match="Unknown provider 'invalid_provider'"):
Expand All @@ -51,7 +61,8 @@ def test_list_providers(self) -> None:
assert "google" in providers
assert "openai" in providers
assert "ollama" in providers
assert len(providers) >= 3
assert "azure_openai" in providers
assert len(providers) >= 4

def test_register_custom_provider(self) -> None:
"""Test registering a custom provider."""
Expand Down
33 changes: 33 additions & 0 deletions codebase_rag/tests/test_provider_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,36 @@ def test_openai_custom_endpoint(self) -> None:
assert orch_config.model_id == "gpt-4o"
assert orch_config.api_key == "sk-test-key"
assert orch_config.endpoint == "https://api.custom-openai.com/v1"

def test_azure_openai_endpoint_configuration(self) -> None:
"""Test Azure OpenAI endpoint configuration with API version."""
with patch.dict(
os.environ,
{
"ORCHESTRATOR_PROVIDER": "azure_openai",
"ORCHESTRATOR_MODEL": "DeepSeek-V3.1",
"ORCHESTRATOR_API_KEY": "test-azure-key",
"ORCHESTRATOR_ENDPOINT": "https://resource.openai.azure.com/",
"AZURE_OPEN_AI_API_VERSION": "2024-02-15-preview",
"CYPHER_PROVIDER": "azure_openai",
"CYPHER_MODEL": "DeepSeek-V3.1",
"CYPHER_API_KEY": "test-azure-key",
"CYPHER_ENDPOINT": "https://resource.openai.azure.com/",
},
):
config = AppConfig()

# Test orchestrator Azure OpenAI config
orch_config = config.active_orchestrator_config
assert orch_config.provider == "azure_openai"
assert orch_config.model_id == "DeepSeek-V3.1"
assert orch_config.api_key == "test-azure-key"
assert orch_config.endpoint == "https://resource.openai.azure.com/"
assert orch_config.api_version == "2024-02-15-preview"

# Test cypher Azure OpenAI config
cypher_config = config.active_cypher_config
assert cypher_config.provider == "azure_openai"
assert cypher_config.model_id == "DeepSeek-V3.1"
assert cypher_config.api_key == "test-azure-key"
assert cypher_config.endpoint == "https://resource.openai.azure.com/"
Comment on lines +283 to +288
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test for the cypher Azure OpenAI configuration is missing an assertion for api_version. Since AZURE_OPEN_AI_API_VERSION is a global setting, it should be applied to both the orchestrator and cypher configurations when they use the azure_openai provider. Adding this assertion will make the test more complete and ensure the configuration is applied correctly for both roles.

Suggested change
# Test cypher Azure OpenAI config
cypher_config = config.active_cypher_config
assert cypher_config.provider == "azure_openai"
assert cypher_config.model_id == "DeepSeek-V3.1"
assert cypher_config.api_key == "test-azure-key"
assert cypher_config.endpoint == "https://resource.openai.azure.com/"
# Test cypher Azure OpenAI config
cypher_config = config.active_cypher_config
assert cypher_config.provider == "azure_openai"
assert cypher_config.model_id == "DeepSeek-V3.1"
assert cypher_config.api_key == "test-azure-key"
assert cypher_config.endpoint == "https://resource.openai.azure.com/"
assert cypher_config.api_version == "2024-02-15-preview"