Skip to content

Commit 07a3d25

Browse files
authored
Merge pull request ag2ai#310 from ag2ai/tool_captain
Integrating Tools from other frameworks into CaptainAgent
2 parents 721ee7a + d2ebc86 commit 07a3d25

File tree

4 files changed

+969
-55
lines changed

4 files changed

+969
-55
lines changed

autogen/agentchat/contrib/captainagent.py

+41-20
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
import hashlib
55
import json
66
import os
7-
from typing import Callable, Dict, List, Literal, Optional, Union
7+
from typing import Callable, Literal, Optional, Union
8+
9+
from termcolor import colored
810

911
import autogen
1012
from autogen import UserProxyAgent
1113
from autogen.agentchat.conversable_agent import ConversableAgent
1214

1315
from .agent_builder import AgentBuilder
14-
from .tool_retriever import ToolBuilder, get_full_tool_description
16+
from .tool_retriever import ToolBuilder, format_ag2_tool, get_full_tool_description
1517

1618

1719
class CaptainAgent(ConversableAgent):
@@ -387,8 +389,9 @@ def _run_autobuild(self, group_name: str, execution_task: str, building_task: st
387389
# tool library is enabled, reload tools and bind them to the agents
388390
tool_root_dir = self.tool_root_dir
389391
tool_builder = ToolBuilder(
390-
corpus_path=os.path.join(tool_root_dir, "tool_description.tsv"),
392+
corpus_root=tool_root_dir,
391393
retriever=self._nested_config["autobuild_tool_config"].get("retriever", "all-mpnet-base-v2"),
394+
type=self.tool_type,
392395
)
393396
for idx, agent in enumerate(agent_list):
394397
if idx == len(self.tool_history[group_name]):
@@ -404,39 +407,57 @@ def _run_autobuild(self, group_name: str, execution_task: str, building_task: st
404407
self.build_history[group_name] = agent_configs.copy()
405408

406409
if self._nested_config.get("autobuild_tool_config", None) and agent_configs["coding"] is True:
407-
print("==> Retrieving tools...", flush=True)
408410
skills = building_task.split("\n")
409411
if len(skills) == 0:
410412
skills = [building_task]
411413

414+
tool_type = "default"
412415
if self._nested_config["autobuild_tool_config"].get("tool_root", "default") == "default":
416+
print(colored("==> Retrieving tools...", "green"), flush=True)
413417
cur_path = os.path.dirname(os.path.abspath(__file__))
414418
tool_root_dir = os.path.join(cur_path, "captainagent", "tools")
419+
elif isinstance(self._nested_config["autobuild_tool_config"].get("tool_root", "default"), list):
420+
# We get a list, in this case, we assume it contains several tools for the agents
421+
tool_root_dir = self._nested_config["autobuild_tool_config"]["tool_root"]
422+
tool_type = "user_defined"
415423
else:
416424
tool_root_dir = self._nested_config["autobuild_tool_config"]["tool_root"]
417425
self.tool_root_dir = tool_root_dir
426+
self.tool_type = tool_type
418427

419428
# Retrieve and build tools based on the smilarities between the skills and the tool description
420429
tool_builder = ToolBuilder(
421-
corpus_path=os.path.join(tool_root_dir, "tool_description.tsv"),
430+
corpus_root=tool_root_dir,
422431
retriever=self._nested_config["autobuild_tool_config"].get("retriever", "all-mpnet-base-v2"),
432+
type=tool_type,
423433
)
424-
for idx, skill in enumerate(skills):
425-
tools = tool_builder.retrieve(skill)
434+
if tool_type == "default":
435+
for idx, skill in enumerate(skills):
436+
tools = tool_builder.retrieve(skill)
437+
docstrings = []
438+
for tool in tools:
439+
category, tool_name = tool.split(" ")[0], tool.split(" ")[1]
440+
tool_path = os.path.join(tool_root_dir, category, f"{tool_name}.py")
441+
docstring = get_full_tool_description(tool_path)
442+
docstrings.append(docstring)
443+
tool_builder.bind(agent_list[idx], "\n\n".join(docstrings))
444+
# the last agent is the user proxy agent, we need special treatment
445+
agent_list[-1] = tool_builder.bind_user_proxy(agent_list[-1], tool_root_dir)
446+
else:
447+
# a list containing all the tools that the agents share
426448
docstrings = []
427-
for tool in tools:
428-
category, tool_name = tool.split(" ")[0], tool.split(" ")[1]
429-
tool_path = os.path.join(tool_root_dir, category, f"{tool_name}.py")
430-
docstring = get_full_tool_description(tool_path)
431-
docstrings.append(docstring)
432-
tool_builder.bind(agent_list[idx], "\n\n".join(docstrings))
433-
# log tools
434-
tool_history = self.tool_history.get(group_name, [])
435-
tool_history.append(docstrings)
436-
self.tool_history[group_name] = tool_history
437-
438-
agent_list[-1] = tool_builder.bind_user_proxy(agent_list[-1], tool_root_dir)
439-
449+
for tool in tool_root_dir:
450+
docstrings.append(format_ag2_tool(tool))
451+
for idx, agent in enumerate(agent_list):
452+
if idx == len(agent_list) - 1:
453+
break
454+
tool_builder.bind(agent, "\n\n".join(docstrings))
455+
agent_list[-1] = tool_builder.bind_user_proxy(agent_list[-1], tool_root_dir)
456+
457+
# log tools
458+
tool_history = self.tool_history.get(group_name, [])
459+
tool_history.append(docstrings)
460+
self.tool_history[group_name] = tool_history
440461
else:
441462
# Build agents from scratch
442463
agent_list, agent_configs = builder.build(

autogen/agentchat/contrib/tool_retriever.py

+117-34
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,39 @@
2020
from sentence_transformers import SentenceTransformer, util
2121

2222
from autogen import AssistantAgent, UserProxyAgent
23-
from autogen.coding import LocalCommandLineCodeExecutor
23+
from autogen.coding import CodeExecutor, CodeExtractor, LocalCommandLineCodeExecutor, MarkdownCodeExtractor
2424
from autogen.coding.base import CodeBlock, CodeResult
25-
from autogen.function_utils import load_basemodels_if_needed
25+
from autogen.function_utils import get_function_schema, load_basemodels_if_needed
2626
from autogen.tools import Tool
2727

2828

2929
class ToolBuilder:
30-
TOOL_USING_PROMPT = """# Functions
31-
You have access to the following functions. They can be accessed from the module called 'functions' by their function names.
30+
TOOL_PROMPT_DEFAULT = """\n## Functions
31+
You have access to the following functions. They can be accessed from the module called 'functions' by their function names.
3232
For example, if there is a function called `foo` you could import it by writing `from functions import foo`
3333
{functions}
34+
"""
35+
TOOL_PROMPT_USER_DEFINED = """\n## Functions
36+
You have access to the following functions. You can write python code to call these functions directly without importing them.
37+
{functions}
3438
"""
3539

36-
def __init__(self, corpus_path, retriever="all-mpnet-base-v2"):
37-
38-
self.df = pd.read_csv(corpus_path, sep="\t")
39-
document_list = self.df["document_content"].tolist()
40+
def __init__(self, corpus_root, retriever="all-mpnet-base-v2", type="default"):
41+
if type == "default":
42+
corpus_path = os.path.join(corpus_root, "tool_description.tsv")
43+
self.df = pd.read_csv(corpus_path, sep="\t")
44+
document_list = self.df["document_content"].tolist()
45+
self.TOOL_PROMPT = self.TOOL_PROMPT_DEFAULT
46+
else:
47+
self.TOOL_PROMPT = self.TOOL_PROMPT_USER_DEFINED
48+
# user defined tools, retrieve is actually not needed, just for consistency
49+
document_list = []
50+
for tool in corpus_root:
51+
document_list.append(tool.description)
4052

4153
self.model = SentenceTransformer(retriever)
4254
self.embeddings = self.model.encode(document_list)
55+
self.type = type
4356

4457
def retrieve(self, query, top_k=3):
4558
# Encode the query using the Sentence Transformer model
@@ -55,39 +68,59 @@ def retrieve(self, query, top_k=3):
5568
def bind(self, agent: AssistantAgent, functions: str):
5669
"""Binds the function to the agent so that agent is aware of it."""
5770
sys_message = agent.system_message
58-
sys_message += self.TOOL_USING_PROMPT.format(functions=functions)
71+
sys_message += self.TOOL_PROMPT.format(functions=functions)
5972
agent.update_system_message(sys_message)
6073
return
6174

62-
def bind_user_proxy(self, agent: UserProxyAgent, tool_root: str):
75+
def bind_user_proxy(self, agent: UserProxyAgent, tool_root: Union[str, list]):
6376
"""
6477
Updates user proxy agent with a executor so that code executor can successfully execute function-related code.
6578
Returns an updated user proxy.
6679
"""
67-
# Find all the functions in the tool root
68-
functions = find_callables(tool_root)
69-
70-
code_execution_config = agent._code_execution_config
71-
executor = LocalCommandLineCodeExecutor(
72-
timeout=code_execution_config.get("timeout", 180),
73-
work_dir=code_execution_config.get("work_dir", "coding"),
74-
functions=functions,
75-
)
76-
code_execution_config = {
77-
"executor": executor,
78-
"last_n_messages": code_execution_config.get("last_n_messages", 1),
79-
}
80-
updated_user_proxy = UserProxyAgent(
81-
name=agent.name,
82-
is_termination_msg=agent._is_termination_msg,
83-
code_execution_config=code_execution_config,
84-
human_input_mode="NEVER",
85-
default_auto_reply=agent._default_auto_reply,
86-
)
87-
return updated_user_proxy
88-
89-
90-
class LocalExecutorWithTools:
80+
if isinstance(tool_root, str):
81+
# Find all the functions in the tool root
82+
functions = find_callables(tool_root)
83+
84+
code_execution_config = agent._code_execution_config
85+
executor = LocalCommandLineCodeExecutor(
86+
timeout=code_execution_config.get("timeout", 180),
87+
work_dir=code_execution_config.get("work_dir", "coding"),
88+
functions=functions,
89+
)
90+
code_execution_config = {
91+
"executor": executor,
92+
"last_n_messages": code_execution_config.get("last_n_messages", 1),
93+
}
94+
updated_user_proxy = UserProxyAgent(
95+
name=agent.name,
96+
is_termination_msg=agent._is_termination_msg,
97+
code_execution_config=code_execution_config,
98+
human_input_mode="NEVER",
99+
default_auto_reply=agent._default_auto_reply,
100+
)
101+
return updated_user_proxy
102+
else:
103+
# second case: user defined tools
104+
code_execution_config = agent._code_execution_config
105+
executor = LocalExecutorWithTools(
106+
tools=tool_root,
107+
work_dir=code_execution_config.get("work_dir", "coding"),
108+
)
109+
code_execution_config = {
110+
"executor": executor,
111+
"last_n_messages": code_execution_config.get("last_n_messages", 1),
112+
}
113+
updated_user_proxy = UserProxyAgent(
114+
name=agent.name,
115+
is_termination_msg=agent._is_termination_msg,
116+
code_execution_config=code_execution_config,
117+
human_input_mode="NEVER",
118+
default_auto_reply=agent._default_auto_reply,
119+
)
120+
return updated_user_proxy
121+
122+
123+
class LocalExecutorWithTools(CodeExecutor):
91124
"""
92125
An executor that executes code blocks with injected tools. In this executor, the func within the tools can be called directly without declaring in the code block.
93126
@@ -124,6 +157,11 @@ class LocalExecutorWithTools:
124157
work_dir: The working directory for the code execution. Default is the current directory.
125158
"""
126159

160+
@property
161+
def code_extractor(self) -> CodeExtractor:
162+
"""(Experimental) Export a code extractor that can be used by an agent."""
163+
return MarkdownCodeExtractor()
164+
127165
def __init__(self, tools: Optional[List[Tool]] = None, work_dir: Union[Path, str] = Path(".")):
128166
self.tools = tools if tools is not None else []
129167
self.work_dir = work_dir
@@ -189,6 +227,51 @@ def restart(self):
189227
pass
190228

191229

230+
def format_ag2_tool(tool: Tool):
231+
# get the args first
232+
schema = get_function_schema(tool.func, description=tool.description)
233+
234+
arg_name = list(inspect.signature(tool.func).parameters.keys())[0]
235+
arg_info = schema["function"]["parameters"]["properties"][arg_name]["properties"]
236+
237+
content = f'def {tool.name}({arg_name}):\n """\n'
238+
content += indent(tool.description, " ") + "\n"
239+
content += (
240+
indent(
241+
f"You must format all the arguments into a dictionary and pass them as **kwargs to {arg_name}. You should use print function to get the results.",
242+
" ",
243+
)
244+
+ "\n"
245+
+ indent(f"For example:\n\tresult = {tool.name}({arg_name}={{'arg1': 'value1' }})", " ")
246+
+ "\n"
247+
)
248+
content += indent(f"Arguments passed in {arg_name}:\n", " ")
249+
for arg, info in arg_info.items():
250+
content += indent(f"{arg} ({info['type']}): {info['description']}\n", " " * 2)
251+
content += ' """\n'
252+
return content
253+
254+
255+
def _wrap_function(func):
256+
"""Wrap the function to dump the return value to json.
257+
258+
Handles both sync and async functions.
259+
260+
Args:
261+
func: the function to be wrapped.
262+
263+
Returns:
264+
The wrapped function.
265+
"""
266+
267+
@load_basemodels_if_needed
268+
@functools.wraps(func)
269+
def _wrapped_func(*args, **kwargs):
270+
return func(*args, **kwargs)
271+
272+
return _wrapped_func
273+
274+
192275
def get_full_tool_description(py_file):
193276
"""
194277
Retrieves the function signature for a given Python file.

notebook/agentchat_captainagent.ipynb

-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
"import autogen\n",
6767
"\n",
6868
"config_path = \"OAI_CONFIG_LIST\"\n",
69-
"llm_config = {\"temperature\": 0}\n",
7069
"config_list = autogen.config_list_from_json(\n",
7170
" config_path, filter_dict={\"model\": [\"gpt-4o\"]}\n",
7271
") # You can modify the filter_dict to select your model"

0 commit comments

Comments
 (0)