From ae8b7dc35d76202491ee1e2dce29d7e95afaa0d3 Mon Sep 17 00:00:00 2001 From: Vignesh Aigal Date: Tue, 3 Dec 2024 16:44:05 -0800 Subject: [PATCH] Fix session support for agents --- llmstack/apps/apis.py | 1 + llmstack/apps/runner/agent_actor.py | 31 ++++++++++--- llmstack/apps/runner/agent_controller.py | 55 ++++++++++++++++++------ llmstack/apps/runner/app_runner.py | 5 ++- 4 files changed, 71 insertions(+), 21 deletions(-) diff --git a/llmstack/apps/apis.py b/llmstack/apps/apis.py index f0367f444fb..cce1836426c 100644 --- a/llmstack/apps/apis.py +++ b/llmstack/apps/apis.py @@ -776,6 +776,7 @@ def _run_internal(self, request, app_uuid, input_data, source, app_data, session response = app_runner.run_until_complete( AppRunnerRequest(client_request_id=str(uuid.uuid4()), session_id=session_id, input=input_data), loop ) + async_to_sync(app_runner.stop)() return response def run(self, request, uid, session_id=None): diff --git a/llmstack/apps/runner/agent_actor.py b/llmstack/apps/runner/agent_actor.py index 9dbc5589811..7dc5232578e 100644 --- a/llmstack/apps/runner/agent_actor.py +++ b/llmstack/apps/runner/agent_actor.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List from llmstack.apps.runner.agent_controller import ( + AgentAssistantMessage, AgentController, AgentControllerConfig, AgentControllerData, @@ -126,15 +127,29 @@ async def _process_output(self): ) elif controller_output.type == AgentControllerDataType.AGENT_OUTPUT_END: + agent_final_output = ( + self._stitched_data["agent"][str(message_index)].data.content[0].data + if str(message_index) in self._stitched_data["agent"] + else "" + ) + self._agent_controller.process( + AgentControllerData( + type=AgentControllerDataType.AGENT_OUTPUT_END, + data=AgentAssistantMessage( + content=[ + AgentMessageContent( + type=AgentMessageContentType.TEXT, + data=agent_final_output, + ) + ] + ), + ) + ) self._content_queue.put_nowait( { "output": { **self._agent_outputs, - "output": ( - self._stitched_data["agent"][str(message_index)].data.content[0].data - if str(message_index) in self._stitched_data["agent"] - else "" - ), + "output": agent_final_output, }, "chunks": self._stitched_data, } @@ -174,6 +189,12 @@ async def _process_output(self): elif controller_output.type == AgentControllerDataType.TOOL_CALLS_END: tool_calls = self._stitched_data["agent"][str(message_index)].data.tool_calls + self._agent_controller.process( + AgentControllerData( + type=AgentControllerDataType.TOOL_CALLS_END, + data=AgentToolCallsMessage(tool_calls=tool_calls), + ), + ) for tool_call in tool_calls: tool_call_args = tool_call.arguments diff --git a/llmstack/apps/runner/agent_controller.py b/llmstack/apps/runner/agent_controller.py index 4e6fab66f21..724c10b956b 100644 --- a/llmstack/apps/runner/agent_controller.py +++ b/llmstack/apps/runner/agent_controller.py @@ -130,11 +130,43 @@ class AgentControllerData(BaseModel): ] = None +def save_messages_to_session_data(session_id, id, messages: List[AgentMessage]): + from llmstack.apps.app_session_utils import save_app_session_data + + logger.info(f"Saving messages to session data: {messages}") + + save_app_session_data(session_id, id, [m.model_dump_json() for m in messages]) + + +def load_messages_from_session_data(session_id, id): + from llmstack.apps.app_session_utils import get_app_session_data + + messages = [] + + session_data = get_app_session_data(session_id, id) + if session_data and isinstance(session_data, list): + for data in session_data: + data_json = json.loads(data) + if data_json["role"] == "system": + messages.append(AgentSystemMessage(**data_json)) + elif data_json["role"] == "assistant": + messages.append(AgentAssistantMessage(**data_json)) + elif data_json["role"] == "user": + messages.append(AgentUserMessage(**data_json)) + + return messages + + class AgentController: def __init__(self, output_queue: asyncio.Queue, config: AgentControllerConfig): + self._session_id = config.metadata.get("session_id") + self._controller_id = f"{config.metadata.get('app_uuid')}_agent" + self._system_message = render_template(config.agent_config.system_message, {}) self._output_queue = output_queue self._config = config - self._messages: List[AgentMessage] = [] + self._messages: List[AgentMessage] = ( + load_messages_from_session_data(self._session_id, self._controller_id) or [] + ) self._llm_client = None self._websocket = None self._provider_config = None @@ -254,18 +286,6 @@ def _init_llm_client(self): ), ) - self._messages.append( - AgentSystemMessage( - role=AgentMessageRole.SYSTEM, - content=[ - AgentMessageContent( - type=AgentMessageContentType.TEXT, - data=render_template(self._config.agent_config.system_message, {}), - ) - ], - ) - ) - async def _process_input_audio_stream(self): if self._input_audio_stream: async for chunk in self._input_audio_stream.read_async(): @@ -387,6 +407,10 @@ def process(self, data: AgentControllerData): # Actor calls this to add a message to the conversation and trigger processing self._messages.append(data.data) + # This is a message from the assistant to the user, simply add it to the message to maintain state + if data.type == AgentControllerDataType.AGENT_OUTPUT_END or data.type == AgentControllerDataType.TOOL_CALLS_END: + return + try: if len(self._messages) > self._config.agent_config.max_steps: raise Exception(f"Max steps ({self._config.agent_config.max_steps}) exceeded: {len(self._messages)}") @@ -465,7 +489,7 @@ async def process_messages(self, data: AgentControllerData): stream = True if self._config.agent_config.stream is None else self._config.agent_config.stream response = self._llm_client.chat.completions.create( model=self._config.agent_config.model, - messages=client_messages, + messages=[{"role": "system", "content": self._system_message}] + client_messages, stream=stream, tools=self._config.tools, ) @@ -703,6 +727,9 @@ async def add_ws_event_to_output_queue(self, event: Any): logger.error(f"WebSocket error: {event}") def terminate(self): + # Save to session data + save_messages_to_session_data(self._session_id, self._controller_id, self._messages) + # Create task for graceful websocket closure if hasattr(self, "_websocket") and self._websocket: asyncio.run_coroutine_threadsafe(self._websocket.close(), self._loop) diff --git a/llmstack/apps/runner/app_runner.py b/llmstack/apps/runner/app_runner.py index 0ea04692422..a35a6dc51c4 100644 --- a/llmstack/apps/runner/app_runner.py +++ b/llmstack/apps/runner/app_runner.py @@ -539,9 +539,10 @@ async def run(self, request: AppRunnerRequest): ) def run_until_complete(self, request: AppRunnerRequest, event_loop): + final_response = None for response in iter_over_async(self.run(request), event_loop): if isinstance(response.data, AppRunnerResponseErrorsData) or isinstance( response.data, AppRunnerResponseOutputData ): - break - return response + final_response = response + return final_response