Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions codebase_rag/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,18 @@ class LanguageMetadata(NamedTuple):
"function_signature_item",
"function_signature",
)
FUNCTION_NODES_GENERATOR = ("generator_function_declaration", "function_expression")
FUNCTION_NODES_GENERATOR = (
"generator_function_declaration",
"function_expression",
)

CLASS_NODES_BASIC = ("class_declaration", "class_definition")
CLASS_NODES_STRUCT = ("struct_declaration", "struct_specifier", "struct_item")
CLASS_NODES_INTERFACE = ("interface_declaration", "trait_declaration", "trait_item")
CLASS_NODES_INTERFACE = (
"interface_declaration",
"trait_declaration",
"trait_item",
)
CLASS_NODES_ENUM = ("enum_declaration", "enum_item", "enum_specifier")
CLASS_NODES_TYPE_ALIAS = ("type_alias_declaration", "type_item")
CLASS_NODES_UNION = ("union_specifier", "union_item")
Expand All @@ -485,8 +492,16 @@ class LanguageMetadata(NamedTuple):
"member_call_expression",
"field_expression",
)
CALL_NODES_OPERATOR = ("binary_expression", "unary_expression", "update_expression")
CALL_NODES_SPECIAL = ("new_expression", "delete_expression", "macro_invocation")
CALL_NODES_OPERATOR = (
"binary_expression",
"unary_expression",
"update_expression",
)
CALL_NODES_SPECIAL = (
"new_expression",
"delete_expression",
"macro_invocation",
)

IMPORT_NODES_STANDARD = ("import_declaration", "import_statement")
IMPORT_NODES_FROM = ("import_from_statement",)
Expand All @@ -503,7 +518,11 @@ class LanguageMetadata(NamedTuple):
"method_definition",
)
JS_TS_CLASS_NODES = ("class_declaration", "class")
JS_TS_IMPORT_NODES = ("import_statement", "lexical_declaration", "export_statement")
JS_TS_IMPORT_NODES = (
"import_statement",
"lexical_declaration",
"export_statement",
)
JS_TS_LANGUAGES = frozenset({SupportedLanguage.JS, SupportedLanguage.TS})

# (H) C++ import node types
Expand Down Expand Up @@ -813,7 +832,11 @@ class Architecture(StrEnum):
MODULE_TRANSFORMERS = "transformers"
MODULE_QDRANT_CLIENT = "qdrant_client"

SEMANTIC_DEPENDENCIES = (MODULE_QDRANT_CLIENT, MODULE_TORCH, MODULE_TRANSFORMERS)
SEMANTIC_DEPENDENCIES = (
MODULE_QDRANT_CLIENT,
MODULE_TORCH,
MODULE_TRANSFORMERS,
)
ML_DEPENDENCIES = (MODULE_TORCH, MODULE_TRANSFORMERS)


Expand Down Expand Up @@ -2109,12 +2132,14 @@ class CppNodeType(StrEnum):
# (H) MCP tool names
class MCPToolName(StrEnum):
INDEX_REPOSITORY = "index_repository"
UPDATE_REPOSITORY = "update_repository"
QUERY_CODE_GRAPH = "query_code_graph"
GET_CODE_SNIPPET = "get_code_snippet"
SURGICAL_REPLACE_CODE = "surgical_replace_code"
READ_FILE = "read_file"
WRITE_FILE = "write_file"
LIST_DIRECTORY = "list_directory"
SEMANTIC_SEARCH = "semantic_search"


# (H) MCP environment variables
Expand Down Expand Up @@ -2151,6 +2176,7 @@ class MCPParamName(StrEnum):
LIMIT = "limit"
CONTENT = "content"
DIRECTORY_PATH = "directory_path"
TOP_K = "top_k"


