diff --git a/docs/02-usage/042_cursor_navigation.md b/docs/02-usage/042_cursor_navigation.md new file mode 100644 index 000000000..4eaac9af6 --- /dev/null +++ b/docs/02-usage/042_cursor_navigation.md @@ -0,0 +1,185 @@ +# Cursor-Based Code Navigation and Editing + +Serena provides a cursor-based interface for incrementally exploring and editing code +structure through LSP graph edges. Instead of jumping directly to a symbol, you place +a **cursor** on a symbol and then **move** it along relationships like containment, +references, calls, and type hierarchy. Edits (replace body, insert before/after, +rename) are issued at the cursor's current position. + +The cursor is the full MCP-exposed interface for LSP-addressable (symbol-level) +activity: search, navigation, overview, read, and edit. The older path-and-name-path +symbol tools (`find_symbol`, `get_symbols_overview`, `find_referencing_symbols`, +`replace_symbol_body`, `insert_before_symbol`, `insert_after_symbol`, `rename_symbol`) +are still available as Python classes for direct CLI / API use, but are no longer +surfaced as default MCP tools — the cursor primitives cover all of their uses. + +## Concepts + +### Cursor +A cursor is a named position in the code graph. You start it at a symbol, then move +it to neighboring symbols. Each cursor maintains a **trail** of every position it has +visited, giving you a breadcrumb path through the code. + +### Edge Types +Cursors navigate along seven types of edges: + +| Edge Type | Direction | Description | +|---|---|---| +| `contains` | Parent -> Children | Symbols defined inside the current symbol (e.g. methods of a class) | +| `references` | Outgoing | Definitions that the current symbol points to | +| `referenced-by` | Incoming | Symbols that reference the current symbol | +| `calls` | Outgoing | Functions/methods called by the current symbol | +| `called-by` | Incoming | Functions/methods that call the current symbol | +| `inherits` | Upward | Supertypes of the current symbol | +| `inherited-by` | Downward | Subtypes of the current symbol | + +All edge types are enabled by default. You can configure which edges a cursor follows +to focus on specific relationships. + +```{note} +Not all language servers support all edge types. For example, some Python language +servers do not support type hierarchy queries (`inherits`/`inherited-by`). The cursor +will gracefully skip unavailable edges. +``` + +## Tools + +Cursor tools are enabled by default and cover search, navigation, overview, and edit. + +### cursor_start +Start a new cursor at a symbol identified by an exact (unique) name path. Returns the +symbol's neighborhood showing all reachable symbols via active edge types. + +**Parameters:** +- `name_path` (required): name path of the symbol (e.g. `MyClass/my_method`). +- `relative_path` (optional): file path to narrow the symbol search. +- `cursor_id` (optional): explicit cursor ID. Auto-generated if omitted. + +### cursor_find +Pattern / substring search for symbols (multi-match variant of `cursor_start`). If the +pattern matches exactly one symbol a cursor is started there; otherwise the candidate +list is returned so you can refine your pattern. + +**Parameters:** +- `name_path_pattern` (required): name path matching pattern. +- `relative_path` (optional): file or directory to restrict the search. +- `depth` (optional): depth of descendants to include for each match. +- `include_body` (optional): include each match's body. +- `include_kinds` / `exclude_kinds` (optional): LSP symbol-kind filters. +- `substring_matching` (optional): substring-match the last segment of the pattern. +- `max_matches` (optional): cap the match count. +- `cursor_id` (optional): ID to use if the match is unique. +- `max_answer_chars` (optional): cap the returned payload size. + +### cursor_overview +List the top-level symbols in a file. Cursor-first replacement for the old +`get_symbols_overview` tool. + +**Parameters:** +- `relative_path` (required): path to the source file. +- `cursor_id` (optional): ID for the cursor (currently informational). +- `max_answer_chars` (optional): cap output size. + +### cursor_move +Move the cursor to an adjacent symbol. The target should be visible in the cursor's +current neighborhood. + +**Parameters:** +- `cursor_id` (required): the cursor to move. +- `target_name` (required): name of the target symbol from the neighborhood. +- `target_relative_path` (optional): file path to disambiguate. + +### cursor_look +Re-examine the cursor's current position and neighborhood without moving. + +### cursor_configure +Configure which edge types a cursor follows and whether the symbol body is shown. + +**Parameters:** +- `cursor_id` (required). +- `edge_types` (optional): list of edge type names; empty = all. +- `include_body` (optional). + +### cursor_history +Show the trail of symbols visited by a cursor, from start to current position. + +### cursor_close +Close a cursor and free its resources. + +### cursor_replace_body +Replace the body of the symbol at the cursor. The cursor remains on the same symbol +and its stored location is refreshed. + +**Parameters:** +- `cursor_id` (required). +- `body` (required): the new body text. Does NOT include preceding docstrings / + comments / imports. + +### cursor_insert_before +Insert content immediately before the symbol at the cursor. The cursor stays on the +target symbol. + +**Parameters:** +- `cursor_id` (required). +- `body` (required): content to insert. + +### cursor_insert_after +Insert content immediately after the symbol at the cursor. The cursor stays on the +target symbol. + +**Parameters:** +- `cursor_id` (required). +- `body` (required): content to insert. + +### cursor_rename +Rename the symbol at the cursor throughout the codebase using the language server's +refactoring support. The cursor re-anchors to the renamed symbol. + +**Parameters:** +- `cursor_id` (required). +- `new_name` (required). + +## Usage Patterns + +### Exploring a Class +1. `cursor_start` at the class name (or `cursor_find` if you only know a substring). +2. `cursor_move` into a method of interest. +3. `cursor_configure` with `edge_types: ["calls"]` to see what the method calls. +4. `cursor_move` to follow a call chain. +5. `cursor_history` to review your exploration path. + +### Tracing References +1. `cursor_start` at a symbol you want to trace. +2. `cursor_configure` with `edge_types: ["referenced-by"]` to see all usage sites. +3. `cursor_move` to a reference to inspect the calling context. + +### Editing at the Cursor +1. `cursor_start` (or `cursor_find`) to place the cursor on the symbol to edit. +2. Optionally `cursor_configure` with `include_body=True` to review the current body. +3. `cursor_replace_body`, `cursor_insert_before`, `cursor_insert_after`, or + `cursor_rename` to perform the edit. + +## Migration from the Old Symbol Tools + +| Old tool | Cursor equivalent | +|---|---| +| `find_symbol` | `cursor_find` | +| `get_symbols_overview` | `cursor_overview` | +| `find_referencing_symbols` | `cursor_start` + `cursor_configure edge_types=["referenced-by"]` + `cursor_look` | +| `replace_symbol_body` | `cursor_start` / `cursor_find` + `cursor_replace_body` | +| `insert_before_symbol` | `cursor_start` / `cursor_find` + `cursor_insert_before` | +| `insert_after_symbol` | `cursor_start` / `cursor_find` + `cursor_insert_after` | +| `rename_symbol` | `cursor_start` / `cursor_find` + `cursor_rename` | + +The old tools are marked `ToolMarkerOptional` and remain available as Python classes +for direct CLI / API use, and can be re-enabled over MCP by adding their names to +`included_optional_tools` in your context or mode configuration. + +## Multiple Cursors +You can have multiple cursors active simultaneously. Each cursor maintains independent +state (position, trail, edge-type configuration). This is useful for comparing +different parts of the code or maintaining context while exploring a call chain. + +## Disabling Cursor Tools +Cursor tools are enabled by default. To disable them, add the cursor tool names to +the `excluded_tools` list in your project / mode / context configuration. diff --git a/src/serena/agent.py b/src/serena/agent.py index 322db7c86..9fe10bc23 100644 --- a/src/serena/agent.py +++ b/src/serena/agent.py @@ -54,6 +54,7 @@ from solidlsp.ls_config import Language if TYPE_CHECKING: + from serena.cursor import CursorManager from serena.gui_log_viewer import GuiLogViewer log = logging.getLogger(__name__) @@ -528,6 +529,18 @@ def get_language_server_manager_or_raise(self) -> LanguageServerManager: active_project = self.get_active_project_or_raise() return active_project.get_language_server_manager_or_raise() + def get_cursor_manager(self) -> "CursorManager": + """Get or create the CursorManager for cursor-based code navigation.""" + from serena.cursor import CursorManager + + cursor_mgr: CursorManager | None = getattr(self, "_cursor_manager", None) + if cursor_mgr is None: + project = self.get_active_project_or_raise() + project.get_language_server_manager_or_raise() # validate LSP is available + cursor_mgr = CursorManager(project) + self._cursor_manager = cursor_mgr # type: ignore[assignment] + return cursor_mgr + def get_log_inspection_instructions(self) -> str: if self.serena_config.web_dashboard: return f"Live logs can be inspected via the dashboard at {self.get_dashboard_url()}" @@ -832,6 +845,7 @@ def _activate_project(self, project: Project, update_active_modes: bool = True, self._active_project.shutdown() self._active_project = project + self._cursor_manager = None # type: ignore[assignment] # reset cursor manager on project switch project.set_agent(self) if update_active_modes: diff --git a/src/serena/cursor.py b/src/serena/cursor.py new file mode 100644 index 000000000..5ed6b8d12 --- /dev/null +++ b/src/serena/cursor.py @@ -0,0 +1,555 @@ +""" +Cursor-based code navigation for Serena. + +Provides a stateful cursor that can be positioned on a symbol and moved along LSP graph edges +(contains, references, calls, type hierarchy) for incremental exploration. +""" + +import logging +import os +from collections.abc import Sequence +from dataclasses import dataclass, field +from enum import Enum + +from serena.project import Project +from serena.symbol import LanguageServerSymbol, LanguageServerSymbolLocation, LanguageServerSymbolRetriever +from solidlsp.ls_exceptions import SolidLSPException +from solidlsp.ls_utils import PathUtils +from solidlsp.lsp_protocol_handler.lsp_types import ( + CallHierarchyItem, + SymbolKind, + TypeHierarchyItem, +) + +log = logging.getLogger(__name__) + + +class EdgeType(Enum): + """Types of edges available for cursor navigation.""" + + CONTAINS = "contains" + REFERENCES = "references" + REFERENCED_BY = "referenced-by" + CALLS = "calls" + CALLED_BY = "called-by" + INHERITS = "inherits" + INHERITED_BY = "inherited-by" + + +ALL_EDGE_TYPES = frozenset(EdgeType) + +# Edge types that are enabled by default (the most commonly useful ones) +DEFAULT_EDGE_TYPES = frozenset( + { + EdgeType.CONTAINS, + EdgeType.REFERENCES, + EdgeType.REFERENCED_BY, + EdgeType.CALLS, + EdgeType.CALLED_BY, + EdgeType.INHERITS, + EdgeType.INHERITED_BY, + } +) + + +@dataclass +class NeighborSymbol: + """A symbol reachable from the current cursor position via a specific edge type.""" + + name: str + kind: str + relative_path: str | None + line: int | None + column: int | None + edge_type: EdgeType + detail: str | None = None + + @property + def location_str(self) -> str: + if self.relative_path and self.line is not None: + return f"{self.relative_path}:{self.line + 1}" + elif self.relative_path: + return self.relative_path + return "?" + + def format_compact(self) -> str: + parts = [self.name] + if self.kind: + parts.append(f"({self.kind})") + parts.append(f"[{self.location_str}]") + if self.detail: + parts.append(f"— {self.detail}") + return " ".join(parts) + + +@dataclass +class CursorState: + """The state of a single navigation cursor.""" + + cursor_id: str + current_symbol: LanguageServerSymbol + current_location: LanguageServerSymbolLocation + trail: list[LanguageServerSymbolLocation] = field(default_factory=list) + active_edge_types: frozenset[EdgeType] = DEFAULT_EDGE_TYPES + include_body: bool = False + + def record_move(self, new_symbol: LanguageServerSymbol, new_location: LanguageServerSymbolLocation) -> None: + """Record moving the cursor to a new symbol.""" + self.trail.append(self.current_location) + self.current_symbol = new_symbol + self.current_location = new_location + + +class CursorManager: + """ + Manages cursor state and resolves LSP graph edges for navigation. + + Each cursor is identified by a string ID and tracks its current symbol, + trail of visited symbols, and configured edge types. + """ + + def __init__(self, project: Project) -> None: + self._project = project + self._cursors: dict[str, CursorState] = {} + self._next_cursor_id = 1 + + @property + def _retriever(self) -> LanguageServerSymbolRetriever: + return LanguageServerSymbolRetriever(self._project) + + def _generate_cursor_id(self) -> str: + cursor_id = f"c{self._next_cursor_id}" + self._next_cursor_id += 1 + return cursor_id + + def get_cursor(self, cursor_id: str) -> CursorState: + if cursor_id not in self._cursors: + raise ValueError(f"No cursor with id '{cursor_id}'. Active cursors: {list(self._cursors.keys())}") + return self._cursors[cursor_id] + + def list_cursors(self) -> list[str]: + return list(self._cursors.keys()) + + def start_cursor( + self, + name_path: str, + relative_path: str | None = None, + cursor_id: str | None = None, + ) -> tuple[str, CursorState]: + """ + Start a new cursor at a symbol identified by name_path. + + :param name_path: the name path of the symbol (e.g. "MyClass/my_method") + :param relative_path: optional file path to narrow the search + :param cursor_id: optional explicit cursor ID; auto-generated if None + :return: tuple of (cursor_id, cursor_state) + """ + retriever = self._retriever + symbol = retriever.find_unique(name_path, within_relative_path=relative_path) + location = symbol.location + + if cursor_id is None: + cursor_id = self._generate_cursor_id() + elif cursor_id in self._cursors: + raise ValueError( + f"Cursor '{cursor_id}' already exists. Close it first or use a different ID. Active cursors: {list(self._cursors.keys())}" + ) + + state = CursorState( + cursor_id=cursor_id, + current_symbol=symbol, + current_location=location, + ) + self._cursors[cursor_id] = state + return cursor_id, state + + def move_cursor( + self, + cursor_id: str, + target_name: str, + target_relative_path: str | None = None, + ) -> CursorState: + """ + Move a cursor to a neighboring symbol by name. + + The target must be reachable from the current position via one of the active edge types. + If target_relative_path is provided, it narrows the match. + + :param cursor_id: the cursor to move + :param target_name: name (or name_path) of the target symbol + :param target_relative_path: optional file path to disambiguate + :return: updated cursor state + """ + state = self.get_cursor(cursor_id) + + # First, try to find the target among current neighbors + neighbors = self.resolve_neighbors(cursor_id) + candidates = [n for n in neighbors if target_name in n.name or n.name in target_name] + if target_relative_path: + candidates = [n for n in candidates if n.relative_path and target_relative_path in n.relative_path] + + if not candidates: + # Fall back to global symbol search + retriever = self._retriever + symbol = retriever.find_unique(target_name, within_relative_path=target_relative_path) + location = symbol.location + elif len(candidates) == 1: + candidate = candidates[0] + retriever = self._retriever + if candidate.relative_path and candidate.line is not None and candidate.column is not None: + symbol_location = LanguageServerSymbolLocation( + relative_path=candidate.relative_path, + line=candidate.line, + column=candidate.column, + ) + found = retriever.find_by_location(symbol_location) + if found: + symbol = found + location = symbol.location + else: + symbol = retriever.find_unique(candidate.name, within_relative_path=candidate.relative_path) + location = symbol.location + else: + symbol = retriever.find_unique(candidate.name, within_relative_path=candidate.relative_path) + location = symbol.location + else: + # Multiple candidates — try exact name match + exact = [n for n in candidates if n.name == target_name] + if len(exact) == 1: + candidate = exact[0] + else: + names = [f" {n.name} ({n.kind}) [{n.location_str}]" for n in candidates] + raise ValueError(f"Ambiguous target '{target_name}'. Candidates:\n" + "\n".join(names)) + retriever = self._retriever + symbol = retriever.find_unique(candidate.name, within_relative_path=candidate.relative_path) + location = symbol.location + + state.record_move(symbol, location) + return state + + def close_cursor(self, cursor_id: str) -> None: + """Close and remove a cursor.""" + if cursor_id in self._cursors: + del self._cursors[cursor_id] + + def resolve_neighbors(self, cursor_id: str, depth: int = 1) -> list[NeighborSymbol]: + """ + Resolve all neighbors of the cursor's current symbol via active edge types. + + :param cursor_id: the cursor whose neighborhood to resolve + :param depth: traversal depth (currently only 1 is supported) + :return: list of neighbor symbols with their edge types + """ + state = self.get_cursor(cursor_id) + symbol = state.current_symbol + location = state.current_location + neighbors: list[NeighborSymbol] = [] + + if location.relative_path is None or location.line is None or location.column is None: + log.warning(f"Cursor {cursor_id} symbol has incomplete location, cannot resolve neighbors") + return neighbors + + rel_path = location.relative_path + line = location.line + col = location.column + + # Contains: children of the current symbol + if EdgeType.CONTAINS in state.active_edge_types: + for child in symbol.iter_children(): + neighbors.append( + NeighborSymbol( + name=child.name, + kind=child.symbol_kind_name, + relative_path=child.relative_path or rel_path, + line=child.line, + column=child.column, + edge_type=EdgeType.CONTAINS, + ) + ) + + retriever = self._retriever + ls = retriever.get_language_server(rel_path) + failed_edge_types: list[EdgeType] = [] + + # References: symbols that THIS symbol references (definitions it points to) + if EdgeType.REFERENCES in state.active_edge_types: + try: + definitions = ls.request_definition(rel_path, line, col) + for defn in definitions: + defn_rel_path = defn.get("relativePath") + defn_range = defn.get("range", {}) + defn_start = defn_range.get("start", {}) + if defn_rel_path: + # Try to get the symbol name at this location + defn_line = defn_start.get("line", 0) + defn_col = defn_start.get("character", 0) + name = self._symbol_name_at(defn_rel_path, defn_line, defn_col) + neighbors.append( + NeighborSymbol( + name=name, + kind="", + relative_path=defn_rel_path, + line=defn_line, + column=defn_col, + edge_type=EdgeType.REFERENCES, + ) + ) + except SolidLSPException as e: + log.debug(f"Failed to resolve definitions for cursor: {e}") + failed_edge_types.append(EdgeType.REFERENCES) + + # Referenced-by: symbols that reference THIS symbol + if EdgeType.REFERENCED_BY in state.active_edge_types: + try: + ref_symbols = ls.request_referencing_symbols(rel_path, line, col, include_imports=False, include_self=False) + for ref in ref_symbols: + ref_sym = ref.symbol + ref_rel_path = ref_sym["location"].get("relativePath", "") + sel_range = ref_sym.get("selectionRange", {}) + sel_start = sel_range.get("start", {}) + neighbors.append( + NeighborSymbol( + name=ref_sym["name"], + kind=SymbolKind(ref_sym["kind"]).name, + relative_path=ref_rel_path, + line=sel_start.get("line"), + column=sel_start.get("character"), + edge_type=EdgeType.REFERENCED_BY, + ) + ) + except SolidLSPException as e: + log.debug(f"Failed to resolve referencing symbols for cursor: {e}") + failed_edge_types.append(EdgeType.REFERENCED_BY) + + # Calls: symbols that this symbol calls (outgoing calls) + if EdgeType.CALLS in state.active_edge_types: + try: + outgoing = ls.request_call_hierarchy_outgoing(rel_path, line, col) + for outgoing_call in outgoing: + target = outgoing_call["to"] + neighbors.append(self._neighbor_from_hierarchy_item(target, EdgeType.CALLS)) + except SolidLSPException as e: + log.debug(f"Failed to resolve outgoing calls for cursor: {e}") + failed_edge_types.append(EdgeType.CALLS) + + # Called-by: symbols that call this symbol (incoming calls) + if EdgeType.CALLED_BY in state.active_edge_types: + try: + incoming = ls.request_call_hierarchy_incoming(rel_path, line, col) + for incoming_call in incoming: + caller = incoming_call["from"] + neighbors.append(self._neighbor_from_hierarchy_item(caller, EdgeType.CALLED_BY)) + except SolidLSPException as e: + log.debug(f"Failed to resolve incoming calls for cursor: {e}") + failed_edge_types.append(EdgeType.CALLED_BY) + + # Inherits: supertypes of the current symbol + if EdgeType.INHERITS in state.active_edge_types: + try: + supertypes = ls.request_type_hierarchy_supertypes(rel_path, line, col) + for item in supertypes: + neighbors.append(self._neighbor_from_type_hierarchy_item(item, EdgeType.INHERITS)) + except SolidLSPException as e: + log.debug(f"Failed to resolve supertypes for cursor: {e}") + failed_edge_types.append(EdgeType.INHERITS) + + # Inherited-by: subtypes of the current symbol + if EdgeType.INHERITED_BY in state.active_edge_types: + try: + subtypes = ls.request_type_hierarchy_subtypes(rel_path, line, col) + for item in subtypes: + neighbors.append(self._neighbor_from_type_hierarchy_item(item, EdgeType.INHERITED_BY)) + except SolidLSPException as e: + log.debug(f"Failed to resolve subtypes for cursor: {e}") + failed_edge_types.append(EdgeType.INHERITED_BY) + + if failed_edge_types: + names = ", ".join(e.value for e in failed_edge_types) + log.warning(f"Cursor {cursor_id}: {len(failed_edge_types)} edge type(s) failed to resolve: {names}") + + return neighbors + + def _symbol_name_at(self, relative_path: str, line: int, col: int) -> str: + """Try to find the symbol name at a location, falling back to file:line.""" + try: + retriever = self._retriever + location = LanguageServerSymbolLocation(relative_path=relative_path, line=line, column=col) + found = retriever.find_by_location(location) + if found: + return found.name + except Exception as e: + log.debug(f"Could not resolve symbol name at {relative_path}:{line + 1}: {e}") + return f"{os.path.basename(relative_path)}:{line + 1}" + + def _neighbor_from_hierarchy_item( + self, + item: CallHierarchyItem, + edge_type: EdgeType, + ) -> NeighborSymbol: + """Create a NeighborSymbol from a CallHierarchyItem.""" + uri = item["uri"] + rel_path = PathUtils.get_relative_path(PathUtils.uri_to_path(uri), self._project.project_root) + sel_start = item["selectionRange"]["start"] + try: + kind_name = SymbolKind(item["kind"]).name + except ValueError: + kind_name = str(item["kind"]) + return NeighborSymbol( + name=item["name"], + kind=kind_name, + relative_path=rel_path if rel_path else None, + line=sel_start.get("line"), + column=sel_start.get("character"), + edge_type=edge_type, + detail=item.get("detail"), + ) + + def _neighbor_from_type_hierarchy_item( + self, + item: TypeHierarchyItem, + edge_type: EdgeType, + ) -> NeighborSymbol: + """Create a NeighborSymbol from a TypeHierarchyItem.""" + # TypeHierarchyItem and CallHierarchyItem share the same relevant fields + return self._neighbor_from_hierarchy_item(item, edge_type) # type: ignore[arg-type] + + def format_cursor_view(self, cursor_id: str) -> str: + """ + Format the current cursor position and its neighborhood as structured text. + + :param cursor_id: the cursor to format + :return: human-readable text representation + """ + state = self.get_cursor(cursor_id) + symbol = state.current_symbol + location = state.current_location + + lines: list[str] = [] + + # Header: current symbol + loc_str = "" + if location.relative_path and location.line is not None: + loc_str = f" [{location.relative_path}:{location.line + 1}]" + lines.append(f"@ {symbol.get_name_path()} ({symbol.symbol_kind_name}){loc_str}") + lines.append(f" cursor: {state.cursor_id} | trail: {len(state.trail)} steps") + + # Body (if configured) + if state.include_body and symbol.body: + lines.append("") + lines.append("--- body ---") + lines.append(symbol.body) + lines.append("--- end body ---") + + # Neighbors grouped by edge type + neighbors = self.resolve_neighbors(cursor_id) + neighbors_by_edge: dict[EdgeType, list[NeighborSymbol]] = {} + for n in neighbors: + neighbors_by_edge.setdefault(n.edge_type, []).append(n) + + if neighbors_by_edge: + lines.append("") + for edge_type in EdgeType: + edge_neighbors = neighbors_by_edge.get(edge_type) + if edge_neighbors: + lines.append(f" {edge_type.value}:") + for n in edge_neighbors: + lines.append(f" {n.format_compact()}") + else: + lines.append("") + lines.append(" (no neighbors found)") + + lines.append("") + lines.append("Use cursor_move to navigate to a neighbor, cursor_look to re-examine.") + + return "\n".join(lines) + + def find_symbols( + self, + name_path_pattern: str, + relative_path: str | None = None, + include_kinds: Sequence[SymbolKind] | None = None, + exclude_kinds: Sequence[SymbolKind] | None = None, + substring_matching: bool = False, + ) -> list[LanguageServerSymbol]: + """ + Pattern/substring search for symbols. Delegates to the language server symbol retriever + (same backend as the old ``find_symbol`` tool). + """ + return self._retriever.find( + name_path_pattern, + include_kinds=include_kinds, + exclude_kinds=exclude_kinds, + substring_matching=substring_matching, + within_relative_path=relative_path, + ) + + def register_cursor_at_symbol( + self, + symbol: LanguageServerSymbol, + cursor_id: str | None = None, + ) -> tuple[str, CursorState]: + """ + Register a cursor positioned on an already-resolved symbol (e.g. one returned by + ``find_symbols``). Used by ``cursor_find`` when the search is unique. + """ + location = symbol.location + if cursor_id is None: + cursor_id = self._generate_cursor_id() + elif cursor_id in self._cursors: + raise ValueError( + f"Cursor '{cursor_id}' already exists. Close it first or use a different ID. Active cursors: {list(self._cursors.keys())}" + ) + state = CursorState( + cursor_id=cursor_id, + current_symbol=symbol, + current_location=location, + ) + self._cursors[cursor_id] = state + return cursor_id, state + + def reanchor_cursor( + self, + cursor_id: str, + name_path: str | None = None, + relative_path: str | None = None, + ) -> CursorState: + """ + Re-resolve the cursor's current symbol from the language server, after an edit + potentially changed line numbers or the symbol's name. The cursor remains on the + same logical symbol; its stored position is refreshed. + + :param cursor_id: the cursor to re-anchor + :param name_path: override name path (e.g. after a rename); defaults to the current symbol's name path + :param relative_path: override file path; defaults to the current cursor location's file + :return: the updated cursor state + """ + state = self.get_cursor(cursor_id) + resolved_name = name_path if name_path is not None else state.current_symbol.get_name_path() + resolved_path = relative_path if relative_path is not None else state.current_location.relative_path + retriever = self._retriever + symbol = retriever.find_unique(resolved_name, within_relative_path=resolved_path) + state.current_symbol = symbol + state.current_location = symbol.location + return state + + def format_trail(self, cursor_id: str) -> str: + """Format the cursor's visited trail as text.""" + state = self.get_cursor(cursor_id) + if not state.trail: + return f"Cursor {cursor_id}: no trail (at starting position)" + + lines = [f"Cursor {cursor_id} trail ({len(state.trail)} steps):"] + for i, loc in enumerate(state.trail): + loc_str = "" + if loc.relative_path and loc.line is not None: + loc_str = f"{loc.relative_path}:{loc.line + 1}" + else: + loc_str = str(loc.relative_path or "?") + lines.append(f" {i + 1}. {loc_str}") + + # Current position + cur = state.current_location + if cur.relative_path and cur.line is not None: + lines.append(f" -> {cur.relative_path}:{cur.line + 1} (current)") + + return "\n".join(lines) diff --git a/src/serena/resources/config/internal_modes/jetbrains.yml b/src/serena/resources/config/internal_modes/jetbrains.yml index a08b189f2..e46817763 100644 --- a/src/serena/resources/config/internal_modes/jetbrains.yml +++ b/src/serena/resources/config/internal_modes/jetbrains.yml @@ -12,6 +12,19 @@ excluded_tools: - restart_language_server - safe_delete_symbol - rename_symbol + # Cursor tools depend on LSP graph edges; disabled under the JetBrains backend. + - cursor_start + - cursor_move + - cursor_look + - cursor_configure + - cursor_history + - cursor_close + - cursor_find + - cursor_overview + - cursor_replace_body + - cursor_insert_before + - cursor_insert_after + - cursor_rename included_optional_tools: - jet_brains_find_declaration - jet_brains_find_implementations diff --git a/src/serena/resources/config/modes/editing.yml b/src/serena/resources/config/modes/editing.yml index 57d800b58..b90ab2cc4 100644 --- a/src/serena/resources/config/modes/editing.yml +++ b/src/serena/resources/config/modes/editing.yml @@ -16,15 +16,20 @@ prompt: | The symbol-based approach is appropriate if you need to adjust an entire symbol, e.g. a method, a class, a function, etc. It is not appropriate if you need to adjust just a few lines of code within a larger symbol. - **Symbolic editing** - Use symbolic retrieval tools to identify the symbols you need to edit. - If you need to replace the definition of a symbol, use the `replace_symbol_body` tool. - If you want to add some new code at the end of the file, use the `insert_after_symbol` tool with the last top-level symbol in the file. - Similarly, you can use `insert_before_symbol` with the first top-level symbol in the file to insert code at the beginning of a file. - You can understand relationships between symbols by using the `find_referencing_symbols` tool. If not explicitly requested otherwise by the user, - you make sure that when you edit a symbol, the change is either backward-compatible or you find and update all references as needed. - The `find_referencing_symbols` tool will give you code snippets around the references as well as symbolic information. - You can assume that all symbol editing tools are reliable, so you never need to verify the results if the tools return without error. + **Symbolic editing via the cursor** + Symbol-level navigation, search and editing go through the cursor: position a cursor on the symbol you want + to work on with `cursor_start` (exact match) or `cursor_find` (pattern / substring search), navigate with + `cursor_move` / `cursor_look` / `cursor_configure` along LSP graph edges (contains, references, referenced-by, + calls, called-by, inherits, inherited-by), and perform edits at the cursor's position: + * `cursor_replace_body` replaces the body of the symbol at the cursor + * `cursor_insert_before` / `cursor_insert_after` insert code around the symbol at the cursor + * `cursor_rename` renames the symbol at the cursor throughout the codebase + `cursor_overview` lists top-level symbols in a file (cursor-first form of the old `get_symbols_overview`). + Use `cursor_configure` with `edge_types: ["referenced-by"]` to understand usage before editing. + If not explicitly requested otherwise by the user, make sure that when you edit a symbol, the change is either + backward-compatible or you find and update all references as needed. + You can assume that all cursor editing operations are reliable, so you never need to verify the results if the + tool returns without error. {% if 'replace_content' in available_tools %} **File-based editing** diff --git a/src/serena/resources/config/modes/onboarding.yml b/src/serena/resources/config/modes/onboarding.yml index dff454036..5336f06b2 100644 --- a/src/serena/resources/config/modes/onboarding.yml +++ b/src/serena/resources/config/modes/onboarding.yml @@ -10,6 +10,11 @@ excluded_tools: - replace_symbol_body - insert_after_symbol - insert_before_symbol + - rename_symbol + - cursor_replace_body + - cursor_insert_after + - cursor_insert_before + - cursor_rename - delete_lines - replace_lines - insert_at_line diff --git a/src/serena/resources/config/modes/planning.yml b/src/serena/resources/config/modes/planning.yml index a24d0dfb1..8003bee34 100644 --- a/src/serena/resources/config/modes/planning.yml +++ b/src/serena/resources/config/modes/planning.yml @@ -7,6 +7,11 @@ excluded_tools: - replace_symbol_body - insert_after_symbol - insert_before_symbol + - rename_symbol + - cursor_replace_body + - cursor_insert_after + - cursor_insert_before + - cursor_rename - delete_lines - replace_lines - insert_at_line diff --git a/src/serena/resources/config/prompt_templates/system_prompt.yml b/src/serena/resources/config/prompt_templates/system_prompt.yml index debefcd26..3d0e2c74d 100644 --- a/src/serena/resources/config/prompt_templates/system_prompt.yml +++ b/src/serena/resources/config/prompt_templates/system_prompt.yml @@ -24,8 +24,16 @@ prompts: {% endif %} {% if 'ToolMarkerSymbolicRead' in available_markers %} - Symbols are identified by their `name_path` and `relative_path` (see the description of the `find_symbol` tool). - You can get information about the symbols in a file by using the `get_symbols_overview` tool or use the `find_symbol` to search. + Symbols are identified by a `name_path` within a source file (e.g. `MyClass/my_method`) and an optional `relative_path`. + {% if 'cursor_start' in available_tools %} + The cursor is the primary interface for symbol-level activity: place a cursor at a symbol with `cursor_start` + (exact-match, unique) or `cursor_find` (pattern / substring search), then navigate along LSP graph edges with + `cursor_move` / `cursor_look` / `cursor_configure` (edge types: contains, references, referenced-by, calls, + called-by, inherits, inherited-by). Use `cursor_overview` to list top-level symbols in a file. + You only read symbol bodies when you need to — set `include_body=True` on `cursor_configure` (or `cursor_find`) + for the current position only. + {% elif 'find_symbol' in available_tools %} + You can get information about the symbols in a file by using the `get_symbols_overview` tool or use the `find_symbol` to search. You only read the bodies of symbols when you need to (e.g. if you want to fully understand or edit it). For example, if you are working with Python code and already know that you need to read the body of the constructor of the class Foo, you can directly use `find_symbol` with name path pattern `Foo/__init__` and `include_body=True`. If you don't know yet which methods in `Foo` you need to read or edit, @@ -33,6 +41,7 @@ prompts: to read the desired methods with `include_body=True`. You can understand relationships between symbols by using the `find_referencing_symbols` tool. {% endif %} + {% endif %} {% if 'read_memory' in available_tools -%} You generally have access to memories and it may be useful for you to read them. diff --git a/src/serena/tools/__init__.py b/src/serena/tools/__init__.py index 08bb4338f..a91e5191f 100644 --- a/src/serena/tools/__init__.py +++ b/src/serena/tools/__init__.py @@ -8,3 +8,4 @@ from .workflow_tools import * from .jetbrains_tools import * from .query_project_tools import * +from .cursor_tools import * diff --git a/src/serena/tools/cursor_tools.py b/src/serena/tools/cursor_tools.py new file mode 100644 index 000000000..531867a02 --- /dev/null +++ b/src/serena/tools/cursor_tools.py @@ -0,0 +1,455 @@ +""" +Cursor-based code navigation tools. + +These tools provide stateful, incremental navigation through code structure +using LSP graph edges (containment, references, call hierarchy, type hierarchy). +They also cover pattern-based symbol search (``cursor_find``) and symbol-level +editing (``cursor_replace_body``, ``cursor_insert_before``, ``cursor_insert_after``, +``cursor_rename``) — so the cursor is the full MCP-exposed interface for +LSP-addressable (symbol-level) activity. +""" + +from collections import defaultdict +from collections.abc import Sequence + +from serena.cursor import EdgeType +from serena.tools import SUCCESS_RESULT +from serena.tools.tools_base import Tool, ToolMarkerSymbolicEdit, ToolMarkerSymbolicRead +from solidlsp.ls_types import SymbolKind + + +class CursorStartTool(Tool, ToolMarkerSymbolicRead): + """ + Start a navigation cursor at a symbol for incremental code exploration. + The cursor tracks your position and lets you navigate along code relationships + (containment, references, calls, type hierarchy). + """ + + def apply( + self, + name_path: str, + relative_path: str = "", + cursor_id: str = "", + ) -> str: + """ + Start a new cursor at the specified symbol. Returns the symbol's neighborhood + showing all reachable symbols via active edge types. + + :param name_path: name path of the symbol to start at (e.g. "MyClass/my_method"). + See find_symbol for name path pattern syntax. + :param relative_path: optional file path to narrow the symbol search. + :param cursor_id: optional explicit cursor ID. Auto-generated if empty. + :return: the cursor view showing the symbol and its neighborhood. + """ + manager = self.agent.get_cursor_manager() + cid, _state = manager.start_cursor( + name_path=name_path, + relative_path=relative_path or None, + cursor_id=cursor_id or None, + ) + return manager.format_cursor_view(cid) + + +class CursorMoveTool(Tool, ToolMarkerSymbolicRead): + """ + Move a navigation cursor to an adjacent symbol. The target should be + visible in the cursor's current neighborhood (from cursor_start or cursor_look output). + """ + + def apply( + self, + cursor_id: str, + target_name: str, + target_relative_path: str = "", + ) -> str: + """ + Move the cursor to a neighboring symbol. Returns the new position's neighborhood. + + :param cursor_id: the ID of the cursor to move (shown in cursor output). + :param target_name: name of the symbol to move to. Must be visible in the current neighborhood. + :param target_relative_path: optional file path to disambiguate if multiple neighbors share the same name. + :return: the updated cursor view at the new position. + """ + manager = self.agent.get_cursor_manager() + manager.move_cursor( + cursor_id=cursor_id, + target_name=target_name, + target_relative_path=target_relative_path or None, + ) + return manager.format_cursor_view(cursor_id) + + +class CursorLookTool(Tool, ToolMarkerSymbolicRead): + """ + Look at the neighborhood of the cursor's current position without moving. + Useful for re-examining the current position after changing edge type configuration. + """ + + def apply(self, cursor_id: str) -> str: + """ + Show the current cursor position and its neighborhood. + + :param cursor_id: the ID of the cursor to look from. + :return: the cursor view showing the symbol and its neighborhood. + """ + manager = self.agent.get_cursor_manager() + return manager.format_cursor_view(cursor_id) + + +class CursorConfigureTool(Tool, ToolMarkerSymbolicRead): + """ + Configure which edge types a cursor follows and what information is shown. + Edge types: contains, references, referenced-by, calls, called-by, inherits, inherited-by. + """ + + # noinspection PyDefaultArgument + def apply( + self, + cursor_id: str, + edge_types: list[str] = [], # noqa: B006 + include_body: bool = False, + ) -> str: + """ + Configure the cursor's active edge types and display options. + + :param cursor_id: the ID of the cursor to configure. + :param edge_types: list of edge type names to enable. If empty, all edge types are enabled. + Valid values: contains, references, referenced-by, calls, called-by, inherits, inherited-by. + :param include_body: whether to include the symbol's source code body in the cursor view. + :return: confirmation of the new configuration and updated cursor view. + """ + manager = self.agent.get_cursor_manager() + state = manager.get_cursor(cursor_id) + + if edge_types: + valid_types: set[EdgeType] = set() + for name in edge_types: + try: + valid_types.add(EdgeType(name)) + except ValueError: + valid_names = [e.value for e in EdgeType] + raise ValueError(f"Unknown edge type '{name}'. Valid edge types: {valid_names}") + state.active_edge_types = frozenset(valid_types) + else: + state.active_edge_types = frozenset(EdgeType) + + state.include_body = include_body + + return manager.format_cursor_view(cursor_id) + + +class CursorHistoryTool(Tool, ToolMarkerSymbolicRead): + """ + Show the trail of symbols visited by a cursor, from start to current position. + """ + + def apply(self, cursor_id: str) -> str: + """ + Show the navigation trail for a cursor. + + :param cursor_id: the ID of the cursor. + :return: the formatted trail showing each visited location. + """ + manager = self.agent.get_cursor_manager() + return manager.format_trail(cursor_id) + + +class CursorCloseTool(Tool, ToolMarkerSymbolicRead): + """ + Close a navigation cursor and free its resources. + """ + + def apply(self, cursor_id: str) -> str: + """ + Close a cursor. + + :param cursor_id: the ID of the cursor to close. + :return: confirmation that the cursor was closed. + """ + manager = self.agent.get_cursor_manager() + manager.close_cursor(cursor_id) + return f"Cursor {cursor_id} closed." + + +class CursorFindTool(Tool, ToolMarkerSymbolicRead): + """ + Search for symbols in the codebase by name path pattern (the multi-match variant of + ``cursor_start``, which requires a unique match). If the search yields exactly one + symbol a cursor is started there; otherwise the candidate list is returned so the + caller can disambiguate and follow up with ``cursor_start``. + """ + + # noinspection PyDefaultArgument + def apply( + self, + name_path_pattern: str, + relative_path: str = "", + depth: int = 0, + include_body: bool = False, + include_kinds: list[int] = [], # noqa: B006 + exclude_kinds: list[int] = [], # noqa: B006 + substring_matching: bool = False, + max_matches: int = -1, + cursor_id: str = "", + max_answer_chars: int = -1, + ) -> str: + """ + Search for symbols matching a name path pattern. + + A name path is a path in the symbol tree *within a source file*. + Examples: ``"method"`` (any symbol named ``method``), ``"MyClass/method"`` + (``method`` inside ``MyClass``), ``"/MyClass/method"`` (exact top-level path). + Append ``[i]`` for a specific overload. + + If the pattern uniquely identifies a symbol, a cursor is started on it and the + cursor view is returned. Otherwise, the list of candidate symbols is returned so + you can refine the pattern or call ``cursor_start`` with a more specific one. + + :param name_path_pattern: name path matching pattern. + :param relative_path: optional file or directory to restrict the search to. + :param depth: depth up to which descendants shall be included for each match. Ignored + when ``include_body=True``. Default 0. + :param include_body: whether to include each match's source code. Use judiciously. + :param include_kinds: LSP symbol kind integers to include (empty = all). + :param exclude_kinds: LSP symbol kind integers to exclude. Takes precedence over ``include_kinds``. + :param substring_matching: if True, the last element of the pattern is matched as a + substring (e.g. ``"Foo/get"`` matches ``"Foo/getValue"``). + :param max_matches: maximum permitted matches; -1 (default) means no limit. + :param cursor_id: optional cursor ID to use when the match is unique. Auto-generated otherwise. + :param max_answer_chars: maximum characters for the candidate-list output; -1 means use default. + :return: a cursor view (unique match) or a JSON-formatted candidate list. + """ + if include_body: + depth = 0 + assert max_matches != 0, "max_matches must be > 0 or equal to -1." + parsed_include_kinds: Sequence[SymbolKind] | None = [SymbolKind(k) for k in include_kinds] if include_kinds else None + parsed_exclude_kinds: Sequence[SymbolKind] | None = [SymbolKind(k) for k in exclude_kinds] if exclude_kinds else None + manager = self.agent.get_cursor_manager() + symbols = manager.find_symbols( + name_path_pattern, + relative_path=relative_path or None, + include_kinds=parsed_include_kinds, + exclude_kinds=parsed_exclude_kinds, + substring_matching=substring_matching, + ) + n_matches = len(symbols) + + if n_matches == 0: + return f"No symbols found matching '{name_path_pattern}'." + + if n_matches == 1: + cid, _ = manager.register_cursor_at_symbol(symbols[0], cursor_id=cursor_id or None) + view = manager.format_cursor_view(cid) + return f"Found unique match; started cursor {cid}.\n\n{view}" + + def candidate_list_json() -> str: + candidate_dicts = [ + s.to_dict( + kind=True, + name_path=True, + name=False, + relative_path=True, + body_location=True, + depth=depth, + body=include_body, + children_name=True, + children_name_path=False, + ) + for s in symbols + ] + return self._to_json(candidate_dicts) + + if 0 < max_matches < n_matches: + summary = f"Matched {n_matches}>{max_matches} symbols; refine your pattern or use cursor_start with a specific name path." + rel_path_to_name_paths: defaultdict[str, list[str]] = defaultdict(list) + for s in symbols: + rel_path_to_name_paths[s.location.relative_path or "unknown"].append(s.get_name_path()) + return f"{summary}\n{self._to_json(rel_path_to_name_paths)}" + + def shortened_relative_path_to_name_paths() -> str: + rel_path_to_name_paths: defaultdict[str, list[str]] = defaultdict(list) + for s in symbols: + rel_path_to_name_paths[s.location.relative_path or "unknown"].append(s.get_name_path()) + return f"Candidates (shortened):\n{self._to_json(rel_path_to_name_paths)}" + + result = f"Found {n_matches} matching symbols; pick one and call cursor_start on its name path.\n{candidate_list_json()}" + return self._limit_length(result, max_answer_chars, shortened_result_factories=[shortened_relative_path_to_name_paths]) + + +class CursorReplaceBodyTool(Tool, ToolMarkerSymbolicEdit): + """ + Replace the body of the symbol at the cursor's current position. + The cursor remains positioned on the same symbol (its stored location is refreshed). + """ + + def apply(self, cursor_id: str, body: str) -> str: + """ + Replace the body of the symbol at the cursor's current position. + + The body is the full definition of the symbol in the programming language, including + the signature line for functions. It does NOT include preceding docstrings, comments, + or imports. + + :param cursor_id: the cursor whose current symbol to replace. + :param body: the new body text. + :return: confirmation and the updated cursor view. + """ + manager = self.agent.get_cursor_manager() + state = manager.get_cursor(cursor_id) + name_path = state.current_symbol.get_name_path() + relative_path = state.current_location.relative_path + if relative_path is None: + raise ValueError(f"Cursor {cursor_id} has no relative path; cannot perform edit.") + code_editor = self.create_code_editor() + code_editor.replace_body(name_path, relative_file_path=relative_path, body=body) + manager.reanchor_cursor(cursor_id) + return f"{SUCCESS_RESULT}\n\n" + manager.format_cursor_view(cursor_id) + + +class CursorInsertBeforeTool(Tool, ToolMarkerSymbolicEdit): + """ + Insert content immediately before the symbol at the cursor's current position. + The cursor stays on the target symbol; its stored location is refreshed. + """ + + def apply(self, cursor_id: str, body: str) -> str: + """ + Insert content before the symbol at the cursor's current position. + + Typical uses: insert a new class/function above the current one, or insert a new + import statement before the first top-level symbol in a file. + + :param cursor_id: the cursor whose current symbol to insert before. + :param body: the content to insert; it will be placed immediately before the line + where the symbol is defined. + :return: confirmation and the updated cursor view. + """ + manager = self.agent.get_cursor_manager() + state = manager.get_cursor(cursor_id) + name_path = state.current_symbol.get_name_path() + relative_path = state.current_location.relative_path + if relative_path is None: + raise ValueError(f"Cursor {cursor_id} has no relative path; cannot perform edit.") + code_editor = self.create_code_editor() + code_editor.insert_before_symbol(name_path, relative_file_path=relative_path, body=body) + manager.reanchor_cursor(cursor_id) + return f"{SUCCESS_RESULT}\n\n" + manager.format_cursor_view(cursor_id) + + +class CursorInsertAfterTool(Tool, ToolMarkerSymbolicEdit): + """ + Insert content immediately after the symbol at the cursor's current position. + The cursor stays on the target symbol; its stored location is refreshed. + """ + + def apply(self, cursor_id: str, body: str) -> str: + """ + Insert content after the symbol at the cursor's current position. + + Typical use: add a new class, function, method, or variable assignment after + an existing one. + + :param cursor_id: the cursor whose current symbol to insert after. + :param body: the content to insert; it will be placed on the line following the + end of the symbol's definition. + :return: confirmation and the updated cursor view. + """ + manager = self.agent.get_cursor_manager() + state = manager.get_cursor(cursor_id) + name_path = state.current_symbol.get_name_path() + relative_path = state.current_location.relative_path + if relative_path is None: + raise ValueError(f"Cursor {cursor_id} has no relative path; cannot perform edit.") + code_editor = self.create_code_editor() + code_editor.insert_after_symbol(name_path, relative_file_path=relative_path, body=body) + manager.reanchor_cursor(cursor_id) + return f"{SUCCESS_RESULT}\n\n" + manager.format_cursor_view(cursor_id) + + +class CursorRenameTool(Tool, ToolMarkerSymbolicEdit): + """ + Rename the symbol at the cursor's current position throughout the codebase using the + language server's refactoring support. The cursor re-anchors to the renamed symbol. + """ + + def apply(self, cursor_id: str, new_name: str) -> str: + """ + Rename the symbol at the cursor's current position. + + All references to the symbol are updated via the language server's rename refactoring. + The cursor re-anchors to the renamed symbol at its new name. + + :param cursor_id: the cursor whose current symbol to rename. + :param new_name: the new name. + :return: the rename status message followed by the updated cursor view. + """ + manager = self.agent.get_cursor_manager() + state = manager.get_cursor(cursor_id) + old_name_path = state.current_symbol.get_name_path() + relative_path = state.current_location.relative_path + if relative_path is None: + raise ValueError(f"Cursor {cursor_id} has no relative path; cannot perform edit.") + code_editor = self.create_ls_code_editor() + status_message = code_editor.rename_symbol(old_name_path, relative_path=relative_path, new_name=new_name) + + # Re-anchor: the old name path's last segment is replaced by new_name + parts = old_name_path.split("/") + parts[-1] = new_name + new_name_path = "/".join(parts) + try: + manager.reanchor_cursor(cursor_id, name_path=new_name_path, relative_path=relative_path) + view = manager.format_cursor_view(cursor_id) + return f"{status_message}\n\n{view}" + except ValueError as e: + return f"{status_message}\n\n(Cursor could not re-anchor to renamed symbol: {e})" + + +class CursorOverviewTool(Tool, ToolMarkerSymbolicRead): + """ + Return an overview of the top-level symbols in a file by starting a cursor on the + file's first top-level symbol with only the ``contains`` edge active. This covers the + use case of the old ``get_symbols_overview`` tool in cursor-first form. + """ + + def apply(self, relative_path: str, cursor_id: str = "", max_answer_chars: int = -1) -> str: + """ + Show the top-level symbols in a file as a cursor view. + + Internally this delegates to the language server symbol retriever to find top-level + symbols and starts a cursor on the file's container with only the ``contains`` edge + so the output is a flat listing of top-level symbols in the file. + + :param relative_path: relative path to the source file. + :param cursor_id: optional cursor ID for the started cursor. Auto-generated otherwise. + :param max_answer_chars: maximum characters for the returned output; -1 means use default. + :return: the cursor view listing top-level symbols under ``contains``. + """ + import os + + file_path = os.path.join(self.project.project_root, relative_path) + if not os.path.exists(file_path): + raise FileNotFoundError(f"File {relative_path} does not exist in the project.") + if os.path.isdir(file_path): + raise ValueError(f"Expected a file path, but got a directory path: {relative_path}.") + + retriever = self.create_language_server_symbol_retriever() + if not retriever.can_analyze_file(relative_path): + raise ValueError( + f"Cannot extract symbols from file {relative_path}. " + f"Active languages: {[l.value for l in self.agent.get_active_lsp_languages()]}" + ) + top_level = retriever.get_symbol_overview(relative_path).get(relative_path, []) + if not top_level: + return f"No top-level symbols found in {relative_path}." + + # Use the first top-level symbol as a foothold; show its siblings by reading + # the overview directly rather than trying to force the cursor onto a synthetic file node. + lines: list[str] = [f"Top-level symbols in {relative_path}:"] + for sym in top_level: + line = sym.line + loc = f"{relative_path}:{line + 1}" if line is not None else relative_path + lines.append(f" {sym.name} ({sym.symbol_kind_name}) [{loc}]") + lines.append("") + lines.append("Use cursor_start with a name path to position on a specific symbol.") + result = "\n".join(lines) + return self._limit_length(result, max_answer_chars) diff --git a/src/serena/tools/symbol_tools.py b/src/serena/tools/symbol_tools.py index 40340e408..e90d5746b 100644 --- a/src/serena/tools/symbol_tools.py +++ b/src/serena/tools/symbol_tools.py @@ -29,7 +29,7 @@ def apply(self) -> str: return SUCCESS_RESULT -class GetSymbolsOverviewTool(Tool, ToolMarkerSymbolicRead): +class GetSymbolsOverviewTool(Tool, ToolMarkerSymbolicRead, ToolMarkerOptional): """ Gets an overview of the top-level symbols defined in a given file. """ @@ -119,7 +119,7 @@ def child_inclusion_predicate(s: LanguageServerSymbol) -> bool: return symbol_dicts -class FindSymbolTool(Tool, ToolMarkerSymbolicRead): +class FindSymbolTool(Tool, ToolMarkerSymbolicRead, ToolMarkerOptional): """ Performs a global (or local) search using the language server backend. """ @@ -236,7 +236,7 @@ def create_short_result_relative_path_to_name_paths() -> str: return self._limit_length(result, max_answer_chars, shortened_result_factories=[create_short_result_relative_path_to_name_paths]) -class FindReferencingSymbolsTool(Tool, ToolMarkerSymbolicRead): +class FindReferencingSymbolsTool(Tool, ToolMarkerSymbolicRead, ToolMarkerOptional): """ Finds symbols that reference the given symbol using the language server backend """ @@ -324,7 +324,7 @@ def make_summary() -> str: return self._limit_length(result_json, max_answer_chars, shortened_result_factories=shortened_results) -class ReplaceSymbolBodyTool(Tool, ToolMarkerSymbolicEdit): +class ReplaceSymbolBodyTool(Tool, ToolMarkerSymbolicEdit, ToolMarkerOptional): """ Replaces the full definition of a symbol using the language server backend. """ @@ -357,7 +357,7 @@ def apply( return SUCCESS_RESULT -class InsertAfterSymbolTool(Tool, ToolMarkerSymbolicEdit): +class InsertAfterSymbolTool(Tool, ToolMarkerSymbolicEdit, ToolMarkerOptional): """ Inserts content after the end of the definition of a given symbol. """ @@ -382,7 +382,7 @@ def apply( return SUCCESS_RESULT -class InsertBeforeSymbolTool(Tool, ToolMarkerSymbolicEdit): +class InsertBeforeSymbolTool(Tool, ToolMarkerSymbolicEdit, ToolMarkerOptional): """ Inserts content before the beginning of the definition of a given symbol. """ @@ -407,7 +407,7 @@ def apply( return SUCCESS_RESULT -class RenameSymbolTool(Tool, ToolMarkerSymbolicEdit): +class RenameSymbolTool(Tool, ToolMarkerSymbolicEdit, ToolMarkerOptional): """ Renames a symbol throughout the codebase using language server refactoring capabilities. For JB, we use a separate tool. diff --git a/src/solidlsp/language_servers/sourcekit_lsp.py b/src/solidlsp/language_servers/sourcekit_lsp.py index 8d30daebe..dd2ecca15 100644 --- a/src/solidlsp/language_servers/sourcekit_lsp.py +++ b/src/solidlsp/language_servers/sourcekit_lsp.py @@ -32,6 +32,39 @@ def is_ignored_dirname(self, dirname: str) -> bool: # - dist/build: common output directories return super().is_ignored_dirname(dirname) or dirname in [".build", ".swiftpm", "node_modules", "dist", "build"] + @staticmethod + def _resolve_sourcekit_lsp_path() -> str: + """Resolve the path to sourcekit-lsp, preferring Xcode's version over Command Line Tools. + + On macOS, bare `sourcekit-lsp` may resolve to the Command Line Tools version which + has limited capabilities. Using `xcrun --find sourcekit-lsp` resolves to Xcode's + full-featured version when Xcode is installed. + """ + # Try xcrun with DEVELOPER_DIR pointing to Xcode (not Command Line Tools) + xcode_path = "/Applications/Xcode.app/Contents/Developer" + for env_override in [{"DEVELOPER_DIR": xcode_path}, {}]: + try: + run_env = {**os.environ, **env_override} if env_override else None + result = subprocess.run( + ["xcrun", "--find", "sourcekit-lsp"], + capture_output=True, + text=True, + check=False, + timeout=10, + env=run_env, + ) + if result.returncode == 0: + resolved = result.stdout.strip() + if resolved and os.path.isfile(resolved): + log.info(f"Resolved sourcekit-lsp via xcrun: {resolved}") + return resolved + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + + # Fall back to bare command (Linux, or macOS without Xcode) + log.info("xcrun not available or failed, falling back to bare sourcekit-lsp") + return "sourcekit-lsp" + @staticmethod def _get_sourcekit_lsp_version() -> str: """Get the installed sourcekit-lsp version or raise error if sourcekit was not found.""" @@ -51,9 +84,17 @@ def __init__(self, config: LanguageServerConfig, repository_root_path: str, soli sourcekit_version = self._get_sourcekit_lsp_version() log.info(f"Starting sourcekit lsp with version: {sourcekit_version}") - super().__init__( - config, repository_root_path, ProcessLaunchInfo(cmd="sourcekit-lsp", cwd=repository_root_path), "swift", solidlsp_settings - ) + # Resolve sourcekit-lsp path — prefer Xcode's version over Command Line Tools + sourcekit_path = self._resolve_sourcekit_lsp_path() + + # sourcekit-lsp needs --scratch-path for background indexing and cross-file references. + # Without it, textDocument/references returns empty because there's no index store. + scratch_path = os.path.join(repository_root_path, ".build", "sourcekit-lsp") + os.makedirs(scratch_path, exist_ok=True) + cmd = [sourcekit_path, "--scratch-path", scratch_path] + log.info(f"sourcekit-lsp path: {sourcekit_path}, scratch path: {scratch_path}") + + super().__init__(config, repository_root_path, ProcessLaunchInfo(cmd=cmd, cwd=repository_root_path), "swift", solidlsp_settings) self.request_id = 0 self._did_sleep_before_requesting_references = False self._initialization_timestamp: float | None = None @@ -363,24 +404,26 @@ def request_references(self, relative_file_path: str, line: int, column: int) -> # Calculate minimum delay based on how much time has passed since initialization if self._initialization_timestamp: elapsed = time.time() - self._initialization_timestamp - # Increased CI delay for project indexing: 15s CI, 5s local - base_delay = 15 if os.getenv("CI") else 5 + # Increased delay for project indexing: 15s CI, 10s local + # 5s was insufficient for real projects — sourcekit-lsp needs time to index + base_delay = 15 if os.getenv("CI") else 10 remaining_delay = max(2, base_delay - elapsed) else: # Fallback if initialization timestamp is missing - remaining_delay = 15 if os.getenv("CI") else 5 + remaining_delay = 15 if os.getenv("CI") else 10 log.info(f"Sleeping {remaining_delay:.1f}s before requesting references for the first time (CI needs extra indexing time)") time.sleep(remaining_delay) self._did_sleep_before_requesting_references = True - # Get references with retry logic for CI stability + # Get references with retry logic — indexing may not be complete on first request references = super().request_references(relative_file_path, line, column) - # In CI, if no references found, retry once after additional delay - if os.getenv("CI") and not references: - log.info("No references found in CI - retrying after additional 5s delay") - time.sleep(5) + # If no references found, retry once after additional delay (indexing may still be in progress) + if not references: + retry_delay = 5 if os.getenv("CI") else 3 + log.info(f"No references found - retrying after additional {retry_delay}s delay (index may still be building)") + time.sleep(retry_delay) references = super().request_references(relative_file_path, line, column) return references diff --git a/src/solidlsp/ls.py b/src/solidlsp/ls.py index 9dda6f33f..8f6260ddf 100644 --- a/src/solidlsp/ls.py +++ b/src/solidlsp/ls.py @@ -1943,6 +1943,110 @@ def request_referencing_symbols( return result + def request_call_hierarchy_incoming(self, relative_file_path: str, line: int, column: int) -> list[lsp_types.CallHierarchyIncomingCall]: + """ + Finds all callers of the symbol at the given position. + + Uses the LSP callHierarchy/incomingCalls request (prepareCallHierarchy + incomingCalls). + + :param relative_file_path: The relative path of the file containing the symbol + :param line: The 0-indexed line number of the symbol + :param column: The 0-indexed column number of the symbol + :return: List of incoming calls (callers) for the symbol + """ + if not self.server_started: + raise SolidLSPException("Language Server not started") + + uri = PathUtils.path_to_uri(os.path.join(self.repository_root_path, relative_file_path)) + with self.open_file(relative_file_path): + items = self.server.send.prepare_call_hierarchy({"textDocument": {"uri": uri}, "position": {"line": line, "character": column}}) + if not items: + return [] + result: list[lsp_types.CallHierarchyIncomingCall] = [] + for item in items: + incoming = self.server.send.incoming_calls({"item": item}) + if incoming: + result.extend(incoming) + return result + + def request_call_hierarchy_outgoing(self, relative_file_path: str, line: int, column: int) -> list[lsp_types.CallHierarchyOutgoingCall]: + """ + Finds all symbols called by the symbol at the given position. + + Uses the LSP callHierarchy/outgoingCalls request (prepareCallHierarchy + outgoingCalls). + + :param relative_file_path: The relative path of the file containing the symbol + :param line: The 0-indexed line number of the symbol + :param column: The 0-indexed column number of the symbol + :return: List of outgoing calls (callees) from the symbol + """ + if not self.server_started: + raise SolidLSPException("Language Server not started") + + uri = PathUtils.path_to_uri(os.path.join(self.repository_root_path, relative_file_path)) + with self.open_file(relative_file_path): + items = self.server.send.prepare_call_hierarchy({"textDocument": {"uri": uri}, "position": {"line": line, "character": column}}) + if not items: + return [] + result: list[lsp_types.CallHierarchyOutgoingCall] = [] + for item in items: + outgoing = self.server.send.outgoing_calls({"item": item}) + if outgoing: + result.extend(outgoing) + return result + + def request_type_hierarchy_supertypes(self, relative_file_path: str, line: int, column: int) -> list[lsp_types.TypeHierarchyItem]: + """ + Finds all supertypes (parents) of the type at the given position. + + Uses the LSP typeHierarchy/supertypes request (prepareTypeHierarchy + supertypes). + + :param relative_file_path: The relative path of the file containing the type + :param line: The 0-indexed line number of the type + :param column: The 0-indexed column number of the type + :return: List of supertype items + """ + if not self.server_started: + raise SolidLSPException("Language Server not started") + + uri = PathUtils.path_to_uri(os.path.join(self.repository_root_path, relative_file_path)) + with self.open_file(relative_file_path): + items = self.server.send.prepare_type_hierarchy({"textDocument": {"uri": uri}, "position": {"line": line, "character": column}}) + if not items: + return [] + result: list[lsp_types.TypeHierarchyItem] = [] + for item in items: + supertypes = self.server.send.type_hierarchy_supertypes({"item": item}) + if supertypes: + result.extend(supertypes) + return result + + def request_type_hierarchy_subtypes(self, relative_file_path: str, line: int, column: int) -> list[lsp_types.TypeHierarchyItem]: + """ + Finds all subtypes (children) of the type at the given position. + + Uses the LSP typeHierarchy/subtypes request (prepareTypeHierarchy + subtypes). + + :param relative_file_path: The relative path of the file containing the type + :param line: The 0-indexed line number of the type + :param column: The 0-indexed column number of the type + :return: List of subtype items + """ + if not self.server_started: + raise SolidLSPException("Language Server not started") + + uri = PathUtils.path_to_uri(os.path.join(self.repository_root_path, relative_file_path)) + with self.open_file(relative_file_path): + items = self.server.send.prepare_type_hierarchy({"textDocument": {"uri": uri}, "position": {"line": line, "character": column}}) + if not items: + return [] + result: list[lsp_types.TypeHierarchyItem] = [] + for item in items: + subtypes = self.server.send.type_hierarchy_subtypes({"item": item}) + if subtypes: + result.extend(subtypes) + return result + def request_containing_symbol( self, relative_file_path: str, diff --git a/test/serena/test_cursor.py b/test/serena/test_cursor.py new file mode 100644 index 000000000..c1cc98703 --- /dev/null +++ b/test/serena/test_cursor.py @@ -0,0 +1,702 @@ +""" +Unit tests for cursor-based code navigation (CursorManager, CursorState, EdgeType, NeighborSymbol). + +These tests mock the LSP layer to test cursor logic in isolation. +""" + +import os +import tempfile +from unittest.mock import MagicMock, patch + +import pytest + +from serena.cursor import ( + ALL_EDGE_TYPES, + DEFAULT_EDGE_TYPES, + CursorManager, + CursorState, + EdgeType, + NeighborSymbol, +) +from serena.symbol import LanguageServerSymbol, LanguageServerSymbolLocation +from solidlsp.ls_exceptions import SolidLSPException +from solidlsp.ls_utils import PathUtils + +# ── Cross-platform test root ──────────────────────────────────────────── +# Use the real temp directory so that file URIs resolve on all platforms +# (on Windows, /tmp is a UNC path on a different mount than D:). +_TEST_PROJECT_ROOT = os.path.join(tempfile.gettempdir(), "test_project") + +# ── Helpers ────────────────────────────────────────────────────────────── + + +def _make_file_uri(relative_path: str) -> str: + """Build a file:// URI for a path relative to the test project root.""" + return PathUtils.path_to_uri(os.path.join(_TEST_PROJECT_ROOT, relative_path)) + + +def _make_symbol( + name: str = "MyClass", + kind_name: str = "Class", + rel_path: str | None = "src/module.py", + line: int | None = 10, + col: int | None = 0, + name_path: str | None = None, + children: list | None = None, + body: str | None = None, +) -> MagicMock: + """Create a minimal mock LanguageServerSymbol.""" + sym = MagicMock(spec=LanguageServerSymbol) + sym.name = name + sym.symbol_kind_name = kind_name + sym.relative_path = rel_path + sym.line = line + sym.column = col + sym.body = body + sym.get_name_path.return_value = name_path or name + sym.location = LanguageServerSymbolLocation(relative_path=rel_path, line=line, column=col) + + child_mocks = [] + for c in children or []: + child_mocks.append(_make_symbol(**c)) + sym.iter_children.return_value = iter(child_mocks) + return sym + + +def _make_project(project_root: str = _TEST_PROJECT_ROOT) -> MagicMock: + """Create a minimal mock Project.""" + project = MagicMock() + project.project_root = project_root + return project + + +def _make_manager(project_root: str = _TEST_PROJECT_ROOT) -> CursorManager: + return CursorManager(_make_project(project_root)) + + +# ── EdgeType / NeighborSymbol ──────────────────────────────────────────── + + +class TestEdgeType: + def test_all_edge_types_contains_all_members(self): + assert frozenset(EdgeType) == ALL_EDGE_TYPES + + def test_default_edge_types_subset_of_all(self): + assert DEFAULT_EDGE_TYPES <= ALL_EDGE_TYPES + + def test_edge_type_values(self): + assert EdgeType.CONTAINS.value == "contains" + assert EdgeType.REFERENCES.value == "references" + assert EdgeType.REFERENCED_BY.value == "referenced-by" + assert EdgeType.CALLS.value == "calls" + assert EdgeType.CALLED_BY.value == "called-by" + assert EdgeType.INHERITS.value == "inherits" + assert EdgeType.INHERITED_BY.value == "inherited-by" + + def test_seven_edge_types(self): + assert len(EdgeType) == 7 + + +class TestNeighborSymbol: + def test_location_str_with_path_and_line(self): + n = NeighborSymbol(name="foo", kind="Function", relative_path="src/a.py", line=9, column=4, edge_type=EdgeType.CONTAINS) + assert n.location_str == "src/a.py:10" # 0-indexed line displayed as 1-indexed + + def test_location_str_with_path_no_line(self): + n = NeighborSymbol(name="foo", kind="Module", relative_path="src/a.py", line=None, column=None, edge_type=EdgeType.CONTAINS) + assert n.location_str == "src/a.py" + + def test_location_str_without_path(self): + n = NeighborSymbol(name="foo", kind="Function", relative_path=None, line=None, column=None, edge_type=EdgeType.REFERENCES) + assert n.location_str == "?" + + def test_format_compact_basic(self): + n = NeighborSymbol(name="bar", kind="Method", relative_path="x.py", line=5, column=0, edge_type=EdgeType.CALLS) + formatted = n.format_compact() + assert "bar" in formatted + assert "(Method)" in formatted + assert "[x.py:6]" in formatted + + def test_format_compact_with_detail(self): + n = NeighborSymbol( + name="bar", kind="Method", relative_path="x.py", line=5, column=0, edge_type=EdgeType.CALLS, detail="returns int" + ) + assert "— returns int" in n.format_compact() + + +# ── CursorState ────────────────────────────────────────────────────────── + + +class TestCursorState: + def test_initial_state_has_empty_trail(self): + sym = _make_symbol() + loc = sym.location + state = CursorState(cursor_id="c1", current_symbol=sym, current_location=loc) + assert state.trail == [] + assert state.cursor_id == "c1" + assert state.active_edge_types == DEFAULT_EDGE_TYPES + assert state.include_body is False + + def test_record_move_appends_to_trail(self): + sym1 = _make_symbol(name="A", line=1) + sym2 = _make_symbol(name="B", line=20) + state = CursorState(cursor_id="c1", current_symbol=sym1, current_location=sym1.location) + + state.record_move(sym2, sym2.location) + + assert len(state.trail) == 1 + assert state.trail[0].line == 1 + assert state.current_symbol is sym2 + assert state.current_location is sym2.location + + def test_record_multiple_moves_builds_trail(self): + syms = [_make_symbol(name=f"S{i}", line=i * 10) for i in range(4)] + state = CursorState(cursor_id="c1", current_symbol=syms[0], current_location=syms[0].location) + + for s in syms[1:]: + state.record_move(s, s.location) + + assert len(state.trail) == 3 + assert state.current_symbol is syms[3] + + def test_custom_edge_types(self): + sym = _make_symbol() + edges = frozenset({EdgeType.CONTAINS, EdgeType.CALLS}) + state = CursorState(cursor_id="c1", current_symbol=sym, current_location=sym.location, active_edge_types=edges) + assert state.active_edge_types == edges + + +# ── CursorManager: start / get / list / close ──────────────────────────── + + +class TestCursorManagerLifecycle: + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_start_cursor_auto_id(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol() + mock_retriever_cls.return_value.find_unique.return_value = sym + + cid, state = manager.start_cursor("MyClass") + + assert cid == "c1" + assert state.cursor_id == "c1" + assert state.current_symbol is sym + assert state.trail == [] + mock_retriever_cls.return_value.find_unique.assert_called_once_with("MyClass", within_relative_path=None) + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_start_cursor_explicit_id(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol() + mock_retriever_cls.return_value.find_unique.return_value = sym + + cid, state = manager.start_cursor("MyClass", cursor_id="custom") + + assert cid == "custom" + assert state.cursor_id == "custom" + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_start_cursor_with_relative_path(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol() + mock_retriever_cls.return_value.find_unique.return_value = sym + + manager.start_cursor("MyClass", relative_path="src/module.py") + + mock_retriever_cls.return_value.find_unique.assert_called_once_with("MyClass", within_relative_path="src/module.py") + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_auto_id_increments(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol() + mock_retriever_cls.return_value.find_unique.return_value = sym + + cid1, _ = manager.start_cursor("A") + cid2, _ = manager.start_cursor("B") + cid3, _ = manager.start_cursor("C") + + assert cid1 == "c1" + assert cid2 == "c2" + assert cid3 == "c3" + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_get_cursor_success(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol() + mock_retriever_cls.return_value.find_unique.return_value = sym + + cid, _ = manager.start_cursor("MyClass") + state = manager.get_cursor(cid) + + assert state.cursor_id == cid + + def test_get_cursor_nonexistent_raises(self): + manager = _make_manager() + with pytest.raises(ValueError, match="No cursor with id"): + manager.get_cursor("nonexistent") + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_list_cursors(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol() + mock_retriever_cls.return_value.find_unique.return_value = sym + + assert manager.list_cursors() == [] + + manager.start_cursor("A") + manager.start_cursor("B") + + assert sorted(manager.list_cursors()) == ["c1", "c2"] + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_close_cursor(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol() + mock_retriever_cls.return_value.find_unique.return_value = sym + + cid, _ = manager.start_cursor("A") + assert cid in manager.list_cursors() + + manager.close_cursor(cid) + assert cid not in manager.list_cursors() + + def test_close_nonexistent_cursor_is_noop(self): + manager = _make_manager() + manager.close_cursor("nope") # should not raise + + +# ── CursorManager: move ────────────────────────────────────────────────── + + +class TestCursorManagerMove: + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_move_cursor_to_neighbor(self, mock_retriever_cls): + """Move cursor to a symbol by name, falling back to global lookup when not in neighbors.""" + manager = _make_manager() + + sym_a = _make_symbol(name="ClassA", line=10) + sym_b = _make_symbol(name="method_b", line=20) + mock_retriever = mock_retriever_cls.return_value + mock_retriever.find_unique.return_value = sym_a + + cid, _ = manager.start_cursor("ClassA") + + # Now set up for move: find_unique returns sym_b for the fallback path + mock_retriever.find_unique.return_value = sym_b + + state = manager.move_cursor(cid, "method_b") + + assert state.current_symbol is sym_b + assert len(state.trail) == 1 + assert state.trail[0].line == 10 # original position + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_move_cursor_records_trail(self, mock_retriever_cls): + """Each move appends the previous location to trail.""" + manager = _make_manager() + mock_retriever = mock_retriever_cls.return_value + + syms = [_make_symbol(name=f"S{i}", line=i * 10) for i in range(4)] + mock_retriever.find_unique.side_effect = syms + + cid, _ = manager.start_cursor("S0") + manager.move_cursor(cid, "S1") + manager.move_cursor(cid, "S2") + manager.move_cursor(cid, "S3") + + state = manager.get_cursor(cid) + assert len(state.trail) == 3 + assert state.current_symbol is syms[3] + + +# ── CursorManager: resolve_neighbors (mocked LSP) ─────────────────────── + + +class TestResolveNeighbors: + def _setup_manager_with_cursor(self, mock_retriever_cls, sym=None, children=None): + """Helper that creates a manager, starts a cursor, and returns (manager, cursor_id, mock_ls).""" + manager = _make_manager() + if sym is None: + child_specs = children or [] + sym = _make_symbol(name="MyClass", line=10, col=0, children=child_specs) + mock_retriever = mock_retriever_cls.return_value + mock_retriever.find_unique.return_value = sym + + mock_ls = MagicMock() + mock_retriever.get_language_server.return_value = mock_ls + + # Default: all LSP methods return empty + mock_ls.request_definition.return_value = [] + mock_ls.request_referencing_symbols.return_value = [] + mock_ls.request_call_hierarchy_outgoing.return_value = [] + mock_ls.request_call_hierarchy_incoming.return_value = [] + mock_ls.request_type_hierarchy_supertypes.return_value = [] + mock_ls.request_type_hierarchy_subtypes.return_value = [] + + cid, _ = manager.start_cursor("MyClass") + return manager, cid, mock_ls + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_contains_edge_resolves_children(self, mock_retriever_cls): + children = [ + {"name": "method_a", "kind_name": "Method", "line": 12, "col": 4}, + {"name": "method_b", "kind_name": "Method", "line": 20, "col": 4}, + ] + manager, cid, _ = self._setup_manager_with_cursor(mock_retriever_cls, children=children) + + neighbors = manager.resolve_neighbors(cid) + contains_neighbors = [n for n in neighbors if n.edge_type == EdgeType.CONTAINS] + + assert len(contains_neighbors) == 2 + names = {n.name for n in contains_neighbors} + assert names == {"method_a", "method_b"} + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_references_edge(self, mock_retriever_cls): + manager, cid, mock_ls = self._setup_manager_with_cursor(mock_retriever_cls) + + mock_ls.request_definition.return_value = [ + { + "relativePath": "src/other.py", + "range": {"start": {"line": 5, "character": 0}, "end": {"line": 5, "character": 10}}, + } + ] + # Stub _symbol_name_at to return a name + manager._symbol_name_at = MagicMock(return_value="OtherClass") + + neighbors = manager.resolve_neighbors(cid) + ref_neighbors = [n for n in neighbors if n.edge_type == EdgeType.REFERENCES] + + assert len(ref_neighbors) == 1 + assert ref_neighbors[0].name == "OtherClass" + assert ref_neighbors[0].relative_path == "src/other.py" + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_referenced_by_edge(self, mock_retriever_cls): + manager, cid, mock_ls = self._setup_manager_with_cursor(mock_retriever_cls) + + ref_mock = MagicMock() + ref_mock.symbol = { + "name": "caller_func", + "kind": 12, # SymbolKind.Function + "location": {"relativePath": "src/caller.py"}, + "selectionRange": {"start": {"line": 30, "character": 4}}, + } + mock_ls.request_referencing_symbols.return_value = [ref_mock] + + neighbors = manager.resolve_neighbors(cid) + refby = [n for n in neighbors if n.edge_type == EdgeType.REFERENCED_BY] + + assert len(refby) == 1 + assert refby[0].name == "caller_func" + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_calls_edge(self, mock_retriever_cls): + manager, cid, mock_ls = self._setup_manager_with_cursor(mock_retriever_cls) + + mock_ls.request_call_hierarchy_outgoing.return_value = [ + { + "to": { + "name": "helper", + "kind": 12, + "uri": _make_file_uri("src/utils.py"), + "selectionRange": {"start": {"line": 5, "character": 0}}, + "range": {"start": {"line": 5, "character": 0}, "end": {"line": 10, "character": 0}}, + } + } + ] + + neighbors = manager.resolve_neighbors(cid) + calls = [n for n in neighbors if n.edge_type == EdgeType.CALLS] + + assert len(calls) == 1 + assert calls[0].name == "helper" + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_called_by_edge(self, mock_retriever_cls): + manager, cid, mock_ls = self._setup_manager_with_cursor(mock_retriever_cls) + + mock_ls.request_call_hierarchy_incoming.return_value = [ + { + "from": { + "name": "main", + "kind": 12, + "uri": _make_file_uri("src/main.py"), + "selectionRange": {"start": {"line": 1, "character": 0}}, + "range": {"start": {"line": 1, "character": 0}, "end": {"line": 20, "character": 0}}, + } + } + ] + + neighbors = manager.resolve_neighbors(cid) + called_by = [n for n in neighbors if n.edge_type == EdgeType.CALLED_BY] + + assert len(called_by) == 1 + assert called_by[0].name == "main" + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_inherits_edge(self, mock_retriever_cls): + manager, cid, mock_ls = self._setup_manager_with_cursor(mock_retriever_cls) + + mock_ls.request_type_hierarchy_supertypes.return_value = [ + { + "name": "BaseClass", + "kind": 5, + "uri": _make_file_uri("src/base.py"), + "selectionRange": {"start": {"line": 3, "character": 0}}, + "range": {"start": {"line": 3, "character": 0}, "end": {"line": 30, "character": 0}}, + } + ] + + neighbors = manager.resolve_neighbors(cid) + inherits = [n for n in neighbors if n.edge_type == EdgeType.INHERITS] + + assert len(inherits) == 1 + assert inherits[0].name == "BaseClass" + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_inherited_by_edge(self, mock_retriever_cls): + manager, cid, mock_ls = self._setup_manager_with_cursor(mock_retriever_cls) + + mock_ls.request_type_hierarchy_subtypes.return_value = [ + { + "name": "ChildClass", + "kind": 5, + "uri": _make_file_uri("src/child.py"), + "selectionRange": {"start": {"line": 1, "character": 0}}, + "range": {"start": {"line": 1, "character": 0}, "end": {"line": 15, "character": 0}}, + } + ] + + neighbors = manager.resolve_neighbors(cid) + inherited_by = [n for n in neighbors if n.edge_type == EdgeType.INHERITED_BY] + + assert len(inherited_by) == 1 + assert inherited_by[0].name == "ChildClass" + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_edge_type_filtering(self, mock_retriever_cls): + """Only active edge types are resolved.""" + children = [{"name": "child", "kind_name": "Method", "line": 12, "col": 4}] + manager, cid, mock_ls = self._setup_manager_with_cursor(mock_retriever_cls, children=children) + + mock_ls.request_call_hierarchy_outgoing.return_value = [ + { + "to": { + "name": "helper", + "kind": 12, + "uri": _make_file_uri("src/utils.py"), + "selectionRange": {"start": {"line": 5, "character": 0}}, + "range": {"start": {"line": 5, "character": 0}, "end": {"line": 10, "character": 0}}, + } + } + ] + + # Restrict to only CALLS + state = manager.get_cursor(cid) + state.active_edge_types = frozenset({EdgeType.CALLS}) + + neighbors = manager.resolve_neighbors(cid) + assert all(n.edge_type == EdgeType.CALLS for n in neighbors) + assert len(neighbors) == 1 + # Contains should NOT appear despite having children + assert not any(n.edge_type == EdgeType.CONTAINS for n in neighbors) + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_lsp_error_is_swallowed(self, mock_retriever_cls): + """LSP errors for individual edge types are caught; other edges still resolve.""" + children = [{"name": "child", "kind_name": "Method", "line": 12, "col": 4}] + manager, cid, mock_ls = self._setup_manager_with_cursor(mock_retriever_cls, children=children) + + mock_ls.request_definition.side_effect = SolidLSPException("LSP error") + mock_ls.request_referencing_symbols.side_effect = SolidLSPException("LSP error") + + # Should still get contains neighbors despite reference errors + neighbors = manager.resolve_neighbors(cid) + assert any(n.edge_type == EdgeType.CONTAINS for n in neighbors) + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_incomplete_location_returns_empty(self, mock_retriever_cls): + """Cursor with None location fields returns no neighbors.""" + manager = _make_manager() + sym = _make_symbol(name="Orphan", rel_path=None, line=None, col=None) + # Override the location to have None fields + sym.location = LanguageServerSymbolLocation(relative_path=None, line=None, column=None) + mock_retriever_cls.return_value.find_unique.return_value = sym + + cid, _ = manager.start_cursor("Orphan") + neighbors = manager.resolve_neighbors(cid) + assert neighbors == [] + + +# ── CursorManager: format_cursor_view / format_trail ───────────────────── + + +class TestFormatting: + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_format_cursor_view_header(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol(name="MyClass", kind_name="Class", rel_path="src/m.py", line=10, name_path="MyClass") + mock_retriever = mock_retriever_cls.return_value + mock_retriever.find_unique.return_value = sym + mock_retriever.get_language_server.return_value = MagicMock( + request_definition=MagicMock(return_value=[]), + request_referencing_symbols=MagicMock(return_value=[]), + request_call_hierarchy_outgoing=MagicMock(return_value=[]), + request_call_hierarchy_incoming=MagicMock(return_value=[]), + request_type_hierarchy_supertypes=MagicMock(return_value=[]), + request_type_hierarchy_subtypes=MagicMock(return_value=[]), + ) + + cid, _ = manager.start_cursor("MyClass") + view = manager.format_cursor_view(cid) + + assert "@ MyClass (Class)" in view + assert "m.py:11" in view # 0-indexed line 10 → displayed as 11 + assert f"cursor: {cid}" in view + assert "trail: 0 steps" in view + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_format_trail_empty(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol() + mock_retriever_cls.return_value.find_unique.return_value = sym + + cid, _ = manager.start_cursor("MyClass") + trail = manager.format_trail(cid) + + assert "no trail" in trail + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_format_trail_with_moves(self, mock_retriever_cls): + manager = _make_manager() + mock_retriever = mock_retriever_cls.return_value + + sym_a = _make_symbol(name="A", rel_path="a.py", line=5) + sym_b = _make_symbol(name="B", rel_path="b.py", line=15) + mock_retriever.find_unique.side_effect = [sym_a, sym_b] + + cid, _ = manager.start_cursor("A") + manager.move_cursor(cid, "B") + + trail = manager.format_trail(cid) + assert "1 steps" in trail + assert "a.py:6" in trail # trail entry (0-indexed 5 → display 6) + assert "b.py:16" in trail # current position + assert "(current)" in trail + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_format_cursor_view_with_body(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol(name="func", kind_name="Function", line=10, body="def func(): pass") + mock_retriever = mock_retriever_cls.return_value + mock_retriever.find_unique.return_value = sym + mock_retriever.get_language_server.return_value = MagicMock( + request_definition=MagicMock(return_value=[]), + request_referencing_symbols=MagicMock(return_value=[]), + request_call_hierarchy_outgoing=MagicMock(return_value=[]), + request_call_hierarchy_incoming=MagicMock(return_value=[]), + request_type_hierarchy_supertypes=MagicMock(return_value=[]), + request_type_hierarchy_subtypes=MagicMock(return_value=[]), + ) + + cid, _ = manager.start_cursor("func") + state = manager.get_cursor(cid) + state.include_body = True + view = manager.format_cursor_view(cid) + + assert "--- body ---" in view + assert "def func(): pass" in view + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_format_cursor_view_no_neighbors_message(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol(name="Leaf", children=[]) + mock_retriever = mock_retriever_cls.return_value + mock_retriever.find_unique.return_value = sym + mock_retriever.get_language_server.return_value = MagicMock( + request_definition=MagicMock(return_value=[]), + request_referencing_symbols=MagicMock(return_value=[]), + request_call_hierarchy_outgoing=MagicMock(return_value=[]), + request_call_hierarchy_incoming=MagicMock(return_value=[]), + request_type_hierarchy_supertypes=MagicMock(return_value=[]), + request_type_hierarchy_subtypes=MagicMock(return_value=[]), + ) + + cid, _ = manager.start_cursor("Leaf") + view = manager.format_cursor_view(cid) + + assert "(no neighbors found)" in view + + +# ── CursorManager: find_symbols / register_cursor_at_symbol / reanchor_cursor ── + + +class TestFindAndAnchor: + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_find_symbols_delegates_to_retriever(self, mock_retriever_cls): + manager = _make_manager() + syms = [_make_symbol(name="A"), _make_symbol(name="B")] + mock_retriever_cls.return_value.find.return_value = syms + + results = manager.find_symbols("A", relative_path="src/x.py", substring_matching=True) + + assert results == syms + mock_retriever_cls.return_value.find.assert_called_once() + kwargs = mock_retriever_cls.return_value.find.call_args.kwargs + assert kwargs["within_relative_path"] == "src/x.py" + assert kwargs["substring_matching"] is True + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_register_cursor_at_symbol_auto_id(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol(name="Target", line=3) + + cid, state = manager.register_cursor_at_symbol(sym) + + assert cid == "c1" + assert state.cursor_id == "c1" + assert state.current_symbol is sym + assert state.trail == [] + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_register_cursor_at_symbol_duplicate_id_raises(self, mock_retriever_cls): + manager = _make_manager() + sym = _make_symbol() + manager.register_cursor_at_symbol(sym, cursor_id="dup") + with pytest.raises(ValueError, match="already exists"): + manager.register_cursor_at_symbol(sym, cursor_id="dup") + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_reanchor_cursor_refreshes_position(self, mock_retriever_cls): + """reanchor_cursor re-resolves the cursor's symbol via the retriever.""" + manager = _make_manager() + orig_sym = _make_symbol(name="Target", line=10, name_path="Module/Target") + new_sym = _make_symbol(name="Target", line=15, name_path="Module/Target") + mock_retriever_cls.return_value.find_unique.return_value = orig_sym + + cid, _ = manager.start_cursor("Module/Target") + + # After an edit, the symbol moved to line 15 + mock_retriever_cls.return_value.find_unique.return_value = new_sym + state = manager.reanchor_cursor(cid) + + assert state.current_symbol is new_sym + assert state.current_location.line == 15 + + @patch("serena.cursor.LanguageServerSymbolRetriever") + def test_reanchor_cursor_with_new_name(self, mock_retriever_cls): + """reanchor_cursor can look up a renamed symbol.""" + manager = _make_manager() + orig_sym = _make_symbol(name="old_name", line=10) + renamed_sym = _make_symbol(name="new_name", line=10, name_path="new_name") + mock_retriever_cls.return_value.find_unique.return_value = orig_sym + + cid, _ = manager.start_cursor("old_name") + + mock_retriever_cls.return_value.find_unique.return_value = renamed_sym + state = manager.reanchor_cursor(cid, name_path="new_name") + + assert state.current_symbol is renamed_sym + # Verify the retriever was called with the new name + assert mock_retriever_cls.return_value.find_unique.call_args.args == ("new_name",) diff --git a/test/serena/test_cursor_integration.py b/test/serena/test_cursor_integration.py new file mode 100644 index 000000000..c25f90dd7 --- /dev/null +++ b/test/serena/test_cursor_integration.py @@ -0,0 +1,219 @@ +""" +Integration tests for cursor-based code navigation using the Python test repository. + +These tests start a real language server against the Python test repo and exercise +the full cursor lifecycle: start, navigate, configure, trail, and close. +""" + +import pytest + +from serena.cursor import CursorManager, EdgeType +from solidlsp.ls_config import Language +from test.conftest import project_with_ls_context + +pytestmark = pytest.mark.python + + +@pytest.fixture(scope="module") +def cursor_manager(): + """Create a CursorManager backed by the real Python test repo with a live LSP server.""" + with project_with_ls_context(Language.PYTHON) as project: + yield CursorManager(project) + + +class TestCursorStartAndLook: + def test_start_cursor_at_class(self, cursor_manager: CursorManager): + """Start a cursor at the UserService class and verify position.""" + cid, state = cursor_manager.start_cursor("UserService", relative_path="test_repo/services.py") + try: + assert state.cursor_id == cid + assert state.current_symbol.name == "UserService" + assert state.trail == [] + assert state.current_location.relative_path is not None + assert "services.py" in state.current_location.relative_path + finally: + cursor_manager.close_cursor(cid) + + def test_start_cursor_at_function(self, cursor_manager: CursorManager): + """Start a cursor at a standalone function.""" + cid, state = cursor_manager.start_cursor("create_user_object", relative_path="test_repo/models.py") + try: + assert state.current_symbol.name == "create_user_object" + finally: + cursor_manager.close_cursor(cid) + + +class TestContainsEdge: + def test_class_contains_methods(self, cursor_manager: CursorManager): + """A class cursor should see its methods via the contains edge.""" + cid, _ = cursor_manager.start_cursor("UserService", relative_path="test_repo/services.py") + try: + neighbors = cursor_manager.resolve_neighbors(cid) + contains = [n for n in neighbors if n.edge_type == EdgeType.CONTAINS] + names = {n.name for n in contains} + # UserService has __init__, create_user, get_user, list_users, delete_user + assert "create_user" in names + assert "get_user" in names + finally: + cursor_manager.close_cursor(cid) + + def test_outer_class_contains_nested(self, cursor_manager: CursorManager): + """OuterClass should contain NestedClass via contains edge.""" + cid, _ = cursor_manager.start_cursor("OuterClass", relative_path="test_repo/nested.py") + try: + neighbors = cursor_manager.resolve_neighbors(cid) + contains = [n for n in neighbors if n.edge_type == EdgeType.CONTAINS] + names = {n.name for n in contains} + assert "NestedClass" in names + finally: + cursor_manager.close_cursor(cid) + + +class TestNavigationAndTrail: + def test_move_to_child_method(self, cursor_manager: CursorManager): + """Navigate from a class to one of its methods.""" + cid, _ = cursor_manager.start_cursor("UserService", relative_path="test_repo/services.py") + try: + state = cursor_manager.move_cursor(cid, "create_user") + assert state.current_symbol.name == "create_user" + assert len(state.trail) == 1 # UserService position recorded + finally: + cursor_manager.close_cursor(cid) + + def test_multi_step_trail(self, cursor_manager: CursorManager): + """Navigate multiple hops and verify the trail records each step. + + Navigate from OuterClass -> NestedClass -> find_me to avoid ambiguity + issues with 'user' substring matching in the services module. + """ + cid, _ = cursor_manager.start_cursor("OuterClass", relative_path="test_repo/nested.py") + try: + cursor_manager.move_cursor(cid, "NestedClass") + cursor_manager.move_cursor(cid, "find_me") + + state = cursor_manager.get_cursor(cid) + assert len(state.trail) == 2 + assert state.current_symbol.name == "find_me" + finally: + cursor_manager.close_cursor(cid) + + def test_format_trail_after_moves(self, cursor_manager: CursorManager): + """The formatted trail should show visited locations.""" + cid, _ = cursor_manager.start_cursor("UserService", relative_path="test_repo/services.py") + try: + cursor_manager.move_cursor(cid, "create_user") + trail_text = cursor_manager.format_trail(cid) + assert "1 steps" in trail_text + assert "services.py" in trail_text + assert "(current)" in trail_text + finally: + cursor_manager.close_cursor(cid) + + +class TestEdgeTypeConfiguration: + def test_restrict_to_contains_only(self, cursor_manager: CursorManager): + """When only CONTAINS is active, only children should appear.""" + cid, _ = cursor_manager.start_cursor("UserService", relative_path="test_repo/services.py") + try: + state = cursor_manager.get_cursor(cid) + state.active_edge_types = frozenset({EdgeType.CONTAINS}) + + neighbors = cursor_manager.resolve_neighbors(cid) + assert all(n.edge_type == EdgeType.CONTAINS for n in neighbors) + assert len(neighbors) > 0 + finally: + cursor_manager.close_cursor(cid) + + def test_disable_all_edges_yields_no_neighbors(self, cursor_manager: CursorManager): + """With no active edge types, resolve_neighbors returns empty.""" + cid, _ = cursor_manager.start_cursor("UserService", relative_path="test_repo/services.py") + try: + state = cursor_manager.get_cursor(cid) + state.active_edge_types = frozenset() + + neighbors = cursor_manager.resolve_neighbors(cid) + assert neighbors == [] + finally: + cursor_manager.close_cursor(cid) + + +class TestFormatCursorView: + def test_view_contains_symbol_info(self, cursor_manager: CursorManager): + """format_cursor_view should include the symbol name, kind, and location.""" + cid, _ = cursor_manager.start_cursor("UserService", relative_path="test_repo/services.py") + try: + view = cursor_manager.format_cursor_view(cid) + assert "UserService" in view + assert "services.py" in view + assert f"cursor: {cid}" in view + finally: + cursor_manager.close_cursor(cid) + + def test_view_lists_neighbors(self, cursor_manager: CursorManager): + """The cursor view should show neighbors grouped by edge type.""" + cid, _ = cursor_manager.start_cursor("UserService", relative_path="test_repo/services.py") + try: + view = cursor_manager.format_cursor_view(cid) + # Should at least show contains section with methods + assert "contains:" in view + assert "create_user" in view + finally: + cursor_manager.close_cursor(cid) + + +class TestInheritanceEdges: + def test_inherits_edge_does_not_crash(self, cursor_manager: CursorManager): + """Type hierarchy edges should resolve without error, even if the LSP returns empty. + + Note: Pyright (Python LSP) does not support textDocument/typeHierarchy, so these + edges typically return empty for Python. The test validates that the edge resolution + completes gracefully. + """ + cid, _ = cursor_manager.start_cursor("User", relative_path="test_repo/models.py") + try: + state = cursor_manager.get_cursor(cid) + state.active_edge_types = frozenset({EdgeType.INHERITS, EdgeType.INHERITED_BY}) + + neighbors = cursor_manager.resolve_neighbors(cid) + # We don't assert specific results — Pyright may not support type hierarchy. + # Just verify no crash and that returned neighbors (if any) have correct edge types. + for n in neighbors: + assert n.edge_type in (EdgeType.INHERITS, EdgeType.INHERITED_BY) + finally: + cursor_manager.close_cursor(cid) + + +class TestMultipleCursors: + def test_independent_cursors(self, cursor_manager: CursorManager): + """Multiple cursors can coexist and track independent state.""" + cid1, _ = cursor_manager.start_cursor("UserService", relative_path="test_repo/services.py") + cid2, _ = cursor_manager.start_cursor("ItemService", relative_path="test_repo/services.py") + try: + assert cid1 != cid2 + assert set(cursor_manager.list_cursors()) >= {cid1, cid2} + + # Move one cursor, the other should be unaffected + cursor_manager.move_cursor(cid1, "create_user") + state1 = cursor_manager.get_cursor(cid1) + state2 = cursor_manager.get_cursor(cid2) + + assert state1.current_symbol.name == "create_user" + assert state2.current_symbol.name == "ItemService" + assert len(state1.trail) == 1 + assert len(state2.trail) == 0 + finally: + cursor_manager.close_cursor(cid1) + cursor_manager.close_cursor(cid2) + + def test_close_one_preserves_other(self, cursor_manager: CursorManager): + """Closing one cursor should not affect another.""" + cid1, _ = cursor_manager.start_cursor("UserService", relative_path="test_repo/services.py") + cid2, _ = cursor_manager.start_cursor("ItemService", relative_path="test_repo/services.py") + try: + cursor_manager.close_cursor(cid1) + assert cid1 not in cursor_manager.list_cursors() + # cid2 should still be accessible + state2 = cursor_manager.get_cursor(cid2) + assert state2.current_symbol.name == "ItemService" + finally: + cursor_manager.close_cursor(cid2) diff --git a/test/serena/test_cursor_navigation.py b/test/serena/test_cursor_navigation.py new file mode 100644 index 000000000..65038ca7e --- /dev/null +++ b/test/serena/test_cursor_navigation.py @@ -0,0 +1,670 @@ +""" +Tests for cursor-based code navigation. + +Tests CursorManager against a live Python LSP, and exercises the cursor MCP tools +through SerenaAgent.get_tool() against the Python test repository. +""" + +import os +import re +from collections.abc import Iterator + +import pytest + +from serena.agent import SerenaAgent +from serena.config.serena_config import ProjectConfig, RegisteredProject, SerenaConfig +from serena.cursor import ALL_EDGE_TYPES, DEFAULT_EDGE_TYPES, CursorManager, CursorState, EdgeType, NeighborSymbol +from serena.project import Project +from serena.tools.cursor_tools import ( + CursorCloseTool, + CursorConfigureTool, + CursorFindTool, + CursorHistoryTool, + CursorInsertAfterTool, + CursorInsertBeforeTool, + CursorLookTool, + CursorMoveTool, + CursorOverviewTool, + CursorRenameTool, + CursorReplaceBodyTool, + CursorStartTool, +) +from solidlsp.ls_config import Language +from test.conftest import get_repo_path, language_tests_enabled, project_with_ls_context + +pytestmark = pytest.mark.python + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def python_project(): + """A Project with an active Python language server for cursor tests.""" + with project_with_ls_context(Language.PYTHON) as project: + yield project + + +@pytest.fixture +def cursor_manager(python_project: Project) -> CursorManager: + """A fresh CursorManager for each test.""" + return CursorManager(python_project) + + +@pytest.fixture(scope="module") +def python_serena_agent(): + """SerenaAgent configured for the Python test repo.""" + if not language_tests_enabled(Language.PYTHON): + pytest.skip("Python tests not enabled") + + config = SerenaConfig(gui_log_window=False, web_dashboard=False) + repo_path = get_repo_path(Language.PYTHON) + project = Project( + project_root=str(repo_path), + project_config=ProjectConfig( + project_name="test_repo_python", + languages=[Language.PYTHON], + ignored_paths=[], + excluded_tools=[], + read_only=False, + ignore_all_files_in_gitignore=True, + initial_prompt="", + encoding="utf-8", + ), + serena_config=config, + ) + config.projects = [RegisteredProject.from_project_instance(project)] + agent = SerenaAgent(project="test_repo_python", serena_config=config) + agent.execute_task(lambda: None) + yield agent + agent.on_shutdown(timeout=5) + + +# =========================================================================== +# CursorManager Integration Tests (live Python LSP) +# =========================================================================== + + +class TestCursorManagerBasics: + """Test basic CursorManager lifecycle: start, get, list, close.""" + + def test_start_cursor_auto_id(self, cursor_manager: CursorManager) -> None: + """start_cursor generates an auto-incremented cursor ID.""" + cid, state = cursor_manager.start_cursor("UserService") + assert cid == "c1" + assert isinstance(state, CursorState) + assert state.cursor_id == "c1" + assert state.current_symbol.name == "UserService" + assert state.trail == [] + assert state.active_edge_types == DEFAULT_EDGE_TYPES + + def test_start_cursor_explicit_id(self, cursor_manager: CursorManager) -> None: + """start_cursor accepts an explicit cursor ID.""" + cid, state = cursor_manager.start_cursor("UserService", cursor_id="my-cursor") + assert cid == "my-cursor" + assert state.cursor_id == "my-cursor" + + def test_start_cursor_with_relative_path(self, cursor_manager: CursorManager) -> None: + """start_cursor can narrow symbol search to a specific file.""" + cid, state = cursor_manager.start_cursor( + "UserService", + relative_path=os.path.join("test_repo", "services.py"), + ) + assert state.current_symbol.name == "UserService" + assert state.current_location.relative_path is not None + assert "services.py" in state.current_location.relative_path + + def test_get_cursor(self, cursor_manager: CursorManager) -> None: + """get_cursor returns the correct state for a known cursor.""" + cid, _ = cursor_manager.start_cursor("UserService") + state = cursor_manager.get_cursor(cid) + assert state.cursor_id == cid + + def test_get_cursor_unknown_raises(self, cursor_manager: CursorManager) -> None: + """get_cursor raises ValueError for an unknown cursor ID.""" + with pytest.raises(ValueError, match="No cursor with id"): + cursor_manager.get_cursor("nonexistent") + + def test_list_cursors(self, cursor_manager: CursorManager) -> None: + """list_cursors returns all active cursor IDs.""" + assert cursor_manager.list_cursors() == [] + cursor_manager.start_cursor("UserService", cursor_id="a") + cursor_manager.start_cursor("Item", cursor_id="b") + ids = cursor_manager.list_cursors() + assert set(ids) == {"a", "b"} + + def test_close_cursor(self, cursor_manager: CursorManager) -> None: + """close_cursor removes the cursor.""" + cid, _ = cursor_manager.start_cursor("UserService") + assert cid in cursor_manager.list_cursors() + cursor_manager.close_cursor(cid) + assert cid not in cursor_manager.list_cursors() + + def test_close_nonexistent_cursor_is_noop(self, cursor_manager: CursorManager) -> None: + """Closing a nonexistent cursor does not raise.""" + cursor_manager.close_cursor("does-not-exist") # should not raise + + +class TestCursorMovement: + """Test moving a cursor to neighbor symbols.""" + + def test_move_to_child(self, cursor_manager: CursorManager) -> None: + """Move cursor from a class to one of its contained methods.""" + cid, _ = cursor_manager.start_cursor("UserService") + state = cursor_manager.move_cursor(cid, "create_user") + assert state.current_symbol.name == "create_user" + assert len(state.trail) == 1 # one step in trail + + def test_trail_records_moves(self, cursor_manager: CursorManager) -> None: + """Trail grows with each move, preserving prior locations.""" + cid, _ = cursor_manager.start_cursor("ItemService") + state = cursor_manager.get_cursor(cid) + # Restrict to contains edges to avoid ambiguous substring matches from references + state.active_edge_types = frozenset({EdgeType.CONTAINS}) + + cursor_manager.move_cursor(cid, "create_item") + + # Close and re-start to get a clean trail for the actual assertion + cursor_manager.close_cursor(cid) + cid2, _ = cursor_manager.start_cursor("ItemService", cursor_id="trail2") + state2 = cursor_manager.get_cursor(cid2) + state2.active_edge_types = frozenset({EdgeType.CONTAINS}) + initial_loc2 = state2.current_location + + cursor_manager.move_cursor(cid2, "create_item") + mid_loc = cursor_manager.get_cursor(cid2).current_location + + # Move from create_item's view — navigate via global fallback to list_items + cursor_manager.move_cursor(cid2, "list_items") + final_state = cursor_manager.get_cursor(cid2) + + assert len(final_state.trail) == 2 + assert final_state.trail[0] == initial_loc2 + assert final_state.trail[1] == mid_loc + + def test_move_to_ambiguous_target_with_path(self, cursor_manager: CursorManager) -> None: + """When a target name appears in multiple neighbors, relative_path narrows it.""" + # __init__ exists in many classes; disambiguate by providing relative path + cid, _ = cursor_manager.start_cursor("UserService") + state = cursor_manager.move_cursor( + cid, + "__init__", + target_relative_path=os.path.join("test_repo", "services.py"), + ) + assert state.current_symbol.name == "__init__" + + +class TestNeighborResolution: + """Test that resolve_neighbors returns correct neighbors for various edge types.""" + + def test_contains_edge(self, cursor_manager: CursorManager) -> None: + """Contains edge returns children of a class.""" + cid, _ = cursor_manager.start_cursor("UserService") + neighbors = cursor_manager.resolve_neighbors(cid) + contains_names = {n.name for n in neighbors if n.edge_type == EdgeType.CONTAINS} + # UserService should contain __init__, create_user, get_user, list_users, delete_user + assert "create_user" in contains_names + assert "get_user" in contains_names + assert "list_users" in contains_names + + def test_referenced_by_edge(self, cursor_manager: CursorManager) -> None: + """Referenced-by edge finds symbols that reference the current symbol.""" + cid, _ = cursor_manager.start_cursor("User", relative_path=os.path.join("test_repo", "models.py")) + neighbors = cursor_manager.resolve_neighbors(cid) + referenced_by = [n for n in neighbors if n.edge_type == EdgeType.REFERENCED_BY] + # User is referenced by services.py and other files + assert len(referenced_by) > 0 + + +class TestEdgeTypeConfiguration: + """Test configuring which edge types are active.""" + + def test_default_edge_types(self, cursor_manager: CursorManager) -> None: + """Default edge types include all 7 types.""" + assert DEFAULT_EDGE_TYPES == ALL_EDGE_TYPES + + def test_configure_subset(self, cursor_manager: CursorManager) -> None: + """Cursor can be configured to only follow a subset of edge types.""" + cid, _ = cursor_manager.start_cursor("UserService") + state = cursor_manager.get_cursor(cid) + state.active_edge_types = frozenset({EdgeType.CONTAINS}) + + neighbors = cursor_manager.resolve_neighbors(cid) + for n in neighbors: + assert n.edge_type == EdgeType.CONTAINS + + +class TestFormatting: + """Test format_cursor_view and format_trail output.""" + + def test_format_cursor_view_contains_symbol_name(self, cursor_manager: CursorManager) -> None: + """format_cursor_view includes the current symbol name and location.""" + cid, _ = cursor_manager.start_cursor("UserService") + view = cursor_manager.format_cursor_view(cid) + assert "UserService" in view + assert "cursor: " + cid in view + assert "trail: 0 steps" in view + assert "contains:" in view + + def test_format_trail_empty(self, cursor_manager: CursorManager) -> None: + """format_trail reports no trail at starting position.""" + cid, _ = cursor_manager.start_cursor("UserService") + trail = cursor_manager.format_trail(cid) + assert "no trail" in trail + + def test_format_trail_after_moves(self, cursor_manager: CursorManager) -> None: + """format_trail shows numbered steps after navigation.""" + cid, _ = cursor_manager.start_cursor("UserService") + cursor_manager.move_cursor(cid, "create_user") + trail = cursor_manager.format_trail(cid) + assert "1 steps" in trail or "1." in trail + assert "(current)" in trail + + +class TestNeighborSymbol: + """Test NeighborSymbol formatting.""" + + def test_location_str_with_path_and_line(self) -> None: + n = NeighborSymbol( + name="foo", + kind="Function", + relative_path="src/main.py", + line=10, + column=0, + edge_type=EdgeType.CONTAINS, + ) + assert n.location_str == "src/main.py:11" # 0-indexed line → 1-indexed display + + def test_location_str_path_only(self) -> None: + n = NeighborSymbol( + name="foo", + kind="Function", + relative_path="src/main.py", + line=None, + column=None, + edge_type=EdgeType.CONTAINS, + ) + assert n.location_str == "src/main.py" + + def test_location_str_unknown(self) -> None: + n = NeighborSymbol( + name="foo", + kind="Function", + relative_path=None, + line=None, + column=None, + edge_type=EdgeType.CONTAINS, + ) + assert n.location_str == "?" + + def test_format_compact(self) -> None: + n = NeighborSymbol( + name="foo", + kind="Function", + relative_path="src/main.py", + line=10, + column=0, + edge_type=EdgeType.CALLS, + detail="some detail", + ) + formatted = n.format_compact() + assert "foo" in formatted + assert "(Function)" in formatted + assert "src/main.py:11" in formatted + assert "— some detail" in formatted + + +# =========================================================================== +# SerenaAgent Tool-Level Integration Tests +# =========================================================================== + + +@pytest.fixture(autouse=True) +def _cleanup_cursors(python_serena_agent: SerenaAgent) -> Iterator[None]: + """Close all cursors after each tool integration test.""" + yield + manager = python_serena_agent.get_cursor_manager() + for cid in list(manager.list_cursors()): + manager.close_cursor(cid) + + +class TestCursorToolsIntegration: + """Integration tests exercising cursor tools through SerenaAgent.""" + + def test_cursor_start_and_look(self, python_serena_agent: SerenaAgent) -> None: + """cursor_start places a cursor and returns a neighborhood view.""" + start_tool = python_serena_agent.get_tool(CursorStartTool) + result = start_tool.apply(name_path="UserService") + + assert "UserService" in result + assert "cursor:" in result + assert "contains:" in result + assert "create_user" in result + + # Extract cursor ID from result + # Format: " cursor: c1 | trail: 0 steps" + match = re.search(r"cursor: (\S+)", result) + assert match, f"Could not find cursor ID in result: {result}" + cid = match.group(1) + + # cursor_look should return the same view + look_tool = python_serena_agent.get_tool(CursorLookTool) + look_result = look_tool.apply(cursor_id=cid) + assert "UserService" in look_result + assert "contains:" in look_result + + def test_cursor_move_and_history(self, python_serena_agent: SerenaAgent) -> None: + """cursor_move navigates to a neighbor; cursor_history shows the trail.""" + start_tool = python_serena_agent.get_tool(CursorStartTool) + start_tool.apply(name_path="UserService", cursor_id="nav-test") + + move_tool = python_serena_agent.get_tool(CursorMoveTool) + move_result = move_tool.apply(cursor_id="nav-test", target_name="create_user") + assert "create_user" in move_result + + history_tool = python_serena_agent.get_tool(CursorHistoryTool) + history_result = history_tool.apply(cursor_id="nav-test") + assert "1." in history_result + assert "(current)" in history_result + + def test_cursor_configure_edge_types(self, python_serena_agent: SerenaAgent) -> None: + """cursor_configure changes which edge types are shown.""" + start_tool = python_serena_agent.get_tool(CursorStartTool) + start_tool.apply(name_path="UserService", cursor_id="cfg-test") + + configure_tool = python_serena_agent.get_tool(CursorConfigureTool) + result = configure_tool.apply(cursor_id="cfg-test", edge_types=["contains"]) + # After configuring to only show "contains", we should still see children + assert "contains:" in result + # Other edge types should not appear (unless they happen to have zero results, + # in which case they wouldn't appear anyway) + + def test_cursor_configure_invalid_edge_type(self, python_serena_agent: SerenaAgent) -> None: + """cursor_configure raises for an invalid edge type name.""" + start_tool = python_serena_agent.get_tool(CursorStartTool) + start_tool.apply(name_path="UserService", cursor_id="cfg-err") + + configure_tool = python_serena_agent.get_tool(CursorConfigureTool) + with pytest.raises(ValueError, match="Unknown edge type"): + configure_tool.apply(cursor_id="cfg-err", edge_types=["nonexistent-edge"]) + + def test_cursor_close(self, python_serena_agent: SerenaAgent) -> None: + """cursor_close removes the cursor.""" + start_tool = python_serena_agent.get_tool(CursorStartTool) + start_tool.apply(name_path="UserService", cursor_id="close-test") + + close_tool = python_serena_agent.get_tool(CursorCloseTool) + result = close_tool.apply(cursor_id="close-test") + assert "close-test" in result + assert "closed" in result.lower() + + # Trying to look at a closed cursor should fail + look_tool = python_serena_agent.get_tool(CursorLookTool) + with pytest.raises(ValueError, match="No cursor with id"): + look_tool.apply(cursor_id="close-test") + + def test_cursor_include_body(self, python_serena_agent: SerenaAgent) -> None: + """cursor_configure with include_body=True shows the symbol body.""" + start_tool = python_serena_agent.get_tool(CursorStartTool) + start_tool.apply(name_path="UserService/create_user", cursor_id="body-test") + + configure_tool = python_serena_agent.get_tool(CursorConfigureTool) + result = configure_tool.apply(cursor_id="body-test", include_body=True) + assert "--- body ---" in result + assert "def create_user" in result + + def test_navigate_class_to_method_to_reference(self, python_serena_agent: SerenaAgent) -> None: + """Full navigation: class -> method -> follow a reference.""" + start_tool = python_serena_agent.get_tool(CursorStartTool) + move_tool = python_serena_agent.get_tool(CursorMoveTool) + history_tool = python_serena_agent.get_tool(CursorHistoryTool) + + # Start at a class + start_tool.apply(name_path="UserService", cursor_id="full-nav") + + # Move to a method + move_tool.apply(cursor_id="full-nav", target_name="create_user") + + # Check trail has one step + history = history_tool.apply(cursor_id="full-nav") + assert "1." in history + + def test_multiple_cursors(self, python_serena_agent: SerenaAgent) -> None: + """Multiple cursors can be active simultaneously.""" + start_tool = python_serena_agent.get_tool(CursorStartTool) + + r1 = start_tool.apply(name_path="UserService", cursor_id="multi-1") + r2 = start_tool.apply(name_path="Item", cursor_id="multi-2") + + assert "UserService" in r1 + assert "Item" in r2 + + look_tool = python_serena_agent.get_tool(CursorLookTool) + look1 = look_tool.apply(cursor_id="multi-1") + look2 = look_tool.apply(cursor_id="multi-2") + + assert "UserService" in look1 + assert "Item" in look2 + + def test_navigate_inheritance(self, python_serena_agent: SerenaAgent) -> None: + """Navigate the type hierarchy: User inherits from BaseModel.""" + start_tool = python_serena_agent.get_tool(CursorStartTool) + configure_tool = python_serena_agent.get_tool(CursorConfigureTool) + + start_tool.apply(name_path="User", cursor_id="inherit-test", relative_path=os.path.join("test_repo", "models.py")) + + # Configure to only show inheritance edges + result = configure_tool.apply( + cursor_id="inherit-test", + edge_types=["inherits", "inherited-by"], + ) + # If the Python LSP supports type hierarchy, we should see BaseModel + # If not, the cursor gracefully shows "(no neighbors found)" + # Either outcome is acceptable — the test validates no crashes + assert "inherit-test" in result or "cursor:" in result + + def test_navigate_nested_class(self, python_serena_agent: SerenaAgent) -> None: + """Navigate to a nested class via the contains edge.""" + start_tool = python_serena_agent.get_tool(CursorStartTool) + result = start_tool.apply( + name_path="OuterClass", + relative_path=os.path.join("test_repo", "nested.py"), + ) + + assert "OuterClass" in result + # Should see NestedClass and nested_test as children + assert "NestedClass" in result or "nested_test" in result + + +# =========================================================================== +# Cursor Find and Overview (read) tests +# =========================================================================== + + +class TestCursorFindAndOverview: + def test_cursor_find_unique_starts_cursor(self, python_serena_agent: SerenaAgent) -> None: + """cursor_find with a unique match starts a cursor and returns its view.""" + find_tool = python_serena_agent.get_tool(CursorFindTool) + result = find_tool.apply(name_path_pattern="/UserService", cursor_id="find-unique") + assert "started cursor" in result + assert "UserService" in result + assert "contains:" in result + + def test_cursor_find_multiple_returns_candidates(self, python_serena_agent: SerenaAgent) -> None: + """cursor_find with a non-unique pattern returns a candidate list without starting a cursor.""" + find_tool = python_serena_agent.get_tool(CursorFindTool) + result = find_tool.apply(name_path_pattern="__init__") + # Many __init__ methods exist; the result should list candidates and NOT start a cursor. + assert "started cursor" not in result + assert "matching symbols" in result or "Candidates" in result + + def test_cursor_find_no_match(self, python_serena_agent: SerenaAgent) -> None: + """cursor_find returns a clear message when no symbol matches.""" + find_tool = python_serena_agent.get_tool(CursorFindTool) + result = find_tool.apply(name_path_pattern="__absolutely_no_such_symbol__") + assert "No symbols found" in result + + def test_cursor_find_substring(self, python_serena_agent: SerenaAgent) -> None: + """cursor_find with substring_matching=True matches partial names.""" + find_tool = python_serena_agent.get_tool(CursorFindTool) + result = find_tool.apply( + name_path_pattern="create_u", + substring_matching=True, + relative_path=os.path.join("test_repo", "services.py"), + ) + assert "create_user" in result + + def test_cursor_overview(self, python_serena_agent: SerenaAgent) -> None: + """cursor_overview lists top-level symbols in a file.""" + overview_tool = python_serena_agent.get_tool(CursorOverviewTool) + result = overview_tool.apply(relative_path=os.path.join("test_repo", "services.py")) + assert "Top-level symbols" in result + assert "UserService" in result + assert "ItemService" in result + + def test_cursor_overview_missing_file(self, python_serena_agent: SerenaAgent) -> None: + """cursor_overview raises FileNotFoundError for a missing file.""" + overview_tool = python_serena_agent.get_tool(CursorOverviewTool) + with pytest.raises(FileNotFoundError): + overview_tool.apply(relative_path="does/not/exist.py") + + +# =========================================================================== +# Cursor Edit tool tests — use a throwaway file so we can assert and revert. +# =========================================================================== + + +@pytest.fixture +def throwaway_python_file(python_serena_agent: SerenaAgent) -> Iterator[str]: + """Create a throwaway Python file in the test repo so edit tests can mutate and clean up.""" + from pathlib import Path + + rel_path = os.path.join("test_repo", "_cursor_edit_sandbox.py") + abs_path = Path(python_serena_agent.get_active_project_or_raise().project_root) / rel_path + abs_path.write_text("def sandbox_fn():\n x = 1\n return x\n\n\ndef other_fn():\n return 2\n") + # Give the LSP a chance to pick up the file on some backends. + try: + python_serena_agent.reset_language_server_manager() + except Exception: + pass + try: + yield rel_path + finally: + if abs_path.exists(): + abs_path.unlink() + try: + python_serena_agent.reset_language_server_manager() + except Exception: + pass + + +class TestCursorEditTools: + def test_cursor_replace_body_refreshes_cursor(self, python_serena_agent: SerenaAgent, throwaway_python_file: str) -> None: + """cursor_replace_body replaces the body and the cursor re-anchors.""" + start_tool = python_serena_agent.get_tool(CursorStartTool) + replace_tool = python_serena_agent.get_tool(CursorReplaceBodyTool) + look_tool = python_serena_agent.get_tool(CursorLookTool) + + start_tool.apply( + name_path="sandbox_fn", + relative_path=throwaway_python_file, + cursor_id="edit-replace", + ) + result = replace_tool.apply( + cursor_id="edit-replace", + body="def sandbox_fn():\n return 42\n", + ) + assert "OK" in result + assert "sandbox_fn" in result + + look = look_tool.apply(cursor_id="edit-replace") + assert "sandbox_fn" in look + + def test_cursor_insert_before(self, python_serena_agent: SerenaAgent, throwaway_python_file: str) -> None: + """cursor_insert_before inserts content above the target symbol.""" + from pathlib import Path + + start_tool = python_serena_agent.get_tool(CursorStartTool) + insert_tool = python_serena_agent.get_tool(CursorInsertBeforeTool) + + start_tool.apply( + name_path="other_fn", + relative_path=throwaway_python_file, + cursor_id="edit-before", + ) + result = insert_tool.apply( + cursor_id="edit-before", + body="# inserted-before-marker\n", + ) + assert "OK" in result + + abs_path = Path(python_serena_agent.get_active_project_or_raise().project_root) / throwaway_python_file + content = abs_path.read_text() + assert "inserted-before-marker" in content + assert content.index("inserted-before-marker") < content.index("def other_fn") + + def test_cursor_insert_after(self, python_serena_agent: SerenaAgent, throwaway_python_file: str) -> None: + """cursor_insert_after inserts content below the target symbol.""" + from pathlib import Path + + start_tool = python_serena_agent.get_tool(CursorStartTool) + insert_tool = python_serena_agent.get_tool(CursorInsertAfterTool) + + start_tool.apply( + name_path="sandbox_fn", + relative_path=throwaway_python_file, + cursor_id="edit-after", + ) + result = insert_tool.apply( + cursor_id="edit-after", + body="# inserted-after-marker\n", + ) + assert "OK" in result + + abs_path = Path(python_serena_agent.get_active_project_or_raise().project_root) / throwaway_python_file + content = abs_path.read_text() + assert "inserted-after-marker" in content + + def test_cursor_rename(self, python_serena_agent: SerenaAgent, throwaway_python_file: str) -> None: + """cursor_rename renames the symbol and re-anchors the cursor.""" + from pathlib import Path + + start_tool = python_serena_agent.get_tool(CursorStartTool) + rename_tool = python_serena_agent.get_tool(CursorRenameTool) + + start_tool.apply( + name_path="sandbox_fn", + relative_path=throwaway_python_file, + cursor_id="edit-rename", + ) + result = rename_tool.apply(cursor_id="edit-rename", new_name="renamed_sandbox_fn") + # Rename may or may not be supported by all language servers; accept graceful failure text. + abs_path = Path(python_serena_agent.get_active_project_or_raise().project_root) / throwaway_python_file + content = abs_path.read_text() + # Either the rename worked or the tool reported back (but no crash). + assert "edit-rename" in result or "renamed_sandbox_fn" in content or "sandbox_fn" in content + + def test_cursor_edit_without_location_raises(self, python_serena_agent: SerenaAgent) -> None: + """A cursor whose current location has no relative_path rejects edits.""" + from serena.cursor import CursorState + from serena.symbol import LanguageServerSymbolLocation + + manager = python_serena_agent.get_cursor_manager() + replace_tool = python_serena_agent.get_tool(CursorReplaceBodyTool) + + # Manually install a cursor with a None relative_path + sandbox_location = LanguageServerSymbolLocation(relative_path=None, line=None, column=None) + # Position with a real symbol first, then override its location. + start_tool = python_serena_agent.get_tool(CursorStartTool) + start_tool.apply( + name_path="UserService", + relative_path=os.path.join("test_repo", "services.py"), + cursor_id="no-loc", + ) + state: CursorState = manager.get_cursor("no-loc") + state.current_location = sandbox_location + + with pytest.raises(ValueError, match="no relative path"): + replace_tool.apply(cursor_id="no-loc", body="irrelevant") diff --git a/test/serena/test_serena_agent.py b/test/serena/test_serena_agent.py index be078eaea..c3c7c12f4 100644 --- a/test/serena/test_serena_agent.py +++ b/test/serena/test_serena_agent.py @@ -416,13 +416,15 @@ def test_find_symbol_name_path( agent = serena_agent find_symbol_tool = agent.get_tool(FindSymbolTool) - result = find_symbol_tool.apply_ex( + # Direct .apply() bypass — FindSymbolTool is marked ToolMarkerOptional so it is + # not in the default active set over MCP, but is still exercised as a Python class. + result = find_symbol_tool.apply( name_path_pattern=name_path, depth=0, - relative_path=None, + relative_path="", include_body=False, - include_kinds=None, - exclude_kinds=None, + include_kinds=[], + exclude_kinds=[], substring_matching=substring_matching, ) @@ -460,7 +462,8 @@ def test_find_symbol_name_path_no_match( agent = serena_agent find_symbol_tool = agent.get_tool(FindSymbolTool) - result = find_symbol_tool.apply_ex( + # Direct .apply() — FindSymbolTool is optional over MCP but still tested as a Python class. + result = find_symbol_tool.apply( name_path_pattern=name_path, depth=0, substring_matching=True, @@ -490,7 +493,8 @@ def test_find_symbol_overloaded_function(self, serena_agent: SerenaAgent, name_p agent = serena_agent find_symbol_tool = agent.get_tool(FindSymbolTool) - result = find_symbol_tool.apply_ex( + # Direct .apply() — FindSymbolTool is optional over MCP but still tested as a Python class. + result = find_symbol_tool.apply( name_path_pattern=name_path, depth=0, substring_matching=False,