Skip to content

Commit 9ae73cb

Browse files
committed
Add Tool class
To not rely on json schema from openai
1 parent a93a1d5 commit 9ae73cb

File tree

5 files changed

+330
-89
lines changed

5 files changed

+330
-89
lines changed

examples/tool_calls/openai_tool_calls.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,32 @@
1717

1818
from neo4j_graphrag.llm import OpenAILLM
1919
from neo4j_graphrag.llm.types import ToolCallResponse
20+
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter
2021

2122
# Load environment variables from .env file
2223
load_dotenv()
2324

24-
# Define a tool for extracting information from text
25-
TOOLS = [
26-
{
27-
"type": "function",
28-
"function": {
29-
"name": "extract_person_info",
30-
"description": "Extract information about a person from text",
31-
"parameters": {
32-
"type": "object",
33-
"properties": {
34-
"name": {"type": "string", "description": "The person's full name"},
35-
"age": {"type": "integer", "description": "The person's age"},
36-
"occupation": {
37-
"type": "string",
38-
"description": "The person's occupation",
39-
},
40-
},
41-
"required": ["name"],
42-
},
43-
},
44-
}
45-
]
25+
26+
# Create a custom Tool implementation for person info extraction
27+
parameters = ObjectParameter(
28+
description="Parameters for extracting person information",
29+
properties={
30+
"name": StringParameter(description="The person's full name"),
31+
"age": IntegerParameter(description="The person's age"),
32+
"occupation": StringParameter(description="The person's occupation"),
33+
},
34+
required_properties=["name"],
35+
additional_properties=False,
36+
)
37+
person_info_tool = Tool(
38+
name="extract_person_info",
39+
description="Extract information about a person from text",
40+
parameters=parameters,
41+
execute_func=lambda **kwargs: kwargs,
42+
)
43+
44+
# Create the tool instance
45+
TOOLS = [person_info_tool]
4646

4747

4848
def process_tool_call(response: ToolCallResponse) -> Dict[str, Any]:

src/neo4j_graphrag/llm/base.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
from __future__ import annotations
1616

1717
from abc import ABC, abstractmethod
18-
from typing import Any, Dict, List, Optional, Union
18+
from typing import Any, List, Optional, Sequence, Union
1919

2020
from neo4j_graphrag.message_history import MessageHistory
2121
from neo4j_graphrag.types import LLMMessage
2222

2323
from .types import LLMResponse, ToolCallResponse
2424

25+
from neo4j_graphrag.tool import Tool
26+
2527

