diff --git a/CHANGELOG.md b/CHANGELOG.md index 390b32e2..20933baf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Next +### Added + +- Added CypherRetriever for executing parameterized Cypher queries with strong type validation. + ## 1.6.0 ### Added diff --git a/docs/source/api.rst b/docs/source/api.rst index f27bf3af..21b5286a 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -187,6 +187,13 @@ Text2CypherRetriever :members: search +CypherRetriever +=============== + +.. autoclass:: neo4j_graphrag.retrievers.CypherRetriever + :members: search + + ******************* External Retrievers ******************* diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index b51c019a..73a6f102 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -349,6 +349,8 @@ We provide implementations for the following retrievers: - Same as HybridRetriever with a retrieval query similar to VectorCypherRetriever. * - :ref:`Text2Cypher ` - Translates the user question into a Cypher query to be run against a Neo4j database (or Knowledge Graph). The results of the query are then passed to the LLM to generate the final answer. + * - :ref:`CypherRetriever ` + - Uses a predefined Cypher query template with parameterized inputs to retrieve data from the database. * - :ref:`WeaviateNeo4jRetriever ` - Use this retriever when vectors are saved in a Weaviate vector database * - :ref:`PineconeNeo4jRetriever ` @@ -849,6 +851,176 @@ LLMs can be different. See :ref:`text2cypherretriever`. +.. _cypher-retriever-user-guide: + +Cypher Retriever +=============================== + +The `CypherRetriever` allows you to define a templated Cypher query with parameterized inputs. This retriever is useful when you need direct database access with dynamic parameters, but without the complexity of LLM-generated queries or vector similarity search. + +Basic Usage +---------- + +The simplest usage involves defining a query with parameters: + +.. code:: python + + from neo4j_graphrag.retrievers import CypherRetriever + + # Create a retriever for finding movies by title + retriever = CypherRetriever( + driver=driver, + query="MATCH (m:Movie {title: $movie_title}) RETURN m", + parameters={ + "movie_title": { + "type": "string", + "description": "Title of a movie" + } + } + ) + + # Use the retriever with specific parameter values + results = retriever.search(parameters={"movie_title": "The Matrix"}) + +Parameter Types +--------------- + +The CypherRetriever supports these parameter types: + +- `string`: For text values +- `number`: For floating point values +- `integer`: For whole number values +- `boolean`: For true/false values +- `array`: For lists of values + +Optional Parameters +------------------ + +You can make parameters optional by setting `required: false` in the parameter definition: + +.. code:: python + + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (m:Movie) + WHERE ($title IS NULL OR m.title CONTAINS $title) + AND ($year IS NULL OR m.released = $year) + RETURN m + """, + parameters={ + "title": { + "type": "string", + "description": "Movie title to search for", + "required": False + }, + "year": { + "type": "integer", + "description": "Release year", + "required": False + } + } + ) + + # Search with only one parameter + results = retriever.search(parameters={"title": "Matrix"}) + +Complex Queries +-------------- + +You can build more complex queries with multiple parameters and conditions: + +.. code:: python + + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (m:Movie) + WHERE ($title IS NULL OR m.title CONTAINS $title) + AND ($min_year IS NULL OR m.released >= $min_year) + AND ($max_year IS NULL OR m.released <= $max_year) + AND ($min_rating IS NULL OR m.rating >= $min_rating) + RETURN m + ORDER BY m.rating DESC + LIMIT $limit + """, + parameters={ + "title": { + "type": "string", + "description": "Partial movie title to search for", + "required": False + }, + "min_year": { + "type": "integer", + "description": "Minimum release year", + "required": False + }, + "max_year": { + "type": "integer", + "description": "Maximum release year", + "required": False + }, + "min_rating": { + "type": "number", + "description": "Minimum movie rating", + "required": False + }, + "limit": { + "type": "integer", + "description": "Maximum number of results to return", + "required": True + } + } + ) + +Custom Result Formatting +----------------------- + +You can customize how the results are formatted using a result formatter: + +.. code:: python + + def movie_formatter(record): + movie = record["m"] + return RetrieverResultItem( + content=f"{movie['title']} ({movie['released']})", + metadata={ + "rating": movie.get("rating"), + "tagline": movie.get("tagline"), + } + ) + + retriever = CypherRetriever( + driver=driver, + query="MATCH (m:Movie) WHERE m.title CONTAINS $title RETURN m", + parameters={"title": {"type": "string", "description": "Movie title"}}, + result_formatter=movie_formatter + ) + +Graph Traversals +--------------- + +The CypherRetriever is particularly useful for complex graph traversals: + +.. code:: python + + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (m:Movie {title: $movie_title})<-[r:ACTED_IN]-(a:Person) + RETURN a.name as actor, r.roles as roles + ORDER BY a.name + """, + parameters={ + "movie_title": { + "type": "string", + "description": "Title of a movie" + } + } + ) + +See :ref:`cypherretriever`. + .. _custom-retriever: Custom Retriever diff --git a/examples/retrieve/cypher_retriever.py b/examples/retrieve/cypher_retriever.py new file mode 100644 index 00000000..3b45d05b --- /dev/null +++ b/examples/retrieve/cypher_retriever.py @@ -0,0 +1,197 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example of using CypherRetriever for parametrized Cypher queries. + +This example demonstrates how to use CypherRetriever to define a retriever with +a templated Cypher query that accepts parameters at runtime. +""" + +import neo4j +from neo4j import Record +from neo4j_graphrag.retrievers import CypherRetriever +from neo4j_graphrag.types import RetrieverResultItem + +# Connect to Neo4j +# Replace with your own connection details +NEO4J_URI = "bolt://localhost:7687" +NEO4J_USER = "neo4j" +NEO4J_PASSWORD = "password" # Change this in production + +driver = neo4j.GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) + + +# Simple example: Find a movie by title +def find_movie_by_title() -> None: + retriever = CypherRetriever( + driver=driver, + query="MATCH (m:Movie {title: $movie_title}) RETURN m", + parameters={ + "movie_title": {"type": "string", "description": "Title of a movie"} + }, + ) + + # Use the retriever to search for a movie + result = retriever.search(parameters={"movie_title": "The Matrix"}) + + print("=== Find Movie by Title ===") + for item in result.items: + print(f"Movie: {item.content}") + print() + + +# Advanced example: Find movies with multiple criteria +def find_movies_by_criteria() -> None: + # Custom formatter to extract specific information + def movie_formatter(record: Record) -> RetrieverResultItem: + movie = record["m"] + return RetrieverResultItem( + content=f"{movie['title']} ({movie['released']})", + metadata={ + "rating": movie.get("rating"), + "tagline": movie.get("tagline"), + }, + ) + + # Create a more complex retriever with multiple parameters + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (m:Movie) + WHERE ($title IS NULL OR m.title CONTAINS $title) + AND ($min_year IS NULL OR m.released >= $min_year) + AND ($max_year IS NULL OR m.released <= $max_year) + AND ($min_rating IS NULL OR m.rating >= $min_rating) + RETURN m + ORDER BY m.rating DESC + LIMIT $limit + """, + parameters={ + "title": { + "type": "string", + "description": "Partial movie title to search for", + "required": False, + }, + "min_year": { + "type": "integer", + "description": "Minimum release year", + "required": False, + }, + "max_year": { + "type": "integer", + "description": "Maximum release year", + "required": False, + }, + "min_rating": { + "type": "number", + "description": "Minimum movie rating", + "required": False, + }, + "limit": { + "type": "integer", + "description": "Maximum number of results to return", + "required": True, + }, + }, + result_formatter=movie_formatter, + ) + + # Search with optional parameters + result = retriever.search( + parameters={"title": "Matrix", "min_year": 1990, "min_rating": 7.5, "limit": 5} + ) + + print("=== Find Movies by Criteria ===") + for item in result.items: + print(f"Movie: {item.content}") + if item.metadata: + if "rating" in item.metadata: + print(f" Rating: {item.metadata['rating']}") + if "tagline" in item.metadata: + print(f" Tagline: {item.metadata['tagline']}") + print() + + +# Example with relationship traversal +def find_actors_in_movie() -> None: + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (m:Movie {title: $movie_title})<-[r:ACTED_IN]-(a:Person) + RETURN a.name as actor, r.roles as roles + ORDER BY a.name + """, + parameters={ + "movie_title": {"type": "string", "description": "Title of a movie"} + }, + ) + + result = retriever.search(parameters={"movie_title": "The Matrix"}) + + print("=== Find Actors in Movie ===") + for item in result.items: + record = eval(item.content) # Simple way to parse the string representation + actor = record.get("actor", "Unknown") + roles = record.get("roles", []) + roles_str = ", ".join(roles) if roles else "Unknown role" + print(f"Actor: {actor} as {roles_str}") + print() + + +if __name__ == "__main__": + try: + # Setup: Make sure we have some movie data + with driver.session() as session: + # Check if data exists + result = session.run("MATCH (m:Movie) RETURN count(m) as count") + record = result.single() + count = record["count"] if record else 0 + + if count == 0: + print("No movie data found. Creating sample data...") + # Create sample data if none exists + session.run(""" + CREATE (TheMatrix:Movie {title:'The Matrix', released:1999, tagline:'Welcome to the Real World', rating: 8.7}) + CREATE (Keanu:Person {name:'Keanu Reeves', born:1964}) + CREATE (Carrie:Person {name:'Carrie-Anne Moss', born:1967}) + CREATE (Laurence:Person {name:'Laurence Fishburne', born:1961}) + CREATE (Hugo:Person {name:'Hugo Weaving', born:1960}) + CREATE (Keanu)-[:ACTED_IN {roles:['Neo']}]->(TheMatrix) + CREATE (Carrie)-[:ACTED_IN {roles:['Trinity']}]->(TheMatrix) + CREATE (Laurence)-[:ACTED_IN {roles:['Morpheus']}]->(TheMatrix) + CREATE (Hugo)-[:ACTED_IN {roles:['Agent Smith']}]->(TheMatrix) + CREATE (TheMatrixReloaded:Movie {title:'The Matrix Reloaded', released:2003, tagline:'Free your mind', rating: 7.2}) + CREATE (Keanu)-[:ACTED_IN {roles:['Neo']}]->(TheMatrixReloaded) + CREATE (Carrie)-[:ACTED_IN {roles:['Trinity']}]->(TheMatrixReloaded) + CREATE (Laurence)-[:ACTED_IN {roles:['Morpheus']}]->(TheMatrixReloaded) + CREATE (Hugo)-[:ACTED_IN {roles:['Agent Smith']}]->(TheMatrixReloaded) + CREATE (TheMatrixRevolutions:Movie {title:'The Matrix Revolutions', released:2003, tagline:'Everything that has a beginning has an end', rating: 6.8}) + CREATE (Keanu)-[:ACTED_IN {roles:['Neo']}]->(TheMatrixRevolutions) + CREATE (Carrie)-[:ACTED_IN {roles:['Trinity']}]->(TheMatrixRevolutions) + CREATE (Laurence)-[:ACTED_IN {roles:['Morpheus']}]->(TheMatrixRevolutions) + CREATE (Hugo)-[:ACTED_IN {roles:['Agent Smith']}]->(TheMatrixRevolutions) + """) + print("Sample data created.") + else: + print(f"Found {count} movies in the database.") + + # Run the examples + find_movie_by_title() + find_movies_by_criteria() + find_actors_in_movie() + + finally: + # Close the driver + driver.close() diff --git a/src/neo4j_graphrag/retrievers/__init__.py b/src/neo4j_graphrag/retrievers/__init__.py index 595eac93..2e957552 100644 --- a/src/neo4j_graphrag/retrievers/__init__.py +++ b/src/neo4j_graphrag/retrievers/__init__.py @@ -16,6 +16,7 @@ from .hybrid import HybridCypherRetriever, HybridRetriever from .text2cypher import Text2CypherRetriever from .vector import VectorCypherRetriever, VectorRetriever +from .cypher import CypherRetriever __all__ = [ "VectorRetriever", @@ -23,6 +24,7 @@ "HybridRetriever", "HybridCypherRetriever", "Text2CypherRetriever", + "CypherRetriever", ] diff --git a/src/neo4j_graphrag/retrievers/cypher.py b/src/neo4j_graphrag/retrievers/cypher.py new file mode 100644 index 00000000..0690c0e0 --- /dev/null +++ b/src/neo4j_graphrag/retrievers/cypher.py @@ -0,0 +1,379 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import logging +import re +from typing import Any, Callable, Dict, Optional + +import neo4j +from pydantic_core import ErrorDetails +from neo4j.exceptions import CypherSyntaxError +from pydantic import ValidationError + +from neo4j_graphrag.exceptions import ( + RetrieverInitializationError, + SearchValidationError, +) +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.types import ( + CypherParameterDefinition, + CypherParameterType, + CypherRetrieverModel, + CypherSearchModel, + Neo4jDriverModel, + RawSearchResult, + RetrieverResultItem, +) + +logger = logging.getLogger(__name__) + + +class CypherRetriever(Retriever): + """ + Allows for the retrieval of records from a Neo4j database using a parameterized Cypher query. + + This retriever enables direct execution of predefined Cypher queries with dynamic parameters. + It ensures type safety through parameter validation and provides the standard retriever result format. + + Example: + + .. code-block:: python + + import neo4j + from neo4j_graphrag.retrievers import CypherRetriever + + driver = neo4j.GraphDatabase.driver(URI, auth=AUTH) + + # Create a retriever for finding movies by title + retriever = CypherRetriever( + driver=driver, + query="MATCH (m:Movie {title: $movie_title}) RETURN m", + parameters={ + "movie_title": { + "type": "string", + "description": "Title of a movie" + } + } + ) + + # Use the retriever with specific parameter values + results = retriever.search(parameters={"movie_title": "The Matrix"}) + + Args: + driver (neo4j.Driver): The Neo4j Python driver. + query (str): Cypher query with parameter placeholders. + parameters (Dict[str, Dict]): Parameter definitions with types and descriptions. + Each parameter should have a 'type' and 'description' field. + Supported types: 'string', 'number', 'integer', 'boolean', 'array'. + result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): + Custom function to transform a neo4j.Record to a RetrieverResultItem. + neo4j_database (Optional[str]): The name of the Neo4j database to use. + + Raises: + RetrieverInitializationError: If validation of the input arguments fail. + """ + + def __init__( + self, + driver: neo4j.Driver, + query: str, + parameters: Dict[str, Dict[str, Any]], + result_formatter: Optional[ + Callable[[neo4j.Record], RetrieverResultItem] + ] = None, + neo4j_database: Optional[str] = None, + ) -> None: + # Convert parameter dictionaries to CypherParameterDefinition objects + param_definitions = {} + for param_name, param_def in parameters.items(): + param_type = param_def.get("type", "string") + description = param_def.get("description", "") + required = param_def.get("required", True) + + try: + param_definitions[param_name] = CypherParameterDefinition( + type=param_type, description=description, required=required + ) + except ValidationError as e: + raise RetrieverInitializationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Invalid parameter definition: {e.errors()}", + type="validation_error", + input=param_def, + ) + ] + ) from e + + try: + driver_model = Neo4jDriverModel(driver=driver) + validated_data = CypherRetrieverModel( + driver_model=driver_model, + query=query, + parameters=param_definitions, + result_formatter=result_formatter, + neo4j_database=neo4j_database, + ) + except ValidationError as e: + raise RetrieverInitializationError(e.errors()) from e + + # Validate that the query is syntactically valid Cypher + self._validate_cypher_query(query) + + # Validate that all parameters in the query are defined + self._validate_query_parameters(query, param_definitions) + + super().__init__( + validated_data.driver_model.driver, validated_data.neo4j_database + ) + self.query = validated_data.query + self.parameters = validated_data.parameters + self.result_formatter = validated_data.result_formatter + + def _validate_cypher_query(self, query: str) -> None: + """ + Validates that the query is syntactically valid Cypher. + + Args: + query (str): The Cypher query to validate. + + Raises: + RetrieverInitializationError: If the query is not valid Cypher. + """ + # We can't fully validate the query without executing it, but we can check for basic syntax + if not query.strip(): + raise RetrieverInitializationError( + [ + ErrorDetails( + loc=("query",), + msg="Query cannot be empty", + type="value_error.empty", + input="", + ) + ] + ) + + # Check for presence of common Cypher keywords + if not any( + keyword in query.upper() + for keyword in ["MATCH", "RETURN", "CREATE", "MERGE", "WITH"] + ): + raise RetrieverInitializationError( + [ + ErrorDetails( + loc=("query",), + msg="Query does not appear to be valid Cypher. It should contain at least one of: MATCH, RETURN, CREATE, MERGE, WITH", + type="value_error.invalid_cypher", + input="", + ) + ] + ) + + def _validate_query_parameters( + self, query: str, parameters: Dict[str, CypherParameterDefinition] + ) -> None: + """ + Validates that all parameters in the query are defined in the parameters dictionary. + + Args: + query (str): The Cypher query to validate. + parameters (Dict[str, CypherParameterDefinition]): The parameter definitions. + + Raises: + RetrieverInitializationError: If any parameters in the query are not defined. + """ + # Find all parameters in the query (starting with $) + param_pattern = r"\$([a-zA-Z0-9_]+)" + query_params = set(re.findall(param_pattern, query)) + + # Check that all parameters in the query are defined + undefined_params = query_params - set(parameters.keys()) + if undefined_params: + raise RetrieverInitializationError( + [ + ErrorDetails( + loc=("parameters",), + msg=f"The following parameters are used in the query but not defined: {', '.join(undefined_params)}", + type="value_error.undefined_parameters", + input=undefined_params, + ) + ] + ) + + def _validate_parameter_values(self, parameters: Dict[str, Any]) -> None: + """ + Validates that parameter values match their defined types. + + Args: + parameters (Dict[str, Any]): The parameter values to validate. + + Raises: + SearchValidationError: If any parameter values do not match their defined types. + """ + # Check that all required parameters are provided + for param_name, param_def in self.parameters.items(): + if param_def.required and param_name not in parameters: + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Required parameter '{param_name}' is missing", + type="value_error.missing", + input=None, + ) + ] + ) + + # Validate the type of each parameter + for param_name, param_value in parameters.items(): + if param_name not in self.parameters: + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Unexpected parameter: {param_name}", + type="value_error.unexpected", + input=param_name, + ) + ] + ) + + param_def = self.parameters[param_name] + + # Type validation + if param_def.type == CypherParameterType.STRING: + if not isinstance(param_value, str): + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type string, got {type(param_value).__name__}", + type="type_error.string", + input=param_value, + ) + ] + ) + elif param_def.type == CypherParameterType.NUMBER: + if not isinstance(param_value, (int, float)): + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type number, got {type(param_value).__name__}", + type="type_error.number", + input=param_value, + ) + ] + ) + elif param_def.type == CypherParameterType.INTEGER: + if not isinstance(param_value, int) or isinstance(param_value, bool): + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type integer, got {type(param_value).__name__}", + type="type_error.integer", + input=param_value, + ) + ] + ) + elif param_def.type == CypherParameterType.BOOLEAN: + if not isinstance(param_value, bool): + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type boolean, got {type(param_value).__name__}", + type="type_error.boolean", + input=param_value, + ) + ] + ) + elif param_def.type == CypherParameterType.ARRAY: + if not isinstance(param_value, (list, tuple)): + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type array, got {type(param_value).__name__}", + type="type_error.array", + input=param_value, + ) + ] + ) + + def get_search_results(self, parameters: Dict[str, Any]) -> RawSearchResult: + """ + Executes the Cypher query with the provided parameters and returns the results. + + Args: + parameters (Dict[str, Any]): Parameter values to use in the query. + Each parameter should match the type specified in the parameter definitions. + + Raises: + SearchValidationError: If validation of the parameters fails. + + Returns: + RawSearchResult: The results of the query as a list of neo4j.Record and an optional metadata dict. + """ + try: + validated_data = CypherSearchModel(parameters=parameters) + except ValidationError as e: + raise SearchValidationError(e.errors()) from e + + # Validate parameter values against their definitions + self._validate_parameter_values(validated_data.parameters) + + logger.debug("CypherRetriever query: %s", self.query) + logger.debug("CypherRetriever parameters: %s", validated_data.parameters) + + try: + records, _, _ = self.driver.execute_query( + query_=self.query, + parameters_=validated_data.parameters, + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, + ) + except CypherSyntaxError as e: + raise SearchValidationError( + [ + ErrorDetails( + loc=("query",), + msg=f"Cypher syntax error: {e.message}", + type="value_error.cypher_syntax", + input=self.query, + ) + ] + ) from e + except Exception as e: + raise SearchValidationError( + [ + ErrorDetails( + loc=("query",), + msg=f"Failed to execute query: {str(e)}", + type="execution_error", + input=self.query, + ) + ] + ) from e + + return RawSearchResult( + records=records, + metadata={ + "cypher": self.query, + }, + ) diff --git a/src/neo4j_graphrag/types.py b/src/neo4j_graphrag/types.py index 1c0b7454..df3c82dd 100644 --- a/src/neo4j_graphrag/types.py +++ b/src/neo4j_graphrag/types.py @@ -16,7 +16,7 @@ import warnings from enum import Enum -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Any, Callable, Literal, Optional, TypedDict, Union, Dict import neo4j from pydantic import ( @@ -312,3 +312,37 @@ def validate_session_id(cls, v: Union[str, int]) -> Union[str, int]: class LLMMessage(TypedDict): role: Literal["system", "user", "assistant"] content: str + + +class CypherParameterType(str, Enum): + """Enumeration of parameter types.""" + + STRING = "string" + NUMBER = "number" + INTEGER = "integer" + BOOLEAN = "boolean" + ARRAY = "array" + + +class CypherParameterDefinition(BaseModel): + """Definition of a Cypher query parameter.""" + + type: CypherParameterType + description: str + required: bool = True + + +class CypherRetrieverModel(BaseModel): + """Model for validating CypherRetriever arguments.""" + + driver_model: Neo4jDriverModel + query: str + parameters: Dict[str, CypherParameterDefinition] + result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None + neo4j_database: Optional[str] = None + + +class CypherSearchModel(BaseModel): + """Model for validating search parameters.""" + + parameters: Dict[str, Any] diff --git a/tests/e2e/retrievers/test_cypher_e2e.py b/tests/e2e/retrievers/test_cypher_e2e.py new file mode 100644 index 00000000..75282b63 --- /dev/null +++ b/tests/e2e/retrievers/test_cypher_e2e.py @@ -0,0 +1,213 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import string +from typing import Generator + +import neo4j +from neo4j import Record +import pytest +from neo4j.exceptions import Neo4jError + +from neo4j_graphrag.retrievers import CypherRetriever +from neo4j_graphrag.types import RetrieverResultItem + + +# Fixture to create test data +@pytest.fixture +def sample_data(driver: neo4j.Driver) -> Generator[str, None, None]: + # Generate a random prefix for category names to avoid conflicts between test runs + prefix = "".join(random.choices(string.ascii_lowercase, k=8)) + category_name = f"Category_{prefix}" + + # Create test data + try: + with driver.session() as session: + session.run( + """ + CREATE (c:Category {name: $category_name}) + CREATE (p1:Product {name: "Product1", price: 10.99, stock: 100, featured: true}) + CREATE (p2:Product {name: "Product2", price: 25.50, stock: 50, featured: false}) + CREATE (p3:Product {name: "Product3", price: 5.99, stock: 200, featured: true}) + CREATE (p1)-[:BELONGS_TO]->(c) + CREATE (p2)-[:BELONGS_TO]->(c) + CREATE (p3)-[:BELONGS_TO]->(c) + """, + category_name=category_name, + ) + except Neo4jError as e: + pytest.fail(f"Failed to create test data: {e}") + + yield category_name + + # Clean up test data + try: + with driver.session() as session: + session.run( + """ + MATCH (p:Product)-[:BELONGS_TO]->(c:Category {name: $category_name}) + DETACH DELETE p, c + """, + category_name=category_name, + ) + except Neo4jError as e: + pytest.fail(f"Failed to clean up test data: {e}") + + +def test_cypher_retriever_basic_query(driver: neo4j.Driver, sample_data: str) -> None: + """Test basic query with CypherRetriever.""" + retriever = CypherRetriever( + driver=driver, + query="MATCH (p:Product) WHERE p.price > $min_price RETURN p ORDER BY p.price", + parameters={ + "min_price": {"type": "number", "description": "Minimum product price"} + }, + ) + + # Execute the query + result = retriever.search(parameters={"min_price": 10.0}) + + # Verify the results + assert len(result.items) == 2 + assert ( + "Product1" in result.items[0].content or "Product2" in result.items[0].content + ) + assert result.metadata is not None and "cypher" in result.metadata + + +def test_cypher_retriever_multiple_parameters( + driver: neo4j.Driver, sample_data: str +) -> None: + """Test query with multiple parameters.""" + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (p:Product) + WHERE p.price >= $min_price AND p.price <= $max_price + AND p.stock > $min_stock + RETURN p + """, + parameters={ + "min_price": {"type": "number", "description": "Minimum product price"}, + "max_price": {"type": "number", "description": "Maximum product price"}, + "min_stock": {"type": "integer", "description": "Minimum stock quantity"}, + }, + ) + + # Execute the query with parameters + result = retriever.search( + parameters={"min_price": 5.0, "max_price": 15.0, "min_stock": 50} + ) + + # Verify the results + assert len(result.items) == 2 + assert any("Product1" in item.content for item in result.items) + assert any("Product3" in item.content for item in result.items) + + +def test_cypher_retriever_optional_parameters( + driver: neo4j.Driver, sample_data: str +) -> None: + """Test query with optional parameters.""" + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (p:Product) + WHERE ($featured IS NULL OR p.featured = $featured) + RETURN p + """, + parameters={ + "featured": { + "type": "boolean", + "description": "Filter for featured products", + "required": False, + } + }, + ) + + # Execute the query with the optional parameter + result_with_param = retriever.search(parameters={"featured": True}) + + # Verify the results with parameter + assert len(result_with_param.items) == 2 + assert all("featured: true" in item.content for item in result_with_param.items) + + # Execute the query without the optional parameter + result_without_param = retriever.search(parameters={}) + + # Verify the results without parameter (should return all products) + assert len(result_without_param.items) == 3 + + +def test_cypher_retriever_relationship_traversal( + driver: neo4j.Driver, sample_data: str +) -> None: + """Test query with relationship traversal.""" + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (p:Product)-[:BELONGS_TO]->(c:Category {name: $category_name}) + RETURN p.name as product, p.price as price, c.name as category + """, + parameters={ + "category_name": {"type": "string", "description": "Category name"} + }, + ) + + # Execute the query + result = retriever.search(parameters={"category_name": sample_data}) + + # Verify the results + assert len(result.items) == 3 + assert all(sample_data in item.content for item in result.items) + + +def test_cypher_retriever_custom_formatter( + driver: neo4j.Driver, sample_data: str +) -> None: + """Test query with custom result formatter.""" + + # Custom formatter that extracts product info in a structured format + def product_formatter(record: Record) -> RetrieverResultItem: + product = record["p"] + return RetrieverResultItem( + content=f"{product['name']} - ${product['price']}", + metadata={ + "price": product["price"], + "stock": product["stock"], + "featured": product["featured"], + }, + ) + + retriever = CypherRetriever( + driver=driver, + query="MATCH (p:Product) RETURN p", + parameters={}, + result_formatter=product_formatter, + ) + + # Execute the query + result = retriever.search(parameters={}) + + # Verify the results + assert len(result.items) == 3 + + # Check custom formatting + for item in result.items: + assert " - $" in item.content + assert item.metadata is not None and "price" in item.metadata + assert item.metadata is not None and "stock" in item.metadata + assert item.metadata is not None and "featured" in item.metadata diff --git a/tests/unit/retrievers/test_cypher.py b/tests/unit/retrievers/test_cypher.py new file mode 100644 index 00000000..cc1a1c86 --- /dev/null +++ b/tests/unit/retrievers/test_cypher.py @@ -0,0 +1,299 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +import pytest +import neo4j +from neo4j import Record + +from neo4j_graphrag.exceptions import ( + RetrieverInitializationError, + SearchValidationError, +) +from neo4j_graphrag.retrievers.cypher import CypherRetriever +from neo4j_graphrag.types import RetrieverResultItem + + +class TestCypherRetriever(unittest.TestCase): + # Define class attributes for mypy + patcher1: unittest.mock._patch[MagicMock] + patcher2: unittest.mock._patch[bool] + mock_check_driver: MagicMock + + @classmethod + def setUpClass(cls) -> None: + # Patch the Neo4jDriverModel.check_driver method to pass validation with MagicMock + cls.patcher1 = patch("neo4j_graphrag.types.Neo4jDriverModel.check_driver") + cls.mock_check_driver = cls.patcher1.start() + cls.mock_check_driver.side_effect = lambda v: v + + # Patch the version check in the Retriever base class to avoid Neo4j version validation + cls.patcher2 = patch( + "neo4j_graphrag.retrievers.base.Retriever.VERIFY_NEO4J_VERSION", False + ) + cls.patcher2.start() + + @classmethod + def tearDownClass(cls) -> None: + cls.patcher1.stop() + cls.patcher2.stop() + + def setUp(self) -> None: + # Create a mock driver + self.driver = MagicMock(spec=neo4j.Driver) + self.driver.execute_query.return_value = ( + [Record({"m": {"title": "Test Movie"}, "score": 0.9})], + None, + None, + ) + + # Sample query and parameters + self.valid_query = "MATCH (m:Movie {title: $movie_title}) RETURN m" + self.valid_parameters = { + "movie_title": {"type": "string", "description": "Title of a movie"} + } + + def test_init_success(self) -> None: + # Test successful initialization + retriever = CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters=self.valid_parameters, + ) + assert retriever.query == self.valid_query + assert "movie_title" in retriever.parameters + + def test_init_empty_query(self) -> None: + # Test initialization with empty query + with pytest.raises(RetrieverInitializationError): + CypherRetriever( + driver=self.driver, + query="", + parameters=self.valid_parameters, + ) + + def test_init_invalid_query(self) -> None: + # Test initialization with invalid query + with pytest.raises(RetrieverInitializationError): + CypherRetriever( + driver=self.driver, + query="SELECT * FROM movies", # SQL, not Cypher + parameters=self.valid_parameters, + ) + + def test_init_undefined_parameters(self) -> None: + # Test initialization with undefined parameters in query + with pytest.raises(RetrieverInitializationError): + CypherRetriever( + driver=self.driver, + query="MATCH (m:Movie {title: $movie_title, year: $year}) RETURN m", + parameters=self.valid_parameters, # Missing 'year' parameter + ) + + def test_init_invalid_parameter_type(self) -> None: + # Test initialization with invalid parameter type + with pytest.raises(RetrieverInitializationError): + CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters={ + "movie_title": { + "type": "invalid_type", + "description": "Title of a movie", + } + }, + ) + + def test_search_success(self) -> None: + # Test successful search + retriever = CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters=self.valid_parameters, + ) + result = retriever.search(parameters={"movie_title": "The Matrix"}) + + # Assert driver.execute_query was called with the right parameters + self.driver.execute_query.assert_called_once() + assert result.items + assert result.metadata and "cypher" in result.metadata + assert result.metadata["cypher"] == self.valid_query + + def test_search_missing_required_parameter(self) -> None: + # Test search with missing required parameter + retriever = CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters=self.valid_parameters, + ) + with pytest.raises(SearchValidationError): + retriever.search(parameters={}) # Missing 'movie_title' + + def test_search_unexpected_parameter(self) -> None: + # Test search with unexpected parameter + retriever = CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters=self.valid_parameters, + ) + with pytest.raises(SearchValidationError): + retriever.search( + parameters={"movie_title": "The Matrix", "year": 1999} + ) # 'year' not defined + + def test_search_type_mismatch(self) -> None: + # Test search with parameter type mismatch + retriever = CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters=self.valid_parameters, + ) + with pytest.raises(SearchValidationError): + retriever.search( + parameters={"movie_title": 123} + ) # Integer, expected string + + def test_different_parameter_types(self) -> None: + # Test with different parameter types + query = ( + "MATCH (m:Movie) WHERE m.title = $title AND m.year = $year AND m.rating > $rating " + "AND m.is_available = $available AND m.genres IN $genres RETURN m" + ) + parameters = { + "title": {"type": "string", "description": "Movie title"}, + "year": {"type": "integer", "description": "Release year"}, + "rating": {"type": "number", "description": "Minimum rating"}, + "available": {"type": "boolean", "description": "Is the movie available"}, + "genres": {"type": "array", "description": "List of genres"}, + } + + retriever = CypherRetriever( + driver=self.driver, + query=query, + parameters=parameters, + ) + + # Valid parameters of different types + result = retriever.search( + parameters={ + "title": "The Matrix", + "year": 1999, + "rating": 8.5, + "available": True, + "genres": ["Action", "Sci-Fi"], + } + ) + + assert result.items + + # Test integer type validation + with pytest.raises(SearchValidationError): + retriever.search( + parameters={ + "title": "The Matrix", + "year": "1999", # String, expected integer + "rating": 8.5, + "available": True, + "genres": ["Action", "Sci-Fi"], + } + ) + + # Test number type validation + with pytest.raises(SearchValidationError): + retriever.search( + parameters={ + "title": "The Matrix", + "year": 1999, + "rating": "8.5", # String, expected number + "available": True, + "genres": ["Action", "Sci-Fi"], + } + ) + + # Test boolean type validation + with pytest.raises(SearchValidationError): + retriever.search( + parameters={ + "title": "The Matrix", + "year": 1999, + "rating": 8.5, + "available": "yes", # String, expected boolean + "genres": ["Action", "Sci-Fi"], + } + ) + + # Test array type validation + with pytest.raises(SearchValidationError): + retriever.search( + parameters={ + "title": "The Matrix", + "year": 1999, + "rating": 8.5, + "available": True, + "genres": "Action, Sci-Fi", # String, expected array + } + ) + + def test_custom_result_formatter(self) -> None: + # Test with custom result formatter + def custom_formatter(record: Record) -> RetrieverResultItem: + return RetrieverResultItem( + content=f"Movie: {record['m']['title']}", + metadata={"score": record["score"]}, + ) + + retriever = CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters=self.valid_parameters, + result_formatter=custom_formatter, + ) + + result = retriever.search(parameters={"movie_title": "The Matrix"}) + assert result.items[0].content == "Movie: Test Movie" + if result.items[0].metadata: + assert result.items[0].metadata.get("score") == 0.9 + + def test_optional_parameters(self) -> None: + # Test with optional parameters + query = "MATCH (m:Movie {title: $title}) WHERE m.year = $year RETURN m" + parameters = { + "title": {"type": "string", "description": "Movie title", "required": True}, + "year": { + "type": "integer", + "description": "Release year", + "required": False, + }, + } + + retriever = CypherRetriever( + driver=self.driver, + query=query, + parameters=parameters, + ) + + # Should succeed with only required parameters + result = retriever.search(parameters={"title": "The Matrix"}) + assert result.items + + # Should also succeed with optional parameters + result = retriever.search(parameters={"title": "The Matrix", "year": 1999}) + assert result.items + + +if __name__ == "__main__": + unittest.main()