diff --git a/packages/create-llama/templates/components/use-cases/typescript/deep_research/src/app/workflow.ts b/packages/create-llama/templates/components/use-cases/typescript/deep_research/src/app/workflow.ts index bb200f2e1..f0d0299bc 100644 --- a/packages/create-llama/templates/components/use-cases/typescript/deep_research/src/app/workflow.ts +++ b/packages/create-llama/templates/components/use-cases/typescript/deep_research/src/app/workflow.ts @@ -1,15 +1,18 @@ import { artifactEvent, toSourceEvent } from "@llamaindex/server"; import { + AgentInputData, agentStreamEvent, createStatefulMiddleware, createWorkflow, startAgentEvent, + StatefulContext, stopAgentEvent, + WorkflowContext, workflowEvent, } from "@llamaindex/workflow"; import { - ChatMemoryBuffer, LlamaCloudIndex, + Memory, MessageContent, Metadata, MetadataMode, @@ -22,10 +25,11 @@ import { import { randomUUID } from "node:crypto"; import { z } from "zod"; import { getIndex } from "./data"; +import { type Message } from "ai"; // workflow factory -export const workflowFactory = async (reqBody: any) => { - const index = await getIndex(reqBody?.data); +export const workflowFactory = async (chatBody: { messages: Message[] }) => { + const index = await getIndex(chatBody); return getWorkflow(index); }; @@ -124,7 +128,7 @@ export const UIEventSchema = z event: z .enum(["retrieve", "analyze", "answer"]) .describe( - "The type of event. DeepResearch has 3 main stages:\n1. retrieve: Retrieve the context from the vector store\n2. analyze: Analyze the context and generate a research questions to answer\n3. answer: Answer the provided questions. Each question has a unique id, when the state is done, the event will have the answer for the question.", + "The type of event. DeepResearch has 3 main stages:\n1. retrieve: Retrieve the context from the vector store\n2. analyze: Analyze the context and generate a research questions to answer\n3. answer: Answer the provided questions. Each question has a unique id, when the state is done, the event will have the answer for the question." ), state: z .enum(["pending", "inprogress", "done", "error"]) @@ -145,15 +149,19 @@ const uiEvent = workflowEvent<{ data: UIEventData; }>(); +type DeepResearchState = { + memory: Memory; + contextNodes: NodeWithScore[]; + userRequest: MessageContent; + totalQuestions: number; + researchResults: ResearchResult[]; +}; // workflow definition export function getWorkflow(index: VectorStoreIndex | LlamaCloudIndex) { const retriever = index.asRetriever({ similarityTopK: TOP_K }); - const { withState, getContext } = createStatefulMiddleware(() => { + const { withState } = createStatefulMiddleware(() => { return { - memory: new ChatMemoryBuffer({ - llm: Settings.llm, - chatHistory: [], - }), + memory: new Memory([], {}), contextNodes: [] as NodeWithScore[], userRequest: "" as MessageContent, totalQuestions: 0, @@ -162,218 +170,246 @@ export function getWorkflow(index: VectorStoreIndex | LlamaCloudIndex) { }); const workflow = withState(createWorkflow()); - workflow.handle([startAgentEvent], async ({ data }) => { - const { userInput, chatHistory = [] } = data; - const { sendEvent, state } = getContext(); - if (!userInput) throw new Error("Invalid input"); - - state.memory.set(chatHistory); - state.memory.put({ role: "user", content: userInput }); - state.userRequest = userInput; - sendEvent( - uiEvent.with({ - type: "ui_event", - data: { - event: "retrieve", - state: "inprogress", - }, - }), - ); + workflow.handle( + [startAgentEvent], + async ( + context: StatefulContext, + event: { data: AgentInputData } + ) => { + const { userInput, chatHistory } = event.data; + const { sendEvent, state } = context; - const retrievedNodes = await retriever.retrieve({ query: userInput }); + if (!userInput) throw new Error("Invalid input"); - sendEvent(toSourceEvent(retrievedNodes)); - sendEvent( - uiEvent.with({ - type: "ui_event", - data: { event: "retrieve", state: "done" }, - }), - ); + state.memory.add({ role: "user", content: userInput }); + state.userRequest = userInput; + sendEvent( + uiEvent.with({ + type: "ui_event", + data: { + event: "retrieve", + state: "inprogress", + }, + }) + ); - state.contextNodes.push(...retrievedNodes); + const retrievedNodes = await retriever.retrieve({ query: userInput }); - return planResearchEvent.with({}); - }); + sendEvent(toSourceEvent(retrievedNodes)); + sendEvent( + uiEvent.with({ + type: "ui_event", + data: { event: "retrieve", state: "done" }, + }) + ); + + state.contextNodes.push(...retrievedNodes); + + return planResearchEvent.with({}); + } + ); - workflow.handle([planResearchEvent], async ({ data }) => { - const { sendEvent, state, stream } = getContext(); - - sendEvent( - uiEvent.with({ - type: "ui_event", - data: { event: "analyze", state: "inprogress" }, - }), - ); - - const { decision, researchQuestions, cancelReason } = - await createResearchPlan( - state.memory, - state.contextNodes - .map((node) => node.node.getContent(MetadataMode.NONE)) - .join("\n"), - enhancedPrompt(state.totalQuestions), - state.userRequest, + workflow.handle( + [planResearchEvent], + async ( + context: StatefulContext, + event: { data: AgentInputData } + ) => { + const { userInput, chatHistory } = event.data; + const { sendEvent, state, stream } = context; + + sendEvent( + uiEvent.with({ + type: "ui_event", + data: { event: "analyze", state: "inprogress" }, + }) ); - sendEvent( - uiEvent.with({ - type: "ui_event", - data: { event: "analyze", state: "done" }, - }), - ); - if (decision === "cancel") { + const { decision, researchQuestions, cancelReason } = + await createResearchPlan( + state.memory, + state.contextNodes + .map((node) => node.node.getContent(MetadataMode.NONE)) + .join("\n"), + enhancedPrompt(state.totalQuestions), + state.userRequest + ); + sendEvent( uiEvent.with({ type: "ui_event", data: { event: "analyze", state: "done" }, - }), + }) ); - return agentStreamEvent.with({ - delta: cancelReason ?? "Research cancelled without any reason.", - response: cancelReason ?? "Research cancelled without any reason.", - currentAgentName: "", - raw: null, - }); - } - if (decision === "research" && researchQuestions.length > 0) { - state.totalQuestions += researchQuestions.length; - state.memory.put({ - role: "assistant", - content: - "We need to find answers to the following questions:\n" + - researchQuestions.join("\n"), - }); - researchQuestions.forEach(({ questionId: id, question }) => { + if (decision === "cancel") { sendEvent( uiEvent.with({ type: "ui_event", - data: { event: "answer", state: "pending", id, question }, - }), + data: { event: "analyze", state: "done" }, + }) ); - sendEvent(researchEvent.with({ questionId: id, question })); + return agentStreamEvent.with({ + delta: cancelReason ?? "Research cancelled without any reason.", + response: cancelReason ?? "Research cancelled without any reason.", + currentAgentName: "", + raw: null, + }); + } + if (decision === "research" && researchQuestions.length > 0) { + state.totalQuestions += researchQuestions.length; + state.memory.add({ + role: "assistant", + content: + "We need to find answers to the following questions:\n" + + researchQuestions.join("\n"), + }); + researchQuestions.forEach(({ questionId: id, question }) => { + sendEvent( + uiEvent.with({ + type: "ui_event", + data: { event: "answer", state: "pending", id, question }, + }) + ); + sendEvent(researchEvent.with({ questionId: id, question })); + }); + const events = await stream + .until( + () => state.researchResults.length === researchQuestions.length + ) + .toArray(); + return planResearchEvent.with({}); + } + state.memory.add({ + role: "assistant", + content: "No more idea to analyze. We should report the answers.", }); - const events = await stream - .until(() => state.researchResults.length === researchQuestions.length) - .toArray(); - return planResearchEvent.with({}); + sendEvent( + uiEvent.with({ + type: "ui_event", + data: { event: "analyze", state: "done" }, + }) + ); + return reportEvent.with({}); } - state.memory.put({ - role: "assistant", - content: "No more idea to analyze. We should report the answers.", - }); - sendEvent( - uiEvent.with({ - type: "ui_event", - data: { event: "analyze", state: "done" }, - }), - ); - return reportEvent.with({}); - }); + ); - workflow.handle([researchEvent], async ({ data }) => { - const { sendEvent, state } = getContext(); - const { questionId, question } = data; - - sendEvent( - uiEvent.with({ - type: "ui_event", - data: { - event: "answer", - state: "inprogress", - id: questionId, - question, - }, - }), - ); + workflow.handle( + [researchEvent], + async ( + context: StatefulContext, + event + ) => { + const { sendEvent, state } = context; + const { questionId, question } = event.data; - const answer = await answerQuestion( - contextStr(state.contextNodes), - question, - ); - state.researchResults.push({ questionId, question, answer }); - - state.memory.put({ - role: "assistant", - content: `${question}\n${answer}`, - }); - - sendEvent( - uiEvent.with({ - type: "ui_event", - data: { - event: "answer", - state: "done", - id: questionId, - question, - answer, - }, - }), - ); - }); + sendEvent( + uiEvent.with({ + type: "ui_event", + data: { + event: "answer", + state: "inprogress", + id: questionId, + question, + }, + }) + ); + + const answer = await answerQuestion( + contextStr(state.contextNodes), + question + ); + state.researchResults.push({ questionId, question, answer }); + + state.memory.add({ + role: "assistant", + content: `${question}\n${answer}`, + }); - workflow.handle([reportEvent], async ({ data }) => { - const { sendEvent, state } = getContext(); - const chatHistory = await state.memory.getAllMessages(); - const messages = chatHistory.concat([ - { - role: "system", - content: WRITE_REPORT_PROMPT, - }, - { - role: "user", - content: - "Write a report addressing the user request based on the research provided the context", - }, - ]); - - const stream = await Settings.llm.chat({ messages, stream: true }); - let response = ""; - for await (const chunk of stream) { - response += chunk.delta; sendEvent( - agentStreamEvent.with({ - delta: chunk.delta, - response, - currentAgentName: "", - raw: stream, - }), + uiEvent.with({ + type: "ui_event", + data: { + event: "answer", + state: "done", + id: questionId, + question, + answer, + }, + }) ); } + ); + + workflow.handle( + [reportEvent], + async ( + context: StatefulContext, + event + ) => { + const { sendEvent, state } = context; + const chatHistory = await state.memory.get(); + const messages = chatHistory.concat([ + { + role: "system", + content: WRITE_REPORT_PROMPT, + }, + { + role: "user", + content: + "Write a report addressing the user request based on the research provided the context", + }, + ]); + + const stream = await Settings.llm.chat({ messages, stream: true }); + let response = ""; + for await (const chunk of stream) { + response += chunk.delta; + sendEvent( + agentStreamEvent.with({ + delta: chunk.delta, + response, + currentAgentName: "", + raw: stream, + }) + ); + } - // Open the generated report in Canvas - sendEvent( - artifactEvent.with({ - type: "artifact", - data: { - type: "document", - created_at: Date.now(), + // Open the generated report in Canvas + sendEvent( + artifactEvent.with({ + type: "artifact", data: { - title: "DeepResearch Report", - content: response, - type: "markdown", - sources: state.contextNodes.map((node) => ({ - id: node.node.id_, - })), + type: "document", + created_at: Date.now(), + data: { + title: "DeepResearch Report", + content: response, + type: "markdown", + sources: state.contextNodes.map((node) => ({ + id: node.node.id_, + })), + }, }, - }, - }), - ); + }) + ); - return stopAgentEvent.with({ - result: response, - }); - }); + return stopAgentEvent.with({ + result: response, + message: { role: "assistant", content: "the reseach is complete" }, + }); + } + ); return workflow; } const createResearchPlan = async ( - memory: ChatMemoryBuffer, + memory: Memory, contextStr: string, enhancedPrompt: string, - userRequest: MessageContent, + userRequest: MessageContent ) => { - const chatHistory = await memory.getMessages(); + const chatHistory = await memory.get(); const conversationContext = chatHistory .map((message) => `${message.role}: ${message.content}`) @@ -389,7 +425,7 @@ const createResearchPlan = async ( const responseFormat = z.object({ decision: z.enum(["research", "write", "cancel"]), researchQuestions: z.array(z.string()), - cancelReason: z.string().optional(), + cancelReason: z.string().nullable(), }); const result = await Settings.llm.complete({ prompt, responseFormat });