2628
class LLMInterface(ABC):
2729
"""Interface for large language models.
@@ -88,7 +90,7 @@ async def ainvoke(
8890
def invoke_with_tools(
8991
self,
9092
input: str,
91-
tools: List[Dict[str, Any]],
93+
tools: Sequence[Tool],
9294
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
9395
system_instruction: Optional[str] = None,
9496
) -> ToolCallResponse:
@@ -98,7 +100,7 @@ def invoke_with_tools(
98100
99101
Args:
100102
input (str): Text sent to the LLM.
101-
tools (List[Dict[str, Any]]): List of tool definitions for the LLM to choose from.
103+
tools (Sequence[Tool]): Sequence of Tools for the LLM to choose from. Each LLM implementation should handle the conversion to its specific format.
102104
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
103105
with each message having a specific role assigned.
104106
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
@@ -115,7 +117,7 @@ def invoke_with_tools(
115117
async def ainvoke_with_tools(
116118
self,
117119
input: str,
118-
tools: List[Dict[str, Any]],
120+
tools: Sequence[Tool],
119121
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
120122
system_instruction: Optional[str] = None,
121123
) -> ToolCallResponse:
@@ -125,7 +127,7 @@ async def ainvoke_with_tools(
125127
126128
Args:
127129
input (str): Text sent to the LLM.
128-
tools (List[Dict[str, Any]]): List of tool definitions for the LLM to choose from.
130+
tools (Sequence[Tool]): Sequence of Tools for the LLM to choose from. Each LLM implementation should handle the conversion to its specific format.
129131
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
130132
with each message having a specific role assigned.
131133
system_instruction (Optional[str]): An option to override the llm system message for this invocation.

src/neo4j_graphrag/llm/openai_llm.py

+42-7
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,17 @@
1616

1717
import abc
1818
import json
19-
from typing import TYPE_CHECKING, Any, List, Optional, Iterable, Union, cast
19+
from typing import (
20+
TYPE_CHECKING,
21+
Any,
22+
Dict,
23+
List,
24+
Optional,
25+
Iterable,
26+
Sequence,
27+
Union,
28+
cast,
29+
)
2030
from openai.types.chat import (
2131
ChatCompletionMessageParam,
2232
ChatCompletionToolParam,
@@ -39,6 +49,8 @@
3949
UserMessage,
4050
)
4151

52+
from neo4j_graphrag.tool import Tool
53+
4254
if TYPE_CHECKING:
4355
import openai
4456

@@ -91,6 +103,27 @@ def get_messages(
91103
messages.append(UserMessage(content=input).model_dump())
92104
return messages # type: ignore
93105

106+
def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]:
107+
"""Convert a Tool object to OpenAI's expected format.
108+
109+
Args:
110+
tool: A Tool object to convert to OpenAI's format.
111+
112+
Returns:
113+
A dictionary in OpenAI's tool format.
114+
"""
115+
try:
116+
return {
117+
"type": "function",
118+
"function": {
119+
"name": tool.get_name(),
120+
"description": tool.get_description(),
121+
"parameters": tool.get_parameters(),
122+
},
123+
}
124+
except AttributeError:
125+
raise LLMGenerationError(f"Tool {tool} is not a valid Tool object")
126+
94127
def invoke(
95128
self,
96129
input: str,
@@ -128,7 +161,7 @@ def invoke(
128161
def invoke_with_tools(
129162
self,
130163
input: str,
131-
tools: List[dict[str, Any]], # Tools definition as a list of dictionaries
164+
tools: Sequence[Tool], # Tools definition as a sequence of Tool objects
132165
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
133166
system_instruction: Optional[str] = None,
134167
) -> ToolCallResponse:
@@ -137,7 +170,7 @@ def invoke_with_tools(
137170
138171
Args:
139172
input (str): Text sent to the LLM.
140-
tools (List[Dict[str, Any]]): List of tool definitions for the LLM to choose from.
173+
tools (List[Tool]): List of Tools for the LLM to choose from.
141174
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
142175
with each message having a specific role assigned.
143176
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
@@ -159,7 +192,8 @@ def invoke_with_tools(
159192
# Convert tools to OpenAI's expected type
160193
openai_tools: List[ChatCompletionToolParam] = []
161194
for tool in tools:
162-
openai_tools.append(cast(ChatCompletionToolParam, tool))
195+
openai_format_tool = self._convert_tool_to_openai_format(tool)
196+
openai_tools.append(cast(ChatCompletionToolParam, openai_format_tool))
163197

164198
response = self.client.chat.completions.create(
165199
messages=self.get_messages(input, message_history, system_instruction),
@@ -235,7 +269,7 @@ async def ainvoke(
235269
async def ainvoke_with_tools(
236270
self,
237271
input: str,
238-
tools: List[dict[str, Any]], # Tools definition as a list of dictionaries
272+
tools: Sequence[Tool], # Tools definition as a sequence of Tool objects
239273
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
240274
system_instruction: Optional[str] = None,
241275
) -> ToolCallResponse:
@@ -244,7 +278,7 @@ async def ainvoke_with_tools(
244278
245279
Args:
246280
input (str): Text sent to the LLM.
247-
tools (List[Dict[str, Any]]): List of tool definitions for the LLM to choose from.
281+
tools (List[Tool]): List of Tools for the LLM to choose from.
248282
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
249283
with each message having a specific role assigned.
250284
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
@@ -266,7 +300,8 @@ async def ainvoke_with_tools(
266300
# Convert tools to OpenAI's expected type
267301
openai_tools: List[ChatCompletionToolParam] = []
268302
for tool in tools:
269-
openai_tools.append(cast(ChatCompletionToolParam, tool))
303+
openai_format_tool = self._convert_tool_to_openai_format(tool)
304+
openai_tools.append(cast(ChatCompletionToolParam, openai_format_tool))
270305

271306
response = await self.async_client.chat.completions.create(
272307
messages=self.get_messages(input, message_history, system_instruction),

0 commit comments

Comments
 (0)