Skip to content

Commit cd47dc2

Browse files
committed
Add tool calling to the LLM base class, implement in OpenAI
1 parent e99ebb0 commit cd47dc2

File tree

6 files changed

+617
-31
lines changed

6 files changed

+617
-31
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### Added
66

7+
- Added tool calling functionality to the LLM base class with OpenAI implementation, enabling structured parameter extraction and function calling.
78
- Added the `run_with_context` method to `Component`. This method includes a `context_` parameter, which provides information about the pipeline from which the component is executed (e.g., the `run_id`). It also enables the component to send events to the pipeline's callback function.
89

910

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
Example showing how to use OpenAI tool calls with parameter extraction.
3+
Both synchronous and asynchronous examples are provided.
4+
5+
To run this example:
6+
1. Make sure you have the OpenAI API key in your .env file:
7+
OPENAI_API_KEY=your-api-key
8+
2. Run: python examples/tool_calls/openai_tool_calls.py
9+
"""
10+
11+
import asyncio
12+
import json
13+
import os
14+
from typing import Dict, Any
15+
16+
from dotenv import load_dotenv
17+
18+
from neo4j_graphrag.llm import OpenAILLM
19+
from neo4j_graphrag.llm.types import ToolCallResponse
20+
21+
# Load environment variables from .env file
22+
load_dotenv()
23+
24+
# Define a tool for extracting information from text
25+
TOOLS = [
26+
{
27+
"type": "function",
28+
"function": {
29+
"name": "extract_person_info",
30+
"description": "Extract information about a person from text",
31+
"parameters": {
32+
"type": "object",
33+
"properties": {
34+
"name": {"type": "string", "description": "The person's full name"},
35+
"age": {"type": "integer", "description": "The person's age"},
36+
"occupation": {
37+
"type": "string",
38+
"description": "The person's occupation",
39+
},
40+
},
41+
"required": ["name"],
42+
},
43+
},
44+
}
45+
]
46+
47+
48+
def process_tool_call(response: ToolCallResponse) -> Dict[str, Any]:
49+
"""Process the tool call response and return the extracted parameters."""
50+
if not response.tool_calls:
51+
raise ValueError("No tool calls found in response")
52+
53+
tool_call = response.tool_calls[0]
54+
print(f"\nTool called: {tool_call.name}")
55+
print(f"Arguments: {tool_call.arguments}")
56+
print(f"Additional content: {response.content or 'None'}")
57+
return tool_call.arguments
58+
59+
60+
async def main() -> None:
61+
# Initialize the OpenAI LLM
62+
llm = OpenAILLM(
63+
api_key=os.getenv("OPENAI_API_KEY"),
64+
model_name="gpt-4o",
65+
model_params={"temperature": 0},
66+
)
67+
68+
# Example text containing information about a person
69+
text = "Stella Hane is a 35-year-old software engineer who loves coding."
70+
71+
print("\n=== Synchronous Tool Call ===")
72+
# Make a synchronous tool call
73+
sync_response = llm.invoke_with_tools(
74+
input=f"Extract information about the person from this text: {text}",
75+
tools=TOOLS,
76+
)
77+
sync_result = process_tool_call(sync_response)
78+
print("\n=== Synchronous Tool Call Result ===")
79+
print(json.dumps(sync_result, indent=2))
80+
81+
print("\n=== Asynchronous Tool Call ===")
82+
# Make an asynchronous tool call with a different text
83+
text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning."
84+
async_response = await llm.ainvoke_with_tools(
85+
input=f"Extract information about the person from this text: {text2}",
86+
tools=TOOLS,
87+
)
88+
async_result = process_tool_call(async_response)
89+
print("\n=== Asynchronous Tool Call Result ===")
90+
print(json.dumps(async_result, indent=2))
91+
92+
93+
if __name__ == "__main__":
94+
# Run the async main function
95+
asyncio.run(main())

src/neo4j_graphrag/llm/base.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from __future__ import annotations
1616

1717
from abc import ABC, abstractmethod
18-
from typing import Any, List, Optional, Union
18+
from typing import Any, Dict, List, Optional, Union
1919

2020
from neo4j_graphrag.message_history import MessageHistory
2121
from neo4j_graphrag.types import LLMMessage
2222

23-
from .types import LLMResponse
23+
from .types import LLMResponse, ToolCallResponse
2424

2525

2626
class LLMInterface(ABC):
@@ -84,3 +84,57 @@ async def ainvoke(
8484
Raises:
8585
LLMGenerationError: If anything goes wrong.
8686
"""
87+
88+
def invoke_with_tools(
89+
self,
90+
input: str,
91+
tools: List[Dict[str, Any]],
92+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
93+
system_instruction: Optional[str] = None,
94+
) -> ToolCallResponse:
95+
"""Sends a text input to the LLM with tool definitions and retrieves a tool call response.
96+
97+
This is a default implementation that should be overridden by LLM providers that support tool/function calling.
98+
99+
Args:
100+
input (str): Text sent to the LLM.
101+
tools (List[Dict[str, Any]]): List of tool definitions for the LLM to choose from.
102+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
103+
with each message having a specific role assigned.
104+
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
105+
106+
Returns:
107+
ToolCallResponse: The response from the LLM containing a tool call.
108+
109+
Raises:
110+
LLMGenerationError: If anything goes wrong.
111+
NotImplementedError: If the LLM provider does not support tool calling.
112+
"""
113+
raise NotImplementedError("This LLM provider does not support tool calling.")
114+
115+
async def ainvoke_with_tools(
116+
self,
117+
input: str,
118+
tools: List[Dict[str, Any]],
119+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
120+
system_instruction: Optional[str] = None,
121+
) -> ToolCallResponse:
122+
"""Asynchronously sends a text input to the LLM with tool definitions and retrieves a tool call response.
123+
124+
This is a default implementation that should be overridden by LLM providers that support tool/function calling.
125+
126+
Args:
127+
input (str): Text sent to the LLM.
128+
tools (List[Dict[str, Any]]): List of tool definitions for the LLM to choose from.
129+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
130+
with each message having a specific role assigned.
131+
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
132+
133+
Returns:
134+
ToolCallResponse: The response from the LLM containing a tool call.
135+
136+
Raises:
137+
LLMGenerationError: If anything goes wrong.
138+
NotImplementedError: If the LLM provider does not support tool calling.
139+
"""
140+
raise NotImplementedError("This LLM provider does not support tool calling.")

0 commit comments

Comments
 (0)