Skip to content

Commit

Permalink
test: adds support for streaming in testupstream (#159)
Browse files Browse the repository at this point in the history
Signed-off-by: Takeshi Yoneda <[email protected]>
  • Loading branch information
mathetake authored Jan 22, 2025
1 parent 36d2d76 commit b1b23e2
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 112 deletions.
102 changes: 84 additions & 18 deletions tests/extproc/testupstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ package extproc

import (
"encoding/base64"
"fmt"
"io"
"net/http"
"os"
"strings"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"

"github.com/envoyproxy/ai-gateway/filterconfig"
Expand Down Expand Up @@ -62,12 +64,17 @@ func TestWithTestUpstream(t *testing.T) {
requestBody,
// responseBody is the response body to return from the test upstream.
responseBody,
// responseType is either empty, "sse" or "aws-event-stream" as implemented by the test upstream.
responseType,
// expPath is the expected path to be sent to the test upstream.
expPath string
// expRequestBody is the expected body to be sent to the test upstream.
// This can be used to test the request body translation.
expRequestBody string
// expStatus is the expected status code from the gateway.
expStatus int
// expBody is the expected body from the gateway.
expBody string
// expResponseBody is the expected body from the gateway to the client.
expResponseBody string
}{
{
name: "unknown path",
Expand All @@ -80,25 +87,78 @@ func TestWithTestUpstream(t *testing.T) {
expStatus: http.StatusInternalServerError,
},
{
name: "aws - /v1/chat/completions",
backend: "aws-bedrock",
path: "/v1/chat/completions",
requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}]}`,
expPath: "/model/something/converse",
responseBody: `{"output":{"message":{"content":[{"text":"response"},{"text":"from"},{"text":"assistant"}],"role":"assistant"}},"stopReason":null,"usage":{"inputTokens":10,"outputTokens":20,"totalTokens":30}}`,
expStatus: http.StatusOK,
expBody: `{"choices":[{"finish_reason":"stop","index":0,"logprobs":{},"message":{"content":"response","role":"assistant"}},{"finish_reason":"stop","index":1,"logprobs":{},"message":{"content":"from","role":"assistant"}},{"finish_reason":"stop","index":2,"logprobs":{},"message":{"content":"assistant","role":"assistant"}}],"object":"chat.completion","usage":{"completion_tokens":20,"prompt_tokens":10,"total_tokens":30}}`,
name: "aws - /v1/chat/completions",
backend: "aws-bedrock",
path: "/v1/chat/completions",
requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}]}`,
expPath: "/model/something/converse",
responseBody: `{"output":{"message":{"content":[{"text":"response"},{"text":"from"},{"text":"assistant"}],"role":"assistant"}},"stopReason":null,"usage":{"inputTokens":10,"outputTokens":20,"totalTokens":30}}`,
expRequestBody: `{"inferenceConfig":{},"messages":[],"modelId":null,"system":[{"text":"You are a chatbot."}]}`,
expStatus: http.StatusOK,
expResponseBody: `{"choices":[{"finish_reason":"stop","index":0,"logprobs":{},"message":{"content":"response","role":"assistant"}},{"finish_reason":"stop","index":1,"logprobs":{},"message":{"content":"from","role":"assistant"}},{"finish_reason":"stop","index":2,"logprobs":{},"message":{"content":"assistant","role":"assistant"}}],"object":"chat.completion","usage":{"completion_tokens":20,"prompt_tokens":10,"total_tokens":30}}`,
},
{
name: "openai - /v1/chat/completions",
backend: "openai",
path: "/v1/chat/completions",
method: http.MethodPost,
requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}]}`,
expPath: "/v1/chat/completions",
responseBody: `{"choices":[{"message":{"content":"This is a test."}}]}`,
expStatus: http.StatusOK,
expResponseBody: `{"choices":[{"message":{"content":"This is a test."}}]}`,
},
{
name: "openai - /v1/chat/completions",
name: "aws - /v1/chat/completions - streaming",
backend: "aws-bedrock",
path: "/v1/chat/completions",
responseType: "aws-event-stream",
method: http.MethodPost,
requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}], "stream": true}`,
expRequestBody: `{"inferenceConfig":{},"messages":[],"modelId":null,"system":[{"text":"You are a chatbot."}]}`,
expPath: "/model/something/converse-stream",
responseBody: `{"role":"assistant"}
{"delta":{"text":"Don"}}
{"delta":{"text":"'t worry, I'm here to help. It"}}
{"delta":{"text":" seems like you're testing my ability to respond appropriately"}}
{"stopReason":"end_turn"}
{"usage":{"inputTokens":41, "outputTokens":36, "totalTokens":77}}
`,
expStatus: http.StatusOK,
expResponseBody: `data: {"choices":[{"delta":{"content":"","role":"assistant"}}],"object":"chat.completion.chunk"}
data: {"choices":[{"delta":{"content":"Don"}}],"object":"chat.completion.chunk"}
data: {"choices":[{"delta":{"content":"'t worry, I'm here to help. It"}}],"object":"chat.completion.chunk"}
data: {"choices":[{"delta":{"content":" seems like you're testing my ability to respond appropriately"}}],"object":"chat.completion.chunk"}
data: {"object":"chat.completion.chunk","usage":{"completion_tokens":36,"prompt_tokens":41,"total_tokens":77}}
data: [DONE]
`,
},
{
name: "openai - /v1/chat/completions - streaming",
backend: "openai",
path: "/v1/chat/completions",
responseType: "sse",
method: http.MethodPost,
requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}]}`,
requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}], "stream": true}`,
expPath: "/v1/chat/completions",
responseBody: `{"choices":[{"message":{"content":"This is a test."}}]}`,
expStatus: http.StatusOK,
expBody: `{"choices":[{"message":{"content":"This is a test."}}]}`,
responseBody: `
{"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null}
{"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[],"usage":{"prompt_tokens":13,"completion_tokens":12,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}}
[DONE]
`,
expStatus: http.StatusOK,
expResponseBody: `data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[],"usage":{"prompt_tokens":13,"completion_tokens":12,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}}
data: [DONE]
`,
},
} {
t.Run(tc.name, func(t *testing.T) {
Expand All @@ -108,6 +168,12 @@ func TestWithTestUpstream(t *testing.T) {
req.Header.Set("x-test-backend", tc.backend)
req.Header.Set("x-response-body", base64.StdEncoding.EncodeToString([]byte(tc.responseBody)))
req.Header.Set("x-expected-path", base64.StdEncoding.EncodeToString([]byte(tc.expPath)))
if tc.responseType != "" {
req.Header.Set("x-response-type", tc.responseType)
}
if tc.expRequestBody != "" {
req.Header.Set("x-expected-request-body", base64.StdEncoding.EncodeToString([]byte(tc.expRequestBody)))
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
Expand All @@ -120,11 +186,11 @@ func TestWithTestUpstream(t *testing.T) {
t.Logf("unexpected status code: %d", resp.StatusCode)
return false
}
if tc.expBody != "" {
if tc.expResponseBody != "" {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
if string(body) != tc.expBody {
t.Logf("unexpected response:\ngot: %s\nexp: %s", body, tc.expBody)
if string(body) != tc.expResponseBody {
fmt.Printf("unexpected response:\n%s", cmp.Diff(string(body), tc.expResponseBody))
return false
}
}
Expand Down
182 changes: 94 additions & 88 deletions tests/testupstream/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ import (
var logger = log.New(os.Stdout, "[testupstream] ", 0)

const (
// responseTypeKey is the key for the response type in the request.
// This can be either empty, "sse", or "aws-event-stream".
// * If this is "sse", the response body is expected to be a Server-Sent Event stream.
// Each line in x-response-body is treated as a separate [data] payload.
// * If this is "aws-event-stream", the response body is expected to be an AWS Event Stream.
// Each line in x-response-body is treated as a separate event payload.
// * If this is empty, the response body is expected to be a regular JSON response.
responseTypeKey = "x-response-type"
// expectedHeadersKey is the key for the expected headers in the request.
// The value is a base64 encoded string of comma separated key-value pairs.
// E.g. "key1:value1,key2:value2".
Expand Down Expand Up @@ -66,7 +74,7 @@ func main() {
doMain(l)
}

var streamingInterval = time.Second
var streamingInterval = 200 * time.Millisecond

func doMain(l net.Listener) {
if raw := os.Getenv("STREAMING_INTERVAL"); raw != "" {
Expand All @@ -77,52 +85,11 @@ func doMain(l net.Listener) {
defer l.Close()
http.HandleFunc("/health", func(writer http.ResponseWriter, request *http.Request) { writer.WriteHeader(http.StatusOK) })
http.HandleFunc("/", handler)
http.HandleFunc("/sse", sseHandler)
http.HandleFunc("/aws-event-stream", awsEventStreamHandler)
if err := http.Serve(l, nil); err != nil { // nolint: gosec
logger.Printf("failed to serve: %v", err)
}
}

func sseHandler(w http.ResponseWriter, r *http.Request) {
expResponseBody, err := base64.StdEncoding.DecodeString(r.Header.Get(responseBodyHeaderKey))
if err != nil {
logger.Println("failed to decode the response body")
http.Error(w, "failed to decode the response body", http.StatusBadRequest)
return
}

w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("testupstream-id", os.Getenv("TESTUPSTREAM_ID"))

for _, line := range bytes.Split(expResponseBody, []byte("\n")) {
line := string(line)
time.Sleep(streamingInterval)

if _, err = w.Write([]byte("event: some event in testupstream\n")); err != nil {
logger.Println("failed to write the response body")
return
}

if _, err = w.Write([]byte(fmt.Sprintf("data: %s\n\n", line))); err != nil {
logger.Println("failed to write the response body")
return
}

if f, ok := w.(http.Flusher); ok {
f.Flush()
} else {
panic("expected http.ResponseWriter to be an http.Flusher")
}
logger.Println("response line sent:", line)
}

logger.Println("response sent")
r.Context().Done()
}

func handler(w http.ResponseWriter, r *http.Request) {
for k, v := range r.Header {
logger.Printf("header %q: %s\n", k, v)
Expand Down Expand Up @@ -193,17 +160,19 @@ func handler(w http.ResponseWriter, r *http.Request) {
logger.Println("no expected testupstream-id")
}

expectedPath, err := base64.StdEncoding.DecodeString(r.Header.Get(expectedPathHeaderKey))
if err != nil {
logger.Println("failed to decode the expected path")
http.Error(w, "failed to decode the expected path", http.StatusBadRequest)
return
}
if expectedPath := r.Header.Get(expectedPathHeaderKey); expectedPath != "" {
expectedPath, err := base64.StdEncoding.DecodeString(expectedPath)
if err != nil {
logger.Println("failed to decode the expected path")
http.Error(w, "failed to decode the expected path", http.StatusBadRequest)
return
}

if r.URL.Path != string(expectedPath) {
logger.Printf("unexpected path: got %q, expected %q\n", r.URL.Path, string(expectedPath))
http.Error(w, "unexpected path: got "+r.URL.Path+", expected "+string(expectedPath), http.StatusBadRequest)
return
if r.URL.Path != string(expectedPath) {
logger.Printf("unexpected path: got %q, expected %q\n", r.URL.Path, string(expectedPath))
http.Error(w, "unexpected path: got "+r.URL.Path+", expected "+string(expectedPath), http.StatusBadRequest)
return
}
}

requestBody, err := io.ReadAll(r.Body)
Expand All @@ -213,8 +182,8 @@ func handler(w http.ResponseWriter, r *http.Request) {
return
}

if r.Header.Get(expectedRequestBodyHeaderKey) != "" {
expectedBody, err := base64.StdEncoding.DecodeString(r.Header.Get(expectedRequestBodyHeaderKey))
if expectedReqBody := r.Header.Get(expectedRequestBodyHeaderKey); expectedReqBody != "" {
expectedBody, err := base64.StdEncoding.DecodeString(expectedReqBody)
if err != nil {
logger.Println("failed to decode the expected request body")
http.Error(w, "failed to decode the expected request body", http.StatusBadRequest)
Expand Down Expand Up @@ -261,7 +230,6 @@ func handler(w http.ResponseWriter, r *http.Request) {
} else {
logger.Println("no response headers")
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("testupstream-id", os.Getenv("TESTUPSTREAM_ID"))
status := http.StatusOK
if v := r.Header.Get(responseStatusKey); v != "" {
Expand All @@ -272,46 +240,84 @@ func handler(w http.ResponseWriter, r *http.Request) {
return
}
}
w.WriteHeader(status)
if _, err := w.Write(responseBody); err != nil {
logger.Println("failed to write the response body")
}
logger.Println("response sent:", string(responseBody))
}

func awsEventStreamHandler(w http.ResponseWriter, r *http.Request) {
expResponseBody, err := base64.StdEncoding.DecodeString(r.Header.Get(responseBodyHeaderKey))
if err != nil {
logger.Println("failed to decode the response body")
http.Error(w, "failed to decode the response body", http.StatusBadRequest)
return
}
switch r.Header.Get(responseTypeKey) {
case "sse":
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(status)

w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("testupstream-id", os.Getenv("TESTUPSTREAM_ID"))
expResponseBody, err := base64.StdEncoding.DecodeString(r.Header.Get(responseBodyHeaderKey))
if err != nil {
logger.Println("failed to decode the response body")
http.Error(w, "failed to decode the response body", http.StatusBadRequest)
return
}

for _, line := range bytes.Split(expResponseBody, []byte("\n")) {
line := string(line)
if line == "" {
continue
}
time.Sleep(streamingInterval)

if _, err = w.Write([]byte(fmt.Sprintf("data: %s\n\n", line))); err != nil {
logger.Println("failed to write the response body")
return
}

if f, ok := w.(http.Flusher); ok {
f.Flush()
} else {
panic("expected http.ResponseWriter to be an http.Flusher")
}
logger.Println("response line sent:", line)
}
logger.Println("response sent")
r.Context().Done()
case "aws-event-stream":
// w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(status)

expResponseBody, err := base64.StdEncoding.DecodeString(r.Header.Get(responseBodyHeaderKey))
if err != nil {
logger.Println("failed to decode the response body")
http.Error(w, "failed to decode the response body", http.StatusBadRequest)
return
}

e := eventstream.NewEncoder()
for _, line := range bytes.Split(expResponseBody, []byte("\n")) {
// Write each line as a chunk with AWS Event Stream format.
if len(line) == 0 {
continue
}
time.Sleep(streamingInterval)
if err := e.Encode(w, eventstream.Message{
Headers: eventstream.Headers{{Name: "event-type", Value: eventstream.StringValue("content")}},
Payload: line,
}); err != nil {
logger.Println("failed to encode the response body")
}
w.(http.Flusher).Flush()
logger.Println("response line sent:", string(line))
}

e := eventstream.NewEncoder()
for _, line := range bytes.Split(expResponseBody, []byte("\n")) {
// Write each line as a chunk with AWS Event Stream format.
time.Sleep(streamingInterval)
if err := e.Encode(w, eventstream.Message{
Headers: eventstream.Headers{{Name: "event-type", Value: eventstream.StringValue("content")}},
Payload: line,
Headers: eventstream.Headers{{Name: "event-type", Value: eventstream.StringValue("end")}},
Payload: []byte("this-is-end"),
}); err != nil {
logger.Println("failed to encode the response body")
}
w.(http.Flusher).Flush()
logger.Println("response line sent:", string(line))
}

if err := e.Encode(w, eventstream.Message{
Headers: eventstream.Headers{{Name: "event-type", Value: eventstream.StringValue("end")}},
Payload: []byte("this-is-end"),
}); err != nil {
logger.Println("failed to encode the response body")
logger.Println("response sent")
r.Context().Done()
default:
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if _, err := w.Write(responseBody); err != nil {
logger.Println("failed to write the response body")
}
logger.Println("response sent:", string(responseBody))
}

logger.Println("response sent")
r.Context().Done()
}
Loading

0 comments on commit b1b23e2

Please sign in to comment.