diff --git a/CHANGELOG.md b/CHANGELOG.md index 87175b8..b1d9755 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # @copilotkit/llmock +## 1.3.0 + +### Minor Changes + +- Mid-stream interruption: `truncateAfterChunks` and `disconnectAfterMs` fixture fields to simulate abrupt server disconnects +- AbortSignal-based cancellation primitives (`createInterruptionSignal`, signal-aware `delay()`) +- Backward-compatible `writeSSEStream` overload with `StreamOptions` returning completion status +- Interruption support across all HTTP SSE and WebSocket streaming paths +- `destroy()` method on `WebSocketConnection` for abrupt disconnect simulation +- Journal records `interrupted` and `interruptReason` on interrupted streams +- LLMock convenience API extended with interruption options (`truncateAfterChunks`, `disconnectAfterMs`) + ## 1.2.0 ### Minor Changes diff --git a/README.md b/README.md index 83cdf96..2f40157 100644 --- a/README.md +++ b/README.md @@ -622,11 +622,6 @@ Areas where llmock could grow, and explicit non-goals for the current scope. - **WebSocket compression**: `permessage-deflate` is not supported. - **Session persistence**: Realtime and Gemini Live sessions exist only for the lifetime of a single WebSocket connection. There is no cross-connection session resumption. -### Streaming - -- **Mid-stream interruption**: No way to simulate a server disconnecting partway through a stream (e.g. `truncateAfterChunks`, `disconnectAfterMs`). -- **Abort/cancellation signaling**: Streaming functions do not accept an `AbortSignal` for client-side cancellation. - ### Fixtures - **Request metadata in predicates**: Predicate functions receive only the `ChatCompletionRequest`, not HTTP headers, method, or URL. diff --git a/package.json b/package.json index d77db94..cdecdae 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@copilotkit/llmock", - "version": "1.2.0", + "version": "1.3.0", "description": "Deterministic mock LLM server for testing (OpenAI, Anthropic, Gemini)", "license": "MIT", "packageManager": "pnpm@10.28.2", diff --git a/src/__tests__/fixture-loader.test.ts b/src/__tests__/fixture-loader.test.ts index a0435be..d9f57dd 100644 --- a/src/__tests__/fixture-loader.test.ts +++ b/src/__tests__/fixture-loader.test.ts @@ -114,6 +114,71 @@ describe("loadFixtureFile", () => { expect(fixtures[0].chunkSize).toBeUndefined(); }); + it("passes through truncateAfterChunks when set", () => { + const filePath = writeJson(tmpDir, "truncate.json", { + fixtures: [ + { + match: { userMessage: "truncate me" }, + response: { content: "partial" }, + truncateAfterChunks: 3, + }, + ], + }); + + const fixtures = loadFixtureFile(filePath); + expect(fixtures).toHaveLength(1); + expect(fixtures[0].truncateAfterChunks).toBe(3); + }); + + it("passes through disconnectAfterMs when set", () => { + const filePath = writeJson(tmpDir, "disconnect.json", { + fixtures: [ + { + match: { userMessage: "disconnect me" }, + response: { content: "partial" }, + disconnectAfterMs: 500, + }, + ], + }); + + const fixtures = loadFixtureFile(filePath); + expect(fixtures).toHaveLength(1); + expect(fixtures[0].disconnectAfterMs).toBe(500); + }); + + it("passes through both truncateAfterChunks and disconnectAfterMs together", () => { + const filePath = writeJson(tmpDir, "both-interruptions.json", { + fixtures: [ + { + match: { userMessage: "both" }, + response: { content: "partial" }, + truncateAfterChunks: 5, + disconnectAfterMs: 1000, + }, + ], + }); + + const fixtures = loadFixtureFile(filePath); + expect(fixtures).toHaveLength(1); + expect(fixtures[0].truncateAfterChunks).toBe(5); + expect(fixtures[0].disconnectAfterMs).toBe(1000); + }); + + it("omits truncateAfterChunks and disconnectAfterMs when not present in JSON", () => { + const filePath = writeJson(tmpDir, "no-interruptions.json", { + fixtures: [ + { + match: { userMessage: "plain" }, + response: { content: "complete" }, + }, + ], + }); + + const fixtures = loadFixtureFile(filePath); + expect(fixtures[0].truncateAfterChunks).toBeUndefined(); + expect(fixtures[0].disconnectAfterMs).toBeUndefined(); + }); + it("warns and returns empty array for invalid JSON", () => { const filePath = join(tmpDir, "bad.json"); writeFileSync(filePath, "{ not valid json", "utf-8"); diff --git a/src/__tests__/interruption.test.ts b/src/__tests__/interruption.test.ts new file mode 100644 index 0000000..185879b --- /dev/null +++ b/src/__tests__/interruption.test.ts @@ -0,0 +1,142 @@ +import { describe, it, expect, vi, afterEach } from "vitest"; +import { createInterruptionSignal } from "../interruption.js"; +import type { Fixture } from "../types.js"; + +function makeFixture(overrides?: Partial): Fixture { + return { + match: { userMessage: "test" }, + response: { content: "hello" }, + ...overrides, + }; +} + +afterEach(() => { + vi.useRealTimers(); +}); + +describe("createInterruptionSignal", () => { + it("returns null when no interruption fields are set", () => { + const result = createInterruptionSignal(makeFixture()); + expect(result).toBeNull(); + }); + + it("returns null when both fields are undefined", () => { + const result = createInterruptionSignal( + makeFixture({ truncateAfterChunks: undefined, disconnectAfterMs: undefined }), + ); + expect(result).toBeNull(); + }); + + it("truncateAfterChunks: aborts after N ticks", () => { + const ctrl = createInterruptionSignal(makeFixture({ truncateAfterChunks: 3 })); + expect(ctrl).not.toBeNull(); + expect(ctrl!.signal.aborted).toBe(false); + + ctrl!.tick(); + expect(ctrl!.signal.aborted).toBe(false); + ctrl!.tick(); + expect(ctrl!.signal.aborted).toBe(false); + ctrl!.tick(); + expect(ctrl!.signal.aborted).toBe(true); + expect(ctrl!.reason()).toBe("truncateAfterChunks"); + + ctrl!.cleanup(); + }); + + it("truncateAfterChunks: extra ticks after abort are no-ops", () => { + const ctrl = createInterruptionSignal(makeFixture({ truncateAfterChunks: 1 })); + ctrl!.tick(); + expect(ctrl!.signal.aborted).toBe(true); + // Should not throw + ctrl!.tick(); + ctrl!.tick(); + expect(ctrl!.reason()).toBe("truncateAfterChunks"); + ctrl!.cleanup(); + }); + + it("disconnectAfterMs: aborts after timeout", async () => { + vi.useFakeTimers(); + const ctrl = createInterruptionSignal(makeFixture({ disconnectAfterMs: 100 })); + expect(ctrl).not.toBeNull(); + expect(ctrl!.signal.aborted).toBe(false); + + vi.advanceTimersByTime(99); + expect(ctrl!.signal.aborted).toBe(false); + + vi.advanceTimersByTime(1); + expect(ctrl!.signal.aborted).toBe(true); + expect(ctrl!.reason()).toBe("disconnectAfterMs"); + + ctrl!.cleanup(); + }); + + it("both set: truncateAfterChunks fires first wins", () => { + vi.useFakeTimers(); + const ctrl = createInterruptionSignal( + makeFixture({ truncateAfterChunks: 2, disconnectAfterMs: 10000 }), + ); + + ctrl!.tick(); + ctrl!.tick(); + expect(ctrl!.signal.aborted).toBe(true); + expect(ctrl!.reason()).toBe("truncateAfterChunks"); + + ctrl!.cleanup(); + }); + + it("both set: disconnectAfterMs fires first wins", () => { + vi.useFakeTimers(); + const ctrl = createInterruptionSignal( + makeFixture({ truncateAfterChunks: 100, disconnectAfterMs: 50 }), + ); + + ctrl!.tick(); // 1 of 100 + expect(ctrl!.signal.aborted).toBe(false); + + vi.advanceTimersByTime(50); + expect(ctrl!.signal.aborted).toBe(true); + expect(ctrl!.reason()).toBe("disconnectAfterMs"); + + ctrl!.cleanup(); + }); + + it("cleanup clears the timer", () => { + vi.useFakeTimers(); + const ctrl = createInterruptionSignal(makeFixture({ disconnectAfterMs: 100 })); + + ctrl!.cleanup(); + + vi.advanceTimersByTime(200); + expect(ctrl!.signal.aborted).toBe(false); + expect(ctrl!.reason()).toBeUndefined(); + }); + + it("reason returns undefined before abort", () => { + const ctrl = createInterruptionSignal(makeFixture({ truncateAfterChunks: 5 })); + expect(ctrl!.reason()).toBeUndefined(); + ctrl!.cleanup(); + }); + + it("truncateAfterChunks: 0 aborts immediately on first tick", () => { + const ctrl = createInterruptionSignal(makeFixture({ truncateAfterChunks: 0 })); + expect(ctrl).not.toBeNull(); + expect(ctrl!.signal.aborted).toBe(false); + + ctrl!.tick(); + expect(ctrl!.signal.aborted).toBe(true); + expect(ctrl!.reason()).toBe("truncateAfterChunks"); + + ctrl!.cleanup(); + }); + + it("disconnectAfterMs: 0 aborts promptly", async () => { + const ctrl = createInterruptionSignal(makeFixture({ disconnectAfterMs: 0 })); + expect(ctrl).not.toBeNull(); + + await new Promise((r) => setTimeout(r, 10)); + expect(ctrl!.signal.aborted).toBe(true); + expect(ctrl!.reason()).toBe("disconnectAfterMs"); + + ctrl!.cleanup(); + }); +}); diff --git a/src/__tests__/server.test.ts b/src/__tests__/server.test.ts index 5d3fdc9..814096a 100644 --- a/src/__tests__/server.test.ts +++ b/src/__tests__/server.test.ts @@ -942,3 +942,295 @@ describe("header forwarding in journal", () => { expect(entries[1].headers["authorization"]).toBe("Bearer key-two"); }); }); + +describe("stream interruption", () => { + // Helper that collects whatever data arrives before the server destroys the + // connection. Unlike `post()`, it does NOT reject on socket errors — it + // returns the partial body that was received. + function postPartial(url: string, body: unknown): Promise<{ body: string; aborted: boolean }> { + return new Promise((resolve) => { + const data = JSON.stringify(body); + const parsed = new URL(url); + const chunks: Buffer[] = []; + let aborted = false; + const req = http.request( + { + hostname: parsed.hostname, + port: parsed.port, + path: parsed.pathname, + method: "POST", + headers: { + "Content-Type": "application/json", + "Content-Length": Buffer.byteLength(data), + }, + }, + (res) => { + res.on("data", (c: Buffer) => chunks.push(c)); + res.on("end", () => { + resolve({ body: Buffer.concat(chunks).toString(), aborted }); + }); + res.on("error", () => { + aborted = true; + }); + res.on("aborted", () => { + aborted = true; + }); + res.on("close", () => { + resolve({ body: Buffer.concat(chunks).toString(), aborted }); + }); + }, + ); + req.on("error", () => { + aborted = true; + resolve({ body: Buffer.concat(chunks).toString(), aborted }); + }); + req.write(data); + req.end(); + }); + } + + it("truncateAfterChunks stops stream early and records interruption", async () => { + // Use enough chunks that without truncation, we'd get many more events. + // With truncateAfterChunks: 2, only 2 chunks should be written before abort. + // res.destroy() simulates abrupt disconnect — some data may be lost in + // transit, so we verify via the journal (which is always reliable). + const fixture: Fixture = { + match: { userMessage: "truncate-me" }, + response: { content: "ABCDEFGHIJKLMNO" }, // 15 chars, chunkSize 3 => 5 content + role + finish = 7 + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([fixture]); + const res = await postPartial(`${instance.url}/v1/chat/completions`, { + model: "gpt-4", + messages: [{ role: "user", content: "truncate-me" }], + }); + + // The body should NOT contain [DONE] since we interrupted + expect(res.body).not.toContain("data: [DONE]"); + + // The connection should have been aborted + expect(res.aborted).toBe(true); + + // Journal should record interruption + await new Promise((r) => setTimeout(r, 50)); + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("truncateAfterChunks"); + }); + + it("truncateAfterChunks is ignored for non-streaming requests", async () => { + const fixture: Fixture = { + match: { userMessage: "no-stream-truncate" }, + response: { content: "Hello world" }, + truncateAfterChunks: 1, + }; + instance = await createServer([fixture]); + const res = await post(`${instance.url}/v1/chat/completions`, { + model: "gpt-4", + messages: [{ role: "user", content: "no-stream-truncate" }], + stream: false, + }); + + expect(res.status).toBe(200); + const body = JSON.parse(res.body); + expect(body.choices[0].message.content).toBe("Hello world"); + + const entry = instance.journal.getLast(); + expect(entry!.response.interrupted).toBeUndefined(); + }); + + it("journal records interrupted: true with interruptReason", async () => { + const fixture: Fixture = { + match: { userMessage: "journal-int" }, + response: { content: "ABCDEFGHIJ" }, + chunkSize: 2, + truncateAfterChunks: 1, + }; + instance = await createServer([fixture]); + await postPartial(`${instance.url}/v1/chat/completions`, { + model: "gpt-4", + messages: [{ role: "user", content: "journal-int" }], + }); + + // Give server a moment to finish the async handler + await new Promise((r) => setTimeout(r, 50)); + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("truncateAfterChunks"); + }); + + it("disconnectAfterMs stops stream after timeout", async () => { + const fixture: Fixture = { + match: { userMessage: "disconnect-me" }, + response: { content: "A".repeat(200) }, + chunkSize: 10, + latency: 20, + disconnectAfterMs: 50, + }; + instance = await createServer([fixture]); + const res = await postPartial(`${instance.url}/v1/chat/completions`, { + model: "gpt-4", + messages: [{ role: "user", content: "disconnect-me" }], + }); + + // Should be a partial stream + expect(res.body).not.toContain("data: [DONE]"); + const events = parseSSEEvents(res.body); + // With 200 chars / 10 chunkSize = 20 content chunks + 1 role + 1 finish = 22 total + // But disconnectAfterMs: 50 with latency: 20 should only get a few + expect(events.length).toBeLessThan(22); + expect(events.length).toBeGreaterThanOrEqual(1); + + // Give server a moment to finish the async handler + await new Promise((r) => setTimeout(r, 100)); + const entry = instance.journal.getLast(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("disconnectAfterMs"); + }); + + it("tool call interruption via OpenAI /v1/chat/completions with truncateAfterChunks", async () => { + // Tool call stream: role chunk + N argument delta chunks + finish chunk + // With truncateAfterChunks: 2 we get at most 2 chunks before abort + const fixture: Fixture = { + match: { userMessage: "tool-truncate" }, + response: { + toolCalls: [{ name: "get_weather", arguments: '{"city":"New York","units":"metric"}' }], + }, + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([fixture]); + const res = await postPartial(`${instance.url}/v1/chat/completions`, { + model: "gpt-4", + messages: [{ role: "user", content: "tool-truncate" }], + }); + + // No [DONE] — stream was cut short + expect(res.body).not.toContain("data: [DONE]"); + + // Journal must record interruption + await new Promise((r) => setTimeout(r, 50)); + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("truncateAfterChunks"); + }); + + it("Claude Messages API /v1/messages with truncateAfterChunks stops stream early", async () => { + // Claude SSE events: message_start, content_block_start, N content_block_delta, content_block_stop, message_delta, message_stop + // With truncateAfterChunks: 2 the stream ends before message_stop + const fixture: Fixture = { + match: { userMessage: "claude-truncate" }, + response: { content: "ABCDEFGHIJKLMNO" }, // 15 chars, chunkSize 3 => 5 deltas + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([fixture]); + const res = await postPartial(`${instance.url}/v1/messages`, { + model: "claude-3-5-sonnet-20241022", + max_tokens: 1024, + stream: true, + messages: [{ role: "user", content: "claude-truncate" }], + }); + + // No message_stop event — stream was cut short + expect(res.body).not.toContain('"message_stop"'); + + // Journal records interruption + await new Promise((r) => setTimeout(r, 50)); + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("truncateAfterChunks"); + }); + + it("Claude Messages API /v1/messages with disconnectAfterMs stops stream early", async () => { + const fixture: Fixture = { + match: { userMessage: "claude-disconnect" }, + response: { content: "A".repeat(150) }, + chunkSize: 10, + latency: 20, + disconnectAfterMs: 50, + }; + instance = await createServer([fixture]); + const res = await postPartial(`${instance.url}/v1/messages`, { + model: "claude-3-5-sonnet-20241022", + max_tokens: 1024, + stream: true, + messages: [{ role: "user", content: "claude-disconnect" }], + }); + + // No message_stop event — stream was cut short + expect(res.body).not.toContain('"message_stop"'); + + // Journal records disconnectAfterMs reason + await new Promise((r) => setTimeout(r, 100)); + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("disconnectAfterMs"); + }); + + it("Gemini HTTP SSE streamGenerateContent with truncateAfterChunks stops stream early", async () => { + // Gemini SSE: N data-only chunks (no [DONE]). The last chunk has finishReason: "STOP". + // With truncateAfterChunks: 2 out of 5 content chunks, finishReason never appears. + const fixture: Fixture = { + match: { userMessage: "gemini-truncate" }, + response: { content: "ABCDEFGHIJKLMNO" }, // 15 chars, chunkSize 3 => 5 chunks + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([fixture]); + const res = await postPartial( + `${instance.url}/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse`, + { + contents: [{ role: "user", parts: [{ text: "gemini-truncate" }] }], + }, + ); + + // No STOP finishReason in the truncated stream + expect(res.body).not.toContain('"STOP"'); + + // Journal records interruption + await new Promise((r) => setTimeout(r, 50)); + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("truncateAfterChunks"); + }); + + it("HTTP Responses API /v1/responses with truncateAfterChunks stops stream early", async () => { + // Responses API SSE ends with response.completed event. + // With truncateAfterChunks: 2, that terminal event never appears. + const fixture: Fixture = { + match: { userMessage: "responses-truncate" }, + response: { content: "ABCDEFGHIJKLMNO" }, // 15 chars, chunkSize 3 => 5 deltas + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([fixture]); + const res = await postPartial(`${instance.url}/v1/responses`, { + model: "gpt-4o", + stream: true, + input: [{ role: "user", content: "responses-truncate" }], + }); + + // No response.completed event — stream was cut short + expect(res.body).not.toContain("response.completed"); + + // Journal records interruption + await new Promise((r) => setTimeout(r, 50)); + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("truncateAfterChunks"); + }); +}); diff --git a/src/__tests__/sse-writer.test.ts b/src/__tests__/sse-writer.test.ts index 2255769..ebca42d 100644 --- a/src/__tests__/sse-writer.test.ts +++ b/src/__tests__/sse-writer.test.ts @@ -1,7 +1,7 @@ -import { describe, it, expect, vi } from "vitest"; +import { describe, it, expect, vi, afterEach } from "vitest"; import { PassThrough } from "node:stream"; import type * as http from "node:http"; -import { writeSSEStream, writeErrorResponse } from "../sse-writer.js"; +import { writeSSEStream, writeErrorResponse, delay } from "../sse-writer.js"; import type { SSEChunk } from "../types.js"; function makeMockResponse(): { @@ -165,6 +165,148 @@ describe("writeSSEStream", () => { }); }); +describe("delay", () => { + afterEach(() => { + vi.useRealTimers(); + }); + + it("resolves after the specified time", async () => { + vi.useFakeTimers(); + let resolved = false; + const p = delay(100).then(() => { + resolved = true; + }); + expect(resolved).toBe(false); + vi.advanceTimersByTime(100); + await p; + expect(resolved).toBe(true); + }); + + it("resolves immediately when ms is 0", async () => { + const start = Date.now(); + await delay(0); + // Should return synchronously (Promise.resolve()) + expect(Date.now() - start).toBeLessThan(50); + }); + + it("resolves early when signal is aborted", async () => { + vi.useFakeTimers(); + const controller = new AbortController(); + let resolved = false; + const p = delay(10000, controller.signal).then(() => { + resolved = true; + }); + + vi.advanceTimersByTime(50); + expect(resolved).toBe(false); + + controller.abort(); + await p; + expect(resolved).toBe(true); + }); + + it("resolves immediately for negative ms", async () => { + await delay(-5); + // no error + }); + + it("resolves immediately when signal is already aborted", async () => { + const controller = new AbortController(); + controller.abort(); + + let resolved = false; + const raceResult = await Promise.race([ + delay(5000, controller.signal).then(() => { + resolved = true; + return "delay"; + }), + new Promise((r) => setTimeout(() => r("timeout"), 100)), + ]); + + expect(raceResult).toBe("delay"); + expect(resolved).toBe(true); + }); +}); + +describe("writeSSEStream with StreamOptions", () => { + it("accepts options object (backward compatible)", async () => { + const { res, output } = makeMockResponse(); + const chunks = [makeChunk("id1", "hello")]; + const result = await writeSSEStream(res, chunks, { latency: 0 }); + expect(result).toBe(true); + expect(output()).toContain("data: [DONE]"); + }); + + it("returns true when stream completes normally", async () => { + const { res } = makeMockResponse(); + const result = await writeSSEStream(res, [makeChunk("id1", "A")]); + expect(result).toBe(true); + }); + + it("stops mid-stream on abort signal and returns false", async () => { + const { res, output } = makeMockResponse(); + const controller = new AbortController(); + + const chunks = [makeChunk("id1", "A"), makeChunk("id2", "B"), makeChunk("id3", "C")]; + + // Abort after first chunk is sent + let chunksSent = 0; + const result = await writeSSEStream(res, chunks, { + signal: controller.signal, + onChunkSent: () => { + chunksSent++; + if (chunksSent === 1) controller.abort(); + }, + }); + + expect(result).toBe(false); + const body = output(); + expect(body).toContain(JSON.stringify(chunks[0])); + // Should not contain [DONE] + expect(body).not.toContain("[DONE]"); + }); + + it("skips [DONE] when interrupted", async () => { + const { res, output } = makeMockResponse(); + const controller = new AbortController(); + + const chunks = [makeChunk("id1", "A"), makeChunk("id2", "B")]; + const result = await writeSSEStream(res, chunks, { + signal: controller.signal, + onChunkSent: () => { + controller.abort(); + }, + }); + + expect(result).toBe(false); + expect(output()).not.toContain("[DONE]"); + }); + + it("onChunkSent fires per chunk", async () => { + const { res } = makeMockResponse(); + const chunks = [makeChunk("id1", "A"), makeChunk("id2", "B"), makeChunk("id3", "C")]; + let count = 0; + const result = await writeSSEStream(res, chunks, { + onChunkSent: () => { + count++; + }, + }); + + expect(result).toBe(true); + expect(count).toBe(3); + }); + + it("returns true for numeric latency arg (backward compat)", async () => { + vi.useFakeTimers(); + const { res } = makeMockResponse(); + const promise = writeSSEStream(res, [makeChunk("id1", "A")], 10); + await vi.runAllTimersAsync(); + const result = await promise; + expect(result).toBe(true); + vi.useRealTimers(); + }); +}); + describe("writeErrorResponse", () => { it("writes the given status code", () => { const { res, status } = makeMockResponse(); diff --git a/src/__tests__/ws-gemini-live.test.ts b/src/__tests__/ws-gemini-live.test.ts index 87be080..652866f 100644 --- a/src/__tests__/ws-gemini-live.test.ts +++ b/src/__tests__/ws-gemini-live.test.ts @@ -252,6 +252,121 @@ describe("WebSocket Gemini Live BidiGenerateContent", () => { ws.close(); }); + it("truncateAfterChunks stops stream early, no turnComplete: true", async () => { + const truncFixture: Fixture = { + match: { userMessage: "truncate-gemini" }, + response: { content: "ABCDEFGHIJKLMNO" }, // 15 chars, chunkSize 3 => 5 chunks + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([truncFixture]); + const ws = await connectWebSocket(instance.url, GEMINI_WS_PATH); + + ws.send(setupMsg()); + await ws.waitForMessages(1); // setupComplete + + ws.send(clientContentMsg("truncate-gemini")); + + // Wait for connection to be destroyed + await ws.waitForClose(); + + // Small pause for server-side processing + await new Promise((r) => setTimeout(r, 50)); + + // Check that no message with turnComplete: true was sent + const raw = await ws.waitForMessages(1).catch(() => [] as string[]); + if (raw.length > 1) { + const chunks = raw.slice(1).map((r) => JSON.parse(r)); + const hasTurnComplete = chunks.some((c) => c.serverContent?.turnComplete === true); + expect(hasTurnComplete).toBe(false); + } + }); + + it("truncateAfterChunks records interrupted: true in journal", async () => { + const truncFixture: Fixture = { + match: { userMessage: "truncate-journal-gemini" }, + response: { content: "ABCDEFGHIJKLMNO" }, + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([truncFixture]); + const ws = await connectWebSocket(instance.url, GEMINI_WS_PATH); + + ws.send(setupMsg()); + await ws.waitForMessages(1); // setupComplete + + ws.send(clientContentMsg("truncate-journal-gemini")); + + // Wait for connection to be destroyed + await ws.waitForClose(); + + // Give server time to finalize journal + await new Promise((r) => setTimeout(r, 50)); + + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("truncateAfterChunks"); + }); + + // Gemini Live sends all tool calls in a single WS frame, so truncateAfterChunks: 1 + // interrupts after that frame is sent (preventing conversation history update). + it("truncateAfterChunks with toolCalls records interrupted: true in journal", async () => { + const truncFixture: Fixture = { + match: { userMessage: "truncate-tool-gemini" }, + response: { + toolCalls: [{ name: "get_weather", arguments: '{"city":"NYC"}' }], + }, + latency: 5, + truncateAfterChunks: 1, + }; + instance = await createServer([truncFixture]); + const ws = await connectWebSocket(instance.url, GEMINI_WS_PATH); + + ws.send(setupMsg()); + await ws.waitForMessages(1); // setupComplete + + ws.send(clientContentMsg("truncate-tool-gemini")); + + // Wait for connection to be destroyed + await ws.waitForClose(); + + // Give server time to finalize journal + await new Promise((r) => setTimeout(r, 50)); + + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("truncateAfterChunks"); + }); + + it("disconnectAfterMs interrupts stream and records in journal", async () => { + const fixture: Fixture = { + match: { userMessage: "disconnect-gemini" }, + response: { content: "ABCDEFGHIJKLMNOPQRSTUVWXYZ" }, + chunkSize: 1, + latency: 20, + disconnectAfterMs: 30, + }; + instance = await createServer([fixture]); + const ws = await connectWebSocket(instance.url, GEMINI_WS_PATH); + + ws.send(setupMsg()); + await ws.waitForMessages(1); // setupComplete + + ws.send(clientContentMsg("disconnect-gemini")); + + await ws.waitForClose(); + await new Promise((r) => setTimeout(r, 50)); + + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("disconnectAfterMs"); + }); + it("returns error when message sent before setup", async () => { instance = await createServer(allFixtures); const ws = await connectWebSocket(instance.url, GEMINI_WS_PATH); diff --git a/src/__tests__/ws-realtime.test.ts b/src/__tests__/ws-realtime.test.ts index c87a811..f6d801d 100644 --- a/src/__tests__/ws-realtime.test.ts +++ b/src/__tests__/ws-realtime.test.ts @@ -299,6 +299,129 @@ describe("WebSocket /v1/realtime", () => { ws.close(); }); + it("truncateAfterChunks stops text stream early, no response.done event", async () => { + const truncFixture: Fixture = { + match: { userMessage: "truncate-rt" }, + response: { content: "ABCDEFGHIJKLMNO" }, // 15 chars, chunkSize 3 => 5 delta chunks + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([truncFixture]); + const ws = await connectWebSocket(instance.url, "/v1/realtime"); + + await ws.waitForMessages(1); // session.created + + ws.send(conversationItemCreate("user", "truncate-rt")); + await ws.waitForMessages(2); // + conversation.item.created + + ws.send(responseCreate()); + + // Wait for connection to be destroyed + await ws.waitForClose(); + + // Small pause for server-side processing + await new Promise((r) => setTimeout(r, 50)); + + // The connection was destroyed, so whatever messages arrived should NOT include response.done + // We got at least session.created + conversation.item.created = 2 before the response + const raw = await ws.waitForMessages(2).catch(() => [] as string[]); + if (raw.length > 2) { + const responseEvents = parseEvents(raw.slice(2)); + const types = responseEvents.map((e) => e.type); + expect(types).not.toContain("response.done"); + } + }); + + it("truncateAfterChunks records interrupted: true in journal", async () => { + const truncFixture: Fixture = { + match: { userMessage: "truncate-journal-rt" }, + response: { content: "ABCDEFGHIJKLMNO" }, + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([truncFixture]); + const ws = await connectWebSocket(instance.url, "/v1/realtime"); + + await ws.waitForMessages(1); // session.created + + ws.send(conversationItemCreate("user", "truncate-journal-rt")); + await ws.waitForMessages(2); // + conversation.item.created + + ws.send(responseCreate()); + + // Wait for connection to be destroyed + await ws.waitForClose(); + + // Give server time to finalize journal + await new Promise((r) => setTimeout(r, 50)); + + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("truncateAfterChunks"); + }); + + it("truncateAfterChunks with toolCalls records interrupted: true in journal", async () => { + const truncFixture: Fixture = { + match: { userMessage: "truncate-tool-rt" }, + response: { + toolCalls: [{ name: "search", arguments: '{"query":"hello world test string"}' }], + }, + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([truncFixture]); + const ws = await connectWebSocket(instance.url, "/v1/realtime"); + + await ws.waitForMessages(1); // session.created + + ws.send(conversationItemCreate("user", "truncate-tool-rt")); + await ws.waitForMessages(2); // + conversation.item.created + + ws.send(responseCreate()); + + // Wait for connection to be destroyed + await ws.waitForClose(); + + // Give server time to finalize journal + await new Promise((r) => setTimeout(r, 50)); + + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("truncateAfterChunks"); + }); + + it("disconnectAfterMs interrupts stream and records in journal", async () => { + const fixture: Fixture = { + match: { userMessage: "disconnect-rt" }, + response: { content: "ABCDEFGHIJKLMNOPQRSTUVWXYZ" }, + chunkSize: 1, + latency: 20, + disconnectAfterMs: 30, + }; + instance = await createServer([fixture]); + const ws = await connectWebSocket(instance.url, "/v1/realtime"); + + await ws.waitForMessages(1); // session.created + + ws.send(conversationItemCreate("user", "disconnect-rt")); + await ws.waitForMessages(2); // + conversation.item.created + + ws.send(responseCreate()); + + await ws.waitForClose(); + await new Promise((r) => setTimeout(r, 50)); + + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("disconnectAfterMs"); + }); + it("accumulates conversation state across multiple response.create calls", async () => { instance = await createServer(allFixtures); const ws = await connectWebSocket(instance.url, "/v1/realtime"); diff --git a/src/__tests__/ws-responses.test.ts b/src/__tests__/ws-responses.test.ts index 457ce14..7a6aebd 100644 --- a/src/__tests__/ws-responses.test.ts +++ b/src/__tests__/ws-responses.test.ts @@ -236,4 +236,108 @@ describe("WebSocket /v1/responses", () => { "Upgrade failed", ); }); + + it("truncateAfterChunks stops stream early, no response.completed event", async () => { + const truncFixture: Fixture = { + match: { userMessage: "truncate-ws" }, + response: { content: "ABCDEFGHIJKLMNO" }, // 15 chars, chunkSize 3 => 5 content chunks + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([truncFixture]); + const ws = await connectWebSocket(instance.url, "/v1/responses"); + + ws.send(responseCreateMsg("truncate-ws")); + + // Wait for the connection to be destroyed + await ws.waitForClose(); + + // Small pause to ensure server-side processing completed + await new Promise((r) => setTimeout(r, 50)); + + // Collect whatever messages were received + // We should have some events but NOT the response.completed event + const raw = await ws.waitForMessages(1).catch(() => [] as string[]); + // If we got messages, verify no response.completed + if (raw.length > 0) { + const events = parseEvents(raw); + const types = events.map((e) => e.type); + expect(types).not.toContain("response.completed"); + } + }); + + it("truncateAfterChunks records interrupted: true in journal", async () => { + const truncFixture: Fixture = { + match: { userMessage: "truncate-journal-ws" }, + response: { content: "ABCDEFGHIJKLMNO" }, + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([truncFixture]); + const ws = await connectWebSocket(instance.url, "/v1/responses"); + + ws.send(responseCreateMsg("truncate-journal-ws")); + + // Wait for the connection to be destroyed + await ws.waitForClose(); + + // Give server time to finalize journal + await new Promise((r) => setTimeout(r, 50)); + + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("truncateAfterChunks"); + }); + + it("truncateAfterChunks with toolCalls records interrupted: true in journal", async () => { + const truncFixture: Fixture = { + match: { userMessage: "truncate-tool-ws" }, + response: { + toolCalls: [{ name: "search", arguments: '{"query":"hello world test string"}' }], + }, + chunkSize: 3, + latency: 5, + truncateAfterChunks: 2, + }; + instance = await createServer([truncFixture]); + const ws = await connectWebSocket(instance.url, "/v1/responses"); + + ws.send(responseCreateMsg("truncate-tool-ws")); + + // Wait for the connection to be destroyed + await ws.waitForClose(); + + // Give server time to finalize journal + await new Promise((r) => setTimeout(r, 50)); + + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("truncateAfterChunks"); + }); + + it("disconnectAfterMs interrupts stream and records in journal", async () => { + const fixture: Fixture = { + match: { userMessage: "disconnect-ws" }, + response: { content: "ABCDEFGHIJKLMNOPQRSTUVWXYZ" }, + chunkSize: 1, + latency: 20, + disconnectAfterMs: 30, + }; + instance = await createServer([fixture]); + const ws = await connectWebSocket(instance.url, "/v1/responses"); + + ws.send(responseCreateMsg("disconnect-ws")); + + await ws.waitForClose(); + await new Promise((r) => setTimeout(r, 50)); + + const entry = instance.journal.getLast(); + expect(entry).not.toBeNull(); + expect(entry!.response.interrupted).toBe(true); + expect(entry!.response.interruptReason).toBe("disconnectAfterMs"); + }); }); diff --git a/src/fixture-loader.ts b/src/fixture-loader.ts index 298804c..8c154b7 100644 --- a/src/fixture-loader.ts +++ b/src/fixture-loader.ts @@ -13,6 +13,10 @@ function entryToFixture(entry: FixtureFileEntry): Fixture { response: entry.response, ...(entry.latency !== undefined && { latency: entry.latency }), ...(entry.chunkSize !== undefined && { chunkSize: entry.chunkSize }), + ...(entry.truncateAfterChunks !== undefined && { + truncateAfterChunks: entry.truncateAfterChunks, + }), + ...(entry.disconnectAfterMs !== undefined && { disconnectAfterMs: entry.disconnectAfterMs }), }; } diff --git a/src/gemini.ts b/src/gemini.ts index 6165e4b..3c81f0b 100644 --- a/src/gemini.ts +++ b/src/gemini.ts @@ -21,7 +21,8 @@ import { generateToolCallId, } from "./helpers.js"; import { matchFixture } from "./router.js"; -import { writeErrorResponse } from "./sse-writer.js"; +import { writeErrorResponse, delay } from "./sse-writer.js"; +import { createInterruptionSignal } from "./interruption.js"; import type { Journal } from "./journal.js"; // ─── Gemini request types ─────────────────────────────────────────────────── @@ -316,30 +317,42 @@ function buildGeminiToolCallResponse(toolCalls: ToolCall[]): GeminiResponseChunk // ─── SSE writer for Gemini streaming ──────────────────────────────────────── -function delay(ms: number): Promise { - return new Promise((resolve) => setTimeout(resolve, ms)); +interface GeminiStreamOptions { + latency?: number; + signal?: AbortSignal; + onChunkSent?: () => void; } async function writeGeminiSSEStream( res: http.ServerResponse, chunks: GeminiResponseChunk[], - latency = 0, -): Promise { - if (res.writableEnded) return; + optionsOrLatency?: number | GeminiStreamOptions, +): Promise { + const opts: GeminiStreamOptions = + typeof optionsOrLatency === "number" ? { latency: optionsOrLatency } : (optionsOrLatency ?? {}); + const latency = opts.latency ?? 0; + const signal = opts.signal; + const onChunkSent = opts.onChunkSent; + + if (res.writableEnded) return true; res.setHeader("Content-Type", "text/event-stream"); res.setHeader("Cache-Control", "no-cache"); res.setHeader("Connection", "keep-alive"); for (const chunk of chunks) { - if (latency > 0) await delay(latency); - if (res.writableEnded) return; + if (latency > 0) await delay(latency, signal); + if (signal?.aborted) return false; + if (res.writableEnded) return true; // Gemini uses data-only SSE (no event: prefix, no [DONE]) res.write(`data: ${JSON.stringify(chunk)}\n\n`); + onChunkSent?.(); + if (signal?.aborted) return false; } if (!res.writableEnded) { res.end(); } + return true; } // ─── Request handler ──────────────────────────────────────────────────────── @@ -423,7 +436,7 @@ export async function handleGemini( // Text response if (isTextResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: req.method ?? "POST", path, headers: {}, @@ -436,14 +449,25 @@ export async function handleGemini( res.end(JSON.stringify(body)); } else { const chunks = buildGeminiTextStreamChunks(response.content, chunkSize); - await writeGeminiSSEStream(res, chunks, latency); + const interruption = createInterruptionSignal(fixture); + const completed = await writeGeminiSSEStream(res, chunks, { + latency, + signal: interruption?.signal, + onChunkSent: interruption?.tick, + }); + if (!completed) { + if (!res.writableEnded) res.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); } return; } // Tool call response if (isToolCallResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: req.method ?? "POST", path, headers: {}, @@ -456,7 +480,18 @@ export async function handleGemini( res.end(JSON.stringify(body)); } else { const chunks = buildGeminiToolCallStreamChunks(response.toolCalls); - await writeGeminiSSEStream(res, chunks, latency); + const interruption = createInterruptionSignal(fixture); + const completed = await writeGeminiSSEStream(res, chunks, { + latency, + signal: interruption?.signal, + onChunkSent: interruption?.tick, + }); + if (!completed) { + if (!res.writableEnded) res.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); } return; } diff --git a/src/index.ts b/src/index.ts index 9a25501..80eb6ef 100644 --- a/src/index.ts +++ b/src/index.ts @@ -35,8 +35,13 @@ export { buildToolCallChunks, } from "./helpers.js"; +// Interruption +export { createInterruptionSignal } from "./interruption.js"; +export type { InterruptionControl } from "./interruption.js"; + // SSE -export { writeSSEStream, writeErrorResponse } from "./sse-writer.js"; +export { writeSSEStream, writeErrorResponse, delay } from "./sse-writer.js"; +export type { StreamOptions } from "./sse-writer.js"; // Types export type { diff --git a/src/interruption.ts b/src/interruption.ts new file mode 100644 index 0000000..3b34299 --- /dev/null +++ b/src/interruption.ts @@ -0,0 +1,54 @@ +import type { Fixture } from "./types.js"; + +export interface InterruptionControl { + signal: AbortSignal; + tick(): void; + cleanup(): void; + reason(): string | undefined; +} + +export function createInterruptionSignal(fixture: Fixture): InterruptionControl | null { + const { truncateAfterChunks, disconnectAfterMs } = fixture; + + if (truncateAfterChunks === undefined && disconnectAfterMs === undefined) { + return null; + } + + const controller = new AbortController(); + let abortReason: string | undefined; + let chunkCount = 0; + let timer: ReturnType | undefined; + + if (disconnectAfterMs !== undefined) { + timer = setTimeout(() => { + if (!controller.signal.aborted) { + abortReason = "disconnectAfterMs"; + controller.abort(); + } + }, disconnectAfterMs); + } + + return { + signal: controller.signal, + + tick() { + if (controller.signal.aborted) return; + chunkCount++; + if (truncateAfterChunks !== undefined && chunkCount >= truncateAfterChunks) { + abortReason = "truncateAfterChunks"; + controller.abort(); + } + }, + + cleanup() { + if (timer !== undefined) { + clearTimeout(timer); + timer = undefined; + } + }, + + reason() { + return abortReason; + }, + }; +} diff --git a/src/llmock.ts b/src/llmock.ts index 404d3fd..d372e2f 100644 --- a/src/llmock.ts +++ b/src/llmock.ts @@ -55,7 +55,12 @@ export class LLMock { on( match: FixtureMatch, response: FixtureResponse, - opts?: { latency?: number; chunkSize?: number }, + opts?: { + latency?: number; + chunkSize?: number; + truncateAfterChunks?: number; + disconnectAfterMs?: number; + }, ): this { return this.addFixture({ match, @@ -67,7 +72,12 @@ export class LLMock { onMessage( pattern: string | RegExp, response: FixtureResponse, - opts?: { latency?: number; chunkSize?: number }, + opts?: { + latency?: number; + chunkSize?: number; + truncateAfterChunks?: number; + disconnectAfterMs?: number; + }, ): this { return this.on({ userMessage: pattern }, response, opts); } @@ -75,7 +85,12 @@ export class LLMock { onToolCall( name: string, response: FixtureResponse, - opts?: { latency?: number; chunkSize?: number }, + opts?: { + latency?: number; + chunkSize?: number; + truncateAfterChunks?: number; + disconnectAfterMs?: number; + }, ): this { return this.on({ toolName: name }, response, opts); } @@ -83,7 +98,12 @@ export class LLMock { onToolResult( id: string, response: FixtureResponse, - opts?: { latency?: number; chunkSize?: number }, + opts?: { + latency?: number; + chunkSize?: number; + truncateAfterChunks?: number; + disconnectAfterMs?: number; + }, ): this { return this.on({ toolCallId: id }, response, opts); } diff --git a/src/messages.ts b/src/messages.ts index a401220..0879a12 100644 --- a/src/messages.ts +++ b/src/messages.ts @@ -22,7 +22,8 @@ import { isErrorResponse, } from "./helpers.js"; import { matchFixture } from "./router.js"; -import { writeErrorResponse } from "./sse-writer.js"; +import { writeErrorResponse, delay } from "./sse-writer.js"; +import { createInterruptionSignal } from "./interruption.js"; import type { Journal } from "./journal.js"; // ─── Claude Messages API request types ────────────────────────────────────── @@ -367,29 +368,41 @@ function buildClaudeToolCallResponse(toolCalls: ToolCall[], model: string): obje // ─── SSE writer for Claude Messages API ───────────────────────────────────── -function delay(ms: number): Promise { - return new Promise((resolve) => setTimeout(resolve, ms)); +interface ClaudeStreamOptions { + latency?: number; + signal?: AbortSignal; + onChunkSent?: () => void; } async function writeClaudeSSEStream( res: http.ServerResponse, events: ClaudeSSEEvent[], - latency = 0, -): Promise { - if (res.writableEnded) return; + optionsOrLatency?: number | ClaudeStreamOptions, +): Promise { + const opts: ClaudeStreamOptions = + typeof optionsOrLatency === "number" ? { latency: optionsOrLatency } : (optionsOrLatency ?? {}); + const latency = opts.latency ?? 0; + const signal = opts.signal; + const onChunkSent = opts.onChunkSent; + + if (res.writableEnded) return true; res.setHeader("Content-Type", "text/event-stream"); res.setHeader("Cache-Control", "no-cache"); res.setHeader("Connection", "keep-alive"); for (const event of events) { - if (latency > 0) await delay(latency); - if (res.writableEnded) return; + if (latency > 0) await delay(latency, signal); + if (signal?.aborted) return false; + if (res.writableEnded) return true; res.write(`event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`); + onChunkSent?.(); + if (signal?.aborted) return false; } if (!res.writableEnded) { res.end(); } + return true; } // ─── Request handler ──────────────────────────────────────────────────────── @@ -468,7 +481,7 @@ export async function handleMessages( // Text response if (isTextResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: req.method ?? "POST", path: req.url ?? "/v1/messages", headers: {}, @@ -481,14 +494,25 @@ export async function handleMessages( res.end(JSON.stringify(body)); } else { const events = buildClaudeTextStreamEvents(response.content, completionReq.model, chunkSize); - await writeClaudeSSEStream(res, events, latency); + const interruption = createInterruptionSignal(fixture); + const completed = await writeClaudeSSEStream(res, events, { + latency, + signal: interruption?.signal, + onChunkSent: interruption?.tick, + }); + if (!completed) { + if (!res.writableEnded) res.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); } return; } // Tool call response if (isToolCallResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: req.method ?? "POST", path: req.url ?? "/v1/messages", headers: {}, @@ -505,7 +529,18 @@ export async function handleMessages( completionReq.model, chunkSize, ); - await writeClaudeSSEStream(res, events, latency); + const interruption = createInterruptionSignal(fixture); + const completed = await writeClaudeSSEStream(res, events, { + latency, + signal: interruption?.signal, + onChunkSent: interruption?.tick, + }); + if (!completed) { + if (!res.writableEnded) res.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); } return; } diff --git a/src/responses.ts b/src/responses.ts index fdad644..1f40e54 100644 --- a/src/responses.ts +++ b/src/responses.ts @@ -22,7 +22,8 @@ import { isErrorResponse, } from "./helpers.js"; import { matchFixture } from "./router.js"; -import { writeErrorResponse } from "./sse-writer.js"; +import { writeErrorResponse, delay } from "./sse-writer.js"; +import { createInterruptionSignal } from "./interruption.js"; import type { Journal } from "./journal.js"; // ─── Responses API request types ──────────────────────────────────────────── @@ -442,29 +443,41 @@ function buildToolCallResponse(toolCalls: ToolCall[], model: string): object { // ─── SSE writer for Responses API ─────────────────────────────────────────── -function delay(ms: number): Promise { - return new Promise((resolve) => setTimeout(resolve, ms)); +interface ResponsesStreamOptions { + latency?: number; + signal?: AbortSignal; + onChunkSent?: () => void; } async function writeResponsesSSEStream( res: http.ServerResponse, events: ResponsesSSEEvent[], - latency = 0, -): Promise { - if (res.writableEnded) return; + optionsOrLatency?: number | ResponsesStreamOptions, +): Promise { + const opts: ResponsesStreamOptions = + typeof optionsOrLatency === "number" ? { latency: optionsOrLatency } : (optionsOrLatency ?? {}); + const latency = opts.latency ?? 0; + const signal = opts.signal; + const onChunkSent = opts.onChunkSent; + + if (res.writableEnded) return true; res.setHeader("Content-Type", "text/event-stream"); res.setHeader("Cache-Control", "no-cache"); res.setHeader("Connection", "keep-alive"); for (const event of events) { - if (latency > 0) await delay(latency); - if (res.writableEnded) return; + if (latency > 0) await delay(latency, signal); + if (signal?.aborted) return false; + if (res.writableEnded) return true; res.write(`event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`); + onChunkSent?.(); + if (signal?.aborted) return false; } if (!res.writableEnded) { res.end(); } + return true; } // ─── Request handler ──────────────────────────────────────────────────────── @@ -541,7 +554,7 @@ export async function handleResponses( // Text response if (isTextResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: req.method ?? "POST", path: req.url ?? "/v1/responses", headers: {}, @@ -554,14 +567,25 @@ export async function handleResponses( res.end(JSON.stringify(body)); } else { const events = buildTextStreamEvents(response.content, completionReq.model, chunkSize); - await writeResponsesSSEStream(res, events, latency); + const interruption = createInterruptionSignal(fixture); + const completed = await writeResponsesSSEStream(res, events, { + latency, + signal: interruption?.signal, + onChunkSent: interruption?.tick, + }); + if (!completed) { + if (!res.writableEnded) res.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); } return; } // Tool call response if (isToolCallResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: req.method ?? "POST", path: req.url ?? "/v1/responses", headers: {}, @@ -574,7 +598,18 @@ export async function handleResponses( res.end(JSON.stringify(body)); } else { const events = buildToolCallStreamEvents(response.toolCalls, completionReq.model, chunkSize); - await writeResponsesSSEStream(res, events, latency); + const interruption = createInterruptionSignal(fixture); + const completed = await writeResponsesSSEStream(res, events, { + latency, + signal: interruption?.signal, + onChunkSent: interruption?.tick, + }); + if (!completed) { + if (!res.writableEnded) res.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); } return; } diff --git a/src/server.ts b/src/server.ts index 3dc2b74..2b04405 100644 --- a/src/server.ts +++ b/src/server.ts @@ -3,6 +3,7 @@ import type { Fixture, ChatCompletionRequest, MockServerOptions } from "./types. import { Journal } from "./journal.js"; import { matchFixture } from "./router.js"; import { writeSSEStream, writeErrorResponse } from "./sse-writer.js"; +import { createInterruptionSignal } from "./interruption.js"; import { buildTextChunks, buildToolCallChunks, @@ -173,7 +174,7 @@ async function handleCompletions( // Text response if (isTextResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: req.method ?? "POST", path: req.url ?? COMPLETIONS_PATH, headers: flattenHeaders(req.headers), @@ -186,14 +187,25 @@ async function handleCompletions( res.end(JSON.stringify(completion)); } else { const chunks = buildTextChunks(response.content, body.model, chunkSize); - await writeSSEStream(res, chunks, latency); + const interruption = createInterruptionSignal(fixture); + const completed = await writeSSEStream(res, chunks, { + latency, + signal: interruption?.signal, + onChunkSent: interruption?.tick, + }); + if (!completed) { + if (!res.writableEnded) res.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); } return; } // Tool call response if (isToolCallResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: req.method ?? "POST", path: req.url ?? COMPLETIONS_PATH, headers: flattenHeaders(req.headers), @@ -206,7 +218,18 @@ async function handleCompletions( res.end(JSON.stringify(completion)); } else { const chunks = buildToolCallChunks(response.toolCalls, body.model, chunkSize); - await writeSSEStream(res, chunks, latency); + const interruption = createInterruptionSignal(fixture); + const completed = await writeSSEStream(res, chunks, { + latency, + signal: interruption?.signal, + onChunkSent: interruption?.tick, + }); + if (!completed) { + if (!res.writableEnded) res.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); } return; } diff --git a/src/sse-writer.ts b/src/sse-writer.ts index 3d9bc64..88f845e 100644 --- a/src/sse-writer.ts +++ b/src/sse-writer.ts @@ -1,32 +1,59 @@ import type * as http from "node:http"; import type { SSEChunk } from "./types.js"; -function delay(ms: number): Promise { - return new Promise((resolve) => setTimeout(resolve, ms)); +export function delay(ms: number, signal?: AbortSignal): Promise { + if (ms <= 0 || signal?.aborted) return Promise.resolve(); + return new Promise((resolve) => { + const timer = setTimeout(resolve, ms); + signal?.addEventListener( + "abort", + () => { + clearTimeout(timer); + resolve(); + }, + { once: true }, + ); + }); +} + +export interface StreamOptions { + latency?: number; + signal?: AbortSignal; + onChunkSent?: () => void; } export async function writeSSEStream( res: http.ServerResponse, chunks: SSEChunk[], - latency = 0, -): Promise { - if (res.writableEnded) return; + optionsOrLatency?: number | StreamOptions, +): Promise { + const opts: StreamOptions = + typeof optionsOrLatency === "number" ? { latency: optionsOrLatency } : (optionsOrLatency ?? {}); + const latency = opts.latency ?? 0; + const signal = opts.signal; + const onChunkSent = opts.onChunkSent; + + if (res.writableEnded) return true; res.setHeader("Content-Type", "text/event-stream"); res.setHeader("Cache-Control", "no-cache"); res.setHeader("Connection", "keep-alive"); for (const chunk of chunks) { if (latency > 0) { - await delay(latency); + await delay(latency, signal); } - if (res.writableEnded) return; + if (signal?.aborted) return false; + if (res.writableEnded) return true; res.write(`data: ${JSON.stringify(chunk)}\n\n`); + onChunkSent?.(); + if (signal?.aborted) return false; } if (!res.writableEnded) { res.write("data: [DONE]\n\n"); res.end(); } + return true; } export function writeErrorResponse(res: http.ServerResponse, status: number, body: string): void { diff --git a/src/types.ts b/src/types.ts index 4c19be6..598aea1 100644 --- a/src/types.ts +++ b/src/types.ts @@ -79,6 +79,8 @@ export interface Fixture { response: FixtureResponse; latency?: number; chunkSize?: number; + truncateAfterChunks?: number; + disconnectAfterMs?: number; } // Fixture file format (JSON on disk) @@ -98,6 +100,8 @@ export interface FixtureFileEntry { response: FixtureResponse; latency?: number; chunkSize?: number; + truncateAfterChunks?: number; + disconnectAfterMs?: number; } // Request journal @@ -109,7 +113,12 @@ export interface JournalEntry { path: string; headers: Record; body: ChatCompletionRequest; - response: { status: number; fixture: Fixture | null }; + response: { + status: number; + fixture: Fixture | null; + interrupted?: boolean; + interruptReason?: string; + }; } // SSE chunk types (OpenAI format) diff --git a/src/ws-framing.ts b/src/ws-framing.ts index 22fdd0f..643c74e 100644 --- a/src/ws-framing.ts +++ b/src/ws-framing.ts @@ -77,6 +77,15 @@ export class WebSocketConnection extends EventEmitter { }, 100); } + destroy(): void { + if (this.closed) return; + this.closed = true; + if (!this.socket.destroyed) { + this.socket.destroy(); + } + this.emit("close", 1006, "Connection destroyed"); + } + get isClosed(): boolean { return this.closed; } diff --git a/src/ws-gemini-live.ts b/src/ws-gemini-live.ts index 510f2ca..7aac86d 100644 --- a/src/ws-gemini-live.ts +++ b/src/ws-gemini-live.ts @@ -9,6 +9,8 @@ import type { Fixture, ChatMessage, ChatCompletionRequest, ToolDefinition } from "./types.js"; import { matchFixture } from "./router.js"; import { isTextResponse, isToolCallResponse, isErrorResponse } from "./helpers.js"; +import { createInterruptionSignal } from "./interruption.js"; +import { delay } from "./sse-writer.js"; import type { Journal } from "./journal.js"; import type { WebSocketConnection } from "./ws-framing.js"; @@ -73,10 +75,6 @@ interface SessionState { // ─── Helpers ──────────────────────────────────────────────────────────────── -function delay(ms: number): Promise { - return new Promise((resolve) => setTimeout(resolve, ms)); -} - const WS_PATH = "/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent"; /** @@ -314,7 +312,7 @@ async function processMessage( // Text response — stream chunks with serverContent if (isTextResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: "WS", path, headers: {}, @@ -343,10 +341,17 @@ async function processMessage( chunks.push(content.slice(i, i + chunkSize)); } + const interruption = createInterruptionSignal(fixture); + let interrupted = false; + for (let i = 0; i < chunks.length; i++) { - if (ws.isClosed) return; - if (latency > 0) await delay(latency); - if (ws.isClosed) return; + if (ws.isClosed) break; + if (latency > 0) await delay(latency, interruption?.signal); + if (interruption?.signal.aborted) { + interrupted = true; + break; + } + if (ws.isClosed) break; const isLast = i === chunks.length - 1; ws.send( @@ -357,8 +362,23 @@ async function processMessage( }, }), ); + interruption?.tick(); + if (interruption?.signal.aborted) { + interrupted = true; + break; + } + } + + if (interrupted) { + ws.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + interruption?.cleanup(); + return; } + interruption?.cleanup(); + // Add assistant response to conversation history session.conversationHistory.push({ role: "assistant", content }); return; @@ -366,7 +386,7 @@ async function processMessage( // Tool call response if (isToolCallResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: "WS", path, headers: {}, @@ -374,9 +394,24 @@ async function processMessage( response: { status: 200, fixture }, }); - if (ws.isClosed) return; - if (latency > 0) await delay(latency); - if (ws.isClosed) return; + const interruption = createInterruptionSignal(fixture); + + if (ws.isClosed) { + interruption?.cleanup(); + return; + } + if (latency > 0) await delay(latency, interruption?.signal); + if (interruption?.signal.aborted) { + ws.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + interruption?.cleanup(); + return; + } + if (ws.isClosed) { + interruption?.cleanup(); + return; + } const functionCalls = response.toolCalls.map((tc, i) => { let argsObj: Record; @@ -396,6 +431,17 @@ async function processMessage( }); ws.send(JSON.stringify({ toolCall: { functionCalls } })); + interruption?.tick(); + + if (interruption?.signal.aborted) { + ws.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + interruption?.cleanup(); + return; + } + + interruption?.cleanup(); // Add assistant tool_calls to conversation history session.conversationHistory.push({ diff --git a/src/ws-realtime.ts b/src/ws-realtime.ts index 8a4a64d..91d97e0 100644 --- a/src/ws-realtime.ts +++ b/src/ws-realtime.ts @@ -15,13 +15,11 @@ import { isToolCallResponse, isErrorResponse, } from "./helpers.js"; +import { createInterruptionSignal } from "./interruption.js"; +import { delay } from "./sse-writer.js"; import type { Journal } from "./journal.js"; import type { WebSocketConnection } from "./ws-framing.js"; -function delay(ms: number): Promise { - return new Promise((resolve) => setTimeout(resolve, ms)); -} - // ─── Realtime protocol types ──────────────────────────────────────────────── interface RealtimeItem { @@ -335,7 +333,7 @@ async function handleResponseCreate( // ── Text response ─────────────────────────────────────────────────── if (isTextResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: "WS", path: "/v1/realtime", headers: {}, @@ -383,10 +381,17 @@ async function handleResponseCreate( // response.text.delta (chunked) const content = response.content; + const interruption = createInterruptionSignal(fixture); + let interrupted = false; + for (let i = 0; i < content.length; i += chunkSize) { - if (ws.isClosed) return; - if (latency > 0) await delay(latency); - if (ws.isClosed) return; + if (ws.isClosed) break; + if (latency > 0) await delay(latency, interruption?.signal); + if (interruption?.signal.aborted) { + interrupted = true; + break; + } + if (ws.isClosed) break; const chunk = content.slice(i, i + chunkSize); ws.send( evt("response.text.delta", { @@ -397,8 +402,23 @@ async function handleResponseCreate( delta: chunk, }), ); + interruption?.tick(); + if (interruption?.signal.aborted) { + interrupted = true; + break; + } + } + + if (interrupted) { + ws.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + interruption?.cleanup(); + return; } + interruption?.cleanup(); + // response.text.done ws.send( evt("response.text.done", { @@ -449,7 +469,7 @@ async function handleResponseCreate( // ── Tool call response ────────────────────────────────────────────── if (isToolCallResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: "WS", path: "/v1/realtime", headers: {}, @@ -465,6 +485,8 @@ async function handleResponseCreate( ); const outputItems: unknown[] = []; + const interruption = createInterruptionSignal(fixture); + let interrupted = false; for (let tcIdx = 0; tcIdx < response.toolCalls.length; tcIdx++) { const tc = response.toolCalls[tcIdx]; @@ -497,9 +519,13 @@ async function handleResponseCreate( // response.function_call_arguments.delta (chunked) const args = tc.arguments; for (let i = 0; i < args.length; i += chunkSize) { - if (ws.isClosed) return; - if (latency > 0) await delay(latency); - if (ws.isClosed) return; + if (ws.isClosed) break; + if (latency > 0) await delay(latency, interruption?.signal); + if (interruption?.signal.aborted) { + interrupted = true; + break; + } + if (ws.isClosed) break; const chunk = args.slice(i, i + chunkSize); ws.send( evt("response.function_call_arguments.delta", { @@ -510,8 +536,15 @@ async function handleResponseCreate( delta: chunk, }), ); + interruption?.tick(); + if (interruption?.signal.aborted) { + interrupted = true; + break; + } } + if (interrupted) break; + // response.function_call_arguments.done ws.send( evt("response.function_call_arguments.done", { @@ -535,6 +568,16 @@ async function handleResponseCreate( outputItems.push(outputItem); } + if (interrupted) { + ws.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + interruption?.cleanup(); + return; + } + + interruption?.cleanup(); + // response.done ws.send( evt("response.done", { diff --git a/src/ws-responses.ts b/src/ws-responses.ts index 3bedb53..5f9495d 100644 --- a/src/ws-responses.ts +++ b/src/ws-responses.ts @@ -15,13 +15,11 @@ import { type ResponsesSSEEvent, } from "./responses.js"; import { isTextResponse, isToolCallResponse, isErrorResponse } from "./helpers.js"; +import { createInterruptionSignal } from "./interruption.js"; +import { delay } from "./sse-writer.js"; import type { Journal } from "./journal.js"; import type { WebSocketConnection } from "./ws-framing.js"; -function delay(ms: number): Promise { - return new Promise((resolve) => setTimeout(resolve, ms)); -} - interface ResponseCreateMessage { type: "response.create"; response: { @@ -182,7 +180,7 @@ async function processMessage( // Text response if (isTextResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: "WS", path: "/v1/responses", headers: {}, @@ -190,13 +188,26 @@ async function processMessage( response: { status: 200, fixture }, }); const events = buildTextStreamEvents(response.content, completionReq.model, chunkSize); - await sendEvents(ws, events, latency); + const interruption = createInterruptionSignal(fixture); + const completed = await sendEvents( + ws, + events, + latency, + interruption?.signal, + interruption?.tick, + ); + if (!completed) { + ws.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); return; } // Tool call response if (isToolCallResponse(response)) { - journal.add({ + const journalEntry = journal.add({ method: "WS", path: "/v1/responses", headers: {}, @@ -204,7 +215,20 @@ async function processMessage( response: { status: 200, fixture }, }); const events = buildToolCallStreamEvents(response.toolCalls, completionReq.model, chunkSize); - await sendEvents(ws, events, latency); + const interruption = createInterruptionSignal(fixture); + const completed = await sendEvents( + ws, + events, + latency, + interruption?.signal, + interruption?.tick, + ); + if (!completed) { + ws.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); return; } @@ -227,11 +251,17 @@ async function sendEvents( ws: WebSocketConnection, events: ResponsesSSEEvent[], latency: number, -): Promise { + signal?: AbortSignal, + onChunkSent?: () => void, +): Promise { for (const event of events) { - if (ws.isClosed) return; - if (latency > 0) await delay(latency); - if (ws.isClosed) return; + if (ws.isClosed) return true; + if (latency > 0) await delay(latency, signal); + if (signal?.aborted) return false; + if (ws.isClosed) return true; ws.send(JSON.stringify(event)); + onChunkSent?.(); + if (signal?.aborted) return false; } + return true; }