# (H) MCP server constants
Expand All @@ -2168,6 +2194,12 @@ class MCPParamName(StrEnum):
MCP_WRITE_SUCCESS = "Successfully wrote file: {path}"
MCP_UNKNOWN_TOOL_ERROR = "Unknown tool: {name}"
MCP_TOOL_EXEC_ERROR = "Error executing tool '{name}': {error}"
MCP_UPDATE_SUCCESS = "Successfully updated repository at {path} (no database wipe)."
MCP_UPDATE_ERROR = "Error updating repository: {error}"
MCP_SEMANTIC_NOT_AVAILABLE_RESPONSE = (
"Semantic search is not available. Install with: uv sync --extra semantic"
)


# (H) MCP dict keys and values
MCP_KEY_RESULTS = "results"
Expand Down Expand Up @@ -2500,7 +2532,10 @@ class MCPParamName(StrEnum):
)

# (H) LANGUAGE_SPECS node type tuples for Lua
SPEC_LUA_FUNCTION_TYPES = (TS_LUA_FUNCTION_DECLARATION, TS_LUA_FUNCTION_DEFINITION)
SPEC_LUA_FUNCTION_TYPES = (
TS_LUA_FUNCTION_DECLARATION,
TS_LUA_FUNCTION_DEFINITION,
)
SPEC_LUA_CLASS_TYPES: tuple[str, ...] = ()
SPEC_LUA_MODULE_TYPES = (TS_LUA_CHUNK,)
SPEC_LUA_CALL_TYPES = (TS_LUA_FUNCTION_CALL,)
Expand Down
7 changes: 7 additions & 0 deletions codebase_rag/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,13 @@
MCP_ERROR_WRITE = "[MCP] Error writing file: {error}"
MCP_LIST_DIR = "[MCP] list_directory: {path}"
MCP_ERROR_LIST_DIR = "[MCP] Error listing directory: {error}"
MCP_SEMANTIC_NOT_AVAILABLE = (
"[MCP] Semantic search not available. Install with: uv sync --extra semantic"
)
MCP_UPDATING_REPO = "[MCP] Updating repository at: {path}"
MCP_ERROR_UPDATING = "[MCP] Error updating repository: {error}"
MCP_SEMANTIC_SEARCH = "[MCP] semantic_search: {query}"


# (H) MCP server logs
MCP_SERVER_INFERRED_ROOT = "[GraphCode MCP] Using inferred project root: {path}"
Expand Down
85 changes: 83 additions & 2 deletions codebase_rag/mcp/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from codebase_rag.services.graph_service import MemgraphIngestor
from codebase_rag.services.llm import CypherGenerator
from codebase_rag.tools import tool_descriptions as td
from codebase_rag.tools.code_retrieval import CodeRetriever, create_code_retrieval_tool
from codebase_rag.tools.code_retrieval import (
CodeRetriever,
create_code_retrieval_tool,
)
from codebase_rag.tools.codebase_query import create_query_tool
from codebase_rag.tools.directory_lister import (
DirectoryLister,
Expand All @@ -29,6 +32,7 @@
MCPToolSchema,
QueryResultDict,
)
from codebase_rag.utils.dependencies import has_semantic_dependencies


class MCPToolsRegistry:
Expand Down Expand Up @@ -61,6 +65,19 @@ def __init__(
directory_lister=self.directory_lister
)

self._semantic_search_tool = None
self._semantic_search_available = False

if has_semantic_dependencies():
from codebase_rag.tools.semantic_search import (
create_semantic_search_tool,
)

self._semantic_search_tool = create_semantic_search_tool()
self._semantic_search_available = True
else:
logger.info(lg.MCP_SEMANTIC_NOT_AVAILABLE)

