Skip to content
319 changes: 222 additions & 97 deletions bridge_integration_test.go

Large diffs are not rendered by default.

35 changes: 35 additions & 0 deletions fixtures/anthropic/non_stream_error.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
Simple request + error which occurs before streaming begins (where applicable).

-- request --
{
"max_tokens": 8192,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "yo"
}
]
}
],
"model": "claude-sonnet-4-0",
"temperature": 1
}

-- streaming --
HTTP/2.0 400 Bad Request
Content-Length: 164
Content-Type: application/json

{"type":"error","error":{"type":"invalid_request_error","message":"prompt is too long: 205429 tokens > 200000 maximum"},"request_id":"req_011CV5Jab6gR3ZNs9Sj6apiD"}


-- non-streaming --
HTTP/2.0 400 Bad Request
Content-Length: 164
Content-Type: application/json

{"type":"error","error":{"type":"invalid_request_error","message":"prompt is too long: 205429 tokens > 200000 maximum"},"request_id":"req_011CV5Jab6gR3ZNs9Sj6apiD"}

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ Simple request + error.
}
],
"model": "claude-sonnet-4-0",
"temperature": 1
"temperature": 1,
"stream": true
}

-- streaming --
Expand All @@ -31,12 +32,3 @@ data: {"type": "ping"}
event: error
data: {"type": "error", "error": {"type": "api_error", "message": "Overloaded"}}

-- non-streaming --
{
"type": "error",
"error": {
"type": "api_error",
"message": "Overloaded"
},
"request_id": null
}
43 changes: 43 additions & 0 deletions fixtures/openai/non_stream_error.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
Simple request + error which occurs before streaming begins (where applicable).

-- request --
{
"messages": [
{
"role": "user",
"content": "how many angels can dance on the head of a pin\n"
}
],
"model": "gpt-4.1",
"stream": true
}

-- streaming --
HTTP/2.0 400 Bad Request
Content-Length: 281
Content-Type: application/json

{
"error": {
"message": "Input tokens exceed the configured limit of 272000 tokens. Your messages resulted in 3148588 tokens. Please reduce the length of the messages.",
"type": "invalid_request_error",
"param": "messages",
"code": "context_length_exceeded"
}
}


-- non-streaming --
HTTP/2.0 400 Bad Request
Content-Length: 281
Content-Type: application/json

{
"error": {
"message": "Input tokens exceed the configured limit of 272000 tokens. Your messages resulted in 3148588 tokens. Please reduce the length of the messages.",
"type": "invalid_request_error",
"param": "messages",
"code": "context_length_exceeded"
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ Simple request + error.
"content": "how many angels can dance on the head of a pin\n"
}
],
"model": "gpt-4.1"
"model": "gpt-4.1",
"stream": true
}

