diff --git a/docs/sdk/concepts/connecting-servers.mdx b/docs/sdk/concepts/connecting-servers.mdx index d83fd905c..eef38d276 100644 --- a/docs/sdk/concepts/connecting-servers.mdx +++ b/docs/sdk/concepts/connecting-servers.mdx @@ -176,7 +176,7 @@ const manager = new MCPClientManager({ await manager.connectToServer("myServer"); -// Get tools in AI SDK format +// Get tools for TestAgent const tools = await manager.getTools(); // Create agent with those tools diff --git a/docs/sdk/concepts/testing-with-llms.mdx b/docs/sdk/concepts/testing-with-llms.mdx index 6605c360c..228c42e9c 100644 --- a/docs/sdk/concepts/testing-with-llms.mdx +++ b/docs/sdk/concepts/testing-with-llms.mdx @@ -143,6 +143,43 @@ const agent = new TestAgent({ }); ``` +## Control Multi-Step Loops with stopWhen + +Use `stopWhen` when you want to stop the multi-step loop after a particular step completes: + +```typescript +import { hasToolCall } from "@mcpjam/sdk"; + +// Stop after the step where the tool is called +const result = await agent.prompt("Search for open tasks", { + stopWhen: hasToolCall("search_tasks"), +}); + +expect(result.hasToolCall("search_tasks")).toBe(true); +``` + + +`stopWhen` does not skip tool execution. It controls whether the prompt loop continues after the current step completes, and `TestAgent` also applies `stepCountIs(maxSteps)` as a safety guard. + + +## Bound Prompt Runtime with timeout + +Use `timeout` when you want to bound how long `TestAgent.prompt()` can run: + +```typescript +const result = await agent.prompt("Run a long workflow", { + timeout: { totalMs: 10_000, stepMs: 2_500 }, +}); + +if (result.hasError()) { + console.error(result.getError()); +} +``` + + +`timeout` accepts `number`, `totalMs`, `stepMs`, and `chunkMs`. In practice, `number` and `totalMs` cap the full prompt, `stepMs` caps each step, and `chunkMs` mainly matters in streaming flows. The runtime creates an internal abort signal, so tools can stop early if their implementation respects the provided `abortSignal`. + + ## Writing Assertions Use validators to assert tool call behavior: diff --git a/docs/sdk/reference/llm-providers.mdx b/docs/sdk/reference/llm-providers.mdx index bd8ff88d7..424e781af 100644 --- a/docs/sdk/reference/llm-providers.mdx +++ b/docs/sdk/reference/llm-providers.mdx @@ -384,7 +384,7 @@ const { provider, modelId } = parseLLMString("anthropic/claude-sonnet-4-20250514 ### createModelFromString() -Create a Vercel AI SDK model directly. +Create a provider model instance directly. ```typescript import { createModelFromString } from "@mcpjam/sdk"; diff --git a/docs/sdk/reference/prompt-result.mdx b/docs/sdk/reference/prompt-result.mdx index 0c015ac5a..b1e24daf4 100644 --- a/docs/sdk/reference/prompt-result.mdx +++ b/docs/sdk/reference/prompt-result.mdx @@ -328,7 +328,7 @@ getMessages(): CoreMessage[] #### Returns -`CoreMessage[]` - Vercel AI SDK message format. +`CoreMessage[]` - The full conversation message format used by the SDK. #### Example diff --git a/docs/sdk/reference/test-agent.mdx b/docs/sdk/reference/test-agent.mdx index bea039067..ae74533b3 100644 --- a/docs/sdk/reference/test-agent.mdx +++ b/docs/sdk/reference/test-agent.mdx @@ -4,7 +4,7 @@ description: "API reference for TestAgent" icon: "book" --- -The `TestAgent` class wraps LLM providers via the Vercel AI SDK, enabling you to run prompts with MCP tools. It handles the agentic loop and returns rich result objects. +The `TestAgent` class runs prompts with MCP tools enabled. It handles the multi-step prompt loop and returns rich result objects. ## Import @@ -76,6 +76,8 @@ prompt( | Property | Type | Description | |----------|------|-------------| | `context` | `PromptResult \| PromptResult[]` | Previous result(s) for multi-turn conversations | +| `stopWhen` | `StopCondition \| Array>` | Additional conditions for the multi-step prompt loop. Tools still execute normally. `TestAgent` always applies `stepCountIs(maxSteps)` as a safety guard. | +| `timeout` | `number \| { totalMs?: number; stepMs?: number; chunkMs?: number }` | Bounds prompt runtime. `number` and `totalMs` cap the full prompt, `stepMs` caps each generation step, and `chunkMs` is accepted for parity but is mainly relevant to streaming APIs. | #### Returns @@ -84,6 +86,8 @@ prompt( #### Example ```typescript +import { hasToolCall } from "@mcpjam/sdk"; + // Simple prompt const result = await agent.prompt("Add 2 and 3"); @@ -93,12 +97,31 @@ const r2 = await agent.prompt("Mark it complete", { context: r1 }); // Multiple context items const r3 = await agent.prompt("Show summary", { context: [r1, r2] }); + +// Stop the loop after the step where a tool is called +const r4 = await agent.prompt("Search for tasks", { + stopWhen: hasToolCall("search_tasks"), +}); +console.log(r4.hasToolCall("search_tasks")); + +// Bound prompt runtime +const r5 = await agent.prompt("Run a long workflow", { + timeout: { totalMs: 10_000, stepMs: 2_500 }, +}); + +if (r5.hasError()) { + console.error(r5.getError()); +} ``` `prompt()` never throws exceptions. Errors are captured in the `PromptResult`. Check `result.hasError()` to detect failures. + +`timeout` bounds prompt runtime. The runtime creates an internal abort signal, so tools can stop early if their implementation respects the provided `abortSignal`. If a tool ignores that signal, its underlying work may continue briefly after the prompt returns an error result. + + --- ## Model String Format @@ -233,6 +256,52 @@ Setting `maxSteps` too low may prevent complex tasks from completing. Setting it --- +## Control Multi-Step Loops with stopWhen + +Use `stopWhen` to control whether the agent starts another step after the current step completes. + +```typescript +import { hasToolCall } from "@mcpjam/sdk"; + +// Stop after the step where "search_tasks" is called +const result = await agent.prompt("Find my open tasks", { + stopWhen: hasToolCall("search_tasks"), +}); + +expect(result.hasToolCall("search_tasks")).toBe(true); + +// Stop after any of multiple conditions +const result2 = await agent.prompt("Do something", { + stopWhen: [hasToolCall("tool_a"), hasToolCall("tool_b")], +}); +``` + + +`stopWhen` does not skip tool execution. It controls whether the prompt loop continues after the current step completes, and `TestAgent` also applies `stepCountIs(maxSteps)` as a safety guard. + + +--- + +## Bound Prompt Runtime with timeout + +Use `timeout` when you want to bound how long `TestAgent.prompt()` can run: + +```typescript +const result = await agent.prompt("Run a long workflow", { + timeout: 10_000, +}); + +const result2 = await agent.prompt("Run a long workflow", { + timeout: { totalMs: 10_000, stepMs: 2_500, chunkMs: 1_000 }, +}); +``` + + +`chunkMs` is accepted for parity, but it is mainly useful for streaming APIs. For `TestAgent.prompt()`, `number`, `totalMs`, and `stepMs` are the main settings to focus on. + + +--- + ## Complete Example ```typescript diff --git a/sdk/README.md b/sdk/README.md index cd7597c95..b01c0c24d 100644 --- a/sdk/README.md +++ b/sdk/README.md @@ -163,7 +163,7 @@ await manager.connectToServer("asana", { }, }); -// Get tools for AI SDK integration +// Get tools for TestAgent const tools = await manager.getToolsForAiSdk(["everything", "asana"]); // Direct MCP operations @@ -187,6 +187,8 @@ await manager.disconnectServer("everything"); Runs LLM prompts with MCP tool access. ```ts +import { hasToolCall } from "@mcpjam/sdk"; + const agent = new TestAgent({ tools: await manager.getToolsForAiSdk(), model: "openai/gpt-4o", // provider/model format @@ -202,8 +204,24 @@ const result = await agent.prompt("Add 2 and 3"); // Multi-turn with context const r1 = await agent.prompt("Who am I?"); const r2 = await agent.prompt("List my projects", { context: [r1] }); + +// Stop the loop after the step where a tool is called +const r3 = await agent.prompt("Search tasks", { + stopWhen: hasToolCall("search_tasks"), +}); +r3.hasToolCall("search_tasks"); // true + +// Bound prompt runtime +const r4 = await agent.prompt("Run a long workflow", { + timeout: { totalMs: 10_000, stepMs: 2_500 }, +}); +r4.hasError(); // true if the prompt timed out ``` +`stopWhen` does not skip tool execution. It controls whether the prompt loop continues after the current step completes, and `TestAgent` also applies `stepCountIs(maxSteps)` as a safety guard. + +`timeout` bounds prompt runtime. `number` and `totalMs` cap the full prompt, `stepMs` caps each step, and `chunkMs` is accepted for parity but mainly matters in streaming flows. The runtime creates an internal abort signal, so tools can stop early if their implementation respects the provided `abortSignal`. + **Supported providers:** `openai`, `anthropic`, `azure`, `google`, `mistral`, `deepseek`, `ollama`, `openrouter`, `xai` diff --git a/sdk/skills/create-mcp-eval/SKILL.md b/sdk/skills/create-mcp-eval/SKILL.md index ac48df0ff..22f326256 100644 --- a/sdk/skills/create-mcp-eval/SKILL.md +++ b/sdk/skills/create-mcp-eval/SKILL.md @@ -172,7 +172,7 @@ await manager.connectToServer("server-id", { env: { API_KEY: "..." }, }); -// Get AI SDK-compatible tools for TestAgent +// Get tools for TestAgent const tools = await manager.getToolsForAiSdk(["server-id"]); // Cleanup @@ -185,6 +185,7 @@ await manager.disconnectAllServers(); ```typescript import { TestAgent } from "@mcpjam/sdk"; +import { hasToolCall } from "@mcpjam/sdk"; const agent = new TestAgent({ tools, // from manager.getToolsForAiSdk() @@ -200,6 +201,18 @@ const result = await agent.prompt("List all projects"); const r1 = await agent.prompt("Get my user profile"); const r2 = await agent.prompt("List workspaces for that user", { context: r1 }); +// Stop the loop after the step where a tool is called +const r3 = await agent.prompt("Search tasks", { + stopWhen: hasToolCall("search_tasks"), +}); +r3.hasToolCall("search_tasks"); // true + +// Bound prompt runtime +const r4 = await agent.prompt("Run a long workflow", { + timeout: { totalMs: 10_000, stepMs: 2_500 }, +}); +r4.hasError(); // true if the prompt timed out + // Mock agent for deterministic tests (no LLM needed) const mockAgent = TestAgent.mock(async (message) => PromptResult.from({ @@ -216,6 +229,10 @@ const mockAgent = TestAgent.mock(async (message) => ); ``` +`stopWhen` does not skip tool execution. It controls whether the prompt loop continues after the current step completes, and `TestAgent` also applies `stepCountIs(maxSteps)` as a safety guard. + +`timeout` bounds prompt runtime. `number` and `totalMs` cap the full prompt, `stepMs` caps each step, and `chunkMs` is accepted for parity but mainly matters in streaming flows. The runtime creates an internal abort signal, so tools can stop early if their implementation respects the provided `abortSignal`. + ### PromptResult — Inspect Agent Responses ```typescript @@ -998,4 +1015,3 @@ it("selects search_tasks", async () => { expect(result.hasToolCall("search_tasks")).toBe(true); }, 90_000); ``` - diff --git a/sdk/src/EvalAgent.ts b/sdk/src/EvalAgent.ts index 0945fc55b..8b9fa1c19 100644 --- a/sdk/src/EvalAgent.ts +++ b/sdk/src/EvalAgent.ts @@ -1,3 +1,4 @@ +import type { StopCondition, TimeoutConfiguration, ToolSet } from "ai"; import type { PromptResult } from "./PromptResult.js"; /** @@ -6,6 +7,45 @@ import type { PromptResult } from "./PromptResult.js"; export interface PromptOptions { /** Previous PromptResult(s) to include as conversation context for multi-turn conversations */ context?: PromptResult | PromptResult[]; + + /** + * Additional stop conditions for the agentic loop. + * Evaluated after each step completes (tools execute normally). + * `stepCountIs(maxSteps)` is always applied as a safety guard + * in addition to any conditions provided here. + * + * Import helpers like `hasToolCall` and `stepCountIs` from `"@mcpjam/sdk"`. + * + * @example + * ```typescript + * import { hasToolCall } from "@mcpjam/sdk"; + * + * // Stop the loop after the step where "search_tasks" is called + * const result = await agent.prompt("Find my tasks", { + * stopWhen: hasToolCall("search_tasks"), + * }); + * expect(result.hasToolCall("search_tasks")).toBe(true); + * + * // Multiple conditions (any one being true stops the loop) + * const result = await agent.prompt("Do something", { + * stopWhen: [hasToolCall("tool_a"), hasToolCall("tool_b")], + * }); + * ``` + */ + stopWhen?: StopCondition | Array>; + + /** + * Timeout for the prompt runtime. + * + * - `number`: total timeout for the entire prompt call in milliseconds + * - `{ totalMs }`: total timeout across all steps + * - `{ stepMs }`: timeout for each generation step + * - `{ chunkMs }`: accepted for parity and primarily relevant to streaming APIs + * + * The runtime creates an internal abort signal. Tools can stop early if they + * respect the `abortSignal` passed to `execute()`. + */ + timeout?: TimeoutConfiguration; } /** diff --git a/sdk/src/TestAgent.ts b/sdk/src/TestAgent.ts index 8e1c80cb8..53400f208 100644 --- a/sdk/src/TestAgent.ts +++ b/sdk/src/TestAgent.ts @@ -3,7 +3,7 @@ */ import { generateText, stepCountIs, dynamicTool, jsonSchema } from "ai"; -import type { ToolSet, ModelMessage, UserModelMessage } from "ai"; +import type { StopCondition, ToolSet, ModelMessage, UserModelMessage } from "ai"; import { CallToolResultSchema } from "@modelcontextprotocol/sdk/types.js"; import { createModelFromString, parseLLMString } from "./model-factory.js"; import type { CreateModelOptions } from "./model-factory.js"; @@ -318,6 +318,19 @@ export class TestAgent implements EvalAgent { return instrumented; } + private resolveStopWhen( + stopWhen?: PromptOptions["stopWhen"] + ): Array> { + const base = [stepCountIs(this.maxSteps)]; + + if (stopWhen == null) { + return base; + } + + const conditions = Array.isArray(stopWhen) ? stopWhen : [stopWhen]; + return [...base, ...conditions]; + } + /** * Build an array of ModelMessages from previous PromptResult(s) for multi-turn context. * @param context - Single PromptResult or array of PromptResults to include as context @@ -383,10 +396,13 @@ export class TestAgent implements EvalAgent { const model = createModelFromString(this.model, modelOptions); // Instrument tools to track MCP execution time - const instrumentedTools = this.createInstrumentedTools((ms) => { - totalMcpMs += ms; - stepMcpMs += ms; // Accumulate per-step for LLM calculation - }, widgetSnapshots); + const instrumentedTools = this.createInstrumentedTools( + (ms) => { + totalMcpMs += ms; + stepMcpMs += ms; // Accumulate per-step for LLM calculation + }, + widgetSnapshots + ); // Build messages array if context is provided for multi-turn const contextMessages = this.buildContextMessages(options?.context); @@ -405,9 +421,12 @@ export class TestAgent implements EvalAgent { ...(this.temperature !== undefined && { temperature: this.temperature, }), + ...(options?.timeout !== undefined && { + timeout: options.timeout, + }), // Use stopWhen with stepCountIs for controlling max agentic steps // AI SDK v6+ uses this instead of maxSteps - stopWhen: stepCountIs(this.maxSteps), + stopWhen: this.resolveStopWhen(options?.stopWhen), onStepFinish: () => { const now = Date.now(); const stepDuration = now - lastStepEndTime; diff --git a/sdk/src/index.ts b/sdk/src/index.ts index 2ec293a72..93ce1b894 100644 --- a/sdk/src/index.ts +++ b/sdk/src/index.ts @@ -82,6 +82,10 @@ export { EvalReportingError, SdkError } from "./errors.js"; // EvalAgent interface (for deterministic testing without concrete TestAgent) export type { EvalAgent, PromptOptions } from "./EvalAgent.js"; +// AI SDK stop condition helpers re-exported for TestAgent.prompt() +export { hasToolCall, stepCountIs } from "ai"; +export type { StopCondition } from "ai"; + // TestAgent export { TestAgent } from "./TestAgent.js"; export type { TestAgentConfig } from "./TestAgent.js"; diff --git a/sdk/tests/TestAgent.stopWhen.integration.test.ts b/sdk/tests/TestAgent.stopWhen.integration.test.ts new file mode 100644 index 000000000..f192e3952 --- /dev/null +++ b/sdk/tests/TestAgent.stopWhen.integration.test.ts @@ -0,0 +1,91 @@ +import { dynamicTool, hasToolCall, jsonSchema } from "ai"; +import { MockLanguageModelV3 } from "ai/test"; +import { TestAgent } from "../src/TestAgent"; + +let currentModel: MockLanguageModelV3; +const mockCreateModelFromString = jest.fn(() => currentModel); + +jest.mock("../src/model-factory", () => { + const actual = jest.requireActual("../src/model-factory"); + return { + ...actual, + createModelFromString: (...args: any[]) => mockCreateModelFromString(...args), + }; +}); + +describe("TestAgent stopWhen integration", () => { + beforeEach(() => { + mockCreateModelFromString.mockClear(); + }); + + it("executes the tool and stops before the next generation step", async () => { + const toolExecutions: Array> = []; + let stepNumber = 0; + + currentModel = new MockLanguageModelV3({ + doGenerate: async () => { + stepNumber += 1; + + if (stepNumber === 1) { + return { + content: [ + { + type: "tool-call" as const, + toolCallId: "call-1", + toolName: "add", + input: JSON.stringify({ a: 2, b: 3 }), + }, + ], + finishReason: "tool-calls", + usage: { inputTokens: 5, outputTokens: 3, totalTokens: 8 }, + warnings: [], + }; + } + + return { + content: [{ type: "text" as const, text: "The result is 5" }], + finishReason: "stop", + usage: { inputTokens: 4, outputTokens: 2, totalTokens: 6 }, + warnings: [], + }; + }, + }); + + const agent = new TestAgent({ + tools: { + add: dynamicTool({ + description: "Add two numbers", + inputSchema: jsonSchema({ + type: "object", + properties: { + a: { type: "number" }, + b: { type: "number" }, + }, + required: ["a", "b"], + }), + execute: async (input) => { + const args = input as { a: number; b: number }; + toolExecutions.push(args); + return args.a + args.b; + }, + }), + }, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + + const result = await agent.prompt("Add 2 and 3", { + stopWhen: hasToolCall("add"), + }); + + expect(toolExecutions).toEqual([{ a: 2, b: 3 }]); + expect(result.hasToolCall("add")).toBe(true); + expect(result.getToolArguments("add")).toEqual({ a: 2, b: 3 }); + expect(result.text).toBe(""); + expect(currentModel.doGenerateCalls).toHaveLength(1); + expect(mockCreateModelFromString).toHaveBeenCalledWith( + "openai/gpt-4o", + expect.objectContaining({ apiKey: "test-key" }) + ); + }); +}); diff --git a/sdk/tests/TestAgent.test.ts b/sdk/tests/TestAgent.test.ts index b0b66dfad..7c1045fdf 100644 --- a/sdk/tests/TestAgent.test.ts +++ b/sdk/tests/TestAgent.test.ts @@ -19,12 +19,13 @@ jest.mock("../src/model-factory", () => ({ createModelFromString: jest.fn(() => ({})), })); -import { generateText, jsonSchema } from "ai"; +import { generateText, jsonSchema, stepCountIs } from "ai"; import { createModelFromString } from "../src/model-factory"; const mockGenerateText = generateText as jest.MockedFunction< typeof generateText >; +const mockStepCountIs = stepCountIs as jest.MockedFunction; const mockCreateModel = createModelFromString as jest.MockedFunction< typeof createModelFromString >; @@ -531,6 +532,9 @@ describe("TestAgent", () => { }); it("should pass system prompt and temperature to generateText", async () => { + const guard = { kind: "max-step-guard" } as any; + mockStepCountIs.mockReturnValueOnce(guard); + mockGenerateText.mockResolvedValueOnce({ text: "OK", steps: [], @@ -548,12 +552,13 @@ describe("TestAgent", () => { await agent.prompt("What is 2+2?"); + expect(mockStepCountIs).toHaveBeenCalledWith(15); expect(mockGenerateText).toHaveBeenCalledWith( expect.objectContaining({ system: "You are a math tutor.", prompt: "What is 2+2?", temperature: 0.3, - stopWhen: { type: "stepCount", value: 15 }, + stopWhen: [guard], }) ); @@ -894,6 +899,171 @@ describe("TestAgent", () => { }); }); + describe("stopWhen", () => { + it("should merge a single stop condition with maxSteps and still execute tools", async () => { + const stopCondition = jest.fn(() => false); + const guard = { kind: "max-step-guard" } as any; + mockStepCountIs.mockReturnValueOnce(guard); + + mockGenerateText.mockImplementationOnce(async (params: any) => { + const result = await params.tools.add.execute( + { a: 2, b: 3 }, + { abortSignal: { throwIfAborted: jest.fn() } } + ); + expect(result).toBe(5); + params.onStepFinish?.(); + return { + text: "Done", + steps: [ + { + toolCalls: [ + { + type: "tool-call", + toolCallId: "1", + toolName: "add", + input: { a: 2, b: 3 }, + }, + ], + }, + ], + usage: { inputTokens: 5, outputTokens: 3, totalTokens: 8 }, + } as any; + }); + + const agent = new TestAgent({ + tools: mockToolSet, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + + const result = await agent.prompt("Add 2 and 3", { + stopWhen: stopCondition as any, + }); + + expect(mockStepCountIs).toHaveBeenCalledWith(10); + expect(result.hasToolCall("add")).toBe(true); + expect(result.getToolArguments("add")).toEqual({ a: 2, b: 3 }); + + const callArgs = mockGenerateText.mock.calls[0][0] as any; + expect(callArgs.stopWhen).toEqual([guard, stopCondition]); + }); + + it("should merge multiple stop conditions with maxSteps", async () => { + const stopA = jest.fn(() => false); + const stopB = jest.fn(() => true); + const guard = { kind: "max-step-guard" } as any; + mockStepCountIs.mockReturnValueOnce(guard); + + mockGenerateText.mockResolvedValueOnce({ + text: "Done", + steps: [], + usage: { inputTokens: 5, outputTokens: 3, totalTokens: 8 }, + } as any); + + const agent = new TestAgent({ + tools: mockToolSet, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + + await agent.prompt("Do math", { + stopWhen: [stopA as any, stopB as any], + }); + + expect(mockStepCountIs).toHaveBeenCalledWith(10); + const callArgs = mockGenerateText.mock.calls[0][0] as any; + expect(callArgs.stopWhen).toEqual([guard, stopA, stopB]); + }); + + it("should default to stepCountIs when stopWhen is not set", async () => { + const guard = { kind: "max-step-guard" } as any; + mockStepCountIs.mockReturnValueOnce(guard); + + mockGenerateText.mockResolvedValueOnce({ + text: "OK", + steps: [], + usage: { inputTokens: 1, outputTokens: 1, totalTokens: 2 }, + } as any); + + const agent = new TestAgent({ + tools: mockToolSet, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + + await agent.prompt("Test"); + + expect(mockStepCountIs).toHaveBeenCalledWith(10); + const callArgs = mockGenerateText.mock.calls[0][0] as any; + expect(callArgs.stopWhen).toEqual([guard]); + }); + }); + + describe("timeout", () => { + it("should pass through a numeric timeout", async () => { + mockGenerateText.mockResolvedValueOnce({ + text: "OK", + steps: [], + usage: { inputTokens: 1, outputTokens: 1, totalTokens: 2 }, + } as any); + + const agent = new TestAgent({ + tools: mockToolSet, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + + await agent.prompt("Test", { timeout: 5000 }); + + const callArgs = mockGenerateText.mock.calls[0][0] as any; + expect(callArgs.timeout).toBe(5000); + }); + + it("should pass through an object timeout", async () => { + mockGenerateText.mockResolvedValueOnce({ + text: "OK", + steps: [], + usage: { inputTokens: 1, outputTokens: 1, totalTokens: 2 }, + } as any); + + const agent = new TestAgent({ + tools: mockToolSet, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + + await agent.prompt("Test", { + timeout: { totalMs: 5000, stepMs: 1000, chunkMs: 250 }, + }); + + const callArgs = mockGenerateText.mock.calls[0][0] as any; + expect(callArgs.timeout).toEqual({ + totalMs: 5000, + stepMs: 1000, + chunkMs: 250, + }); + }); + + it("should omit timeout when it is not set", async () => { + mockGenerateText.mockResolvedValueOnce({ + text: "OK", + steps: [], + usage: { inputTokens: 1, outputTokens: 1, totalTokens: 2 }, + } as any); + + const agent = new TestAgent({ + tools: mockToolSet, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + + await agent.prompt("Test"); + + const callArgs = mockGenerateText.mock.calls[0][0] as any; + expect(callArgs).not.toHaveProperty("timeout"); + }); + }); + describe("app-only tool filtering", () => { // Helper to create a mock Tool with visibility const createMockTool = ( diff --git a/sdk/tests/TestAgent.timeout.integration.test.ts b/sdk/tests/TestAgent.timeout.integration.test.ts new file mode 100644 index 000000000..205b4cc65 --- /dev/null +++ b/sdk/tests/TestAgent.timeout.integration.test.ts @@ -0,0 +1,122 @@ +import { dynamicTool, jsonSchema } from "ai"; +import { MockLanguageModelV3 } from "ai/test"; +import { TestAgent } from "../src/TestAgent"; + +let currentModel: MockLanguageModelV3; +const mockCreateModelFromString = jest.fn(() => currentModel); + +jest.mock("../src/model-factory", () => { + const actual = jest.requireActual("../src/model-factory"); + return { + ...actual, + createModelFromString: (...args: any[]) => + mockCreateModelFromString(...args), + }; +}); + +function toError(reason: unknown): Error { + if (reason instanceof Error) { + return reason; + } + + return new Error(String(reason ?? "aborted")); +} + +describe("TestAgent timeout integration", () => { + beforeEach(() => { + mockCreateModelFromString.mockClear(); + }); + + it("returns an error result when AI SDK timeout aborts a tool cooperatively", async () => { + let sawAbortSignal = false; + let abortObserved = false; + let stepNumber = 0; + + currentModel = new MockLanguageModelV3({ + doGenerate: async ({ abortSignal }) => { + stepNumber += 1; + + if (stepNumber === 1) { + return { + content: [ + { + type: "tool-call" as const, + toolCallId: "call-1", + toolName: "wait", + input: JSON.stringify({}), + }, + ], + finishReason: "tool-calls", + usage: { inputTokens: 5, outputTokens: 3, totalTokens: 8 }, + warnings: [], + }; + } + + if (abortSignal?.aborted) { + throw toError(abortSignal.reason); + } + + return { + content: [{ type: "text" as const, text: "unexpected follow-up" }], + finishReason: "stop", + usage: { inputTokens: 4, outputTokens: 2, totalTokens: 6 }, + warnings: [], + }; + }, + }); + + const agent = new TestAgent({ + tools: { + wait: dynamicTool({ + description: "Wait until the abort signal fires", + inputSchema: jsonSchema({ + type: "object", + properties: {}, + }), + execute: async (_input, { abortSignal }) => { + sawAbortSignal = abortSignal != null; + + if (abortSignal == null) { + throw new Error("missing abort signal"); + } + + if (abortSignal.aborted) { + abortObserved = true; + throw toError(abortSignal.reason); + } + + await new Promise((_, reject) => { + abortSignal.addEventListener( + "abort", + () => { + abortObserved = true; + reject(toError(abortSignal.reason)); + }, + { once: true } + ); + }); + + throw new Error("unreachable"); + }, + }), + }, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + + const startedAt = Date.now(); + const result = await agent.prompt("Run the long tool", { timeout: 25 }); + const elapsedMs = Date.now() - startedAt; + + expect(sawAbortSignal).toBe(true); + expect(abortObserved).toBe(true); + expect(result.hasError()).toBe(true); + expect(result.getError()).toEqual(expect.any(String)); + expect(elapsedMs).toBeLessThan(1000); + expect(currentModel.doGenerateCalls).toHaveLength(2); + expect(mockCreateModelFromString).toHaveBeenCalledWith( + "openai/gpt-4o", + expect.objectContaining({ apiKey: "test-key" }) + ); + }); +});