diff --git a/codebase_rag/constants.py b/codebase_rag/constants.py index 8abcda648..0466740f1 100644 --- a/codebase_rag/constants.py +++ b/codebase_rag/constants.py @@ -123,6 +123,11 @@ class FileAction(StrEnum): DEFAULT_MODEL = "llama3.2" DEFAULT_API_KEY = "ollama" +ENV_OPENAI_API_KEY = "OPENAI_API_KEY" +ENV_GOOGLE_API_KEY = "GOOGLE_API_KEY" + +HELP_ARG = "help" + class GoogleProviderType(StrEnum): GLA = "gla" @@ -254,6 +259,15 @@ class GoogleProviderType(StrEnum): UI_GRAPH_EXPORT_STATS = "[bold cyan]Export contains {nodes} nodes and {relationships} relationships[/bold cyan]" UI_ERR_UNEXPECTED = "[bold red]An unexpected error occurred: {error}[/bold red]" UI_ERR_EXPORT_FAILED = "[bold red]Failed to export graph: {error}[/bold red]" +UI_MODEL_SWITCHED = "[bold green]Model switched to: {model}[/bold green]" +UI_MODEL_CURRENT = "[bold cyan]Current model: {model}[/bold cyan]" +UI_MODEL_SWITCH_ERROR = "[bold red]Failed to switch model: {error}[/bold red]" +UI_MODEL_USAGE = "[bold yellow]Usage: /model (e.g., /model google:gemini-2.0-flash)[/bold yellow]" +UI_HELP_COMMANDS = """[bold cyan]Available commands:[/bold cyan] + /model - Switch to a different model + /model - Show current model + /help - Show this help + exit, quit - Exit the session""" UI_TOOL_ARGS_FORMAT = " Arguments: {args}" UI_REFERENCE_DOC_INFO = " using the reference document: {reference_document}" UI_INPUT_PROMPT_HTML = ( @@ -555,6 +569,10 @@ class LanguageMetadata(NamedTuple): # (H) CLI exit commands EXIT_COMMANDS = frozenset({"exit", "quit"}) +# (H) CLI commands +MODEL_COMMAND_PREFIX = "/model" +HELP_COMMAND = "/help" + # (H) UI separators and formatting HORIZONTAL_SEPARATOR = "─" * 60 diff --git a/codebase_rag/exceptions.py b/codebase_rag/exceptions.py index bf7cb6359..f30202395 100644 --- a/codebase_rag/exceptions.py +++ b/codebase_rag/exceptions.py @@ -22,6 +22,10 @@ # (H) Configuration errors PROVIDER_EMPTY = "Provider name cannot be empty in 'provider:model' format." +MODEL_ID_EMPTY = "Model ID cannot be empty." +MODEL_FORMAT_INVALID = ( + "Model must be specified as 'provider:model' (e.g., openai:gpt-4o)." +) BATCH_SIZE_POSITIVE = "batch_size must be a positive integer" CONFIG = "{role} configuration error: {error}" diff --git a/codebase_rag/logs.py b/codebase_rag/logs.py index 6892bd1a1..899c477ac 100644 --- a/codebase_rag/logs.py +++ b/codebase_rag/logs.py @@ -614,3 +614,8 @@ # (H) Exclude prompt logs EXCLUDE_INVALID_INDEX = "Invalid index: {index} (out of range)" EXCLUDE_INVALID_INPUT = "Invalid input: '{input}' (expected number)" + +# (H) Model switching logs +MODEL_SWITCHED = "Model switched to: {model}" +MODEL_SWITCH_FAILED = "Failed to switch model: {error}" +MODEL_CURRENT = "Current model: {model}" diff --git a/codebase_rag/main.py b/codebase_rag/main.py index e59e1b76c..2b7626e44 100644 --- a/codebase_rag/main.py +++ b/codebase_rag/main.py @@ -9,6 +9,7 @@ import uuid from collections import deque from collections.abc import Coroutine +from dataclasses import replace from pathlib import Path from typing import TYPE_CHECKING @@ -27,9 +28,10 @@ from . import constants as cs from . import exceptions as ex from . import logs as ls -from .config import load_cgrignore_patterns, settings +from .config import ModelConfig, load_cgrignore_patterns, settings from .models import AppContext from .prompts import OPTIMIZATION_PROMPT, OPTIMIZATION_PROMPT_WITH_REFERENCE +from .providers.base import get_provider_from_config from .services import QueryProtocol from .services.graph_service import MemgraphIngestor from .services.llm import CypherGenerator, create_rag_orchestrator @@ -64,8 +66,7 @@ from prompt_toolkit.key_binding import KeyPressEvent from pydantic_ai import Agent from pydantic_ai.messages import ModelMessage - - from .config import ModelConfig + from pydantic_ai.models import Model def style( @@ -389,6 +390,7 @@ async def _run_agent_response_loop( question_with_context: str, config: AgentLoopUI, tool_names: ConfirmationToolNames, + model_override: Model | None = None, ) -> None: deferred_results: DeferredToolResults | None = None @@ -399,6 +401,7 @@ async def _run_agent_response_loop( question_with_context, message_history=message_history, deferred_tool_results=deferred_results, + model=model_override, ), ) @@ -529,6 +532,75 @@ def keyboard_interrupt(event: KeyPressEvent) -> None: return stripped +def _create_model_from_string( + model_string: str, current_override_config: ModelConfig | None = None +) -> tuple[Model, str, ModelConfig]: + base_config = current_override_config or settings.active_orchestrator_config + + if cs.CHAR_COLON not in model_string: + raise ValueError(ex.MODEL_FORMAT_INVALID) + provider_name, model_id = ( + p.strip() for p in settings.parse_model_string(model_string) + ) + if not model_id: + raise ValueError(ex.MODEL_ID_EMPTY) + if not provider_name: + raise ValueError(ex.PROVIDER_EMPTY) + + if provider_name == base_config.provider: + config = replace(base_config, model_id=model_id) + elif provider_name == cs.Provider.OLLAMA: + config = ModelConfig( + provider=provider_name, + model_id=model_id, + endpoint=str(settings.LOCAL_MODEL_ENDPOINT), + api_key=cs.DEFAULT_API_KEY, + ) + else: + config = ModelConfig(provider=provider_name, model_id=model_id) + + canonical_string = f"{provider_name}{cs.CHAR_COLON}{model_id}" + provider = get_provider_from_config(config) + return provider.create_model(model_id), canonical_string, config + + +def _handle_model_command( + command: str, + current_model: Model | None, + current_model_string: str | None, + current_config: ModelConfig | None, +) -> tuple[Model | None, str | None, ModelConfig | None]: + parts = command.strip().split(maxsplit=1) + arg = parts[1].strip() if len(parts) > 1 else None + + if not arg: + if current_model_string: + display_model = current_model_string + else: + config = settings.active_orchestrator_config + display_model = f"{config.provider}{cs.CHAR_COLON}{config.model_id}" + app_context.console.print(cs.UI_MODEL_CURRENT.format(model=display_model)) + return current_model, current_model_string, current_config + + if arg.lower() == cs.HELP_ARG: + app_context.console.print(cs.UI_MODEL_USAGE) + return current_model, current_model_string, current_config + + try: + new_model, canonical_model_string, new_config = _create_model_from_string( + arg, current_config + ) + logger.info(ls.MODEL_SWITCHED.format(model=canonical_model_string)) + app_context.console.print( + cs.UI_MODEL_SWITCHED.format(model=canonical_model_string) + ) + return new_model, canonical_model_string, new_config + except (ValueError, AssertionError) as e: + logger.error(ls.MODEL_SWITCH_FAILED.format(error=e)) + app_context.console.print(cs.UI_MODEL_SWITCH_ERROR.format(error=e)) + return current_model, current_model_string, current_config + + async def _run_interactive_loop( rag_agent: Agent[None, str | DeferredToolRequests], message_history: list[ModelMessage], @@ -540,15 +612,39 @@ async def _run_interactive_loop( ) -> None: init_session_log(project_root) question = initial_question or "" + model_override: Model | None = None + model_override_string: str | None = None + model_override_config: ModelConfig | None = None while True: try: if not initial_question or question != initial_question: question = await asyncio.to_thread(get_multiline_input, input_prompt) - if question.lower() in cs.EXIT_COMMANDS: + stripped_question = question.strip() + stripped_lower = stripped_question.lower() + + if stripped_lower in cs.EXIT_COMMANDS: break - if not question.strip(): + + if not stripped_question: + initial_question = None + continue + + command_parts = stripped_lower.split(maxsplit=1) + if command_parts[0] == cs.MODEL_COMMAND_PREFIX: + model_override, model_override_string, model_override_config = ( + _handle_model_command( + stripped_question, + model_override, + model_override_string, + model_override_config, + ) + ) + initial_question = None + continue + if command_parts[0] == cs.HELP_COMMAND: + app_context.console.print(cs.UI_HELP_COMMANDS) initial_question = None continue @@ -565,7 +661,12 @@ async def _run_interactive_loop( ) await _run_agent_response_loop( - rag_agent, message_history, question_with_context, config, tool_names + rag_agent, + message_history, + question_with_context, + config, + tool_names, + model_override, ) initial_question = None @@ -608,7 +709,7 @@ def _update_single_model_setting(role: cs.ModelRole, model_string: str) -> None: if provider == cs.Provider.OLLAMA and not kwargs[cs.FIELD_ENDPOINT]: kwargs[cs.FIELD_ENDPOINT] = str(settings.LOCAL_MODEL_ENDPOINT) - kwargs[cs.FIELD_API_KEY] = cs.Provider.OLLAMA + kwargs[cs.FIELD_API_KEY] = cs.DEFAULT_API_KEY set_method(provider, model, **kwargs) diff --git a/codebase_rag/providers/base.py b/codebase_rag/providers/base.py index cca68a8c6..7cac65f0f 100644 --- a/codebase_rag/providers/base.py +++ b/codebase_rag/providers/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from abc import ABC, abstractmethod from urllib.parse import urljoin @@ -48,7 +49,7 @@ def __init__( **kwargs: str | int | None, ) -> None: super().__init__(**kwargs) - self.api_key = api_key + self.api_key = api_key or os.environ.get(cs.ENV_GOOGLE_API_KEY) self.provider_type = provider_type self.project_id = project_id self.region = region @@ -104,7 +105,7 @@ def __init__( **kwargs: str | int | None, ) -> None: super().__init__(**kwargs) - self.api_key = api_key + self.api_key = api_key or os.environ.get(cs.ENV_OPENAI_API_KEY) self.endpoint = endpoint @property diff --git a/codebase_rag/tests/test_model_switching.py b/codebase_rag/tests/test_model_switching.py new file mode 100644 index 000000000..17a79351f --- /dev/null +++ b/codebase_rag/tests/test_model_switching.py @@ -0,0 +1,473 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from codebase_rag import constants as cs +from codebase_rag import exceptions as ex +from codebase_rag.config import ModelConfig +from codebase_rag.main import _create_model_from_string, _handle_model_command + +if TYPE_CHECKING: + from collections.abc import Generator + + +@pytest.fixture +def mock_console() -> Generator[MagicMock]: + with patch("codebase_rag.main.app_context") as mock_ctx: + mock_ctx.console = MagicMock() + yield mock_ctx.console + + +@pytest.fixture +def mock_settings() -> Generator[MagicMock]: + with patch("codebase_rag.main.settings") as mock_s: + mock_s.active_orchestrator_config = ModelConfig( + provider="google", model_id="gemini-2.0-flash" + ) + mock_s.parse_model_string.side_effect = lambda x: ( + x.split(":") if ":" in x else ("ollama", x) + ) + yield mock_s + + +class TestHandleModelCommand: + def test_show_current_model_when_no_argument( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + mock_model = MagicMock() + mock_config = MagicMock() + new_model, new_string, new_config = _handle_model_command( + "/model", mock_model, "custom-model", mock_config + ) + + assert new_model == mock_model + assert new_string == "custom-model" + assert new_config == mock_config + mock_console.print.assert_called_once() + call_arg = mock_console.print.call_args[0][0] + assert "custom-model" in call_arg + + def test_show_default_model_when_no_override( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + new_model, new_string, new_config = _handle_model_command( + "/model", None, None, None + ) + + assert new_model is None + assert new_string is None + assert new_config is None + mock_console.print.assert_called_once() + call_arg = mock_console.print.call_args[0][0] + assert "google:gemini-2.0-flash" in call_arg + + def test_switch_to_new_model( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + mock_new_model = MagicMock() + mock_new_config = MagicMock() + with ( + patch("codebase_rag.main.logger") as mock_logger, + patch( + "codebase_rag.main._create_model_from_string", + return_value=(mock_new_model, "openai:gpt-4o", mock_new_config), + ), + ): + new_model, new_string, new_config = _handle_model_command( + "/model openai:gpt-4o", None, None, None + ) + + assert new_model == mock_new_model + assert new_string == "openai:gpt-4o" + assert new_config == mock_new_config + mock_console.print.assert_called_once() + call_arg = mock_console.print.call_args[0][0] + assert "openai:gpt-4o" in call_arg + mock_logger.info.assert_called_once() + + def test_switch_model_with_extra_whitespace( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + mock_new_model = MagicMock() + mock_new_config = MagicMock() + with ( + patch("codebase_rag.main.logger"), + patch( + "codebase_rag.main._create_model_from_string", + return_value=( + mock_new_model, + "anthropic:claude-3-opus", + mock_new_config, + ), + ), + ): + new_model, new_string, new_config = _handle_model_command( + "/model anthropic:claude-3-opus ", None, None, None + ) + + assert new_model == mock_new_model + assert new_string == "anthropic:claude-3-opus" + + def test_show_current_model_with_trailing_space( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + new_model, new_string, new_config = _handle_model_command( + "/model ", None, None, None + ) + + assert new_model is None + assert new_string is None + assert new_config is None + mock_console.print.assert_called_once() + call_arg = mock_console.print.call_args[0][0] + assert "google:gemini-2.0-flash" in call_arg + + def test_preserves_previous_model_on_show( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + mock_model = MagicMock() + mock_config = MagicMock() + new_model, new_string, new_config = _handle_model_command( + "/model", mock_model, "previous:model", mock_config + ) + + assert new_model == mock_model + assert new_string == "previous:model" + assert new_config == mock_config + + def test_model_creation_error_shows_error_message( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + with ( + patch("codebase_rag.main.logger") as mock_logger, + patch( + "codebase_rag.main._create_model_from_string", + side_effect=ValueError("Invalid model"), + ), + ): + new_model, new_string, new_config = _handle_model_command( + "/model invalid:model", None, None, None + ) + + assert new_model is None + assert new_string is None + assert new_config is None + mock_logger.error.assert_called_once() + call_arg = mock_console.print.call_args[0][0] + assert "Invalid model" in call_arg + + +class TestModelOverrideInAgentLoop: + @pytest.mark.asyncio + async def test_model_override_passed_to_agent_run(self) -> None: + from codebase_rag.main import _run_agent_response_loop + from codebase_rag.types_defs import CHAT_LOOP_UI, ConfirmationToolNames + + mock_agent = MagicMock() + mock_response = MagicMock() + mock_response.output = "Test response" + mock_response.new_messages.return_value = [] + mock_agent.run = AsyncMock(return_value=mock_response) + + mock_model = MagicMock() + tool_names = ConfirmationToolNames( + replace_code="replace", create_file="create", shell_command="shell" + ) + + with ( + patch("codebase_rag.main.app_context") as mock_ctx, + patch("codebase_rag.main.log_session_event"), + ): + mock_ctx.console.status.return_value.__enter__ = MagicMock() + mock_ctx.console.status.return_value.__exit__ = MagicMock() + mock_ctx.console.print = MagicMock() + + await _run_agent_response_loop( + mock_agent, + [], + "test question", + CHAT_LOOP_UI, + tool_names, + model_override=mock_model, + ) + + mock_agent.run.assert_called_once() + _, kwargs = mock_agent.run.call_args + assert kwargs.get("model") is mock_model + + @pytest.mark.asyncio + async def test_model_override_none_by_default(self) -> None: + from codebase_rag.main import _run_agent_response_loop + from codebase_rag.types_defs import CHAT_LOOP_UI, ConfirmationToolNames + + mock_agent = MagicMock() + mock_response = MagicMock() + mock_response.output = "Test response" + mock_response.new_messages.return_value = [] + mock_agent.run = AsyncMock(return_value=mock_response) + + tool_names = ConfirmationToolNames( + replace_code="replace", create_file="create", shell_command="shell" + ) + + with ( + patch("codebase_rag.main.app_context") as mock_ctx, + patch("codebase_rag.main.log_session_event"), + ): + mock_ctx.console.status.return_value.__enter__ = MagicMock() + mock_ctx.console.status.return_value.__exit__ = MagicMock() + mock_ctx.console.print = MagicMock() + + await _run_agent_response_loop( + mock_agent, + [], + "test question", + CHAT_LOOP_UI, + tool_names, + ) + + mock_agent.run.assert_called_once() + _, kwargs = mock_agent.run.call_args + assert kwargs.get("model") is None + + +class TestCommandConstants: + def test_model_command_prefix(self) -> None: + assert cs.MODEL_COMMAND_PREFIX == "/model" + + def test_help_command(self) -> None: + assert cs.HELP_COMMAND == "/help" + + def test_ui_messages_exist(self) -> None: + assert hasattr(cs, "UI_MODEL_SWITCHED") + assert hasattr(cs, "UI_MODEL_CURRENT") + assert hasattr(cs, "UI_MODEL_USAGE") + assert hasattr(cs, "UI_HELP_COMMANDS") + + def test_ui_model_switched_format(self) -> None: + result = cs.UI_MODEL_SWITCHED.format(model="test-model") + assert "test-model" in result + + def test_ui_model_current_format(self) -> None: + result = cs.UI_MODEL_CURRENT.format(model="current-model") + assert "current-model" in result + + +class TestMultipleModelSwitches: + def test_multiple_switches_in_sequence( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + mock_model_a = MagicMock(name="model-a") + mock_model_b = MagicMock(name="model-b") + mock_model_c = MagicMock(name="model-c") + mock_config_a = MagicMock(name="config-a") + mock_config_b = MagicMock(name="config-b") + mock_config_c = MagicMock(name="config-c") + + with ( + patch("codebase_rag.main.logger"), + patch("codebase_rag.main._create_model_from_string") as mock_create, + ): + mock_create.return_value = (mock_model_a, "ollama:model-a", mock_config_a) + model, model_str, config = _handle_model_command( + "/model ollama:model-a", None, None, None + ) + assert model == mock_model_a + assert model_str == "ollama:model-a" + + mock_create.return_value = (mock_model_b, "ollama:model-b", mock_config_b) + model, model_str, config = _handle_model_command( + "/model ollama:model-b", model, model_str, config + ) + assert model == mock_model_b + assert model_str == "ollama:model-b" + + mock_create.return_value = (mock_model_c, "ollama:model-c", mock_config_c) + model, model_str, config = _handle_model_command( + "/model ollama:model-c", model, model_str, config + ) + assert model == mock_model_c + assert model_str == "ollama:model-c" + + model, model_str, config = _handle_model_command( + "/model", model, model_str, config + ) + assert model == mock_model_c + assert model_str == "ollama:model-c" + + def test_switch_then_show_preserves_model( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + mock_model = MagicMock() + mock_config = MagicMock() + with ( + patch("codebase_rag.main.logger"), + patch( + "codebase_rag.main._create_model_from_string", + return_value=(mock_model, "openai:gpt-4", mock_config), + ), + ): + model, model_str, config = _handle_model_command( + "/model openai:gpt-4", None, None, None + ) + assert model == mock_model + assert model_str == "openai:gpt-4" + + model, model_str, config = _handle_model_command( + "/model", model, model_str, config + ) + assert model == mock_model + assert model_str == "openai:gpt-4" + call_arg = mock_console.print.call_args[0][0] + assert "openai:gpt-4" in call_arg + + +class TestModelHelpCommand: + def test_model_help_shows_usage( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + new_model, new_string, new_config = _handle_model_command( + "/model help", None, None, None + ) + + assert new_model is None + assert new_string is None + assert new_config is None + mock_console.print.assert_called_once() + call_arg = mock_console.print.call_args[0][0] + assert "Usage:" in call_arg or "provider:model" in call_arg + + def test_model_help_case_insensitive( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + new_model, new_string, new_config = _handle_model_command( + "/model HELP", None, None, None + ) + + assert new_model is None + assert new_string is None + assert new_config is None + mock_console.print.assert_called_once() + + def test_model_help_preserves_current_model( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + mock_model = MagicMock() + mock_config = MagicMock() + new_model, new_string, new_config = _handle_model_command( + "/model help", mock_model, "current:model", mock_config + ) + + assert new_model == mock_model + assert new_string == "current:model" + assert new_config == mock_config + + +class TestCreateModelFromString: + def test_missing_colon_raises_format_error(self, mock_settings: MagicMock) -> None: + with pytest.raises(ValueError, match=re.escape(ex.MODEL_FORMAT_INVALID)): + _create_model_from_string("modelwithoutcolon") + + def test_empty_model_id_raises_error(self, mock_settings: MagicMock) -> None: + with pytest.raises(ValueError, match=ex.MODEL_ID_EMPTY): + _create_model_from_string("openai:") + + def test_empty_provider_raises_error(self, mock_settings: MagicMock) -> None: + with pytest.raises(ValueError, match=ex.PROVIDER_EMPTY): + _create_model_from_string(":gpt-4o") + + def test_whitespace_around_colon_is_stripped( + self, mock_settings: MagicMock + ) -> None: + mock_model = MagicMock() + with patch("codebase_rag.main.get_provider_from_config") as mock_get_provider: + mock_provider = MagicMock() + mock_provider.create_model.return_value = mock_model + mock_get_provider.return_value = mock_provider + + model, canonical, config = _create_model_from_string("openai : gpt-4o") + + assert canonical == "openai:gpt-4o" + mock_provider.create_model.assert_called_once_with("gpt-4o") + + def test_invalid_provider_raises_error(self, mock_settings: MagicMock) -> None: + with patch( + "codebase_rag.main.get_provider_from_config", + side_effect=ValueError("Unknown provider"), + ): + with pytest.raises(ValueError, match="Unknown provider"): + _create_model_from_string("invalid:model") + + def test_same_provider_uses_current_config(self, mock_settings: MagicMock) -> None: + mock_model = MagicMock() + with patch("codebase_rag.main.get_provider_from_config") as mock_get_provider: + mock_provider = MagicMock() + mock_provider.create_model.return_value = mock_model + mock_get_provider.return_value = mock_provider + + model, canonical, config = _create_model_from_string("google:gemini-pro") + + assert canonical == "google:gemini-pro" + + def test_ollama_provider_uses_local_endpoint( + self, mock_settings: MagicMock + ) -> None: + mock_model = MagicMock() + mock_settings.LOCAL_MODEL_ENDPOINT = "http://localhost:11434/v1" + + with patch("codebase_rag.main.get_provider_from_config") as mock_get_provider: + mock_provider = MagicMock() + mock_provider.create_model.return_value = mock_model + mock_get_provider.return_value = mock_provider + + model, canonical, config = _create_model_from_string("ollama:llama3") + + assert canonical == "ollama:llama3" + + +class TestModelCommandEdgeCases: + def test_assertion_error_is_caught( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + with ( + patch("codebase_rag.main.logger") as mock_logger, + patch( + "codebase_rag.main._create_model_from_string", + side_effect=AssertionError("Missing API key"), + ), + ): + new_model, new_string, new_config = _handle_model_command( + "/model openai:gpt-4o", None, None, None + ) + + assert new_model is None + assert new_string is None + assert new_config is None + mock_logger.error.assert_called_once() + call_arg = mock_console.print.call_args[0][0] + assert "Missing API key" in call_arg + + def test_value_error_is_caught( + self, mock_console: MagicMock, mock_settings: MagicMock + ) -> None: + with ( + patch("codebase_rag.main.logger") as mock_logger, + patch( + "codebase_rag.main._create_model_from_string", + side_effect=ValueError("Invalid configuration"), + ), + ): + new_model, new_string, new_config = _handle_model_command( + "/model bad:config", None, None, None + ) + + assert new_model is None + assert new_string is None + assert new_config is None + mock_logger.error.assert_called_once() + call_arg = mock_console.print.call_args[0][0] + assert "Invalid configuration" in call_arg diff --git a/uv.lock b/uv.lock index 7a18f23d3..eb58bb880 100644 --- a/uv.lock +++ b/uv.lock @@ -1026,7 +1026,7 @@ wheels = [ [[package]] name = "graph-code" -version = "0.0.24" +version = "0.0.25" source = { editable = "." } dependencies = [ { name = "click" },