diff --git a/CHANGELOG.md b/CHANGELOG.md index fcec4548..3aaf3660 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Added +- Added tool calling functionality to the LLM base class with OpenAI implementation, enabling structured parameter extraction and function calling. - Added support for multi-vector collection in Qdrant driver. - Added a `Pipeline.stream` method to stream pipeline progress. - Added a new semantic match resolver to the KG Builder for entity resolution based on spaCy embeddings and cosine similarities so that nodes with similar textual properties get merged. diff --git a/examples/README.md b/examples/README.md index 9c2c8eb0..b1b06f93 100644 --- a/examples/README.md +++ b/examples/README.md @@ -78,6 +78,8 @@ are listed in [the last section of this file](#customize). - [Message history with Neo4j](./customize/llms/llm_with_neo4j_message_history.py) - [System Instruction](./customize/llms/llm_with_system_instructions.py) +- [Tool Calling with OpenAI](./customize/llms/openai_tool_calls.py) + ### Prompts diff --git a/examples/customize/llms/openai_tool_calls.py b/examples/customize/llms/openai_tool_calls.py new file mode 100644 index 00000000..166fb724 --- /dev/null +++ b/examples/customize/llms/openai_tool_calls.py @@ -0,0 +1,101 @@ +""" +Example showing how to use OpenAI tool calls with parameter extraction. +Both synchronous and asynchronous examples are provided. + +To run this example: +1. Make sure you have the OpenAI API key in your .env file: + OPENAI_API_KEY=your-api-key +2. Run: python examples/tool_calls/openai_tool_calls.py +""" + +import asyncio +import json +import os +from typing import Dict, Any + +from dotenv import load_dotenv + +from neo4j_graphrag.llm import OpenAILLM +from neo4j_graphrag.llm.types import ToolCallResponse +from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter + +# Load environment variables from .env file (OPENAI_API_KEY required for this example) +load_dotenv() + + +# Create a custom Tool implementation for person info extraction +parameters = ObjectParameter( + description="Parameters for extracting person information", + properties={ + "name": StringParameter(description="The person's full name"), + "age": IntegerParameter(description="The person's age"), + "occupation": StringParameter(description="The person's occupation"), + }, + required_properties=["name"], + additional_properties=False, +) +person_info_tool = Tool( + name="extract_person_info", + description="Extract information about a person from text", + parameters=parameters, + execute_func=lambda **kwargs: kwargs, +) + +# Create the tool instance +TOOLS = [person_info_tool] + + +def process_tool_calls(response: ToolCallResponse) -> Dict[str, Any]: + """Process all tool calls in the response and return the extracted parameters.""" + if not response.tool_calls: + raise ValueError("No tool calls found in response") + + print(f"\nNumber of tool calls: {len(response.tool_calls)}") + print(f"Additional content: {response.content or 'None'}") + + results = [] + for i, tool_call in enumerate(response.tool_calls): + print(f"\nTool call #{i + 1}: {tool_call.name}") + print(f"Arguments: {tool_call.arguments}") + results.append(tool_call.arguments) + + # For backward compatibility, return the first tool call's arguments + return results[0] if results else {} + + +async def main() -> None: + # Initialize the OpenAI LLM + llm = OpenAILLM( + api_key=os.getenv("OPENAI_API_KEY"), + model_name="gpt-4o", + model_params={"temperature": 0}, + ) + + # Example text containing information about a person + text = "Stella Hane is a 35-year-old software engineer who loves coding." + + print("\n=== Synchronous Tool Call ===") + # Make a synchronous tool call + sync_response = llm.invoke_with_tools( + input=f"Extract information about the person from this text: {text}", + tools=TOOLS, + ) + sync_result = process_tool_calls(sync_response) + print("\n=== Synchronous Tool Call Result ===") + print(json.dumps(sync_result, indent=2)) + + print("\n=== Asynchronous Tool Call ===") + # Make an asynchronous tool call with a different text + text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning." + async_response = await llm.ainvoke_with_tools( + input=f"Extract information about the person from this text: {text2}", + tools=TOOLS, + ) + async_result = process_tool_calls(async_response) + print("\n=== Asynchronous Tool Call Result ===") + print(json.dumps(async_result, indent=2)) + + +if __name__ == "__main__": + # Run the async main function + asyncio.run(main()) diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index f2ca5170..87d28179 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -15,12 +15,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Sequence, Union from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage -from .types import LLMResponse +from .types import LLMResponse, ToolCallResponse + +from neo4j_graphrag.tool import Tool class LLMInterface(ABC): @@ -84,3 +86,57 @@ async def ainvoke( Raises: LLMGenerationError: If anything goes wrong. """ + + def invoke_with_tools( + self, + input: str, + tools: Sequence[Tool], + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> ToolCallResponse: + """Sends a text input to the LLM with tool definitions and retrieves a tool call response. + + This is a default implementation that should be overridden by LLM providers that support tool/function calling. + + Args: + input (str): Text sent to the LLM. + tools (Sequence[Tool]): Sequence of Tools for the LLM to choose from. Each LLM implementation should handle the conversion to its specific format. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. + + Returns: + ToolCallResponse: The response from the LLM containing a tool call. + + Raises: + LLMGenerationError: If anything goes wrong. + NotImplementedError: If the LLM provider does not support tool calling. + """ + raise NotImplementedError("This LLM provider does not support tool calling.") + + async def ainvoke_with_tools( + self, + input: str, + tools: Sequence[Tool], + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> ToolCallResponse: + """Asynchronously sends a text input to the LLM with tool definitions and retrieves a tool call response. + + This is a default implementation that should be overridden by LLM providers that support tool/function calling. + + Args: + input (str): Text sent to the LLM. + tools (Sequence[Tool]): Sequence of Tools for the LLM to choose from. Each LLM implementation should handle the conversion to its specific format. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. + + Returns: + ToolCallResponse: The response from the LLM containing a tool call. + + Raises: + LLMGenerationError: If anything goes wrong. + NotImplementedError: If the LLM provider does not support tool calling. + """ + raise NotImplementedError("This LLM provider does not support tool calling.") diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index cbf889d6..1e0228e4 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -15,7 +15,22 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast +import json +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Iterable, + Sequence, + Union, + cast, +) +from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionToolParam, +) from pydantic import ValidationError @@ -28,15 +43,16 @@ BaseMessage, LLMResponse, MessageList, + ToolCall, + ToolCallResponse, SystemMessage, UserMessage, ) +from neo4j_graphrag.tool import Tool + if TYPE_CHECKING: import openai - from openai.types.chat.chat_completion_message_param import ( - ChatCompletionMessageParam, - ) class BaseOpenAILLM(LLMInterface, abc.ABC): @@ -87,6 +103,27 @@ def get_messages( messages.append(UserMessage(content=input).model_dump()) return messages # type: ignore + def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]: + """Convert a Tool object to OpenAI's expected format. + + Args: + tool: A Tool object to convert to OpenAI's format. + + Returns: + A dictionary in OpenAI's tool format. + """ + try: + return { + "type": "function", + "function": { + "name": tool.get_name(), + "description": tool.get_description(), + "parameters": tool.get_parameters(), + }, + } + except AttributeError: + raise LLMGenerationError(f"Tool {tool} is not a valid Tool object") + def invoke( self, input: str, @@ -121,6 +158,80 @@ def invoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) + def invoke_with_tools( + self, + input: str, + tools: Sequence[Tool], # Tools definition as a sequence of Tool objects + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> ToolCallResponse: + """Sends a text input to the OpenAI chat completion model with tool definitions + and retrieves a tool call response. + + Args: + input (str): Text sent to the LLM. + tools (List[Tool]): List of Tools for the LLM to choose from. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. + + Returns: + ToolCallResponse: The response from the LLM containing a tool call. + + Raises: + LLMGenerationError: If anything goes wrong. + """ + try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + + params = self.model_params.copy() if self.model_params else {} + if "temperature" not in params: + params["temperature"] = 0.0 + + # Convert tools to OpenAI's expected type + openai_tools: List[ChatCompletionToolParam] = [] + for tool in tools: + openai_format_tool = self._convert_tool_to_openai_format(tool) + openai_tools.append(cast(ChatCompletionToolParam, openai_format_tool)) + + response = self.client.chat.completions.create( + messages=self.get_messages(input, message_history, system_instruction), + model=self.model_name, + tools=openai_tools, + tool_choice="auto", + **params, + ) + + message = response.choices[0].message + + # If there's no tool call, return the content as a regular response + if not message.tool_calls or len(message.tool_calls) == 0: + return ToolCallResponse( + tool_calls=[], + content=message.content, + ) + + # Process all tool calls + tool_calls = [] + + for tool_call in message.tool_calls: + try: + args = json.loads(tool_call.function.arguments) + except (json.JSONDecodeError, AttributeError) as e: + raise LLMGenerationError( + f"Failed to parse tool call arguments: {e}" + ) + + tool_calls.append( + ToolCall(name=tool_call.function.name, arguments=args) + ) + + return ToolCallResponse(tool_calls=tool_calls, content=message.content) + + except self.openai.OpenAIError as e: + raise LLMGenerationError(e) + async def ainvoke( self, input: str, @@ -155,6 +266,81 @@ async def ainvoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) + async def ainvoke_with_tools( + self, + input: str, + tools: Sequence[Tool], # Tools definition as a sequence of Tool objects + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> ToolCallResponse: + """Asynchronously sends a text input to the OpenAI chat completion model with tool definitions + and retrieves a tool call response. + + Args: + input (str): Text sent to the LLM. + tools (List[Tool]): List of Tools for the LLM to choose from. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. + + Returns: + ToolCallResponse: The response from the LLM containing a tool call. + + Raises: + LLMGenerationError: If anything goes wrong. + """ + try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + + params = self.model_params.copy() + if "temperature" not in params: + params["temperature"] = 0.0 + + # Convert tools to OpenAI's expected type + openai_tools: List[ChatCompletionToolParam] = [] + for tool in tools: + openai_format_tool = self._convert_tool_to_openai_format(tool) + openai_tools.append(cast(ChatCompletionToolParam, openai_format_tool)) + + response = await self.async_client.chat.completions.create( + messages=self.get_messages(input, message_history, system_instruction), + model=self.model_name, + tools=openai_tools, + tool_choice="auto", + **params, + ) + + message = response.choices[0].message + + # If there's no tool call, return the content as a regular response + if not message.tool_calls or len(message.tool_calls) == 0: + return ToolCallResponse( + tool_calls=[ToolCall(name="", arguments={})], + content=message.content or "", + ) + + # Process all tool calls + tool_calls = [] + import json + + for tool_call in message.tool_calls: + try: + args = json.loads(tool_call.function.arguments) + except (json.JSONDecodeError, AttributeError) as e: + raise LLMGenerationError( + f"Failed to parse tool call arguments: {e}" + ) + + tool_calls.append( + ToolCall(name=tool_call.function.name, arguments=args) + ) + + return ToolCallResponse(tool_calls=tool_calls, content=message.content) + + except self.openai.OpenAIError as e: + raise LLMGenerationError(e) + class OpenAILLM(BaseOpenAILLM): def __init__( diff --git a/src/neo4j_graphrag/llm/types.py b/src/neo4j_graphrag/llm/types.py index 21f10922..34d68c33 100644 --- a/src/neo4j_graphrag/llm/types.py +++ b/src/neo4j_graphrag/llm/types.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Literal +from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel @@ -36,3 +36,17 @@ class SystemMessage(BaseMessage): class MessageList(BaseModel): messages: list[BaseMessage] + + +class ToolCall(BaseModel): + """A tool call made by an LLM.""" + + name: str + arguments: Dict[str, Any] + + +class ToolCallResponse(BaseModel): + """Response from an LLM containing tool calls.""" + + tool_calls: List[ToolCall] + content: Optional[str] = None diff --git a/src/neo4j_graphrag/tool.py b/src/neo4j_graphrag/tool.py new file mode 100644 index 00000000..63aac668 --- /dev/null +++ b/src/neo4j_graphrag/tool.py @@ -0,0 +1,263 @@ +from abc import ABC +from enum import Enum +from typing import Any, Dict, List, Callable, Optional, Union, ClassVar +from pydantic import BaseModel, Field, model_validator + + +class ParameterType(str, Enum): + """Enum for parameter types supported in tool parameters.""" + + STRING = "string" + INTEGER = "integer" + NUMBER = "number" + BOOLEAN = "boolean" + OBJECT = "object" + ARRAY = "array" + + +class ToolParameter(BaseModel): + """Base class for all tool parameters using Pydantic.""" + + description: str + required: bool = False + type: ClassVar[ParameterType] + + def model_dump_tool(self) -> Dict[str, Any]: + """Convert the parameter to a dictionary format for tool usage.""" + result: Dict[str, Any] = {"type": self.type, "description": self.description} + if self.required: + result["required"] = True + return result + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ToolParameter": + """Create a parameter from a dictionary.""" + param_type = data.get("type") + if not param_type: + raise ValueError("Parameter type is required") + + # Find the appropriate class based on the type + param_classes = { + ParameterType.STRING: StringParameter, + ParameterType.INTEGER: IntegerParameter, + ParameterType.NUMBER: NumberParameter, + ParameterType.BOOLEAN: BooleanParameter, + ParameterType.OBJECT: ObjectParameter, + ParameterType.ARRAY: ArrayParameter, + } + + param_class = param_classes.get(param_type) + if not param_class: + raise ValueError(f"Unknown parameter type: {param_type}") + + # Use type ignore since mypy doesn't understand dynamic class instantiation + return param_class.model_validate(data) # type: ignore + + +class StringParameter(ToolParameter): + """String parameter for tools.""" + + type: ClassVar[ParameterType] = ParameterType.STRING + enum: Optional[List[str]] = None + + def model_dump_tool(self) -> Dict[str, Any]: + result = super().model_dump_tool() + if self.enum: + result["enum"] = self.enum + return result + + +class IntegerParameter(ToolParameter): + """Integer parameter for tools.""" + + type: ClassVar[ParameterType] = ParameterType.INTEGER + minimum: Optional[int] = None + maximum: Optional[int] = None + + def model_dump_tool(self) -> Dict[str, Any]: + result = super().model_dump_tool() + if self.minimum is not None: + result["minimum"] = self.minimum + if self.maximum is not None: + result["maximum"] = self.maximum + return result + + +class NumberParameter(ToolParameter): + """Number parameter for tools.""" + + type: ClassVar[ParameterType] = ParameterType.NUMBER + minimum: Optional[float] = None + maximum: Optional[float] = None + + def model_dump_tool(self) -> Dict[str, Any]: + result = super().model_dump_tool() + if self.minimum is not None: + result["minimum"] = self.minimum + if self.maximum is not None: + result["maximum"] = self.maximum + return result + + +class BooleanParameter(ToolParameter): + """Boolean parameter for tools.""" + + type: ClassVar[ParameterType] = ParameterType.BOOLEAN + + +class ArrayParameter(ToolParameter): + """Array parameter for tools.""" + + type: ClassVar[ParameterType] = ParameterType.ARRAY + items: "ToolParameter" + min_items: Optional[int] = None + max_items: Optional[int] = None + + @model_validator(mode="before") + @classmethod + def _preprocess_items(cls, values: dict[str, Any]) -> dict[str, Any]: + # Convert items from dict to ToolParameter if needed + items = values.get("items") + if isinstance(items, dict): + values["items"] = ToolParameter.from_dict(items) + return values + + def model_dump_tool(self) -> Dict[str, Any]: + result = super().model_dump_tool() + result["items"] = self.items.model_dump_tool() + if self.min_items is not None: + result["minItems"] = self.min_items + if self.max_items is not None: + result["maxItems"] = self.max_items + return result + + @model_validator(mode="after") + def validate_items(self) -> "ArrayParameter": + if not isinstance(self.items, ToolParameter): + if isinstance(self.items, dict): + self.items = ToolParameter.from_dict(self.items) + else: + raise ValueError( + f"Items must be a ToolParameter or dict, got {type(self.items)}" + ) + elif type(self.items) is ToolParameter: + # Promote base ToolParameter to correct subclass if possible + self.items = ToolParameter.from_dict(self.items.model_dump()) + return self + + +class ObjectParameter(ToolParameter): + """Object parameter for tools.""" + + type: ClassVar[ParameterType] = ParameterType.OBJECT + properties: Dict[str, ToolParameter] + required_properties: List[str] = Field(default_factory=list) + additional_properties: bool = True + + @model_validator(mode="before") + @classmethod + def _preprocess_properties(cls, values: dict[str, Any]) -> dict[str, Any]: + # Convert properties from dicts to ToolParameter if needed + props = values.get("properties") + if isinstance(props, dict): + new_props = {} + for k, v in props.items(): + if isinstance(v, dict): + new_props[k] = ToolParameter.from_dict(v) + else: + new_props[k] = v + values["properties"] = new_props + return values + + def model_dump_tool(self) -> Dict[str, Any]: + properties_dict: Dict[str, Any] = {} + for name, param in self.properties.items(): + properties_dict[name] = param.model_dump_tool() + + result = super().model_dump_tool() + result["properties"] = properties_dict + + if self.required_properties: + result["required"] = self.required_properties + + if not self.additional_properties: + result["additionalProperties"] = False + + return result + + @model_validator(mode="after") + def validate_properties(self) -> "ObjectParameter": + validated_properties = {} + for name, param in self.properties.items(): + if not isinstance(param, ToolParameter): + if isinstance(param, dict): + validated_properties[name] = ToolParameter.from_dict(param) + else: + raise ValueError( + f"Property {name} must be a ToolParameter or dict, got {type(param)}" + ) + elif type(param) is ToolParameter: + # Promote base ToolParameter to correct subclass if possible + validated_properties[name] = ToolParameter.from_dict(param.model_dump()) + else: + validated_properties[name] = param + self.properties = validated_properties + return self + + +class Tool(ABC): + """Abstract base class defining the interface for all tools in the neo4j-graphrag library.""" + + def __init__( + self, + name: str, + description: str, + parameters: Union[ObjectParameter, Dict[str, Any]], + execute_func: Callable[..., Any], + ): + self._name = name + self._description = description + + # Allow parameters to be provided as a dictionary + if isinstance(parameters, dict): + self._parameters = ObjectParameter.model_validate(parameters) + else: + self._parameters = parameters + + self._execute_func = execute_func + + def get_name(self) -> str: + """Get the name of the tool. + + Returns: + str: Name of the tool. + """ + return self._name + + def get_description(self) -> str: + """Get a detailed description of what the tool does. + + Returns: + str: Description of the tool. + """ + return self._description + + def get_parameters(self) -> Dict[str, Any]: + """Get the parameters the tool accepts in a dictionary format suitable for LLM providers. + + Returns: + Dict[str, Any]: Dictionary containing parameter schema information. + """ + return self._parameters.model_dump_tool() + + def execute(self, query: str, **kwargs: Any) -> Any: + """Execute the tool with the given query and additional parameters. + + Args: + query (str): The query or input for the tool to process. + **kwargs (Any): Additional parameters for the tool. + + Returns: + Any: The result of the tool execution. + """ + return self._execute_func(query, **kwargs) diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 03fbf120..4220f3b3 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -19,6 +19,8 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM +from neo4j_graphrag.llm.types import ToolCallResponse +from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter def get_mock_openai() -> MagicMock: @@ -27,6 +29,25 @@ def get_mock_openai() -> MagicMock: return mock +class TestTool(Tool): + """Test tool for unit tests.""" + + def __init__(self, name: str = "test_tool", description: str = "A test tool"): + parameters = ObjectParameter( + description="Test parameters", + properties={"param1": StringParameter(description="Test parameter")}, + required_properties=["param1"], + additional_properties=False, + ) + + super().__init__( + name=name, + description=description, + parameters=parameters, + execute_func=lambda **kwargs: kwargs, + ) + + @patch("builtins.__import__", side_effect=ImportError) def test_openai_llm_missing_dependency(mock_import: Mock) -> None: with pytest.raises(ImportError): @@ -65,10 +86,14 @@ def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: assert isinstance(res, LLMResponse) assert res.content == "openai chat response" message_history.append({"role": "user", "content": question}) - llm.client.chat.completions.create.assert_called_once_with( # type: ignore - messages=message_history, - model="gpt", - ) + # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions + llm.client.chat.completions.create.assert_called_once() # type: ignore + # Check call arguments individually + call_args = llm.client.chat.completions.create.call_args[ # type: ignore + 1 + ] # Get the keyword arguments + assert call_args["messages"] == message_history + assert call_args["model"] == "gpt" @patch("builtins.__import__") @@ -97,10 +122,14 @@ def test_openai_llm_with_message_history_and_system_instruction( messages = [{"role": "system", "content": system_instruction}] messages.extend(message_history) messages.append({"role": "user", "content": question}) - llm.client.chat.completions.create.assert_called_once_with( # type: ignore - messages=messages, - model="gpt", - ) + # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions + llm.client.chat.completions.create.assert_called_once() # type: ignore + # Check call arguments individually + call_args = llm.client.chat.completions.create.call_args[ # type: ignore + 1 + ] # Get the keyword arguments + assert call_args["messages"] == messages + assert call_args["model"] == "gpt" assert llm.client.chat.completions.create.call_count == 1 # type: ignore @@ -124,6 +153,183 @@ def test_openai_llm_with_message_history_validation_error(mock_import: Mock) -> assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) +@patch("builtins.__import__") +@patch("json.loads") +def test_openai_llm_invoke_with_tools_happy_path( + mock_json_loads: Mock, mock_import: Mock +) -> None: + # Set up json.loads to return a dictionary + mock_json_loads.return_value = {"param1": "value1"} + + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + + # Mock the tool call response + mock_function = MagicMock() + mock_function.name = "test_tool" + mock_function.arguments = '{"param1": "value1"}' + + mock_tool_call = MagicMock() + mock_tool_call.function = mock_function + + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content="openai tool response", tool_calls=[mock_tool_call] + ) + ) + ], + ) + + llm = OpenAILLM(api_key="my key", model_name="gpt") + tools = [TestTool()] + + res = llm.invoke_with_tools("my text", tools) + assert isinstance(res, ToolCallResponse) + assert len(res.tool_calls) == 1 + assert res.tool_calls[0].name == "test_tool" + assert res.tool_calls[0].arguments == {"param1": "value1"} + assert res.content == "openai tool response" + + +@patch("builtins.__import__") +@patch("json.loads") +def test_openai_llm_invoke_with_tools_with_message_history( + mock_json_loads: Mock, mock_import: Mock +) -> None: + # Set up json.loads to return a dictionary + mock_json_loads.return_value = {"param1": "value1"} + + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + + # Mock the tool call response + mock_function = MagicMock() + mock_function.name = "test_tool" + mock_function.arguments = '{"param1": "value1"}' + + mock_tool_call = MagicMock() + mock_tool_call.function = mock_function + + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content="openai tool response", tool_calls=[mock_tool_call] + ) + ) + ], + ) + + llm = OpenAILLM(api_key="my key", model_name="gpt") + tools = [TestTool()] + + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + res = llm.invoke_with_tools(question, tools, message_history) # type: ignore + assert isinstance(res, ToolCallResponse) + assert len(res.tool_calls) == 1 + assert res.tool_calls[0].name == "test_tool" + assert res.tool_calls[0].arguments == {"param1": "value1"} + + # Verify the correct messages were passed + message_history.append({"role": "user", "content": question}) + # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions + llm.client.chat.completions.create.assert_called_once() # type: ignore + # Check call arguments individually + call_args = llm.client.chat.completions.create.call_args[ # type: ignore + 1 + ] # Get the keyword arguments + assert call_args["messages"] == message_history + assert call_args["model"] == "gpt" + # Check tools content rather than direct equality + assert len(call_args["tools"]) == 1 + assert call_args["tools"][0]["type"] == "function" + assert call_args["tools"][0]["function"]["name"] == "test_tool" + assert call_args["tools"][0]["function"]["description"] == "A test tool" + assert call_args["tool_choice"] == "auto" + assert call_args["temperature"] == 0.0 + + +@patch("builtins.__import__") +@patch("json.loads") +def test_openai_llm_invoke_with_tools_with_system_instruction( + mock_json_loads: Mock, mock_import: Mock +) -> None: + # Set up json.loads to return a dictionary + mock_json_loads.return_value = {"param1": "value1"} + + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + + # Mock the tool call response + mock_function = MagicMock() + mock_function.name = "test_tool" + mock_function.arguments = '{"param1": "value1"}' + + mock_tool_call = MagicMock() + mock_tool_call.function = mock_function + + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content="openai tool response", tool_calls=[mock_tool_call] + ) + ) + ], + ) + + llm = OpenAILLM(api_key="my key", model_name="gpt") + tools = [TestTool()] + + system_instruction = "You are a helpful assistant." + + res = llm.invoke_with_tools("my text", tools, system_instruction=system_instruction) + assert isinstance(res, ToolCallResponse) + + # Verify system instruction was included + messages = [{"role": "system", "content": system_instruction}] + messages.append({"role": "user", "content": "my text"}) + # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions + llm.client.chat.completions.create.assert_called_once() # type: ignore + # Check call arguments individually + call_args = llm.client.chat.completions.create.call_args[ # type: ignore + 1 + ] # Get the keyword arguments + assert call_args["messages"] == messages + assert call_args["model"] == "gpt" + # Check tools content rather than direct equality + assert len(call_args["tools"]) == 1 + assert call_args["tools"][0]["type"] == "function" + assert call_args["tools"][0]["function"]["name"] == "test_tool" + assert call_args["tools"][0]["function"]["description"] == "A test tool" + assert call_args["tool_choice"] == "auto" + assert call_args["temperature"] == 0.0 + + +@patch("builtins.__import__") +def test_openai_llm_invoke_with_tools_error(mock_import: Mock) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + + # Mock an OpenAI error + mock_openai.OpenAI.return_value.chat.completions.create.side_effect = ( + openai.OpenAIError("Test error") + ) + + llm = OpenAILLM(api_key="my key", model_name="gpt") + tools = [TestTool()] + + with pytest.raises(LLMGenerationError): + llm.invoke_with_tools("my text", tools) + + @patch("builtins.__import__", side_effect=ImportError) def test_azure_openai_llm_missing_dependency(mock_import: Mock) -> None: with pytest.raises(ImportError): @@ -177,10 +383,14 @@ def test_azure_openai_llm_with_message_history_happy_path(mock_import: Mock) -> assert isinstance(res, LLMResponse) assert res.content == "openai chat response" message_history.append({"role": "user", "content": question}) - llm.client.chat.completions.create.assert_called_once_with( # type: ignore - messages=message_history, - model="gpt", - ) + # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions + llm.client.chat.completions.create.assert_called_once() # type: ignore + # Check call arguments individually + call_args = llm.client.chat.completions.create.call_args[ # type: ignore + 1 + ] # Get the keyword arguments + assert call_args["messages"] == message_history + assert call_args["model"] == "gpt" @patch("builtins.__import__") diff --git a/tests/unit/tool/test_tool.py b/tests/unit/tool/test_tool.py new file mode 100644 index 00000000..b3b1d5dd --- /dev/null +++ b/tests/unit/tool/test_tool.py @@ -0,0 +1,208 @@ +import pytest +from typing import Any +from neo4j_graphrag.tool import ( + StringParameter, + IntegerParameter, + NumberParameter, + BooleanParameter, + ArrayParameter, + ObjectParameter, + Tool, + ToolParameter, + ParameterType, +) + + +def test_string_parameter() -> None: + param = StringParameter(description="A string", required=True, enum=["a", "b"]) + assert param.description == "A string" + assert param.required is True + assert param.enum == ["a", "b"] + d = param.model_dump_tool() + assert d["type"] == ParameterType.STRING + assert d["enum"] == ["a", "b"] + assert d["required"] is True + + +def test_integer_parameter() -> None: + param = IntegerParameter(description="An int", minimum=0, maximum=10) + d = param.model_dump_tool() + assert d["type"] == ParameterType.INTEGER + assert d["minimum"] == 0 + assert d["maximum"] == 10 + + +def test_number_parameter() -> None: + param = NumberParameter(description="A number", minimum=1.5, maximum=3.5) + d = param.model_dump_tool() + assert d["type"] == ParameterType.NUMBER + assert d["minimum"] == 1.5 + assert d["maximum"] == 3.5 + + +def test_boolean_parameter() -> None: + param = BooleanParameter(description="A bool") + d = param.model_dump_tool() + assert d["type"] == ParameterType.BOOLEAN + assert d["description"] == "A bool" + + +def test_array_parameter_and_validation() -> None: + arr_param = ArrayParameter( + description="An array", + items=StringParameter(description="str"), + min_items=1, + max_items=5, + ) + d = arr_param.model_dump_tool() + assert d["type"] == ParameterType.ARRAY + assert d["items"]["type"] == ParameterType.STRING + assert d["minItems"] == 1 + assert d["maxItems"] == 5 + + # Test items as dict + arr_param2 = ArrayParameter( + description="Arr with dict", + items={"type": "string", "description": "str"}, # type: ignore + ) + assert isinstance(arr_param2.items, StringParameter) + + # Test error on invalid items + with pytest.raises(ValueError): + # Use type: ignore to bypass type checking for this intentional error case + ArrayParameter(description="bad", items=123).validate_items() # type: ignore + + +def test_object_parameter_and_validation() -> None: + obj_param = ObjectParameter( + description="Obj", + properties={ + "foo": StringParameter(description="foo"), + "bar": IntegerParameter(description="bar"), + }, + required_properties=["foo"], + additional_properties=False, + ) + d = obj_param.model_dump_tool() + assert d["type"] == ParameterType.OBJECT + assert d["properties"]["foo"]["type"] == ParameterType.STRING + assert d["required"] == ["foo"] + assert d["additionalProperties"] is False + + # Test properties as dicts + obj_param2 = ObjectParameter( + description="Obj2", + properties={ + "foo": {"type": "string", "description": "foo"}, # type: ignore + }, + ) + assert isinstance(obj_param2.properties["foo"], StringParameter) + + # Test error on invalid property + with pytest.raises(ValueError): + # Use type: ignore to bypass type checking for this intentional error case + ObjectParameter( + description="bad", + properties={"foo": 123}, # type: ignore + ).validate_properties() + + +def test_from_dict() -> None: + d = {"type": ParameterType.STRING, "description": "desc"} + param = ToolParameter.from_dict(d) + assert isinstance(param, StringParameter) + assert param.description == "desc" + + obj_dict = { + "type": "object", + "description": "obj", + "properties": {"foo": {"type": "string", "description": "foo"}}, + } + obj_param = ToolParameter.from_dict(obj_dict) + assert isinstance(obj_param, ObjectParameter) + assert isinstance(obj_param.properties["foo"], StringParameter) + + arr_dict = { + "type": "array", + "description": "arr", + "items": {"type": "integer", "description": "int"}, + } + arr_param = ToolParameter.from_dict(arr_dict) + assert isinstance(arr_param, ArrayParameter) + assert isinstance(arr_param.items, IntegerParameter) + + # Test unknown type + with pytest.raises(ValueError): + ToolParameter.from_dict({"type": "unknown", "description": "bad"}) + + # Test missing type + with pytest.raises(ValueError): + ToolParameter.from_dict({"description": "no type"}) + + +def test_required_parameter() -> None: + # Test that required=True is included in model_dump_tool output for different parameter types + string_param = StringParameter(description="Required string", required=True) + assert string_param.model_dump_tool()["required"] is True + + integer_param = IntegerParameter(description="Required integer", required=True) + assert integer_param.model_dump_tool()["required"] is True + + number_param = NumberParameter(description="Required number", required=True) + assert number_param.model_dump_tool()["required"] is True + + boolean_param = BooleanParameter(description="Required boolean", required=True) + assert boolean_param.model_dump_tool()["required"] is True + + array_param = ArrayParameter( + description="Required array", + items=StringParameter(description="item"), + required=True, + ) + assert array_param.model_dump_tool()["required"] is True + + object_param = ObjectParameter( + description="Required object", + properties={"prop": StringParameter(description="property")}, + required=True, + ) + assert object_param.model_dump_tool()["required"] is True + + # Test that required=False doesn't include the required field + optional_param = StringParameter(description="Optional string", required=False) + assert "required" not in optional_param.model_dump_tool() + + +def test_tool_class() -> None: + def dummy_func(query: str, **kwargs: Any) -> dict[str, Any]: + return kwargs + + params = ObjectParameter( + description="params", + properties={"a": StringParameter(description="a")}, + ) + tool = Tool( + name="mytool", + description="desc", + parameters=params, + execute_func=dummy_func, + ) + assert tool.get_name() == "mytool" + assert tool.get_description() == "desc" + assert tool.get_parameters()["type"] == ParameterType.OBJECT + assert tool.execute("query", a="b") == {"a": "b"} + + # Test parameters as dict + params_dict = { + "type": "object", + "description": "params", + "properties": {"a": {"type": "string", "description": "a"}}, + } + tool2 = Tool( + name="mytool2", + description="desc2", + parameters=params_dict, + execute_func=dummy_func, + ) + assert tool2.get_parameters()["type"] == ParameterType.OBJECT + assert tool2.execute("query", a="b") == {"a": "b"}