diff --git a/.gitignore b/.gitignore index 8da05107fd..d13345bd17 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ .venv/ venv/ pr_agent/settings/.secrets.toml +pr_agent/settings_prod/.secrets.toml __pycache__ dist/ *.egg-info/ diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index de9993284d..e2992f313c 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -152,6 +152,25 @@ def __init__(self): # Models that require streaming self.streaming_required_models = STREAMING_REQUIRED_MODELS + self.force_streaming_provider = str( + getattr(get_settings().litellm, "force_streaming_custom_llm_provider", "") or "" + ).strip().lower() + raw_force_streaming_api_base_substrings = getattr( + get_settings().litellm, "force_streaming_api_base_substrings", [] + ) + if isinstance(raw_force_streaming_api_base_substrings, (list, tuple, set)): + self.force_streaming_api_base_substrings = [ + str(value).strip().lower() + for value in raw_force_streaming_api_base_substrings + if value is not None and str(value).strip() + ] + else: + if raw_force_streaming_api_base_substrings: + get_logger().warning( + "LITELLM.FORCE_STREAMING_API_BASE_SUBSTRINGS must be a list, tuple, or set. " + "Ignoring invalid value." + ) + self.force_streaming_api_base_substrings = [] def prepare_logs(self, response, system, user, resp, finish_reason): response_log = response.dict().copy() @@ -395,6 +414,12 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: # Support for custom OpenAI body fields (e.g., Flex Processing) kwargs = _process_litellm_extra_body(kwargs) + custom_llm_provider = str( + getattr(get_settings().litellm, "custom_llm_provider", "") or "" + ).strip().lower() + if custom_llm_provider: + kwargs["custom_llm_provider"] = custom_llm_provider + # Support for Bedrock custom inference profile via model_id model_id = get_settings().get("litellm.model_id") if model_id and 'bedrock/' in model: @@ -442,9 +467,28 @@ async def _get_completion(self, **kwargs): Wrapper that automatically handles streaming for required models. """ model = kwargs["model"] - if model in self.streaming_required_models: + custom_llm_provider = str(kwargs.get("custom_llm_provider") or "").strip().lower() + api_base_value = kwargs.get("api_base") + api_base = api_base_value.strip().lower() if isinstance(api_base_value, str) else "" + force_streaming = ( + bool(self.force_streaming_provider) + and custom_llm_provider == self.force_streaming_provider + and bool(self.force_streaming_api_base_substrings) + and any(substring in api_base for substring in self.force_streaming_api_base_substrings) + ) + + # Some OpenAI-compatible endpoints can return an empty-string + # finish_reason on non-streaming responses, which LiteLLM rejects during + # response normalization. Streaming avoids that conversion path. + if model in self.streaming_required_models or force_streaming: kwargs["stream"] = True - get_logger().info(f"Using streaming mode for model {model}") + if force_streaming and model not in self.streaming_required_models: + get_logger().info( + f"Using streaming mode for model {model} " + "due to OpenAI-compatible endpoint compatibility" + ) + else: + get_logger().info(f"Using streaming mode for model {model}") response = await acompletion(**kwargs) resp, finish_reason = await _handle_streaming_response(response) # Create MockResponse for streaming since we don't have the full response object diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 16ffbcae2a..10695f2605 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -323,6 +323,9 @@ enable_callbacks = false success_callback = [] failure_callback = [] service_callback = [] +custom_llm_provider = "" +force_streaming_custom_llm_provider = "" +force_streaming_api_base_substrings = [] # model_id = "" # Optional: Custom inference profile ID for Amazon Bedrock [pr_similar_issue] diff --git a/tests/unittest/test_litellm_custom_provider.py b/tests/unittest/test_litellm_custom_provider.py new file mode 100644 index 0000000000..b953de69d8 --- /dev/null +++ b/tests/unittest/test_litellm_custom_provider.py @@ -0,0 +1,306 @@ +from unittest.mock import AsyncMock, patch + +import pytest + +import pr_agent.algo.ai_handlers.litellm_ai_handler as litellm_handler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler + + +def create_mock_settings( + custom_llm_provider=None, + force_streaming_custom_llm_provider="openai", + force_streaming_api_base_substrings=None, +): + if force_streaming_api_base_substrings is None: + force_streaming_api_base_substrings = ["snowflakecomputing.com"] + + litellm_settings = type("", (), {"get": lambda self, key, default=None: default})() + if custom_llm_provider is not None: + litellm_settings.custom_llm_provider = custom_llm_provider + litellm_settings.force_streaming_custom_llm_provider = force_streaming_custom_llm_provider + litellm_settings.force_streaming_api_base_substrings = force_streaming_api_base_substrings + + def get_value(key, default=None): + values = { + "LITELLM.CUSTOM_LLM_PROVIDER": custom_llm_provider, + "litellm.custom_llm_provider": custom_llm_provider, + "LITELLM.FORCE_STREAMING_CUSTOM_LLM_PROVIDER": force_streaming_custom_llm_provider, + "litellm.force_streaming_custom_llm_provider": force_streaming_custom_llm_provider, + "LITELLM.FORCE_STREAMING_API_BASE_SUBSTRINGS": force_streaming_api_base_substrings, + "litellm.force_streaming_api_base_substrings": force_streaming_api_base_substrings, + } + return values.get(key, default) + + return type( + "", + (), + { + "config": type( + "", + (), + { + "ai_timeout": 120, + "custom_reasoning_model": False, + "verbosity_level": 0, + "get": lambda self, key, default=None: default, + }, + )(), + "litellm": litellm_settings, + "get": staticmethod(get_value), + }, + )() + + +def create_mock_acompletion_response(): + response_payload = { + "choices": [{"message": {"content": "test"}, "finish_reason": "stop"}] + } + + class MockCompletionResponse(dict): + def dict(self): + return dict(self) + + return MockCompletionResponse(response_payload) + + +@pytest.mark.asyncio +async def test_custom_llm_provider_is_forwarded_without_rewriting_model(monkeypatch): + fake_settings = create_mock_settings(" OpenAI ") + monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings) + + with patch( + "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock, + ) as mock_completion: + mock_completion.return_value = create_mock_acompletion_response() + + handler = LiteLLMAIHandler() + await handler.chat_completion( + model="claude-sonnet-4-5", + system="test system", + user="test user", + ) + + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["model"] == "claude-sonnet-4-5" + assert call_kwargs["custom_llm_provider"] == "openai" + + +@pytest.mark.asyncio +async def test_custom_llm_provider_is_omitted_when_unset(monkeypatch): + fake_settings = create_mock_settings() + monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings) + + with patch( + "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock, + ) as mock_completion: + mock_completion.return_value = create_mock_acompletion_response() + + handler = LiteLLMAIHandler() + await handler.chat_completion( + model="claude-sonnet-4-5", + system="test system", + user="test user", + ) + + call_kwargs = mock_completion.call_args[1] + assert "custom_llm_provider" not in call_kwargs + + +@pytest.mark.asyncio +async def test_openai_compatible_endpoint_calls_force_streaming(monkeypatch): + fake_settings = create_mock_settings("openai") + monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings) + + with ( + patch( + "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock, + ) as mock_completion, + patch( + "pr_agent.algo.ai_handlers.litellm_ai_handler._handle_streaming_response", + new_callable=AsyncMock, + ) as mock_stream_handler, + ): + mock_stream_handler.return_value = ("test", "stop") + handler = LiteLLMAIHandler() + await handler._get_completion( + model="claude-sonnet-4-5", + messages=[], + timeout=120, + api_base="https://example-account.snowflakecomputing.com/api/v2/cortex/v1", + custom_llm_provider="openai", + ) + + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["stream"] is True + + +@pytest.mark.asyncio +async def test_openai_compatible_endpoint_normalizes_custom_provider_for_streaming(monkeypatch): + fake_settings = create_mock_settings(" OpenAI ") + monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings) + + with ( + patch( + "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock, + ) as mock_completion, + patch( + "pr_agent.algo.ai_handlers.litellm_ai_handler._handle_streaming_response", + new_callable=AsyncMock, + ) as mock_stream_handler, + ): + mock_stream_handler.return_value = ("test", "stop") + handler = LiteLLMAIHandler() + await handler._get_completion( + model="claude-sonnet-4-5", + messages=[], + timeout=120, + api_base="https://example-account.snowflakecomputing.com/api/v2/cortex/v1", + custom_llm_provider=" OpenAI ", + ) + + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["stream"] is True + + +@pytest.mark.asyncio +async def test_openai_compatible_endpoint_ignores_non_string_api_base(monkeypatch): + fake_settings = create_mock_settings("openai") + monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings) + + with patch( + "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock, + ) as mock_completion: + mock_completion.return_value = create_mock_acompletion_response() + + handler = LiteLLMAIHandler() + await handler._get_completion( + model="claude-sonnet-4-5", + messages=[], + timeout=120, + api_base=123, + custom_llm_provider="openai", + ) + + call_kwargs = mock_completion.call_args[1] + assert "stream" not in call_kwargs + + +@pytest.mark.asyncio +async def test_force_streaming_is_settings_driven(monkeypatch): + fake_settings = create_mock_settings( + "openai", + force_streaming_custom_llm_provider="openai", + force_streaming_api_base_substrings=["example-gateway.local"], + ) + monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings) + + with patch( + "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock, + ) as mock_completion: + mock_completion.return_value = create_mock_acompletion_response() + + handler = LiteLLMAIHandler() + await handler._get_completion( + model="claude-sonnet-4-5", + messages=[], + timeout=120, + api_base="https://example-account.snowflakecomputing.com/api/v2/cortex/v1", + custom_llm_provider="openai", + ) + + call_kwargs = mock_completion.call_args[1] + assert "stream" not in call_kwargs + + +@pytest.mark.asyncio +async def test_force_streaming_requires_non_empty_provider_setting(monkeypatch): + fake_settings = create_mock_settings( + "openai", + force_streaming_custom_llm_provider="", + force_streaming_api_base_substrings=["snowflakecomputing.com"], + ) + monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings) + + with patch( + "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock, + ) as mock_completion: + mock_completion.return_value = create_mock_acompletion_response() + + handler = LiteLLMAIHandler() + await handler._get_completion( + model="claude-sonnet-4-5", + messages=[], + timeout=120, + api_base="https://example-account.snowflakecomputing.com/api/v2/cortex/v1", + custom_llm_provider="", + ) + + call_kwargs = mock_completion.call_args[1] + assert "stream" not in call_kwargs + + +@pytest.mark.asyncio +async def test_force_streaming_ignores_non_collection_substring_setting(monkeypatch): + fake_settings = create_mock_settings( + "openai", + force_streaming_custom_llm_provider="openai", + force_streaming_api_base_substrings="snowflakecomputing.com", + ) + monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings) + + with patch( + "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock, + ) as mock_completion: + mock_completion.return_value = create_mock_acompletion_response() + + handler = LiteLLMAIHandler() + await handler._get_completion( + model="claude-sonnet-4-5", + messages=[], + timeout=120, + api_base="https://example-account.snowflakecomputing.com/api/v2/cortex/v1", + custom_llm_provider="openai", + ) + + call_kwargs = mock_completion.call_args[1] + assert "stream" not in call_kwargs + + +@pytest.mark.asyncio +async def test_force_streaming_warns_on_invalid_substring_setting(monkeypatch): + fake_settings = create_mock_settings( + "openai", + force_streaming_custom_llm_provider="openai", + force_streaming_api_base_substrings="snowflakecomputing.com", + ) + monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings) + + with ( + patch( + "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock, + ) as mock_completion, + patch("pr_agent.algo.ai_handlers.litellm_ai_handler.get_logger") as mock_logger, + ): + mock_completion.return_value = create_mock_acompletion_response() + handler = LiteLLMAIHandler() + await handler._get_completion( + model="claude-sonnet-4-5", + messages=[], + timeout=120, + api_base="https://example-account.snowflakecomputing.com/api/v2/cortex/v1", + custom_llm_provider="openai", + ) + + mock_logger.return_value.warning.assert_called_once_with( + "LITELLM.FORCE_STREAMING_API_BASE_SUBSTRINGS must be a list, tuple, or set. " + "Ignoring invalid value." + )