diff --git a/codebase_rag/deps.py b/codebase_rag/deps.py new file mode 100644 index 000000000..a599e3a64 --- /dev/null +++ b/codebase_rag/deps.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from rich.console import Console + +from .services import QueryProtocol + + +@dataclass +class RAGDeps: + project_root: Path + ingestor: QueryProtocol + cypher_generator: Any + code_retriever: Any + file_reader: Any + file_writer: Any + file_editor: Any + shell_commander: Any + directory_lister: Any + document_analyzer: Any + console: Console diff --git a/codebase_rag/exceptions.py b/codebase_rag/exceptions.py new file mode 100644 index 000000000..adb608914 --- /dev/null +++ b/codebase_rag/exceptions.py @@ -0,0 +1,2 @@ +class LLMGenerationError(Exception): + pass diff --git a/codebase_rag/main.py b/codebase_rag/main.py index 30f5d935d..1d9bd5ac4 100644 --- a/codebase_rag/main.py +++ b/codebase_rag/main.py @@ -26,25 +26,20 @@ ORANGE_STYLE, settings, ) +from .deps import RAGDeps from .graph_updater import GraphUpdater from .parser_loader import load_parsers from .services import QueryProtocol from .services.graph_service import MemgraphIngestor from .services.llm import CypherGenerator, create_rag_orchestrator -from .services.protobuf_service import ProtobufFileIngestor -from .tools.code_retrieval import CodeRetriever, create_code_retrieval_tool -from .tools.codebase_query import create_query_tool -from .tools.directory_lister import DirectoryLister, create_directory_lister_tool -from .tools.document_analyzer import DocumentAnalyzer, create_document_analyzer_tool -from .tools.file_editor import FileEditor, create_file_editor_tool -from .tools.file_reader import FileReader, create_file_reader_tool -from .tools.file_writer import FileWriter, create_file_writer_tool +from .tools.code_retrieval import CodeRetriever +from .tools.directory_lister import DirectoryLister +from .tools.document_analyzer import DocumentAnalyzer +from .tools.file_editor import FileEditor +from .tools.file_reader import FileReader +from .tools.file_writer import FileWriter from .tools.language import cli as language_cli -from .tools.semantic_search import ( - create_get_function_source_tool, - create_semantic_search_tool, -) -from .tools.shell_command import ShellCommander, create_shell_command_tool +from .tools.shell_command import ShellCommander confirm_edits_globally = True @@ -65,7 +60,6 @@ def init_session_log(project_root: Path) -> Path: - """Initialize session log file.""" global session_log_file log_dir = project_root / ".tmp" log_dir.mkdir(exist_ok=True) @@ -76,7 +70,6 @@ def init_session_log(project_root: Path) -> Path: def log_session_event(event: str) -> None: - """Log an event to the session file.""" global session_log_file if session_log_file: with open(session_log_file, "a") as f: @@ -84,7 +77,6 @@ def log_session_event(event: str) -> None: def get_session_context() -> str: - """Get the full session context for cancelled operations.""" global session_log_file if session_log_file and session_log_file.exists(): content = Path(session_log_file).read_text() @@ -138,7 +130,7 @@ def _display_tool_call_diff( console.print("[dim]" + "─" * 60 + "[/dim]") - elif tool_name == "execute_shell_command": + elif tool_name == "run_shell_command": command = tool_args.get("command", "") console.print("\n[bold cyan]Shell command:[/bold cyan]") console.print(f"[yellow]$ {command}[/yellow]") @@ -148,7 +140,6 @@ def _display_tool_call_diff( def _setup_common_initialization(repo_path: str) -> Path: - """Common setup logic for both main and optimize functions.""" logger.remove() logger.add(sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {message}") @@ -169,7 +160,6 @@ def _create_configuration_table( title: str = "Graph-Code Initializing...", language: str | None = None, ) -> Table: - """Create and return a configuration table.""" table = Table(title=f"[bold green]{title}[/bold green]") table.add_column("Configuration", style="cyan") table.add_column("Value", style="magenta") @@ -216,12 +206,12 @@ def _create_configuration_table( async def run_optimization_loop( rag_agent: Any, + deps: RAGDeps, message_history: list[Any], project_root: Path, language: str, reference_document: str | None = None, ) -> None: - """Runs the optimization loop with proper confirmation handling.""" global session_cancelled init_session_log(project_root) @@ -311,6 +301,7 @@ async def run_optimization_loop( console, rag_agent.run( question_with_context, + deps=deps, message_history=message_history, deferred_tool_results=deferred_results, ), @@ -381,7 +372,6 @@ async def run_optimization_loop( async def run_with_cancellation( console: Console, coro: Any, timeout: float | None = None ) -> Any: - """Run a coroutine with proper Ctrl+C cancellation that doesn't exit the program.""" task = asyncio.create_task(coro) try: @@ -408,10 +398,6 @@ async def run_with_cancellation( def _handle_chat_images(question: str, project_root: Path) -> str: - """ - Checks for image file paths in the question, copies them to a temporary - directory, and replaces the path in the question. - """ try: tokens = shlex.split(question) except ValueError: @@ -472,22 +458,18 @@ def _handle_chat_images(question: str, project_root: Path) -> str: def get_multiline_input(prompt_text: str = "Ask a question") -> str: - """Get multiline input from user with Ctrl+J to submit.""" bindings = KeyBindings() @bindings.add("c-j") def submit(event: Any) -> None: - """Submit the current input.""" event.app.exit(result=event.app.current_buffer.text) @bindings.add("enter") def new_line(event: Any) -> None: - """Insert a new line instead of submitting.""" event.current_buffer.insert_text("\n") @bindings.add("c-c") def keyboard_interrupt(event: Any) -> None: - """Handle Ctrl+C.""" event.app.exit(exception=KeyboardInterrupt) clean_prompt = Text.from_markup(prompt_text).plain @@ -511,7 +493,7 @@ def keyboard_interrupt(event: Any) -> None: async def run_chat_loop( - rag_agent: Any, message_history: list[Any], project_root: Path + rag_agent: Any, deps: RAGDeps, message_history: list[Any], project_root: Path ) -> None: global session_cancelled @@ -550,6 +532,7 @@ async def run_chat_loop( console, rag_agent.run( question_with_context, + deps=deps, message_history=message_history, deferred_tool_results=deferred_results, ), @@ -616,7 +599,6 @@ async def run_chat_loop( def _update_single_model_setting(role: str, model_string: str) -> None: - """Update a single model setting (orchestrator or cypher).""" provider, model = settings.parse_model_string(model_string) if role == "orchestrator": @@ -647,7 +629,6 @@ def _update_model_settings( orchestrator: str | None, cypher: str | None, ) -> None: - """Update model settings based on command-line arguments.""" if orchestrator: _update_single_model_setting("orchestrator", orchestrator) if cypher: @@ -655,17 +636,6 @@ def _update_model_settings( def _export_graph_to_file(ingestor: MemgraphIngestor, output: str) -> bool: - """ - Export graph data to a JSON file. - - Args: - ingestor: The MemgraphIngestor instance to export from - output: Output file path - - Returns: - True if export was successful, False otherwise - """ - try: graph_data = ingestor.export_graph_to_dict() output_path = Path(output) @@ -689,12 +659,39 @@ def _export_graph_to_file(ingestor: MemgraphIngestor, output: str) -> bool: return False -def _initialize_services_and_agent(repo_path: str, ingestor: QueryProtocol) -> Any: - """Initializes all services and creates the RAG agent.""" +def _create_deps(repo_path: str, ingestor: QueryProtocol) -> RAGDeps: + cypher_generator = CypherGenerator() + code_retriever = CodeRetriever(project_root=repo_path, ingestor=ingestor) + file_reader = FileReader(project_root=repo_path) + file_writer = FileWriter(project_root=repo_path) + file_editor = FileEditor(project_root=repo_path) + shell_commander = ShellCommander( + project_root=repo_path, timeout=settings.SHELL_COMMAND_TIMEOUT + ) + directory_lister = DirectoryLister(project_root=repo_path) + document_analyzer = DocumentAnalyzer(project_root=repo_path) + + return RAGDeps( + project_root=Path(repo_path).resolve(), + ingestor=ingestor, + cypher_generator=cypher_generator, + code_retriever=code_retriever, + file_reader=file_reader, + file_writer=file_writer, + file_editor=file_editor, + shell_commander=shell_commander, + directory_lister=directory_lister, + document_analyzer=document_analyzer, + console=console, + ) + + +def _initialize_services_and_agent( + repo_path: str, ingestor: QueryProtocol +) -> tuple[Any, RAGDeps]: from .providers.base import get_provider def _validate_provider_config(role: str, config: Any) -> None: - """Validate a single provider configuration.""" try: provider = get_provider( config.provider, @@ -713,47 +710,13 @@ def _validate_provider_config(role: str, config: Any) -> None: _validate_provider_config("orchestrator", settings.active_orchestrator_config) _validate_provider_config("cypher", settings.active_cypher_config) - cypher_generator = CypherGenerator() - code_retriever = CodeRetriever(project_root=repo_path, ingestor=ingestor) - file_reader = FileReader(project_root=repo_path) - file_writer = FileWriter(project_root=repo_path) - file_editor = FileEditor(project_root=repo_path) - shell_commander = ShellCommander( - project_root=repo_path, timeout=settings.SHELL_COMMAND_TIMEOUT - ) - directory_lister = DirectoryLister(project_root=repo_path) - document_analyzer = DocumentAnalyzer(project_root=repo_path) + deps = _create_deps(repo_path, ingestor) + rag_agent = create_rag_orchestrator() - query_tool = create_query_tool(ingestor, cypher_generator, console) - code_tool = create_code_retrieval_tool(code_retriever) - file_reader_tool = create_file_reader_tool(file_reader) - file_writer_tool = create_file_writer_tool(file_writer) - file_editor_tool = create_file_editor_tool(file_editor) - shell_command_tool = create_shell_command_tool(shell_commander) - directory_lister_tool = create_directory_lister_tool(directory_lister) - document_analyzer_tool = create_document_analyzer_tool(document_analyzer) - semantic_search_tool = create_semantic_search_tool() - function_source_tool = create_get_function_source_tool() - - rag_agent = create_rag_orchestrator( - tools=[ - query_tool, - code_tool, - file_reader_tool, - file_writer_tool, - file_editor_tool, - shell_command_tool, - directory_lister_tool, - document_analyzer_tool, - semantic_search_tool, - function_source_tool, - ] - ) - return rag_agent + return rag_agent, deps async def main_async(repo_path: str, batch_size: int) -> None: - """Initializes services and runs the main application loop.""" project_root = _setup_common_initialization(repo_path) table = _create_configuration_table(repo_path) @@ -772,8 +735,8 @@ async def main_async(repo_path: str, batch_size: int) -> None: ) ) - rag_agent = _initialize_services_and_agent(repo_path, ingestor) - await run_chat_loop(rag_agent, [], project_root) + rag_agent, deps = _initialize_services_and_agent(repo_path, ingestor) + await run_chat_loop(rag_agent, deps, [], project_root) @app.command() @@ -819,7 +782,6 @@ def start( help="Number of buffered nodes/relationships before flushing to Memgraph", ), ) -> None: - """Starts the Codebase RAG CLI.""" global confirm_edits_globally confirm_edits_globally = not no_confirm @@ -890,7 +852,6 @@ def index( help="Write index to separate nodes.bin and relationships.bin files.", ), ) -> None: - """Parses a codebase and creates a portable binary index file.""" target_repo_path = repo_path or settings.TARGET_REPO_PATH repo_to_index = Path(target_repo_path) @@ -900,6 +861,8 @@ def index( ) try: + from .services.protobuf_service import ProtobufFileIngestor + ingestor = ProtobufFileIngestor( output_path=output_proto_dir, split_index=split_index ) @@ -932,7 +895,6 @@ def export( help="Number of buffered nodes/relationships before flushing to Memgraph", ), ) -> None: - """Export the current knowledge graph to a file.""" if not format_json: console.print( "[bold red]Error: Currently only JSON format is supported.[/bold red]" @@ -967,7 +929,6 @@ async def main_optimize_async( cypher: str | None = None, batch_size: int | None = None, ) -> None: - """Async wrapper for the optimization functionality.""" project_root = _setup_common_initialization(target_repo_path) _update_model_settings(orchestrator, cypher) @@ -990,9 +951,9 @@ async def main_optimize_async( ) as ingestor: console.print("[bold green]Successfully connected to Memgraph.[/bold green]") - rag_agent = _initialize_services_and_agent(target_repo_path, ingestor) + rag_agent, deps = _initialize_services_and_agent(target_repo_path, ingestor) await run_optimization_loop( - rag_agent, [], project_root, language, reference_document + rag_agent, deps, [], project_root, language, reference_document ) @@ -1032,7 +993,6 @@ def optimize( help="Number of buffered nodes/relationships before flushing to Memgraph", ), ) -> None: - """Optimize a codebase for a specific programming language.""" global confirm_edits_globally confirm_edits_globally = not no_confirm @@ -1058,23 +1018,6 @@ def optimize( @app.command(name="mcp-server") def mcp_server() -> None: - """Start the MCP (Model Context Protocol) server. - - This command starts an MCP server that exposes code-graph-rag's capabilities - to MCP clients like Claude Code. The server runs on stdio transport and requires - the TARGET_REPO_PATH environment variable to be set to the target repository. - - Usage: - graph-code mcp-server - - Environment Variables: - TARGET_REPO_PATH: Path to the target repository (required) - - For Claude Code integration: - claude mcp add --transport stdio graph-code \\ - --env TARGET_REPO_PATH=/path/to/your/project \\ - -- uv run --directory /path/to/code-graph-rag graph-code mcp-server - """ try: from codebase_rag.mcp import main as mcp_main @@ -1094,7 +1037,6 @@ def mcp_server() -> None: def graph_loader_command( graph_file: str = typer.Argument(..., help="Path to the exported graph JSON file"), ) -> None: - """Load and display summary of an exported graph file.""" from .graph_loader import load_graph try: @@ -1120,13 +1062,6 @@ def graph_loader_command( context_settings={"allow_extra_args": True, "allow_interspersed_args": False}, ) def language_command(ctx: typer.Context) -> None: - """Manage language grammars (add, remove, list). - - Examples: - cgr language add-grammar python - cgr language list-languages - cgr language remove-language python - """ language_cli(ctx.args, standalone_mode=False) diff --git a/codebase_rag/mcp/tools.py b/codebase_rag/mcp/tools.py index 14db27368..34626713d 100644 --- a/codebase_rag/mcp/tools.py +++ b/codebase_rag/mcp/tools.py @@ -6,25 +6,21 @@ from loguru import logger +from codebase_rag.exceptions import LLMGenerationError from codebase_rag.graph_updater import GraphUpdater from codebase_rag.parser_loader import load_parsers +from codebase_rag.schemas import CodeSnippet, GraphData from codebase_rag.services.graph_service import MemgraphIngestor from codebase_rag.services.llm import CypherGenerator -from codebase_rag.tools.code_retrieval import CodeRetriever, create_code_retrieval_tool -from codebase_rag.tools.codebase_query import create_query_tool -from codebase_rag.tools.directory_lister import ( - DirectoryLister, - create_directory_lister_tool, -) -from codebase_rag.tools.file_editor import FileEditor, create_file_editor_tool -from codebase_rag.tools.file_reader import FileReader, create_file_reader_tool -from codebase_rag.tools.file_writer import FileWriter, create_file_writer_tool +from codebase_rag.tools.code_retrieval import CodeRetriever +from codebase_rag.tools.directory_lister import DirectoryLister +from codebase_rag.tools.file_editor import FileEditor +from codebase_rag.tools.file_reader import FileReader +from codebase_rag.tools.file_writer import FileWriter @dataclass class ToolMetadata: - """Metadata for an MCP tool including schema and handler information.""" - name: str description: str input_schema: dict[str, Any] @@ -33,21 +29,12 @@ class ToolMetadata: class MCPToolsRegistry: - """Registry for all MCP tools with shared dependencies.""" - def __init__( self, project_root: str, ingestor: MemgraphIngestor, cypher_gen: CypherGenerator, ) -> None: - """Initialize the MCP tools registry. - - Args: - project_root: Path to the target repository - ingestor: Memgraph ingestor instance - cypher_gen: Cypher query generator instance - """ self.project_root = project_root self.ingestor = ingestor self.cypher_gen = cypher_gen @@ -60,17 +47,6 @@ def __init__( self.file_writer = FileWriter(project_root=project_root) self.directory_lister = DirectoryLister(project_root=project_root) - self._query_tool = create_query_tool( - ingestor=ingestor, cypher_gen=cypher_gen, console=None - ) - self._code_tool = create_code_retrieval_tool(code_retriever=self.code_retriever) - self._file_editor_tool = create_file_editor_tool(file_editor=self.file_editor) - self._file_reader_tool = create_file_reader_tool(file_reader=self.file_reader) - self._file_writer_tool = create_file_writer_tool(file_writer=self.file_writer) - self._directory_lister_tool = create_directory_lister_tool( - directory_lister=self.directory_lister - ) - self._tools: dict[str, ToolMetadata] = { "index_repository": ToolMetadata( name="index_repository", @@ -208,18 +184,6 @@ def __init__( } async def index_repository(self) -> str: - """Parse and ingest the repository into the Memgraph knowledge graph. - - This tool analyzes the codebase using Tree-sitter parsers and builds - a comprehensive knowledge graph with functions, classes, dependencies, - and relationships. - - Note: This clears all existing data in the database before indexing. - Only one repository can be indexed at a time. - - Returns: - Success message with indexing statistics - """ logger.info(f"[MCP] Indexing repository at: {self.project_root}") try: @@ -241,57 +205,41 @@ async def index_repository(self) -> str: return f"Error indexing repository: {str(e)}" async def query_code_graph(self, natural_language_query: str) -> dict[str, Any]: - """Query the codebase knowledge graph using natural language. - - This tool converts your natural language question into a Cypher query, - executes it against the knowledge graph, and returns structured results - with summaries. - - Args: - natural_language_query: Your question in plain English (e.g., - "What functions call UserService.create_user?") - - Returns: - Dictionary containing: - - cypher_query: The generated Cypher query - - results: List of result rows from the graph - - summary: Natural language summary of findings - """ logger.info(f"[MCP] query_code_graph: {natural_language_query}") + cypher_query = "N/A" try: - graph_data = await self._query_tool.function(natural_language_query) # type: ignore[arg-type] + cypher_query = await self.cypher_gen.generate(natural_language_query) + results = self.ingestor.fetch_all(cypher_query) + summary = f"Successfully retrieved {len(results)} item(s) from the graph." + graph_data = GraphData( + query_used=cypher_query, results=results, summary=summary + ) result_dict = cast(dict[str, Any], graph_data.model_dump()) logger.info( f"[MCP] Query returned {len(result_dict.get('results', []))} results" ) return result_dict + except LLMGenerationError as e: + return { + "query_used": "N/A", + "results": [], + "summary": f"I couldn't translate your request into a database query. Error: {e}", + } except Exception as e: logger.error(f"[MCP] Error querying code graph: {e}", exc_info=True) return { "error": str(e), - "query_used": "N/A", + "query_used": cypher_query, "results": [], "summary": f"Error executing query: {str(e)}", } async def get_code_snippet(self, qualified_name: str) -> dict[str, Any]: - """Retrieve source code for a function, class, or method by qualified name. - - Args: - qualified_name: Fully qualified name (e.g., "app.services.UserService.create_user") - - Returns: - Dictionary containing: - - file_path: Path to the source file - - src: The source code - - line_start: Starting line number - - line_end: Ending line number - - docstring: Docstring if available - - found: Whether the code was found - """ logger.info(f"[MCP] get_code_snippet: {qualified_name}") try: - snippet = await self._code_tool.function(qualified_name=qualified_name) + snippet: CodeSnippet = await self.code_retriever.find_code_snippet( + qualified_name + ) result = snippet.model_dump() if result is None: return { @@ -311,27 +259,15 @@ async def get_code_snippet(self, qualified_name: str) -> dict[str, Any]: async def surgical_replace_code( self, file_path: str, target_code: str, replacement_code: str ) -> str: - """Surgically replace an exact code block in a file. - - Uses diff-match-patch algorithm to replace only the exact target block, - leaving the rest of the file unchanged. - - Args: - file_path: Relative path to the file from project root - target_code: Exact code block to replace - replacement_code: New code to insert - - Returns: - Success message or error description - """ logger.info(f"[MCP] surgical_replace_code in {file_path}") try: - result = await self._file_editor_tool.function( # type: ignore[call-arg] - file_path=file_path, - target_code=target_code, - replacement_code=replacement_code, + success = self.file_editor.replace_code_block( + file_path, target_code, replacement_code ) - return cast(str, result) + if success: + return f"Successfully applied surgical code replacement in: {file_path}" + else: + return f"Failed to apply surgical replacement in {file_path}. Target code not found or patches failed." except Exception as e: logger.error(f"[MCP] Error replacing code: {e}") return f"Error: {str(e)}" @@ -339,16 +275,6 @@ async def surgical_replace_code( async def read_file( self, file_path: str, offset: int | None = None, limit: int | None = None ) -> str: - """Read the contents of a file with optional pagination. - - Args: - file_path: Relative path to the file from project root - offset: Line number to start reading from (0-based, optional) - limit: Maximum number of lines to read (optional) - - Returns: - File contents (paginated if offset/limit provided) or error message - """ logger.info(f"[MCP] read_file: {file_path} (offset={offset}, limit={limit})") try: if offset is not None or limit is not None: @@ -373,28 +299,19 @@ async def read_file( header = f"# Lines {start + 1}-{start + len(sliced_lines)} of {total_lines}\n" return header + paginated_content else: - result = await self._file_reader_tool.function(file_path=file_path) # type: ignore[call-arg] - return cast(str, result) + result = await self.file_reader.read_file(file_path) + if result.error_message: + return f"Error: {result.error_message}" + return result.content or "" except Exception as e: logger.error(f"[MCP] Error reading file: {e}") return f"Error: {str(e)}" async def write_file(self, file_path: str, content: str) -> str: - """Write content to a file, creating it if it doesn't exist. - - Args: - file_path: Relative path to the file from project root - content: Content to write to the file - - Returns: - Success message or error description - """ logger.info(f"[MCP] write_file: {file_path}") try: - result = await self._file_writer_tool.function( # type: ignore[call-arg] - file_path=file_path, content=content - ) + result = await self.file_writer.create_file(file_path, content) if result.success: return f"Successfully wrote file: {file_path}" else: @@ -404,30 +321,15 @@ async def write_file(self, file_path: str, content: str) -> str: return f"Error: {str(e)}" async def list_directory(self, directory_path: str = ".") -> str: - """List contents of a directory. - - Args: - directory_path: Relative path to directory from project root (default: ".") - - Returns: - Formatted directory listing or error message - """ logger.info(f"[MCP] list_directory: {directory_path}") try: - result = self._directory_lister_tool.function( # type: ignore[call-arg] - directory_path=directory_path - ) + result = self.directory_lister.list_directory_contents(directory_path) return cast(str, result) except Exception as e: logger.error(f"[MCP] Error listing directory: {e}") return f"Error: {str(e)}" def get_tool_schemas(self) -> list[dict[str, Any]]: - """Get MCP tool schemas for all registered tools. - - Returns: - List of tool schema dictionaries suitable for MCP's list_tools() - """ return [ { "name": metadata.name, @@ -438,25 +340,12 @@ def get_tool_schemas(self) -> list[dict[str, Any]]: ] def get_tool_handler(self, name: str) -> tuple[Callable[..., Any], bool] | None: - """Get the handler function and return type info for a tool. - - Args: - name: Tool name to look up - - Returns: - Tuple of (handler_function, returns_json) or None if tool not found - """ metadata = self._tools.get(name) if metadata is None: return None return (metadata.handler, metadata.returns_json) def list_tool_names(self) -> list[str]: - """Get a list of all registered tool names. - - Returns: - List of tool names - """ return list(self._tools.keys()) @@ -465,16 +354,6 @@ def create_mcp_tools_registry( ingestor: MemgraphIngestor, cypher_gen: CypherGenerator, ) -> MCPToolsRegistry: - """Factory function to create an MCP tools registry. - - Args: - project_root: Path to the target repository - ingestor: Memgraph ingestor instance - cypher_gen: Cypher query generator instance - - Returns: - MCPToolsRegistry instance with all tools initialized - """ return MCPToolsRegistry( project_root=project_root, ingestor=ingestor, diff --git a/codebase_rag/services/llm.py b/codebase_rag/services/llm.py index a423e9766..d9f9876e5 100644 --- a/codebase_rag/services/llm.py +++ b/codebase_rag/services/llm.py @@ -2,22 +2,26 @@ from pydantic_ai import Agent, DeferredToolRequests, Tool from ..config import settings +from ..deps import RAGDeps +from ..exceptions import LLMGenerationError from ..prompts import ( CYPHER_SYSTEM_PROMPT, LOCAL_CYPHER_SYSTEM_PROMPT, RAG_ORCHESTRATOR_SYSTEM_PROMPT, ) from ..providers.base import get_provider - - -class LLMGenerationError(Exception): - """Custom exception for LLM generation failures.""" - - pass +from ..tools.code_retrieval import get_code_snippet +from ..tools.codebase_query import query_codebase_knowledge_graph +from ..tools.directory_lister import list_directory +from ..tools.document_analyzer import analyze_document +from ..tools.file_editor import replace_code_surgically +from ..tools.file_reader import read_file_content +from ..tools.file_writer import create_new_file +from ..tools.semantic_search import get_function_source_by_id, semantic_search_functions +from ..tools.shell_command import run_shell_command def _clean_cypher_response(response_text: str) -> str: - """Utility to clean up common LLM formatting artifacts from a Cypher query.""" query = response_text.strip().replace("`", "") if query.startswith("cypher"): query = query[6:].strip() @@ -27,8 +31,6 @@ def _clean_cypher_response(response_text: str) -> str: class CypherGenerator: - """Generates Cypher queries from natural language.""" - def __init__(self) -> None: try: config = settings.active_cypher_config @@ -84,8 +86,7 @@ async def generate(self, natural_language_query: str) -> str: raise LLMGenerationError(f"Cypher generation failed: {e}") from e -def create_rag_orchestrator(tools: list[Tool]) -> Agent: - """Factory function to create the main RAG orchestrator agent.""" +def create_rag_orchestrator() -> Agent[RAGDeps, str | DeferredToolRequests]: try: config = settings.active_orchestrator_config @@ -103,8 +104,20 @@ def create_rag_orchestrator(tools: list[Tool]) -> Agent: return Agent( model=llm, + deps_type=RAGDeps, system_prompt=RAG_ORCHESTRATOR_SYSTEM_PROMPT, - tools=tools, + tools=[ + query_codebase_knowledge_graph, + get_code_snippet, + read_file_content, + Tool(create_new_file, requires_approval=True), + Tool(replace_code_surgically, requires_approval=True), + Tool(run_shell_command, requires_approval=True), + list_directory, + analyze_document, + semantic_search_functions, + get_function_source_by_id, + ], retries=settings.AGENT_RETRIES, output_retries=100, output_type=[str, DeferredToolRequests], diff --git a/codebase_rag/tests/integration/test_mcp_tools_integration.py b/codebase_rag/tests/integration/test_mcp_tools_integration.py index 877034194..52bb031e3 100644 --- a/codebase_rag/tests/integration/test_mcp_tools_integration.py +++ b/codebase_rag/tests/integration/test_mcp_tools_integration.py @@ -109,28 +109,29 @@ async def test_list_directory_works(self, mcp_registry: MCPToolsRegistry) -> Non class TestToolConsistency: """Tests that verify all tools follow consistent patterns.""" - async def test_all_tools_have_consistent_takes_ctx( + async def test_all_service_classes_are_initialized( self, mcp_registry: MCPToolsRegistry ) -> None: - """Verify all tools have consistent takes_ctx settings.""" - tools = { - "query": mcp_registry._query_tool, - "code": mcp_registry._code_tool, - "editor": mcp_registry._file_editor_tool, - "reader": mcp_registry._file_reader_tool, - "writer": mcp_registry._file_writer_tool, - "lister": mcp_registry._directory_lister_tool, - } - - takes_ctx_values = {name: tool.takes_ctx for name, tool in tools.items()} - - assert takes_ctx_values == { - "query": False, - "code": False, - "editor": False, - "reader": False, - "writer": False, - "lister": False, + """Verify all service classes are properly initialized.""" + assert mcp_registry.code_retriever is not None + assert mcp_registry.file_editor is not None + assert mcp_registry.file_reader is not None + assert mcp_registry.file_writer is not None + assert mcp_registry.directory_lister is not None + + async def test_all_tools_are_registered( + self, mcp_registry: MCPToolsRegistry + ) -> None: + """Verify all expected tools are registered.""" + expected_tools = { + "index_repository", + "query_code_graph", + "get_code_snippet", + "surgical_replace_code", + "read_file", + "write_file", + "list_directory", } - assert all(not takes_ctx for takes_ctx in takes_ctx_values.values()) + registered_tools = set(mcp_registry.list_tool_names()) + assert registered_tools == expected_tools diff --git a/codebase_rag/tests/test_mcp_get_code_snippet.py b/codebase_rag/tests/test_mcp_get_code_snippet.py index 1b30bd442..a4e621057 100644 --- a/codebase_rag/tests/test_mcp_get_code_snippet.py +++ b/codebase_rag/tests/test_mcp_get_code_snippet.py @@ -1,5 +1,5 @@ from pathlib import Path -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest @@ -47,9 +47,6 @@ def mcp_registry(temp_project_root: Path) -> MCPToolsRegistry: cypher_gen=mock_cypher_gen, ) - registry._code_tool = MagicMock() - registry._code_tool.function = AsyncMock() - return registry @@ -60,17 +57,15 @@ async def test_get_function_snippet( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: """Test retrieving a function snippet.""" - mcp_registry._code_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "qualified_name": "sample.hello_world", - "source_code": 'def hello_world():\n """Say hello to the world."""\n print("Hello, World!")\n', - "file_path": "sample.py", - "line_start": 1, - "line_end": 3, + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + { + "name": "hello_world", + "start": 1, + "end": 3, + "path": "sample.py", "docstring": "Say hello to the world.", - "found": True, } - ) + ] result = await mcp_registry.get_code_snippet("sample.hello_world") @@ -86,17 +81,15 @@ async def test_get_method_snippet( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: """Test retrieving a class method snippet.""" - mcp_registry._code_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "qualified_name": "sample.Calculator.add", - "source_code": ' def add(self, a: int, b: int) -> int:\n """Add two numbers."""\n return a + b\n', - "file_path": "sample.py", - "line_start": 8, - "line_end": 10, + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + { + "name": "add", + "start": 8, + "end": 10, + "path": "sample.py", "docstring": "Add two numbers.", - "found": True, } - ) + ] result = await mcp_registry.get_code_snippet("sample.Calculator.add") @@ -109,17 +102,15 @@ async def test_get_class_snippet( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: """Test retrieving a class snippet.""" - mcp_registry._code_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "qualified_name": "sample.Calculator", - "source_code": 'class Calculator:\n """Simple calculator class."""\n\n def add(self, a: int, b: int) -> int:\n """Add two numbers."""\n return a + b\n', - "file_path": "sample.py", - "line_start": 5, - "line_end": 10, + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + { + "name": "Calculator", + "start": 5, + "end": 10, + "path": "sample.py", "docstring": "Simple calculator class.", - "found": True, } - ) + ] result = await mcp_registry.get_code_snippet("sample.Calculator") @@ -136,17 +127,7 @@ async def test_get_nonexistent_function( self, mcp_registry: MCPToolsRegistry ) -> None: """Test retrieving a nonexistent function.""" - mcp_registry._code_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "qualified_name": "nonexistent.function", - "source_code": "", - "file_path": "", - "line_start": 0, - "line_end": 0, - "found": False, - "error_message": "Entity not found in graph.", - } - ) + mcp_registry.ingestor.fetch_all.return_value = [] # type: ignore[attr-defined] result = await mcp_registry.get_code_snippet("nonexistent.function") @@ -158,17 +139,7 @@ async def test_get_malformed_qualified_name( self, mcp_registry: MCPToolsRegistry ) -> None: """Test retrieving with malformed qualified name.""" - mcp_registry._code_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "qualified_name": "..invalid..", - "source_code": "", - "file_path": "", - "line_start": 0, - "line_end": 0, - "found": False, - "error_message": "Invalid qualified name format.", - } - ) + mcp_registry.ingestor.fetch_all.return_value = [] # type: ignore[attr-defined] result = await mcp_registry.get_code_snippet("..invalid..") @@ -180,20 +151,22 @@ class TestGetCodeSnippetEdgeCases: """Test edge cases and special scenarios.""" async def test_get_snippet_with_no_docstring( - self, mcp_registry: MCPToolsRegistry + self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: """Test retrieving code with no docstring.""" - mcp_registry._code_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "qualified_name": "sample.no_docstring", - "source_code": "def no_docstring():\n pass\n", - "file_path": "sample.py", - "line_start": 12, - "line_end": 13, + # (H) Add a file with no docstring function + nodoc_file = temp_project_root / "nodoc.py" + nodoc_file.write_text("def no_docstring():\n pass\n", encoding="utf-8") + + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + { + "name": "no_docstring", + "start": 1, + "end": 2, + "path": "nodoc.py", "docstring": None, - "found": True, } - ) + ] result = await mcp_registry.get_code_snippet("sample.no_docstring") @@ -201,20 +174,27 @@ async def test_get_snippet_with_no_docstring( assert result["docstring"] is None async def test_get_snippet_from_nested_module( - self, mcp_registry: MCPToolsRegistry + self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: """Test retrieving code from deeply nested module.""" - mcp_registry._code_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "qualified_name": "pkg.subpkg.module.ClassName.method", - "source_code": " def method(self):\n return True\n", - "file_path": "pkg/subpkg/module.py", - "line_start": 10, - "line_end": 11, + # (H) Create nested directory and file + nested_dir = temp_project_root / "pkg" / "subpkg" + nested_dir.mkdir(parents=True) + nested_file = nested_dir / "module.py" + nested_file.write_text( + "class ClassName:\n def method(self):\n return True\n", + encoding="utf-8", + ) + + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + { + "name": "method", + "start": 2, + "end": 3, + "path": "pkg/subpkg/module.py", "docstring": None, - "found": True, } - ) + ] result = await mcp_registry.get_code_snippet( "pkg.subpkg.module.ClassName.method" @@ -225,20 +205,25 @@ async def test_get_snippet_from_nested_module( assert result["file_path"] == "pkg/subpkg/module.py" async def test_get_snippet_with_unicode( - self, mcp_registry: MCPToolsRegistry + self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: """Test retrieving code with unicode characters.""" - mcp_registry._code_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "qualified_name": "sample.unicode_func", - "source_code": 'def unicode_func():\n """返回 Unicode 字符串。"""\n return "Hello 世界"\n', - "file_path": "sample.py", - "line_start": 15, - "line_end": 17, + # (H) Create file with unicode content + unicode_file = temp_project_root / "unicode.py" + unicode_file.write_text( + 'def unicode_func():\n """返回 Unicode 字符串。"""\n return "Hello 世界"\n', + encoding="utf-8", + ) + + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + { + "name": "unicode_func", + "start": 1, + "end": 3, + "path": "unicode.py", "docstring": "返回 Unicode 字符串。", - "found": True, } - ) + ] result = await mcp_registry.get_code_snippet("sample.unicode_func") @@ -254,78 +239,87 @@ async def test_get_snippet_with_exception( self, mcp_registry: MCPToolsRegistry ) -> None: """Test handling of exceptions during retrieval.""" - mcp_registry._code_tool.function.side_effect = Exception("Database error") # ty: ignore[invalid-assignment] + mcp_registry.ingestor.fetch_all.side_effect = Exception("Database error") # type: ignore[attr-defined] result = await mcp_registry.get_code_snippet("sample.function") assert result["found"] is False assert "error" in result or "error_message" in result - async def test_get_snippet_tool_returns_none( + async def test_get_snippet_returns_empty_results( self, mcp_registry: MCPToolsRegistry ) -> None: - """Test handling when tool returns None.""" - mcp_registry._code_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: None - ) + """Test handling when ingestor returns empty results.""" + mcp_registry.ingestor.fetch_all.return_value = [] # type: ignore[attr-defined] result = await mcp_registry.get_code_snippet("sample.function") assert isinstance(result, dict) + assert result["found"] is False class TestGetCodeSnippetIntegration: """Test integration scenarios.""" async def test_get_multiple_snippets_sequentially( - self, mcp_registry: MCPToolsRegistry + self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: """Test retrieving multiple snippets in sequence.""" + # (H) Create a module file with multiple functions + module_file = temp_project_root / "module.py" + module_file.write_text( + "def func1(): pass\n\ndef func2(): pass\n", encoding="utf-8" + ) + snippets = [ { "qualified_name": "module.func1", - "source_code": "def func1(): pass", - "file_path": "module.py", - "line_start": 1, - "line_end": 1, - "found": True, + "name": "func1", + "start": 1, + "end": 1, + "path": "module.py", }, { "qualified_name": "module.func2", - "source_code": "def func2(): pass", - "file_path": "module.py", - "line_start": 3, - "line_end": 3, - "found": True, + "name": "func2", + "start": 3, + "end": 3, + "path": "module.py", }, ] for snippet_data in snippets: - mcp_registry._code_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda s=snippet_data: s - ) - qualified_name: str = snippet_data["qualified_name"] # type: ignore[assignment] + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + { + "name": snippet_data["name"], + "start": snippet_data["start"], + "end": snippet_data["end"], + "path": snippet_data["path"], + } + ] + qualified_name = str(snippet_data["qualified_name"]) result = await mcp_registry.get_code_snippet(qualified_name) assert result["found"] is True assert result["qualified_name"] == snippet_data["qualified_name"] async def test_get_snippet_verifies_qualified_name_passed( - self, mcp_registry: MCPToolsRegistry + self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test that qualified name is correctly passed to underlying tool.""" - mcp_registry._code_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "qualified_name": "test.function", - "source_code": "def function(): pass", - "file_path": "test.py", - "line_start": 1, - "line_end": 1, - "found": True, + """Test that qualified name is correctly passed to underlying retriever.""" + test_file = temp_project_root / "test.py" + test_file.write_text("def function(): pass\n", encoding="utf-8") + + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + { + "name": "function", + "start": 1, + "end": 1, + "path": "test.py", } - ) + ] await mcp_registry.get_code_snippet("test.function") - mcp_registry._code_tool.function.assert_called_once_with( - qualified_name="test.function" - ) + mcp_registry.ingestor.fetch_all.assert_called_once() # type: ignore[attr-defined] + call_args = mcp_registry.ingestor.fetch_all.call_args # type: ignore[attr-defined] + assert "test.function" in str(call_args) diff --git a/codebase_rag/tests/test_mcp_query_and_index.py b/codebase_rag/tests/test_mcp_query_and_index.py index 60145c0d7..f8c397fc0 100644 --- a/codebase_rag/tests/test_mcp_query_and_index.py +++ b/codebase_rag/tests/test_mcp_query_and_index.py @@ -1,5 +1,5 @@ from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest @@ -49,15 +49,17 @@ def mcp_registry(temp_project_root: Path) -> MCPToolsRegistry: mock_ingestor = MagicMock() mock_cypher_gen = MagicMock() + async def mock_generate(query: str) -> str: + return "MATCH (n) RETURN n" + + mock_cypher_gen.generate = mock_generate + registry = MCPToolsRegistry( project_root=str(temp_project_root), ingestor=mock_ingestor, cypher_gen=mock_cypher_gen, ) - registry._query_tool = MagicMock() - registry._query_tool.function = AsyncMock() - return registry @@ -66,16 +68,10 @@ class TestQueryCodeGraph: async def test_query_finds_functions(self, mcp_registry: MCPToolsRegistry) -> None: """Test querying for functions in the code graph.""" - mcp_registry._query_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "cypher_query": "MATCH (f:Function) RETURN f.name", - "results": [ - {"name": "add"}, - {"name": "multiply"}, - ], - "summary": "Found 2 functions", - } - ) + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + {"name": "add"}, + {"name": "multiply"}, + ] result = await mcp_registry.query_code_graph("Find all functions") @@ -83,18 +79,14 @@ async def test_query_finds_functions(self, mcp_registry: MCPToolsRegistry) -> No assert len(result["results"]) == 2 assert result["results"][0]["name"] == "add" assert result["results"][1]["name"] == "multiply" - assert "cypher_query" in result + assert "query_used" in result assert "summary" in result async def test_query_finds_classes(self, mcp_registry: MCPToolsRegistry) -> None: """Test querying for classes in the code graph.""" - mcp_registry._query_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "cypher_query": "MATCH (c:Class) RETURN c.name", - "results": [{"name": "Calculator"}], - "summary": "Found 1 class", - } - ) + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + {"name": "Calculator"} + ] result = await mcp_registry.query_code_graph("Find all classes") @@ -105,68 +97,44 @@ async def test_query_finds_function_calls( self, mcp_registry: MCPToolsRegistry ) -> None: """Test querying for function call relationships.""" - mcp_registry._query_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "cypher_query": "MATCH (f:Function)-[:CALLS]->(g:Function) RETURN f.name, g.name", - "results": [ - {"f.name": "main", "g.name": "add"}, - {"f.name": "main", "g.name": "multiply"}, - ], - "summary": "Found 2 function call relationships", - } - ) + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + {"f.name": "main", "g.name": "add"}, + {"f.name": "main", "g.name": "multiply"}, + ] result = await mcp_registry.query_code_graph("What functions does main call?") assert len(result["results"]) == 2 - assert result["summary"] == "Found 2 function call relationships" async def test_query_with_no_results(self, mcp_registry: MCPToolsRegistry) -> None: """Test query that returns no results.""" - mcp_registry._query_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "cypher_query": "MATCH (n:NonExistent) RETURN n", - "results": [], - "summary": "No results found", - } - ) + mcp_registry.ingestor.fetch_all.return_value = [] # type: ignore[attr-defined] result = await mcp_registry.query_code_graph("Find nonexistent nodes") assert result["results"] == [] - assert "No results" in result["summary"] async def test_query_with_complex_natural_language( self, mcp_registry: MCPToolsRegistry ) -> None: """Test complex natural language query.""" - mcp_registry._query_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "cypher_query": "MATCH (f:Function)-[:DEFINED_IN]->(m:Module) WHERE m.name = 'calculator' RETURN f.name", - "results": [ - {"name": "add"}, - {"name": "multiply"}, - ], - "summary": "Found 2 functions in calculator module", - } - ) + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + {"name": "add"}, + {"name": "multiply"}, + ] result = await mcp_registry.query_code_graph( "What functions are defined in the calculator module?" ) assert len(result["results"]) == 2 - assert "cypher_query" in result + assert "query_used" in result async def test_query_handles_unicode(self, mcp_registry: MCPToolsRegistry) -> None: """Test query with unicode characters.""" - mcp_registry._query_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "cypher_query": "MATCH (f:Function) WHERE f.name = '你好' RETURN f", - "results": [{"name": "你好"}], - "summary": "Found 1 function", - } - ) + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + {"name": "你好"} + ] result = await mcp_registry.query_code_graph("Find function 你好") @@ -174,7 +142,7 @@ async def test_query_handles_unicode(self, mcp_registry: MCPToolsRegistry) -> No async def test_query_error_handling(self, mcp_registry: MCPToolsRegistry) -> None: """Test error handling during query execution.""" - mcp_registry._query_tool.function.side_effect = Exception("Database error") # ty: ignore[invalid-assignment] + mcp_registry.ingestor.fetch_all.side_effect = Exception("Database error") # type: ignore[attr-defined] result = await mcp_registry.query_code_graph("Find all nodes") @@ -187,18 +155,12 @@ async def test_query_verifies_parameter_passed( self, mcp_registry: MCPToolsRegistry ) -> None: """Test that query parameter is correctly passed.""" - mcp_registry._query_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "cypher_query": "MATCH (n) RETURN n", - "results": [], - "summary": "Query executed", - } - ) + mcp_registry.ingestor.fetch_all.return_value = [] # type: ignore[attr-defined] query = "Find all nodes" await mcp_registry.query_code_graph(query) - mcp_registry._query_tool.function.assert_called_once_with(query) + mcp_registry.ingestor.fetch_all.assert_called_once() # type: ignore[attr-defined] class TestIndexRepository: @@ -376,13 +338,9 @@ async def test_query_after_index( index_result = await mcp_registry.index_repository() assert "Error:" not in index_result - mcp_registry._query_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "cypher_query": "MATCH (f:Function) RETURN f.name", - "results": [{"name": "add"}], - "summary": "Found 1 function", - } - ) + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + {"name": "add"} + ] query_result = await mcp_registry.query_code_graph("Find all functions") assert len(query_result["results"]) >= 0 @@ -398,22 +356,15 @@ async def test_index_and_query_workflow( await mcp_registry.index_repository() - mcp_registry._query_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "cypher_query": "MATCH (f:Function) RETURN f", - "results": [{"name": "add"}, {"name": "multiply"}], - "summary": "Found 2 functions", - } - ) + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + {"name": "add"}, + {"name": "multiply"}, + ] result = await mcp_registry.query_code_graph("Find all functions") assert len(result["results"]) == 2 - mcp_registry._query_tool.function.return_value = MagicMock( # ty: ignore[invalid-assignment] - model_dump=lambda: { - "cypher_query": "MATCH (c:Class) RETURN c", - "results": [{"name": "Calculator"}], - "summary": "Found 1 class", - } - ) + mcp_registry.ingestor.fetch_all.return_value = [ # type: ignore[attr-defined] + {"name": "Calculator"} + ] result = await mcp_registry.query_code_graph("Find all classes") assert len(result["results"]) == 1 diff --git a/codebase_rag/tests/test_mcp_read_file.py b/codebase_rag/tests/test_mcp_read_file.py index c69af87a3..72d35217f 100644 --- a/codebase_rag/tests/test_mcp_read_file.py +++ b/codebase_rag/tests/test_mcp_read_file.py @@ -1,5 +1,5 @@ from pathlib import Path -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest @@ -10,19 +10,16 @@ @pytest.fixture(params=["asyncio"]) def anyio_backend(request: pytest.FixtureRequest) -> str: - """Configure anyio to only use asyncio backend.""" return str(request.param) @pytest.fixture def temp_project_root(tmp_path: Path) -> Path: - """Create a temporary project root directory.""" return tmp_path @pytest.fixture def sample_file(temp_project_root: Path) -> Path: - """Create a sample file with known content.""" file_path = temp_project_root / "test_file.txt" content = "\n".join([f"Line {i}" for i in range(1, 101)]) file_path.write_text(content, encoding="utf-8") @@ -31,7 +28,6 @@ def sample_file(temp_project_root: Path) -> Path: @pytest.fixture def large_file(temp_project_root: Path) -> Path: - """Create a large file to test memory efficiency.""" file_path = temp_project_root / "large_file.txt" with open(file_path, "w", encoding="utf-8") as f: for i in range(1, 10001): @@ -41,7 +37,6 @@ def large_file(temp_project_root: Path) -> Path: @pytest.fixture def mcp_registry(temp_project_root: Path) -> MCPToolsRegistry: - """Create an MCP tools registry with mocked dependencies.""" mock_ingestor = MagicMock() mock_cypher_gen = MagicMock() @@ -51,37 +46,24 @@ def mcp_registry(temp_project_root: Path) -> MCPToolsRegistry: cypher_gen=mock_cypher_gen, ) - registry._file_reader_tool = MagicMock() - registry._file_reader_tool.function = AsyncMock() - return registry class TestReadFileWithoutPagination: - """Test reading files without pagination parameters.""" - async def test_read_full_file( self, mcp_registry: MCPToolsRegistry, sample_file: Path ) -> None: - """Test reading entire file without pagination.""" expected_content = sample_file.read_text(encoding="utf-8") - mcp_registry._file_reader_tool.function.return_value = expected_content # ty: ignore[invalid-assignment] result = await mcp_registry.read_file("test_file.txt") assert result == expected_content - mcp_registry._file_reader_tool.function.assert_called_once_with( - file_path="test_file.txt" - ) class TestReadFileWithPagination: - """Test reading files with pagination parameters.""" - async def test_read_with_offset_only( self, mcp_registry: MCPToolsRegistry, sample_file: Path ) -> None: - """Test reading from a specific offset to end of file.""" result = await mcp_registry.read_file("test_file.txt", offset=10) lines = result.split("\n") @@ -92,7 +74,6 @@ async def test_read_with_offset_only( async def test_read_with_limit_only( self, mcp_registry: MCPToolsRegistry, sample_file: Path ) -> None: - """Test reading only first N lines.""" result = await mcp_registry.read_file("test_file.txt", limit=10) lines = result.split("\n") @@ -104,7 +85,6 @@ async def test_read_with_limit_only( async def test_read_with_offset_and_limit( self, mcp_registry: MCPToolsRegistry, sample_file: Path ) -> None: - """Test reading a specific range of lines.""" result = await mcp_registry.read_file("test_file.txt", offset=20, limit=10) lines = result.split("\n") @@ -117,7 +97,6 @@ async def test_read_with_offset_and_limit( async def test_read_offset_beyond_file_length( self, mcp_registry: MCPToolsRegistry, sample_file: Path ) -> None: - """Test reading with offset beyond file length.""" result = await mcp_registry.read_file("test_file.txt", offset=150) lines = result.split("\n") @@ -127,7 +106,6 @@ async def test_read_offset_beyond_file_length( async def test_read_zero_offset( self, mcp_registry: MCPToolsRegistry, sample_file: Path ) -> None: - """Test reading with offset=0 (should read from beginning).""" result = await mcp_registry.read_file("test_file.txt", offset=0, limit=5) lines = result.split("\n") @@ -137,12 +115,9 @@ async def test_read_zero_offset( class TestReadFileLargeFiles: - """Test memory efficiency with large files.""" - async def test_read_middle_of_large_file( self, mcp_registry: MCPToolsRegistry, large_file: Path ) -> None: - """Test reading from middle of large file doesn't load entire file.""" result = await mcp_registry.read_file("large_file.txt", offset=5000, limit=10) lines = result.split("\n") @@ -155,7 +130,6 @@ async def test_read_middle_of_large_file( async def test_read_last_lines_of_large_file( self, mcp_registry: MCPToolsRegistry, large_file: Path ) -> None: - """Test reading last few lines of large file.""" result = await mcp_registry.read_file("large_file.txt", offset=9995, limit=10) lines = result.split("\n") @@ -165,12 +139,9 @@ async def test_read_last_lines_of_large_file( class TestReadFileEdgeCases: - """Test edge cases and error handling.""" - async def test_read_empty_file( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test reading an empty file.""" empty_file = temp_project_root / "empty.txt" empty_file.write_text("", encoding="utf-8") @@ -180,7 +151,6 @@ async def test_read_empty_file( assert lines[0] == "# Lines 1-0 of 0" async def test_read_nonexistent_file(self, mcp_registry: MCPToolsRegistry) -> None: - """Test reading a file that doesn't exist.""" result = await mcp_registry.read_file("nonexistent.txt", offset=0, limit=10) assert "Error:" in result @@ -188,7 +158,6 @@ async def test_read_nonexistent_file(self, mcp_registry: MCPToolsRegistry) -> No async def test_read_single_line_file( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test reading a file with only one line.""" single_line_file = temp_project_root / "single.txt" single_line_file.write_text("Only one line", encoding="utf-8") @@ -201,7 +170,6 @@ async def test_read_single_line_file( async def test_read_file_with_unicode( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test reading a file with unicode characters.""" unicode_file = temp_project_root / "unicode.txt" content = "\n".join( ["Hello 世界", "Привет мир", "مرحبا بالعالم", "🎉 Emoji line"] diff --git a/codebase_rag/tests/test_mcp_surgical_replace.py b/codebase_rag/tests/test_mcp_surgical_replace.py index ce993beae..a1768b221 100644 --- a/codebase_rag/tests/test_mcp_surgical_replace.py +++ b/codebase_rag/tests/test_mcp_surgical_replace.py @@ -1,5 +1,5 @@ from pathlib import Path -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest @@ -10,13 +10,11 @@ @pytest.fixture(params=["asyncio"]) def anyio_backend(request: pytest.FixtureRequest) -> str: - """Configure anyio to only use asyncio backend.""" return str(request.param) @pytest.fixture def temp_project_root(tmp_path: Path) -> Path: - """Create a temporary project root directory with sample files.""" sample_file = tmp_path / "sample.py" sample_file.write_text( '''def hello_world(): @@ -41,7 +39,6 @@ def subtract(self, a: int, b: int) -> int: @pytest.fixture def mcp_registry(temp_project_root: Path) -> MCPToolsRegistry: - """Create an MCP tools registry with mocked dependencies.""" mock_ingestor = MagicMock() mock_cypher_gen = MagicMock() @@ -51,23 +48,13 @@ def mcp_registry(temp_project_root: Path) -> MCPToolsRegistry: cypher_gen=mock_cypher_gen, ) - registry._file_editor_tool = MagicMock() - registry._file_editor_tool.function = AsyncMock() - return registry class TestSurgicalReplaceBasic: - """Test basic code replacement functionality.""" - async def test_replace_function_implementation( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test replacing a function's implementation.""" - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Successfully replaced code in sample.py" - ) - target = ' print("Hello, World!")' replacement = ' print("Hello, Universe!")' @@ -75,18 +62,13 @@ async def test_replace_function_implementation( "sample.py", target, replacement ) - assert "Error:" not in result - assert "Success" in result or "replaced" in result.lower() - mcp_registry._file_editor_tool.function.assert_called_once() + assert "Success" in result + content = (temp_project_root / "sample.py").read_text() + assert 'print("Hello, Universe!")' in content async def test_replace_method_implementation( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test replacing a method's implementation.""" - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Successfully replaced code in sample.py" - ) - target = """ def add(self, a: int, b: int) -> int: \"\"\"Add two numbers.\"\"\" return a + b""" @@ -100,44 +82,32 @@ async def test_replace_method_implementation( "sample.py", target, replacement ) - assert "Error:" not in result - mcp_registry._file_editor_tool.function.assert_called_once() + assert "Success" in result async def test_replace_with_exact_match( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test that replacement requires exact match.""" - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Successfully replaced code in sample.py" - ) - - target = 'print("Hello, World!")' - replacement = 'print("Goodbye!")' + target = ' print("Hello, World!")' + replacement = ' print("Goodbye!")' - await mcp_registry.surgical_replace_code("sample.py", target, replacement) + result = await mcp_registry.surgical_replace_code( + "sample.py", target, replacement + ) - call_args = mcp_registry._file_editor_tool.function.call_args - assert call_args.kwargs["file_path"] == "sample.py" - assert call_args.kwargs["target_code"] == target - assert call_args.kwargs["replacement_code"] == replacement + assert "Success" in result + content = (temp_project_root / "sample.py").read_text() + assert 'print("Goodbye!")' in content class TestSurgicalReplaceEdgeCases: - """Test edge cases and special scenarios.""" - async def test_replace_with_unicode( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test replacing code with unicode characters.""" unicode_file = temp_project_root / "unicode.py" unicode_file.write_text( 'def greet():\n print("Hello 世界")\n', encoding="utf-8" ) - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Successfully replaced code" - ) - target = 'print("Hello 世界")' replacement = 'print("你好 世界")' @@ -145,98 +115,36 @@ async def test_replace_with_unicode( "unicode.py", target, replacement ) - assert "Error:" not in result - - async def test_replace_multiline_block( - self, mcp_registry: MCPToolsRegistry, temp_project_root: Path - ) -> None: - """Test replacing a multiline code block.""" - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Successfully replaced code" - ) - - target = '''class Calculator: - """Simple calculator class.""" - - def add(self, a: int, b: int) -> int: - """Add two numbers.""" - return a + b''' - - replacement = '''class Calculator: - """Advanced calculator class.""" - - def add(self, a: int, b: int) -> int: - """Add two numbers with validation.""" - if not isinstance(a, (int, float)) or not isinstance(b, (int, float)): - raise TypeError("Arguments must be numbers") - return a + b''' - - result = await mcp_registry.surgical_replace_code( - "sample.py", target, replacement - ) - - assert "Error:" not in result - - async def test_replace_with_empty_replacement( - self, mcp_registry: MCPToolsRegistry, temp_project_root: Path - ) -> None: - """Test replacing code with empty string (deletion).""" - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Successfully replaced code" - ) - - target = ' print("Hello, World!")' - replacement = "" - - result = await mcp_registry.surgical_replace_code( - "sample.py", target, replacement - ) - - assert "Error:" not in result + assert "Success" in result async def test_replace_preserves_whitespace( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test that whitespace in target/replacement is preserved.""" - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Successfully replaced code" - ) - target = " def add(self, a: int, b: int) -> int:" replacement = " def multiply(self, a: int, b: int) -> int:" - await mcp_registry.surgical_replace_code("sample.py", target, replacement) + result = await mcp_registry.surgical_replace_code( + "sample.py", target, replacement + ) - call_args = mcp_registry._file_editor_tool.function.call_args - assert call_args.kwargs["target_code"] == target - assert call_args.kwargs["replacement_code"] == replacement + assert "Success" in result + content = (temp_project_root / "sample.py").read_text() + assert "def multiply(self, a: int, b: int) -> int:" in content class TestSurgicalReplaceErrorHandling: - """Test error handling and failure scenarios.""" - async def test_replace_nonexistent_file( self, mcp_registry: MCPToolsRegistry ) -> None: - """Test replacing code in a nonexistent file.""" - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Error: File not found: nonexistent.py" - ) - result = await mcp_registry.surgical_replace_code( "nonexistent.py", "target", "replacement" ) - assert "Error:" in result + assert "Failed" in result or "Error" in result async def test_replace_code_not_found( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test replacing code that doesn't exist in the file.""" - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Error: Target code not found in file" - ) - target = "def nonexistent_function():" replacement = "def new_function():" @@ -244,14 +152,13 @@ async def test_replace_code_not_found( "sample.py", target, replacement ) - assert "Error:" in result + assert "Failed" in result or "not found" in result.lower() async def test_replace_with_exception( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test handling of exceptions during replacement.""" - mcp_registry._file_editor_tool.function.side_effect = Exception( # ty: ignore[invalid-assignment] - "Permission denied" + mcp_registry.file_editor.replace_code_block = MagicMock( # type: ignore[method-assign] + side_effect=Exception("Permission denied") ) result = await mcp_registry.surgical_replace_code( @@ -263,101 +170,80 @@ async def test_replace_with_exception( async def test_replace_readonly_file( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test replacing code in a read-only file.""" readonly_file = temp_project_root / "readonly.py" readonly_file.write_text("def func(): pass", encoding="utf-8") readonly_file.chmod(0o444) - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Error: Permission denied" - ) - try: result = await mcp_registry.surgical_replace_code( "readonly.py", "def func():", "def new_func():" ) - assert "Error:" in result + assert "Error" in result or "Failed" in result finally: readonly_file.chmod(0o644) class TestSurgicalReplacePathHandling: - """Test path handling and security.""" - async def test_replace_in_subdirectory( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test replacing code in a file in a subdirectory.""" subdir = temp_project_root / "subdir" subdir.mkdir() sub_file = subdir / "module.py" sub_file.write_text("def func(): pass", encoding="utf-8") - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Successfully replaced code" - ) - result = await mcp_registry.surgical_replace_code( "subdir/module.py", "def func():", "def new_func():" ) - assert "Error:" not in result + assert "Success" in result async def test_replace_prevents_directory_traversal( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test that directory traversal is prevented.""" - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Error: Security risk - path traversal detected" - ) - result = await mcp_registry.surgical_replace_code( "../../../etc/passwd", "root:", "hacked:" ) - assert "Error:" in result or "Security" in result + assert "Error" in result or "Failed" in result or "Security" in result class TestSurgicalReplaceIntegration: - """Test integration scenarios.""" - async def test_multiple_replacements_in_sequence( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test performing multiple replacements sequentially.""" - mcp_registry._file_editor_tool.function.return_value = ( # ty: ignore[invalid-assignment] - "Successfully replaced code" + result1 = await mcp_registry.surgical_replace_code( + "sample.py", 'print("Hello, World!")', 'print("Hi!")' ) - - replacements = [ - ('print("Hello, World!")', 'print("Hi!")'), - ("def add(", "def addition("), - ("def subtract(", "def subtraction("), - ] - - for target, replacement in replacements: - result = await mcp_registry.surgical_replace_code( - "sample.py", target, replacement - ) - assert "Error:" not in result + assert "Success" in result1 async def test_replace_verifies_parameters_passed( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test that all parameters are correctly passed to underlying tool.""" - mcp_registry._file_editor_tool.function.return_value = "Success" # ty: ignore[invalid-assignment] + original_replace = mcp_registry.file_editor.replace_code_block + call_args: dict[str, str] = {} + + def capture_args(*args: str, **kwargs: str) -> bool: + call_args["file_path"] = args[0] + call_args["target_code"] = args[1] + call_args["replacement_code"] = args[2] + return original_replace(*args, **kwargs) # type: ignore[return-value] + + mcp_registry.file_editor.replace_code_block = capture_args # type: ignore[method-assign] + + test_file = temp_project_root / "test.py" + test_file.write_text("old_code = 1", encoding="utf-8") await mcp_registry.surgical_replace_code("test.py", "old_code", "new_code") - mcp_registry._file_editor_tool.function.assert_called_once_with( - file_path="test.py", target_code="old_code", replacement_code="new_code" - ) + assert call_args["file_path"] == "test.py" + assert call_args["target_code"] == "old_code" + assert call_args["replacement_code"] == "new_code" async def test_replace_different_file_types( self, mcp_registry: MCPToolsRegistry, temp_project_root: Path ) -> None: - """Test replacing code in different file types.""" files = { "script.js": "function hello() { console.log('hi'); }", "style.css": "body { color: blue; }", @@ -366,7 +252,6 @@ async def test_replace_different_file_types( for filename, content in files.items(): (temp_project_root / filename).write_text(content, encoding="utf-8") - mcp_registry._file_editor_tool.function.return_value = "Success" # ty: ignore[invalid-assignment] result = await mcp_registry.surgical_replace_code( filename, list(content.split())[0], "replacement" diff --git a/codebase_rag/tools/code_retrieval.py b/codebase_rag/tools/code_retrieval.py index b707ee0f9..204cbdb05 100644 --- a/codebase_rag/tools/code_retrieval.py +++ b/codebase_rag/tools/code_retrieval.py @@ -1,22 +1,20 @@ from pathlib import Path from loguru import logger -from pydantic_ai import Tool +from pydantic_ai import RunContext +from ..deps import RAGDeps from ..schemas import CodeSnippet from ..services import QueryProtocol class CodeRetriever: - """Service to retrieve code snippets using the graph and filesystem.""" - def __init__(self, project_root: str, ingestor: QueryProtocol): self.project_root = Path(project_root).resolve() self.ingestor = ingestor logger.info(f"CodeRetriever initialized with root: {self.project_root}") async def find_code_snippet(self, qualified_name: str) -> CodeSnippet: - """Finds a code snippet by querying the graph for its location.""" logger.info(f"[CodeRetriever] Searching for: {qualified_name}") query = """ @@ -84,15 +82,11 @@ async def find_code_snippet(self, qualified_name: str) -> CodeSnippet: ) -def create_code_retrieval_tool(code_retriever: CodeRetriever) -> Tool: - """Factory function to create the code snippet retrieval tool.""" - - async def get_code_snippet(qualified_name: str) -> CodeSnippet: - """Retrieves the source code for a given qualified name.""" - logger.info(f"[Tool:GetCode] Retrieving code for: {qualified_name}") - return await code_retriever.find_code_snippet(qualified_name) - - return Tool( - function=get_code_snippet, - description="Retrieves the source code for a specific function, class, or method using its full qualified name.", - ) +async def get_code_snippet( + ctx: RunContext[RAGDeps], qualified_name: str +) -> CodeSnippet: + """ + Retrieves the source code for a specific function, class, or method using its full qualified name. + """ + logger.info(f"[Tool:GetCode] Retrieving code for: {qualified_name}") + return await ctx.deps.code_retriever.find_code_snippet(qualified_name) diff --git a/codebase_rag/tools/codebase_query.py b/codebase_rag/tools/codebase_query.py index ca2e223d0..3ea85a2af 100644 --- a/codebase_rag/tools/codebase_query.py +++ b/codebase_rag/tools/codebase_query.py @@ -1,103 +1,81 @@ from loguru import logger -from pydantic_ai import Tool -from rich.console import Console +from pydantic_ai import RunContext from rich.panel import Panel from rich.table import Table +from ..deps import RAGDeps +from ..exceptions import LLMGenerationError from ..schemas import GraphData -from ..services import QueryProtocol -from ..services.llm import CypherGenerator, LLMGenerationError -class GraphQueryError(Exception): - """Custom exception for graph query failures.""" - - pass - - -def create_query_tool( - ingestor: QueryProtocol, - cypher_gen: CypherGenerator, - console: Console | None = None, -) -> Tool: +async def query_codebase_knowledge_graph( + ctx: RunContext[RAGDeps], natural_language_query: str +) -> GraphData: """ - Factory function that creates the knowledge graph query tool, - injecting its dependencies. - """ - if console is None: - console = Console(width=None, force_terminal=True) - - async def query_codebase_knowledge_graph(natural_language_query: str) -> GraphData: - """ - Queries the codebase knowledge graph using natural language. + Queries the codebase knowledge graph using natural language. - Provide your question in plain English about the codebase structure, - functions, classes, dependencies, or relationships. The tool will - automatically translate your natural language question into the - appropriate database query and return the results. + Provide your question in plain English about the codebase structure, + functions, classes, dependencies, or relationships. The tool will + automatically translate your natural language question into the + appropriate database query and return the results. - Examples: - - "Find all functions that call each other" - - "What classes are in the user authentication module" - - "Show me functions with the longest call chains" - - "Which files contain functions related to database operations" - """ - logger.info(f"[Tool:QueryGraph] Received NL query: '{natural_language_query}'") - cypher_query = "N/A" - try: - cypher_query = await cypher_gen.generate(natural_language_query) + Examples: + - "Find all functions that call each other" + - "What classes are in the user authentication module" + - "Show me functions with the longest call chains" + - "Which files contain functions related to database operations" + """ + logger.info(f"[Tool:QueryGraph] Received NL query: '{natural_language_query}'") + cypher_query = "N/A" + try: + cypher_query = await ctx.deps.cypher_generator.generate(natural_language_query) - results = ingestor.fetch_all(cypher_query) + results = ctx.deps.ingestor.fetch_all(cypher_query) - if results: - table = Table( - show_header=True, - header_style="bold magenta", - ) - headers = results[0].keys() - for header in headers: - table.add_column(header) + if results: + table = Table( + show_header=True, + header_style="bold magenta", + ) + headers = results[0].keys() + for header in headers: + table.add_column(header) - for row in results: - renderable_values = [] - for value in row.values(): - if value is None: - renderable_values.append("") - elif isinstance(value, bool): - renderable_values.append("✓" if value else "✗") - elif isinstance(value, int | float): - renderable_values.append(str(value)) - else: - renderable_values.append(str(value)) - table.add_row(*renderable_values) + for row in results: + renderable_values = [] + for value in row.values(): + if value is None: + renderable_values.append("") + elif isinstance(value, bool): + renderable_values.append("✓" if value else "✗") + elif isinstance(value, int | float): + renderable_values.append(str(value)) + else: + renderable_values.append(str(value)) + table.add_row(*renderable_values) - console.print( - Panel( - table, - title="[bold blue]Cypher Query Results[/bold blue]", - expand=False, - ) + ctx.deps.console.print( + Panel( + table, + title="[bold blue]Cypher Query Results[/bold blue]", + expand=False, ) - - summary = f"Successfully retrieved {len(results)} item(s) from the graph." - return GraphData(query_used=cypher_query, results=results, summary=summary) - except LLMGenerationError as e: - return GraphData( - query_used="N/A", - results=[], - summary=f"I couldn't translate your request into a database query. Error: {e}", - ) - except Exception as e: - logger.error( - f"[Tool:QueryGraph] Error during query execution: {e}", exc_info=True - ) - return GraphData( - query_used=cypher_query, - results=[], - summary=f"There was an error querying the database: {e}", ) - return Tool( - function=query_codebase_knowledge_graph, - description="Query the codebase knowledge graph using natural language questions. Ask in plain English about classes, functions, methods, dependencies, or code structure. Examples: 'Find all functions that call each other', 'What classes are in the user module', 'Show me functions with the longest call chains'.", - ) + summary = f"Successfully retrieved {len(results)} item(s) from the graph." + return GraphData(query_used=cypher_query, results=results, summary=summary) + except LLMGenerationError as e: + return GraphData( + query_used="N/A", + results=[], + summary=f"I couldn't translate your request into a database query. Error: {e}", + ) + except Exception as e: + logger.error( + f"[Tool:QueryGraph] Error during query execution: {e}", exc_info=True + ) + return GraphData( + query_used=cypher_query, + results=[], + summary=f"There was an error querying the database: {e}", + ) diff --git a/codebase_rag/tools/directory_lister.py b/codebase_rag/tools/directory_lister.py index 888f5bada..28fd2c313 100644 --- a/codebase_rag/tools/directory_lister.py +++ b/codebase_rag/tools/directory_lister.py @@ -2,7 +2,9 @@ from pathlib import Path from loguru import logger -from pydantic_ai import Tool +from pydantic_ai import RunContext + +from ..deps import RAGDeps class DirectoryLister: @@ -10,9 +12,6 @@ def __init__(self, project_root: str): self.project_root = Path(project_root).resolve() def list_directory_contents(self, directory_path: str) -> str: - """ - Lists the contents of a specified directory. - """ target_path = self._get_safe_path(directory_path) logger.info(f"Listing contents of directory: {target_path}") @@ -30,10 +29,6 @@ def list_directory_contents(self, directory_path: str) -> str: return f"Error: Could not list contents of '{directory_path}'." def _get_safe_path(self, file_path: str) -> Path: - """ - Resolves the file path relative to the root and ensures it's within - the project directory. - """ if Path(file_path).is_absolute(): safe_path = Path(file_path).resolve() else: @@ -54,8 +49,8 @@ def _get_safe_path(self, file_path: str) -> Path: return safe_path -def create_directory_lister_tool(directory_lister: DirectoryLister) -> Tool: - return Tool( - function=directory_lister.list_directory_contents, - description="Lists the contents of a directory to explore the codebase.", - ) +def list_directory(ctx: RunContext[RAGDeps], directory_path: str) -> str: + """ + Lists the contents of a directory to explore the codebase. + """ + return ctx.deps.directory_lister.list_directory_contents(directory_path) diff --git a/codebase_rag/tools/document_analyzer.py b/codebase_rag/tools/document_analyzer.py index 2b1198d39..c6caa3c21 100644 --- a/codebase_rag/tools/document_analyzer.py +++ b/codebase_rag/tools/document_analyzer.py @@ -7,14 +7,13 @@ from google.genai import types from google.genai.errors import ClientError from loguru import logger -from pydantic_ai import Tool +from pydantic_ai import RunContext from ..config import settings +from ..deps import RAGDeps class _NotSupportedClient: - """Placeholder client that raises NotImplementedError for unsupported providers.""" - def __getattr__(self, name: str) -> None: raise NotImplementedError( "DocumentAnalyzer does not support the 'local' LLM provider." @@ -22,11 +21,6 @@ def __getattr__(self, name: str) -> None: class DocumentAnalyzer: - """ - A tool to perform multimodal analysis on documents like PDFs - by making a direct call to the Gemini API. - """ - def __init__(self, project_root: str) -> None: self.project_root = Path(project_root).resolve() @@ -47,10 +41,6 @@ def __init__(self, project_root: str) -> None: logger.info(f"DocumentAnalyzer initialized with root: {self.project_root}") def analyze(self, file_path: str, question: str) -> str: - """ - Reads a document (e.g., PDF), sends it to the Gemini multimodal endpoint - with a specific question, and returns the model's analysis. - """ logger.info( f"[DocumentAnalyzer] Analyzing '{file_path}' with question: '{question}'" ) @@ -134,34 +124,25 @@ def analyze(self, file_path: str, question: str) -> str: return f"An error occurred during analysis: {e}" -def create_document_analyzer_tool(analyzer: DocumentAnalyzer) -> Tool: - """Factory function to create the document analyzer tool.""" - - def analyze_document(file_path: str, question: str) -> str: - """ - Analyzes a document (like a PDF) to answer a specific question about its content. - Use this tool when a user asks a question that requires understanding the content of a non-source-code file. +def analyze_document(ctx: RunContext[RAGDeps], file_path: str, question: str) -> str: + """ + Analyzes a document (like a PDF) to answer a specific question about its content. + Use this tool when a user asks a question that requires understanding the content of a non-source-code file. - Args: - file_path: The path to the document file (e.g., 'path/to/book.pdf'). - question: The specific question to ask about the document's content. - """ - try: - result = analyzer.analyze(file_path, question) - logger.debug( - f"[analyze_document] Result type: {type(result)}, content: {result[:100] if result else 'None'}..." - ) - return result - except Exception as e: - logger.error( - f"[analyze_document] Exception during analysis: {e}", exc_info=True - ) - if str(e).startswith("Error:") or str(e).startswith("API error:"): - return str(e) - return f"Error during document analysis: {e}" - - return Tool( - function=analyze_document, - name="analyze_document", - description="Analyzes documents (PDFs, images) to answer questions about their content.", - ) + Args: + file_path: The path to the document file (e.g., 'path/to/book.pdf'). + question: The specific question to ask about the document's content. + """ + try: + result = ctx.deps.document_analyzer.analyze(file_path, question) + logger.debug( + f"[analyze_document] Result type: {type(result)}, content: {result[:100] if result else 'None'}..." + ) + return result + except Exception as e: + logger.error( + f"[analyze_document] Exception during analysis: {e}", exc_info=True + ) + if str(e).startswith("Error:") or str(e).startswith("API error:"): + return str(e) + return f"Error during document analysis: {e}" diff --git a/codebase_rag/tools/file_editor.py b/codebase_rag/tools/file_editor.py index 6e4390a42..431d8d2df 100644 --- a/codebase_rag/tools/file_editor.py +++ b/codebase_rag/tools/file_editor.py @@ -5,9 +5,10 @@ import diff_match_patch from loguru import logger from pydantic import BaseModel -from pydantic_ai import Tool +from pydantic_ai import RunContext from tree_sitter import Node, Parser +from ..deps import RAGDeps from ..language_config import get_language_config from ..parser_loader import load_parsers @@ -40,8 +41,6 @@ class FunctionMatch(TypedDict): class EditResult(BaseModel): - """Data model for file edit results.""" - file_path: str success: bool error_message: str | None = None @@ -55,7 +54,6 @@ def __init__(self, project_root: str = ".") -> None: logger.info(f"FileEditor initialized with root: {self.project_root}") def _get_real_extension(self, file_path_obj: Path) -> str: - """Gets the file extension, looking past a .tmp suffix if present.""" extension = file_path_obj.suffix if extension == ".tmp": base_name = file_path_obj.stem @@ -248,7 +246,6 @@ def get_diff( return "".join(diff) def apply_patch_to_file(self, file_path: str, patch_text: str) -> bool: - """Apply a patch to a file using diff-match-patch.""" try: with open(file_path, encoding="utf-8") as f: original_content = f.read() @@ -274,7 +271,6 @@ def apply_patch_to_file(self, file_path: str, patch_text: str) -> bool: def replace_code_block( self, file_path: str, target_block: str, replacement_block: str ) -> bool: - """Surgically replace a specific code block in a file using diff-match-patch.""" logger.info( f"[FileEditor] Attempting surgical block replacement in: {file_path}" ) @@ -334,7 +330,6 @@ def replace_code_block( return False async def edit_file(self, file_path: str, new_content: str) -> EditResult: - """Overwrites entire file with new content - use for full file replacement.""" logger.info(f"[FileEditor] Attempting full file replacement: {file_path}") try: full_path = (self.project_root / file_path).resolve() @@ -369,35 +364,26 @@ async def edit_file(self, file_path: str, new_content: str) -> EditResult: ) -def create_file_editor_tool(file_editor: FileEditor) -> Tool: - """Factory function to create the file editor tool.""" - - async def replace_code_surgically( - file_path: str, target_code: str, replacement_code: str - ) -> str: - """ - Surgically replaces a specific code block in a file using diff-match-patch. - This tool finds the exact target code block and replaces only that section, - leaving the rest of the file completely unchanged. This is true surgical patching. - - Args: - file_path: Path to the file to modify - target_code: The exact code block to find and replace (must match exactly) - replacement_code: The new code to replace the target with - - Use this when you need to change specific functions, classes, or code blocks - without affecting the rest of the file. The target_code must be an exact match. - """ - success = file_editor.replace_code_block( - file_path, target_code, replacement_code - ) - if success: - return f"Successfully applied surgical code replacement in: {file_path}" - else: - return f"Failed to apply surgical replacement in {file_path}. Target code not found or patches failed." - - return Tool( - function=replace_code_surgically, - description="Surgically replaces specific code blocks in files. Requires exact target code and replacement. Only modifies the specified block, leaving rest of file unchanged. True surgical patching.", - requires_approval=True, +async def replace_code_surgically( + ctx: RunContext[RAGDeps], file_path: str, target_code: str, replacement_code: str +) -> str: + """ + Surgically replaces a specific code block in a file using diff-match-patch. + This tool finds the exact target code block and replaces only that section, + leaving the rest of the file completely unchanged. This is true surgical patching. + + Args: + file_path: Path to the file to modify + target_code: The exact code block to find and replace (must match exactly) + replacement_code: The new code to replace the target with + + Use this when you need to change specific functions, classes, or code blocks + without affecting the rest of the file. The target_code must be an exact match. + """ + success = ctx.deps.file_editor.replace_code_block( + file_path, target_code, replacement_code ) + if success: + return f"Successfully applied surgical code replacement in: {file_path}" + else: + return f"Failed to apply surgical replacement in {file_path}. Target code not found or patches failed." diff --git a/codebase_rag/tools/file_reader.py b/codebase_rag/tools/file_reader.py index dff60bdeb..71b4e54ee 100644 --- a/codebase_rag/tools/file_reader.py +++ b/codebase_rag/tools/file_reader.py @@ -2,20 +2,18 @@ from loguru import logger from pydantic import BaseModel -from pydantic_ai import Tool +from pydantic_ai import RunContext +from ..deps import RAGDeps -class FileReadResult(BaseModel): - """Data model for file read results.""" +class FileReadResult(BaseModel): file_path: str content: str | None = None error_message: str | None = None class FileReader: - """Service to read file content from the filesystem.""" - def __init__(self, project_root: str = "."): self.project_root = Path(project_root).resolve() self.binary_extensions = { @@ -32,7 +30,6 @@ def __init__(self, project_root: str = "."): logger.info(f"FileReader initialized with root: {self.project_root}") async def read_file(self, file_path: str) -> FileReadResult: - """Reads and returns the content of a text-based file.""" logger.info(f"[FileReader] Attempting to read file: {file_path}") try: full_path = (self.project_root / file_path).resolve() @@ -81,20 +78,12 @@ async def read_file(self, file_path: str) -> FileReadResult: ) -def create_file_reader_tool(file_reader: FileReader) -> Tool: - """Factory function to create the file reader tool.""" - - async def read_file_content(file_path: str) -> str: - """ - Reads the content of a specified text-based file (e.g., source code, README.md, config files). - This tool should NOT be used for binary files like PDFs or images. For those, use the 'analyze_document' tool. - """ - result = await file_reader.read_file(file_path) - if result.error_message: - return f"Error: {result.error_message}" - return result.content or "" - - return Tool( - function=read_file_content, - description="Reads the content of text-based files. For documents like PDFs or images, use the 'analyze_document' tool instead.", - ) +async def read_file_content(ctx: RunContext[RAGDeps], file_path: str) -> str: + """ + Reads the content of a specified text-based file (e.g., source code, README.md, config files). + This tool should NOT be used for binary files like PDFs or images. For those, use the 'analyze_document' tool. + """ + result = await ctx.deps.file_reader.read_file(file_path) + if result.error_message: + return f"Error: {result.error_message}" + return result.content or "" diff --git a/codebase_rag/tools/file_writer.py b/codebase_rag/tools/file_writer.py index 69e09bce7..408ad9563 100644 --- a/codebase_rag/tools/file_writer.py +++ b/codebase_rag/tools/file_writer.py @@ -2,26 +2,23 @@ from loguru import logger from pydantic import BaseModel -from pydantic_ai import Tool +from pydantic_ai import RunContext +from ..deps import RAGDeps -class FileCreationResult(BaseModel): - """Data model for file creation results.""" +class FileCreationResult(BaseModel): file_path: str success: bool = True error_message: str | None = None class FileWriter: - """Service to write file content to the filesystem.""" - def __init__(self, project_root: str = "."): self.project_root = Path(project_root).resolve() logger.info(f"FileWriter initialized with root: {self.project_root}") async def create_file(self, file_path: str, content: str) -> FileCreationResult: - """Creates or overwrites a file with the given content.""" logger.info(f"[FileWriter] Creating file: {file_path}") try: full_path = (self.project_root / file_path).resolve() @@ -48,25 +45,18 @@ async def create_file(self, file_path: str, content: str) -> FileCreationResult: ) -def create_file_writer_tool(file_writer: FileWriter) -> Tool: - """Factory function to create the file writer tool.""" - - async def create_new_file(file_path: str, content: str) -> FileCreationResult: - """ - Creates a new file with the specified content. - - IMPORTANT: Before using this tool, you MUST check if the file already exists using - the file reader or directory listing tools. If the file exists, use edit_existing_file - instead to preserve existing content and show diffs. +async def create_new_file( + ctx: RunContext[RAGDeps], file_path: str, content: str +) -> FileCreationResult: + """ + Creates a new file with the specified content. - If the file already exists, it will be completely overwritten WITHOUT showing any diff. - Use this ONLY for creating entirely new files, not for modifying existing ones. - For modifying existing files with diff preview, use edit_existing_file instead. - """ - return await file_writer.create_file(file_path, content) + IMPORTANT: Before using this tool, you MUST check if the file already exists using + the file reader or directory listing tools. If the file exists, use edit_existing_file + instead to preserve existing content and show diffs. - return Tool( - function=create_new_file, - description="Creates a new file with content. IMPORTANT: Check file existence first! Overwrites completely WITHOUT showing diff. Use only for new files, not existing file modifications.", - requires_approval=True, - ) + If the file already exists, it will be completely overwritten WITHOUT showing any diff. + Use this ONLY for creating entirely new files, not for modifying existing ones. + For modifying existing files with diff preview, use edit_existing_file instead. + """ + return await ctx.deps.file_writer.create_file(file_path, content) diff --git a/codebase_rag/tools/semantic_search.py b/codebase_rag/tools/semantic_search.py index 99cc9e9c8..176d13fff 100644 --- a/codebase_rag/tools/semantic_search.py +++ b/codebase_rag/tools/semantic_search.py @@ -1,30 +1,11 @@ from typing import Any from loguru import logger -from pydantic_ai import Tool from ..utils.dependencies import has_semantic_dependencies def semantic_code_search(query: str, top_k: int = 5) -> list[dict[str, Any]]: - """ - Search for functions/methods by natural language intent using semantic embeddings. - - Args: - query: Natural language description of desired functionality - top_k: Number of results to return - - Returns: - List of dictionaries with node information: - [ - { - "node_id": int, - "qualified_name": str, - "type": str, - "score": float - } - ] - """ if not has_semantic_dependencies(): logger.warning( "Semantic search requires 'semantic' extra: uv sync --extra semantic" @@ -87,15 +68,6 @@ def semantic_code_search(query: str, top_k: int = 5) -> list[dict[str, Any]]: def get_function_source_code(node_id: int) -> str | None: - """ - Retrieve source code for a function/method by node ID. - - Args: - node_id: Memgraph node ID - - Returns: - Source code string or None if not found - """ try: from ..config import settings from ..services.graph_service import MemgraphIngestor @@ -141,78 +113,62 @@ def get_function_source_code(node_id: int) -> str | None: return None -def create_semantic_search_tool() -> Tool: - """ - Factory function to create the semantic code search tool. +async def semantic_search_functions(query: str, top_k: int = 5) -> str: """ + Search for functions/methods using natural language descriptions of their purpose. - async def semantic_search_functions(query: str, top_k: int = 5) -> str: - """ - Search for functions/methods using natural language descriptions of their purpose. - - Use this tool when you need to find code that performs specific functionality - based on intent rather than exact names. Perfect for questions like: - - "Find error handling functions" - - "Show me authentication-related code" - - "Where is data validation implemented?" - - "Find functions that handle file I/O" + Use this tool when you need to find code that performs specific functionality + based on intent rather than exact names. Perfect for questions like: + - "Find error handling functions" + - "Show me authentication-related code" + - "Where is data validation implemented?" + - "Find functions that handle file I/O" - Args: - query: Natural language description of the desired functionality - top_k: Maximum number of results to return (default: 5) + Args: + query: Natural language description of the desired functionality + top_k: Maximum number of results to return (default: 5) - Returns: - String describing the found functions with their qualified names and similarity scores - """ - logger.info(f"[Tool:SemanticSearch] Searching for: '{query}'") + Returns: + String describing the found functions with their qualified names and similarity scores + """ + logger.info(f"[Tool:SemanticSearch] Searching for: '{query}'") - results = semantic_code_search(query, top_k) + results = semantic_code_search(query, top_k) - if not results: - return f"No semantic matches found for query: '{query}'. This could mean:\n1. No functions match this description\n2. Semantic search dependencies are not installed\n3. No embeddings have been generated yet" + if not results: + return f"No semantic matches found for query: '{query}'. This could mean:\n1. No functions match this description\n2. Semantic search dependencies are not installed\n3. No embeddings have been generated yet" - formatted_results = [] - for i, result in enumerate(results, 1): - formatted_results.append( - f"{i}. {result['qualified_name']} (type: {result['type']}, score: {result['score']})" - ) - - response = f"Found {len(results)} semantic matches for '{query}':\n\n" - response += "\n".join(formatted_results) - response += "\n\nUse the qualified names above with other tools to get more details or source code." + formatted_results = [] + for i, result in enumerate(results, 1): + formatted_results.append( + f"{i}. {result['qualified_name']} (type: {result['type']}, score: {result['score']})" + ) - return response + response = f"Found {len(results)} semantic matches for '{query}':\n\n" + response += "\n".join(formatted_results) + response += "\n\nUse the qualified names above with other tools to get more details or source code." - return Tool(semantic_search_functions, name="semantic_search_functions") + return response -def create_get_function_source_tool() -> Tool: - """ - Factory function to create the function source code retrieval tool. +async def get_function_source_by_id(node_id: int) -> str: """ + Retrieve the complete source code for a function or method by its node ID. - async def get_function_source_by_id(node_id: int) -> str: - """ - Retrieve the complete source code for a function or method by its node ID. + Use this tool after semantic search to get the actual implementation + of functions you're interested in. - Use this tool after semantic search to get the actual implementation - of functions you're interested in. - - Args: - node_id: The Memgraph node ID of the function/method - - Returns: - The complete source code of the function/method - """ - logger.info( - f"[Tool:GetFunctionSource] Retrieving source for node ID: {node_id}" - ) + Args: + node_id: The Memgraph node ID of the function/method - source_code = get_function_source_code(node_id) + Returns: + The complete source code of the function/method + """ + logger.info(f"[Tool:GetFunctionSource] Retrieving source for node ID: {node_id}") - if source_code is None: - return f"Could not retrieve source code for node ID {node_id}. The node may not exist or source file may be unavailable." + source_code = get_function_source_code(node_id) - return f"Source code for node ID {node_id}:\n\n```\n{source_code}\n```" + if source_code is None: + return f"Could not retrieve source code for node ID {node_id}. The node may not exist or source file may be unavailable." - return Tool(get_function_source_by_id, name="get_function_source_by_id") + return f"Source code for node ID {node_id}:\n\n```\n{source_code}\n```" diff --git a/codebase_rag/tools/shell_command.py b/codebase_rag/tools/shell_command.py index c469c2157..ae018f026 100644 --- a/codebase_rag/tools/shell_command.py +++ b/codebase_rag/tools/shell_command.py @@ -4,11 +4,12 @@ from collections.abc import Awaitable, Callable from functools import wraps from pathlib import Path -from typing import Any, cast +from typing import Any from loguru import logger -from pydantic_ai import Tool +from pydantic_ai import RunContext +from ..deps import RAGDeps from ..schemas import ShellCommandResult COMMAND_ALLOWLIST = { @@ -49,16 +50,11 @@ def _is_dangerous_command(cmd_parts: list[str]) -> bool: - """Checks for dangerous command patterns.""" command = cmd_parts[0] return command == "rm" and "-rf" in cmd_parts def _requires_confirmation(cmd_parts: list[str]) -> tuple[bool, str]: - """ - Checks if a command requires user confirmation. - Returns (requires_confirmation, reason). - """ if not cmd_parts: return False, "" @@ -81,10 +77,6 @@ def _requires_confirmation(cmd_parts: list[str]) -> tuple[bool, str]: def timing_decorator( func: Callable[..., Awaitable[Any]], ) -> Callable[..., Awaitable[Any]]: - """ - A decorator that logs the execution time of the decorated asynchronous function. - """ - @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: start_time = time.perf_counter() @@ -99,8 +91,6 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: class ShellCommander: - """Service to execute shell commands.""" - def __init__(self, project_root: str = ".", timeout: int = 30): self.project_root = Path(project_root).resolve() self.timeout = timeout @@ -110,9 +100,6 @@ def __init__(self, project_root: str = ".", timeout: int = 30): async def execute( self, command: str, confirmed: bool = False ) -> ShellCommandResult: - """ - Execute a shell command and return the status code, stdout, and stderr. - """ logger.info(f"Executing shell command: {command}") try: cmd_parts = shlex.split(command) @@ -188,49 +175,36 @@ async def execute( return ShellCommandResult(return_code=-1, stdout="", stderr=str(e)) -def create_shell_command_tool(shell_commander: ShellCommander) -> Tool: - """Factory function to create the shell command tool.""" - - async def run_shell_command( - command: str, user_confirmed: bool = False - ) -> ShellCommandResult: - """ - Executes a shell command from the approved allowlist only. - - Args: - command: The shell command to execute - user_confirmed: Set to True if user has explicitly confirmed this command - - AVAILABLE COMMANDS: - - File operations: ls, cat, find, pwd - - Text search: rg (ripgrep) - USE THIS INSTEAD OF grep - - Version control: git (some subcommands require confirmation) - - Testing: pytest, mypy, ruff - - Package management: uv (requires confirmation) - - File system: rm, cp, mv, mkdir, rmdir (require confirmation) - - Other: echo - - IMPORTANT: Use 'rg' for text searching, NOT 'grep' (grep is not available). - - COMMANDS REQUIRING USER CONFIRMATION: - - File system: rm, cp, mv, mkdir, rmdir - - Package management: uv (any subcommand) - - Git operations: add, commit, push, pull, merge, rebase, reset, checkout, branch, tag, stash, cherry-pick, revert - - Safe git commands (no confirmation needed): status, log, diff, show, ls-files, remote, config - - For dangerous commands: - 1. Call once to check if confirmation needed (will return error if required) - 2. Ask user for approval - 3. Call again with user_confirmed=True to execute - """ - return cast( - ShellCommandResult, - await shell_commander.execute(command, confirmed=user_confirmed), - ) - - return Tool( - function=run_shell_command, - name="execute_shell_command", - description="Executes shell commands from allowlist. For dangerous commands, call twice: first to check if confirmation needed, then with user_confirmed=True after getting approval.", - requires_approval=True, - ) +async def run_shell_command( + ctx: RunContext[RAGDeps], command: str, user_confirmed: bool = False +) -> ShellCommandResult: + """ + Executes a shell command from the approved allowlist only. + + Args: + command: The shell command to execute + user_confirmed: Set to True if user has explicitly confirmed this command + + AVAILABLE COMMANDS: + - File operations: ls, cat, find, pwd + - Text search: rg (ripgrep) - USE THIS INSTEAD OF grep + - Version control: git (some subcommands require confirmation) + - Testing: pytest, mypy, ruff + - Package management: uv (requires confirmation) + - File system: rm, cp, mv, mkdir, rmdir (require confirmation) + - Other: echo + + IMPORTANT: Use 'rg' for text searching, NOT 'grep' (grep is not available). + + COMMANDS REQUIRING USER CONFIRMATION: + - File system: rm, cp, mv, mkdir, rmdir + - Package management: uv (any subcommand) + - Git operations: add, commit, push, pull, merge, rebase, reset, checkout, branch, tag, stash, cherry-pick, revert + - Safe git commands (no confirmation needed): status, log, diff, show, ls-files, remote, config + + For dangerous commands: + 1. Call once to check if confirmation needed (will return error if required) + 2. Ask user for approval + 3. Call again with user_confirmed=True to execute + """ + return await ctx.deps.shell_commander.execute(command, confirmed=user_confirmed)