self._tools: dict[str, ToolMetadata] = {
cs.MCPToolName.INDEX_REPOSITORY: ToolMetadata(
name=cs.MCPToolName.INDEX_REPOSITORY,
Expand All @@ -73,6 +90,17 @@ def __init__(
handler=self.index_repository,
returns_json=False,
),
cs.MCPToolName.UPDATE_REPOSITORY: ToolMetadata(
name=cs.MCPToolName.UPDATE_REPOSITORY,
description=td.MCP_TOOLS[cs.MCPToolName.UPDATE_REPOSITORY],
input_schema=MCPInputSchema(
type=cs.MCPSchemaType.OBJECT,
properties={},
required=[],
),
handler=self.update_repository,
returns_json=False,
),
cs.MCPToolName.QUERY_CODE_GRAPH: ToolMetadata(
name=cs.MCPToolName.QUERY_CODE_GRAPH,
description=td.MCP_TOOLS[cs.MCPToolName.QUERY_CODE_GRAPH],
Expand Down Expand Up @@ -198,6 +226,28 @@ def __init__(
returns_json=False,
),
}
if self._semantic_search_available:
self._tools[cs.MCPToolName.SEMANTIC_SEARCH] = ToolMetadata(
name=cs.MCPToolName.SEMANTIC_SEARCH,
description=td.MCP_TOOLS[cs.MCPToolName.SEMANTIC_SEARCH],
input_schema=MCPInputSchema(
type=cs.MCPSchemaType.OBJECT,
properties={
cs.MCPParamName.NATURAL_LANGUAGE_QUERY: MCPInputSchemaProperty(
type=cs.MCPSchemaType.STRING,
description=td.MCP_PARAM_NATURAL_LANGUAGE_QUERY,
),
cs.MCPParamName.TOP_K: MCPInputSchemaProperty(
type=cs.MCPSchemaType.INTEGER,
description=td.MCP_PARAM_TOP_K,
default="5",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The TOP_K parameter is defined with type INTEGER, but its default value is provided as a string "5". This type mismatch can lead to schema validation errors or unexpected behavior when the tool is used with default parameters. The default value should be an integer to match the specified type.

Suggested change
default="5",
default=5,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class MCPInputSchemaProperty(TypedDict, total=False):
type: str
description: str
default: str

),
},
required=[cs.MCPParamName.NATURAL_LANGUAGE_QUERY],
),
handler=self.semantic_search,
returns_json=False,
)

async def index_repository(self) -> str:
logger.info(lg.MCP_INDEXING_REPO.format(path=self.project_root))
Expand All @@ -220,6 +270,23 @@ async def index_repository(self) -> str:
logger.error(lg.MCP_ERROR_INDEXING.format(error=e))
return cs.MCP_INDEX_ERROR.format(error=e)

async def update_repository(self) -> str:
logger.info(lg.MCP_UPDATING_REPO.format(path=self.project_root))

try:
updater = GraphUpdater(
ingestor=self.ingestor,
repo_path=Path(self.project_root),
parsers=self.parsers,
queries=self.queries,
)
updater.run()

return cs.MCP_UPDATE_SUCCESS.format(path=self.project_root)
except Exception as e:
logger.error(lg.MCP_ERROR_UPDATING.format(error=e))
return cs.MCP_UPDATE_ERROR.format(error=e)
Comment on lines +286 to +288
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception can obscure the underlying cause of an error and make debugging more difficult. It's better to catch more specific exceptions that you expect updater.run() to raise. This allows for more precise error logging and handling. If GraphUpdater can raise specific custom exceptions, they should be caught here.


async def query_code_graph(self, natural_language_query: str) -> QueryResultDict:
logger.info(lg.MCP_QUERY_CODE_GRAPH.format(query=natural_language_query))
try:
Expand Down Expand Up @@ -278,7 +345,10 @@ async def surgical_replace_code(
return te.ERROR_WRAPPER.format(message=e)

async def read_file(
self, file_path: str, offset: int | None = None, limit: int | None = None
self,
file_path: str,
offset: int | None = None,
limit: int | None = None,
) -> str:
logger.info(lg.MCP_READ_FILE.format(path=file_path, offset=offset, limit=limit))
try:
Expand Down Expand Up @@ -339,6 +409,17 @@ async def list_directory(
logger.error(lg.MCP_ERROR_LIST_DIR.format(error=e))
return te.ERROR_WRAPPER.format(message=e)

async def semantic_search(self, natural_language_query: str, top_k: int = 5) -> str:
if self._semantic_search_tool is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

You've introduced a _semantic_search_available flag in the constructor, which is a great way to track the availability of this optional feature. For consistency and clarity, it would be better to use this flag here instead of checking if _semantic_search_tool is None.

Suggested change
if self._semantic_search_tool is None:
if not self._semantic_search_available:

return cs.MCP_SEMANTIC_NOT_AVAILABLE_RESPONSE

logger.info(lg.MCP_SEMANTIC_SEARCH.format(query=natural_language_query))

result = await self._semantic_search_tool.function(
query=natural_language_query, top_k=top_k
)
return str(result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _semantic_search_tool.function already returns a formatted string. The str() cast here is redundant and can be removed.

Suggested change
return str(result)
return result


def get_tool_schemas(self) -> list[MCPToolSchema]:
return [
MCPToolSchema(
Expand Down
18 changes: 17 additions & 1 deletion codebase_rag/tools/tool_descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,19 @@ class AgenticToolName(StrEnum):

# (H) MCP tool descriptions
MCP_INDEX_REPOSITORY = (
"WARNING: Clears the entire database including embeddings. "
"Parse and ingest the repository into the Memgraph knowledge graph. "
"This builds a comprehensive graph of functions, classes, dependencies, and relationships."
"Use update_repository for incremental updates. Only use when explicitly requested."
)

MCP_UPDATE_REPOSITORY = (
"Update the repository in the Memgraph knowledge graph without clearing existing data. "
"Use this for incremental updates."
)

MCP_QUERY_CODE_GRAPH = (
"Query the codebase knowledge graph using natural language. "
"Use semantic_search unless you know the exact names of classes/functions you are searching for. "
"Ask questions like 'What functions call UserService.create_user?' or "
"'Show me all classes that implement the Repository interface'."
)
Expand All @@ -92,6 +99,12 @@ class AgenticToolName(StrEnum):
"Only modifies the exact target block, leaving the rest unchanged."
)

MCP_SEMANTIC_SEARCH = (
"Performs a semantic search for functions based on a natural language query "
"describing their purpose, returning a list of potential matches with similarity scores. "
"Requires the 'semantic' extra to be installed."
)

MCP_READ_FILE = (
"Read the contents of a file from the project. Supports pagination for large files."
)
Expand All @@ -111,16 +124,19 @@ class AgenticToolName(StrEnum):
MCP_PARAM_LIMIT = "Maximum number of lines to read (optional)"
MCP_PARAM_CONTENT = "Content to write to the file"
MCP_PARAM_DIRECTORY_PATH = "Relative path to directory from project root (default: '.')"
MCP_PARAM_TOP_K = "Max number of results to return (optional, default: 5)"


MCP_TOOLS: dict[MCPToolName, str] = {
MCPToolName.INDEX_REPOSITORY: MCP_INDEX_REPOSITORY,
MCPToolName.UPDATE_REPOSITORY: MCP_UPDATE_REPOSITORY,
MCPToolName.QUERY_CODE_GRAPH: MCP_QUERY_CODE_GRAPH,
MCPToolName.GET_CODE_SNIPPET: MCP_GET_CODE_SNIPPET,
MCPToolName.SURGICAL_REPLACE_CODE: MCP_SURGICAL_REPLACE_CODE,
MCPToolName.READ_FILE: MCP_READ_FILE,
MCPToolName.WRITE_FILE: MCP_WRITE_FILE,
MCPToolName.LIST_DIRECTORY: MCP_LIST_DIRECTORY,
MCPToolName.SEMANTIC_SEARCH: MCP_SEMANTIC_SEARCH,
}

AGENTIC_TOOLS: dict[AgenticToolName, str] = {
Expand Down