diff --git a/docs/sdk/reference/eval-test.mdx b/docs/sdk/reference/eval-test.mdx index e28590449..362a6e70a 100644 --- a/docs/sdk/reference/eval-test.mdx +++ b/docs/sdk/reference/eval-test.mdx @@ -26,18 +26,19 @@ new EvalTest(options: EvalTestConfig) ### EvalTestConfig -| Property | Type | Required | Description | -|----------|------|----------|-------------| -| `name` | `string` | Yes | Unique identifier for the test | -| `test` | `TestFunction` | Yes | The test function to run | +| Property | Type | Required | Description | +| -------- | -------------- | -------- | ------------------------------ | +| `name` | `string` | Yes | Unique identifier for the test | +| `test` | `TestFunction` | Yes | The test function to run | ### TestFunction Type ```typescript -type TestFunction = (agent: EvalAgent) => boolean | Promise +type TestFunction = (agent: EvalAgent) => boolean | Promise; ``` The test function receives an [`EvalAgent`](/sdk/reference/test-agent) and must return a `boolean`: + - `true` = test passed - `false` = test failed @@ -69,31 +70,33 @@ run(agent: EvalAgent, options: EvalTestRunOptions): Promise #### Parameters -| Parameter | Type | Description | -|-----------|------|-------------| -| `agent` | `EvalAgent` | The agent to test with (`TestAgent` or mock) | -| `options` | `EvalTestRunOptions` | Run configuration | +| Parameter | Type | Description | +| --------- | -------------------- | -------------------------------------------- | +| `agent` | `EvalAgent` | The agent to test with (`TestAgent` or mock) | +| `options` | `EvalTestRunOptions` | Run configuration | #### EvalTestRunOptions -| Property | Type | Required | Default | Description | -|----------|------|----------|---------|-------------| -| `iterations` | `number` | Yes | - | Number of test runs | -| `concurrency` | `number` | No | `5` | Parallel test runs | -| `retries` | `number` | No | `0` | Retry failed tests | -| `timeoutMs` | `number` | No | `30000` | Timeout per test (ms) | -| `onProgress` | `ProgressCallback` | No | - | Progress callback | -| `onFailure` | `(report: string) => void` | No | - | Called with a failure report if any iterations fail | -| `mcpjam` | [`MCPJamReportingConfig`](/sdk/reference/eval-reporting) | No | - | Auto-save results to MCPJam | +| Property | Type | Required | Default | Description | +| ------------- | -------------------------------------------------------- | -------- | ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `iterations` | `number` | Yes | - | Number of test runs | +| `concurrency` | `number` | No | `5` | Parallel test runs | +| `retries` | `number` | No | `0` | Retry failed tests | +| `timeoutMs` | `number` | No | `30000` | Per-iteration wall-clock timeout in ms. The active prompt is aborted at this deadline, then given a 1 second grace period to settle so partial tool calls and trace history can still be captured. | +| `onProgress` | `ProgressCallback` | No | - | Progress callback | +| `onFailure` | `(report: string) => void` | No | - | Called with a failure report if any iterations fail | +| `mcpjam` | [`MCPJamReportingConfig`](/sdk/reference/eval-reporting) | No | - | Auto-save results to MCPJam | -Results are automatically saved to MCPJam after the run completes when an API key is available via `mcpjam.apiKey` or the `MCPJAM_API_KEY` environment variable. Set `mcpjam.enabled: false` to disable. + Results are automatically saved to MCPJam after the run completes when an API + key is available via `mcpjam.apiKey` or the `MCPJAM_API_KEY` environment + variable. Set `mcpjam.enabled: false` to disable. #### ProgressCallback Type ```typescript -type ProgressCallback = (completed: number, total: number) => void +type ProgressCallback = (completed: number, total: number) => void; ``` #### Example @@ -226,26 +229,26 @@ getResults(): EvalRunResult | null #### EvalRunResult Type -| Property | Type | Description | -|----------|------|-------------| -| `iterations` | `number` | Total iterations run | -| `successes` | `number` | Number that passed | -| `failures` | `number` | Number that failed | -| `results` | `boolean[]` | Pass/fail per iteration | -| `iterationDetails` | `IterationResult[]` | Detailed per-iteration results | -| `tokenUsage` | `object` | Aggregate and per-iteration token usage | -| `latency` | `object` | Latency stats (e2e, llm, mcp) with p50/p95 | +| Property | Type | Description | +| ------------------ | ------------------- | ------------------------------------------ | +| `iterations` | `number` | Total iterations run | +| `successes` | `number` | Number that passed | +| `failures` | `number` | Number that failed | +| `results` | `boolean[]` | Pass/fail per iteration | +| `iterationDetails` | `IterationResult[]` | Detailed per-iteration results | +| `tokenUsage` | `object` | Aggregate and per-iteration token usage | +| `latency` | `object` | Latency stats (e2e, llm, mcp) with p50/p95 | #### IterationResult Type -| Property | Type | Description | -|----------|------|-------------| -| `passed` | `boolean` | Whether this iteration passed | -| `latencies` | `LatencyBreakdown[]` | Latency per prompt in this iteration | -| `tokens` | `{ total, input, output }` | Token usage | -| `error` | `string \| undefined` | Error message if failed | -| `retryCount` | `number \| undefined` | Number of retries attempted | -| `prompts` | `PromptResult[] \| undefined` | Prompt results from this iteration | +| Property | Type | Description | +| ------------ | ----------------------------- | ------------------------------------ | +| `passed` | `boolean` | Whether this iteration passed | +| `latencies` | `LatencyBreakdown[]` | Latency per prompt in this iteration | +| `tokens` | `{ total, input, output }` | Token usage | +| `error` | `string \| undefined` | Error message if failed | +| `retryCount` | `number \| undefined` | Number of retries attempted | +| `prompts` | `PromptResult[] \| undefined` | Prompt results from this iteration | --- @@ -337,7 +340,7 @@ if (test.accuracy() < 0.9) { The test's identifier (via `getName()`). ```typescript -test.getName() // "addition-accuracy" +test.getName(); // "addition-accuracy" ``` --- @@ -350,7 +353,7 @@ test.getName() // "addition-accuracy" test: async (agent) => { const result = await agent.prompt("Add 5 and 3"); return result.hasToolCall("add"); -} +}; ``` ### Argument Validation @@ -360,7 +363,7 @@ test: async (agent) => { const result = await agent.prompt("Add 10 and 20"); const args = result.getToolArguments("add"); return args?.a === 10 && args?.b === 20; -} +}; ``` ### Response Content @@ -369,7 +372,7 @@ test: async (agent) => { test: async (agent) => { const result = await agent.prompt("What is 5 + 5?"); return result.getText().includes("10"); -} +}; ``` ### Multiple Conditions @@ -382,7 +385,7 @@ test: async (agent) => { !result.hasError() && result.getText().length > 0 ); -} +}; ``` ### Multi-Turn Conversation @@ -392,7 +395,7 @@ test: async (agent) => { const r1 = await agent.prompt("Create a project"); const r2 = await agent.prompt("Add a task to it", { context: r1 }); return r1.hasToolCall("createProject") && r2.hasToolCall("createTask"); -} +}; ``` ### With Validators @@ -403,7 +406,7 @@ import { matchToolCallWithArgs } from "@mcpjam/sdk"; test: async (agent) => { const result = await agent.prompt("Add 2 and 3"); return matchToolCallWithArgs("add", { a: 2, b: 3 }, result.getToolCalls()); -} +}; ``` --- diff --git a/docs/sdk/reference/test-agent.mdx b/docs/sdk/reference/test-agent.mdx index ae74533b3..b75e2d110 100644 --- a/docs/sdk/reference/test-agent.mdx +++ b/docs/sdk/reference/test-agent.mdx @@ -26,15 +26,15 @@ new TestAgent(options: TestAgentOptions) ### TestAgentOptions -| Property | Type | Required | Default | Description | -|----------|------|----------|---------|-------------| -| `tools` | `Record` | Yes | - | MCP tools from `manager.getTools()` | -| `model` | `string` | Yes | - | Model identifier in `provider/model` format | -| `apiKey` | `string` | Yes | - | API key for the LLM provider | -| `systemPrompt` | `string` | No | `undefined` | System prompt for the LLM | -| `temperature` | `number` | No | `undefined` | Sampling temperature (0-2). If undefined, uses model default. | -| `maxSteps` | `number` | No | `10` | Maximum agentic loop iterations | -| `customProviders` | `Record` | No | `undefined` | Custom LLM provider definitions | +| Property | Type | Required | Default | Description | +| ----------------- | -------------------------------- | -------- | ----------- | ------------------------------------------------------------- | +| `tools` | `Record` | Yes | - | MCP tools from `manager.getTools()` | +| `model` | `string` | Yes | - | Model identifier in `provider/model` format | +| `apiKey` | `string` | Yes | - | API key for the LLM provider | +| `systemPrompt` | `string` | No | `undefined` | System prompt for the LLM | +| `temperature` | `number` | No | `undefined` | Sampling temperature (0-2). If undefined, uses model default. | +| `maxSteps` | `number` | No | `10` | Maximum agentic loop iterations | +| `customProviders` | `Record` | No | `undefined` | Custom LLM provider definitions | ### Example @@ -66,18 +66,21 @@ prompt( #### Parameters -| Parameter | Type | Required | Description | -|-----------|------|----------|-------------| -| `message` | `string` | Yes | The user prompt | -| `options` | `PromptOptions` | No | Additional options | +| Parameter | Type | Required | Description | +| --------- | --------------- | -------- | ------------------ | +| `message` | `string` | Yes | The user prompt | +| `options` | `PromptOptions` | No | Additional options | #### PromptOptions -| Property | Type | Description | -|----------|------|-------------| -| `context` | `PromptResult \| PromptResult[]` | Previous result(s) for multi-turn conversations | -| `stopWhen` | `StopCondition \| Array>` | Additional conditions for the multi-step prompt loop. Tools still execute normally. `TestAgent` always applies `stepCountIs(maxSteps)` as a safety guard. | -| `timeout` | `number \| { totalMs?: number; stepMs?: number; chunkMs?: number }` | Bounds prompt runtime. `number` and `totalMs` cap the full prompt, `stepMs` caps each generation step, and `chunkMs` is accepted for parity but is mainly relevant to streaming APIs. | +| Property | Type | Description | +| ------------------- | ------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `context` | `PromptResult \| PromptResult[]` | Previous result(s) for multi-turn conversations | +| `abortSignal` | `AbortSignal` | Cancels the prompt runtime when aborted | +| `stopWhen` | `StopCondition \| Array>` | Additional conditions for the multi-step prompt loop. Tools still execute normally. `TestAgent` always applies `stepCountIs(maxSteps)` as a safety guard. | +| `timeout` | `number \| { totalMs?: number; stepMs?: number; chunkMs?: number }` | Bounds prompt runtime. `number` and `totalMs` cap the full prompt, `stepMs` caps each generation step, and `chunkMs` is accepted for parity but is mainly relevant to streaming APIs. | +| `timeoutMs` | `number` | Shortcut for a total prompt timeout in milliseconds | +| `stopAfterToolCall` | `string \| string[]` | Short-circuits the named tool(s) with a stub result and stops after the step where they were called. Tool names and args are still captured in the `PromptResult`. | #### Returns @@ -112,14 +115,31 @@ const r5 = await agent.prompt("Run a long workflow", { if (r5.hasError()) { console.error(r5.getError()); } + +// Exit early after selecting a tool, without waiting for the real MCP call +const r6 = await agent.prompt("Search for tasks", { + stopAfterToolCall: "search_tasks", + timeoutMs: 5_000, +}); +console.log(r6.getToolArguments("search_tasks")); ``` -`prompt()` never throws exceptions. Errors are captured in the `PromptResult`. Check `result.hasError()` to detect failures. + `prompt()` never throws exceptions. Errors are captured in the `PromptResult`. + Check `result.hasError()` to detect failures. + + + + `stopAfterToolCall` is intended for evals that only care about tool selection + and arguments. If the model emits multiple tool calls in the same step, + non-target sibling tools may still execute before the loop stops. -`timeout` bounds prompt runtime. The runtime creates an internal abort signal, so tools can stop early if their implementation respects the provided `abortSignal`. If a tool ignores that signal, its underlying work may continue briefly after the prompt returns an error result. + `timeout` bounds prompt runtime. The runtime creates an internal abort signal, + so tools can stop early if their implementation respects the provided + `abortSignal`. If a tool ignores that signal, its underlying work may continue + briefly after the prompt returns an error result. --- @@ -130,34 +150,34 @@ Models are specified as `provider/model`: ```typescript // Anthropic -"anthropic/claude-sonnet-4-20250514" -"anthropic/claude-3-haiku-20240307" +"anthropic/claude-sonnet-4-20250514"; +"anthropic/claude-3-haiku-20240307"; // OpenAI -"openai/gpt-4o" -"openai/gpt-4o-mini" +"openai/gpt-4o"; +"openai/gpt-4o-mini"; // Google -"google/gemini-1.5-pro" -"google/gemini-1.5-flash" +"google/gemini-1.5-pro"; +"google/gemini-1.5-flash"; // Azure -"azure/gpt-4o" +"azure/gpt-4o"; // Mistral -"mistral/mistral-large-latest" +"mistral/mistral-large-latest"; // DeepSeek -"deepseek/deepseek-chat" +"deepseek/deepseek-chat"; // Ollama (local) -"ollama/llama3" +"ollama/llama3"; // OpenRouter -"openrouter/anthropic/claude-3-opus" +"openrouter/anthropic/claude-3-opus"; // xAI -"xai/grok-beta" +"xai/grok-beta"; ``` --- @@ -168,14 +188,14 @@ Add custom OpenAI or Anthropic-compatible endpoints. ### CustomProvider Type -| Property | Type | Required | Description | -|----------|------|----------|-------------| -| `name` | `string` | Yes | Provider identifier | -| `protocol` | `"openai-compatible" \| "anthropic-compatible"` | Yes | API protocol | -| `baseUrl` | `string` | Yes | API endpoint URL | -| `modelIds` | `string[]` | Yes | Available model IDs | -| `useChatCompletions` | `boolean` | No | Use `/chat/completions` endpoint | -| `apiKeyEnvVar` | `string` | No | Custom env var for API key | +| Property | Type | Required | Description | +| -------------------- | ----------------------------------------------- | -------- | -------------------------------- | +| `name` | `string` | Yes | Provider identifier | +| `protocol` | `"openai-compatible" \| "anthropic-compatible"` | Yes | API protocol | +| `baseUrl` | `string` | Yes | API endpoint URL | +| `modelIds` | `string[]` | Yes | Available model IDs | +| `useChatCompletions` | `boolean` | No | Use `/chat/completions` endpoint | +| `apiKeyEnvVar` | `string` | No | Custom env var for API key | ### Example @@ -214,7 +234,7 @@ const agent = new TestAgent({ tools, ... }); The LLM model identifier. Format: `provider/model-id`. ```typescript -model: "anthropic/claude-sonnet-4-20250514" +model: "anthropic/claude-sonnet-4-20250514"; ``` ### apiKey @@ -222,7 +242,7 @@ model: "anthropic/claude-sonnet-4-20250514" The API key for the LLM provider. ```typescript -apiKey: process.env.ANTHROPIC_API_KEY +apiKey: process.env.ANTHROPIC_API_KEY; ``` ### systemPrompt @@ -230,7 +250,7 @@ apiKey: process.env.ANTHROPIC_API_KEY Optional system prompt to guide the LLM's behavior. ```typescript -systemPrompt: "You are a task management assistant. Be concise." +systemPrompt: "You are a task management assistant. Be concise."; ``` ### temperature @@ -238,8 +258,8 @@ systemPrompt: "You are a task management assistant. Be concise." Controls response randomness. Range: 0.0 (deterministic) to 1.0 (creative). ```typescript -temperature: 0.1 // More deterministic, better for testing -temperature: 0.9 // More creative +temperature: 0.1; // More deterministic, better for testing +temperature: 0.9; // More creative ``` ### maxSteps @@ -247,11 +267,12 @@ temperature: 0.9 // More creative Maximum iterations of the agentic loop (prompt → tool → result → continue). ```typescript -maxSteps: 5 // Stop after 5 tool calls +maxSteps: 5; // Stop after 5 tool calls ``` -Setting `maxSteps` too low may prevent complex tasks from completing. Setting it too high may allow runaway loops. + Setting `maxSteps` too low may prevent complex tasks from completing. Setting + it too high may allow runaway loops. --- @@ -277,7 +298,9 @@ const result2 = await agent.prompt("Do something", { ``` -`stopWhen` does not skip tool execution. It controls whether the prompt loop continues after the current step completes, and `TestAgent` also applies `stepCountIs(maxSteps)` as a safety guard. + `stopWhen` does not skip tool execution. It controls whether the prompt loop + continues after the current step completes, and `TestAgent` also applies + `stepCountIs(maxSteps)` as a safety guard. --- @@ -297,7 +320,9 @@ const result2 = await agent.prompt("Run a long workflow", { ``` -`chunkMs` is accepted for parity, but it is mainly useful for streaming APIs. For `TestAgent.prompt()`, `number`, `totalMs`, and `stepMs` are the main settings to focus on. + `chunkMs` is accepted for parity, but it is mainly useful for streaming APIs. + For `TestAgent.prompt()`, `number`, `totalMs`, and `stepMs` are the main + settings to focus on. --- diff --git a/sdk/README.md b/sdk/README.md index b01c0c24d..46660c68b 100644 --- a/sdk/README.md +++ b/sdk/README.md @@ -1,6 +1,6 @@ # @mcpjam/sdk -Use the MCPJam SDK to write unit tests and evals for your MCP server. +Use the MCPJam SDK to write unit tests and evals for your MCP server. ## Installation @@ -8,13 +8,13 @@ Use the MCPJam SDK to write unit tests and evals for your MCP server. npm install @mcpjam/sdk ``` -Compatible with your favorite testing framework like [Jest](https://jestjs.io/) and [Vitest](https://vitest.dev/) +Compatible with your favorite testing framework like [Jest](https://jestjs.io/) and [Vitest](https://vitest.dev/) ## Quick Start ### Unit Test -Test the individual parts, request response flow of your MCP server. MCP unit tests are deterministic. +Test the individual parts, request response flow of your MCP server. MCP unit tests are deterministic. ```ts import { MCPClientManager } from "@mcpjam/sdk"; @@ -40,15 +40,18 @@ describe("Everything MCP example", () => { }); test("get-sum tool returns correct result", async () => { - const result = await manager.executeTool("everything", "get-sum", { a: 2, b: 3 }); + const result = await manager.executeTool("everything", "get-sum", { + a: 2, + b: 3, + }); expect(result.content[0].text).toBe("5"); }); }); ``` -### MCP evals +### MCP evals -Test that an LLM correctly understands how to use your MCP server. Evals are non-deterministic and multiple runs are needed. +Test that an LLM correctly understands how to use your MCP server. Evals are non-deterministic and multiple runs are needed. ```ts import { MCPClientManager, TestAgent, EvalTest } from "@mcpjam/sdk"; @@ -89,10 +92,10 @@ describe("Asana MCP Evals", () => { await evalTest.run(agent, { iterations: 10, - onFailure: (report) => console.error(report), // Print the report when a test iteration fails. + onFailure: (report) => console.error(report), // Print the report when a test iteration fails. }); - expect(evalTest.accuracy()).toBeGreaterThan(0.8); // Pass threshold + expect(evalTest.accuracy()).toBeGreaterThan(0.8); // Pass threshold }); // Multi-turn eval @@ -103,7 +106,9 @@ describe("Asana MCP Evals", () => { const r1 = await agent.prompt("Who am I in Asana?"); if (!r1.hasToolCall("asana_get_user")) return false; - const r2 = await agent.prompt("Now list my projects", { context: [r1] }); // Continue the conversation from the previous prompt + const r2 = await agent.prompt("Now list my projects", { + context: [r1], + }); // Continue the conversation from the previous prompt return r2.hasToolCall("asana_get_projects"); }, }); @@ -121,9 +126,14 @@ describe("Asana MCP Evals", () => { const evalTest = new EvalTest({ name: "search-args", test: async (agent) => { - const result = await agent.prompt("Search for tasks containing 'bug' in my workspace"); + const result = await agent.prompt( + "Search for tasks containing 'bug' in my workspace" + ); const args = result.getToolArguments("asana_search_tasks"); - return result.hasToolCall("asana_search_tasks") && typeof args?.workspace_gid === "string"; + return ( + result.hasToolCall("asana_search_tasks") && + typeof args?.workspace_gid === "string" + ); }, }); @@ -191,11 +201,11 @@ import { hasToolCall } from "@mcpjam/sdk"; const agent = new TestAgent({ tools: await manager.getToolsForAiSdk(), - model: "openai/gpt-4o", // provider/model format + model: "openai/gpt-4o", // provider/model format apiKey: process.env.OPENAI_API_KEY!, - systemPrompt: "You are a helpful assistant.", // optional - temperature: 0.7, // optional, omit for reasoning models - maxSteps: 10, // optional, max tool call loops + systemPrompt: "You are a helpful assistant.", // optional + temperature: 0.7, // optional, omit for reasoning models + maxSteps: 10, // optional, max tool call loops }); // Run a prompt @@ -209,19 +219,28 @@ const r2 = await agent.prompt("List my projects", { context: [r1] }); const r3 = await agent.prompt("Search tasks", { stopWhen: hasToolCall("search_tasks"), }); -r3.hasToolCall("search_tasks"); // true +r3.hasToolCall("search_tasks"); // true // Bound prompt runtime const r4 = await agent.prompt("Run a long workflow", { timeout: { totalMs: 10_000, stepMs: 2_500 }, }); -r4.hasError(); // true if the prompt timed out +r4.hasError(); // true if the prompt timed out + +// Exit early after selecting a tool without waiting for the MCP round-trip +const r5 = await agent.prompt("Search tasks", { + stopAfterToolCall: "search_tasks", + timeoutMs: 5_000, +}); +r5.getToolArguments("search_tasks"); // captured even if the prompt stops early ``` `stopWhen` does not skip tool execution. It controls whether the prompt loop continues after the current step completes, and `TestAgent` also applies `stepCountIs(maxSteps)` as a safety guard. `timeout` bounds prompt runtime. `number` and `totalMs` cap the full prompt, `stepMs` caps each step, and `chunkMs` is accepted for parity but mainly matters in streaming flows. The runtime creates an internal abort signal, so tools can stop early if their implementation respects the provided `abortSignal`. +`stopAfterToolCall` is intended for evals that only care about tool selection and arguments. The targeted tool is short-circuited with a stub result, and the `PromptResult` still includes the tool name and args. If multiple tools are emitted in the same step, non-target siblings may still execute before the loop stops. + **Supported providers:** `openai`, `anthropic`, `azure`, `google`, `mistral`, `deepseek`, `ollama`, `openrouter`, `xai` @@ -235,24 +254,24 @@ Returned by `agent.prompt()`. Contains the LLM response and tool calls. const result = await agent.prompt("Add 2 and 3"); // Tool calls -result.hasToolCall("add"); // boolean -result.toolsCalled(); // ["add"] -result.getToolCalls(); // [{ toolName: "add", arguments: { a: 2, b: 3 } }] -result.getToolArguments("add"); // { a: 2, b: 3 } +result.hasToolCall("add"); // boolean +result.toolsCalled(); // ["add"] +result.getToolCalls(); // [{ toolName: "add", arguments: { a: 2, b: 3 } }] +result.getToolArguments("add"); // { a: 2, b: 3 } // Response -result.text; // "The result is 5" +result.text; // "The result is 5" // Messages (full conversation) -result.getMessages(); // CoreMessage[] -result.getUserMessages(); // user messages only -result.getAssistantMessages(); // assistant messages only -result.getToolMessages(); // tool result messages only +result.getMessages(); // CoreMessage[] +result.getUserMessages(); // user messages only +result.getAssistantMessages(); // assistant messages only +result.getToolMessages(); // tool result messages only // Latency -result.e2eLatencyMs(); // total wall-clock time -result.llmLatencyMs(); // LLM API time -result.mcpLatencyMs(); // MCP tool execution time +result.e2eLatencyMs(); // total wall-clock time +result.llmLatencyMs(); // LLM API time +result.mcpLatencyMs(); // MCP tool execution time // Tokens result.totalTokens(); @@ -285,22 +304,22 @@ const test = new EvalTest({ await test.run(agent, { iterations: 30, - concurrency: 5, // parallel iterations (default: 5) - retries: 2, // retry failed iterations (default: 0) - timeoutMs: 30000, // timeout per iteration (default: 30000) + concurrency: 5, // parallel iterations (default: 5) + retries: 2, // retry failed iterations (default: 0) + timeoutMs: 30000, // aborts the active prompt at 30s, then waits up to 1s for it to settle onProgress: (completed, total) => console.log(`${completed}/${total}`), - onFailure: (report) => console.error(report), // called if any iteration fails + onFailure: (report) => console.error(report), // called if any iteration fails }); // Metrics -test.accuracy(); // success rate (0-1) -test.averageTokenUse(); // avg tokens per iteration +test.accuracy(); // success rate (0-1) +test.averageTokenUse(); // avg tokens per iteration // Iteration details -test.getAllIterations(); // all iteration results -test.getFailedIterations(); // failed iterations only -test.getSuccessfulIterations(); // successful iterations only -test.getFailureReport(); // formatted string of failed traces +test.getAllIterations(); // all iteration results +test.getFailedIterations(); // failed iterations only +test.getSuccessfulIterations(); // successful iterations only +test.getFailureReport(); // formatted string of failed traces ``` @@ -313,32 +332,36 @@ Groups multiple `EvalTest` instances for aggregate metrics. ```ts const suite = new EvalSuite({ name: "Math Operations" }); -suite.add(new EvalTest({ - name: "addition", - test: async (agent) => { - const r = await agent.prompt("Add 2+3"); - return r.hasToolCall("add"); - }, -})); - -suite.add(new EvalTest({ - name: "multiply", - test: async (agent) => { - const r = await agent.prompt("Multiply 4*5"); - return r.hasToolCall("multiply"); - }, -})); +suite.add( + new EvalTest({ + name: "addition", + test: async (agent) => { + const r = await agent.prompt("Add 2+3"); + return r.hasToolCall("add"); + }, + }) +); + +suite.add( + new EvalTest({ + name: "multiply", + test: async (agent) => { + const r = await agent.prompt("Multiply 4*5"); + return r.hasToolCall("multiply"); + }, + }) +); await suite.run(agent, { iterations: 30 }); // Aggregate metrics -suite.accuracy(); // overall accuracy +suite.accuracy(); // overall accuracy suite.averageTokenUse(); // Individual test access suite.get("addition")?.accuracy(); suite.get("multiply")?.accuracy(); -suite.getAll(); // all EvalTest instances +suite.getAll(); // all EvalTest instances ``` @@ -361,30 +384,30 @@ import { matchToolArgumentWith, } from "@mcpjam/sdk"; -const tools = result.toolsCalled(); // ["add", "multiply"] -const calls = result.getToolCalls(); // ToolCall[] +const tools = result.toolsCalled(); // ["add", "multiply"] +const calls = result.getToolCalls(); // ToolCall[] // Exact match (order matters) -matchToolCalls(["add", "multiply"], tools); // true -matchToolCalls(["multiply", "add"], tools); // false +matchToolCalls(["add", "multiply"], tools); // true +matchToolCalls(["multiply", "add"], tools); // false // Subset match (order doesn't matter) -matchToolCallsSubset(["add"], tools); // true +matchToolCallsSubset(["add"], tools); // true // Any match (at least one) -matchAnyToolCall(["add", "subtract"], tools); // true +matchAnyToolCall(["add", "subtract"], tools); // true // Count match -matchToolCallCount("add", tools, 1); // true +matchToolCallCount("add", tools, 1); // true // No tools called -matchNoToolCalls([]); // true +matchNoToolCalls([]); // true // Argument matching -matchToolCallWithArgs("add", { a: 2, b: 3 }, calls); // exact match -matchToolCallWithPartialArgs("add", { a: 2 }, calls); // partial match -matchToolArgument("add", "a", 2, calls); // single arg -matchToolArgumentWith("add", "a", (v) => v > 0, calls); // predicate +matchToolCallWithArgs("add", { a: 2, b: 3 }, calls); // exact match +matchToolCallWithPartialArgs("add", { a: 2 }, calls); // partial match +matchToolArgument("add", "a", 2, calls); // single arg +matchToolArgumentWith("add", "a", (v) => v > 0, calls); // predicate ``` diff --git a/sdk/src/EvalAgent.ts b/sdk/src/EvalAgent.ts index 8b9fa1c19..f870ced3c 100644 --- a/sdk/src/EvalAgent.ts +++ b/sdk/src/EvalAgent.ts @@ -8,6 +8,9 @@ export interface PromptOptions { /** Previous PromptResult(s) to include as conversation context for multi-turn conversations */ context?: PromptResult | PromptResult[]; + /** Optional abort signal for cancelling the prompt runtime. */ + abortSignal?: AbortSignal; + /** * Additional stop conditions for the agentic loop. * Evaluated after each step completes (tools execute normally). @@ -46,6 +49,15 @@ export interface PromptOptions { * respect the `abortSignal` passed to `execute()`. */ timeout?: TimeoutConfiguration; + + /** Shortcut for a total prompt timeout in milliseconds. */ + timeoutMs?: number; + + /** + * Stop the prompt loop after the step where one of these tools is called and + * short-circuit that tool execution with a stub result. + */ + stopAfterToolCall?: string | string[]; } /** diff --git a/sdk/src/EvalTest.ts b/sdk/src/EvalTest.ts index 58d8e5570..61207e695 100644 --- a/sdk/src/EvalTest.ts +++ b/sdk/src/EvalTest.ts @@ -103,26 +103,7 @@ class Semaphore { } } -/** - * Timeout wrapper for promises - */ -function withTimeout(promise: Promise, ms: number): Promise { - return new Promise((resolve, reject) => { - const timer = setTimeout(() => { - reject(new Error(`Operation timed out after ${ms}ms`)); - }, ms); - - promise - .then((value) => { - clearTimeout(timer); - resolve(value); - }) - .catch((error) => { - clearTimeout(timer); - reject(error); - }); - }); -} +const ITERATION_ABORT_GRACE_MS = 1000; /** * Sleep for a given number of milliseconds @@ -131,6 +112,87 @@ function sleep(ms: number): Promise { return new Promise((resolve) => setTimeout(resolve, ms)); } +function mergeAbortSignals( + first?: AbortSignal, + second?: AbortSignal +): AbortSignal | undefined { + if (!first) { + return second; + } + + if (!second) { + return first; + } + + if (first.aborted) { + return AbortSignal.abort(first.reason); + } + + if (second.aborted) { + return AbortSignal.abort(second.reason); + } + + const controller = new AbortController(); + const abort = (signal: AbortSignal) => { + cleanup(); + controller.abort(signal.reason); + }; + const onFirstAbort = () => abort(first); + const onSecondAbort = () => abort(second); + const cleanup = () => { + first.removeEventListener("abort", onFirstAbort); + second.removeEventListener("abort", onSecondAbort); + }; + + first.addEventListener("abort", onFirstAbort, { once: true }); + second.addEventListener("abort", onSecondAbort, { once: true }); + + return controller.signal; +} + +function collectPromptMetrics( + promptResults: PromptResult[] +): Pick { + const latencies = promptResults.map((result) => result.getLatency()); + + return { + latencies: + latencies.length > 0 ? latencies : [{ e2eMs: 0, llmMs: 0, mcpMs: 0 }], + tokens: { + total: promptResults.reduce( + (sum, result) => sum + result.totalTokens(), + 0 + ), + input: promptResults.reduce( + (sum, result) => sum + result.inputTokens(), + 0 + ), + output: promptResults.reduce( + (sum, result) => sum + result.outputTokens(), + 0 + ), + }, + prompts: promptResults, + }; +} + +function wrapAgentWithAbortSignal( + agent: EvalAgent, + abortSignal: AbortSignal +): EvalAgent { + return { + prompt: (message, options) => + agent.prompt(message, { + ...options, + abortSignal: mergeAbortSignals(options?.abortSignal, abortSignal), + }), + withOptions: (options) => + wrapAgentWithAbortSignal(agent.withOptions(options), abortSignal), + getPromptHistory: () => agent.getPromptHistory(), + resetPromptHistory: () => agent.resetPromptHistory(), + }; +} + /** * EvalTest - Runs a single test scenario with iterations * @@ -191,39 +253,48 @@ export class EvalTest { await semaphore.acquire(); try { let lastError: string | undefined; + let iterationAgent: EvalAgent | undefined; for (let attempt = 0; attempt <= retries; attempt++) { + const abortController = new AbortController(); + const timeoutError = new Error( + `Operation timed out after ${timeoutMs}ms` + ); + let timeoutTriggered = false; + let timeoutId: ReturnType | undefined; + let hardTimeoutId: ReturnType | undefined; + try { // Create a fresh agent clone for this iteration to avoid race conditions // when multiple iterations run concurrently - const iterationAgent = agent.withOptions({}); - - const passed = await withTimeout( - Promise.resolve(testFn(iterationAgent)), - timeoutMs + iterationAgent = wrapAgentWithAbortSignal( + agent.withOptions({}), + abortController.signal ); - - // Get metrics from this iteration's prompt history + const hardTimeoutPromise = new Promise((_, reject) => { + timeoutId = setTimeout(() => { + timeoutTriggered = true; + abortController.abort(timeoutError); + hardTimeoutId = setTimeout( + () => reject(timeoutError), + ITERATION_ABORT_GRACE_MS + ); + }, timeoutMs); + }); + const passed = await Promise.race([ + Promise.resolve().then(() => testFn(iterationAgent!)), + hardTimeoutPromise, + ]); const promptResults = iterationAgent.getPromptHistory(); - const latencies = promptResults.map((r) => r.getLatency()); - const tokens = { - total: promptResults.reduce((sum, r) => sum + r.totalTokens(), 0), - input: promptResults.reduce((sum, r) => sum + r.inputTokens(), 0), - output: promptResults.reduce( - (sum, r) => sum + r.outputTokens(), - 0 - ), - }; + const promptMetrics = collectPromptMetrics(promptResults); return { passed, - latencies: - latencies.length > 0 - ? latencies - : [{ e2eMs: 0, llmMs: 0, mcpMs: 0 }], - tokens, + ...promptMetrics, + ...(timeoutTriggered && !passed + ? { error: timeoutError.message } + : {}), retryCount: attempt, - prompts: promptResults, }; } catch (error) { lastError = error instanceof Error ? error.message : String(error); @@ -231,13 +302,23 @@ export class EvalTest { if (attempt < retries) { await sleep(100 * Math.pow(2, attempt)); } + } finally { + if (timeoutId) { + clearTimeout(timeoutId); + } + if (hardTimeoutId) { + clearTimeout(hardTimeoutId); + } } } + const promptMetrics = collectPromptMetrics( + iterationAgent?.getPromptHistory() ?? [] + ); + return { passed: false, - latencies: [{ e2eMs: 0, llmMs: 0, mcpMs: 0 }], - tokens: { total: 0, input: 0, output: 0 }, + ...promptMetrics, error: lastError, retryCount: retries, }; diff --git a/sdk/src/TestAgent.ts b/sdk/src/TestAgent.ts index 53400f208..dce32288d 100644 --- a/sdk/src/TestAgent.ts +++ b/sdk/src/TestAgent.ts @@ -2,14 +2,25 @@ * TestAgent - Runs LLM prompts with tool calling for evals */ -import { generateText, stepCountIs, dynamicTool, jsonSchema } from "ai"; -import type { StopCondition, ToolSet, ModelMessage, UserModelMessage } from "ai"; +import { + generateText, + hasToolCall, + stepCountIs, + dynamicTool, + jsonSchema, +} from "ai"; +import type { + StopCondition, + ToolSet, + ModelMessage, + UserModelMessage, +} from "ai"; import { CallToolResultSchema } from "@modelcontextprotocol/sdk/types.js"; import { createModelFromString, parseLLMString } from "./model-factory.js"; import type { CreateModelOptions } from "./model-factory.js"; import { extractToolCalls } from "./tool-extraction.js"; import { PromptResult } from "./PromptResult.js"; -import type { CustomProvider } from "./types.js"; +import type { CustomProvider, ToolCall as PromptToolCall } from "./types.js"; import type { EvalAgent, PromptOptions } from "./EvalAgent.js"; import type { Tool, AiSdkTool } from "./mcp-client-manager/types.js"; import type { MCPClientManager } from "./mcp-client-manager/MCPClientManager.js"; @@ -86,6 +97,13 @@ function convertToToolSet(tools: Tool[]): ToolSet { return toolSet; } +type StartedToolCall = { + toolCallId: string; + toolName: string; + arguments: Record; + shortCircuited: boolean; +}; + /** * Agent for running LLM prompts with tool calling. * Wraps the AI SDK generateText function with proper tool integration. @@ -283,7 +301,9 @@ export class TestAgent implements EvalAgent { private createInstrumentedTools( onLatency: (ms: number) => void, - snapshotBuffer: Map + snapshotBuffer: Map, + pendingStepToolCalls: StartedToolCall[], + shortCircuitTools?: Set ): ToolSet { const instrumented: ToolSet = {}; for (const [name, tool] of Object.entries(this.tools)) { @@ -294,13 +314,38 @@ export class TestAgent implements EvalAgent { ...tool, execute: async (args: any, options: any) => { const start = Date.now(); + const toolCallId = + typeof options?.toolCallId === "string" + ? options.toolCallId + : `${name}-${pendingStepToolCalls.length + 1}`; + const toolInput = (args ?? {}) as Record; + const shouldShortCircuit = shortCircuitTools?.has(name) ?? false; + + pendingStepToolCalls.push({ + toolCallId, + toolName: name, + arguments: toolInput, + shortCircuited: shouldShortCircuit, + }); + try { + if (shouldShortCircuit) { + return CallToolResultSchema.parse({ + content: [ + { + type: "text", + text: "[skipped by stopAfterToolCall]", + }, + ], + }); + } + const result = await originalExecute(args, options); await this.captureMcpAppSnapshot({ toolName: name, tool, options, - toolInput: args as Record, + toolInput, toolOutput: result, snapshotBuffer, }); @@ -319,16 +364,49 @@ export class TestAgent implements EvalAgent { } private resolveStopWhen( - stopWhen?: PromptOptions["stopWhen"] + stopWhen?: PromptOptions["stopWhen"], + stopAfterToolCall?: PromptOptions["stopAfterToolCall"] ): Array> { const base = [stepCountIs(this.maxSteps)]; + const conditions = + stopWhen == null ? [] : Array.isArray(stopWhen) ? stopWhen : [stopWhen]; + const stopAfterConditions = this.normalizeStopAfterToolCall( + stopAfterToolCall + ).map((toolName) => hasToolCall(toolName)); + + return [...base, ...conditions, ...stopAfterConditions]; + } + + private normalizeStopAfterToolCall( + stopAfterToolCall?: PromptOptions["stopAfterToolCall"] + ): string[] { + if (stopAfterToolCall == null) { + return []; + } + + return Array.isArray(stopAfterToolCall) + ? stopAfterToolCall + : [stopAfterToolCall]; + } - if (stopWhen == null) { - return base; + private buildPartialAssistantMessages( + pendingStepToolCalls: StartedToolCall[] + ): ModelMessage[] { + if (pendingStepToolCalls.length === 0) { + return []; } - const conditions = Array.isArray(stopWhen) ? stopWhen : [stopWhen]; - return [...base, ...conditions]; + return [ + { + role: "assistant", + content: pendingStepToolCalls.map((toolCall) => ({ + type: "tool-call" as const, + toolCallId: toolCall.toolCallId, + toolName: toolCall.toolName, + input: toolCall.arguments, + })), + }, + ]; } /** @@ -387,6 +465,12 @@ export class TestAgent implements EvalAgent { let totalLlmMs = 0; let stepMcpMs = 0; // MCP time within current step const widgetSnapshots = new Map(); + const completedToolCalls: PromptToolCall[] = []; + const pendingStepToolCalls: StartedToolCall[] = []; + let lastCompletedStepMessages: ModelMessage[] = []; + let partialInputTokens = 0; + let partialOutputTokens = 0; + let lastCompletedStepText = ""; try { const modelOptions: CreateModelOptions = { @@ -394,6 +478,9 @@ export class TestAgent implements EvalAgent { customProviders: this.customProviders, }; const model = createModelFromString(this.model, modelOptions); + const stopAfterToolCallNames = this.normalizeStopAfterToolCall( + options?.stopAfterToolCall + ); // Instrument tools to track MCP execution time const instrumentedTools = this.createInstrumentedTools( @@ -401,12 +488,15 @@ export class TestAgent implements EvalAgent { totalMcpMs += ms; stepMcpMs += ms; // Accumulate per-step for LLM calculation }, - widgetSnapshots + widgetSnapshots, + pendingStepToolCalls, + new Set(stopAfterToolCallNames) ); // Build messages array if context is provided for multi-turn const contextMessages = this.buildContextMessages(options?.context); const userMessage: UserModelMessage = { role: "user", content: message }; + const resolvedTimeout = options?.timeout ?? options?.timeoutMs; // Cast model to any to handle AI SDK version compatibility const result = await generateText({ @@ -421,19 +511,41 @@ export class TestAgent implements EvalAgent { ...(this.temperature !== undefined && { temperature: this.temperature, }), - ...(options?.timeout !== undefined && { - timeout: options.timeout, + ...(options?.abortSignal !== undefined && { + abortSignal: options.abortSignal, }), - // Use stopWhen with stepCountIs for controlling max agentic steps - // AI SDK v6+ uses this instead of maxSteps - stopWhen: this.resolveStopWhen(options?.stopWhen), - onStepFinish: () => { + ...(resolvedTimeout !== undefined && { + timeout: resolvedTimeout, + }), + stopWhen: this.resolveStopWhen( + options?.stopWhen, + options?.stopAfterToolCall + ), + onStepFinish: (stepResult) => { const now = Date.now(); const stepDuration = now - lastStepEndTime; // LLM time for this step = step duration - MCP time in this step totalLlmMs += Math.max(0, stepDuration - stepMcpMs); lastStepEndTime = now; stepMcpMs = 0; // Reset for next step + + if (!stepResult) { + return; + } + + partialInputTokens += stepResult.usage?.inputTokens ?? 0; + partialOutputTokens += stepResult.usage?.outputTokens ?? 0; + lastCompletedStepText = stepResult.text ?? ""; + lastCompletedStepMessages = stepResult.response?.messages + ? [...stepResult.response.messages] + : []; + completedToolCalls.push( + ...stepResult.toolCalls.map((toolCall) => ({ + toolName: toolCall.toolName, + arguments: (toolCall.input ?? {}) as Record, + })) + ); + pendingStepToolCalls.length = 0; }, }); @@ -471,19 +583,51 @@ export class TestAgent implements EvalAgent { return this.lastResult; } catch (error) { const e2eMs = Date.now() - startTime; + const abortReason = options?.abortSignal?.aborted + ? options.abortSignal.reason + : undefined; const errorMessage = - error instanceof Error ? error.message : String(error); + abortReason instanceof Error + ? abortReason.message + : abortReason != null + ? String(abortReason) + : error instanceof Error + ? error.message + : String(error); + const partialMessages: ModelMessage[] = [ + { role: "user", content: message }, + ...lastCompletedStepMessages, + ...this.buildPartialAssistantMessages(pendingStepToolCalls), + ]; + const partialToolCalls = [ + ...completedToolCalls, + ...pendingStepToolCalls.map((toolCall) => ({ + toolName: toolCall.toolName, + arguments: toolCall.arguments, + })), + ]; + const totalTokens = partialInputTokens + partialOutputTokens; - this.lastResult = PromptResult.error( - errorMessage, - { + this.lastResult = PromptResult.from({ + prompt: message, + messages: partialMessages, + text: lastCompletedStepText, + toolCalls: partialToolCalls, + usage: { + inputTokens: partialInputTokens, + outputTokens: partialOutputTokens, + totalTokens, + }, + latency: { e2eMs, llmMs: totalLlmMs, mcpMs: totalMcpMs, }, - message, - { provider: this._parsedProvider, model: this._parsedModel } - ); + error: errorMessage, + provider: this._parsedProvider, + model: this._parsedModel, + widgetSnapshots: Array.from(widgetSnapshots.values()), + }); this.promptHistory.push(this.lastResult); return this.lastResult; } diff --git a/sdk/tests/EvalTest.test.ts b/sdk/tests/EvalTest.test.ts index 91aa88e25..a65c94002 100644 --- a/sdk/tests/EvalTest.test.ts +++ b/sdk/tests/EvalTest.test.ts @@ -36,13 +36,13 @@ function createMockPromptResult(options: { // Create a mock TestAgent with prompt history tracking function createMockAgent( - promptFn: (message: string) => Promise + promptFn: (message: string, options?: any) => Promise ): TestAgent { const createAgent = (): TestAgent => { let promptHistory: PromptResult[] = []; return { - prompt: async (message: string) => { - const result = await promptFn(message); + prompt: async (message: string, options?: any) => { + const result = await promptFn(message, options); promptHistory.push(result); return result; }, @@ -467,7 +467,7 @@ describe("EvalTest", () => { describe("timeout handling", () => { it("should timeout after timeoutMs", async () => { const agent = createMockAgent(async () => { - await new Promise((resolve) => setTimeout(resolve, 200)); + await new Promise((resolve) => setTimeout(resolve, 2000)); return createMockPromptResult({}); }); @@ -506,6 +506,153 @@ describe("EvalTest", () => { const result = await test.run(agent, { iterations: 1 }); expect(result.successes).toBe(1); }); + + it("should pass if a timed-out prompt captured the expected tool call", async () => { + const agent = createMockAgent(async (message, options) => { + await new Promise((resolve) => { + options?.abortSignal?.addEventListener("abort", () => resolve(), { + once: true, + }); + }); + + return createMockPromptResult({ + prompt: message, + toolsCalled: ["add"], + tokens: 0, + error: "Operation timed out after 50ms", + }); + }); + + const test = new EvalTest({ + name: "timeout-partial-pass", + test: async (agent) => { + const result = await agent.prompt("Add 2 and 3"); + return result.hasToolCall("add"); + }, + }); + + const result = await test.run(agent, { + iterations: 1, + timeoutMs: 50, + concurrency: 1, + }); + + expect(result.successes).toBe(1); + expect(result.failures).toBe(0); + expect(result.iterationDetails[0].passed).toBe(true); + expect(result.iterationDetails[0].error).toBeUndefined(); + expect(result.iterationDetails[0].prompts).toHaveLength(1); + expect(result.iterationDetails[0].prompts?.[0].hasToolCall("add")).toBe( + true + ); + expect(result.iterationDetails[0].prompts?.[0].hasError()).toBe(true); + }); + + it("should preserve earlier prompts and metrics when a later prompt times out", async () => { + let promptCount = 0; + const agent = createMockAgent(async (message, options) => { + promptCount++; + + if (promptCount === 1) { + return createMockPromptResult({ + prompt: message, + toolsCalled: ["lookup"], + tokens: 50, + latency: { e2eMs: 20, llmMs: 15, mcpMs: 5 }, + }); + } + + await new Promise((resolve) => { + options?.abortSignal?.addEventListener("abort", () => resolve(), { + once: true, + }); + }); + + return createMockPromptResult({ + prompt: message, + toolsCalled: ["add"], + tokens: 0, + latency: { e2eMs: 50, llmMs: 10, mcpMs: 40 }, + error: "Operation timed out after 50ms", + }); + }); + + const test = new EvalTest({ + name: "multi-turn-timeout", + test: async (agent) => { + const first = await agent.prompt("First"); + const second = await agent.prompt("Second"); + return first.hasToolCall("lookup") && second.hasToolCall("add"); + }, + }); + + const result = await test.run(agent, { + iterations: 1, + timeoutMs: 50, + concurrency: 1, + }); + + expect(result.successes).toBe(1); + expect(result.iterationDetails[0].prompts).toHaveLength(2); + expect( + result.iterationDetails[0].prompts?.map((prompt) => prompt.getPrompt()) + ).toEqual(["First", "Second"]); + expect(result.iterationDetails[0].tokens).toEqual({ + total: 50, + input: 25, + output: 25, + }); + expect(result.iterationDetails[0].latencies).toEqual([ + { e2eMs: 20, llmMs: 15, mcpMs: 5 }, + { e2eMs: 50, llmMs: 10, mcpMs: 40 }, + ]); + }); + + it("should fail after the hard-timeout grace if a prompt ignores abort but preserve captured history", async () => { + const createHungAgent = (): TestAgent => { + let promptHistory: PromptResult[] = []; + + return { + prompt: async (message: string) => { + promptHistory.push( + createMockPromptResult({ + prompt: message, + toolsCalled: ["add"], + tokens: 0, + }) + ); + + return await new Promise(() => {}); + }, + resetPromptHistory: () => { + promptHistory = []; + }, + getPromptHistory: () => [...promptHistory], + withOptions: () => createHungAgent(), + } as unknown as TestAgent; + }; + + const test = new EvalTest({ + name: "hung-timeout", + test: async (agent) => { + const result = await agent.prompt("Add 2 and 3"); + return result.hasToolCall("add"); + }, + }); + + const result = await test.run(createHungAgent(), { + iterations: 1, + timeoutMs: 25, + concurrency: 1, + }); + + expect(result.failures).toBe(1); + expect(result.iterationDetails[0].error).toContain("timed out"); + expect(result.iterationDetails[0].prompts).toHaveLength(1); + expect(result.iterationDetails[0].prompts?.[0].hasToolCall("add")).toBe( + true + ); + }, 5000); }); describe("progress callback", () => { diff --git a/sdk/tests/TestAgent.test.ts b/sdk/tests/TestAgent.test.ts index 7c1045fdf..6f7f0418e 100644 --- a/sdk/tests/TestAgent.test.ts +++ b/sdk/tests/TestAgent.test.ts @@ -6,6 +6,10 @@ import type { Tool } from "../src/mcp-client-manager/types"; // Mock the ai module jest.mock("ai", () => ({ generateText: jest.fn(), + hasToolCall: jest.fn((name: string) => ({ + type: "hasToolCall", + value: name, + })), stepCountIs: jest.fn((n: number) => ({ type: "stepCount", value: n })), dynamicTool: jest.fn((config: any) => ({ ...config, @@ -19,12 +23,13 @@ jest.mock("../src/model-factory", () => ({ createModelFromString: jest.fn(() => ({})), })); -import { generateText, jsonSchema, stepCountIs } from "ai"; +import { generateText, hasToolCall, jsonSchema, stepCountIs } from "ai"; import { createModelFromString } from "../src/model-factory"; const mockGenerateText = generateText as jest.MockedFunction< typeof generateText >; +const mockHasToolCall = hasToolCall as jest.MockedFunction; const mockStepCountIs = stepCountIs as jest.MockedFunction; const mockCreateModel = createModelFromString as jest.MockedFunction< typeof createModelFromString @@ -508,6 +513,86 @@ describe("TestAgent", () => { expect(result.getError()).toBe("String error"); }); + it("should preserve partial tool calls and messages when aborted mid-tool", async () => { + let notifyExecuteStarted: (() => void) | undefined; + const executeStarted = new Promise((resolve) => { + notifyExecuteStarted = resolve; + }); + const abortableToolSet: ToolSet = { + add: { + description: "Add two numbers", + inputSchema: jsonSchema({ + type: "object", + properties: { + a: { type: "number" }, + b: { type: "number" }, + }, + required: ["a", "b"], + }), + execute: jest.fn(async (_args, options) => { + notifyExecuteStarted?.(); + await new Promise((_, reject) => { + options?.abortSignal?.addEventListener( + "abort", + () => reject(new Error("tool execution aborted")), + { once: true } + ); + }); + }), + }, + }; + + mockGenerateText.mockImplementationOnce(async (params: any) => { + await params.tools.add.execute( + { a: 2, b: 3 }, + { + toolCallId: "call-1", + abortSignal: params.abortSignal, + } + ); + + return { + text: "unreachable", + steps: [], + usage: { inputTokens: 1, outputTokens: 1, totalTokens: 2 }, + } as any; + }); + + const agent = new TestAgent({ + tools: abortableToolSet, + model: "openai/gpt-4o", + apiKey: "test-api-key", + }); + const abortController = new AbortController(); + const promptPromise = agent.prompt("Add 2 and 3", { + abortSignal: abortController.signal, + }); + + await executeStarted; + abortController.abort(new Error("Prompt aborted")); + + const result = await promptPromise; + + expect(result.hasError()).toBe(true); + expect(result.getError()).toBe("Prompt aborted"); + expect(result.hasToolCall("add")).toBe(true); + expect(result.getToolArguments("add")).toEqual({ a: 2, b: 3 }); + expect(result.getMessages()).toEqual([ + { role: "user", content: "Add 2 and 3" }, + { + role: "assistant", + content: [ + { + type: "tool-call", + toolCallId: "call-1", + toolName: "add", + input: { a: 2, b: 3 }, + }, + ], + }, + ]); + }); + it("should call createModelFromString with correct options", async () => { mockGenerateText.mockResolvedValueOnce({ text: "OK", @@ -999,6 +1084,179 @@ describe("TestAgent", () => { }); }); + describe("stopAfterToolCall", () => { + it("should append hasToolCall stop conditions for targeted tools", async () => { + const guard = { kind: "max-step-guard" } as any; + const stopOnAdd = { kind: "stop-on-add" } as any; + mockStepCountIs.mockReturnValueOnce(guard); + mockHasToolCall.mockReturnValueOnce(stopOnAdd); + mockGenerateText.mockResolvedValueOnce({ + text: "Done", + steps: [], + usage: { inputTokens: 1, outputTokens: 1, totalTokens: 2 }, + } as any); + + const agent = new TestAgent({ + tools: mockToolSet, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + + await agent.prompt("Add 2 and 3", { + stopAfterToolCall: "add", + }); + + const callArgs = mockGenerateText.mock.calls[0][0] as any; + expect(callArgs.stopWhen).toEqual([guard, stopOnAdd]); + expect(mockHasToolCall).toHaveBeenCalledWith("add"); + }); + + it("should short-circuit the targeted tool and preserve tool arguments", async () => { + const addExecute = jest.fn( + async (args: { a: number; b: number }) => args.a + args.b + ); + const subtractExecute = jest.fn( + async (args: { a: number; b: number }) => args.a - args.b + ); + const tools: ToolSet = { + add: { + description: "Add two numbers", + inputSchema: jsonSchema({ + type: "object", + properties: { + a: { type: "number" }, + b: { type: "number" }, + }, + required: ["a", "b"], + }), + execute: addExecute, + }, + subtract: { + description: "Subtract two numbers", + inputSchema: jsonSchema({ + type: "object", + properties: { + a: { type: "number" }, + b: { type: "number" }, + }, + required: ["a", "b"], + }), + execute: subtractExecute, + }, + }; + + mockGenerateText.mockImplementationOnce(async (params: any) => { + const addResult = await params.tools.add.execute( + { a: 2, b: 3 }, + { + toolCallId: "call-add", + abortSignal: { throwIfAborted: jest.fn() }, + } + ); + const subtractResult = await params.tools.subtract.execute( + { a: 5, b: 2 }, + { + toolCallId: "call-subtract", + abortSignal: { throwIfAborted: jest.fn() }, + } + ); + + expect(addResult).toEqual({ + content: [{ type: "text", text: "[skipped by stopAfterToolCall]" }], + }); + expect(subtractResult).toBe(3); + + return { + text: "Done", + steps: [ + { + toolCalls: [ + { + type: "tool-call", + toolCallId: "call-add", + toolName: "add", + input: { a: 2, b: 3 }, + }, + { + type: "tool-call", + toolCallId: "call-subtract", + toolName: "subtract", + input: { a: 5, b: 2 }, + }, + ], + }, + ], + usage: { inputTokens: 6, outputTokens: 4, totalTokens: 10 }, + } as any; + }); + + const agent = new TestAgent({ + tools, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + const result = await agent.prompt("Do some math", { + stopAfterToolCall: "add", + }); + + expect(addExecute).not.toHaveBeenCalled(); + expect(subtractExecute).toHaveBeenCalledTimes(1); + expect(result.hasToolCall("add")).toBe(true); + expect(result.getToolArguments("add")).toEqual({ a: 2, b: 3 }); + expect(result.hasToolCall("subtract")).toBe(true); + }); + + it("should skip widget snapshot capture for short-circuited tools", async () => { + const tools = createMcpAppToolSet(); + const createViewExecute = (tools.create_view as any).execute as jest.Mock; + const mockManager = { + getToolMetadata: jest.fn().mockReturnValue({ + ui: { resourceUri: "ui://widget/create-view.html" }, + }), + readResource: jest.fn(), + }; + + mockGenerateText.mockImplementationOnce(async (params: any) => { + await params.tools.create_view.execute( + { title: "Flow 1" }, + { toolCallId: "call-1", abortSignal: { throwIfAborted: jest.fn() } } + ); + + return { + text: "Done", + steps: [ + { + toolCalls: [ + { + type: "tool-call", + toolCallId: "call-1", + toolName: "create_view", + input: { title: "Flow 1" }, + }, + ], + }, + ], + usage: { inputTokens: 4, outputTokens: 2, totalTokens: 6 }, + } as any; + }); + + const agent = new TestAgent({ + tools, + model: "openai/gpt-4o", + apiKey: "test-key", + mcpClientManager: mockManager as any, + }); + const result = await agent.prompt("Create a view", { + stopAfterToolCall: "create_view", + }); + + expect(result.getWidgetSnapshots()).toEqual([]); + expect(createViewExecute).not.toHaveBeenCalled(); + expect(mockManager.getToolMetadata).not.toHaveBeenCalled(); + expect(mockManager.readResource).not.toHaveBeenCalled(); + }); + }); + describe("timeout", () => { it("should pass through a numeric timeout", async () => { mockGenerateText.mockResolvedValueOnce({ @@ -1044,6 +1302,67 @@ describe("TestAgent", () => { }); }); + it("should pass through timeoutMs as a numeric timeout", async () => { + mockGenerateText.mockResolvedValueOnce({ + text: "OK", + steps: [], + usage: { inputTokens: 1, outputTokens: 1, totalTokens: 2 }, + } as any); + + const agent = new TestAgent({ + tools: mockToolSet, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + + await agent.prompt("Test", { timeoutMs: 2500 }); + + const callArgs = mockGenerateText.mock.calls[0][0] as any; + expect(callArgs.timeout).toBe(2500); + }); + + it("should prefer timeout over timeoutMs when both are provided", async () => { + mockGenerateText.mockResolvedValueOnce({ + text: "OK", + steps: [], + usage: { inputTokens: 1, outputTokens: 1, totalTokens: 2 }, + } as any); + + const agent = new TestAgent({ + tools: mockToolSet, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + + await agent.prompt("Test", { + timeout: { totalMs: 5000, stepMs: 1000 }, + timeoutMs: 2500, + }); + + const callArgs = mockGenerateText.mock.calls[0][0] as any; + expect(callArgs.timeout).toEqual({ totalMs: 5000, stepMs: 1000 }); + }); + + it("should pass through abortSignal", async () => { + mockGenerateText.mockResolvedValueOnce({ + text: "OK", + steps: [], + usage: { inputTokens: 1, outputTokens: 1, totalTokens: 2 }, + } as any); + + const agent = new TestAgent({ + tools: mockToolSet, + model: "openai/gpt-4o", + apiKey: "test-key", + }); + const abortController = new AbortController(); + + await agent.prompt("Test", { abortSignal: abortController.signal }); + + const callArgs = mockGenerateText.mock.calls[0][0] as any; + expect(callArgs.abortSignal).toBe(abortController.signal); + }); + it("should omit timeout when it is not set", async () => { mockGenerateText.mockResolvedValueOnce({ text: "OK",