Skip to content

Commit 85eaa5b

Browse files
authored
Add tool calling to the LLM base class, implement in OpenAI (#322)
1 parent a141d6c commit 85eaa5b

File tree

9 files changed

+1060
-19
lines changed

9 files changed

+1060
-19
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### Added
66

7+
- Added tool calling functionality to the LLM base class with OpenAI implementation, enabling structured parameter extraction and function calling.
78
- Added support for multi-vector collection in Qdrant driver.
89
- Added a `Pipeline.stream` method to stream pipeline progress.
910
- Added a new semantic match resolver to the KG Builder for entity resolution based on spaCy embeddings and cosine similarities so that nodes with similar textual properties get merged.

examples/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ are listed in [the last section of this file](#customize).
7878
- [Message history with Neo4j](./customize/llms/llm_with_neo4j_message_history.py)
7979
- [System Instruction](./customize/llms/llm_with_system_instructions.py)
8080

81+
- [Tool Calling with OpenAI](./customize/llms/openai_tool_calls.py)
82+
8183

8284
### Prompts
8385

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""
2+
Example showing how to use OpenAI tool calls with parameter extraction.
3+
Both synchronous and asynchronous examples are provided.
4+
5+
To run this example:
6+
1. Make sure you have the OpenAI API key in your .env file:
7+
OPENAI_API_KEY=your-api-key
8+
2. Run: python examples/tool_calls/openai_tool_calls.py
9+
"""
10+
11+
import asyncio
12+
import json
13+
import os
14+
from typing import Dict, Any
15+
16+
from dotenv import load_dotenv
17+
18+
from neo4j_graphrag.llm import OpenAILLM
19+
from neo4j_graphrag.llm.types import ToolCallResponse
20+
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter
21+
22+
# Load environment variables from .env file (OPENAI_API_KEY required for this example)
23+
load_dotenv()
24+
25+
26+
# Create a custom Tool implementation for person info extraction
27+
parameters = ObjectParameter(
28+
description="Parameters for extracting person information",
29+
properties={
30+
"name": StringParameter(description="The person's full name"),
31+
"age": IntegerParameter(description="The person's age"),
32+
"occupation": StringParameter(description="The person's occupation"),
33+
},
34+
required_properties=["name"],
35+
additional_properties=False,
36+
)
37+
person_info_tool = Tool(
38+
name="extract_person_info",
39+
description="Extract information about a person from text",
40+
parameters=parameters,
41+
execute_func=lambda **kwargs: kwargs,
42+
)
43+
44+
# Create the tool instance
45+
TOOLS = [person_info_tool]
46+
47+
48+
def process_tool_calls(response: ToolCallResponse) -> Dict[str, Any]:
49+
"""Process all tool calls in the response and return the extracted parameters."""
50+
if not response.tool_calls:
51+
raise ValueError("No tool calls found in response")
52+
53+
print(f"\nNumber of tool calls: {len(response.tool_calls)}")
54+
print(f"Additional content: {response.content or 'None'}")
55+
56+
results = []
57+
for i, tool_call in enumerate(response.tool_calls):
58+
print(f"\nTool call #{i + 1}: {tool_call.name}")
59+
print(f"Arguments: {tool_call.arguments}")
60+
results.append(tool_call.arguments)
61+
62+
# For backward compatibility, return the first tool call's arguments
63+
return results[0] if results else {}
64+
65+
66+
async def main() -> None:
67+
# Initialize the OpenAI LLM
68+
llm = OpenAILLM(
69+
api_key=os.getenv("OPENAI_API_KEY"),
70+
model_name="gpt-4o",
71+
model_params={"temperature": 0},
72+
)
73+
74+
# Example text containing information about a person
75+
text = "Stella Hane is a 35-year-old software engineer who loves coding."
76+
77+
print("\n=== Synchronous Tool Call ===")
78+
# Make a synchronous tool call
79+
sync_response = llm.invoke_with_tools(
80+
input=f"Extract information about the person from this text: {text}",
81+
tools=TOOLS,
82+
)
83+
sync_result = process_tool_calls(sync_response)
84+
print("\n=== Synchronous Tool Call Result ===")
85+
print(json.dumps(sync_result, indent=2))
86+
87+
print("\n=== Asynchronous Tool Call ===")
88+
# Make an asynchronous tool call with a different text
89+
text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning."
90+
async_response = await llm.ainvoke_with_tools(
91+
input=f"Extract information about the person from this text: {text2}",
92+
tools=TOOLS,
93+
)
94+
async_result = process_tool_calls(async_response)
95+
print("\n=== Asynchronous Tool Call Result ===")
96+
print(json.dumps(async_result, indent=2))
97+
98+
99+
if __name__ == "__main__":
100+
# Run the async main function
101+
asyncio.run(main())

src/neo4j_graphrag/llm/base.py

+58-2
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
from __future__ import annotations
1616

1717
from abc import ABC, abstractmethod
18-
from typing import Any, List, Optional, Union
18+
from typing import Any, List, Optional, Sequence, Union
1919

2020
from neo4j_graphrag.message_history import MessageHistory
2121
from neo4j_graphrag.types import LLMMessage
2222

23-
from .types import LLMResponse
23+
from .types import LLMResponse, ToolCallResponse
24+
25+
from neo4j_graphrag.tool import Tool
2426

2527

2628
class LLMInterface(ABC):
@@ -84,3 +86,57 @@ async def ainvoke(
8486
Raises:
8587
LLMGenerationError: If anything goes wrong.
8688
"""
89+
90+
def invoke_with_tools(
91+
self,
92+
input: str,
93+
tools: Sequence[Tool],
94+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
95+
system_instruction: Optional[str] = None,
96+
) -> ToolCallResponse:
97+
"""Sends a text input to the LLM with tool definitions and retrieves a tool call response.
98+
99+
This is a default implementation that should be overridden by LLM providers that support tool/function calling.
100+
101+
Args:
102+
input (str): Text sent to the LLM.
103+
tools (Sequence[Tool]): Sequence of Tools for the LLM to choose from. Each LLM implementation should handle the conversion to its specific format.
104+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
105+
with each message having a specific role assigned.
106+
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
107+
108+
Returns:
109+
ToolCallResponse: The response from the LLM containing a tool call.
110+
111+
Raises:
112+
LLMGenerationError: If anything goes wrong.
113+
NotImplementedError: If the LLM provider does not support tool calling.
114+
"""
115+
raise NotImplementedError("This LLM provider does not support tool calling.")
116+
117+
async def ainvoke_with_tools(
118+
self,
119+
input: str,
120+
tools: Sequence[Tool],
121+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
122+
system_instruction: Optional[str] = None,
123+
) -> ToolCallResponse:
124+
"""Asynchronously sends a text input to the LLM with tool definitions and retrieves a tool call response.
125+
126+
This is a default implementation that should be overridden by LLM providers that support tool/function calling.
127+
128+
Args:
129+
input (str): Text sent to the LLM.
130+
tools (Sequence[Tool]): Sequence of Tools for the LLM to choose from. Each LLM implementation should handle the conversion to its specific format.
131+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
132+
with each message having a specific role assigned.
133+
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
134+
135+
Returns:
136+
ToolCallResponse: The response from the LLM containing a tool call.
137+
138+
Raises:
139+
LLMGenerationError: If anything goes wrong.
140+
NotImplementedError: If the LLM provider does not support tool calling.
141+
"""
142+
raise NotImplementedError("This LLM provider does not support tool calling.")

0 commit comments

Comments
 (0)