Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visualization method to display the agent' structure as a tree 🌳 #470

Merged
merged 1 commit into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ def write_memory_to_messages(
messages.extend(memory_step.to_messages(summary_mode=summary_mode))
return messages

def visualize(self):
"""Creates a rich tree visualization of the agent's structure."""
self.logger.visualize_agent_tree(self)

def extract_action(self, model_output: str, split_token: str) -> Tuple[str, str]:
"""
Parse action from the LLM output
Expand Down
10 changes: 7 additions & 3 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,6 @@ class Model:
def __init__(self, **kwargs):
self.last_input_token_count = None
self.last_output_token_count = None
# Set default values for common parameters
kwargs.setdefault("max_tokens", 4096)
self.kwargs = kwargs

def _prepare_completion_kwargs(
Expand Down Expand Up @@ -643,15 +641,19 @@ class LiteLLMModel(Model):
The base URL of the OpenAI-compatible API server.
api_key (`str`, *optional*):
The API key to use for authentication.
custom_role_conversions (`dict[str, str]`, *optional*):
Custom role conversion mapping to convert message roles in others.
Useful for specific models that do not support specific message roles like "system".
**kwargs:
Additional keyword arguments to pass to the OpenAI API.
"""

def __init__(
self,
model_id="anthropic/claude-3-5-sonnet-20240620",
model_id: str = "anthropic/claude-3-5-sonnet-20240620",
api_base=None,
api_key=None,
custom_role_conversions: Optional[Dict[str, str]] = None,
**kwargs,
):
try:
Expand All @@ -667,6 +669,7 @@ def __init__(
litellm.add_function_to_prompt = True
self.api_base = api_base
self.api_key = api_key
self.custom_role_conversions = custom_role_conversions

def __call__(
self,
Expand All @@ -687,6 +690,7 @@ def __call__(
api_base=self.api_base,
api_key=self.api_key,
convert_images_to_image_urls=True,
custom_role_conversions=self.custom_role_conversions,
**kwargs,
)

Expand Down
39 changes: 39 additions & 0 deletions src/smolagents/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
from rich.table import Table
from rich.text import Text
from rich.tree import Tree


class Monitor:
Expand Down Expand Up @@ -162,5 +164,42 @@ def log_messages(self, messages: List) -> None:
)
)

def visualize_agent_tree(self, agent):
def create_tools_section(tools_dict):
table = Table(show_header=True, header_style="bold")
table.add_column("Name", style="blue")
table.add_column("Description")
table.add_column("Arguments")

for name, tool in tools_dict.items():
args = [
f"{arg_name} (`{info.get('type', 'Any')}`{', optional' if info.get('optional') else ''}): {info.get('description', '')}"
for arg_name, info in getattr(tool, "inputs", {}).items()
]
table.add_row(name, getattr(tool, "description", str(tool)), "\n".join(args))

return Group(Text("🛠️ Tools", style="bold italic blue"), table)

def build_agent_tree(parent_tree, agent_obj):
"""Recursively builds the agent tree."""
if agent_obj.tools:
parent_tree.add(create_tools_section(agent_obj.tools))

if agent_obj.managed_agents:
agents_branch = parent_tree.add("[bold italic blue]🤖 Managed agents")
for name, managed_agent in agent_obj.managed_agents.items():
agent_node_text = f"[bold {YELLOW_HEX}]{name} - {managed_agent.agent.__class__.__name__}"
agent_tree = agents_branch.add(agent_node_text)
if hasattr(managed_agent, "description"):
agent_tree.add(
f"[bold italic blue]📝 Description:[/bold italic blue] {managed_agent.description}"
)
if hasattr(managed_agent, "agent"):
build_agent_tree(agent_tree, managed_agent.agent)

main_tree = Tree(f"[bold {YELLOW_HEX}]{agent.__class__.__name__}")
build_agent_tree(main_tree, agent)
self.console.print(main_tree)


__all__ = ["AgentLogger", "Monitor"]