From 7d7b2844b73f26f1982479e7f68a3c9e171282fa Mon Sep 17 00:00:00 2001 From: ComputelessComputer <63365510+ComputelessComputer@users.noreply.github.com> Date: Fri, 26 Jun 2026 18:10:03 +0900 Subject: [PATCH] Fix first chat response transport Recreate chat sessions when transport becomes ready and persist first-send responses against the newly created chat group. --- apps/desktop/src/chat/components/content.tsx | 10 +- .../chat/components/session-provider.test.tsx | 289 +++++++++++++++++- .../src/chat/components/session-provider.tsx | 181 +++++++++-- .../src/chat/store/use-chat-actions.ts | 7 +- 4 files changed, 446 insertions(+), 41 deletions(-) diff --git a/apps/desktop/src/chat/components/content.tsx b/apps/desktop/src/chat/components/content.tsx index fce8a90282..5095a2ec7c 100644 --- a/apps/desktop/src/chat/components/content.tsx +++ b/apps/desktop/src/chat/components/content.tsx @@ -46,7 +46,10 @@ export function ChatContent({ layout?: "floating" | "right-panel"; sessionId: string; messages: HyprUIMessage[]; - sendMessage: (message: HyprUIMessage) => void; + sendMessage: ( + message: HyprUIMessage, + options?: { chatGroupId?: string }, + ) => void; regenerate: () => void; stop: () => void; status: ChatStatus; @@ -55,7 +58,10 @@ export function ChatContent({ handleSendMessage: ( content: string, parts: HyprUIMessage["parts"], - sendMessage: (message: HyprUIMessage) => void, + sendMessage: ( + message: HyprUIMessage, + options?: { chatGroupId?: string }, + ) => void, contextRefs?: ContextRef[], ) => void; contextEntities: DisplayEntity[]; diff --git a/apps/desktop/src/chat/components/session-provider.test.tsx b/apps/desktop/src/chat/components/session-provider.test.tsx index c9cd020184..d8caa60294 100644 --- a/apps/desktop/src/chat/components/session-provider.test.tsx +++ b/apps/desktop/src/chat/components/session-provider.test.tsx @@ -1,4 +1,10 @@ -import { cleanup, fireEvent, render, screen } from "@testing-library/react"; +import { + cleanup, + fireEvent, + render, + screen, + waitFor, +} from "@testing-library/react"; import { beforeEach, describe, expect, it, vi } from "vitest"; const mocks = vi.hoisted(() => ({ @@ -6,17 +12,29 @@ const mocks = vi.hoisted(() => ({ chatSendMessage: vi.fn(), chatSetMessages: vi.fn(), chatStop: vi.fn(), + chatInits: [] as unknown[], + chatMessagesTable: {} as Record, messages: [] as unknown[], + status: "ready", store: null as unknown, + transport: {} as unknown, })); vi.mock("@ai-sdk/react", () => ({ + Chat: class MockChat { + id: string; + + constructor(init: { id: string }) { + this.id = init.id; + mocks.chatInits.push(init); + } + }, useChat: () => ({ messages: mocks.messages, sendMessage: mocks.chatSendMessage, regenerate: mocks.chatRegenerate, stop: mocks.chatStop, - status: "ready", + status: mocks.status, error: undefined, setMessages: mocks.chatSetMessages, }), @@ -31,7 +49,7 @@ vi.mock("~/chat/context/use-chat-context-pipeline", () => ({ vi.mock("~/chat/transport/use-transport", () => ({ useTransport: () => ({ - transport: {}, + transport: mocks.transport, isSystemPromptReady: true, }), })); @@ -40,11 +58,12 @@ vi.mock("~/store/tinybase/store/main", () => ({ STORE_ID: "main", UI: { useStore: () => mocks.store, + useTable: () => mocks.chatMessagesTable, useValues: () => ({ user_id: "user-1" }), }, })); -import { ChatSession } from "./session-provider"; +import { ChatSession, type ChatSessionRenderProps } from "./session-provider"; import { buildPersistedChatMessageRow } from "~/chat/store/persisted-messages"; import type { HyprUIMessage } from "~/chat/types"; @@ -116,8 +135,12 @@ describe("ChatSession", () => { mocks.chatSendMessage.mockClear(); mocks.chatSetMessages.mockClear(); mocks.chatStop.mockClear(); + mocks.chatInits = []; + mocks.chatMessagesTable = {}; mocks.messages = []; + mocks.status = "ready"; mocks.store = createStore({}); + mocks.transport = {}; }); it("does not delete the previous persisted assistant when retrying an unpersisted empty assistant", () => { @@ -232,4 +255,262 @@ describe("ChatSession", () => { expect(store.getRow("chat_messages", "assistant-current")).toBeUndefined(); expect(mocks.chatRegenerate).toHaveBeenCalledTimes(1); }); + + it("recreates the sdk chat when transport becomes ready", () => { + const initialTransport = {}; + const readyTransport = {}; + mocks.transport = initialTransport; + + const { rerender } = render( + + {() => null} + , + ); + + mocks.transport = readyTransport; + rerender( + + {() => null} + , + ); + + expect(mocks.chatInits).toHaveLength(2); + expect((mocks.chatInits[0] as { transport: unknown }).transport).toBe( + initialTransport, + ); + expect((mocks.chatInits[1] as { transport: unknown }).transport).toBe( + readyTransport, + ); + }); + + it("syncs sdk messages when persisted chat rows load later", async () => { + const store = createStore({}); + mocks.store = store; + mocks.chatMessagesTable = {}; + const userMessage: HyprUIMessage = { + id: "user-1", + role: "user", + parts: [{ type: "text", text: "Question" }], + metadata: { createdAt: Date.parse("2024-01-01T00:00:00Z") }, + }; + + const { rerender } = render( + + {() => null} + , + ); + expect(mocks.chatInits).toHaveLength(1); + + store.setRow( + "chat_messages", + userMessage.id, + buildPersistedChatMessageRow({ + message: userMessage, + chatGroupId: "group-1", + userId: "user-1", + status: "ready", + }) as unknown as Record, + ); + mocks.chatMessagesTable = { [userMessage.id]: true }; + rerender( + + {() => null} + , + ); + + await waitFor(() => { + expect(mocks.chatSetMessages).toHaveBeenCalledWith([userMessage]); + }); + expect(mocks.chatInits).toHaveLength(1); + }); + + it("keeps the sdk chat when first send creates a chat group", () => { + const { rerender } = render( + {() => null}, + ); + + rerender( + + {() => null} + , + ); + + expect(mocks.chatInits).toHaveLength(1); + }); + + it("does not replace streaming sdk messages with stale persisted rows", () => { + const userMessage: HyprUIMessage = { + id: "user-1", + role: "user", + parts: [{ type: "text", text: "Question" }], + metadata: { createdAt: Date.parse("2024-01-01T00:00:00Z") }, + }; + const assistantMessage: HyprUIMessage = { + id: "assistant-1", + role: "assistant", + parts: [{ type: "text", text: "Partial answer" }], + metadata: { createdAt: Date.parse("2024-01-01T00:00:01Z") }, + }; + const store = createStore({ + "user-1": buildPersistedChatMessageRow({ + message: userMessage, + chatGroupId: "group-1", + userId: "user-1", + status: "ready", + }) as unknown as Record, + }); + mocks.store = store; + mocks.messages = [userMessage, assistantMessage]; + mocks.status = "streaming"; + + render( + + {() => null} + , + ); + + expect(mocks.chatSetMessages).not.toHaveBeenCalled(); + }); + + it("persists a first-send assistant response to the newly created group", () => { + const store = createStore({}); + mocks.store = store; + const captured: { send?: ChatSessionRenderProps["sendMessage"] } = {}; + + render( + + {(props) => { + captured.send = props.sendMessage; + return null; + }} + , + ); + + const sendMessage = captured.send; + expect(sendMessage).toBeDefined(); + sendMessage!( + { + id: "user-1", + role: "user", + parts: [{ type: "text", text: "Question" }], + }, + { chatGroupId: "new-group" }, + ); + + const onFinish = mocks.chatInits[0] as { + onFinish: (params: { + message: HyprUIMessage; + messages: HyprUIMessage[]; + isAbort: boolean; + }) => void; + }; + const userMessage: HyprUIMessage = { + id: "user-1", + role: "user", + parts: [{ type: "text", text: "Question" }], + }; + onFinish.onFinish({ + isAbort: false, + message: { + id: "assistant-1", + role: "assistant", + parts: [{ type: "text", text: "Answer" }], + metadata: { createdAt: Date.parse("2024-01-01T00:00:01Z") }, + }, + messages: [ + userMessage, + { + id: "assistant-1", + role: "assistant", + parts: [{ type: "text", text: "Answer" }], + metadata: { createdAt: Date.parse("2024-01-01T00:00:01Z") }, + }, + ], + }); + + expect(store.setRow).toHaveBeenCalledWith( + "chat_messages", + "assistant-1", + expect.objectContaining({ + chat_group_id: "new-group", + content: "Answer", + }), + ); + }); + + it("persists overlapping assistant responses to their submitted groups", () => { + const store = createStore({}); + mocks.store = store; + const captured: { send?: ChatSessionRenderProps["sendMessage"] } = {}; + + render( + + {(props) => { + captured.send = props.sendMessage; + return null; + }} + , + ); + + const userOne: HyprUIMessage = { + id: "user-1", + role: "user", + parts: [{ type: "text", text: "First question" }], + }; + const userTwo: HyprUIMessage = { + id: "user-2", + role: "user", + parts: [{ type: "text", text: "Second question" }], + }; + captured.send!(userOne, { chatGroupId: "group-1" }); + captured.send!(userTwo, { chatGroupId: "group-2" }); + + const onFinish = mocks.chatInits[0] as { + onFinish: (params: { + message: HyprUIMessage; + messages: HyprUIMessage[]; + isAbort: boolean; + }) => void; + }; + const assistantOne: HyprUIMessage = { + id: "assistant-1", + role: "assistant", + parts: [{ type: "text", text: "First answer" }], + metadata: { createdAt: Date.parse("2024-01-01T00:00:01Z") }, + }; + const assistantTwo: HyprUIMessage = { + id: "assistant-2", + role: "assistant", + parts: [{ type: "text", text: "Second answer" }], + metadata: { createdAt: Date.parse("2024-01-01T00:00:02Z") }, + }; + + onFinish.onFinish({ + isAbort: false, + message: assistantOne, + messages: [userOne, assistantOne, userTwo], + }); + onFinish.onFinish({ + isAbort: false, + message: assistantTwo, + messages: [userOne, assistantOne, userTwo, assistantTwo], + }); + + expect(store.setRow).toHaveBeenCalledWith( + "chat_messages", + "assistant-1", + expect.objectContaining({ + chat_group_id: "group-1", + content: "First answer", + }), + ); + expect(store.setRow).toHaveBeenCalledWith( + "chat_messages", + "assistant-2", + expect.objectContaining({ + chat_group_id: "group-2", + content: "Second answer", + }), + ); + }); }); diff --git a/apps/desktop/src/chat/components/session-provider.tsx b/apps/desktop/src/chat/components/session-provider.tsx index 8f0dda0020..c7f9c6e4b2 100644 --- a/apps/desktop/src/chat/components/session-provider.tsx +++ b/apps/desktop/src/chat/components/session-provider.tsx @@ -1,5 +1,5 @@ -import { useChat } from "@ai-sdk/react"; -import type { ChatStatus, LanguageModel, ToolSet } from "ai"; +import { Chat, useChat } from "@ai-sdk/react"; +import type { ChatStatus, ChatTransport, LanguageModel, ToolSet } from "ai"; import { type ReactNode, useCallback, @@ -31,7 +31,10 @@ export type ChatSessionRenderProps = { setMessages: ( msgs: HyprUIMessage[] | ((prev: HyprUIMessage[]) => HyprUIMessage[]), ) => void; - sendMessage: (message: HyprUIMessage) => void; + sendMessage: ( + message: HyprUIMessage, + options?: { chatGroupId?: string }, + ) => void; regenerate: () => void; stop: () => void; status: ChatStatus; @@ -55,6 +58,23 @@ interface ChatSessionProps { children: (props: ChatSessionRenderProps) => ReactNode; } +function areMessagesEqual(a: HyprUIMessage[], b: HyprUIMessage[]) { + if (a.length !== b.length) { + return false; + } + + return a.every((message, index) => { + const other = b[index]; + return ( + message.id === other?.id && + message.role === other.role && + JSON.stringify(message.parts) === JSON.stringify(other.parts) && + JSON.stringify(message.metadata ?? {}) === + JSON.stringify(other.metadata ?? {}) + ); + }); +} + export function ChatSession({ sessionId, chatGroupId, @@ -66,10 +86,19 @@ export function ChatSession({ children, }: ChatSessionProps) { const store = main.UI.useStore(main.STORE_ID); + const chatMessagesTable = main.UI.useTable("chat_messages", main.STORE_ID); const { user_id } = main.UI.useValues(main.STORE_ID); const [pendingManualRefs, setPendingManualRefs] = useState([]); const [pendingDraftRefs, setPendingDraftRefs] = useState([]); + const latestChatGroupIdRef = useRef(chatGroupId); + const latestStoreRef = useRef(store); + const latestUserIdRef = useRef(user_id); + const submittedChatGroupIdsRef = useRef(new Map()); + + latestChatGroupIdRef.current = chatGroupId; + latestStoreRef.current = store; + latestUserIdRef.current = user_id; const onAddContextEntity = useCallback((ref: ContextRef) => { setPendingManualRefs((prev) => @@ -98,6 +127,90 @@ export function ChatSession({ store, ); + const persistedVisibleMessages = useMemo( + () => + store && chatGroupId ? getVisibleChatMessages(store, chatGroupId) : [], + [store, chatGroupId, chatMessagesTable], + ); + + const chat = useMemo( + () => + new Chat({ + id: sessionId, + messages: + store && chatGroupId + ? getVisibleChatMessages(store, chatGroupId) + : [], + transport: transport ?? unavailableChatTransport, + onFinish: ({ message, messages, isAbort }) => { + const currentStore = latestStoreRef.current; + const currentUserId = latestUserIdRef.current; + const messageIndex = messages.findIndex((m) => m.id === message.id); + const lastMessageIndex = + messageIndex === -1 ? messages.length - 1 : messageIndex - 1; + let submittedUserMessage: HyprUIMessage | undefined; + for (let i = lastMessageIndex; i >= 0; i--) { + if (messages[i].role === "user") { + submittedUserMessage = messages[i]; + break; + } + } + const submittedChatGroupId = submittedUserMessage + ? submittedChatGroupIdsRef.current.get(submittedUserMessage.id) + : undefined; + if (submittedUserMessage) { + submittedChatGroupIdsRef.current.delete(submittedUserMessage.id); + } + + const persistedChatGroupId = + submittedUserMessage && currentStore + ? currentStore.getRow("chat_messages", submittedUserMessage.id) + ?.chat_group_id + : undefined; + const targetChatGroupId = + (typeof persistedChatGroupId === "string" && persistedChatGroupId + ? persistedChatGroupId + : undefined) ?? + submittedChatGroupId ?? + latestChatGroupIdRef.current; + + if ( + isAbort || + !targetChatGroupId || + !currentStore || + !currentUserId + ) { + return; + } + + const sanitizedParts = stripEphemeralToolContext(message.parts); + const sanitizedMessage = + sanitizedParts === message.parts + ? message + : { ...message, parts: sanitizedParts }; + if (!shouldPersistFinishedMessage(sanitizedMessage)) { + currentStore.delRow("chat_messages", sanitizedMessage.id); + return; + } + currentStore.setRow( + "chat_messages", + sanitizedMessage.id, + buildPersistedChatMessageRow({ + message: sanitizedMessage, + chatGroupId: targetChatGroupId, + userId: currentUserId, + status: "ready", + existingRow: currentStore.getRow( + "chat_messages", + sanitizedMessage.id, + ), + }), + ); + }, + }), + [sessionId, store, transport], + ); + const { messages, sendMessage: chatSendMessage, @@ -106,38 +219,33 @@ export function ChatSession({ status, error, setMessages: chatSetMessages, - } = useChat({ - id: sessionId, - messages: - store && chatGroupId ? getVisibleChatMessages(store, chatGroupId) : [], - transport: transport ?? undefined, - onFinish: ({ message, isAbort }) => { - if (isAbort || !chatGroupId || !store || !user_id) return; - const sanitizedParts = stripEphemeralToolContext(message.parts); - const sanitizedMessage = - sanitizedParts === message.parts - ? message - : { ...message, parts: sanitizedParts }; - if (!shouldPersistFinishedMessage(sanitizedMessage)) { - store.delRow("chat_messages", sanitizedMessage.id); - return; - } - store.setRow( - "chat_messages", - sanitizedMessage.id, - buildPersistedChatMessageRow({ - message: sanitizedMessage, - chatGroupId, - userId: user_id, - status: "ready", - existingRow: store.getRow("chat_messages", sanitizedMessage.id), - }), - ); - }, - }); + } = useChat({ chat }); + + useEffect(() => { + if ( + status !== "ready" || + !chatGroupId || + areMessagesEqual(messages, persistedVisibleMessages) + ) { + return; + } + + chatSetMessages(persistedVisibleMessages); + }, [ + chatGroupId, + messages, + persistedVisibleMessages, + status, + chatSetMessages, + ]); const sendMessage = useCallback( - (message: HyprUIMessage) => { + (message: HyprUIMessage, options?: { chatGroupId?: string }) => { + const targetChatGroupId = + options?.chatGroupId ?? latestChatGroupIdRef.current; + if (targetChatGroupId) { + submittedChatGroupIdsRef.current.set(message.id, targetChatGroupId); + } // HyprUIMessage is structurally compatible with CreateUIMessage: // no `text`/`files` so the SDK takes the `else` branch and uses message.id as the message id. void chatSendMessage(message as Parameters[0]); @@ -218,3 +326,10 @@ export function ChatSession({ return
{content}
; } + +const unavailableChatTransport: ChatTransport = { + sendMessages: async () => { + throw new Error("Chat model is not ready"); + }, + reconnectToStream: async () => null, +}; diff --git a/apps/desktop/src/chat/store/use-chat-actions.ts b/apps/desktop/src/chat/store/use-chat-actions.ts index 0dd13da8cc..b529a25baf 100644 --- a/apps/desktop/src/chat/store/use-chat-actions.ts +++ b/apps/desktop/src/chat/store/use-chat-actions.ts @@ -83,7 +83,10 @@ export function useChatActions({ ( content: string, parts: HyprUIMessage["parts"], - sendMessage: (message: HyprUIMessage) => void, + sendMessage: ( + message: HyprUIMessage, + options?: { chatGroupId?: string }, + ) => void, contextRefs?: ContextRef[], ) => { const messageId = id(); @@ -120,7 +123,7 @@ export function useChatActions({ metadata, }); - sendMessage(uiMessage); + sendMessage(uiMessage, { chatGroupId: currentGroupId }); }, [ groupId,