|
1 | 1 | import { |
2 | 2 | Message, |
3 | | - MessageAnnotation, |
4 | 3 | getChatUIAnnotation, |
| 4 | + getCustomAnnotation, |
5 | 5 | useChatMessage, |
6 | 6 | useChatUI, |
7 | 7 | } from "@llamaindex/chat-ui"; |
| 8 | +import { ChatEvents } from "@llamaindex/chat-ui/widgets"; |
8 | 9 | import { JSONValue } from "ai"; |
9 | 10 | import { useMemo } from "react"; |
| 11 | +import { z } from "zod"; |
10 | 12 | import { Artifact, CodeArtifact } from "./artifact"; |
11 | 13 | import { WeatherCard, WeatherData } from "./weather-card"; |
12 | 14 |
|
| 15 | +const ToolCallSchema = z.object({ |
| 16 | + tool_name: z.string(), |
| 17 | + tool_kwargs: z.record(z.unknown()), |
| 18 | + tool_id: z.string(), |
| 19 | + tool_output: z.optional( |
| 20 | + z |
| 21 | + .object({ |
| 22 | + content: z.string(), |
| 23 | + tool_name: z.string(), |
| 24 | + raw_input: z.object({ |
| 25 | + args: z.array(z.unknown()), |
| 26 | + kwargs: z.record(z.unknown()), |
| 27 | + }), |
| 28 | + raw_output: z.record(z.unknown()), |
| 29 | + is_error: z.boolean().optional(), |
| 30 | + }) |
| 31 | + .optional(), |
| 32 | + ), |
| 33 | + return_direct: z.boolean().optional(), |
| 34 | +}); |
| 35 | + |
| 36 | +type ToolCallEvent = z.infer<typeof ToolCallSchema>; |
| 37 | + |
| 38 | +type GroupedToolCall = { |
| 39 | + initial: ToolCallEvent; |
| 40 | + output?: ToolCallEvent; |
| 41 | +}; |
| 42 | + |
13 | 43 | export function ToolAnnotations() { |
14 | | - // TODO: This is a bit of a hack to get the artifact version. better to generate the version in the tool call and |
15 | | - // store it in CodeArtifact |
16 | 44 | const { messages } = useChatUI(); |
17 | 45 | const { message } = useChatMessage(); |
18 | 46 | const artifactVersion = useMemo( |
19 | 47 | () => getArtifactVersion(messages, message), |
20 | 48 | [messages, message], |
21 | 49 | ); |
22 | | - // Get the tool data from the message annotations |
23 | | - const annotations = message.annotations as MessageAnnotation[] | undefined; |
24 | | - const toolData = annotations |
25 | | - ? (getChatUIAnnotation(annotations, "tools") as unknown as ToolData[]) |
26 | | - : null; |
27 | | - return toolData?.[0] ? ( |
28 | | - <ChatTools data={toolData[0]} artifactVersion={artifactVersion} /> |
29 | | - ) : null; |
30 | | -} |
31 | 50 |
|
32 | | -// TODO: Used to render outputs of tools. If needed, add more renderers here. |
33 | | -function ChatTools({ |
34 | | - data, |
35 | | - artifactVersion, |
36 | | -}: { |
37 | | - data: ToolData; |
38 | | - artifactVersion: number | undefined; |
39 | | -}) { |
40 | | - if (!data) return null; |
41 | | - const { toolCall, toolOutput } = data; |
| 51 | + const toolCallEvents = getCustomAnnotation<ToolCallEvent>( |
| 52 | + message.annotations, |
| 53 | + (annotation) => { |
| 54 | + const result = ToolCallSchema.safeParse(annotation); |
| 55 | + return result.success; |
| 56 | + }, |
| 57 | + ); |
| 58 | + |
| 59 | + // Group tool calls by tool_id - we just need to take the latest event for each tool_id |
| 60 | + const groupedToolCalls = useMemo(() => { |
| 61 | + const groups = new Map<string, GroupedToolCall>(); |
42 | 62 |
|
43 | | - if (toolOutput.isError) { |
44 | | - return ( |
45 | | - <div className="border-l-2 border-red-400 pl-2"> |
46 | | - There was an error when calling the tool {toolCall.name} with input:{" "} |
47 | | - <br /> |
48 | | - {JSON.stringify(toolCall.input)} |
49 | | - </div> |
50 | | - ); |
51 | | - } |
| 63 | + toolCallEvents?.forEach((event) => { |
| 64 | + groups.set(event.tool_id, { initial: event }); |
| 65 | + }); |
52 | 66 |
|
53 | | - switch (toolCall.name) { |
54 | | - case "get_weather_information": |
55 | | - const weatherData = toolOutput.output as unknown as WeatherData; |
56 | | - return <WeatherCard data={weatherData} />; |
57 | | - case "artifact": |
58 | | - return ( |
59 | | - <Artifact |
60 | | - artifact={toolOutput.output as CodeArtifact} |
61 | | - version={artifactVersion} |
62 | | - /> |
63 | | - ); |
64 | | - default: |
65 | | - return null; |
66 | | - } |
| 67 | + return Array.from(groups.values()); |
| 68 | + }, [toolCallEvents]); |
| 69 | + |
| 70 | + return ( |
| 71 | + <div className="space-y-4"> |
| 72 | + {groupedToolCalls.map(({ initial }) => { |
| 73 | + switch (initial.tool_name) { |
| 74 | + case "query_index": { |
| 75 | + const query = initial.tool_kwargs.input; |
| 76 | + const eventData = [ |
| 77 | + { |
| 78 | + title: initial.tool_output |
| 79 | + ? `Got ${JSON.stringify((initial.tool_output?.raw_output as any).source_nodes?.length ?? 0)} sources for query: ${query}` |
| 80 | + : `Searching information for query: ${query}`, |
| 81 | + }, |
| 82 | + ]; |
| 83 | + |
| 84 | + return ( |
| 85 | + <ChatEvents |
| 86 | + key={initial.tool_id} |
| 87 | + data={eventData} |
| 88 | + showLoading={!initial.tool_output} |
| 89 | + /> |
| 90 | + ); |
| 91 | + } |
| 92 | + case "get_weather_information": { |
| 93 | + if (!initial.tool_output) |
| 94 | + return ( |
| 95 | + <ChatEvents |
| 96 | + key={initial.tool_id} |
| 97 | + data={[ |
| 98 | + { |
| 99 | + title: `Getting weather information for ${initial.tool_kwargs.location}`, |
| 100 | + }, |
| 101 | + ]} |
| 102 | + showLoading={false} |
| 103 | + /> |
| 104 | + ); |
| 105 | + const weatherData = initial.tool_output |
| 106 | + ?.raw_output as unknown as WeatherData; |
| 107 | + return <WeatherCard key={initial.tool_id} data={weatherData} />; |
| 108 | + } |
| 109 | + case "artifact": { |
| 110 | + const artifact = initial.tool_output |
| 111 | + ?.content as unknown as CodeArtifact; |
| 112 | + return ( |
| 113 | + <Artifact |
| 114 | + key={initial.tool_id} |
| 115 | + artifact={artifact} |
| 116 | + version={artifactVersion} |
| 117 | + /> |
| 118 | + ); |
| 119 | + } |
| 120 | + default: |
| 121 | + return null; |
| 122 | + } |
| 123 | + })} |
| 124 | + </div> |
| 125 | + ); |
67 | 126 | } |
68 | 127 |
|
69 | 128 | type ToolData = { |
|
0 commit comments