Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ type TokenUsageRecord struct {
InterceptionID string
MsgID string
Input, Output int64
Metadata Metadata
CreatedAt time.Time
// ExtraTokenTypes holds token types which *may* exist over and above input/output.
// These should ultimately get merged into [Metadata], but it's useful to keep these
// with their actual type (int64) since [Metadata] is a map[string]any.
ExtraTokenTypes map[string]int64
Metadata Metadata
CreatedAt time.Time
}

type PromptUsageRecord struct {
Expand Down
8 changes: 4 additions & 4 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ var _ http.Handler = &RequestBridge{}
// A [Recorder] is also required to record prompt, tool, and token use.
//
// mcpProxy will be closed when the [RequestBridge] is closed.
func NewRequestBridge(ctx context.Context, providers []Provider, logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) (*RequestBridge, error) {
func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, logger slog.Logger) (*RequestBridge, error) {
mux := http.NewServeMux()

for _, provider := range providers {
// Add the known provider-specific routes which are bridged (i.e. intercepted and augmented).
for _, path := range provider.BridgedRoutes() {
mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy))
mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics))
}

// Any requests which passthrough to this will be reverse-proxied to the upstream.
//
// We have to whitelist the known-safe routes because an API key with elevant privileges (i.e. admin) might be
// We have to whitelist the known-safe routes because an API key with elevated privileges (i.e. admin) might be
// configured, so we should just reverse-proxy known-safe routes.
ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())))
ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics)
for _, path := range provider.PassthroughRoutes() {
prefix := fmt.Sprintf("/%s", provider.Name())
route := fmt.Sprintf("%s%s", prefix, path)
Expand Down
36 changes: 18 additions & 18 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func TestAnthropicMessages(t *testing.T) {
recorderClient := &mockRecorderClient{}

logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}, logger, recorderClient, mcp.NewServerProxyManager(nil))
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger)
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
Expand Down Expand Up @@ -214,7 +214,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{
aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg),
}, logger, recorderClient, mcp.NewServerProxyManager(nil))
}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger)
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
Expand Down Expand Up @@ -312,7 +312,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
b, err := aibridge.NewRequestBridge(
ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)},
logger, recorderClient, mcp.NewServerProxyManager(nil))
recorderClient, mcp.NewServerProxyManager(nil), nil, logger)
require.NoError(t, err)

mockBridgeSrv := httptest.NewUnstartedServer(b)
Expand Down Expand Up @@ -399,7 +399,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
recorderClient := &mockRecorderClient{}

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}, logger, recorderClient, mcp.NewServerProxyManager(nil))
b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger)
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
Expand Down Expand Up @@ -466,7 +466,7 @@ func TestSimple(t *testing.T) {
fixture: antSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, mcp.NewServerProxyManager(nil))
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger)
},
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
if streaming {
Expand Down Expand Up @@ -504,7 +504,7 @@ func TestSimple(t *testing.T) {
fixture: oaiSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil))
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger)
},
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
if streaming {
Expand Down Expand Up @@ -645,7 +645,7 @@ func TestFallthrough(t *testing.T) {
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger)
require.NoError(t, err)
return provider, bridge
},
Expand All @@ -656,7 +656,7 @@ func TestFallthrough(t *testing.T) {
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger)
require.NoError(t, err)
return provider, bridge
},
Expand Down Expand Up @@ -762,7 +762,7 @@ func TestAnthropicInjectedTools(t *testing.T) {

configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger)
}

// Build the requirements & make the assertions which are common to all providers.
Expand Down Expand Up @@ -843,7 +843,7 @@ func TestOpenAIInjectedTools(t *testing.T) {

configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger)
}

// Build the requirements & make the assertions which are common to all providers.
Expand Down Expand Up @@ -1029,7 +1029,7 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createAnthropicMessagesReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger)
},
responseHandlerFn: func(resp *http.Response) {
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
Expand All @@ -1046,7 +1046,7 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger)
},
responseHandlerFn: func(resp *http.Response) {
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
Expand Down Expand Up @@ -1134,7 +1134,7 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createAnthropicMessagesReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger)
},
responseHandlerFn: func(resp *http.Response) {
// Server responds first with 200 OK then starts streaming.
Expand All @@ -1152,7 +1152,7 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger)
},
responseHandlerFn: func(resp *http.Response) {
// Server responds first with 200 OK then starts streaming.
Expand Down Expand Up @@ -1238,15 +1238,15 @@ func TestStableRequestEncoding(t *testing.T) {
fixture: antSimple,
createRequestFunc: createAnthropicMessagesReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger)
},
},
{
name: aibridge.ProviderOpenAI,
fixture: oaiSimple,
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger)
},
},
}
Expand Down Expand Up @@ -1352,7 +1352,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
fixture: antSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, mcp.NewServerProxyManager(nil))
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger)
},
createRequest: createAnthropicMessagesReq,
envVars: map[string]string{
Expand All @@ -1365,7 +1365,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
fixture: oaiSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil))
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger)
},
createRequest: createOpenAIChatCompletionsReq,
envVars: map[string]string{
Expand Down
14 changes: 11 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ require (
github.com/google/uuid v1.6.0
github.com/hashicorp/go-multierror v1.1.1
github.com/mark3labs/mcp-go v0.38.0
github.com/stretchr/testify v1.10.0
github.com/prometheus/client_golang v1.23.2
github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
go.uber.org/goleak v1.3.0
Expand Down Expand Up @@ -40,18 +41,25 @@ require (
github.com/aws/smithy-go v1.20.3 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/lipgloss v0.7.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.15.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/rivo/uniseg v0.4.4 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/spf13/cast v1.7.1 // indirect
Expand All @@ -61,11 +69,11 @@ require (
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opentelemetry.io/otel v1.33.0 // indirect
go.opentelemetry.io/otel/trace v1.33.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/term v0.34.0 // indirect
golang.org/x/text v0.28.0 // indirect
golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect
google.golang.org/protobuf v1.36.3 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading