Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 86 additions & 37 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,21 +171,24 @@ func TestAnthropicMessages(t *testing.T) {
// One for message_start, one for message_delta.
expectedTokenRecordings = 2
}
require.Len(t, recorderClient.tokenUsages, expectedTokenRecordings)
tokenUsages := recorderClient.RecordedTokenUsages()
require.Len(t, tokenUsages, expectedTokenRecordings)

assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(recorderClient.tokenUsages), "input tokens miscalculated")
assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(recorderClient.tokenUsages), "output tokens miscalculated")
assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated")
assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated")

require.Len(t, recorderClient.toolUsages, 1)
assert.Equal(t, "Read", recorderClient.toolUsages[0].Tool)
require.IsType(t, json.RawMessage{}, recorderClient.toolUsages[0].Args)
toolUsages := recorderClient.RecordedToolUsages()
require.Len(t, toolUsages, 1)
assert.Equal(t, "Read", toolUsages[0].Tool)
require.IsType(t, json.RawMessage{}, toolUsages[0].Args)
var args map[string]any
require.NoError(t, json.Unmarshal(recorderClient.toolUsages[0].Args.(json.RawMessage), &args))
require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args))
require.Contains(t, args, "file_path")
assert.Equal(t, "/tmp/blah/foo", args["file_path"])

require.Len(t, recorderClient.userPrompts, 1)
assert.Equal(t, "read the foo file", recorderClient.userPrompts[0].Prompt)
promptUsages := recorderClient.RecordedPromptUsages()
require.Len(t, promptUsages, 1)
assert.Equal(t, "read the foo file", promptUsages[0].Prompt)

recorderClient.verifyAllInterceptionsEnded(t)
})
Expand Down Expand Up @@ -346,8 +349,9 @@ func TestAWSBedrockIntegration(t *testing.T) {
// and the interception data.
require.Equal(t, requestCount, 1)
require.Equal(t, bedrockCfg.Model, receivedModelName)
require.Len(t, recorderClient.interceptions, 1)
require.Equal(t, recorderClient.interceptions[0].Model, bedrockCfg.Model)
interceptions := recorderClient.RecordedInterceptions()
require.Len(t, interceptions, 1)
require.Equal(t, interceptions[0].Model, bedrockCfg.Model)
recorderClient.verifyAllInterceptionsEnded(t)
})
}
Expand Down Expand Up @@ -437,18 +441,21 @@ func TestOpenAIChatCompletions(t *testing.T) {
assert.Equal(t, "[DONE]", lastEvent.Data)
}

require.Len(t, recorderClient.tokenUsages, 1)
assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(recorderClient.tokenUsages), "input tokens miscalculated")
assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(recorderClient.tokenUsages), "output tokens miscalculated")
tokenUsages := recorderClient.RecordedTokenUsages()
require.Len(t, tokenUsages, 1)
assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated")
assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated")

require.Len(t, recorderClient.toolUsages, 1)
assert.Equal(t, "read_file", recorderClient.toolUsages[0].Tool)
require.IsType(t, map[string]any{}, recorderClient.toolUsages[0].Args)
require.Contains(t, recorderClient.toolUsages[0].Args, "path")
assert.Equal(t, "README.md", recorderClient.toolUsages[0].Args.(map[string]any)["path"])
toolUsages := recorderClient.RecordedToolUsages()
require.Len(t, toolUsages, 1)
assert.Equal(t, "read_file", toolUsages[0].Tool)
require.IsType(t, map[string]any{}, toolUsages[0].Args)
require.Contains(t, toolUsages[0].Args, "path")
assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"])

require.Len(t, recorderClient.userPrompts, 1)
assert.Equal(t, "how large is the README.md file in my current path", recorderClient.userPrompts[0].Prompt)
promptUsages := recorderClient.RecordedPromptUsages()
require.Len(t, promptUsages, 1)
assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt)

recorderClient.verifyAllInterceptionsEnded(t)
})
Expand Down Expand Up @@ -605,8 +612,9 @@ func TestSimple(t *testing.T) {
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))

// Then: I expect the prompt to have been tracked.
require.NotEmpty(t, recorderClient.userPrompts, "no prompts tracked")
assert.Contains(t, recorderClient.userPrompts[0].Prompt, "how many angels can dance on the head of a pin")
promptUsages := recorderClient.RecordedPromptUsages()
require.NotEmpty(t, promptUsages, "no prompts tracked")
assert.Contains(t, promptUsages[0].Prompt, "how many angels can dance on the head of a pin")

// Validate that responses have their IDs overridden with a interception ID rather than the original ID from the upstream provider.
// The reason for this is that Bridge may make multiple upstream requests (i.e. to invoke injected tools), and clients will not be expecting
Expand All @@ -615,8 +623,9 @@ func TestSimple(t *testing.T) {
require.NoError(t, err, "failed to retrieve response ID")
require.Nilf(t, uuid.Validate(id), "%s is not a valid UUID", id)

require.GreaterOrEqual(t, len(recorderClient.tokenUsages), 1)
require.Equal(t, recorderClient.tokenUsages[0].MsgID, tc.expectedMsgID)
tokenUsages := recorderClient.RecordedTokenUsages()
require.GreaterOrEqual(t, len(tokenUsages), 1)
require.Equal(t, tokenUsages[0].MsgID, tc.expectedMsgID)

recorderClient.verifyAllInterceptionsEnded(t)
})
Expand Down Expand Up @@ -770,11 +779,12 @@ func TestAnthropicInjectedTools(t *testing.T) {
recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)