-- streaming --
Expand All @@ -22,10 +23,3 @@ data: {"id":"chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N","object":"chat.completion.c

data: {"error": {"message": "The server had an error while processing your request. Sorry about that!", "type": "server_error"}}

-- non-streaming --
{
"error": {
"message": "The server had an error while processing your request. Sorry about that!",
"type": "server_error"
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/hashicorp/go-multierror v1.1.1
github.com/mark3labs/mcp-go v0.38.0
github.com/stretchr/testify v1.10.0
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
go.uber.org/goleak v1.3.0
go.uber.org/mock v0.6.0
Expand Down
28 changes: 28 additions & 0 deletions intercept_anthropic_messages_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aibridge

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -173,6 +174,33 @@ func (i *AnthropicMessagesInterceptionBase) augmentRequestForBedrock() {
i.req.MessageNewParams.Model = anthropic.Model(i.Model())
}

// writeUpstreamError marshals and writes a given error.
func (i *AnthropicMessagesInterceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *AnthropicErrorResponse) {
if antErr == nil {
return
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(antErr.StatusCode)

out, err := json.Marshal(antErr)
if err != nil {
i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", antErr)))
// Response has to match expected format.
// See https://docs.claude.com/en/api/errors#error-shapes.
_, _ = w.Write([]byte(fmt.Sprintf(`{
"type":"error",
"error": {
"type": "error",
"message":"error marshaling upstream error"
},
"request_id": "%s"
}`, i.ID().String())))
} else {
_, _ = w.Write(out)
}
}

// redirectTransport is an HTTP RoundTripper that redirects requests to a different endpoint.
// This is useful for testing when we need to redirect AWS Bedrock requests to a mock server.
type redirectTransport struct {
Expand Down
10 changes: 4 additions & 6 deletions intercept_anthropic_messages_blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,17 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr
resp, err = client.Messages.New(ctx, messages)
if err != nil {
if isConnError(err) {
logger.Warn(ctx, "upstream connection closed", slog.Error(err))
// Can't write a response, just error out.
return fmt.Errorf("upstream connection closed: %w", err)
}

logger.Warn(ctx, "anthropic API error", slog.Error(err))
if antErr := getAnthropicErrorResponse(err); antErr != nil {
http.Error(w, antErr.Error(), antErr.StatusCode)
return fmt.Errorf("api error: %w", err)
i.writeUpstreamError(w, antErr)
return fmt.Errorf("anthropic API error: %w", err)
}

logger.Warn(ctx, "upstream API error", slog.Error(err))
http.Error(w, "internal error", http.StatusInternalServerError)
return fmt.Errorf("upstream API error: %w", err)
return fmt.Errorf("internal error: %w", err)
}

if prompt != nil {
Expand Down
59 changes: 32 additions & 27 deletions intercept_anthropic_messages_streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW

// events will either terminate when shutdown after interaction with upstream completes, or when streamCtx is done.
events := newEventStream(streamCtx, logger.Named("sse-sender"), i.pingPayload())
go events.run(w, r)
go events.start(w, r)
defer func() {
_ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes.
}()
Expand Down Expand Up @@ -414,35 +414,40 @@ newStream:
prompt = nil
}

// Check if the stream encountered any errors.
if streamErr := stream.Err(); streamErr != nil {
if isUnrecoverableError(streamErr) {
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
// We can't reflect an error back if there's a connection error or the request context was canceled.
} else if antErr := getAnthropicErrorResponse(streamErr); antErr != nil {
logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr))
interceptionErr = fmt.Errorf("stream error: %w", antErr)
} else {
logger.Warn(ctx, "unknown error", slog.Error(streamErr))
// Unfortunately, the Anthropic SDK does not support parsing errors received in the stream
// into known types (i.e. [shared.OverloadedError]).
// See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174
// All it does is wrap the payload in an error - which is all we can return, currently.
interceptionErr = newAnthropicErr(fmt.Errorf("unknown stream error: %w", streamErr))
if events.isStreaming() {
// Check if the stream encountered any errors.
if streamErr := stream.Err(); streamErr != nil {
if isUnrecoverableError(streamErr) {
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
// We can't reflect an error back if there's a connection error or the request context was canceled.
} else if antErr := getAnthropicErrorResponse(streamErr); antErr != nil {
logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr))
interceptionErr = antErr
} else {
logger.Warn(ctx, "unknown error", slog.Error(streamErr))
// Unfortunately, the Anthropic SDK does not support parsing errors received in the stream
// into known types (i.e. [shared.OverloadedError]).
// See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174
// All it does is wrap the payload in an error - which is all we can return, currently.
interceptionErr = newAnthropicErr(fmt.Errorf("unknown stream error: %w", streamErr))
}
} else if lastErr != nil {
// Otherwise check if any logical errors occurred during processing.
logger.Warn(ctx, "stream failed", slog.Error(lastErr))
interceptionErr = newAnthropicErr(fmt.Errorf("processing error: %w", lastErr))
}
} else if lastErr != nil {
// Otherwise check if any logical errors occurred during processing.
logger.Warn(ctx, "stream failed", slog.Error(lastErr))
interceptionErr = newAnthropicErr(fmt.Errorf("processing error: %w", lastErr))
}

if interceptionErr != nil {
payload, err := i.marshal(interceptionErr)
if err != nil {
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr)))
} else if err := events.Send(streamCtx, payload); err != nil {
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
if interceptionErr != nil {
payload, err := i.marshal(interceptionErr)
if err != nil {
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr)))
} else if err := events.Send(streamCtx, payload); err != nil {
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
}
}
} else {
// Stream has not started yet; write to response if present.
i.writeUpstreamError(w, getAnthropicErrorResponse(stream.Err()))
}

shutdownCtx, shutdownCancel := context.WithTimeout(ctx, time.Second*30)
Expand Down
35 changes: 35 additions & 0 deletions intercept_openai_chat_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package aibridge
import (
"context"
"encoding/json"
"net/http"
"strings"

"github.com/coder/aibridge/mcp"
"github.com/google/uuid"
"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/option"
"github.com/openai/openai-go/v2/shared"

"cdr.dev/slog"
Expand All @@ -24,6 +26,14 @@ type OpenAIChatInterceptionBase struct {
mcpProxy mcp.ServerProxier
}

func (i *OpenAIChatInterceptionBase) newOpenAIClient(baseURL, key string) openai.Client {
var opts []option.RequestOption
opts = append(opts, option.WithAPIKey(key))
opts = append(opts, option.WithBaseURL(baseURL))

return openai.NewClient(opts...)
}

func (i *OpenAIChatInterceptionBase) ID() uuid.UUID {
return i.id
}
Expand Down Expand Up @@ -92,3 +102,28 @@ func (i *OpenAIChatInterceptionBase) unmarshalArgs(in string) (args ToolArgs) {

return args
}

// writeUpstreamError marshals and writes a given error.
func (i *OpenAIChatInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *OpenAIErrorResponse) {
if oaiErr == nil {
return
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(oaiErr.StatusCode)

out, err := json.Marshal(oaiErr)
if err != nil {
i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", oaiErr)))
// Response has to match expected format.
_, _ = w.Write([]byte(`{
"error": {
"type": "error",
"message":"error marshaling upstream error",
"code": "server_error"
},
}`))
} else {
_, _ = w.Write(out)
}
}
12 changes: 4 additions & 8 deletions intercept_openai_chat_blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package aibridge
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -42,7 +41,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
}

ctx := r.Context()
client := newOpenAIClient(i.baseURL, i.key)
client := i.newOpenAIClient(i.baseURL, i.key)
logger := i.logger.With(slog.F("model", i.req.Model))

var (
Expand Down Expand Up @@ -184,18 +183,15 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
}
}

// TODO: these probably have to be formatted as JSON errs?
if err != nil {
if isConnError(err) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return fmt.Errorf("upstream connection closed: %w", err)
}

logger.Warn(ctx, "openai API error", slog.Error(err))
var apierr *openai.Error
if errors.As(err, &apierr) {
http.Error(w, apierr.Message, apierr.StatusCode)
return fmt.Errorf("api error: %w", apierr)
if apiErr := getOpenAIErrorResponse(err); apiErr != nil {
i.writeUpstreamError(w, apiErr)
return fmt.Errorf("openai API error: %w", err)
}

http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down
Loading