16
16
17
17
import abc
18
18
import json
19
- from typing import TYPE_CHECKING , Any , List , Optional , Iterable , Union , cast
19
+ from typing import (
20
+ TYPE_CHECKING ,
21
+ Any ,
22
+ Dict ,
23
+ List ,
24
+ Optional ,
25
+ Iterable ,
26
+ Sequence ,
27
+ Union ,
28
+ cast ,
29
+ )
20
30
from openai .types .chat import (
21
31
ChatCompletionMessageParam ,
22
32
ChatCompletionToolParam ,
39
49
UserMessage ,
40
50
)
41
51
52
+ from neo4j_graphrag .tool import Tool
53
+
42
54
if TYPE_CHECKING :
43
55
import openai
44
56
@@ -91,6 +103,27 @@ def get_messages(
91
103
messages .append (UserMessage (content = input ).model_dump ())
92
104
return messages # type: ignore
93
105
106
+ def _convert_tool_to_openai_format (self , tool : Tool ) -> Dict [str , Any ]:
107
+ """Convert a Tool object to OpenAI's expected format.
108
+
109
+ Args:
110
+ tool: A Tool object to convert to OpenAI's format.
111
+
112
+ Returns:
113
+ A dictionary in OpenAI's tool format.
114
+ """
115
+ try :
116
+ return {
117
+ "type" : "function" ,
118
+ "function" : {
119
+ "name" : tool .get_name (),
120
+ "description" : tool .get_description (),
121
+ "parameters" : tool .get_parameters (),
122
+ },
123
+ }
124
+ except AttributeError :
125
+ raise LLMGenerationError (f"Tool { tool } is not a valid Tool object" )
126
+
94
127
def invoke (
95
128
self ,
96
129
input : str ,
@@ -128,7 +161,7 @@ def invoke(
128
161
def invoke_with_tools (
129
162
self ,
130
163
input : str ,
131
- tools : List [ dict [ str , Any ]] , # Tools definition as a list of dictionaries
164
+ tools : Sequence [ Tool ] , # Tools definition as a sequence of Tool objects
132
165
message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
133
166
system_instruction : Optional [str ] = None ,
134
167
) -> ToolCallResponse :
@@ -137,7 +170,7 @@ def invoke_with_tools(
137
170
138
171
Args:
139
172
input (str): Text sent to the LLM.
140
- tools (List[Dict[str, Any]] ): List of tool definitions for the LLM to choose from.
173
+ tools (List[Tool] ): List of Tools for the LLM to choose from.
141
174
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
142
175
with each message having a specific role assigned.
143
176
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
@@ -159,7 +192,8 @@ def invoke_with_tools(
159
192
# Convert tools to OpenAI's expected type
160
193
openai_tools : List [ChatCompletionToolParam ] = []
161
194
for tool in tools :
162
- openai_tools .append (cast (ChatCompletionToolParam , tool ))
195
+ openai_format_tool = self ._convert_tool_to_openai_format (tool )
196
+ openai_tools .append (cast (ChatCompletionToolParam , openai_format_tool ))
163
197
164
198
response = self .client .chat .completions .create (
165
199
messages = self .get_messages (input , message_history , system_instruction ),
@@ -235,7 +269,7 @@ async def ainvoke(
235
269
async def ainvoke_with_tools (
236
270
self ,
237
271
input : str ,
238
- tools : List [ dict [ str , Any ]] , # Tools definition as a list of dictionaries
272
+ tools : Sequence [ Tool ] , # Tools definition as a sequence of Tool objects
239
273
message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
240
274
system_instruction : Optional [str ] = None ,
241
275
) -> ToolCallResponse :
@@ -244,7 +278,7 @@ async def ainvoke_with_tools(
244
278
245
279
Args:
246
280
input (str): Text sent to the LLM.
247
- tools (List[Dict[str, Any]] ): List of tool definitions for the LLM to choose from.
281
+ tools (List[Tool] ): List of Tools for the LLM to choose from.
248
282
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
249
283
with each message having a specific role assigned.
250
284
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
@@ -266,7 +300,8 @@ async def ainvoke_with_tools(
266
300
# Convert tools to OpenAI's expected type
267
301
openai_tools : List [ChatCompletionToolParam ] = []
268
302
for tool in tools :
269
- openai_tools .append (cast (ChatCompletionToolParam , tool ))
303
+ openai_format_tool = self ._convert_tool_to_openai_format (tool )
304
+ openai_tools .append (cast (ChatCompletionToolParam , openai_format_tool ))
270
305
271
306
response = await self .async_client .chat .completions .create (
272
307
messages = self .get_messages (input , message_history , system_instruction ),
0 commit comments