Skip to content

Commit

Permalink
Adds the raw output to the chat message for later inspection
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Jan 23, 2025
1 parent 0e8251b commit d06de7a
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 77 deletions.
29 changes: 13 additions & 16 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@
from rich.syntax import Syntax
from rich.text import Text

from smolagents.logger import ActionStep, PlanningStep, SystemPromptStep, TaskStep, ToolCall
from smolagents.types import AgentAudio, AgentImage, handle_agent_output_types
from smolagents.utils import (
AgentError,
AgentExecutionError,
AgentGenerationError,
AgentMaxStepsError,
AgentParsingError,
parse_code_blobs,
parse_json_tool_call,
truncate_content,
)

from .default_tools import TOOL_MAPPING, FinalAnswerTool
from .e2b_executor import E2BExecutor
from .local_python_executor import (
Expand Down Expand Up @@ -59,22 +72,6 @@
Tool,
get_tool_description_with_args,
)
from .types import AgentAudio, AgentImage, handle_agent_output_types
from .utils import (
ActionStep,
AgentError,
AgentExecutionError,
AgentGenerationError,
AgentMaxStepsError,
AgentParsingError,
PlanningStep,
SystemPromptStep,
TaskStep,
ToolCall,
parse_code_blobs,
parse_json_tool_call,
truncate_content,
)


def get_tool_descriptions(tools: Dict[str, Tool], tool_description_template: str) -> str:
Expand Down
7 changes: 4 additions & 3 deletions src/smolagents/gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import shutil
from typing import Optional

from .agents import ActionStep, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from .utils import AgentStepLog, _is_package_available
from smolagents.agents import ActionStep, MultiStepAgent
from smolagents.logger import AgentStepLog
from smolagents.types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from smolagents.utils import _is_package_available


def pull_messages_from_step(step_log: AgentStepLog):
Expand Down
94 changes: 45 additions & 49 deletions src/smolagents/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,19 @@ class AgentStepLog:
def dict(self):
raise NotImplementedError

def to_memory(self, **kwargs) -> List[Dict[str, Any]]:
def to_messages(self, **kwargs) -> List[Dict[str, Any]]:
raise NotImplementedError


@dataclass
class Message:
role: MessageRole
content: str

def dict(self):
return {"role": self.role, "content": self.content}


@dataclass
class ToolCall:
name: str
Expand Down Expand Up @@ -57,38 +66,29 @@ class ActionStep(AgentStepLog):
def dict(self):
return {
"agent_memory": self.agent_memory,
"tool_calls": [tc.dict() for tc in self.tool_calls],
"tool_calls": [tc.dict() for tc in self.tool_calls] if self.tool_calls else [],
"start_time": self.start_time,
"end_time": self.end_time,
"step": self.step,
"error": self.error.dict(),
"error": self.error.dict() if self.error else None,
"duration": self.duration,
"llm_output": self.llm_output,
"observations": self.observations,
"action_output": make_json_serializable(self.action_output),
}

def to_memory(self, summary_mode: bool, return_memory: bool) -> List[Dict[str, Any]]:
def to_messages(self, summary_mode: bool, return_memory: bool) -> List[Dict[str, Any]]:
memory = []
if self.agent_memory is not None and return_memory:
thought_message = {
"role": MessageRole.SYSTEM,
"content": self.agent_memory,
}
memory.append(thought_message)
message = Message(MessageRole.SYSTEM, self.agent_memory)
memory.append(message.dict())
if self.llm_output is not None and not summary_mode:
thought_message = {
"role": MessageRole.ASSISTANT,
"content": self.llm_output.strip(),
}
memory.append(thought_message)
message = Message(MessageRole.ASSISTANT, self.llm_output.strip())
memory.append(message.dict())

if self.tool_calls is not None:
tool_call_message = {
"role": MessageRole.ASSISTANT,
"content": str([tc.dict() for tc in self.tool_calls]),
}
memory.append(tool_call_message)
message = Message(MessageRole.ASSISTANT, str([tc.dict() for tc in self.tool_calls]))
memory.append(message.dict())

