Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5f36e23
feat: add dynamic model switching via /model command
vitali87 Jan 10, 2026
91bae6c
refactor: simplify model switching with dataclasses.replace and speci…
vitali87 Jan 10, 2026
ecc99b2
refactor: return canonical provider:model_id format from model creation
vitali87 Jan 10, 2026
2ed0807
refactor: remove redundant boolean return value from _handle_model_co…
vitali87 Jan 10, 2026
2d72509
fix: load provider API keys from environment variables when not expli…
vitali87 Jan 10, 2026
5a7ad72
fix: validate model_id is not empty to prevent invalid model configur…
vitali87 Jan 10, 2026
085c918
feat: add /model help to display usage instructions
vitali87 Jan 10, 2026
3d57132
refactor: rename MODEL_COMMAND_HELP to HELP_COMMAND for clarity
vitali87 Jan 10, 2026
6e8d4f7
fix: require provider in model string format (provider:model)
vitali87 Jan 10, 2026
a52bae0
refactor: consolidate command handling logic to reduce repetition
vitali87 Jan 10, 2026
6c6d019
fix: use exact command matching and rename misleading test
vitali87 Jan 10, 2026
5a67501
fix: use DEFAULT_API_KEY constant for Ollama api_key field
vitali87 Jan 10, 2026
8aa3e33
test: add console output assertion to trailing space test
vitali87 Jan 10, 2026
eae9259
fix: use CHAR_COLON constant and correct provider name in example
vitali87 Jan 10, 2026
2670bb6
refactor: simplify command handling with direct continue statements
vitali87 Jan 10, 2026
471153e
fix: catch specific exceptions (ValueError, AssertionError) in model …
vitali87 Jan 10, 2026
29dc51e
refactor: consolidate argument parsing in model command handler
vitali87 Jan 10, 2026
f01aca8
fix: strip both provider and model_id when parsing model string
vitali87 Jan 10, 2026
14ede30
test: add unit tests for model switching and fix empty provider valid…
vitali87 Jan 10, 2026
e380d74
refactor: inline redundant new_model_string_arg variable
vitali87 Jan 10, 2026
44a8eb9
fix: track config when switching models to preserve settings for same…
vitali87 Jan 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions codebase_rag/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ class FileAction(StrEnum):
DEFAULT_MODEL = "llama3.2"
DEFAULT_API_KEY = "ollama"

ENV_OPENAI_API_KEY = "OPENAI_API_KEY"
ENV_GOOGLE_API_KEY = "GOOGLE_API_KEY"

HELP_ARG = "help"


