diff --git a/bridge.go b/bridge.go index 9f2c424..8a26aee 100644 --- a/bridge.go +++ b/bridge.go @@ -9,9 +9,9 @@ import ( "cdr.dev/slog" "github.com/coder/aibridge/mcp" - "go.opentelemetry.io/otel/trace" - "github.com/hashicorp/go-multierror" + "github.com/sony/gobreaker/v2" + "go.opentelemetry.io/otel/trace" ) // RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs; @@ -48,13 +48,38 @@ var _ http.Handler = &RequestBridge{} // A [Recorder] is also required to record prompt, tool, and token use. // // mcpProxy will be closed when the [RequestBridge] is closed. +// +// Circuit breaker configuration is obtained from each provider's CircuitBreakerConfig() method. +// Providers returning nil will not have circuit breaker protection. func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) (*RequestBridge, error) { mux := http.NewServeMux() for _, provider := range providers { + // Create per-provider circuit breaker if configured + cfg := provider.CircuitBreakerConfig() + providerName := provider.Name() + onChange := func(endpoint string, from, to gobreaker.State) { + logger.Info(context.Background(), "circuit breaker state change", + slog.F("provider", providerName), + slog.F("endpoint", endpoint), + slog.F("from", from.String()), + slog.F("to", to.String()), + ) + if cfg != nil && metrics != nil { + metrics.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(stateToGaugeValue(to)) + if to == gobreaker.StateOpen { + metrics.CircuitBreakerTrips.WithLabelValues(providerName, endpoint).Inc() + } + } + } + cbs := NewProviderCircuitBreakers(providerName, cfg, onChange) + // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer)) + handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) + // Wrap with circuit breaker middleware (nil cbs passes through) + wrapped := CircuitBreakerMiddleware(cbs, metrics)(handler) + mux.Handle(path, wrapped) } // Any requests which passthrough to this will be reverse-proxied to the upstream. diff --git a/circuit_breaker.go b/circuit_breaker.go new file mode 100644 index 0000000..50a3604 --- /dev/null +++ b/circuit_breaker.go @@ -0,0 +1,190 @@ +package aibridge + +import ( + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/sony/gobreaker/v2" +) + +// CircuitBreakerConfig holds configuration for circuit breakers. +// Fields match gobreaker.Settings for clarity. +type CircuitBreakerConfig struct { + // MaxRequests is the maximum number of requests allowed in half-open state. + MaxRequests uint32 + // Interval is the cyclic period of the closed state for clearing internal counts. + Interval time.Duration + // Timeout is how long the circuit stays open before transitioning to half-open. + Timeout time.Duration + // FailureThreshold is the number of consecutive failures that triggers the circuit to open. + FailureThreshold uint32 + // IsFailure determines if a status code should count as a failure. + // If nil, defaults to 429, 503, and 529 (Anthropic overloaded). + IsFailure func(statusCode int) bool +} + +// DefaultCircuitBreakerConfig returns sensible defaults for circuit breaker configuration. +func DefaultCircuitBreakerConfig() CircuitBreakerConfig { + return CircuitBreakerConfig{ + FailureThreshold: 5, + Interval: 10 * time.Second, + Timeout: 30 * time.Second, + MaxRequests: 3, + IsFailure: DefaultIsFailure, + } +} + +// DefaultIsFailure returns true for status codes that typically indicate +// upstream overload: 429 (Too Many Requests), 503 (Service Unavailable), +// and 529 (Anthropic Overloaded). +func DefaultIsFailure(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests, // 429 + http.StatusServiceUnavailable, // 503 + 529: // Anthropic "Overloaded" + return true + default: + return false + } +} + +// ProviderCircuitBreakers manages per-endpoint circuit breakers for a single provider. +type ProviderCircuitBreakers struct { + provider string + config CircuitBreakerConfig + breakers sync.Map // endpoint -> *gobreaker.CircuitBreaker[struct{}] + onChange func(endpoint string, from, to gobreaker.State) +} + +// NewProviderCircuitBreakers creates circuit breakers for a single provider. +// Returns nil if config is nil (no circuit breaker protection). +func NewProviderCircuitBreakers(provider string, config *CircuitBreakerConfig, onChange func(endpoint string, from, to gobreaker.State)) *ProviderCircuitBreakers { + if config == nil { + return nil + } + if config.IsFailure == nil { + config.IsFailure = DefaultIsFailure + } + return &ProviderCircuitBreakers{ + provider: provider, + config: *config, + onChange: onChange, + } +} + +// Get returns the circuit breaker for an endpoint, creating it if needed. +func (p *ProviderCircuitBreakers) Get(endpoint string) *gobreaker.CircuitBreaker[struct{}] { + if v, ok := p.breakers.Load(endpoint); ok { + return v.(*gobreaker.CircuitBreaker[struct{}]) + } + + settings := gobreaker.Settings{ + Name: p.provider + ":" + endpoint, + MaxRequests: p.config.MaxRequests, + Interval: p.config.Interval, + Timeout: p.config.Timeout, + ReadyToTrip: func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures >= p.config.FailureThreshold + }, + OnStateChange: func(_ string, from, to gobreaker.State) { + if p.onChange != nil { + p.onChange(endpoint, from, to) + } + }, + } + + cb := gobreaker.NewCircuitBreaker[struct{}](settings) + actual, _ := p.breakers.LoadOrStore(endpoint, cb) + return actual.(*gobreaker.CircuitBreaker[struct{}]) +} + +// statusCapturingWriter wraps http.ResponseWriter to capture the status code. +// It also implements http.Flusher to support streaming responses. +type statusCapturingWriter struct { + http.ResponseWriter + statusCode int + headerWritten bool +} + +func (w *statusCapturingWriter) WriteHeader(code int) { + if !w.headerWritten { + w.statusCode = code + w.headerWritten = true + } + w.ResponseWriter.WriteHeader(code) +} + +func (w *statusCapturingWriter) Write(b []byte) (int, error) { + if !w.headerWritten { + w.statusCode = http.StatusOK + w.headerWritten = true + } + return w.ResponseWriter.Write(b) +} + +func (w *statusCapturingWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// Unwrap returns the underlying ResponseWriter for interface checks. +func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +// CircuitBreakerMiddleware returns middleware that wraps handlers with circuit breaker protection. +// It captures the response status code to determine success/failure without provider-specific logic. +// If cbs is nil, requests pass through without circuit breaker protection. +func CircuitBreakerMiddleware(cbs *ProviderCircuitBreakers, metrics *Metrics) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + // No circuit breaker configured - pass through + if cbs == nil { + return next + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + endpoint := strings.TrimPrefix(r.URL.Path, "/"+cbs.provider) + cb := cbs.Get(endpoint) + + // Wrap response writer to capture status code + sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} + + _, err := cb.Execute(func() (struct{}, error) { + next.ServeHTTP(sw, r) + if cbs.config.IsFailure(sw.statusCode) { + return struct{}{}, fmt.Errorf("upstream error: %d", sw.statusCode) + } + return struct{}{}, nil + }) + + if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { + if metrics != nil { + metrics.CircuitBreakerRejects.WithLabelValues(cbs.provider, endpoint).Inc() + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(`{"type":"error","error":{"type":"circuit_breaker_open","message":"circuit breaker is open"}}`)) + } + }) + } +} + +// stateToGaugeValue converts gobreaker.State to a gauge value. +// closed=0, half-open=0.5, open=1 +func stateToGaugeValue(s gobreaker.State) float64 { + switch s { + case gobreaker.StateClosed: + return 0 + case gobreaker.StateHalfOpen: + return 0.5 + case gobreaker.StateOpen: + return 1 + default: + return 0 + } +} diff --git a/circuit_breaker_integration_test.go b/circuit_breaker_integration_test.go new file mode 100644 index 0000000..061bcf2 --- /dev/null +++ b/circuit_breaker_integration_test.go @@ -0,0 +1,296 @@ +package aibridge_test + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/aibridge" + "github.com/coder/aibridge/mcp" + "github.com/prometheus/client_golang/prometheus" + promtest "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" +) + +func TestCircuitBreaker_WithNewRequestBridge(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + providerName string + endpoint string + errorBody string + successBody string + requestBody string + setupHeaders func(req *http.Request) + createProvider func(baseURL string, cbConfig *aibridge.CircuitBreakerConfig) aibridge.Provider + } + + tests := []testCase{ + { + name: "Anthropic", + providerName: aibridge.ProviderAnthropic, + endpoint: "/v1/messages", + errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, + successBody: `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, + requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, + setupHeaders: func(req *http.Request) { + req.Header.Set("x-api-key", "test") + req.Header.Set("anthropic-version", "2023-06-01") + }, + createProvider: func(baseURL string, cbConfig *aibridge.CircuitBreakerConfig) aibridge.Provider { + return aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{ + BaseURL: baseURL, + Key: "test-key", + CircuitBreaker: cbConfig, + }, nil) + }, + }, + { + name: "OpenAI", + providerName: aibridge.ProviderOpenAI, + endpoint: "/v1/chat/completions", + errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, + successBody: `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, + requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, + setupHeaders: func(req *http.Request) { + req.Header.Set("Authorization", "Bearer test-key") + }, + createProvider: func(baseURL string, cbConfig *aibridge.CircuitBreakerConfig) aibridge.Provider { + return aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{ + BaseURL: baseURL, + Key: "test-key", + CircuitBreaker: cbConfig, + }) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + var shouldFail atomic.Bool + shouldFail.Store(true) + + // Mock upstream that returns 429 or 200 based on shouldFail flag. + // x-should-retry: false is required to disable SDK automatic retries (default MaxRetries=2). + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-should-retry", "false") + if shouldFail.Load() { + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(tc.errorBody)) + } else { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(tc.successBody)) + } + })) + defer mockUpstream.Close() + + metrics := aibridge.NewMetrics(prometheus.NewRegistry()) + + // Create provider with circuit breaker config + cbConfig := &aibridge.CircuitBreakerConfig{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + } + provider := tc.createProvider(mockUpstream.URL, cbConfig) + + ctx := t.Context() + tracer := otel.Tracer("forTesting") + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + bridge, err := aibridge.NewRequestBridge(ctx, + []aibridge.Provider{provider}, + &mockRecorderClient{}, + mcp.NewServerProxyManager(nil, tracer), + logger, + metrics, + tracer, + ) + require.NoError(t, err) + + mockSrv := httptest.NewUnstartedServer(bridge) + t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, "test-user-id", nil) + } + mockSrv.Start() + + makeRequest := func() *http.Response { + req, _ := http.NewRequest("POST", mockSrv.URL+"/"+tc.providerName+tc.endpoint, strings.NewReader(tc.requestBody)) + req.Header.Set("Content-Type", "application/json") + tc.setupHeaders(req) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _, _ = io.ReadAll(resp.Body) + resp.Body.Close() + return resp + } + + // Phase 1: Trip the circuit breaker + // First FailureThreshold requests hit upstream, get 429 + for i := uint32(0); i < cbConfig.FailureThreshold; i++ { + resp := makeRequest() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load()) + + // Phase 2: Verify circuit is open + // Request should be blocked by circuit breaker (no upstream call) + resp := makeRequest() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load(), "No new upstream call when circuit is open") + + // Verify metrics show circuit is open + trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") + + state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open)") + + rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should be 1") + + // Phase 3: Wait for timeout to transition to half-open + time.Sleep(cbConfig.Timeout + 10*time.Millisecond) + + // Switch upstream to return success + shouldFail.Store(false) + + // Phase 4: Recovery - request in half-open state should succeed and close circuit + upstreamCallsBefore := upstreamCalls.Load() + resp = makeRequest() + assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed in half-open state") + assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") + + // Verify circuit is now closed + state = promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 0.0, state, "CircuitBreakerState should be 0 (closed) after recovery") + + // Phase 5: Verify circuit is fully functional again + // Multiple requests should all succeed and reach upstream + for i := 0; i < 3; i++ { + resp = makeRequest() + assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed after circuit closes") + } + + // All requests should have reached upstream + assert.Equal(t, upstreamCallsBefore+4, upstreamCalls.Load(), "All requests should reach upstream after circuit closes") + + // Rejects count should not have increased + rejects = promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should still be 1 (no new rejects)") + }) + } +} + +func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + + // Mock upstream that always returns 429. + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-should-retry", "false") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`)) + })) + defer mockUpstream.Close() + + metrics := aibridge.NewMetrics(prometheus.NewRegistry()) + + cbConfig := &aibridge.CircuitBreakerConfig{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + } + provider := aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{ + BaseURL: mockUpstream.URL, + Key: "test-key", + CircuitBreaker: cbConfig, + }) + + ctx := t.Context() + tracer := otel.Tracer("forTesting") + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + bridge, err := aibridge.NewRequestBridge(ctx, + []aibridge.Provider{provider}, + &mockRecorderClient{}, + mcp.NewServerProxyManager(nil, tracer), + logger, + metrics, + tracer, + ) + require.NoError(t, err) + + mockSrv := httptest.NewUnstartedServer(bridge) + t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, "test-user-id", nil) + } + mockSrv.Start() + + makeRequest := func() *http.Response { + req, _ := http.NewRequest("POST", mockSrv.URL+"/openai/v1/chat/completions", + strings.NewReader(`{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-key") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _, _ = io.ReadAll(resp.Body) + resp.Body.Close() + return resp + } + + // Phase 1: Trip the circuit + for i := uint32(0); i < cbConfig.FailureThreshold; i++ { + resp := makeRequest() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + + // Verify circuit is open + resp := makeRequest() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + + trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues("openai", "/v1/chat/completions")) + assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") + + // Phase 2: Wait for half-open state + time.Sleep(cbConfig.Timeout + 10*time.Millisecond) + + // Phase 3: Request in half-open state fails, circuit should re-open + upstreamCallsBefore := upstreamCalls.Load() + resp = makeRequest() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode, "Request should fail in half-open state") + assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") + + // Circuit should be open again - next request should be rejected immediately + resp = makeRequest() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Circuit should be open again after half-open failure") + assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should NOT reach upstream when circuit re-opens") + + // Verify metrics: trips should be 2 now (tripped twice) + trips = promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues("openai", "/v1/chat/completions")) + assert.Equal(t, 2.0, trips, "CircuitBreakerTrips should be 2 after half-open failure") + + state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues("openai", "/v1/chat/completions")) + assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open) after half-open failure") +} diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go new file mode 100644 index 0000000..afa8cfb --- /dev/null +++ b/circuit_breaker_test.go @@ -0,0 +1,148 @@ +package aibridge + +import ( + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/sony/gobreaker/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCircuitBreakerMiddleware_PerEndpointIsolation(t *testing.T) { + t.Parallel() + + chatCalls := atomic.Int32{} + responsesCalls := atomic.Int32{} + + // Mock upstream - /chat returns 429, /responses returns 200 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/test/v1/chat/completions" { + chatCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + } else { + responsesCalls.Add(1) + w.WriteHeader(http.StatusOK) + } + }) + + cbs := NewProviderCircuitBreakers("test", &CircuitBreakerConfig{ + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + }, func(endpoint string, from, to gobreaker.State) {}) + + handler := CircuitBreakerMiddleware(cbs, nil)(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // Trip circuit on /chat/completions + resp, err := http.Get(server.URL + "/test/v1/chat/completions") + require.NoError(t, err) + resp.Body.Close() + + // /chat/completions should now be blocked + resp, err = http.Get(server.URL + "/test/v1/chat/completions") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Equal(t, int32(1), chatCalls.Load()) // Only 1 call, second was blocked + + // /responses should still work + resp, err = http.Get(server.URL + "/test/v1/responses") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, int32(1), responsesCalls.Load()) +} + +func TestCircuitBreakerMiddleware_NotConfigured(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + }) + + // No circuit breaker configured (nil) + handler := CircuitBreakerMiddleware(nil, nil)(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // All requests should pass through even with 429s + for i := 0; i < 10; i++ { + resp, err := http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(10), upstreamCalls.Load()) +} + +func TestCircuitBreakerMiddleware_CustomIsFailure(t *testing.T) { + t.Parallel() + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) // 502 + }) + + // Custom IsFailure that treats 502 as failure + cbs := NewProviderCircuitBreakers("test", &CircuitBreakerConfig{ + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + IsFailure: func(statusCode int) bool { + return statusCode == http.StatusBadGateway + }, + }, func(endpoint string, from, to gobreaker.State) {}) + + handler := CircuitBreakerMiddleware(cbs, nil)(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // First request returns 502, trips circuit + resp, _ := http.Get(server.URL + "/test/v1/messages") + resp.Body.Close() + + // Second request should be blocked + resp, _ = http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + resp.Body.Close() +} + +func TestDefaultIsFailure(t *testing.T) { + t.Parallel() + + tests := []struct { + statusCode int + isFailure bool + }{ + {http.StatusOK, false}, + {http.StatusBadRequest, false}, + {http.StatusUnauthorized, false}, + {http.StatusTooManyRequests, true}, // 429 + {http.StatusInternalServerError, false}, + {http.StatusBadGateway, false}, + {http.StatusServiceUnavailable, true}, // 503 + {529, true}, // Anthropic Overloaded + } + + for _, tt := range tests { + assert.Equal(t, tt.isFailure, DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode) + } +} + +func TestStateToGaugeValue(t *testing.T) { + t.Parallel() + + assert.Equal(t, float64(0), stateToGaugeValue(gobreaker.StateClosed)) + assert.Equal(t, float64(0.5), stateToGaugeValue(gobreaker.StateHalfOpen)) + assert.Equal(t, float64(1), stateToGaugeValue(gobreaker.StateOpen)) +} diff --git a/config.go b/config.go index 8dc6f1d..ff1f639 100644 --- a/config.go +++ b/config.go @@ -1,7 +1,8 @@ package aibridge type ProviderConfig struct { - BaseURL, Key string + BaseURL, Key string + CircuitBreaker *CircuitBreakerConfig } type ( diff --git a/go.mod b/go.mod index 9a62089..329c7bf 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/mark3labs/mcp-go v0.38.0 github.com/prometheus/client_golang v1.23.2 + github.com/sony/gobreaker/v2 v2.3.0 github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 diff --git a/go.sum b/go.sum index 385345d..fff1ee3 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/sony/gobreaker/v2 v2.3.0 h1:7VYxZ69QXRQ2Q4eEawHn6eU4FiuwovzJwsUMA03Lu4I= +github.com/sony/gobreaker/v2 v2.3.0/go.mod h1:pTyFJgcZ3h2tdQVLZZruK2C0eoFL1fb/G83wK1ZQl+s= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/interception.go b/interception.go index 46ec7bd..039cf31 100644 --- a/interception.go +++ b/interception.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "strings" "time" @@ -93,7 +94,8 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server return } - route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) + prefix, _ := url.JoinPath("/", p.Name()) + route := strings.TrimPrefix(r.URL.Path, prefix) log := logger.With( slog.F("route", route), slog.F("provider", p.Name()), diff --git a/metrics.go b/metrics.go index 32d5a78..e7fbb5c 100644 --- a/metrics.go +++ b/metrics.go @@ -28,6 +28,11 @@ type Metrics struct { // Tool-related metrics. InjectedToolUseCount *prometheus.CounterVec NonInjectedToolUseCount *prometheus.CounterVec + + // Circuit breaker metrics. + CircuitBreakerState *prometheus.GaugeVec // Current state (0=closed, 0.5=half-open, 1=open) + CircuitBreakerTrips *prometheus.CounterVec // Total times circuit opened + CircuitBreakerRejects *prometheus.CounterVec // Requests rejected due to open circuit } // NewMetrics creates AND registers metrics. It will panic if a collector has already been registered. @@ -102,5 +107,26 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { Name: "total", Help: "The number of times an AI model selected a tool to be invoked by the client.", }, append(baseLabels, "name")), + + // Circuit breaker metrics. + + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + CircuitBreakerState: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ + Subsystem: "circuit_breaker", + Name: "state", + Help: "Current state of the circuit breaker (0=closed, 0.5=half-open, 1=open).", + }, []string{"provider", "endpoint"}), + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "circuit_breaker", + Name: "trips_total", + Help: "Total number of times the circuit breaker has tripped open.", + }, []string{"provider", "endpoint"}), + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + CircuitBreakerRejects: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "circuit_breaker", + Name: "rejects_total", + Help: "Total number of requests rejected due to open circuit breaker.", + }, []string{"provider", "endpoint"}), } } diff --git a/provider.go b/provider.go index 20f8f52..a2d1d87 100644 --- a/provider.go +++ b/provider.go @@ -33,4 +33,7 @@ type Provider interface { AuthHeader() string // InjectAuthHeader allows [Provider]s to set its authentication header. InjectAuthHeader(*http.Header) + + // CircuitBreakerConfig returns the circuit breaker configuration for the provider. + CircuitBreakerConfig() *CircuitBreakerConfig } diff --git a/provider_anthropic.go b/provider_anthropic.go index fb5d10b..d07502e 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -108,6 +108,10 @@ func (p *AnthropicProvider) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), p.cfg.Key) } +func (p *AnthropicProvider) CircuitBreakerConfig() *CircuitBreakerConfig { + return p.cfg.CircuitBreaker +} + func getAnthropicErrorResponse(err error) *AnthropicErrorResponse { var apierr *anthropic.Error if !errors.As(err, &apierr) { diff --git a/provider_openai.go b/provider_openai.go index 68777e7..65288f6 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -17,7 +17,8 @@ var _ Provider = &OpenAIProvider{} // OpenAIProvider allows for interactions with the OpenAI API. type OpenAIProvider struct { - baseURL, key string + baseURL, key string + circuitBreaker *CircuitBreakerConfig } const ( @@ -36,8 +37,9 @@ func NewOpenAIProvider(cfg OpenAIConfig) *OpenAIProvider { } return &OpenAIProvider{ - baseURL: cfg.BaseURL, - key: cfg.Key, + baseURL: cfg.BaseURL, + key: cfg.Key, + circuitBreaker: cfg.CircuitBreaker, } } @@ -108,3 +110,7 @@ func (p *OpenAIProvider) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), "Bearer "+p.key) } + +func (p *OpenAIProvider) CircuitBreakerConfig() *CircuitBreakerConfig { + return p.circuitBreaker +}