diff --git a/api.go b/api.go index 5c152e9..d628953 100644 --- a/api.go +++ b/api.go @@ -25,8 +25,12 @@ type TokenUsageRecord struct { InterceptionID string MsgID string Input, Output int64 - Metadata Metadata - CreatedAt time.Time + // ExtraTokenTypes holds token types which *may* exist over and above input/output. + // These should ultimately get merged into [Metadata], but it's useful to keep these + // with their actual type (int64) since [Metadata] is a map[string]any. + ExtraTokenTypes map[string]int64 + Metadata Metadata + CreatedAt time.Time } type PromptUsageRecord struct { diff --git a/bridge.go b/bridge.go index 4f23127..4f5428d 100644 --- a/bridge.go +++ b/bridge.go @@ -47,20 +47,20 @@ var _ http.Handler = &RequestBridge{} // A [Recorder] is also required to record prompt, tool, and token use. // // mcpProxy will be closed when the [RequestBridge] is closed. -func NewRequestBridge(ctx context.Context, providers []Provider, logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) (*RequestBridge, error) { +func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, logger slog.Logger) (*RequestBridge, error) { mux := http.NewServeMux() for _, provider := range providers { // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy)) + mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics)) } // Any requests which passthrough to this will be reverse-proxied to the upstream. // - // We have to whitelist the known-safe routes because an API key with elevant privileges (i.e. admin) might be + // We have to whitelist the known-safe routes because an API key with elevated privileges (i.e. admin) might be // configured, so we should just reverse-proxy known-safe routes. - ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name()))) + ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics) for _, path := range provider.PassthroughRoutes() { prefix := fmt.Sprintf("/%s", provider.Name()) route := fmt.Sprintf("%s%s", prefix, path) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 36f2a3e..f8ef57e 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -133,7 +133,7 @@ func TestAnthropicMessages(t *testing.T) { recorderClient := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}, logger, recorderClient, mcp.NewServerProxyManager(nil)) + b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -214,7 +214,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{ aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg), - }, logger, recorderClient, mcp.NewServerProxyManager(nil)) + }, recorderClient, mcp.NewServerProxyManager(nil), nil, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -312,7 +312,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge( ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)}, - logger, recorderClient, mcp.NewServerProxyManager(nil)) + recorderClient, mcp.NewServerProxyManager(nil), nil, logger) require.NoError(t, err) mockBridgeSrv := httptest.NewUnstartedServer(b) @@ -399,7 +399,7 @@ func TestOpenAIChatCompletions(t *testing.T) { recorderClient := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}, logger, recorderClient, mcp.NewServerProxyManager(nil)) + b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -466,7 +466,7 @@ func TestSimple(t *testing.T) { fixture: antSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, mcp.NewServerProxyManager(nil)) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -504,7 +504,7 @@ func TestSimple(t *testing.T) { fixture: oaiSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil)) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -645,7 +645,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil)) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger) require.NoError(t, err) return provider, bridge }, @@ -656,7 +656,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil)) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger) require.NoError(t, err) return provider, bridge }, @@ -762,7 +762,7 @@ func TestAnthropicInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) } // Build the requirements & make the assertions which are common to all providers. @@ -843,7 +843,7 @@ func TestOpenAIInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) } // Build the requirements & make the assertions which are common to all providers. @@ -1029,7 +1029,7 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1046,7 +1046,7 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1134,7 +1134,7 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1152,7 +1152,7 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1238,7 +1238,7 @@ func TestStableRequestEncoding(t *testing.T) { fixture: antSimple, createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) }, }, { @@ -1246,7 +1246,7 @@ func TestStableRequestEncoding(t *testing.T) { fixture: oaiSimple, createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) }, }, } @@ -1352,7 +1352,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { fixture: antSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, mcp.NewServerProxyManager(nil)) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger) }, createRequest: createAnthropicMessagesReq, envVars: map[string]string{ @@ -1365,7 +1365,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { fixture: oaiSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil)) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger) }, createRequest: createOpenAIChatCompletionsReq, envVars: map[string]string{ diff --git a/go.mod b/go.mod index 827a224..47fd45d 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,8 @@ require ( github.com/google/uuid v1.6.0 github.com/hashicorp/go-multierror v1.1.1 github.com/mark3labs/mcp-go v0.38.0 - github.com/stretchr/testify v1.10.0 + github.com/prometheus/client_golang v1.23.2 + github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 go.uber.org/goleak v1.3.0 @@ -40,18 +41,25 @@ require ( github.com/aws/smithy-go v1.20.3 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/beorn7/perks v1.0.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charmbracelet/lipgloss v0.7.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.15.2 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.66.1 // indirect + github.com/prometheus/procfs v0.16.1 // indirect github.com/rivo/uniseg v0.4.4 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/spf13/cast v1.7.1 // indirect @@ -61,11 +69,11 @@ require ( github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.opentelemetry.io/otel v1.33.0 // indirect go.opentelemetry.io/otel/trace v1.33.0 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/term v0.34.0 // indirect - golang.org/x/text v0.28.0 // indirect golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect - google.golang.org/protobuf v1.36.3 // indirect + google.golang.org/protobuf v1.36.8 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9ee76b3..d0b79c8 100644 --- a/go.sum +++ b/go.sum @@ -41,8 +41,12 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiE github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/charmbracelet/lipgloss v0.7.1 h1:17WMwi7N1b1rVWOjMT+rCh7sQkvDU75B2hbZpc5Kc1E= github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNWpFujkNawKNhE2c= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -53,8 +57,8 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= @@ -68,6 +72,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= @@ -83,10 +89,20 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/openai/openai-go/v2 v2.7.0 h1:/8MSFCXcasin7AyuWQ2au6FraXL71gzAs+VfbMv+J3k= github.com/openai/openai-go/v2 v2.7.0/go.mod h1:jrJs23apqJKKbT+pqtFgNKpRju/KP9zpUTZhz3GElQE= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= +github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= @@ -95,8 +111,8 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -126,6 +142,8 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0= golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= @@ -149,10 +167,10 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20240722135656-d784300faade h1: google.golang.org/genproto/googleapis/rpc v0.0.0-20240722135656-d784300faade/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= -google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU= -google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index 3aef2dd..a1f71e6 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -35,6 +35,10 @@ func (s *AnthropicMessagesBlockingInterception) Setup(logger slog.Logger, record s.AnthropicMessagesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy) } +func (s *AnthropicMessagesBlockingInterception) Streaming() bool { + return false +} + func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { if i.req == nil { return fmt.Errorf("developer error: req is nil") @@ -103,7 +107,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr MsgID: resp.ID, Input: resp.Usage.InputTokens, Output: resp.Usage.OutputTokens, - Metadata: Metadata{ + ExtraTokenTypes: map[string]int64{ "web_search_requests": resp.Usage.ServerToolUse.WebSearchRequests, "cache_creation_input": resp.Usage.CacheCreationInputTokens, "cache_read_input": resp.Usage.CacheReadInputTokens, diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index daa0352..ef8aabd 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -38,6 +38,10 @@ func (s *AnthropicMessagesStreamingInterception) Setup(logger slog.Logger, recor s.AnthropicMessagesInterceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy) } +func (s *AnthropicMessagesStreamingInterception) Streaming() bool { + return true +} + // ProcessRequest handles a request to /v1/messages. // This API has a state-machine behind it, which is described in https://docs.claude.com/en/docs/build-with-claude/streaming#event-types. // @@ -169,7 +173,7 @@ newStream: MsgID: message.ID, Input: start.Message.Usage.InputTokens, Output: start.Message.Usage.OutputTokens, - Metadata: Metadata{ + ExtraTokenTypes: map[string]int64{ "web_search_requests": start.Message.Usage.ServerToolUse.WebSearchRequests, "cache_creation_input": start.Message.Usage.CacheCreationInputTokens, "cache_read_input": start.Message.Usage.CacheReadInputTokens, diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index d9bb50e..757c933 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -35,6 +35,10 @@ func (s *OpenAIBlockingChatInterception) Setup(logger slog.Logger, recorder Reco s.OpenAIChatInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy) } +func (s *OpenAIBlockingChatInterception) Streaming() bool { + return false +} + func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { if i.req == nil { return fmt.Errorf("developer error: req is nil") @@ -83,7 +87,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r MsgID: completion.ID, Input: calculateActualInputTokenUsage(lastUsage), Output: lastUsage.CompletionTokens, - Metadata: Metadata{ + ExtraTokenTypes: map[string]int64{ "prompt_audio": lastUsage.PromptTokensDetails.AudioTokens, "prompt_cached": lastUsage.PromptTokensDetails.CachedTokens, "completion_accepted_prediction": lastUsage.CompletionTokensDetails.AcceptedPredictionTokens, diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 9798505..ccabb35 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -38,6 +38,10 @@ func (i *OpenAIStreamingChatInterception) Setup(logger slog.Logger, recorder Rec i.OpenAIChatInterceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy) } +func (i *OpenAIStreamingChatInterception) Streaming() bool { + return true +} + // ProcessRequest handles a request to /v1/chat/completions. // See https://platform.openai.com/docs/api-reference/chat-streaming/streaming. // @@ -161,7 +165,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, MsgID: processor.getMsgID(), Input: calculateActualInputTokenUsage(lastUsage), Output: lastUsage.CompletionTokens, - Metadata: Metadata{ + ExtraTokenTypes: map[string]int64{ "prompt_audio": lastUsage.PromptTokensDetails.AudioTokens, "prompt_cached": lastUsage.PromptTokensDetails.CachedTokens, "completion_accepted_prediction": lastUsage.CompletionTokensDetails.AcceptedPredictionTokens, diff --git a/interception.go b/interception.go index 2b9d08b..8210c41 100644 --- a/interception.go +++ b/interception.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "strings" "time" "cdr.dev/slog" @@ -18,11 +19,12 @@ type Interceptor interface { // Setup injects some required dependencies. This MUST be called before using the interceptor // to process requests. Setup(logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) - // Model returns the model in use for this [Interceptor]. Model() string // ProcessRequest handles the HTTP request. ProcessRequest(w http.ResponseWriter, r *http.Request) error + // Specifies whether an interceptor handles streaming or not. + Streaming() bool } var UnknownRoute = errors.New("unknown route") @@ -32,7 +34,7 @@ const recordingTimeout = time.Second * 5 // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] recorder. -func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) http.HandlerFunc { +func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { interceptor, err := p.CreateInterceptor(w, r) if err != nil { @@ -41,9 +43,12 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, return } - // Record usage in the background to not block request flow. - asyncRecorder := NewAsyncRecorder(logger, recorder, recordingTimeout) - interceptor.Setup(logger, asyncRecorder, mcpProxy) + if metrics != nil { + start := time.Now() + defer func() { + metrics.InterceptionDuration.WithLabelValues(p.Name(), interceptor.Model()).Observe(time.Since(start).Seconds()) + }() + } actor := actorFromContext(r.Context()) if actor == nil { @@ -52,6 +57,14 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, return } + // Record usage in the background to not block request flow. + asyncRecorder := NewAsyncRecorder(logger, recorder, recordingTimeout) + asyncRecorder.WithMetrics(metrics) + asyncRecorder.WithProvider(p.Name()) + asyncRecorder.WithModel(interceptor.Model()) + asyncRecorder.WithInitiatorID(actor.id) + interceptor.Setup(logger, asyncRecorder, mcpProxy) + if err := recorder.RecordInterception(r.Context(), &InterceptionRecord{ ID: interceptor.ID().String(), Metadata: actor.metadata, @@ -64,21 +77,37 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, return } + route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) log := logger.With( - slog.F("route", r.URL.Path), + slog.F("route", route), slog.F("provider", p.Name()), slog.F("interception_id", interceptor.ID()), slog.F("user_agent", r.UserAgent()), + slog.F("streaming", interceptor.Streaming()), ) log.Debug(r.Context(), "interception started") + if metrics != nil { + metrics.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Add(1) + } + if err := interceptor.ProcessRequest(w, r); err != nil { + if metrics != nil { + metrics.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), InterceptionCountStatusFailed, route, r.Method, actor.id).Add(1) + } log.Warn(r.Context(), "interception failed", slog.Error(err)) } else { + if metrics != nil { + metrics.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), InterceptionCountStatusCompleted, route, r.Method, actor.id).Add(1) + } log.Debug(r.Context(), "interception ended") } asyncRecorder.RecordInterceptionEnded(r.Context(), &InterceptionRecordEnded{ID: interceptor.ID().String()}) + if metrics != nil { + metrics.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Sub(1) + } + // Ensure all recording have completed before completing request. asyncRecorder.Wait() } diff --git a/metrics.go b/metrics.go new file mode 100644 index 0000000..32d5a78 --- /dev/null +++ b/metrics.go @@ -0,0 +1,106 @@ +package aibridge + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var baseLabels []string = []string{"provider", "model"} + +const ( + InterceptionCountStatusFailed = "failed" + InterceptionCountStatusCompleted = "completed" +) + +type Metrics struct { + // Interception-related metrics. + InterceptionDuration *prometheus.HistogramVec + InterceptionCount *prometheus.CounterVec + InterceptionsInflight *prometheus.GaugeVec + PassthroughCount *prometheus.CounterVec + + // Prompt-related metrics. + PromptCount *prometheus.CounterVec + + // Token-related metrics. + TokenUseCount *prometheus.CounterVec + + // Tool-related metrics. + InjectedToolUseCount *prometheus.CounterVec + NonInjectedToolUseCount *prometheus.CounterVec +} + +// NewMetrics creates AND registers metrics. It will panic if a collector has already been registered. +// Note: we are not specifying namespace in the metrics; the provided registerer may specify a "namespace" +// using [prometheus.WrapRegistererWithPrefix]. +func NewMetrics(reg prometheus.Registerer) *Metrics { + return &Metrics{ + // Interception-related metrics. + + // Pessimistic cardinality: 2 providers, 5 models, 2 statuses, 2 routes, 3 methods = up to 120 PER INITIATOR. + InterceptionCount: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "interceptions", + Name: "total", + Help: "The count of intercepted requests.", + }, append(baseLabels, "status", "route", "method", "initiator_id")), + // Pessimistic cardinality: 2 providers, 5 models, 2 routes = up to 20. + // NOTE: route is not unbounded because this is only for intercepted routes. + InterceptionsInflight: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ + Subsystem: "interceptions", + Name: "inflight", + Help: "The number of intercepted requests which are being processed.", + }, append(baseLabels, "route")), + // Pessimistic cardinality: 2 providers, 5 models, 7 buckets + 3 extra series (count, sum, +Inf) = up to 100. + InterceptionDuration: promauto.With(reg).NewHistogramVec(prometheus.HistogramOpts{ + Subsystem: "interceptions", + Name: "duration_seconds", + Help: "The total duration of intercepted requests, in seconds. " + + "The majority of this time will be the upstream processing of the request. " + + "aibridge has no control over upstream processing time, so it's just an illustrative metric.", + // TODO: add docs around determining aibridge's *own* latency with distributed traces + // once https://github.com/coder/aibridge/issues/26 lands. + Buckets: []float64{0.5, 2, 5, 15, 30, 60, 120}, + }, baseLabels), + + // Pessimistic cardinality: 2 providers, 10 routes, 3 methods = up to 60. + // NOTE: route is not unbounded because PassthroughRoutes (see provider.go) is a static list. + PassthroughCount: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "passthrough", + Name: "total", + Help: "The count of requests which were not intercepted but passed through to the upstream.", + }, []string{"provider", "route", "method"}), + + // Prompt-related metrics. + + // Pessimistic cardinality: 2 providers, 5 models = up to 10 PER INITIATOR. + PromptCount: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "prompts", + Name: "total", + Help: "The number of prompts issued by users (initiators).", + }, append(baseLabels, "initiator_id")), + + // Token-related metrics. + + // Pessimistic cardinality: 2 providers, 5 models, 10 types = up to 100 PER INITIATOR. + TokenUseCount: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "tokens", + Name: "total", + Help: "The number of tokens used by intercepted requests.", + }, append(baseLabels, "type", "initiator_id")), + + // Tool-related metrics. + + // Pessimistic cardinality: 2 providers, 5 models, 3 servers, 30 tools = up to 900. + InjectedToolUseCount: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "injected_tool_invocations", + Name: "total", + Help: "The number of times an injected MCP tool was invoked by aibridge.", + }, append(baseLabels, "server", "name")), + // Pessimistic cardinality: 2 providers, 5 models, 30 tools = up to 300. + NonInjectedToolUseCount: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "non_injected_tool_selections", + Name: "total", + Help: "The number of times an AI model selected a tool to be invoked by the client.", + }, append(baseLabels, "name")), + } +} diff --git a/metrics_integration_test.go b/metrics_integration_test.go new file mode 100644 index 0000000..3696de2 --- /dev/null +++ b/metrics_integration_test.go @@ -0,0 +1,292 @@ +package aibridge_test + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/aibridge" + "github.com/coder/aibridge/mcp" + "github.com/prometheus/client_golang/prometheus" + promtest "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/require" + "golang.org/x/tools/txtar" +) + +func TestMetrics_Interception(t *testing.T) { + t.Parallel() + + cases := []struct { + fixture []byte + expectedStatus string + }{ + { + fixture: antSimple, + expectedStatus: aibridge.InterceptionCountStatusCompleted, + }, + { + fixture: antNonStreamErr, + expectedStatus: aibridge.InterceptionCountStatusFailed, + }, + } + + for _, tc := range cases { + arc := txtar.Parse(tc.fixture) + files := filesMap(arc) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + + metrics := aibridge.NewMetrics(prometheus.NewRegistry()) + provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) + srv := newTestSrv(t, ctx, provider, metrics) + + req := createAnthropicMessagesReq(t, srv.URL, files[fixtureRequest]) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + count := promtest.ToFloat64(metrics.InterceptionCount.WithLabelValues( + aibridge.ProviderAnthropic, "claude-sonnet-4-0", tc.expectedStatus, "/v1/messages", "POST", userID)) + require.Equal(t, 1.0, count) + require.Equal(t, 1, promtest.CollectAndCount(metrics.InterceptionDuration)) + } +} + +func TestMetrics_InterceptionsInflight(t *testing.T) { + t.Parallel() + + arc := txtar.Parse(antSimple) + files := filesMap(arc) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + blockCh := make(chan struct{}) + + // Setup a mock HTTP server which blocks until the request is marked as inflight then proceeds. + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-blockCh + mock := newMockServer(ctx, t, files, nil) + defer mock.Close() + mock.Server.Config.Handler.ServeHTTP(w, r) + })) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + srv.Start() + t.Cleanup(srv.Close) + + metrics := aibridge.NewMetrics(prometheus.NewRegistry()) + provider := aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil) + bridgeSrv := newTestSrv(t, ctx, provider, metrics) + + // Make request in background. + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + req := createAnthropicMessagesReq(t, bridgeSrv.URL, files[fixtureRequest]) + resp, err := http.DefaultClient.Do(req) + if err == nil { + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + } + }() + + // Wait until request is detected as inflight. + require.Eventually(t, func() bool { + return promtest.ToFloat64( + metrics.InterceptionsInflight.WithLabelValues(aibridge.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), + ) == 1 + }, time.Second*10, time.Millisecond*50) + + // Unblock request, await completion. + close(blockCh) + select { + case <-doneCh: + case <-ctx.Done(): + t.Fatal(ctx.Err()) + } + + // Metric is not updated immediately after request completes, so wait until it is. + require.Eventually(t, func() bool { + return promtest.ToFloat64( + metrics.InterceptionsInflight.WithLabelValues(aibridge.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), + ) == 0 + }, time.Second*10, time.Millisecond*50) +} + +func TestMetrics_PassthroughCount(t *testing.T) { + t.Parallel() + + arc := txtar.Parse(oaiFallthrough) + files := filesMap(arc) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(files[fixtureResponse]) + })) + t.Cleanup(upstream.Close) + + metrics := aibridge.NewMetrics(prometheus.NewRegistry()) + provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) + srv := newTestSrv(t, t.Context(), provider, metrics) + + req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + count := promtest.ToFloat64(metrics.PassthroughCount.WithLabelValues( + aibridge.ProviderOpenAI, "/v1/models", "GET")) + require.Equal(t, 1.0, count) +} + +func TestMetrics_PromptCount(t *testing.T) { + t.Parallel() + + arc := txtar.Parse(oaiSimple) + files := filesMap(arc) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + + metrics := aibridge.NewMetrics(prometheus.NewRegistry()) + provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) + srv := newTestSrv(t, ctx, provider, metrics) + + req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + prompts := promtest.ToFloat64(metrics.PromptCount.WithLabelValues( + aibridge.ProviderOpenAI, "gpt-4.1", userID)) + require.Equal(t, 1.0, prompts) +} + +func TestMetrics_NonInjectedToolUseCount(t *testing.T) { + t.Parallel() + + arc := txtar.Parse(oaiSingleBuiltinTool) + files := filesMap(arc) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + + metrics := aibridge.NewMetrics(prometheus.NewRegistry()) + provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) + srv := newTestSrv(t, ctx, provider, metrics) + + req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + count := promtest.ToFloat64(metrics.NonInjectedToolUseCount.WithLabelValues( + aibridge.ProviderOpenAI, "gpt-4.1", "read_file")) + require.Equal(t, 1.0, count) +} + +func TestMetrics_InjectedToolUseCount(t *testing.T) { + t.Parallel() + + arc := txtar.Parse(antSingleInjectedTool) + files := filesMap(arc) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // First request returns the tool invocation, the second returns the mocked response to the tool result. + mockAPI := newMockServer(ctx, t, files, func(reqCount uint32, resp []byte) []byte { + if reqCount == 1 { + return resp + } + return files[fixtureNonStreamingToolResponse] + }) + t.Cleanup(mockAPI.Close) + + recorder := &mockRecorderClient{} + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + metrics := aibridge.NewMetrics(prometheus.NewRegistry()) + provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) + + // Setup mocked MCP server & tools. + tools := setupMCPServerProxiesForTest(t) + mcpMgr := mcp.NewServerProxyManager(tools) + require.NoError(t, mcpMgr.Init(ctx)) + + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, logger) + require.NoError(t, err) + + srv := httptest.NewUnstartedServer(bridge) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, userID, nil) + } + srv.Start() + t.Cleanup(srv.Close) + + req := createAnthropicMessagesReq(t, srv.URL, files[fixtureRequest]) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + // Wait until full roundtrip has completed. + require.Eventually(t, func() bool { + return mockAPI.callCount.Load() == 2 + }, time.Second*10, time.Millisecond*50) + + require.Len(t, recorder.toolUsages, 1) + require.True(t, recorder.toolUsages[0].Injected) + require.NotNil(t, recorder.toolUsages[0].ServerURL) + actualServerURL := *recorder.toolUsages[0].ServerURL + + count := promtest.ToFloat64(metrics.InjectedToolUseCount.WithLabelValues( + aibridge.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, mockToolName)) + require.Equal(t, 1.0, count) +} + +func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *aibridge.Metrics) *httptest.Server { + t.Helper() + + recorder := &mockRecorderClient{} + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcp.NewServerProxyManager(nil), metrics, logger) + require.NoError(t, err) + + srv := httptest.NewUnstartedServer(bridge) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, userID, nil) + } + srv.Start() + t.Cleanup(srv.Close) + + return srv +} diff --git a/passthrough.go b/passthrough.go index 5743bc1..6788672 100644 --- a/passthrough.go +++ b/passthrough.go @@ -12,8 +12,12 @@ import ( // newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically // by a [Provider]. -func newPassthroughRouter(provider Provider, logger slog.Logger) http.HandlerFunc { +func newPassthroughRouter(provider Provider, logger slog.Logger, metrics *Metrics) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + if metrics != nil { + metrics.PassthroughCount.WithLabelValues(provider.Name(), r.URL.Path, r.Method).Add(1) + } + upURL, err := url.Parse(provider.BaseURL()) if err != nil { logger.Warn(r.Context(), "failed to parse provider base URL", slog.Error(err)) diff --git a/recorder.go b/recorder.go index cf28387..edd68f7 100644 --- a/recorder.go +++ b/recorder.go @@ -104,6 +104,9 @@ type AsyncRecorder struct { logger slog.Logger wrapped Recorder timeout time.Duration + metrics *Metrics + + provider, model, initiatorID string wg sync.WaitGroup } @@ -112,6 +115,22 @@ func NewAsyncRecorder(logger slog.Logger, wrapped Recorder, timeout time.Duratio return &AsyncRecorder{logger: logger, wrapped: wrapped, timeout: timeout} } +func (a *AsyncRecorder) WithMetrics(metrics *Metrics) { + a.metrics = metrics +} + +func (a *AsyncRecorder) WithProvider(provider string) { + a.provider = provider +} + +func (a *AsyncRecorder) WithModel(model string) { + a.model = model +} + +func (a *AsyncRecorder) WithInitiatorID(initiatorID string) { + a.initiatorID = initiatorID +} + // RecordInterception must NOT be called asynchronously. // If an interception cannot be recorded, the whole request should fail. func (a *AsyncRecorder) RecordInterception(ctx context.Context, req *InterceptionRecord) error { @@ -145,6 +164,10 @@ func (a *AsyncRecorder) RecordPromptUsage(_ context.Context, req *PromptUsageRec if err != nil { a.logger.Warn(timedCtx, "failed to record usage", slog.F("type", "prompt"), slog.Error(err), slog.F("payload", req)) } + + if a.metrics != nil && req.Prompt != "" { // TODO: will be irrelevant once https://github.com/coder/aibridge/issues/55 is fixed. + a.metrics.PromptCount.WithLabelValues(a.provider, a.model, a.initiatorID).Add(1) + } }() return nil // Caller is not interested in error. @@ -161,6 +184,14 @@ func (a *AsyncRecorder) RecordTokenUsage(_ context.Context, req *TokenUsageRecor if err != nil { a.logger.Warn(timedCtx, "failed to record usage", slog.F("type", "token"), slog.Error(err), slog.F("payload", req)) } + + if a.metrics != nil { + a.metrics.TokenUseCount.WithLabelValues(a.provider, a.model, "input", a.initiatorID).Add(float64(req.Input)) + a.metrics.TokenUseCount.WithLabelValues(a.provider, a.model, "output", a.initiatorID).Add(float64(req.Output)) + for k, v := range req.ExtraTokenTypes { + a.metrics.TokenUseCount.WithLabelValues(a.provider, a.model, k, a.initiatorID).Add(float64(v)) + } + } }() return nil // Caller is not interested in error. @@ -177,6 +208,18 @@ func (a *AsyncRecorder) RecordToolUsage(_ context.Context, req *ToolUsageRecord) if err != nil { a.logger.Warn(timedCtx, "failed to record usage", slog.F("type", "tool"), slog.Error(err), slog.F("payload", req)) } + + if a.metrics != nil { + if req.Injected { + var srvURL string + if req.ServerURL != nil { + srvURL = *req.ServerURL + } + a.metrics.InjectedToolUseCount.WithLabelValues(a.provider, a.model, srvURL, req.Tool).Add(1) + } else { + a.metrics.NonInjectedToolUseCount.WithLabelValues(a.provider, a.model, req.Tool).Add(1) + } + } }() return nil // Caller is not interested in error.