class GoogleProviderType(StrEnum):
GLA = "gla"
Expand Down Expand Up @@ -254,6 +259,15 @@ class GoogleProviderType(StrEnum):
UI_GRAPH_EXPORT_STATS = "[bold cyan]Export contains {nodes} nodes and {relationships} relationships[/bold cyan]"
UI_ERR_UNEXPECTED = "[bold red]An unexpected error occurred: {error}[/bold red]"
UI_ERR_EXPORT_FAILED = "[bold red]Failed to export graph: {error}[/bold red]"
UI_MODEL_SWITCHED = "[bold green]Model switched to: {model}[/bold green]"
UI_MODEL_CURRENT = "[bold cyan]Current model: {model}[/bold cyan]"
UI_MODEL_SWITCH_ERROR = "[bold red]Failed to switch model: {error}[/bold red]"
UI_MODEL_USAGE = "[bold yellow]Usage: /model <provider:model> (e.g., /model google:gemini-2.0-flash)[/bold yellow]"
UI_HELP_COMMANDS = """[bold cyan]Available commands:[/bold cyan]
/model <provider:model> - Switch to a different model
/model - Show current model
/help - Show this help
exit, quit - Exit the session"""
UI_TOOL_ARGS_FORMAT = " Arguments: {args}"
UI_REFERENCE_DOC_INFO = " using the reference document: {reference_document}"
UI_INPUT_PROMPT_HTML = (
Expand Down Expand Up @@ -555,6 +569,10 @@ class LanguageMetadata(NamedTuple):
# (H) CLI exit commands
EXIT_COMMANDS = frozenset({"exit", "quit"})

# (H) CLI commands
MODEL_COMMAND_PREFIX = "/model"
HELP_COMMAND = "/help"

# (H) UI separators and formatting
HORIZONTAL_SEPARATOR = "─" * 60

Expand Down
4 changes: 4 additions & 0 deletions codebase_rag/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

# (H) Configuration errors
PROVIDER_EMPTY = "Provider name cannot be empty in 'provider:model' format."
MODEL_ID_EMPTY = "Model ID cannot be empty."
MODEL_FORMAT_INVALID = (
"Model must be specified as 'provider:model' (e.g., openai:gpt-4o)."
)
BATCH_SIZE_POSITIVE = "batch_size must be a positive integer"
CONFIG = "{role} configuration error: {error}"

Expand Down
5 changes: 5 additions & 0 deletions codebase_rag/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,3 +614,8 @@
# (H) Exclude prompt logs
EXCLUDE_INVALID_INDEX = "Invalid index: {index} (out of range)"
EXCLUDE_INVALID_INPUT = "Invalid input: '{input}' (expected number)"

# (H) Model switching logs
MODEL_SWITCHED = "Model switched to: {model}"
MODEL_SWITCH_FAILED = "Failed to switch model: {error}"
MODEL_CURRENT = "Current model: {model}"
115 changes: 108 additions & 7 deletions codebase_rag/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import uuid
from collections import deque
from collections.abc import Coroutine
from dataclasses import replace
from pathlib import Path
from typing import TYPE_CHECKING

Expand All @@ -27,9 +28,10 @@
from . import constants as cs
from . import exceptions as ex
from . import logs as ls
from .config import load_cgrignore_patterns, settings
from .config import ModelConfig, load_cgrignore_patterns, settings
from .models import AppContext
from .prompts import OPTIMIZATION_PROMPT, OPTIMIZATION_PROMPT_WITH_REFERENCE
from .providers.base import get_provider_from_config
from .services import QueryProtocol
from .services.graph_service import MemgraphIngestor
from .services.llm import CypherGenerator, create_rag_orchestrator
Expand Down Expand Up @@ -64,8 +66,7 @@
from prompt_toolkit.key_binding import KeyPressEvent
from pydantic_ai import Agent
from pydantic_ai.messages import ModelMessage

from .config import ModelConfig
from pydantic_ai.models import Model


def style(
Expand Down Expand Up @@ -389,6 +390,7 @@ async def _run_agent_response_loop(
question_with_context: str,
config: AgentLoopUI,
tool_names: ConfirmationToolNames,
model_override: Model | None = None,
) -> None:
deferred_results: DeferredToolResults | None = None

Expand All @@ -399,6 +401,7 @@ async def _run_agent_response_loop(
question_with_context,
message_history=message_history,
deferred_tool_results=deferred_results,
model=model_override,
),
)

Expand Down Expand Up @@ -529,6 +532,75 @@ def keyboard_interrupt(event: KeyPressEvent) -> None:
return stripped


def _create_model_from_string(
model_string: str, current_override_config: ModelConfig | None = None
) -> tuple[Model, str, ModelConfig]:
base_config = current_override_config or settings.active_orchestrator_config

if cs.CHAR_COLON not in model_string:
raise ValueError(ex.MODEL_FORMAT_INVALID)
provider_name, model_id = (
p.strip() for p in settings.parse_model_string(model_string)
)
if not model_id:
raise ValueError(ex.MODEL_ID_EMPTY)
if not provider_name:
raise ValueError(ex.PROVIDER_EMPTY)

if provider_name == base_config.provider:
config = replace(base_config, model_id=model_id)
elif provider_name == cs.Provider.OLLAMA:
config = ModelConfig(
provider=provider_name,
model_id=model_id,
endpoint=str(settings.LOCAL_MODEL_ENDPOINT),
api_key=cs.DEFAULT_API_KEY,
)
else:
config = ModelConfig(provider=provider_name, model_id=model_id)

canonical_string = f"{provider_name}{cs.CHAR_COLON}{model_id}"
provider = get_provider_from_config(config)
return provider.create_model(model_id), canonical_string, config


def _handle_model_command(
command: str,
current_model: Model | None,
current_model_string: str | None,
current_config: ModelConfig | None,
) -> tuple[Model | None, str | None, ModelConfig | None]:
parts = command.strip().split(maxsplit=1)
arg = parts[1].strip() if len(parts) > 1 else None

if not arg:
if current_model_string:
display_model = current_model_string
else:
config = settings.active_orchestrator_config
display_model = f"{config.provider}{cs.CHAR_COLON}{config.model_id}"
app_context.console.print(cs.UI_MODEL_CURRENT.format(model=display_model))
return current_model, current_model_string, current_config

if arg.lower() == cs.HELP_ARG:
app_context.console.print(cs.UI_MODEL_USAGE)
return current_model, current_model_string, current_config

try:
new_model, canonical_model_string, new_config = _create_model_from_string(
arg, current_config
)
logger.info(ls.MODEL_SWITCHED.format(model=canonical_model_string))
app_context.console.print(
cs.UI_MODEL_SWITCHED.format(model=canonical_model_string)
)
return new_model, canonical_model_string, new_config
except (ValueError, AssertionError) as e:
logger.error(ls.MODEL_SWITCH_FAILED.format(error=e))
app_context.console.print(cs.UI_MODEL_SWITCH_ERROR.format(error=e))
return current_model, current_model_string, current_config


async def _run_interactive_loop(
rag_agent: Agent[None, str | DeferredToolRequests],
message_history: list[ModelMessage],
Expand All @@ -540,15 +612,39 @@ async def _run_interactive_loop(
) -> None:
init_session_log(project_root)
question = initial_question or ""
model_override: Model | None = None
model_override_string: str | None = None
model_override_config: ModelConfig | None = None

while True:
try:
if not initial_question or question != initial_question:
question = await asyncio.to_thread(get_multiline_input, input_prompt)

if question.lower() in cs.EXIT_COMMANDS:
stripped_question = question.strip()
stripped_lower = stripped_question.lower()

if stripped_lower in cs.EXIT_COMMANDS:
break
if not question.strip():

if not stripped_question:
initial_question = None
continue

command_parts = stripped_lower.split(maxsplit=1)
if command_parts[0] == cs.MODEL_COMMAND_PREFIX:
model_override, model_override_string, model_override_config = (
_handle_model_command(
stripped_question,
model_override,
model_override_string,
model_override_config,
)
)
initial_question = None
continue
if command_parts[0] == cs.HELP_COMMAND:
app_context.console.print(cs.UI_HELP_COMMANDS)
initial_question = None
continue

Expand All @@ -565,7 +661,12 @@ async def _run_interactive_loop(
)

await _run_agent_response_loop(
rag_agent, message_history, question_with_context, config, tool_names
rag_agent,
message_history,
question_with_context,
config,
tool_names,
model_override,
)

initial_question = None
Expand Down Expand Up @@ -608,7 +709,7 @@ def _update_single_model_setting(role: cs.ModelRole, model_string: str) -> None:

if provider == cs.Provider.OLLAMA and not kwargs[cs.FIELD_ENDPOINT]:
kwargs[cs.FIELD_ENDPOINT] = str(settings.LOCAL_MODEL_ENDPOINT)
kwargs[cs.FIELD_API_KEY] = cs.Provider.OLLAMA
kwargs[cs.FIELD_API_KEY] = cs.DEFAULT_API_KEY

set_method(provider, model, **kwargs)

Expand Down
5 changes: 3 additions & 2 deletions codebase_rag/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from abc import ABC, abstractmethod
from urllib.parse import urljoin

Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(
**kwargs: str | int | None,
) -> None:
super().__init__(**kwargs)
self.api_key = api_key
self.api_key = api_key or os.environ.get(cs.ENV_GOOGLE_API_KEY)
self.provider_type = provider_type
self.project_id = project_id
self.region = region
Expand Down Expand Up @@ -104,7 +105,7 @@ def __init__(
**kwargs: str | int | None,
) -> None:
super().__init__(**kwargs)
self.api_key = api_key
self.api_key = api_key or os.environ.get(cs.ENV_OPENAI_API_KEY)
self.endpoint = endpoint

@property
Expand Down
Loading