if self.error is not None:
message_content = (
Expand All @@ -97,23 +97,20 @@ def to_memory(self, summary_mode: bool, return_memory: bool) -> List[Dict[str, A
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
)
if self.tool_calls is None:
tool_response_message = {
"role": MessageRole.ASSISTANT,
"content": message_content,
}
tool_response_message = Message(MessageRole.ASSISTANT, message_content)
else:
tool_response_message = {
"role": MessageRole.TOOL_RESPONSE,
"content": f"Call id: {self.tool_calls[0].id}\n{message_content}",
}
memory.append(tool_response_message)
tool_response_message = Message(
MessageRole.TOOL_RESPONSE, f"Call id: {self.tool_calls[0].id}\n{message_content}"
)

memory.append(tool_response_message.dict())
else:
if self.observations is not None and self.tool_calls is not None:
tool_response_message = {
"role": MessageRole.TOOL_RESPONSE,
"content": f"Call id: {self.tool_calls[0].id}\nObservation:\n{self.observations}",
}
memory.append(tool_response_message)
tool_response_message = Message(
MessageRole.TOOL_RESPONSE,
f"Call id: {self.tool_calls[0].id}\nObservation:\n{self.observations}",
)
memory.append(tool_response_message.dict())
return memory


Expand All @@ -125,20 +122,14 @@ class PlanningStep(AgentStepLog):
def dict(self, **kwargs):
return {"plan": self.plan, "facts": self.facts}

def to_memory(self, summary_mode: bool) -> List[Dict[str, str]]:
def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
memory = []
thought_message = {
"role": MessageRole.ASSISTANT,
"content": f"[FACTS LIST]:\n{self.facts.strip()}",
}
memory.append(thought_message)
thought_message = Message(MessageRole.ASSISTANT, f"[FACTS LIST]:\n{self.facts.strip()}")
memory.append(thought_message.dict())

if not summary_mode:
thought_message = {
"role": MessageRole.ASSISTANT,
"content": f"[PLAN]:\n{self.plan.strip()}",
}
memory.append(thought_message)
thought_message = Message(MessageRole.ASSISTANT, f"[PLAN]:\n{self.plan.strip()}")
memory.append(thought_message.dict())
return memory


Expand All @@ -149,8 +140,9 @@ class TaskStep(AgentStepLog):
def dict(self):
return {"task": self.task}

def to_memory(self, summary_mode: bool) -> List[Dict[str, str]]:
return [{"role": MessageRole.USER, "content": f"New task:\n{self.task}"}]
def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
message = Message(MessageRole.USER, f"New task:\n{self.task}")
return [message.dict()]


@dataclass
Expand All @@ -160,9 +152,10 @@ class SystemPromptStep(AgentStepLog):
def dict(self):
return {"system_prompt": self.system_prompt}

def to_memory(self, summary_mode: bool) -> List[Dict[str, str]]:
def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
if not summary_mode:
return [{"role": MessageRole.SYSTEM, "content": self.system_prompt}]
message = Message(MessageRole.SYSTEM, self.system_prompt)
return [message.dict()]
return []


Expand Down Expand Up @@ -226,9 +219,12 @@ def write_inner_memory_from_logs(
"""
memory = []
for step_log in self.steps:
memory.extend(step_log.to_memory(summary_mode=summary_mode, return_memory=return_memory))
memory.extend(step_log.to_messages(summary_mode=summary_mode, return_memory=return_memory))
return memory

def dict(self):
return [step.dict() for step in self.steps]

def replay(self, with_memory: bool = False):
"""Prints a replay of the agent's steps.
Expand Down
21 changes: 12 additions & 9 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@
}


def get_dict_from_nested_dataclasses(obj):
def get_dict_from_nested_dataclasses(obj, ignore_key=None):
def convert(obj):
if hasattr(obj, "__dataclass_fields__"):
return {k: convert(v) for k, v in asdict(obj).items()}
return {k: convert(v) for k, v in asdict(obj).items() if k != ignore_key}
return obj

return convert(obj)
Expand Down Expand Up @@ -77,7 +77,7 @@ class ChatMessageToolCall:
type: str

@classmethod
def from_hf_api(cls, tool_call) -> "ChatMessageToolCall":
def from_hf_api(cls, tool_call, raw) -> "ChatMessageToolCall":
return cls(
function=ChatMessageToolCallDefinition.from_hf_api(tool_call.function),
id=tool_call.id,
Expand All @@ -90,16 +90,17 @@ class ChatMessage:
role: str
content: Optional[str] = None
tool_calls: Optional[List[ChatMessageToolCall]] = None
raw: Optional[Any] = None # Stores the raw output from the API

def model_dump_json(self):
return json.dumps(get_dict_from_nested_dataclasses(self))
return json.dumps(get_dict_from_nested_dataclasses(self, ignore_key="raw"))

@classmethod
def from_hf_api(cls, message) -> "ChatMessage":
def from_hf_api(cls, message, raw) -> "ChatMessage":
tool_calls = None
if getattr(message, "tool_calls", None) is not None:
tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls]
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
return cls(role=message.role, content=message.content, tool_calls=tool_calls, raw=raw)


def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
Expand Down Expand Up @@ -307,7 +308,7 @@ def __call__(
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
message = ChatMessage.from_hf_api(response.choices[0].message)
message = ChatMessage.from_hf_api(response.choices[0].message, raw=response)
if tools_to_call_from is not None:
return parse_tool_args_if_needed(message)
return message
Expand Down Expand Up @@ -539,7 +540,8 @@ def __call__(
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
message = response.choices[0].message
message: ChatMessage = response.choices[0].message
message.raw = response
if tools_to_call_from is not None:
return parse_tool_args_if_needed(message)
return message
Expand Down Expand Up @@ -614,7 +616,8 @@ def __call__(
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
message = response.choices[0].message
message: ChatMessage = response.choices[0].message
message.raw = response
if tools_to_call_from is not None:
return parse_tool_args_if_needed(message)
return message
Expand Down

0 comments on commit d06de7a

Please sign in to comment.