-
Notifications
You must be signed in to change notification settings - Fork 191
Migrate AgentRunner to Agent Workflow (Python) #502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
b4f0767
stg
leehuwuj bc2d503
raise error if there is no tools
leehuwuj cbebd03
stg
leehuwuj 5ec1947
support request api
leehuwuj 6d5749d
remove --no-files e2e test for python
leehuwuj 22e4be9
use agent workflow for financial report use case
leehuwuj 6ba5023
migrate form_filling to AgentWorkflow
leehuwuj 0e4ee4a
refactor: chat message content
thucpn 86610e6
rename function in chat-ui
thucpn 8d3db71
Create cool-cars-promise.md
marcusschiesser 5a230be
bump chat-ui
leehuwuj 7e23d77
add new query index and weather card for agent workflows
leehuwuj 0139a11
support source nodes
leehuwuj dae3249
remove unused function
leehuwuj 798f378
fix empty chunk
leehuwuj d09ae65
keep the old code for financial report and form-filling
leehuwuj c7e4696
fix annotation message
leehuwuj c83fa96
fix mypy
leehuwuj 25144dc
add artifact tool component
leehuwuj fe5982e
fix render empty div
leehuwuj 1e90a6a
improve typing
leehuwuj 087a45e
Merge remote-tracking branch 'origin' into lee/agent-workflows
leehuwuj d38eb3c
unify chat.py file
leehuwuj 9fd6d0c
remove multiagent folder (python)
leehuwuj d0f606d
fix linting
leehuwuj 21b7df1
fix missing import
leehuwuj c996508
support non-streaming api
leehuwuj be5870c
update citation prompt
leehuwuj 8004c9f
remove dead code
leehuwuj b60618a
remove dead code
leehuwuj 7514736
add comment
leehuwuj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| --- | ||
| "create-llama": patch | ||
| --- | ||
|
|
||
| Migrate AgentRunner to Agent Workflow (Python) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
45 changes: 45 additions & 0 deletions
45
templates/components/agents/python/financial_report/app/workflows/events.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| from enum import Enum | ||
| from typing import List, Optional | ||
|
|
||
| from llama_index.core.schema import NodeWithScore | ||
| from llama_index.core.workflow import Event | ||
|
|
||
| from app.api.routers.models import SourceNodes | ||
|
|
||
|
|
||
| class AgentRunEventType(Enum): | ||
| TEXT = "text" | ||
| PROGRESS = "progress" | ||
|
|
||
|
|
||
| class AgentRunEvent(Event): | ||
| name: str | ||
| msg: str | ||
| event_type: AgentRunEventType = AgentRunEventType.TEXT | ||
| data: Optional[dict] = None | ||
|
|
||
| def to_response(self) -> dict: | ||
| return { | ||
| "type": "agent", | ||
| "data": { | ||
| "agent": self.name, | ||
| "type": self.event_type.value, | ||
| "text": self.msg, | ||
| "data": self.data, | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| class SourceNodesEvent(Event): | ||
| nodes: List[NodeWithScore] | ||
|
|
||
| def to_response(self): | ||
| return { | ||
| "type": "sources", | ||
| "data": { | ||
| "nodes": [ | ||
| SourceNodes.from_source_node(node).model_dump() | ||
| for node in self.nodes | ||
| ] | ||
| }, | ||
| } |
230 changes: 230 additions & 0 deletions
230
templates/components/agents/python/financial_report/app/workflows/tools.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,230 @@ | ||
| import logging | ||
| import uuid | ||
| from abc import ABC, abstractmethod | ||
| from typing import Any, AsyncGenerator, Callable, Optional | ||
|
|
||
| from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole | ||
| from llama_index.core.llms.function_calling import FunctionCallingLLM | ||
| from llama_index.core.tools import ( | ||
| BaseTool, | ||
| FunctionTool, | ||
| ToolOutput, | ||
| ToolSelection, | ||
| ) | ||
| from llama_index.core.workflow import Context | ||
| from pydantic import BaseModel, ConfigDict | ||
|
|
||
| from app.workflows.events import AgentRunEvent, AgentRunEventType | ||
|
|
||
| logger = logging.getLogger("uvicorn") | ||
|
|
||
|
|
||
| class ContextAwareTool(FunctionTool, ABC): | ||
| @abstractmethod | ||
| async def acall(self, ctx: Context, input: Any) -> ToolOutput: # type: ignore | ||
| pass | ||
|
|
||
|
|
||
| class ChatWithToolsResponse(BaseModel): | ||
| """ | ||
| A tool call response from chat_with_tools. | ||
| """ | ||
|
|
||
| tool_calls: Optional[list[ToolSelection]] | ||
| tool_call_message: Optional[ChatMessage] | ||
| generator: Optional[AsyncGenerator[ChatResponse | None, None]] | ||
|
|
||
| model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
|
||
| def is_calling_different_tools(self) -> bool: | ||
| tool_names = {tool_call.tool_name for tool_call in self.tool_calls} | ||
| return len(tool_names) > 1 | ||
|
|
||
| def has_tool_calls(self) -> bool: | ||
| return self.tool_calls is not None and len(self.tool_calls) > 0 | ||
|
|
||
| def tool_name(self) -> str: | ||
| assert self.has_tool_calls() | ||
| assert not self.is_calling_different_tools() | ||
| return self.tool_calls[0].tool_name | ||
|
|
||
| async def full_response(self) -> str: | ||
| assert self.generator is not None | ||
| full_response = "" | ||
| async for chunk in self.generator: | ||
| content = chunk.message.content | ||
| if content: | ||
| full_response += content | ||
| return full_response | ||
|
|
||
|
|
||
| async def chat_with_tools( # type: ignore | ||
| llm: FunctionCallingLLM, | ||
| tools: list[BaseTool], | ||
| chat_history: list[ChatMessage], | ||
| ) -> ChatWithToolsResponse: | ||
| """ | ||
| Request LLM to call tools or not. | ||
| This function doesn't change the memory. | ||
| """ | ||
| generator = _tool_call_generator(llm, tools, chat_history) | ||
| is_tool_call = await generator.__anext__() | ||
| if is_tool_call: | ||
| # Last chunk is the full response | ||
| # Wait for the last chunk | ||
| full_response = None | ||
| async for chunk in generator: | ||
| full_response = chunk | ||
| assert isinstance(full_response, ChatResponse) | ||
| return ChatWithToolsResponse( | ||
| tool_calls=llm.get_tool_calls_from_response(full_response), | ||
| tool_call_message=full_response.message, | ||
| generator=None, | ||
| ) | ||
| else: | ||
| return ChatWithToolsResponse( | ||
| tool_calls=None, | ||
| tool_call_message=None, | ||
| generator=generator, | ||
| ) | ||
|
|
||
|
|
||
| async def call_tools( | ||
| ctx: Context, | ||
| agent_name: str, | ||
| tools: list[BaseTool], | ||
| tool_calls: list[ToolSelection], | ||
| emit_agent_events: bool = True, | ||
| ) -> list[ChatMessage]: | ||
| if len(tool_calls) == 0: | ||
| return [] | ||
|
|
||
| tools_by_name = {tool.metadata.get_name(): tool for tool in tools} | ||
| if len(tool_calls) == 1: | ||
| return [ | ||
| await call_tool( | ||
| ctx, | ||
| tools_by_name[tool_calls[0].tool_name], | ||
| tool_calls[0], | ||
| lambda msg: ctx.write_event_to_stream( | ||
| AgentRunEvent( | ||
| name=agent_name, | ||
| msg=msg, | ||
| ) | ||
| ), | ||
| ) | ||
| ] | ||
| # Multiple tool calls, show progress | ||
| tool_msgs: list[ChatMessage] = [] | ||
|
|
||
| progress_id = str(uuid.uuid4()) | ||
| total_steps = len(tool_calls) | ||
| if emit_agent_events: | ||
| ctx.write_event_to_stream( | ||
| AgentRunEvent( | ||
| name=agent_name, | ||
| msg=f"Making {total_steps} tool calls", | ||
| ) | ||
| ) | ||
| for i, tool_call in enumerate(tool_calls): | ||
| tool = tools_by_name.get(tool_call.tool_name) | ||
| if not tool: | ||
| tool_msgs.append( | ||
| ChatMessage( | ||
| role=MessageRole.ASSISTANT, | ||
| content=f"Tool {tool_call.tool_name} does not exist", | ||
| ) | ||
| ) | ||
| continue | ||
| tool_msg = await call_tool( | ||
| ctx, | ||
| tool, | ||
| tool_call, | ||
| event_emitter=lambda msg: ctx.write_event_to_stream( | ||
| AgentRunEvent( | ||
| name=agent_name, | ||
| msg=msg, | ||
| event_type=AgentRunEventType.PROGRESS, | ||
| data={ | ||
| "id": progress_id, | ||
| "total": total_steps, | ||
| "current": i, | ||
| }, | ||
| ) | ||
| ), | ||
| ) | ||
| tool_msgs.append(tool_msg) | ||
| return tool_msgs | ||
|
|
||
leehuwuj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| async def call_tool( | ||
| ctx: Context, | ||
| tool: BaseTool, | ||
| tool_call: ToolSelection, | ||
| event_emitter: Optional[Callable[[str], None]], | ||
| ) -> ChatMessage: | ||
| if event_emitter: | ||
| event_emitter( | ||
| f"Calling tool {tool_call.tool_name}, {str(tool_call.tool_kwargs)}" | ||
| ) | ||
| try: | ||
| if isinstance(tool, ContextAwareTool): | ||
| if ctx is None: | ||
| raise ValueError("Context is required for context aware tool") | ||
| # inject context for calling an context aware tool | ||
| response = await tool.acall(ctx=ctx, **tool_call.tool_kwargs) | ||
| else: | ||
| response = await tool.acall(**tool_call.tool_kwargs) # type: ignore | ||
| return ChatMessage( | ||
| role=MessageRole.TOOL, | ||
| content=str(response.raw_output), | ||
| additional_kwargs={ | ||
| "tool_call_id": tool_call.tool_id, | ||
| "name": tool.metadata.get_name(), | ||
| }, | ||
| ) | ||
| except Exception as e: | ||
| logger.error(f"Got error in tool {tool_call.tool_name}: {str(e)}") | ||
| if event_emitter: | ||
| event_emitter(f"Got error in tool {tool_call.tool_name}: {str(e)}") | ||
| return ChatMessage( | ||
| role=MessageRole.TOOL, | ||
| content=f"Error: {str(e)}", | ||
| additional_kwargs={ | ||
| "tool_call_id": tool_call.tool_id, | ||
| "name": tool.metadata.get_name(), | ||
| }, | ||
| ) | ||
|
|
||
|
|
||
| async def _tool_call_generator( | ||
| llm: FunctionCallingLLM, | ||
| tools: list[BaseTool], | ||
| chat_history: list[ChatMessage], | ||
| ) -> AsyncGenerator[ChatResponse | bool, None]: | ||
| response_stream = await llm.astream_chat_with_tools( | ||
| tools, | ||
| chat_history=chat_history, | ||
| allow_parallel_tool_calls=False, | ||
| ) | ||
|
|
||
| full_response = None | ||
| yielded_indicator = False | ||
| async for chunk in response_stream: | ||
| if "tool_calls" not in chunk.message.additional_kwargs: | ||
| # Yield a boolean to indicate whether the response is a tool call | ||
| if not yielded_indicator: | ||
| yield False | ||
| yielded_indicator = True | ||
|
|
||
| # if not a tool call, yield the chunks! | ||
| yield chunk # type: ignore | ||
| elif not yielded_indicator: | ||
| # Yield the indicator for a tool call | ||
| yield True | ||
| yielded_indicator = True | ||
|
|
||
| full_response = chunk | ||
|
|
||
| if full_response: | ||
| yield full_response # type: ignore | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.