Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

translator: handling for OpenAI and AWSBedrock error response #152

Merged
merged 10 commits into from
Jan 23, 2025
File renamed without changes.
5 changes: 5 additions & 0 deletions internal/apischema/awsbedrock/awsbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@ type ConverseStreamEventContentBlockDelta struct {
}

type BedrockException struct {
// Status code of the aws error response.
Code string `json:"code,omitempty"`
// Error type of the aws response.
Type string `json:"type,omitempty"`
// Error message of the aws response
Message string `json:"message"`
}

Expand Down
22 changes: 15 additions & 7 deletions internal/apischema/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -702,15 +702,23 @@ type ChatCompletionResponseChunkChoiceDelta struct {
// Error is described in the OpenAI API documentation
// https://platform.openai.com/docs/api-reference/realtime-server-events/error
type Error struct {
EventID *string `json:"event_id,omitempty"`
Type string `json:"type"`
Error ErrorType `json:"error"`
// The unique ID of the server event.
EventID *string `json:"event_id,omitempty"`
// The event type, must be error.
Type string `json:"type"`
// Details of the error.
Error ErrorType `json:"error"`
}

type ErrorType struct {
Type string `json:"type"`
Code *string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Param *string `json:"param,omitempty"`
// The type of error (e.g., "invalid_request_error", "server_error").
Type string `json:"type"`
// Error code, if any.
Code *string `json:"code,omitempty"`
// A human-readable error message.
Message string `json:"message,omitempty"`
// Parameter related to the error, if any.
Param *string `json:"param,omitempty"`
// The event_id of the client event that caused the error, if applicable.
EventID *string `json:"event_id,omitempty"`
}
12 changes: 11 additions & 1 deletion internal/extproc/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,18 @@ func (m mockTranslator) ResponseHeaders(headers map[string]string) (headerMutati
return m.retHeaderMutation, m.retErr
}

// ResponseError implements [translator.Translator.ResponseError].
func (m mockTranslator) ResponseError(_ map[string]string, body io.Reader) (headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, err error) {
if m.expResponseBody != nil {
buf, err := io.ReadAll(body)
require.NoError(m.t, err)
require.Equal(m.t, m.expResponseBody.Body, buf)
}
return m.retHeaderMutation, m.retBodyMutation, m.retErr
}

// ResponseBody implements [translator.Translator.ResponseBody].
func (m mockTranslator) ResponseBody(body io.Reader, _ bool) (headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tokenUsage translator.LLMTokenUsage, err error) {
func (m mockTranslator) ResponseBody(respHeader map[string]string, body io.Reader, _ bool) (headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tokenUsage translator.LLMTokenUsage, err error) {
if m.expResponseBody != nil {
buf, err := io.ReadAll(body)
require.NoError(m.t, err)
Expand Down
16 changes: 10 additions & 6 deletions internal/extproc/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type Processor struct {
logger *slog.Logger
config *processorConfig
requestHeaders map[string]string
responseHeaders map[string]string
responseEncoding string
translator translator.Translator
// cost is the cost of the request that is accumulated during the processing of the response.
Expand Down Expand Up @@ -135,18 +136,20 @@ func (p *Processor) ProcessRequestBody(_ context.Context, rawBody *extprocv3.Htt

// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
func (p *Processor) ProcessResponseHeaders(_ context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
hs := headersToMap(headers)
if enc := hs["content-encoding"]; enc != "" {
p.responseHeaders = headersToMap(headers)
if enc := p.responseHeaders["content-encoding"]; enc != "" {
p.responseEncoding = enc
}
// The translator can be nil as there could be response event generated by previous ext proc without
// getting the request event.
if p.translator == nil {
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseHeaders{}}, nil
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseHeaders{
ResponseHeaders: &extprocv3.HeadersResponse{},
}}, nil
}
headerMutation, err := p.translator.ResponseHeaders(hs)
headerMutation, err := p.translator.ResponseHeaders(p.responseHeaders)
if err != nil {
return nil, fmt.Errorf("failed to transform response: %w", err)
return nil, fmt.Errorf("failed to transform response headers: %w", err)
}
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseHeaders{
ResponseHeaders: &extprocv3.HeadersResponse{
Expand All @@ -172,7 +175,8 @@ func (p *Processor) ProcessResponseBody(_ context.Context, body *extprocv3.HttpB
if p.translator == nil {
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseBody{}}, nil
}
headerMutation, bodyMutation, tokenUsage, err := p.translator.ResponseBody(br, body.EndOfStream)

headerMutation, bodyMutation, tokenUsage, err := p.translator.ResponseBody(p.responseHeaders, br, body.EndOfStream)
if err != nil {
return nil, fmt.Errorf("failed to transform response: %w", err)
}
Expand Down
17 changes: 17 additions & 0 deletions internal/extproc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,23 @@ func TestServer_processMsg(t *testing.T) {
require.NotNil(t, resp)
require.Equal(t, expResponse, resp)
})
t.Run("error response headers", func(t *testing.T) {
s := requireNewServerWithMockProcessor(t)
p := s.newProcessor(nil, slog.Default())

hm := &corev3.HeaderMap{Headers: []*corev3.HeaderValue{{Key: ":status", Value: "504"}}}
expResponse := &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseHeaders{}}
p.t = t
p.expHeaderMap = hm
p.retProcessingResponse = expResponse
req := &extprocv3.ProcessingRequest{
Request: &extprocv3.ProcessingRequest_ResponseHeaders{ResponseHeaders: &extprocv3.HttpHeaders{Headers: hm}},
}
resp, err := s.processMsg(context.Background(), p, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, expResponse, resp)
})
t.Run("response body", func(t *testing.T) {
s := requireNewServerWithMockProcessor(t)
p := s.newProcessor(nil, slog.Default())
Expand Down
74 changes: 64 additions & 10 deletions internal/extproc/translator/openai_awsbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"reflect"
"regexp"
"strconv"
"strings"

"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
Expand Down Expand Up @@ -387,16 +388,14 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseHeaders(headers m
) {
if o.stream {
contentType := headers["content-type"]
if contentType != "application/vnd.amazon.eventstream" {
return nil, fmt.Errorf("unexpected content-type for streaming: %s", contentType)
if contentType == "application/vnd.amazon.eventstream" {
// We need to change the content-type to text/event-stream for streaming responses.
return &extprocv3.HeaderMutation{
SetHeaders: []*corev3.HeaderValueOption{
{Header: &corev3.HeaderValue{Key: "content-type", Value: "text/event-stream"}},
},
}, nil
}

// We need to change the content-type to text/event-stream for streaming responses.
return &extprocv3.HeaderMutation{
SetHeaders: []*corev3.HeaderValueOption{
{Header: &corev3.HeaderValue{Key: "content-type", Value: "text/event-stream"}},
},
}, nil
}
return nil, nil
}
Expand Down Expand Up @@ -438,10 +437,65 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) bedrockToolUseToOpenAICal
}
}

// ResponseError implements [Translator.ResponseError]
// Translate AWS Bedrock exceptions to OpenAI error type.
// The error type is stored in the "x-amzn-errortype" HTTP header for AWS error responses.
// If AWS Bedrock connection fails the error body is translated to OpenAI error type for events such as HTTP 503 or 504.
func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseError(respHeaders map[string]string, body io.Reader) (
headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, err error,
) {
statusCode := respHeaders[statusHeaderName]
var openaiError openai.Error
if v, ok := respHeaders[contentTypeHeaderName]; ok && v == jsonContentType {
var bedrockError awsbedrock.BedrockException
if err := json.NewDecoder(body).Decode(&bedrockError); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal error body: %w", err)
}
openaiError = openai.Error{
Type: "error",
Error: openai.ErrorType{
Type: respHeaders[awsErrorTypeHeaderName],
Message: bedrockError.Message,
Code: &statusCode,
},
}
} else {
buf, err := io.ReadAll(body)
if err != nil {
return nil, nil, fmt.Errorf("failed to read error body: %w", err)
}
openaiError = openai.Error{
Type: "error",
Error: openai.ErrorType{
Type: awsBedrockBackendError,
Message: string(buf),
Code: &statusCode,
},
}
}
mut := &extprocv3.BodyMutation_Body{}
if errBody, err := json.Marshal(openaiError); err != nil {
return nil, nil, fmt.Errorf("failed to marshal error body: %w", err)
} else {
mut.Body = errBody
}
headerMutation = &extprocv3.HeaderMutation{}
setContentLength(headerMutation, mut.Body)
return headerMutation, &extprocv3.BodyMutation{Mutation: mut}, nil
}

// ResponseBody implements [Translator.ResponseBody].
func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(body io.Reader, endOfStream bool) (
func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(respHeaders map[string]string, body io.Reader, endOfStream bool) (
headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tokenUsage LLMTokenUsage, err error,
) {
if v, ok := respHeaders[statusHeaderName]; ok {
if v, err := strconv.Atoi(v); err == nil {
if !isGoodStatusCode(v) {
headerMutation, bodyMutation, err = o.ResponseError(respHeaders, body)
return headerMutation, bodyMutation, LLMTokenUsage{}, err
}
}
}
mut := &extprocv3.BodyMutation_Body{}
if o.stream {
buf, err := io.ReadAll(body)
Expand Down
78 changes: 75 additions & 3 deletions internal/extproc/translator/openai_awsbedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"strconv"
"strings"
"testing"
Expand Down Expand Up @@ -734,7 +735,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_Streaming_ResponseBody(t *

var results []string
for i := 0; i < len(buf); i++ {
hm, bm, tokenUsage, err := o.ResponseBody(bytes.NewBuffer([]byte{buf[i]}), i == len(buf)-1)
hm, bm, tokenUsage, err := o.ResponseBody(nil, bytes.NewBuffer([]byte{buf[i]}), i == len(buf)-1)
require.NoError(t, err)
require.Nil(t, hm)
require.NotNil(t, bm)
Expand Down Expand Up @@ -769,10 +770,81 @@ data: [DONE]
})
}

func TestOpenAIToAWSBedrockTranslator_ResponseError(t *testing.T) {
tests := []struct {
name string
responseHeaders map[string]string
input io.Reader
output openai.Error
}{
{
name: "test unhealthy upstream",
responseHeaders: map[string]string{
":status": "503",
"content-type": "text/plain",
},
input: bytes.NewBuffer([]byte("service not available")),
output: openai.Error{
Type: "error",
Error: openai.ErrorType{
Type: awsBedrockBackendError,
Code: ptr.To("503"),
Message: "service not available",
},
},
},
{
name: "test AWS throttled error response",
responseHeaders: map[string]string{
":status": "429",
"content-type": "application/json",
awsErrorTypeHeaderName: "ThrottledException",
},
input: bytes.NewBuffer([]byte(`{"message": "aws bedrock rate limit exceeded"}`)),
output: openai.Error{
Type: "error",
Error: openai.ErrorType{
Type: "ThrottledException",
Code: ptr.To("429"),
Message: "aws bedrock rate limit exceeded",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body, err := json.Marshal(tt.input)
require.NoError(t, err)
fmt.Println(string(body))

o := &openAIToAWSBedrockTranslatorV1ChatCompletion{}
hm, bm, err := o.ResponseError(tt.responseHeaders, tt.input)
require.NoError(t, err)
require.NotNil(t, bm)
require.NotNil(t, bm.Mutation)
require.NotNil(t, bm.Mutation.(*extprocv3.BodyMutation_Body))
newBody := bm.Mutation.(*extprocv3.BodyMutation_Body).Body
require.NotNil(t, newBody)
require.NotNil(t, hm)
require.NotNil(t, hm.SetHeaders)
require.Len(t, hm.SetHeaders, 1)
require.Equal(t, "content-length", hm.SetHeaders[0].Header.Key)
require.Equal(t, strconv.Itoa(len(newBody)), string(hm.SetHeaders[0].Header.RawValue))

var openAIError openai.Error
err = json.Unmarshal(newBody, &openAIError)
require.NoError(t, err)
if !cmp.Equal(openAIError, tt.output) {
t.Errorf("ConvertAWSBedrockErrorResp(), diff(got, expected) = %s\n", cmp.Diff(openAIError, tt.output))
}
})
}
}

func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T) {
t.Run("invalid body", func(t *testing.T) {
o := &openAIToAWSBedrockTranslatorV1ChatCompletion{}
_, _, _, err := o.ResponseBody(bytes.NewBuffer([]byte("invalid")), false)
_, _, _, err := o.ResponseBody(nil, bytes.NewBuffer([]byte("invalid")), false)
require.Error(t, err)
})
tests := []struct {
Expand Down Expand Up @@ -924,7 +996,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T)
fmt.Println(string(body))

o := &openAIToAWSBedrockTranslatorV1ChatCompletion{}
hm, bm, usedToken, err := o.ResponseBody(bytes.NewBuffer(body), false)
hm, bm, usedToken, err := o.ResponseBody(nil, bytes.NewBuffer(body), false)
require.NoError(t, err)
require.NotNil(t, bm)
require.NotNil(t, bm.Mutation)
Expand Down
Loading
Loading