// Ensure expected tool was invoked with expected input.
require.Len(t, recorderClient.toolUsages, 1)
require.Equal(t, mockToolName, recorderClient.toolUsages[0].Tool)
toolUsages := recorderClient.RecordedToolUsages()
require.Len(t, toolUsages, 1)
require.Equal(t, mockToolName, toolUsages[0].Tool)
expected, err := json.Marshal(map[string]any{"owner": "admin"})
require.NoError(t, err)
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
actual, err := json.Marshal(toolUsages[0].Args)
require.NoError(t, err)
require.EqualValues(t, expected, actual)
invocations := mcpCalls.getCallsByTool(mockToolName)
Expand Down Expand Up @@ -831,11 +841,13 @@ func TestAnthropicInjectedTools(t *testing.T) {
assert.EqualValues(t, 204, message.Usage.OutputTokens)

// Ensure tokens used during injected tool invocation are accounted for.
assert.EqualValues(t, 15308, calculateTotalInputTokens(recorderClient.tokenUsages))
assert.EqualValues(t, 204, calculateTotalOutputTokens(recorderClient.tokenUsages))
tokenUsages := recorderClient.RecordedTokenUsages()
assert.EqualValues(t, 15308, calculateTotalInputTokens(tokenUsages))
assert.EqualValues(t, 204, calculateTotalOutputTokens(tokenUsages))

// Ensure we received exactly one prompt.
require.Len(t, recorderClient.userPrompts, 1)
promptUsages := recorderClient.RecordedPromptUsages()
require.Len(t, promptUsages, 1)
})
}
}
Expand All @@ -857,11 +869,12 @@ func TestOpenAIInjectedTools(t *testing.T) {
recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)

// Ensure expected tool was invoked with expected input.
require.Len(t, recorderClient.toolUsages, 1)
require.Equal(t, mockToolName, recorderClient.toolUsages[0].Tool)
toolUsages := recorderClient.RecordedToolUsages()
require.Len(t, toolUsages, 1)
require.Equal(t, mockToolName, toolUsages[0].Tool)
expected, err := json.Marshal(map[string]any{"owner": "admin"})
require.NoError(t, err)
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
actual, err := json.Marshal(toolUsages[0].Args)
require.NoError(t, err)
require.EqualValues(t, expected, actual)
invocations := mcpCalls.getCallsByTool(mockToolName)
Expand Down Expand Up @@ -933,11 +946,13 @@ func TestOpenAIInjectedTools(t *testing.T) {
assert.EqualValues(t, 105, message.Usage.CompletionTokens)

// Ensure tokens used during injected tool invocation are accounted for.
require.EqualValues(t, 5047, calculateTotalInputTokens(recorderClient.tokenUsages))
require.EqualValues(t, 105, calculateTotalOutputTokens(recorderClient.tokenUsages))
tokenUsages := recorderClient.RecordedTokenUsages()
require.EqualValues(t, 5047, calculateTotalInputTokens(tokenUsages))
require.EqualValues(t, 105, calculateTotalOutputTokens(tokenUsages))

// Ensure we received exactly one prompt.
require.Len(t, recorderClient.userPrompts, 1)
promptUsages := recorderClient.RecordedPromptUsages()
require.Len(t, promptUsages, 1)
})
}
}
Expand Down Expand Up @@ -1822,6 +1837,40 @@ func (m *mockRecorderClient) RecordToolUsage(ctx context.Context, req *aibridge.
return nil
}

// RecordedTokenUsages returns a copy of recorded token usages in a thread-safe manner.
// Note: This is a shallow clone - the slice is copied but the pointers reference the
// same underlying records. This is sufficient for our test assertions which only read
// the data and don't modify the records.
func (m *mockRecorderClient) RecordedTokenUsages() []*aibridge.TokenUsageRecord {
m.mu.Lock()
defer m.mu.Unlock()
return slices.Clone(m.tokenUsages)
}

// RecordedPromptUsages returns a copy of recorded prompt usages in a thread-safe manner.
// Note: This is a shallow clone (see RecordedTokenUsages for details).
func (m *mockRecorderClient) RecordedPromptUsages() []*aibridge.PromptUsageRecord {
m.mu.Lock()
defer m.mu.Unlock()
return slices.Clone(m.userPrompts)
}

// RecordedToolUsages returns a copy of recorded tool usages in a thread-safe manner.
// Note: This is a shallow clone (see RecordedTokenUsages for details).
func (m *mockRecorderClient) RecordedToolUsages() []*aibridge.ToolUsageRecord {
m.mu.Lock()
defer m.mu.Unlock()
return slices.Clone(m.toolUsages)
}

// RecordedInterceptions returns a copy of recorded interceptions in a thread-safe manner.
// Note: This is a shallow clone (see RecordedTokenUsages for details).
func (m *mockRecorderClient) RecordedInterceptions() []*aibridge.InterceptionRecord {
m.mu.Lock()
defer m.mu.Unlock()
return slices.Clone(m.interceptions)
}

// verify all recorded interceptions has been marked as completed
func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) {
t.Helper()
Expand Down