Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions intercept_anthropic_message_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,155 @@ import (
"github.com/stretchr/testify/require"
)

func TestInjectTools_CacheBreakpoints(t *testing.T) {
t.Parallel()

t.Run("cache control preserved when no tools to inject", func(t *testing.T) {
t.Parallel()

// Request has existing tool with cache control, but no tools to inject.
i := &AnthropicMessagesInterceptionBase{
req: &MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Tools: []anthropic.ToolUnionParam{
{
OfTool: &anthropic.ToolParam{
Name: "existing_tool",
CacheControl: anthropic.CacheControlEphemeralParam{
Type: constant.ValueOf[constant.Ephemeral](),
},
},
},
},
},
},
mcpProxy: &mockServerProxier{tools: nil},
}

i.injectTools()

// Cache control should remain untouched since no tools were injected.
require.Len(t, i.req.Tools, 1)
require.Equal(t, constant.ValueOf[constant.Ephemeral](), i.req.Tools[0].OfTool.CacheControl.Type)
})

t.Run("cache control breakpoint is preserved by prepending injected tools", func(t *testing.T) {
t.Parallel()

// Request has existing tool with cache control.
i := &AnthropicMessagesInterceptionBase{
req: &MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Tools: []anthropic.ToolUnionParam{
{
OfTool: &anthropic.ToolParam{
Name: "existing_tool",
CacheControl: anthropic.CacheControlEphemeralParam{
Type: constant.ValueOf[constant.Ephemeral](),
},
},
},
},
},
},
mcpProxy: &mockServerProxier{
tools: []*mcp.Tool{
{ID: "injected_tool", Name: "injected", Description: "Injected tool"},
},
},
}

i.injectTools()

require.Len(t, i.req.Tools, 2)
// Injected tools are prepended.
require.Equal(t, "injected_tool", i.req.Tools[0].OfTool.Name)
require.Zero(t, i.req.Tools[0].OfTool.CacheControl)
// Original tool's cache control should be preserved at the end.
require.Equal(t, "existing_tool", i.req.Tools[1].OfTool.Name)
require.Equal(t, constant.ValueOf[constant.Ephemeral](), i.req.Tools[1].OfTool.CacheControl.Type)
})

// The cache breakpoint SHOULD be on the final tool, but may not be; we must preserve that intention.
t.Run("cache control breakpoint in non-standard location is preserved", func(t *testing.T) {
t.Parallel()

// Request has multiple tools with cache control breakpoints.
i := &AnthropicMessagesInterceptionBase{
req: &MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Tools: []anthropic.ToolUnionParam{
{
OfTool: &anthropic.ToolParam{
Name: "tool_with_cache_1",
CacheControl: anthropic.CacheControlEphemeralParam{
Type: constant.ValueOf[constant.Ephemeral](),
},
},
},
{
OfTool: &anthropic.ToolParam{
Name: "tool_with_cache_2",
},
},
},
},
},
mcpProxy: &mockServerProxier{
tools: []*mcp.Tool{
{ID: "injected_tool", Name: "injected", Description: "Injected tool"},
},
},
}

i.injectTools()

require.Len(t, i.req.Tools, 3)
// Injected tool is prepended without cache control.
require.Equal(t, "injected_tool", i.req.Tools[0].OfTool.Name)
require.Zero(t, i.req.Tools[0].OfTool.CacheControl)
// Both original tools' cache controls should remain.
require.Equal(t, "tool_with_cache_1", i.req.Tools[1].OfTool.Name)
require.Equal(t, constant.ValueOf[constant.Ephemeral](), i.req.Tools[1].OfTool.CacheControl.Type)
require.Equal(t, "tool_with_cache_2", i.req.Tools[2].OfTool.Name)
require.Zero(t, i.req.Tools[2].OfTool.CacheControl)
})

t.Run("no cache control added when none originally set", func(t *testing.T) {
t.Parallel()

// Request has tools but none with cache control.
i := &AnthropicMessagesInterceptionBase{
req: &MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Tools: []anthropic.ToolUnionParam{
{
OfTool: &anthropic.ToolParam{
Name: "existing_tool_no_cache",
},
},
},
},
},
mcpProxy: &mockServerProxier{
tools: []*mcp.Tool{
{ID: "injected_tool", Name: "injected", Description: "Injected tool"},
},
},
}

i.injectTools()

require.Len(t, i.req.Tools, 2)
// Injected tool is prepended without cache control.
require.Equal(t, "injected_tool", i.req.Tools[0].OfTool.Name)
require.Zero(t, i.req.Tools[0].OfTool.CacheControl)
// Original tool remains at the end without cache control.
require.Equal(t, "existing_tool_no_cache", i.req.Tools[1].OfTool.Name)
require.Zero(t, i.req.Tools[1].OfTool.CacheControl)
})
}

func TestInjectTools_ParallelToolCalls(t *testing.T) {
t.Parallel()

Expand Down
8 changes: 7 additions & 1 deletion intercept_anthropic_messages_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ func (i *AnthropicMessagesInterceptionBase) injectTools() {
}

// Inject tools.
var injectedTools []anthropic.ToolUnionParam
for _, tool := range tools {
i.req.Tools = append(i.req.Tools, anthropic.ToolUnionParam{
injectedTools = append(injectedTools, anthropic.ToolUnionParam{
OfTool: &anthropic.ToolParam{
InputSchema: anthropic.ToolInputSchemaParam{
Properties: tool.Params,
Expand All @@ -102,6 +103,11 @@ func (i *AnthropicMessagesInterceptionBase) injectTools() {
})
}

// Prepend the injected tools in order to maintain any configured cache breakpoints.
// The order of injected tools is expected to be stable, and therefore will not cause
// any cache invalidation when prepended.
i.req.Tools = append(injectedTools, i.req.Tools...)

// Note: Parallel tool calls are disabled to avoid tool_use/tool_result block mismatches.
// https://github.com/coder/aibridge/issues/2
toolChoiceType := i.req.ToolChoice.GetType()
Expand Down
2 changes: 1 addition & 1 deletion mcp/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type ServerProxier interface {
// See https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#session-management.
Shutdown(ctx context.Context) error

// ListTools lists all known tools.
// ListTools lists all known tools. These MUST be sorted in a stable order.
ListTools() []*Tool
// GetTool returns a given tool, if known, or returns nil.
GetTool(id string) *Tool
Expand Down