diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 925571c..f73138e 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -171,21 +171,24 @@ func TestAnthropicMessages(t *testing.T) { // One for message_start, one for message_delta. expectedTokenRecordings = 2 } - require.Len(t, recorderClient.tokenUsages, expectedTokenRecordings) + tokenUsages := recorderClient.RecordedTokenUsages() + require.Len(t, tokenUsages, expectedTokenRecordings) - assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(recorderClient.tokenUsages), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(recorderClient.tokenUsages), "output tokens miscalculated") + assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") - require.Len(t, recorderClient.toolUsages, 1) - assert.Equal(t, "Read", recorderClient.toolUsages[0].Tool) - require.IsType(t, json.RawMessage{}, recorderClient.toolUsages[0].Args) + toolUsages := recorderClient.RecordedToolUsages() + require.Len(t, toolUsages, 1) + assert.Equal(t, "Read", toolUsages[0].Tool) + require.IsType(t, json.RawMessage{}, toolUsages[0].Args) var args map[string]any - require.NoError(t, json.Unmarshal(recorderClient.toolUsages[0].Args.(json.RawMessage), &args)) + require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args)) require.Contains(t, args, "file_path") assert.Equal(t, "/tmp/blah/foo", args["file_path"]) - require.Len(t, recorderClient.userPrompts, 1) - assert.Equal(t, "read the foo file", recorderClient.userPrompts[0].Prompt) + promptUsages := recorderClient.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + assert.Equal(t, "read the foo file", promptUsages[0].Prompt) recorderClient.verifyAllInterceptionsEnded(t) }) @@ -346,8 +349,9 @@ func TestAWSBedrockIntegration(t *testing.T) { // and the interception data. require.Equal(t, requestCount, 1) require.Equal(t, bedrockCfg.Model, receivedModelName) - require.Len(t, recorderClient.interceptions, 1) - require.Equal(t, recorderClient.interceptions[0].Model, bedrockCfg.Model) + interceptions := recorderClient.RecordedInterceptions() + require.Len(t, interceptions, 1) + require.Equal(t, interceptions[0].Model, bedrockCfg.Model) recorderClient.verifyAllInterceptionsEnded(t) }) } @@ -437,18 +441,21 @@ func TestOpenAIChatCompletions(t *testing.T) { assert.Equal(t, "[DONE]", lastEvent.Data) } - require.Len(t, recorderClient.tokenUsages, 1) - assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(recorderClient.tokenUsages), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(recorderClient.tokenUsages), "output tokens miscalculated") + tokenUsages := recorderClient.RecordedTokenUsages() + require.Len(t, tokenUsages, 1) + assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") - require.Len(t, recorderClient.toolUsages, 1) - assert.Equal(t, "read_file", recorderClient.toolUsages[0].Tool) - require.IsType(t, map[string]any{}, recorderClient.toolUsages[0].Args) - require.Contains(t, recorderClient.toolUsages[0].Args, "path") - assert.Equal(t, "README.md", recorderClient.toolUsages[0].Args.(map[string]any)["path"]) + toolUsages := recorderClient.RecordedToolUsages() + require.Len(t, toolUsages, 1) + assert.Equal(t, "read_file", toolUsages[0].Tool) + require.IsType(t, map[string]any{}, toolUsages[0].Args) + require.Contains(t, toolUsages[0].Args, "path") + assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"]) - require.Len(t, recorderClient.userPrompts, 1) - assert.Equal(t, "how large is the README.md file in my current path", recorderClient.userPrompts[0].Prompt) + promptUsages := recorderClient.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt) recorderClient.verifyAllInterceptionsEnded(t) }) @@ -605,8 +612,9 @@ func TestSimple(t *testing.T) { resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Then: I expect the prompt to have been tracked. - require.NotEmpty(t, recorderClient.userPrompts, "no prompts tracked") - assert.Contains(t, recorderClient.userPrompts[0].Prompt, "how many angels can dance on the head of a pin") + promptUsages := recorderClient.RecordedPromptUsages() + require.NotEmpty(t, promptUsages, "no prompts tracked") + assert.Contains(t, promptUsages[0].Prompt, "how many angels can dance on the head of a pin") // Validate that responses have their IDs overridden with a interception ID rather than the original ID from the upstream provider. // The reason for this is that Bridge may make multiple upstream requests (i.e. to invoke injected tools), and clients will not be expecting @@ -615,8 +623,9 @@ func TestSimple(t *testing.T) { require.NoError(t, err, "failed to retrieve response ID") require.Nilf(t, uuid.Validate(id), "%s is not a valid UUID", id) - require.GreaterOrEqual(t, len(recorderClient.tokenUsages), 1) - require.Equal(t, recorderClient.tokenUsages[0].MsgID, tc.expectedMsgID) + tokenUsages := recorderClient.RecordedTokenUsages() + require.GreaterOrEqual(t, len(tokenUsages), 1) + require.Equal(t, tokenUsages[0].MsgID, tc.expectedMsgID) recorderClient.verifyAllInterceptionsEnded(t) }) @@ -770,11 +779,12 @@ func TestAnthropicInjectedTools(t *testing.T) { recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) // Ensure expected tool was invoked with expected input. - require.Len(t, recorderClient.toolUsages, 1) - require.Equal(t, mockToolName, recorderClient.toolUsages[0].Tool) + toolUsages := recorderClient.RecordedToolUsages() + require.Len(t, toolUsages, 1) + require.Equal(t, mockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) require.NoError(t, err) - actual, err := json.Marshal(recorderClient.toolUsages[0].Args) + actual, err := json.Marshal(toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) invocations := mcpCalls.getCallsByTool(mockToolName) @@ -831,11 +841,13 @@ func TestAnthropicInjectedTools(t *testing.T) { assert.EqualValues(t, 204, message.Usage.OutputTokens) // Ensure tokens used during injected tool invocation are accounted for. - assert.EqualValues(t, 15308, calculateTotalInputTokens(recorderClient.tokenUsages)) - assert.EqualValues(t, 204, calculateTotalOutputTokens(recorderClient.tokenUsages)) + tokenUsages := recorderClient.RecordedTokenUsages() + assert.EqualValues(t, 15308, calculateTotalInputTokens(tokenUsages)) + assert.EqualValues(t, 204, calculateTotalOutputTokens(tokenUsages)) // Ensure we received exactly one prompt. - require.Len(t, recorderClient.userPrompts, 1) + promptUsages := recorderClient.RecordedPromptUsages() + require.Len(t, promptUsages, 1) }) } } @@ -857,11 +869,12 @@ func TestOpenAIInjectedTools(t *testing.T) { recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) // Ensure expected tool was invoked with expected input. - require.Len(t, recorderClient.toolUsages, 1) - require.Equal(t, mockToolName, recorderClient.toolUsages[0].Tool) + toolUsages := recorderClient.RecordedToolUsages() + require.Len(t, toolUsages, 1) + require.Equal(t, mockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) require.NoError(t, err) - actual, err := json.Marshal(recorderClient.toolUsages[0].Args) + actual, err := json.Marshal(toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) invocations := mcpCalls.getCallsByTool(mockToolName) @@ -933,11 +946,13 @@ func TestOpenAIInjectedTools(t *testing.T) { assert.EqualValues(t, 105, message.Usage.CompletionTokens) // Ensure tokens used during injected tool invocation are accounted for. - require.EqualValues(t, 5047, calculateTotalInputTokens(recorderClient.tokenUsages)) - require.EqualValues(t, 105, calculateTotalOutputTokens(recorderClient.tokenUsages)) + tokenUsages := recorderClient.RecordedTokenUsages() + require.EqualValues(t, 5047, calculateTotalInputTokens(tokenUsages)) + require.EqualValues(t, 105, calculateTotalOutputTokens(tokenUsages)) // Ensure we received exactly one prompt. - require.Len(t, recorderClient.userPrompts, 1) + promptUsages := recorderClient.RecordedPromptUsages() + require.Len(t, promptUsages, 1) }) } } @@ -1822,6 +1837,40 @@ func (m *mockRecorderClient) RecordToolUsage(ctx context.Context, req *aibridge. return nil } +// RecordedTokenUsages returns a copy of recorded token usages in a thread-safe manner. +// Note: This is a shallow clone - the slice is copied but the pointers reference the +// same underlying records. This is sufficient for our test assertions which only read +// the data and don't modify the records. +func (m *mockRecorderClient) RecordedTokenUsages() []*aibridge.TokenUsageRecord { + m.mu.Lock() + defer m.mu.Unlock() + return slices.Clone(m.tokenUsages) +} + +// RecordedPromptUsages returns a copy of recorded prompt usages in a thread-safe manner. +// Note: This is a shallow clone (see RecordedTokenUsages for details). +func (m *mockRecorderClient) RecordedPromptUsages() []*aibridge.PromptUsageRecord { + m.mu.Lock() + defer m.mu.Unlock() + return slices.Clone(m.userPrompts) +} + +// RecordedToolUsages returns a copy of recorded tool usages in a thread-safe manner. +// Note: This is a shallow clone (see RecordedTokenUsages for details). +func (m *mockRecorderClient) RecordedToolUsages() []*aibridge.ToolUsageRecord { + m.mu.Lock() + defer m.mu.Unlock() + return slices.Clone(m.toolUsages) +} + +// RecordedInterceptions returns a copy of recorded interceptions in a thread-safe manner. +// Note: This is a shallow clone (see RecordedTokenUsages for details). +func (m *mockRecorderClient) RecordedInterceptions() []*aibridge.InterceptionRecord { + m.mu.Lock() + defer m.mu.Unlock() + return slices.Clone(m.interceptions) +} + // verify all recorded interceptions has been marked as completed func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) { t.Helper()