From a1b4bcc66e73d388a66ce4695211f72c31ffafd5 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Tue, 23 Dec 2025 20:02:15 +0100 Subject: [PATCH] testutil: export harness + refactor integration tests Add exported github.com/coder/aibridge/testutil helpers (txtar fixtures, upstream/mcp/bridge harnesses) and migrate integration tests to use them. Fixes #73 Change-Id: I7cdf803a08643a6c6683044bb576fc3da4622d9a Signed-off-by: Thomas Kosiewski --- bridge_integration_injected_tools_test.go | 258 ++++ bridge_integration_test.go | 1313 ++++----------------- metrics_integration_test.go | 176 ++- testutil/bridge_server.go | 111 ++ testutil/doc.go | 15 + testutil/fixture_llm.go | 118 ++ testutil/fixture_txtar.go | 79 ++ testutil/helpers.go | 28 + testutil/http_reflector.go | 60 + testutil/inspector.go | 81 ++ testutil/json.go | 32 + testutil/mcp_server.go | 122 ++ testutil/recorder_spy.go | 142 +++ testutil/upstream_server.go | 261 ++++ trace_integration_test.go | 246 ++-- 15 files changed, 1723 insertions(+), 1319 deletions(-) create mode 100644 bridge_integration_injected_tools_test.go create mode 100644 testutil/bridge_server.go create mode 100644 testutil/doc.go create mode 100644 testutil/fixture_llm.go create mode 100644 testutil/fixture_txtar.go create mode 100644 testutil/helpers.go create mode 100644 testutil/http_reflector.go create mode 100644 testutil/inspector.go create mode 100644 testutil/json.go create mode 100644 testutil/mcp_server.go create mode 100644 testutil/recorder_spy.go create mode 100644 testutil/upstream_server.go diff --git a/bridge_integration_injected_tools_test.go b/bridge_integration_injected_tools_test.go new file mode 100644 index 0000000..b7aa24c --- /dev/null +++ b/bridge_integration_injected_tools_test.go @@ -0,0 +1,258 @@ +package aibridge_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + "time" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" + "github.com/coder/aibridge" + "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/testutil" + "github.com/openai/openai-go/v2" + oaissestream "github.com/openai/openai-go/v2/packages/ssestream" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace" +) + +type injectedToolHarness struct { + Recorder *testutil.RecorderSpy + MCP *testutil.MCPServer + MCPProxiers map[string]mcp.ServerProxier + Upstream *testutil.UpstreamServer + Bridge *testutil.BridgeServer + Inspector *testutil.Inspector + Response *http.Response + + RequestBody []byte + RequestPath string +} + +func runInjectedToolTest(t *testing.T, providerName string, fixture []byte, streaming bool, tracer trace.Tracer, makeProviders func(upstreamURL string) []aibridge.Provider) injectedToolHarness { + t.Helper() + + if tracer == nil { + tracer = testTracer + } + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) + t.Cleanup(cancel) + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + + // Fixture-driven upstream. + arc := testutil.MustParseTXTAR(t, fixture) + llm := testutil.MustLLMFixture(t, arc) + upstream := testutil.NewUpstreamServer(t, ctx, llm) + + // MCP server + proxies. + mcpSrv := testutil.NewMCPServer(t, testutil.DefaultCoderToolNames()) + mcpProxiers := mcpSrv.Proxiers(t, "coder", logger, tracer) + + recorder := &testutil.RecorderSpy{} + bridge := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Providers: makeProviders(upstream.URL), + Recorder: recorder, + MCPProxiers: mcpProxiers, + Logger: logger, + Tracer: tracer, + }) + + reqBody := llm.MustRequestBody(t, streaming) + req := bridge.NewProviderRequest(t, providerName, reqBody) + + resp, err := bridge.Client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + t.Cleanup(func() { _ = resp.Body.Close() }) + + // Injected tool tests must always produce exactly 2 upstream calls. + upstream.RequireCallCountEventually(t, 2) + + inspector := testutil.NewInspector(recorder, mcpSrv, upstream) + + return injectedToolHarness{ + Recorder: recorder, + MCP: mcpSrv, + MCPProxiers: mcpProxiers, + Upstream: upstream, + Bridge: bridge, + Inspector: inspector, + Response: resp, + + RequestBody: reqBody, + RequestPath: req.URL.Path, + } +} + +func TestAnthropicInjectedTools(t *testing.T) { + t.Parallel() + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + h := runInjectedToolTest(t, aibridge.ProviderAnthropic, antSingleInjectedTool, streaming, testTracer, func(upstreamURL string) []aibridge.Provider { + return []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(upstreamURL, apiKey), nil)} + }) + resp := h.Response + + // Ensure expected tool was invoked with expected input. + h.Inspector.RequireToolCalledOnceWithArgs(t, testutil.ToolCoderListWorkspaces, map[string]any{"owner": "admin"}) + + var ( + content *anthropic.ContentBlockUnion + message anthropic.Message + ) + if streaming { + // Parse the response stream. + decoder := ssestream.NewDecoder(resp) + stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil) + for stream.Next() { + event := stream.Current() + require.NoError(t, message.Accumulate(event), "accumulate event") + } + + require.NoError(t, stream.Err(), "stream error") + require.Len(t, message.Content, 2) + + content = &message.Content[1] + } else { + // Parse & unmarshal the response. + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "read response body") + + require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") + require.GreaterOrEqual(t, len(message.Content), 1) + + content = &message.Content[0] + } + + // Ensure tool returned expected value. + require.NotNil(t, content) + require.Contains(t, content.Text, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. + + // Check the token usage from the client's perspective. + // + // We overwrite the final message_delta which is relayed to the client to include the + // accumulated tokens but currently the SDK only supports accumulating output tokens + // for message_delta events. + // + // For non-streaming requests the token usage is also overwritten and should be faithfully + // represented in the response. + // + // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/message.go#L2619-L2622 + if !streaming { + assert.EqualValues(t, 15308, message.Usage.InputTokens) + } + assert.EqualValues(t, 204, message.Usage.OutputTokens) + + // Ensure tokens used during injected tool invocation are accounted for. + tokenUsages := h.Recorder.RecordedTokenUsages() + assert.EqualValues(t, 15308, testutil.TotalInputTokens(tokenUsages)) + assert.EqualValues(t, 204, testutil.TotalOutputTokens(tokenUsages)) + + // Ensure we received exactly one prompt. + promptUsages := h.Recorder.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + }) + } +} + +func TestOpenAIInjectedTools(t *testing.T) { + t.Parallel() + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + h := runInjectedToolTest(t, aibridge.ProviderOpenAI, oaiSingleInjectedTool, streaming, testTracer, func(upstreamURL string) []aibridge.Provider { + return []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(upstreamURL, apiKey))} + }) + resp := h.Response + + // Ensure expected tool was invoked with expected input. + h.Inspector.RequireToolCalledOnceWithArgs(t, testutil.ToolCoderListWorkspaces, map[string]any{"owner": "admin"}) + + var ( + content *openai.ChatCompletionChoice + message openai.ChatCompletion + ) + if streaming { + // Parse the response stream. + decoder := oaissestream.NewDecoder(resp) + stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil) + var acc openai.ChatCompletionAccumulator + detectedToolCalls := make(map[string]struct{}) + for stream.Next() { + chunk := stream.Current() + acc.AddChunk(chunk) + + if len(chunk.Choices) == 0 { + continue + } + + for _, c := range chunk.Choices { + if len(c.Delta.ToolCalls) == 0 { + continue + } + + for _, t := range c.Delta.ToolCalls { + if t.Function.Name == "" { + continue + } + + detectedToolCalls[t.Function.Name] = struct{}{} + } + } + } + + // Verify that no injected tool call events (or partials thereof) were sent to the client. + require.Len(t, detectedToolCalls, 0) + + message = acc.ChatCompletion + require.NoError(t, stream.Err(), "stream error") + } else { + // Parse & unmarshal the response. + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "read response body") + require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") + + // Verify that no injected tools were sent to the client. + require.GreaterOrEqual(t, len(message.Choices), 1) + require.Len(t, message.Choices[0].Message.ToolCalls, 0) + } + + require.GreaterOrEqual(t, len(message.Choices), 1) + content = &message.Choices[0] + + // Ensure tool returned expected value. + require.NotNil(t, content) + require.Contains(t, content.Message.Content, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. + + // Check the token usage from the client's perspective. + // This *should* work but the openai SDK doesn't accumulate the prompt token details :(. + // See https://github.com/openai/openai-go/blob/v2.7.0/streamaccumulator.go#L145-L147. + // assert.EqualValues(t, 5047, message.Usage.PromptTokens-message.Usage.PromptTokensDetails.CachedTokens) + assert.EqualValues(t, 105, message.Usage.CompletionTokens) + + // Ensure tokens used during injected tool invocation are accounted for. + tokenUsages := h.Recorder.RecordedTokenUsages() + require.EqualValues(t, 5047, testutil.TotalInputTokens(tokenUsages)) + require.EqualValues(t, 105, testutil.TotalOutputTokens(tokenUsages)) + + // Ensure we received exactly one prompt. + promptUsages := h.Recorder.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + }) + } +} diff --git a/bridge_integration_test.go b/bridge_integration_test.go index f73138e..39215e4 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1,20 +1,15 @@ package aibridge_test import ( - "bufio" "bytes" "context" _ "embed" "encoding/json" "fmt" "io" - "net" "net/http" "net/http/httptest" - "slices" "strings" - "sync" - "sync/atomic" "testing" "time" @@ -25,19 +20,15 @@ import ( "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/testutil" "github.com/google/uuid" - mcplib "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" "github.com/openai/openai-go/v2" oaissestream "github.com/openai/openai-go/v2/packages/ssestream" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" - "github.com/tidwall/sjson" "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/trace" "go.uber.org/goleak" - "golang.org/x/tools/txtar" ) var ( @@ -71,13 +62,6 @@ var ( ) const ( - fixtureRequest = "request" - fixtureStreamingResponse = "streaming" - fixtureNonStreamingResponse = "non-streaming" - fixtureStreamingToolResponse = "streaming/tool-call" - fixtureNonStreamingToolResponse = "non-streaming/tool-call" - fixtureResponse = "response" - apiKey = "api-key" userID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" ) @@ -113,45 +97,31 @@ func TestAnthropicMessages(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { t.Parallel() - arc := txtar.Parse(antSingleBuiltinTool) + arc := testutil.MustParseTXTAR(t, antSingleBuiltinTool) t.Logf("%s: %s", t.Name(), arc.Comment) - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) - - reqBody := files[fixtureRequest] + llm := testutil.MustLLMFixture(t, arc) - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", tc.streaming) - require.NoError(t, err) - reqBody = newBody - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - srv := newMockServer(ctx, t, files, nil) - t.Cleanup(srv.Close) - recorderClient := &mockRecorderClient{} + upstream := testutil.NewUpstreamServer(t, ctx, llm) - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)} - b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) - require.NoError(t, err) + recorderClient := &testutil.RecorderSpy{} - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - mockSrv.Start() + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + bridgeSrv := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Providers: []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(upstream.URL, apiKey), nil)}, + Recorder: recorderClient, + Logger: logger, + Tracer: testTracer, + }) - // Make API call to aibridge for Anthropic /v1/messages - req := createAnthropicMessagesReq(t, mockSrv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) + reqBody := llm.MustRequestBody(t, tc.streaming) + req := bridgeSrv.NewProviderRequest(t, aibridge.ProviderAnthropic, reqBody) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() @@ -174,8 +144,8 @@ func TestAnthropicMessages(t *testing.T) { tokenUsages := recorderClient.RecordedTokenUsages() require.Len(t, tokenUsages, expectedTokenRecordings) - assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") + assert.EqualValues(t, tc.expectedInputTokens, testutil.TotalInputTokens(tokenUsages), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, testutil.TotalOutputTokens(tokenUsages), "output tokens miscalculated") toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) @@ -190,7 +160,7 @@ func TestAnthropicMessages(t *testing.T) { require.Len(t, promptUsages, 1) assert.Equal(t, "read the foo file", promptUsages[0].Prompt) - recorderClient.verifyAllInterceptionsEnded(t) + recorderClient.RequireAllInterceptionsEnded(t) }) } }) @@ -202,11 +172,10 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Run("invalid config", func(t *testing.T) { t.Parallel() - arc := txtar.Parse(antSingleBuiltinTool) - files := filesMap(arc) - reqBody := files[fixtureRequest] + arc := testutil.MustParseTXTAR(t, antSingleBuiltinTool) + reqBody := arc.MustFile(t, testutil.FixtureRequest) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) // Invalid bedrock config - missing region @@ -218,22 +187,21 @@ func TestAWSBedrockIntegration(t *testing.T) { SmallFastModel: "test-haiku", } - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.RecorderSpy{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{ - aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg), - }, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - mockSrv.Start() + bridgeSrv := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Providers: []aibridge.Provider{ + aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg), + }, + Recorder: recorderClient, + Logger: logger, + Tracer: testTracer, + }) - req := createAnthropicMessagesReq(t, mockSrv.URL, reqBody) - resp, err := http.DefaultClient.Do(req) + req := bridgeSrv.NewProviderRequest(t, aibridge.ProviderAnthropic, reqBody) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -249,60 +217,15 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) { t.Parallel() - arc := txtar.Parse(antSingleBuiltinTool) + arc := testutil.MustParseTXTAR(t, antSingleBuiltinTool) t.Logf("%s: %s", t.Name(), arc.Comment) - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - - reqBody := files[fixtureRequest] + llm := testutil.MustLLMFixture(t, arc) - newBody, err := setJSON(reqBody, "stream", streaming) - require.NoError(t, err) - reqBody = newBody - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - var receivedModelName string - var requestCount int - - // Create a mock server that intercepts requests to capture model name and return fixtures. - srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount++ - t.Logf("Mock server received request #%d: %s %s (streaming=%v)", requestCount, r.Method, r.URL.Path, streaming) - t.Logf("Request headers: %v", r.Header) - - // AWS Bedrock encodes the model name in the URL path: /model/{model-id}/invoke or /model/{model-id}/invoke-with-response-stream. - // Extract the model name from the path. - pathParts := strings.Split(r.URL.Path, "/") - if len(pathParts) >= 3 && pathParts[1] == "model" { - receivedModelName = pathParts[2] - t.Logf("Extracted model name from path: %s", receivedModelName) - } - - // Return appropriate fixture response. - var respBody []byte - if streaming { - respBody = files[fixtureStreamingResponse] - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - } else { - respBody = files[fixtureNonStreamingResponse] - w.Header().Set("Content-Type", "application/json") - } - - w.WriteHeader(http.StatusOK) - _, _ = w.Write(respBody) - })) - - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - srv.Start() - t.Cleanup(srv.Close) + upstream := testutil.NewUpstreamServer(t, ctx, llm) // Configure Bedrock with test credentials and model names. // The EndpointOverride will make requests go to the mock server instead of real AWS endpoints. @@ -312,47 +235,51 @@ func TestAWSBedrockIntegration(t *testing.T) { AccessKeySecret: "test-secret-key", Model: "danthropic", // This model should override the request's given one. SmallFastModel: "danthropic-mini", // Unused but needed for validation. - EndpointOverride: srv.URL, + EndpointOverride: upstream.URL, } - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.RecorderSpy{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge( - ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)}, - recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) - require.NoError(t, err) - - mockBridgeSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockBridgeSrv.Close) - mockBridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - mockBridgeSrv.Start() + bridgeSrv := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Providers: []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(upstream.URL, apiKey), bedrockCfg)}, + Recorder: recorderClient, + Logger: logger, + Tracer: testTracer, + }) - // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. - // We override the AWS Bedrock client to route requests through our mock server. - req := createAnthropicMessagesReq(t, mockBridgeSrv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) + reqBody := llm.MustRequestBody(t, streaming) + req := bridgeSrv.NewProviderRequest(t, aibridge.ProviderAnthropic, reqBody) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() // For streaming responses, consume the body to allow the stream to complete. if streaming { - // Read the streaming response. _, err = io.ReadAll(resp.Body) require.NoError(t, err) } + reqs := upstream.Requests() + require.Len(t, reqs, 1) + + // AWS Bedrock encodes the model name in the URL path: /model/{model-id}/invoke or /model/{model-id}/invoke-with-response-stream. + pathParts := strings.Split(reqs[0].Path, "/") + receivedModelName := "" + if len(pathParts) >= 3 && pathParts[1] == "model" { + receivedModelName = pathParts[2] + } + // Verify that Bedrock-specific model name was used in the request to the mock server // and the interception data. - require.Equal(t, requestCount, 1) require.Equal(t, bedrockCfg.Model, receivedModelName) interceptions := recorderClient.RecordedInterceptions() require.Len(t, interceptions, 1) require.Equal(t, interceptions[0].Model, bedrockCfg.Model) - recorderClient.verifyAllInterceptionsEnded(t) + recorderClient.RequireAllInterceptionsEnded(t) }) } }) @@ -384,45 +311,32 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { t.Parallel() - arc := txtar.Parse(oaiSingleBuiltinTool) + arc := testutil.MustParseTXTAR(t, oaiSingleBuiltinTool) t.Logf("%s: %s", t.Name(), arc.Comment) - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) - - reqBody := files[fixtureRequest] + llm := testutil.MustLLMFixture(t, arc) - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", tc.streaming) - require.NoError(t, err) - reqBody = newBody - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - srv := newMockServer(ctx, t, files, nil) - t.Cleanup(srv.Close) - recorderClient := &mockRecorderClient{} + upstream := testutil.NewUpstreamServer(t, ctx, llm) + + recorderClient := &testutil.RecorderSpy{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))} - b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) - require.NoError(t, err) + bridgeSrv := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Providers: []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey))}, + Recorder: recorderClient, + Logger: logger, + Tracer: testTracer, + }) - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - mockSrv.Start() - // Make API call to aibridge for OpenAI /v1/chat/completions - req := createOpenAIChatCompletionsReq(t, mockSrv.URL, reqBody) + reqBody := llm.MustRequestBody(t, tc.streaming) + req := bridgeSrv.NewProviderRequest(t, aibridge.ProviderOpenAI, reqBody) - client := &http.Client{} - resp, err := client.Do(req) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() @@ -443,8 +357,8 @@ func TestOpenAIChatCompletions(t *testing.T) { tokenUsages := recorderClient.RecordedTokenUsages() require.Len(t, tokenUsages, 1) - assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") + assert.EqualValues(t, tc.expectedInputTokens, testutil.TotalInputTokens(tokenUsages), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, testutil.TotalOutputTokens(tokenUsages), "output tokens miscalculated") toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) @@ -457,7 +371,7 @@ func TestOpenAIChatCompletions(t *testing.T) { require.Len(t, promptUsages, 1) assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt) - recorderClient.verifyAllInterceptionsEnded(t) + recorderClient.RequireAllInterceptionsEnded(t) }) } }) @@ -471,7 +385,6 @@ func TestSimple(t *testing.T) { fixture []byte configureFunc func(string, aibridge.Recorder) (*aibridge.RequestBridge, error) getResponseIDFunc func(bool, *http.Response) (string, error) - createRequest func(*testing.T, string, []byte) *http.Request expectedMsgID string }{ { @@ -510,7 +423,6 @@ func TestSimple(t *testing.T) { } return message.ID, nil }, - createRequest: createAnthropicMessagesReq, expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", }, { @@ -549,7 +461,6 @@ func TestSimple(t *testing.T) { } return message.ID, nil }, - createRequest: createOpenAIChatCompletionsReq, expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", }, } @@ -562,43 +473,30 @@ func TestSimple(t *testing.T) { t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { t.Parallel() - arc := txtar.Parse(tc.fixture) + arc := testutil.MustParseTXTAR(t, tc.fixture) t.Logf("%s: %s", t.Name(), arc.Comment) - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) + llm := testutil.MustLLMFixture(t, arc) - reqBody := files[fixtureRequest] - - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", streaming) - require.NoError(t, err) - reqBody = newBody - - // Given: a mock API server and a Bridge through which the requests will flow. - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - srv := newMockServer(ctx, t, files, nil) - t.Cleanup(srv.Close) - recorderClient := &mockRecorderClient{} + upstream := testutil.NewUpstreamServer(t, ctx, llm) + + recorderClient := &testutil.RecorderSpy{} - b, err := tc.configureFunc(srv.URL, recorderClient) + bridge, err := tc.configureFunc(upstream.URL, recorderClient) require.NoError(t, err) - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - mockSrv.Start() - // When: calling the "API server" with the fixture's request body. - req := tc.createRequest(t, mockSrv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) + bridgeSrv := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Handler: bridge, + }) + + reqBody := llm.MustRequestBody(t, streaming) + req := bridgeSrv.NewProviderRequest(t, tc.name, reqBody) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() @@ -627,18 +525,13 @@ func TestSimple(t *testing.T) { require.GreaterOrEqual(t, len(tokenUsages), 1) require.Equal(t, tokenUsages[0].MsgID, tc.expectedMsgID) - recorderClient.verifyAllInterceptionsEnded(t) + recorderClient.RequireAllInterceptionsEnded(t) }) } }) } } -func setJSON(in []byte, key string, val bool) ([]byte, error) { - out, err := sjson.Set(string(in), key, val) - return []byte(out), err -} - func TestFallthrough(t *testing.T) { t.Parallel() @@ -675,14 +568,11 @@ func TestFallthrough(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - arc := txtar.Parse(tc.fixture) + arc := testutil.MustParseTXTAR(t, tc.fixture) t.Logf("%s: %s", t.Name(), arc.Comment) - files := filesMap(arc) - require.Contains(t, files, fixtureResponse) - var receivedHeaders *http.Header - respBody := files[fixtureResponse] + respBody := arc.MustFile(t, testutil.FixtureResponse) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/models" { t.Errorf("unexpected request path: %q", r.URL.Path) @@ -697,21 +587,20 @@ func TestFallthrough(t *testing.T) { })) t.Cleanup(upstream.Close) - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.RecorderSpy{} provider, bridge := tc.configureFunc(upstream.URL, recorderClient) - bridgeSrv := httptest.NewUnstartedServer(bridge) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(t.Context(), userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + bridgeSrv := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: t.Context(), + ActorID: userID, + Handler: bridge, + }) req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s/%s/v1/models", bridgeSrv.URL, tc.name), nil) require.NoError(t, err) - resp, err := http.DefaultClient.Do(req) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -734,310 +623,10 @@ func TestFallthrough(t *testing.T) { } } -// setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools -func setupMCPServerProxiesForTest(t *testing.T, tracer trace.Tracer) (map[string]mcp.ServerProxier, *callAccumulator) { - t.Helper() - - // Setup Coder MCP integration - srv, acc := createMockMCPSrv(t) - mcpSrv := httptest.NewServer(srv) - t.Cleanup(mcpSrv.Close) - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer) - require.NoError(t, err) - - // Initialize MCP client, fetch tools, and inject into bridge - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - require.NoError(t, proxy.Init(ctx)) - tools := proxy.ListTools() - require.NotEmpty(t, tools) - - return map[string]mcp.ServerProxier{proxy.Name(): proxy}, acc -} - type ( - configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) - createRequestFunc func(*testing.T, string, []byte) *http.Request + configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) ) -func TestAnthropicInjectedTools(t *testing.T) { - t.Parallel() - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() - - configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - } - - // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) - - // Ensure expected tool was invoked with expected input. - toolUsages := recorderClient.RecordedToolUsages() - require.Len(t, toolUsages, 1) - require.Equal(t, mockToolName, toolUsages[0].Tool) - expected, err := json.Marshal(map[string]any{"owner": "admin"}) - require.NoError(t, err) - actual, err := json.Marshal(toolUsages[0].Args) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - invocations := mcpCalls.getCallsByTool(mockToolName) - require.Len(t, invocations, 1) - actual, err = json.Marshal(invocations[0]) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - - var ( - content *anthropic.ContentBlockUnion - message anthropic.Message - ) - if streaming { - // Parse the response stream. - decoder := ssestream.NewDecoder(resp) - stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil) - for stream.Next() { - event := stream.Current() - require.NoError(t, message.Accumulate(event), "accumulate event") - } - - require.NoError(t, stream.Err(), "stream error") - require.Len(t, message.Content, 2) - - content = &message.Content[1] - } else { - // Parse & unmarshal the response. - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "read response body") - - require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") - require.GreaterOrEqual(t, len(message.Content), 1) - - content = &message.Content[0] - } - - // Ensure tool returned expected value. - require.NotNil(t, content) - require.Contains(t, content.Text, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. - - // Check the token usage from the client's perspective. - // - // We overwrite the final message_delta which is relayed to the client to include the - // accumulated tokens but currently the SDK only supports accumulating output tokens - // for message_delta events. - // - // For non-streaming requests the token usage is also overwritten and should be faithfully - // represented in the response. - // - // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/message.go#L2619-L2622 - if !streaming { - assert.EqualValues(t, 15308, message.Usage.InputTokens) - } - assert.EqualValues(t, 204, message.Usage.OutputTokens) - - // Ensure tokens used during injected tool invocation are accounted for. - tokenUsages := recorderClient.RecordedTokenUsages() - assert.EqualValues(t, 15308, calculateTotalInputTokens(tokenUsages)) - assert.EqualValues(t, 204, calculateTotalOutputTokens(tokenUsages)) - - // Ensure we received exactly one prompt. - promptUsages := recorderClient.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - }) - } -} - -func TestOpenAIInjectedTools(t *testing.T) { - t.Parallel() - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() - - configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - } - - // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) - - // Ensure expected tool was invoked with expected input. - toolUsages := recorderClient.RecordedToolUsages() - require.Len(t, toolUsages, 1) - require.Equal(t, mockToolName, toolUsages[0].Tool) - expected, err := json.Marshal(map[string]any{"owner": "admin"}) - require.NoError(t, err) - actual, err := json.Marshal(toolUsages[0].Args) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - invocations := mcpCalls.getCallsByTool(mockToolName) - require.Len(t, invocations, 1) - actual, err = json.Marshal(invocations[0]) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - - var ( - content *openai.ChatCompletionChoice - message openai.ChatCompletion - ) - if streaming { - // Parse the response stream. - decoder := oaissestream.NewDecoder(resp) - stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil) - var acc openai.ChatCompletionAccumulator - detectedToolCalls := make(map[string]struct{}) - for stream.Next() { - chunk := stream.Current() - acc.AddChunk(chunk) - - if len(chunk.Choices) == 0 { - continue - } - - for _, c := range chunk.Choices { - if len(c.Delta.ToolCalls) == 0 { - continue - } - - for _, t := range c.Delta.ToolCalls { - if t.Function.Name == "" { - continue - } - - detectedToolCalls[t.Function.Name] = struct{}{} - } - } - } - - // Verify that no injected tool call events (or partials thereof) were sent to the client. - require.Len(t, detectedToolCalls, 0) - - message = acc.ChatCompletion - require.NoError(t, stream.Err(), "stream error") - } else { - // Parse & unmarshal the response. - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "read response body") - require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") - - // Verify that no injected tools were sent to the client. - require.GreaterOrEqual(t, len(message.Choices), 1) - require.Len(t, message.Choices[0].Message.ToolCalls, 0) - } - - require.GreaterOrEqual(t, len(message.Choices), 1) - content = &message.Choices[0] - - // Ensure tool returned expected value. - require.NotNil(t, content) - require.Contains(t, content.Message.Content, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. - - // Check the token usage from the client's perspective. - // This *should* work but the openai SDK doesn't accumulate the prompt token details :(. - // See https://github.com/openai/openai-go/blob/v2.7.0/streamaccumulator.go#L145-L147. - // assert.EqualValues(t, 5047, message.Usage.PromptTokens-message.Usage.PromptTokensDetails.CachedTokens) - assert.EqualValues(t, 105, message.Usage.CompletionTokens) - - // Ensure tokens used during injected tool invocation are accounted for. - tokenUsages := recorderClient.RecordedTokenUsages() - require.EqualValues(t, 5047, calculateTotalInputTokens(tokenUsages)) - require.EqualValues(t, 105, calculateTotalOutputTokens(tokenUsages)) - - // Ensure we received exactly one prompt. - promptUsages := recorderClient.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - }) - } -} - -// setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests. -// Kinda fugly right now, we can refactor this later. -func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) { - t.Helper() - - arc := txtar.Parse(fixture) - t.Logf("%s: %s", t.Name(), arc.Comment) - - files := filesMap(arc) - require.Len(t, files, 5) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) - require.Contains(t, files, fixtureStreamingToolResponse) - require.Contains(t, files, fixtureNonStreamingToolResponse) - - reqBody := files[fixtureRequest] - - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", streaming) - require.NoError(t, err) - reqBody = newBody - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - // Setup mock server with response mutator for multi-turn interaction. - mockSrv := newMockServer(ctx, t, files, func(reqCount uint32, resp []byte) []byte { - if reqCount == 1 { - return resp // First request gets the normal response (with tool call). - } - - if reqCount > 2 { - // This should not happen in single injected tool tests. - return resp - } - - // Second request gets the tool response. - if streaming { - return files[fixtureStreamingToolResponse] - } - return files[fixtureNonStreamingToolResponse] - }) - t.Cleanup(mockSrv.Close) - - recorderClient := &mockRecorderClient{} - - // Setup MCP mcpProxiers. - mcpProxiers, acc := setupMCPServerProxiesForTest(t, testTracer) - - // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) - require.NoError(t, mcpMgr.Init(ctx)) - b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr) - require.NoError(t, err) - - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(b) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) - - req := createRequestFn(t, bridgeSrv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - t.Cleanup(func() { - _ = resp.Body.Close() - }) - - // We must ALWAYS have 2 calls to the bridge for injected tool tests. - require.Eventually(t, func() bool { - return mockSrv.callCount.Load() == 2 - }, time.Second*10, time.Millisecond*50) - - return recorderClient, acc, mcpProxiers, resp -} - func TestErrorHandling(t *testing.T) { t.Parallel() @@ -1046,14 +635,12 @@ func TestErrorHandling(t *testing.T) { cases := []struct { name string fixture []byte - createRequestFunc createRequestFunc configureFunc configureFunc responseHandlerFn func(resp *http.Response) }{ { - name: aibridge.ProviderAnthropic, - fixture: antNonStreamErr, - createRequestFunc: createAnthropicMessagesReq, + name: aibridge.ProviderAnthropic, + fixture: antNonStreamErr, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} @@ -1069,9 +656,8 @@ func TestErrorHandling(t *testing.T) { }, }, { - name: aibridge.ProviderOpenAI, - fixture: oaiNonStreamErr, - createRequestFunc: createOpenAIChatCompletionsReq, + name: aibridge.ProviderOpenAI, + fixture: oaiNonStreamErr, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} @@ -1096,52 +682,37 @@ func TestErrorHandling(t *testing.T) { t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - arc := txtar.Parse(tc.fixture) + arc := testutil.MustParseTXTAR(t, tc.fixture) t.Logf("%s: %s", t.Name(), arc.Comment) - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) + llm := testutil.MustLLMFixture(t, arc) + reqBody := llm.MustRequestBody(t, streaming) - reqBody := files[fixtureRequest] - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", streaming) + mockResp, err := llm.Response(1, streaming) require.NoError(t, err) - reqBody = newBody - - // Setup mock server. - mockResp := files[fixtureStreamingResponse] - if !streaming { - mockResp = files[fixtureNonStreamingResponse] - } - mockSrv := newMockHTTPReflector(ctx, t, mockResp) - t.Cleanup(mockSrv.Close) + mockSrv := testutil.NewHTTPReflectorServer(t, ctx, mockResp) - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.RecorderSpy{} - b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) + bridge, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) require.NoError(t, err) - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(b) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + bridgeSrv := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Handler: bridge, + }) - req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody) - resp, err := http.DefaultClient.Do(req) - t.Cleanup(func() { _ = resp.Body.Close() }) + req := bridgeSrv.NewProviderRequest(t, tc.name, reqBody) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) + defer resp.Body.Close() tc.responseHandlerFn(resp) - recorderClient.verifyAllInterceptionsEnded(t) + recorderClient.RequireAllInterceptionsEnded(t) }) } }) @@ -1153,14 +724,12 @@ func TestErrorHandling(t *testing.T) { cases := []struct { name string fixture []byte - createRequestFunc createRequestFunc configureFunc configureFunc responseHandlerFn func(resp *http.Response) }{ { - name: aibridge.ProviderAnthropic, - fixture: antMidStreamErr, - createRequestFunc: createAnthropicMessagesReq, + name: aibridge.ProviderAnthropic, + fixture: antMidStreamErr, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} @@ -1177,9 +746,8 @@ func TestErrorHandling(t *testing.T) { }, }, { - name: aibridge.ProviderOpenAI, - fixture: oaiMidStreamErr, - createRequestFunc: createOpenAIChatCompletionsReq, + name: aibridge.ProviderOpenAI, + fixture: oaiMidStreamErr, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} @@ -1206,45 +774,36 @@ func TestErrorHandling(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - arc := txtar.Parse(tc.fixture) + arc := testutil.MustParseTXTAR(t, tc.fixture) t.Logf("%s: %s", t.Name(), arc.Comment) - files := filesMap(arc) - require.Len(t, files, 2) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) + llm := testutil.MustLLMFixture(t, arc) - reqBody := files[fixtureRequest] + upstream := testutil.NewUpstreamServer(t, ctx, llm) - // Setup mock server. - mockSrv := newMockServer(ctx, t, files, nil) - mockSrv.statusCode = http.StatusInternalServerError - t.Cleanup(mockSrv.Close) + reqBody := llm.MustRequestBody(t, true) - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.RecorderSpy{} - b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) + bridge, err := tc.configureFunc(upstream.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) require.NoError(t, err) - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(b) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + bridgeSrv := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Handler: bridge, + }) - req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody) - resp, err := http.DefaultClient.Do(req) - t.Cleanup(func() { _ = resp.Body.Close() }) + req := bridgeSrv.NewProviderRequest(t, tc.name, reqBody) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) - bridgeSrv.Close() + defer resp.Body.Close() tc.responseHandlerFn(resp) - recorderClient.verifyAllInterceptionsEnded(t) + recorderClient.RequireAllInterceptionsEnded(t) }) } }) @@ -1260,27 +819,22 @@ func TestStableRequestEncoding(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) cases := []struct { - name string - fixture []byte - createRequestFunc createRequestFunc - configureFunc configureFunc + name string + fixture []byte + newProvider func(upstreamURL string) aibridge.Provider }{ { - name: aibridge.ProviderAnthropic, - fixture: antSimple, - createRequestFunc: createAnthropicMessagesReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) + name: aibridge.ProviderAnthropic, + fixture: antSimple, + newProvider: func(upstreamURL string) aibridge.Provider { + return aibridge.NewAnthropicProvider(anthropicCfg(upstreamURL, apiKey), nil) }, }, { - name: aibridge.ProviderOpenAI, - fixture: oaiSimple, - createRequestFunc: createOpenAIChatCompletionsReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) + name: aibridge.ProviderOpenAI, + fixture: oaiSimple, + newProvider: func(upstreamURL string) aibridge.Provider { + return aibridge.NewOpenAIProvider(openaiCfg(upstreamURL, apiKey)) }, }, } @@ -1289,81 +843,52 @@ func TestStableRequestEncoding(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - // Setup MCP tools. - mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer) - - // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) - require.NoError(t, mcpMgr.Init(ctx)) - - arc := txtar.Parse(tc.fixture) + arc := testutil.MustParseTXTAR(t, tc.fixture) t.Logf("%s: %s", t.Name(), arc.Comment) - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureNonStreamingResponse) - - var ( - reference []byte - reqCount atomic.Int32 - ) - - // Create a mock server that captures and compares request bodies. - mockSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - reqCount.Add(1) + llm := testutil.MustLLMFixture(t, arc) + reqBody := arc.MustFile(t, testutil.FixtureRequest) - // Capture the raw request body. - raw, err := io.ReadAll(r.Body) - defer r.Body.Close() - require.NoError(t, err) - require.NotEmpty(t, raw) - - // Store the first instance as the reference value. - if reference == nil { - reference = raw - } else { - // Compare all subsequent requests to the reference. - assert.JSONEq(t, string(reference), string(raw)) - } - - // Return a valid API response. - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(files[fixtureNonStreamingResponse]) - })) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - mockSrv.Start() - t.Cleanup(mockSrv.Close) + upstream := testutil.NewUpstreamServer(t, ctx, llm) - recorder := &mockRecorderClient{} - bridge, err := tc.configureFunc(mockSrv.URL, recorder, mcpMgr) - require.NoError(t, err) + // Setup mocked MCP server & tools. + mcpSrv := testutil.NewMCPServer(t, testutil.DefaultCoderToolNames()) + mcpProxiers := mcpSrv.Proxiers(t, "coder", logger, testTracer) - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(bridge) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + recorder := &testutil.RecorderSpy{} + bridgeSrv := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Providers: []aibridge.Provider{tc.newProvider(upstream.URL)}, + Recorder: recorder, + MCPProxiers: mcpProxiers, + Logger: logger, + Tracer: testTracer, + }) // Make multiple requests and verify they all have identical payloads. count := 10 - for range count { - req := tc.createRequestFunc(t, bridgeSrv.URL, files[fixtureRequest]) - client := &http.Client{} - resp, err := client.Do(req) + for i := 0; i < count; i++ { + req := bridgeSrv.NewProviderRequest(t, tc.name, reqBody) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) _ = resp.Body.Close() } - require.EqualValues(t, count, reqCount.Load()) + upstream.RequireCallCountEventually(t, count) + reqs := upstream.Requests() + require.Len(t, reqs, count) + + reference := string(reqs[0].Body) + for i := 1; i < len(reqs); i++ { + assert.JSONEq(t, reference, string(reqs[i].Body)) + } + + recorder.RequireAllInterceptionsEnded(t) }) } } @@ -1423,70 +948,48 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - // Configure the bridge. - mcpMgr := mcp.NewServerProxyManager(nil, testTracer) - require.NoError(t, mcpMgr.Init(ctx)) - - arc := txtar.Parse(antSimple) - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureNonStreamingResponse) + arc := testutil.MustParseTXTAR(t, antSimple) + llm := testutil.MustLLMFixture(t, arc) // Prepare request body with tool_choice set. var reqJSON map[string]any - require.NoError(t, json.Unmarshal(files[fixtureRequest], &reqJSON)) + require.NoError(t, json.Unmarshal(arc.MustFile(t, testutil.FixtureRequest), &reqJSON)) if tc.toolChoice != nil { reqJSON["tool_choice"] = tc.toolChoice } reqBody, err := json.Marshal(reqJSON) require.NoError(t, err) - var receivedRequest map[string]any - - // Create a mock server that captures the request body sent upstream. - mockSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Capture the raw request body. - raw, err := io.ReadAll(r.Body) - defer r.Body.Close() - require.NoError(t, err) - - require.NoError(t, json.Unmarshal(raw, &receivedRequest)) - - // Return a valid API response. - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(files[fixtureNonStreamingResponse]) - })) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - mockSrv.Start() - t.Cleanup(mockSrv.Close) + upstream := testutil.NewUpstreamServer(t, ctx, llm) - recorder := &mockRecorderClient{} + recorder := &testutil.RecorderSpy{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(mockSrv.URL, apiKey), nil)} - bridge, err := aibridge.NewRequestBridge(ctx, providers, recorder, mcpMgr, logger, nil, testTracer) - require.NoError(t, err) - - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(bridge) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(upstream.URL, apiKey), nil)} + bridgeSrv := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Providers: providers, + Recorder: recorder, + Logger: logger, + Tracer: testTracer, + }) - req := createAnthropicMessagesReq(t, bridgeSrv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) + req := bridgeSrv.NewProviderRequest(t, aibridge.ProviderAnthropic, reqBody) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) _ = resp.Body.Close() + upstream.RequireCallCountEventually(t, 1) + reqs := upstream.Requests() + require.Len(t, reqs, 1) + + var receivedRequest map[string]any + require.NoError(t, json.Unmarshal(reqs[0].Body, &receivedRequest)) + // Verify tool_choice in the upstream request. require.NotNil(t, receivedRequest) toolChoice, ok := receivedRequest["tool_choice"].(map[string]any) @@ -1523,7 +1026,6 @@ func TestEnvironmentDoNotLeak(t *testing.T) { name string fixture []byte configureFunc func(string, aibridge.Recorder) (*aibridge.RequestBridge, error) - createRequest func(*testing.T, string, []byte) *http.Request envVars map[string]string headerName string }{ @@ -1535,7 +1037,6 @@ func TestEnvironmentDoNotLeak(t *testing.T) { providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) }, - createRequest: createAnthropicMessagesReq, envVars: map[string]string{ "ANTHROPIC_AUTH_TOKEN": "should-not-leak", }, @@ -1549,7 +1050,6 @@ func TestEnvironmentDoNotLeak(t *testing.T) { providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) }, - createRequest: createOpenAIChatCompletionsReq, envVars: map[string]string{ "OPENAI_ORG_ID": "should-not-leak", }, @@ -1561,26 +1061,14 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // NOTE: Cannot use t.Parallel() here because t.Setenv requires sequential execution. - arc := txtar.Parse(tc.fixture) - files := filesMap(arc) - reqBody := files[fixtureRequest] + arc := testutil.MustParseTXTAR(t, tc.fixture) + llm := testutil.MustLLMFixture(t, arc) + reqBody := arc.MustFile(t, testutil.FixtureRequest) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - // Track headers received by the upstream server. - var receivedHeaders http.Header - srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(files[fixtureNonStreamingResponse]) - })) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - srv.Start() - t.Cleanup(srv.Close) + upstream := testutil.NewUpstreamServer(t, ctx, llm) // Set environment variables that the SDK would automatically read. // These should NOT leak into upstream requests. @@ -1588,357 +1076,34 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Setenv(key, val) } - recorderClient := &mockRecorderClient{} - b, err := tc.configureFunc(srv.URL, recorderClient) + recorderClient := &testutil.RecorderSpy{} + bridge, err := tc.configureFunc(upstream.URL, recorderClient) require.NoError(t, err) - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - mockSrv.Start() + bridgeSrv := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Handler: bridge, + }) - req := tc.createRequest(t, mockSrv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) + req := bridgeSrv.NewProviderRequest(t, tc.name, reqBody) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() + upstream.RequireCallCountEventually(t, 1) + reqs := upstream.Requests() + require.Len(t, reqs, 1) + // Verify that environment values did not leak. + receivedHeaders := reqs[0].Header require.NotNil(t, receivedHeaders) require.Empty(t, receivedHeaders.Get(tc.headerName)) }) } } -func calculateTotalInputTokens(in []*aibridge.TokenUsageRecord) int64 { - var total int64 - for _, el := range in { - total += el.Input - } - return total -} - -func calculateTotalOutputTokens(in []*aibridge.TokenUsageRecord) int64 { - var total int64 - for _, el := range in { - total += el.Output - } - return total -} - -type archiveFileMap map[string][]byte - -func filesMap(archive *txtar.Archive) archiveFileMap { - if len(archive.Files) == 0 { - return nil - } - - out := make(archiveFileMap, len(archive.Files)) - for _, f := range archive.Files { - out[f.Name] = f.Data - } - return out -} - -func createAnthropicMessagesReq(t *testing.T, baseURL string, input []byte) *http.Request { - t.Helper() - - req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/anthropic/v1/messages", bytes.NewReader(input)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - - return req -} - -func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) *http.Request { - t.Helper() - - req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/openai/v1/chat/completions", bytes.NewReader(input)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - - return req -} - -type mockHTTPReflector struct { - *httptest.Server -} - -func newMockHTTPReflector(ctx context.Context, t *testing.T, resp []byte) *mockHTTPReflector { - ref := &mockHTTPReflector{} - - srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mock, err := http.ReadResponse(bufio.NewReader(bytes.NewBuffer(resp)), r) - require.NoError(t, err) - defer mock.Body.Close() - - // Copy headers from the mocked response. - for key, values := range mock.Header { - for _, value := range values { - w.Header().Add(key, value) - } - } - - // Write the status code. - w.WriteHeader(mock.StatusCode) - - // Copy the body. - _, err = io.Copy(w, mock.Body) - require.NoError(t, err) - })) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - - srv.Start() - t.Cleanup(srv.Close) - - ref.Server = srv - return ref -} - -// TODO: replace this with mockHTTPReflector. -type mockServer struct { - *httptest.Server - - callCount atomic.Uint32 - - statusCode int -} - -func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, responseMutatorFn func(reqCount uint32, resp []byte) []byte) *mockServer { - t.Helper() - - ms := &mockServer{} - srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - statusCode := http.StatusOK - if ms.statusCode != 0 { - statusCode = ms.statusCode - } - - ms.callCount.Add(1) - - body, err := io.ReadAll(r.Body) - defer r.Body.Close() - require.NoError(t, err) - - type msg struct { - Stream bool `json:"stream"` - } - var reqMsg msg - require.NoError(t, json.Unmarshal(body, &reqMsg)) - - if !reqMsg.Stream && !strings.HasSuffix(r.URL.Path, "invoke-with-response-stream") { - resp := files[fixtureNonStreamingResponse] - if responseMutatorFn != nil { - resp = responseMutatorFn(ms.callCount.Load(), resp) - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(statusCode) - w.Write(resp) - return - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - resp := files[fixtureStreamingResponse] - if responseMutatorFn != nil { - resp = responseMutatorFn(ms.callCount.Load(), resp) - } - - scanner := bufio.NewScanner(bytes.NewReader(resp)) - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "Streaming unsupported", http.StatusInternalServerError) - return - } - - for scanner.Scan() { - line := scanner.Text() - - fmt.Fprintf(w, "%s\n", line) - flusher.Flush() - } - - if err := scanner.Err(); err != nil { - http.Error(w, fmt.Sprintf("Error reading fixture: %v", err), http.StatusInternalServerError) - return - } - })) - - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - - srv.Start() - t.Cleanup(srv.Close) - - ms.Server = srv - return ms -} - -var _ aibridge.Recorder = &mockRecorderClient{} - -type mockRecorderClient struct { - mu sync.Mutex - - interceptions []*aibridge.InterceptionRecord - tokenUsages []*aibridge.TokenUsageRecord - userPrompts []*aibridge.PromptUsageRecord - toolUsages []*aibridge.ToolUsageRecord - interceptionsEnd map[string]time.Time -} - -func (m *mockRecorderClient) RecordInterception(ctx context.Context, req *aibridge.InterceptionRecord) error { - m.mu.Lock() - defer m.mu.Unlock() - m.interceptions = append(m.interceptions, req) - return nil -} - -func (m *mockRecorderClient) RecordInterceptionEnded(ctx context.Context, req *aibridge.InterceptionRecordEnded) error { - m.mu.Lock() - defer m.mu.Unlock() - if m.interceptionsEnd == nil { - m.interceptionsEnd = make(map[string]time.Time) - } - if !slices.ContainsFunc(m.interceptions, func(intc *aibridge.InterceptionRecord) bool { return intc.ID == req.ID }) { - return fmt.Errorf("id not found") - } - m.interceptionsEnd[req.ID] = req.EndedAt - return nil -} - -func (m *mockRecorderClient) RecordPromptUsage(ctx context.Context, req *aibridge.PromptUsageRecord) error { - m.mu.Lock() - defer m.mu.Unlock() - m.userPrompts = append(m.userPrompts, req) - return nil -} - -func (m *mockRecorderClient) RecordTokenUsage(ctx context.Context, req *aibridge.TokenUsageRecord) error { - m.mu.Lock() - defer m.mu.Unlock() - m.tokenUsages = append(m.tokenUsages, req) - return nil -} - -func (m *mockRecorderClient) RecordToolUsage(ctx context.Context, req *aibridge.ToolUsageRecord) error { - m.mu.Lock() - defer m.mu.Unlock() - m.toolUsages = append(m.toolUsages, req) - return nil -} - -// RecordedTokenUsages returns a copy of recorded token usages in a thread-safe manner. -// Note: This is a shallow clone - the slice is copied but the pointers reference the -// same underlying records. This is sufficient for our test assertions which only read -// the data and don't modify the records. -func (m *mockRecorderClient) RecordedTokenUsages() []*aibridge.TokenUsageRecord { - m.mu.Lock() - defer m.mu.Unlock() - return slices.Clone(m.tokenUsages) -} - -// RecordedPromptUsages returns a copy of recorded prompt usages in a thread-safe manner. -// Note: This is a shallow clone (see RecordedTokenUsages for details). -func (m *mockRecorderClient) RecordedPromptUsages() []*aibridge.PromptUsageRecord { - m.mu.Lock() - defer m.mu.Unlock() - return slices.Clone(m.userPrompts) -} - -// RecordedToolUsages returns a copy of recorded tool usages in a thread-safe manner. -// Note: This is a shallow clone (see RecordedTokenUsages for details). -func (m *mockRecorderClient) RecordedToolUsages() []*aibridge.ToolUsageRecord { - m.mu.Lock() - defer m.mu.Unlock() - return slices.Clone(m.toolUsages) -} - -// RecordedInterceptions returns a copy of recorded interceptions in a thread-safe manner. -// Note: This is a shallow clone (see RecordedTokenUsages for details). -func (m *mockRecorderClient) RecordedInterceptions() []*aibridge.InterceptionRecord { - m.mu.Lock() - defer m.mu.Unlock() - return slices.Clone(m.interceptions) -} - -// verify all recorded interceptions has been marked as completed -func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) { - t.Helper() - - m.mu.Lock() - defer m.mu.Unlock() - require.Equalf(t, len(m.interceptions), len(m.interceptionsEnd), "got %v interception ended calls, want: %v", len(m.interceptionsEnd), len(m.interceptions)) - for _, intc := range m.interceptions { - require.Containsf(t, m.interceptionsEnd, intc.ID, "interception with id: %v has not been ended", intc.ID) - } -} - -const mockToolName = "coder_list_workspaces" - -// callAccumulator tracks all tool invocations by name and each instance's arguments. -type callAccumulator struct { - calls map[string][]any - callsMu sync.Mutex -} - -func newCallAccumulator() *callAccumulator { - return &callAccumulator{ - calls: make(map[string][]any), - } -} - -func (a *callAccumulator) addCall(tool string, args any) { - a.callsMu.Lock() - defer a.callsMu.Unlock() - - a.calls[tool] = append(a.calls[tool], args) -} - -func (a *callAccumulator) getCallsByTool(name string) []any { - a.callsMu.Lock() - defer a.callsMu.Unlock() - - // Protect against concurrent access of the slice. - result := make([]any, len(a.calls[name])) - copy(result, a.calls[name]) - return result -} - -func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) { - t.Helper() - - s := server.NewMCPServer( - "Mock coder MCP server", - "1.0.0", - server.WithToolCapabilities(true), - ) - - // Accumulate tool calls & their arguments. - acc := newCallAccumulator() - - for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} { - tool := mcplib.NewTool(name, - mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), - ) - s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { - acc.addCall(request.Params.Name, request.Params.Arguments) - return mcplib.NewToolResultText("mock"), nil - }) - } - - return server.NewStreamableHTTPServer(s), acc -} - func openaiCfg(url, key string) aibridge.OpenAIConfig { return aibridge.OpenAIConfig{ BaseURL: url, diff --git a/metrics_integration_test.go b/metrics_integration_test.go index f326dec..5b1a0db 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -12,12 +12,11 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/aibridge" - "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/testutil" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" - "golang.org/x/tools/txtar" ) func TestMetrics_Interception(t *testing.T) { @@ -38,21 +37,20 @@ func TestMetrics_Interception(t *testing.T) { } for _, tc := range cases { - arc := txtar.Parse(tc.fixture) - files := filesMap(arc) + fixture := testutil.MustParseTXTAR(t, tc.fixture) + llm := testutil.MustLLMFixture(t, fixture) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - mockAPI := newMockServer(ctx, t, files, nil) - t.Cleanup(mockAPI.Close) + upstream := testutil.NewUpstreamServer(t, ctx, llm) metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) - srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) + provider := aibridge.NewAnthropicProvider(anthropicCfg(upstream.URL, apiKey), nil) + bridgeSrv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) - req := createAnthropicMessagesReq(t, srv.URL, files[fixtureRequest]) - resp, err := http.DefaultClient.Do(req) + req := bridgeSrv.NewProviderRequest(t, provider.Name(), fixture.MustFile(t, testutil.FixtureRequest)) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) @@ -67,20 +65,20 @@ func TestMetrics_Interception(t *testing.T) { func TestMetrics_InterceptionsInflight(t *testing.T) { t.Parallel() - arc := txtar.Parse(antSimple) - files := filesMap(arc) + fixture := testutil.MustParseTXTAR(t, antSimple) + llm := testutil.MustLLMFixture(t, fixture) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) + upstream := testutil.NewUpstreamServer(t, ctx, llm) + blockCh := make(chan struct{}) // Setup a mock HTTP server which blocks until the request is marked as inflight then proceeds. srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { <-blockCh - mock := newMockServer(ctx, t, files, nil) - defer mock.Close() - mock.Server.Config.Handler.ServeHTTP(w, r) + upstream.Config.Handler.ServeHTTP(w, r) })) srv.Config.BaseContext = func(_ net.Listener) context.Context { return ctx @@ -92,12 +90,13 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { provider := aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil) bridgeSrv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) + req := bridgeSrv.NewProviderRequest(t, provider.Name(), fixture.MustFile(t, testutil.FixtureRequest)) + // Make request in background. doneCh := make(chan struct{}) go func() { defer close(doneCh) - req := createAnthropicMessagesReq(t, bridgeSrv.URL, files[fixtureRequest]) - resp, err := http.DefaultClient.Do(req) + resp, err := bridgeSrv.Client.Do(req) if err == nil { defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) @@ -109,7 +108,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { return promtest.ToFloat64( metrics.InterceptionsInflight.WithLabelValues(aibridge.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), ) == 1 - }, time.Second*10, time.Millisecond*50) + }, 10*time.Second, 50*time.Millisecond) // Unblock request, await completion. close(blockCh) @@ -124,30 +123,30 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { return promtest.ToFloat64( metrics.InterceptionsInflight.WithLabelValues(aibridge.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), ) == 0 - }, time.Second*10, time.Millisecond*50) + }, 10*time.Second, 50*time.Millisecond) } func TestMetrics_PassthroughCount(t *testing.T) { t.Parallel() - arc := txtar.Parse(oaiFallthrough) - files := filesMap(arc) + fixture := testutil.MustParseTXTAR(t, oaiFallthrough) + respBody := fixture.MustFile(t, testutil.FixtureResponse) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write(files[fixtureResponse]) + _, _ = w.Write(respBody) })) t.Cleanup(upstream.Close) metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, t.Context(), provider, metrics, testTracer) + bridgeSrv, _ := newTestSrv(t, t.Context(), provider, metrics, testTracer) - req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) + req, err := http.NewRequestWithContext(t.Context(), "GET", bridgeSrv.URL+"/openai/v1/models", nil) require.NoError(t, err) - resp, err := http.DefaultClient.Do(req) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) @@ -160,21 +159,20 @@ func TestMetrics_PassthroughCount(t *testing.T) { func TestMetrics_PromptCount(t *testing.T) { t.Parallel() - arc := txtar.Parse(oaiSimple) - files := filesMap(arc) + fixture := testutil.MustParseTXTAR(t, oaiSimple) + llm := testutil.MustLLMFixture(t, fixture) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - mockAPI := newMockServer(ctx, t, files, nil) - t.Cleanup(mockAPI.Close) + upstream := testutil.NewUpstreamServer(t, ctx, llm) metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) + provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) + bridgeSrv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) - req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) - resp, err := http.DefaultClient.Do(req) + req := bridgeSrv.NewProviderRequest(t, provider.Name(), fixture.MustFile(t, testutil.FixtureRequest)) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() @@ -188,21 +186,20 @@ func TestMetrics_PromptCount(t *testing.T) { func TestMetrics_NonInjectedToolUseCount(t *testing.T) { t.Parallel() - arc := txtar.Parse(oaiSingleBuiltinTool) - files := filesMap(arc) + fixture := testutil.MustParseTXTAR(t, oaiSingleBuiltinTool) + llm := testutil.MustLLMFixture(t, fixture) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - mockAPI := newMockServer(ctx, t, files, nil) - t.Cleanup(mockAPI.Close) + upstream := testutil.NewUpstreamServer(t, ctx, llm) metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) + provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) + bridgeSrv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) - req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) - resp, err := http.DefaultClient.Do(req) + req := bridgeSrv.NewProviderRequest(t, provider.Name(), fixture.MustFile(t, testutil.FixtureRequest)) + resp, err := bridgeSrv.Client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() @@ -216,82 +213,75 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { func TestMetrics_InjectedToolUseCount(t *testing.T) { t.Parallel() - arc := txtar.Parse(antSingleInjectedTool) - files := filesMap(arc) + fixture := testutil.MustParseTXTAR(t, antSingleInjectedTool) + llm := testutil.MustLLMFixture(t, fixture) ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - // First request returns the tool invocation, the second returns the mocked response to the tool result. - mockAPI := newMockServer(ctx, t, files, func(reqCount uint32, resp []byte) []byte { - if reqCount == 1 { - return resp - } - return files[fixtureNonStreamingToolResponse] - }) - t.Cleanup(mockAPI.Close) + upstream := testutil.NewUpstreamServer(t, ctx, llm) - recorder := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) + provider := aibridge.NewAnthropicProvider(anthropicCfg(upstream.URL, apiKey), nil) // Setup mocked MCP server & tools. - mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer) - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) - require.NoError(t, mcpMgr.Init(ctx)) - - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, logger, metrics, testTracer) - require.NoError(t, err) - - srv := httptest.NewUnstartedServer(bridge) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - srv.Start() - t.Cleanup(srv.Close) + mcpSrv := testutil.NewMCPServer(t, testutil.DefaultCoderToolNames()) + mcpProxiers := mcpSrv.Proxiers(t, "coder", logger, testTracer) + + recorder := &testutil.RecorderSpy{} + bridge := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Providers: []aibridge.Provider{provider}, + Recorder: recorder, + MCPProxiers: mcpProxiers, + Logger: logger, + Metrics: metrics, + Tracer: testTracer, + }) - req := createAnthropicMessagesReq(t, srv.URL, files[fixtureRequest]) - resp, err := http.DefaultClient.Do(req) + reqBody := fixture.MustFile(t, testutil.FixtureRequest) + req := bridge.NewProviderRequest(t, aibridge.ProviderAnthropic, reqBody) + resp, err := bridge.Client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) // Wait until full roundtrip has completed. - require.Eventually(t, func() bool { - return mockAPI.callCount.Load() == 2 - }, time.Second*10, time.Millisecond*50) + upstream.RequireCallCountEventually(t, 2) - require.Len(t, recorder.toolUsages, 1) - require.True(t, recorder.toolUsages[0].Injected) - require.NotNil(t, recorder.toolUsages[0].ServerURL) - actualServerURL := *recorder.toolUsages[0].ServerURL + toolUsages := recorder.RecordedToolUsages() + require.Len(t, toolUsages, 1) + require.True(t, toolUsages[0].Injected) + require.NotNil(t, toolUsages[0].ServerURL) + actualServerURL := *toolUsages[0].ServerURL count := promtest.ToFloat64(metrics.InjectedToolUseCount.WithLabelValues( - aibridge.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, mockToolName)) + aibridge.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, testutil.ToolCoderListWorkspaces)) require.Equal(t, 1.0, count) } -func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *aibridge.Metrics, tracer trace.Tracer) (*httptest.Server, *mockRecorderClient) { +func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *aibridge.Metrics, tracer trace.Tracer) (*testutil.BridgeServer, *testutil.RecorderSpy) { t.Helper() logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - mockRecorder := &mockRecorderClient{} + spy := &testutil.RecorderSpy{} clientFn := func() (aibridge.Recorder, error) { - return mockRecorder, nil + return spy, nil } wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil, testTracer), logger, metrics, tracer) - require.NoError(t, err) - - srv := httptest.NewUnstartedServer(bridge) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - srv.Start() - t.Cleanup(srv.Close) + bridgeSrv := testutil.NewBridgeServer(t, testutil.BridgeConfig{ + Ctx: ctx, + ActorID: userID, + Providers: []aibridge.Provider{provider}, + Recorder: wrappedRecorder, + Logger: logger, + Metrics: metrics, + Tracer: tracer, + }) - return srv, mockRecorder + return bridgeSrv, spy } diff --git a/testutil/bridge_server.go b/testutil/bridge_server.go new file mode 100644 index 0000000..a357bb9 --- /dev/null +++ b/testutil/bridge_server.go @@ -0,0 +1,111 @@ +package testutil + +import ( + "bytes" + "context" + "net" + "net/http" + "net/http/httptest" + "testing" + + "cdr.dev/slog" + "github.com/coder/aibridge" + "github.com/coder/aibridge/mcp" + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" +) + +type BridgeConfig struct { + Ctx context.Context + ActorID string + + // Exactly one of Handler or Providers must be set. + Handler http.Handler + Providers []aibridge.Provider + + Recorder aibridge.Recorder + + MCPProxiers map[string]mcp.ServerProxier + + Logger slog.Logger + Metrics *aibridge.Metrics + Tracer trace.Tracer +} + +type BridgeServer struct { + *httptest.Server + Client *http.Client +} + +func NewBridgeServer(t testing.TB, cfg BridgeConfig) *BridgeServer { + t.Helper() + + ctx := cfg.Ctx + if ctx == nil { + ctx = context.Background() + } + + if cfg.Tracer == nil { + cfg.Tracer = noop.NewTracerProvider().Tracer("aibridge/testutil") + } + + if cfg.Handler == nil { + if len(cfg.Providers) == 0 { + t.Fatalf("BridgeConfig: must set either Handler or Providers") + } + if cfg.Recorder == nil { + t.Fatalf("BridgeConfig: Recorder is required when building a RequestBridge") + } + + mgr := mcp.NewServerProxyManager(cfg.MCPProxiers, cfg.Tracer) + // Only init when there are proxiers. This keeps trace output consistent with + // tests that intentionally pass a nil/empty proxy map. + if len(cfg.MCPProxiers) > 0 { + if err := mgr.Init(ctx); err != nil { + t.Fatalf("init MCP manager: %v", err) + } + } + + bridge, err := aibridge.NewRequestBridge(ctx, cfg.Providers, cfg.Recorder, mgr, cfg.Logger, cfg.Metrics, cfg.Tracer) + if err != nil { + t.Fatalf("create RequestBridge: %v", err) + } + cfg.Handler = bridge + } + + srv := httptest.NewUnstartedServer(cfg.Handler) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + if cfg.ActorID == "" { + return ctx + } + return aibridge.AsActor(ctx, cfg.ActorID, nil) + } + srv.Start() + t.Cleanup(srv.Close) + + return &BridgeServer{ + Server: srv, + Client: &http.Client{}, + } +} + +func (b *BridgeServer) NewProviderRequest(t testing.TB, provider string, body []byte) *http.Request { + t.Helper() + + path := "" + switch provider { + case aibridge.ProviderAnthropic: + path = "/anthropic/v1/messages" + case aibridge.ProviderOpenAI: + path = "/openai/v1/chat/completions" + default: + t.Fatalf("unknown provider %q", provider) + } + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, b.URL+path, bytes.NewReader(body)) + if err != nil { + t.Fatalf("create request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + return req +} diff --git a/testutil/doc.go b/testutil/doc.go new file mode 100644 index 0000000..03af093 --- /dev/null +++ b/testutil/doc.go @@ -0,0 +1,15 @@ +// Package testutil contains helpers for testing AIBridge. +// +// # Stability +// +// This package is intended for tests and examples within the Coder ecosystem. +// It is not considered a stable public API. +// +// The goal is to make AIBridge tests: +// - easier to read +// - easier to extend +// - less reliant on copy/paste setup +// +// In particular, it provides typed accessors for txtar fixtures and small +// harness structs for standing up mock upstream/MCP/bridge servers. +package testutil diff --git a/testutil/fixture_llm.go b/testutil/fixture_llm.go new file mode 100644 index 0000000..e779aa2 --- /dev/null +++ b/testutil/fixture_llm.go @@ -0,0 +1,118 @@ +package testutil + +import ( + "fmt" + "testing" +) + +const ( + FixtureRequest = "request" + FixtureStreamingResponse = "streaming" + FixtureNonStreamingResponse = "non-streaming" + + FixtureStreamingToolResponse = "streaming/tool-call" + FixtureNonStreamingToolResponse = "non-streaming/tool-call" + + FixtureResponse = "response" +) + +// LLMFixture is a typed view over a TXTAR fixture used for bridged LLM +// interactions (OpenAI/Anthropic-like request + streaming/non-streaming +// responses). +// +// It knows how to: +// - derive a request body for streaming/non-streaming modes +// - select the correct upstream response for the Nth call +// +// It does NOT attempt to interpret or validate the JSON/SSE payload contents. +type LLMFixture struct { + TXTAR TXTARFixture +} + +func NewLLMFixture(txtarFixture TXTARFixture) (LLMFixture, error) { + if len(txtarFixture.Files) == 0 { + return LLMFixture{}, fmt.Errorf("empty txtar fixture") + } + + // We require at minimum a request. + if !txtarFixture.Has(FixtureRequest) { + return LLMFixture{}, fmt.Errorf("missing %q section", FixtureRequest) + } + + return LLMFixture{TXTAR: txtarFixture}, nil +} + +func MustLLMFixture(t testing.TB, txtarFixture TXTARFixture) LLMFixture { + t.Helper() + f, err := NewLLMFixture(txtarFixture) + if err != nil { + t.Fatalf("create LLM fixture: %v", err) + } + return f +} + +// RequestBody returns the request body with the stream flag set according to +// streaming. +func (f LLMFixture) RequestBody(streaming bool) ([]byte, error) { + body, ok := f.TXTAR.File(FixtureRequest) + if !ok { + // NewLLMFixture requires this, but keep it defensive. + return nil, fmt.Errorf("missing %q section", FixtureRequest) + } + return SetJSON(body, "stream", streaming) +} + +// MustRequestBody is a convenience helper for tests. +func (f LLMFixture) MustRequestBody(t testing.TB, streaming bool) []byte { + t.Helper() + body, err := f.RequestBody(streaming) + if err != nil { + t.Fatalf("fixture request body: %v", err) + } + return body +} + +// HasToolCallResponses reports whether the fixture includes the */tool-call +// response sections. +func (f LLMFixture) HasToolCallResponses() bool { + return f.TXTAR.Has(FixtureStreamingToolResponse) || f.TXTAR.Has(FixtureNonStreamingToolResponse) +} + +// Response returns the upstream response body for the given call number. +// +// call is 1-indexed. +func (f LLMFixture) Response(call int, streaming bool) ([]byte, error) { + if call < 1 { + return nil, fmt.Errorf("call must be >= 1, got %d", call) + } + + if call == 1 || !f.HasToolCallResponses() { + if streaming { + if !f.TXTAR.Has(FixtureStreamingResponse) { + return nil, fmt.Errorf("missing %q section", FixtureStreamingResponse) + } + return f.TXTAR.Files[FixtureStreamingResponse], nil + } + + if !f.TXTAR.Has(FixtureNonStreamingResponse) { + return nil, fmt.Errorf("missing %q section", FixtureNonStreamingResponse) + } + return f.TXTAR.Files[FixtureNonStreamingResponse], nil + } + + if call != 2 { + return nil, fmt.Errorf("unexpected call %d; this fixture only supports 1 or 2 calls", call) + } + + if streaming { + if !f.TXTAR.Has(FixtureStreamingToolResponse) { + return nil, fmt.Errorf("missing %q section", FixtureStreamingToolResponse) + } + return f.TXTAR.Files[FixtureStreamingToolResponse], nil + } + + if !f.TXTAR.Has(FixtureNonStreamingToolResponse) { + return nil, fmt.Errorf("missing %q section", FixtureNonStreamingToolResponse) + } + return f.TXTAR.Files[FixtureNonStreamingToolResponse], nil +} diff --git a/testutil/fixture_txtar.go b/testutil/fixture_txtar.go new file mode 100644 index 0000000..063ecbd --- /dev/null +++ b/testutil/fixture_txtar.go @@ -0,0 +1,79 @@ +package testutil + +import ( + "fmt" + "testing" + + "golang.org/x/tools/txtar" +) + +// TXTARFixture is a parsed txtar archive (see golang.org/x/tools/txtar) with a +// convenient map-based API. +// +// Tests should prefer TXTARFixture over ad-hoc txtar.Parse + manual file maps. +// It provides early validation and clearer error messages. +type TXTARFixture struct { + Comment string + Files map[string][]byte +} + +func ParseTXTAR(data []byte) (TXTARFixture, error) { + if len(data) == 0 { + return TXTARFixture{}, fmt.Errorf("empty txtar input") + } + + arc := txtar.Parse(data) + + files := make(map[string][]byte, len(arc.Files)) + for _, f := range arc.Files { + if f.Name == "" { + return TXTARFixture{}, fmt.Errorf("txtar contains a file with an empty name") + } + if _, exists := files[f.Name]; exists { + return TXTARFixture{}, fmt.Errorf("txtar contains duplicate file name %q", f.Name) + } + files[f.Name] = f.Data + } + + return TXTARFixture{ + Comment: string(arc.Comment), + Files: files, + }, nil +} + +func MustParseTXTAR(t testing.TB, data []byte) TXTARFixture { + t.Helper() + f, err := ParseTXTAR(data) + if err != nil { + t.Fatalf("parse txtar: %v", err) + } + return f +} + +func (f TXTARFixture) Has(name string) bool { + _, ok := f.Files[name] + return ok +} + +func (f TXTARFixture) File(name string) ([]byte, bool) { + b, ok := f.Files[name] + return b, ok +} + +func (f TXTARFixture) MustFile(t testing.TB, name string) []byte { + t.Helper() + b, ok := f.File(name) + if !ok { + t.Fatalf("txtar missing section %q; have %v", name, sortedKeys(f.Files)) + } + return b +} + +func (f TXTARFixture) RequireFiles(t testing.TB, names ...string) { + t.Helper() + for _, name := range names { + if !f.Has(name) { + t.Fatalf("txtar missing required section %q; have %v", name, sortedKeys(f.Files)) + } + } +} diff --git a/testutil/helpers.go b/testutil/helpers.go new file mode 100644 index 0000000..13cdefa --- /dev/null +++ b/testutil/helpers.go @@ -0,0 +1,28 @@ +package testutil + +import ( + "fmt" + "sort" + "testing" +) + +func sortedKeys[M ~map[string]V, V any](m M) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + +func mustNoError(t testing.TB, err error, format string, args ...any) { + t.Helper() + if err == nil { + return + } + prefix := "" + if format != "" { + prefix = fmt.Sprintf(format, args...) + ": " + } + t.Fatalf("%s%v", prefix, err) +} diff --git a/testutil/http_reflector.go b/testutil/http_reflector.go new file mode 100644 index 0000000..2f55668 --- /dev/null +++ b/testutil/http_reflector.go @@ -0,0 +1,60 @@ +package testutil + +import ( + "bufio" + "bytes" + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +// HTTPReflectorServer is an httptest.Server that responds with a raw HTTP +// response loaded from a fixture. +// +// The fixture bytes must be a complete HTTP response (status line, headers, +// blank line, body) as accepted by http.ReadResponse. +// +// This is useful for simulating upstream error responses. +type HTTPReflectorServer struct { + *httptest.Server +} + +func NewHTTPReflectorServer(t testing.TB, ctx context.Context, rawHTTPResponse []byte) *HTTPReflectorServer { + t.Helper() + if ctx == nil { + ctx = context.Background() + } + + s := &HTTPReflectorServer{} + + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mock, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(rawHTTPResponse)), r) + if err != nil { + t.Fatalf("read mock response: %v", err) + } + defer mock.Body.Close() + + for key, values := range mock.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + w.WriteHeader(mock.StatusCode) + _, err = io.Copy(w, mock.Body) + if err != nil { + t.Fatalf("copy mock response body: %v", err) + } + })) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + srv.Start() + t.Cleanup(srv.Close) + + s.Server = srv + return s +} diff --git a/testutil/inspector.go b/testutil/inspector.go new file mode 100644 index 0000000..8582ce3 --- /dev/null +++ b/testutil/inspector.go @@ -0,0 +1,81 @@ +package testutil + +import ( + "encoding/json" + "testing" + + "github.com/coder/aibridge" +) + +// Inspector provides a single place to access data needed for assertions. +// +// It is intentionally small; tests can always drop down to the underlying +// Recorder/MCP/Upstream objects when needed. +type Inspector struct { + Recorder *RecorderSpy + MCP *MCPServer + Upstream *UpstreamServer +} + +func NewInspector(recorder *RecorderSpy, mcpServer *MCPServer, upstream *UpstreamServer) *Inspector { + return &Inspector{Recorder: recorder, MCP: mcpServer, Upstream: upstream} +} + +func (i *Inspector) UpstreamCalls() int { + if i == nil || i.Upstream == nil { + return 0 + } + return i.Upstream.CallCount() +} + +// RequireToolCalledOnceWithArgs asserts that: +// - the bridge recorded exactly one tool usage with the given name +// - the mock MCP server received exactly one call with matching args +func (i *Inspector) RequireToolCalledOnceWithArgs(t testing.TB, tool string, wantArgs any) { + t.Helper() + if i == nil { + t.Fatalf("inspector is nil") + } + if i.Recorder == nil { + t.Fatalf("inspector Recorder is nil") + } + if i.MCP == nil { + t.Fatalf("inspector MCP is nil") + } + + // Verify bridge-side record. + var toolUsages []*aibridge.ToolUsageRecord + for _, u := range i.Recorder.RecordedToolUsages() { + if u.Tool == tool { + toolUsages = append(toolUsages, u) + } + } + if len(toolUsages) != 1 { + t.Fatalf("tool usages for %q: got %d, want 1", tool, len(toolUsages)) + } + + wantJSON, err := json.Marshal(wantArgs) + if err != nil { + t.Fatalf("marshal wantArgs: %v", err) + } + gotJSON, err := json.Marshal(toolUsages[0].Args) + if err != nil { + t.Fatalf("marshal recorded tool args: %v", err) + } + if string(wantJSON) != string(gotJSON) { + t.Fatalf("recorded tool args mismatch\nwant: %s\ngot: %s", string(wantJSON), string(gotJSON)) + } + + // Verify MCP-side receipt. + invocations := i.MCP.CallsByTool(tool) + if len(invocations) != 1 { + t.Fatalf("MCP calls for %q: got %d, want 1", tool, len(invocations)) + } + gotJSON, err = json.Marshal(invocations[0]) + if err != nil { + t.Fatalf("marshal MCP call args: %v", err) + } + if string(wantJSON) != string(gotJSON) { + t.Fatalf("MCP call args mismatch\nwant: %s\ngot: %s", string(wantJSON), string(gotJSON)) + } +} diff --git a/testutil/json.go b/testutil/json.go new file mode 100644 index 0000000..69aa10d --- /dev/null +++ b/testutil/json.go @@ -0,0 +1,32 @@ +package testutil + +import ( + "fmt" + "testing" + + "github.com/tidwall/sjson" +) + +// SetJSON sets a JSON value at the given key/path (tidwall/sjson syntax) and +// returns the updated JSON bytes. +func SetJSON(in []byte, key string, val any) ([]byte, error) { + if len(in) == 0 { + return nil, fmt.Errorf("empty JSON input") + } + if key == "" { + return nil, fmt.Errorf("empty JSON key") + } + + out, err := sjson.SetBytes(in, key, val) + if err != nil { + return nil, err + } + return out, nil +} + +func MustSetJSON(t testing.TB, in []byte, key string, val any) []byte { + t.Helper() + out, err := SetJSON(in, key, val) + mustNoError(t, err, "set JSON") + return out +} diff --git a/testutil/mcp_server.go b/testutil/mcp_server.go new file mode 100644 index 0000000..4cc2f82 --- /dev/null +++ b/testutil/mcp_server.go @@ -0,0 +1,122 @@ +package testutil + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "cdr.dev/slog" + "github.com/coder/aibridge/mcp" + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "go.opentelemetry.io/otel/trace" +) + +const ( + // ToolCoderListWorkspaces matches the tool name used throughout the current tests. + ToolCoderListWorkspaces = "coder_list_workspaces" +) + +func DefaultCoderToolNames() []string { + return []string{ + ToolCoderListWorkspaces, + "coder_list_templates", + "coder_template_version_parameters", + "coder_get_authenticated_user", + "coder_create_workspace_build", + } +} + +type MCPToolResultFunc func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) + +type MCPServer struct { + *httptest.Server + + callsMu sync.Mutex + calls map[string][]any +} + +type MCPServerOption func(*mcpServerConfig) + +type mcpServerConfig struct { + toolResultFn MCPToolResultFunc +} + +func WithMCPToolResult(fn MCPToolResultFunc) MCPServerOption { + return func(cfg *mcpServerConfig) { + cfg.toolResultFn = fn + } +} + +func NewMCPServer(t testing.TB, toolNames []string, opts ...MCPServerOption) *MCPServer { + t.Helper() + + cfg := mcpServerConfig{ + toolResultFn: func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + return mcplib.NewToolResultText("mock"), nil + }, + } + for _, opt := range opts { + if opt != nil { + opt(&cfg) + } + } + + s := &MCPServer{calls: make(map[string][]any)} + + mcpSrv := server.NewMCPServer( + "Mock MCP server", + "1.0.0", + server.WithToolCapabilities(true), + ) + + for _, name := range toolNames { + name := name // capture + tool := mcplib.NewTool(name, mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name))) + mcpSrv.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + s.addCall(request.Params.Name, request.Params.Arguments) + return cfg.toolResultFn(ctx, request) + }) + } + + h := server.NewStreamableHTTPServer(mcpSrv) + s.Server = httptest.NewServer(h) + t.Cleanup(s.Server.Close) + + return s +} + +func (s *MCPServer) addCall(tool string, args any) { + s.callsMu.Lock() + defer s.callsMu.Unlock() + + s.calls[tool] = append(s.calls[tool], args) +} + +func (s *MCPServer) CallsByTool(name string) []any { + s.callsMu.Lock() + defer s.callsMu.Unlock() + + calls := s.calls[name] + out := make([]any, len(calls)) + copy(out, calls) + return out +} + +func (s *MCPServer) Proxiers(t testing.TB, serverName string, logger slog.Logger, tracer trace.Tracer) map[string]mcp.ServerProxier { + t.Helper() + + proxy, err := mcp.NewStreamableHTTPServerProxy(serverName, s.URL, nil, nil, nil, logger, tracer) + mustNoError(t, err, "create MCP proxy") + return map[string]mcp.ServerProxier{proxy.Name(): proxy} +} + +func (s *MCPServer) Handler() http.Handler { + if s.Server == nil { + return nil + } + return s.Server.Config.Handler +} diff --git a/testutil/recorder_spy.go b/testutil/recorder_spy.go new file mode 100644 index 0000000..a599d48 --- /dev/null +++ b/testutil/recorder_spy.go @@ -0,0 +1,142 @@ +package testutil + +import ( + "context" + "fmt" + "slices" + "sync" + "testing" + "time" + + "github.com/coder/aibridge" +) + +var _ aibridge.Recorder = (*RecorderSpy)(nil) + +// RecorderSpy is a thread-safe in-memory implementation of [aibridge.Recorder] +// intended for tests. +// +// It provides query helpers so tests can assert on recorded interceptions, +// token/prompt usage, and tool usage. +type RecorderSpy struct { + mu sync.Mutex + + interceptions []*aibridge.InterceptionRecord + tokenUsages []*aibridge.TokenUsageRecord + userPrompts []*aibridge.PromptUsageRecord + toolUsages []*aibridge.ToolUsageRecord + interceptionsEnd map[string]time.Time +} + +func (m *RecorderSpy) RecordInterception(ctx context.Context, req *aibridge.InterceptionRecord) error { + m.mu.Lock() + defer m.mu.Unlock() + m.interceptions = append(m.interceptions, req) + return nil +} + +func (m *RecorderSpy) RecordInterceptionEnded(ctx context.Context, req *aibridge.InterceptionRecordEnded) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.interceptionsEnd == nil { + m.interceptionsEnd = make(map[string]time.Time) + } + if !slices.ContainsFunc(m.interceptions, func(intc *aibridge.InterceptionRecord) bool { return intc.ID == req.ID }) { + return fmt.Errorf("interception id not found: %q", req.ID) + } + m.interceptionsEnd[req.ID] = req.EndedAt + return nil +} + +func (m *RecorderSpy) RecordPromptUsage(ctx context.Context, req *aibridge.PromptUsageRecord) error { + m.mu.Lock() + defer m.mu.Unlock() + m.userPrompts = append(m.userPrompts, req) + return nil +} + +func (m *RecorderSpy) RecordTokenUsage(ctx context.Context, req *aibridge.TokenUsageRecord) error { + m.mu.Lock() + defer m.mu.Unlock() + m.tokenUsages = append(m.tokenUsages, req) + return nil +} + +func (m *RecorderSpy) RecordToolUsage(ctx context.Context, req *aibridge.ToolUsageRecord) error { + m.mu.Lock() + defer m.mu.Unlock() + m.toolUsages = append(m.toolUsages, req) + return nil +} + +// RecordedTokenUsages returns a shallow clone of recorded token usages. +func (m *RecorderSpy) RecordedTokenUsages() []*aibridge.TokenUsageRecord { + m.mu.Lock() + defer m.mu.Unlock() + return slices.Clone(m.tokenUsages) +} + +// RecordedPromptUsages returns a shallow clone of recorded prompt usages. +func (m *RecorderSpy) RecordedPromptUsages() []*aibridge.PromptUsageRecord { + m.mu.Lock() + defer m.mu.Unlock() + return slices.Clone(m.userPrompts) +} + +// RecordedToolUsages returns a shallow clone of recorded tool usages. +func (m *RecorderSpy) RecordedToolUsages() []*aibridge.ToolUsageRecord { + m.mu.Lock() + defer m.mu.Unlock() + return slices.Clone(m.toolUsages) +} + +// RecordedInterceptions returns a shallow clone of recorded interceptions. +func (m *RecorderSpy) RecordedInterceptions() []*aibridge.InterceptionRecord { + m.mu.Lock() + defer m.mu.Unlock() + return slices.Clone(m.interceptions) +} + +// RequireAllInterceptionsEnded fails the test if any recorded interception did +// not receive a corresponding RecordInterceptionEnded call. +func (m *RecorderSpy) RequireAllInterceptionsEnded(t testing.TB) { + t.Helper() + + m.mu.Lock() + defer m.mu.Unlock() + + gotEnded := 0 + if m.interceptionsEnd != nil { + gotEnded = len(m.interceptionsEnd) + } + + if len(m.interceptions) != gotEnded { + t.Fatalf("got %d interception ended calls, want %d", gotEnded, len(m.interceptions)) + } + for _, intc := range m.interceptions { + if m.interceptionsEnd == nil { + t.Fatalf("interception with id %q has not been ended", intc.ID) + } + if _, ok := m.interceptionsEnd[intc.ID]; !ok { + t.Fatalf("interception with id %q has not been ended", intc.ID) + } + } +} + +// TotalInputTokens sums input tokens from token usage records. +func TotalInputTokens(in []*aibridge.TokenUsageRecord) int64 { + var total int64 + for _, el := range in { + total += el.Input + } + return total +} + +// TotalOutputTokens sums output tokens from token usage records. +func TotalOutputTokens(in []*aibridge.TokenUsageRecord) int64 { + var total int64 + for _, el := range in { + total += el.Output + } + return total +} diff --git a/testutil/upstream_server.go b/testutil/upstream_server.go new file mode 100644 index 0000000..6e9d161 --- /dev/null +++ b/testutil/upstream_server.go @@ -0,0 +1,261 @@ +package testutil + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// UpstreamRequest captures a single request received by an [UpstreamServer]. +// +// It is intended for test assertions (e.g. validating headers, paths, and +// request bodies). +type UpstreamRequest struct { + Call int + Method string + Path string + Header http.Header + Body []byte +} + +// UpstreamServer is an httptest.Server that mimics an upstream LLM provider. +// +// It is driven by an [LLMFixture]. It chooses between the streaming and +// non-streaming fixture responses based on the request body ("stream": true) +// and the request path (AWS Bedrock uses a streaming-specific path). +// +// For injected-tool tests, it can serve a second response ("*/tool-call") on +// the 2nd request. +type UpstreamServer struct { + *httptest.Server + + fixture LLMFixture + + callCount atomic.Uint32 + + requestsMu sync.Mutex + requests []UpstreamRequest + + // statusCode is only applied for non-streaming responses, matching the current + // test behavior (streaming responses default to 200 once the stream begins). + statusCode atomic.Int32 + + responseMutator func(call int, resp []byte) []byte +} + +type UpstreamOption func(*UpstreamServer) + +func WithUpstreamNonStreamingStatusCode(code int) UpstreamOption { + return func(s *UpstreamServer) { + s.statusCode.Store(int32(code)) + } +} + +func WithUpstreamResponseMutator(fn func(call int, resp []byte) []byte) UpstreamOption { + return func(s *UpstreamServer) { + s.responseMutator = fn + } +} + +func NewUpstreamServer(t testing.TB, ctx context.Context, fixture LLMFixture, opts ...UpstreamOption) *UpstreamServer { + t.Helper() + if ctx == nil { + ctx = context.Background() + } + + s := &UpstreamServer{fixture: fixture} + for _, opt := range opts { + if opt != nil { + opt(s) + } + } + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + call := int(s.callCount.Add(1)) + + body, err := io.ReadAll(r.Body) + _ = r.Body.Close() + if err != nil { + t.Fatalf("read upstream request body: %v", err) + } + + s.recordRequest(call, r, body) + + type msg struct { + Stream bool `json:"stream"` + } + var reqMsg msg + if err := json.Unmarshal(body, &reqMsg); err != nil { + t.Fatalf("unmarshal upstream request body: %v", err) + } + + // AWS Bedrock uses a streaming-specific path suffix. + isStreaming := reqMsg.Stream || strings.HasSuffix(r.URL.Path, "invoke-with-response-stream") + + if !isStreaming { + respBody, err := s.fixture.Response(call, false) + if err != nil { + t.Fatalf("select upstream response: %v", err) + } + if s.responseMutator != nil { + respBody = s.responseMutator(call, respBody) + } + + code := int(s.statusCode.Load()) + if code == 0 { + code = http.StatusOK + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + _, _ = w.Write(respBody) + return + } + + // Streaming response. + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + respBody, err := s.fixture.Response(call, true) + if err != nil { + t.Fatalf("select upstream response: %v", err) + } + if s.responseMutator != nil { + respBody = s.responseMutator(call, respBody) + } + + scanner := bufio.NewScanner(bytes.NewReader(respBody)) + // Allow large SSE lines; fixtures may exceed the default 64KiB scanner limit. + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming unsupported", http.StatusInternalServerError) + return + } + + for scanner.Scan() { + line := scanner.Text() + fmt.Fprintf(w, "%s\n", line) + flusher.Flush() + } + + if err := scanner.Err(); err != nil { + http.Error(w, fmt.Sprintf("error reading fixture: %v", err), http.StatusInternalServerError) + return + } + }) + + srv := httptest.NewUnstartedServer(h) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + srv.Start() + t.Cleanup(srv.Close) + + s.Server = srv + return s +} + +func (s *UpstreamServer) recordRequest(call int, r *http.Request, body []byte) { + if s == nil { + return + } + if r == nil { + return + } + + bodyCopy := make([]byte, len(body)) + copy(bodyCopy, body) + + s.requestsMu.Lock() + s.requests = append(s.requests, UpstreamRequest{ + Call: call, + Method: r.Method, + Path: r.URL.Path, + Header: r.Header.Clone(), + Body: bodyCopy, + }) + s.requestsMu.Unlock() +} + +// Requests returns a snapshot of requests received by this server. +func (s *UpstreamServer) Requests() []UpstreamRequest { + if s == nil { + return nil + } + + s.requestsMu.Lock() + defer s.requestsMu.Unlock() + + out := make([]UpstreamRequest, len(s.requests)) + for i, req := range s.requests { + bodyCopy := make([]byte, len(req.Body)) + copy(bodyCopy, req.Body) + + out[i] = UpstreamRequest{ + Call: req.Call, + Method: req.Method, + Path: req.Path, + Header: req.Header.Clone(), + Body: bodyCopy, + } + } + return out +} + +// LastRequest returns the most recently received request, if any. +func (s *UpstreamServer) LastRequest() (UpstreamRequest, bool) { + if s == nil { + return UpstreamRequest{}, false + } + + s.requestsMu.Lock() + defer s.requestsMu.Unlock() + if len(s.requests) == 0 { + return UpstreamRequest{}, false + } + req := s.requests[len(s.requests)-1] + + bodyCopy := make([]byte, len(req.Body)) + copy(bodyCopy, req.Body) + + return UpstreamRequest{ + Call: req.Call, + Method: req.Method, + Path: req.Path, + Header: req.Header.Clone(), + Body: bodyCopy, + }, true +} + +func (s *UpstreamServer) CallCount() int { + return int(s.callCount.Load()) +} + +func (s *UpstreamServer) RequireCallCountEventually(t testing.TB, want int) { + t.Helper() + + deadline := time.Now().Add(10 * time.Second) + for { + if s.CallCount() == want { + return + } + if time.Now().After(deadline) { + t.Fatalf("upstream call count: got %d, want %d", s.CallCount(), want) + } + time.Sleep(50 * time.Millisecond) + } +} diff --git a/trace_integration_test.go b/trace_integration_test.go index ee6574d..b41ea63 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -3,6 +3,7 @@ package aibridge_test import ( "context" "fmt" + "io" "net/http" "net/http/httptest" "slices" @@ -14,6 +15,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/aibridge" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/testutil" "github.com/coder/aibridge/tracing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,7 +24,6 @@ import ( "go.opentelemetry.io/otel/codes" sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/sdk/trace/tracetest" - "golang.org/x/tools/txtar" ) // expect 'count' amount of traces named 'name' with status 'status' @@ -85,18 +86,13 @@ func TestTraceAnthropic(t *testing.T) { }, } - arc := txtar.Parse(antSingleBuiltinTool) - - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) - - fixtureReqBody := files[fixtureRequest] + fixture := testutil.MustParseTXTAR(t, antSingleBuiltinTool) + fixture.RequireFiles(t, testutil.FixtureRequest, testutil.FixtureStreamingResponse, testutil.FixtureNonStreamingResponse) + llm := testutil.MustLLMFixture(t, fixture) for _, tc := range cases { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) sr := tracetest.NewSpanRecorder() @@ -104,29 +100,28 @@ func TestTraceAnthropic(t *testing.T) { tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) - require.NoError(t, err) + reqBody := llm.MustRequestBody(t, tc.streaming) - mockAPI := newMockServer(ctx, t, files, nil) - t.Cleanup(mockAPI.Close) + upstream := testutil.NewUpstreamServer(t, ctx, llm) var bedrockCfg *aibridge.AWSBedrockConfig if tc.bedrock { - bedrockCfg = testBedrockCfg(mockAPI.URL) + bedrockCfg = testBedrockCfg(upstream.URL) } - provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), bedrockCfg) + provider := aibridge.NewAnthropicProvider(anthropicCfg(upstream.URL, apiKey), bedrockCfg) srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) - req := createAnthropicMessagesReq(t, srv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) + req := srv.NewProviderRequest(t, provider.Name(), reqBody) + resp, err := srv.Client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - srv.Close() + _, err = io.Copy(io.Discard, resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) - require.Equal(t, 1, len(recorder.interceptions)) - intcID := recorder.interceptions[0].ID + interceptions := recorder.RecordedInterceptions() + require.Len(t, interceptions, 1) + intcID := interceptions[0].ID model := gjson.Get(string(reqBody), "model").Str if tc.bedrock { @@ -205,58 +200,48 @@ func TestTraceAnthropicErr(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - var arc *txtar.Archive + fixtureBytes := antNonStreamErr if tc.streaming { - arc = txtar.Parse(antMidStreamErr) - } else { - arc = txtar.Parse(antNonStreamErr) - } - - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) - if tc.streaming { - require.Contains(t, files, fixtureStreamingResponse) - } else { - require.Contains(t, files, fixtureNonStreamingResponse) + fixtureBytes = antMidStreamErr } - fixtureReqBody := files[fixtureRequest] + fixture := testutil.MustParseTXTAR(t, fixtureBytes) + llm := testutil.MustLLMFixture(t, fixture) sr := tracetest.NewSpanRecorder() tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) - require.NoError(t, err) + reqBody := llm.MustRequestBody(t, tc.streaming) - mockAPI := newMockServer(ctx, t, files, nil) - t.Cleanup(mockAPI.Close) + upstream := testutil.NewUpstreamServer(t, ctx, llm) var bedrockCfg *aibridge.AWSBedrockConfig if tc.bedrock { - bedrockCfg = testBedrockCfg(mockAPI.URL) + bedrockCfg = testBedrockCfg(upstream.URL) } - provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), bedrockCfg) + provider := aibridge.NewAnthropicProvider(anthropicCfg(upstream.URL, apiKey), bedrockCfg) srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) - req := createAnthropicMessagesReq(t, srv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) + req := srv.NewProviderRequest(t, provider.Name(), reqBody) + resp, err := srv.Client.Do(req) require.NoError(t, err) if tc.streaming { require.Equal(t, http.StatusOK, resp.StatusCode) } else { require.Equal(t, http.StatusInternalServerError, resp.StatusCode) } - defer resp.Body.Close() - srv.Close() + _, err = io.Copy(io.Discard, resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) - require.Equal(t, 1, len(recorder.interceptions)) - intcID := recorder.interceptions[0].ID + interceptions := recorder.RecordedInterceptions() + require.Len(t, interceptions, 1) + intcID := interceptions[0].ID totalCount := 0 for _, e := range tc.expect { @@ -326,44 +311,32 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) { tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + h := runInjectedToolTest(t, aibridge.ProviderAnthropic, antSingleInjectedTool, tc.streaming, tracer, func(upstreamURL string) []aibridge.Provider { var bedrockCfg *aibridge.AWSBedrockConfig if tc.bedrock { - bedrockCfg = testBedrockCfg(addr) + bedrockCfg = testBedrockCfg(upstreamURL) } - providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), bedrockCfg)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, tracer) - } - - var reqBody string - var reqPath string - reqFunc := func(t *testing.T, baseURL string, input []byte) *http.Request { - reqBody = string(input) - r := createAnthropicMessagesReq(t, baseURL, input) - reqPath = r.URL.Path - return r - } - - // Build the requirements & make the assertions which are common to all providers. - recorderClient, _, proxies, resp := setupInjectedToolTest(t, antSingleInjectedTool, tc.streaming, configureFn, reqFunc) + return []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(upstreamURL, apiKey), bedrockCfg)} + }) - defer resp.Body.Close() + defer h.Response.Body.Close() - require.Len(t, recorderClient.interceptions, 1) - intcID := recorderClient.interceptions[0].ID + interceptions := h.Recorder.RecordedInterceptions() + require.Len(t, interceptions, 1) + intcID := interceptions[0].ID - model := gjson.Get(string(reqBody), "model").Str + reqBody := string(h.RequestBody) + model := gjson.Get(reqBody, "model").Str if tc.bedrock { model = "beddel" } - for _, proxy := range proxies { + for _, proxy := range h.MCPProxiers { require.NotEmpty(t, proxy.ListTools()) tool := proxy.ListTools()[0] attrs := []attribute.KeyValue{ - attribute.String(tracing.RequestPath, reqPath), + attribute.String(tracing.RequestPath, h.RequestPath), attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, aibridge.ProviderAnthropic), attribute.String(tracing.Model, model), @@ -421,42 +394,36 @@ func TestTraceOpenAI(t *testing.T) { } for _, tc := range cases { - t.Run(t.Name(), func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - arc := txtar.Parse(tc.fixture) - - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) - - fixtureReqBody := files[fixtureRequest] + fixture := testutil.MustParseTXTAR(t, tc.fixture) + fixture.RequireFiles(t, testutil.FixtureRequest, testutil.FixtureStreamingResponse, testutil.FixtureNonStreamingResponse) + llm := testutil.MustLLMFixture(t, fixture) sr := tracetest.NewSpanRecorder() tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) - require.NoError(t, err) + reqBody := llm.MustRequestBody(t, tc.streaming) - mockAPI := newMockServer(ctx, t, files, nil) - t.Cleanup(mockAPI.Close) - provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) + upstream := testutil.NewUpstreamServer(t, ctx, llm) + provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) - req := createOpenAIChatCompletionsReq(t, srv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) + req := srv.NewProviderRequest(t, provider.Name(), reqBody) + resp, err := srv.Client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - srv.Close() + _, err = io.Copy(io.Discard, resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) - require.Equal(t, 1, len(recorder.interceptions)) - intcID := recorder.interceptions[0].ID + interceptions := recorder.RecordedInterceptions() + require.Len(t, interceptions, 1) + intcID := interceptions[0].ID totalCount := 0 for _, e := range tc.expect { @@ -511,54 +478,44 @@ func TestTraceOpenAIErr(t *testing.T) { } for _, tc := range cases { - t.Run(t.Name(), func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) t.Cleanup(cancel) - var arc *txtar.Archive - if tc.streaming { - arc = txtar.Parse(oaiMidStreamErr) - } else { - arc = txtar.Parse(oaiNonStreamErr) - } - - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) + fixtureBytes := oaiNonStreamErr if tc.streaming { - require.Contains(t, files, fixtureStreamingResponse) - } else { - require.Contains(t, files, fixtureNonStreamingResponse) + fixtureBytes = oaiMidStreamErr } - fixtureReqBody := files[fixtureRequest] + fixture := testutil.MustParseTXTAR(t, fixtureBytes) + llm := testutil.MustLLMFixture(t, fixture) sr := tracetest.NewSpanRecorder() tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) - require.NoError(t, err) + reqBody := llm.MustRequestBody(t, tc.streaming) - mockAPI := newMockServer(ctx, t, files, nil) - t.Cleanup(mockAPI.Close) - provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) + upstream := testutil.NewUpstreamServer(t, ctx, llm) + provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) - req := createOpenAIChatCompletionsReq(t, srv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) + req := srv.NewProviderRequest(t, provider.Name(), reqBody) + resp, err := srv.Client.Do(req) require.NoError(t, err) if tc.streaming { require.Equal(t, http.StatusOK, resp.StatusCode) } else { require.Equal(t, http.StatusInternalServerError, resp.StatusCode) } - defer resp.Body.Close() - srv.Close() + _, err = io.Copy(io.Discard, resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) - require.Equal(t, 1, len(recorder.interceptions)) - intcID := recorder.interceptions[0].ID + interceptions := recorder.RecordedInterceptions() + require.Len(t, interceptions, 1) + intcID := interceptions[0].ID totalCount := 0 for _, e := range tc.expect { @@ -591,38 +548,25 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) { tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, tracer) - } - - var reqBody string - var reqPath string - reqFunc := func(t *testing.T, baseURL string, input []byte) *http.Request { - reqBody = string(input) - r := createOpenAIChatCompletionsReq(t, baseURL, input) - reqPath = r.URL.Path - return r - } - - // Build the requirements & make the assertions which are common to all providers. - recorderClient, _, proxies, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, reqFunc) + h := runInjectedToolTest(t, aibridge.ProviderOpenAI, oaiSingleInjectedTool, streaming, tracer, func(upstreamURL string) []aibridge.Provider { + return []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(upstreamURL, apiKey))} + }) - defer resp.Body.Close() + defer h.Response.Body.Close() - require.Len(t, recorderClient.interceptions, 1) - intcID := recorderClient.interceptions[0].ID + interceptions := h.Recorder.RecordedInterceptions() + require.Len(t, interceptions, 1) + intcID := interceptions[0].ID - for _, proxy := range proxies { + for _, proxy := range h.MCPProxiers { require.NotEmpty(t, proxy.ListTools()) tool := proxy.ListTools()[0] attrs := []attribute.KeyValue{ - attribute.String(tracing.RequestPath, reqPath), + attribute.String(tracing.RequestPath, h.RequestPath), attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, aibridge.ProviderOpenAI), - attribute.String(tracing.Model, gjson.Get(reqBody, "model").Str), + attribute.String(tracing.Model, gjson.Get(string(h.RequestBody), "model").Str), attribute.String(tracing.InitiatorID, userID), attribute.String(tracing.MCPInput, "{\"owner\":\"admin\"}"), attribute.String(tracing.MCPToolName, "coder_list_workspaces"), @@ -639,13 +583,13 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) { func TestTracePassthrough(t *testing.T) { t.Parallel() - arc := txtar.Parse(oaiFallthrough) - files := filesMap(arc) + fixture := testutil.MustParseTXTAR(t, oaiFallthrough) + respBody := fixture.MustFile(t, testutil.FixtureResponse) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write(files[fixtureResponse]) + _, _ = w.Write(respBody) })) t.Cleanup(upstream.Close) @@ -687,9 +631,7 @@ func TestNewServerProxyManagerTraces(t *testing.T) { defer func() { _ = tp.Shutdown(t.Context()) }() serverName := "serverName" - srv, _ := createMockMCPSrv(t) - mcpSrv := httptest.NewServer(srv) - t.Cleanup(mcpSrv.Close) + mcpSrv := testutil.NewMCPServer(t, testutil.DefaultCoderToolNames()) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) proxy, err := mcp.NewStreamableHTTPServerProxy(serverName, mcpSrv.URL, nil, nil, nil, logger, tracer)