diff --git a/Makefile b/Makefile index 233bffb3..fbb4c4a3 100644 --- a/Makefile +++ b/Makefile @@ -113,6 +113,7 @@ test-cel: envtest apigen # This requires the extproc binary to be built as well as Envoy binary to be available in the PATH. .PHONY: test-extproc # This requires the extproc binary to be built. test-extproc: build.extproc + @$(MAKE) build.extproc_custom_router CMD_PATH_PREFIX=examples @$(MAKE) build.testupstream CMD_PATH_PREFIX=tests @echo "Run ExtProc test" @go test ./tests/extproc/... -tags test_extproc -v -count=1 @@ -140,6 +141,7 @@ test-e2e: kind # Example: # - `make build.controller`: will build the cmd/controller directory. # - `make build.extproc`: will build the cmd/extproc directory. +# - `make build.extproc_custom_router CMD_PATH_PREFIX=examples`: will build the examples/extproc_custom_router directory. # - `make build.testupstream CMD_PATH_PREFIX=tests`: will build the tests/testupstream directory. # # By default, this will build for the current GOOS and GOARCH. diff --git a/cmd/extproc/main.go b/cmd/extproc/main.go index b7b2b31d..35687fc9 100644 --- a/cmd/extproc/main.go +++ b/cmd/extproc/main.go @@ -1,78 +1,5 @@ package main -import ( - "context" - "flag" - "log" - "log/slog" - "net" - "os" - "os/signal" - "syscall" - "time" +import "github.com/envoyproxy/ai-gateway/cmd/extproc/mainlib" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "google.golang.org/grpc" - "google.golang.org/grpc/health/grpc_health_v1" - - "github.com/envoyproxy/ai-gateway/internal/extproc" - "github.com/envoyproxy/ai-gateway/internal/version" -) - -var ( - configPath = flag.String("configPath", "", "path to the configuration file. "+ - "The file must be in YAML format specified in extprocconfig.Config type. The configuration file is watched for changes.") - // TODO: unix domain socket support. - extProcPort = flag.String("extProcPort", ":1063", "gRPC port for the external processor") - logLevel = flag.String("logLevel", "info", "log level") -) - -func main() { - flag.Parse() - - var level slog.Level - if err := level.UnmarshalText([]byte(*logLevel)); err != nil { - log.Fatalf("failed to unmarshal log level: %v", err) - } - l := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: level, - })) - - l.Info("starting external processor", slog.String("version", version.Version)) - - if *configPath == "" { - log.Fatal("configPath must be provided") - } - - ctx, cancel := context.WithCancel(context.Background()) - signalsChan := make(chan os.Signal, 1) - signal.Notify(signalsChan, syscall.SIGINT, syscall.SIGTERM) - go func() { - <-signalsChan - cancel() - }() - - // TODO: unix domain socket support. - lis, err := net.Listen("tcp", *extProcPort) - if err != nil { - log.Fatalf("failed to listen: %v", err) - } - - server, err := extproc.NewServer[*extproc.Processor](l, extproc.NewProcessor) - if err != nil { - log.Fatalf("failed to create external processor server: %v", err) - } - - if err := extproc.StartConfigWatcher(ctx, *configPath, server, l, time.Second*5); err != nil { - log.Fatalf("failed to start config watcher: %v", err) - } - - s := grpc.NewServer() - extprocv3.RegisterExternalProcessorServer(s, server) - grpc_health_v1.RegisterHealthServer(s, server) - go func() { - <-ctx.Done() - s.GracefulStop() - }() - _ = s.Serve(lis) -} +func main() { mainlib.Main() } diff --git a/cmd/extproc/mainlib/main.go b/cmd/extproc/mainlib/main.go new file mode 100644 index 00000000..3c82b710 --- /dev/null +++ b/cmd/extproc/mainlib/main.go @@ -0,0 +1,80 @@ +package mainlib + +import ( + "context" + "flag" + "log" + "log/slog" + "net" + "os" + "os/signal" + "syscall" + "time" + + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/health/grpc_health_v1" + + "github.com/envoyproxy/ai-gateway/internal/extproc" + "github.com/envoyproxy/ai-gateway/internal/version" +) + +var ( + configPath = flag.String("configPath", "", "path to the configuration file. "+ + "The file must be in YAML format specified in extprocconfig.Config type. The configuration file is watched for changes.") + // TODO: unix domain socket support. + extProcPort = flag.String("extProcPort", ":1063", "gRPC port for the external processor") + logLevel = flag.String("logLevel", "info", "log level") +) + +// Main is a main function for the external processor exposed +// for allowing users to build their own external processor. +func Main() { + flag.Parse() + + var level slog.Level + if err := level.UnmarshalText([]byte(*logLevel)); err != nil { + log.Fatalf("failed to unmarshal log level: %v", err) + } + l := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: level, + })) + + l.Info("starting external processor", slog.String("version", version.Version)) + + if *configPath == "" { + log.Fatal("configPath must be provided") + } + + ctx, cancel := context.WithCancel(context.Background()) + signalsChan := make(chan os.Signal, 1) + signal.Notify(signalsChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-signalsChan + cancel() + }() + + // TODO: unix domain socket support. + lis, err := net.Listen("tcp", *extProcPort) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + + server, err := extproc.NewServer[*extproc.Processor](l, extproc.NewProcessor) + if err != nil { + log.Fatalf("failed to create external processor server: %v", err) + } + + if err := extproc.StartConfigWatcher(ctx, *configPath, server, l, time.Second*5); err != nil { + log.Fatalf("failed to start config watcher: %v", err) + } + + s := grpc.NewServer() + extprocv3.RegisterExternalProcessorServer(s, server) + grpc_health_v1.RegisterHealthServer(s, server) + go func() { + <-ctx.Done() + s.GracefulStop() + }() + _ = s.Serve(lis) +} diff --git a/examples/extproc_custom_router/README.md b/examples/extproc_custom_router/README.md new file mode 100644 index 00000000..af9d29cb --- /dev/null +++ b/examples/extproc_custom_router/README.md @@ -0,0 +1 @@ +This example shows how to insert a custom router in the custom external process using `filterconfig` package. diff --git a/examples/extproc_custom_router/main.go b/examples/extproc_custom_router/main.go new file mode 100644 index 00000000..49ad5120 --- /dev/null +++ b/examples/extproc_custom_router/main.go @@ -0,0 +1,41 @@ +package main + +import ( + "fmt" + + "github.com/envoyproxy/ai-gateway/cmd/extproc/mainlib" + "github.com/envoyproxy/ai-gateway/extprocapi" + "github.com/envoyproxy/ai-gateway/filterconfig" +) + +// newCustomRouter implements [extprocapi.NewCustomRouter]. +func newCustomRouter(defaultRouter extprocapi.Router, config *filterconfig.Config) extprocapi.Router { + // You can poke the current configuration of the routes, and the list of backends + // specified in the AIGatewayRoute.Rules, etc. + return &myCustomRouter{config: config, defaultRouter: defaultRouter} +} + +// myCustomRouter implements [extprocapi.Router]. +type myCustomRouter struct { + config *filterconfig.Config + defaultRouter extprocapi.Router +} + +// Calculate implements [extprocapi.Router.Calculate]. +func (m *myCustomRouter) Calculate(headers map[string]string) (backend *filterconfig.Backend, err error) { + // Simply logs the headers and delegates the calculation to the default router. + modelName, ok := headers[m.config.ModelNameHeaderKey] + if !ok { + panic("model name not found in the headers") + } + fmt.Printf("model name: %s\n", modelName) + return m.defaultRouter.Calculate(headers) +} + +// This demonstrates how to build a custom router for the external processor. +func main() { + // Initializes the custom router. + extprocapi.NewCustomRouter = newCustomRouter + // Executes the main function of the external processor. + mainlib.Main() +} diff --git a/extprocapi/exptorcapi.go b/extprocapi/exptorcapi.go new file mode 100644 index 00000000..7efbfa22 --- /dev/null +++ b/extprocapi/exptorcapi.go @@ -0,0 +1,29 @@ +// Package extprocapi is for building a custom external process. +package extprocapi + +import "github.com/envoyproxy/ai-gateway/filterconfig" + +// NewCustomRouter is the function to create a custom router over the default router. +// This is nil by default and can be set by the custom build of external processor. +var NewCustomRouter NewCustomRouterFn + +// NewCustomRouterFn is the function signature for [NewCustomRouter]. +// +// It accepts the exptproc config passed to the AI Gateway filter and returns a [Router]. +// This is called when the new configuration is loaded. +// +// The defaultRouter can be used to delegate the calculation to the default router implementation. +type NewCustomRouterFn func(defaultRouter Router, config *filterconfig.Config) Router + +// Router is the interface for the router. +// +// Router must be goroutine-safe as it is shared across multiple requests. +type Router interface { + // Calculate determines the backend to route to based on the request headers. + // + // The request headers include the populated [filterconfig.Config.ModelNameHeaderKey] + // with the parsed model name based on the [filterconfig.Config] given to the NewCustomRouterFn. + // + // Returns the backend. + Calculate(requestHeaders map[string]string) (backend *filterconfig.Backend, err error) +} diff --git a/internal/extproc/mocks_test.go b/internal/extproc/mocks_test.go index cf29d53b..7c3666f7 100644 --- a/internal/extproc/mocks_test.go +++ b/internal/extproc/mocks_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc/metadata" + "github.com/envoyproxy/ai-gateway/extprocapi" "github.com/envoyproxy/ai-gateway/filterconfig" "github.com/envoyproxy/ai-gateway/internal/extproc/router" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" @@ -19,7 +20,7 @@ import ( var ( _ ProcessorIface = &mockProcessor{} _ translator.Translator = &mockTranslator{} - _ router.Router = &mockRouter{} + _ extprocapi.Router = &mockRouter{} ) func newMockProcessor(_ *processorConfig) *mockProcessor { diff --git a/internal/extproc/processor.go b/internal/extproc/processor.go index ed0dfe15..fdb0c564 100644 --- a/internal/extproc/processor.go +++ b/internal/extproc/processor.go @@ -12,6 +12,7 @@ import ( extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "google.golang.org/protobuf/types/known/structpb" + "github.com/envoyproxy/ai-gateway/extprocapi" "github.com/envoyproxy/ai-gateway/filterconfig" "github.com/envoyproxy/ai-gateway/internal/extproc/backendauth" "github.com/envoyproxy/ai-gateway/internal/extproc/router" @@ -22,7 +23,7 @@ import ( // This will be created by the server and passed to the processor when it detects a new configuration. type processorConfig struct { bodyParser router.RequestBodyParser - router router.Router + router extprocapi.Router ModelNameHeaderKey, selectedBackendHeaderKey string factories map[filterconfig.VersionedAPISchema]translator.Factory backendAuthHandlers map[string]backendauth.Handler diff --git a/internal/extproc/router/router.go b/internal/extproc/router/router.go index 43b3f62c..3cd27d93 100644 --- a/internal/extproc/router/router.go +++ b/internal/extproc/router/router.go @@ -6,28 +6,27 @@ import ( "golang.org/x/exp/rand" + "github.com/envoyproxy/ai-gateway/extprocapi" "github.com/envoyproxy/ai-gateway/filterconfig" ) -// Router is the interface for the router. -type Router interface { - // Calculate determines the backend to route to based on the headers. - // Returns the backend name and the output schema. - Calculate(headers map[string]string) (backend *filterconfig.Backend, err error) -} - -// router implements [Router]. +// router implements [extprocapi.Router]. type router struct { rules []filterconfig.RouteRule rng *rand.Rand } -// NewRouter creates a new [Router] implementation for the given config. -func NewRouter(config *filterconfig.Config) (Router, error) { - return &router{rules: config.Rules, rng: rand.New(rand.NewSource(uint64(time.Now().UnixNano())))}, nil +// NewRouter creates a new [extprocapi.Router] implementation for the given config. +func NewRouter(config *filterconfig.Config, newCustomFn extprocapi.NewCustomRouterFn) (extprocapi.Router, error) { + r := &router{rules: config.Rules, rng: rand.New(rand.NewSource(uint64(time.Now().UnixNano())))} + if newCustomFn != nil { + customRouter := newCustomFn(r, config) + return customRouter, nil + } + return r, nil } -// Calculate implements [Router.Calculate]. +// Calculate implements [extprocapi.Router.Calculate]. func (r *router) Calculate(headers map[string]string) (backend *filterconfig.Backend, err error) { var rule *filterconfig.RouteRule for i := range r.rules { diff --git a/internal/extproc/router/router_test.go b/internal/extproc/router/router_test.go index 98109533..e8d3bac8 100644 --- a/internal/extproc/router/router_test.go +++ b/internal/extproc/router/router_test.go @@ -5,9 +5,34 @@ import ( "github.com/stretchr/testify/require" + "github.com/envoyproxy/ai-gateway/extprocapi" "github.com/envoyproxy/ai-gateway/filterconfig" ) +// dummyCustomRouter implements [extprocapi.Router]. +type dummyCustomRouter struct{ called bool } + +func (c *dummyCustomRouter) Calculate(map[string]string) (*filterconfig.Backend, error) { + c.called = true + return nil, nil +} + +func TestRouter_NewRouter_Custom(t *testing.T) { + r, err := NewRouter(&filterconfig.Config{}, func(defaultRouter extprocapi.Router, config *filterconfig.Config) extprocapi.Router { + require.NotNil(t, defaultRouter) + _, ok := defaultRouter.(*router) + require.True(t, ok) // Checking if the default router is correctly passed. + return &dummyCustomRouter{} + }) + require.NoError(t, err) + _, ok := r.(*dummyCustomRouter) + require.True(t, ok) + + _, err = r.Calculate(nil) + require.NoError(t, err) + require.True(t, r.(*dummyCustomRouter).called) +} + func TestRouter_Calculate(t *testing.T) { outSchema := filterconfig.VersionedAPISchema{Schema: filterconfig.APISchemaOpenAI} _r, err := NewRouter(&filterconfig.Config{ @@ -30,7 +55,7 @@ func TestRouter_Calculate(t *testing.T) { }, }, }, - }) + }, nil) require.NoError(t, err) r, ok := _r.(*router) require.True(t, ok) @@ -62,7 +87,7 @@ func TestRouter_Calculate(t *testing.T) { } func TestRouter_selectBackendFromRule(t *testing.T) { - _r, err := NewRouter(&filterconfig.Config{}) + _r, err := NewRouter(&filterconfig.Config{}, nil) require.NoError(t, err) r, ok := _r.(*router) require.True(t, ok) diff --git a/internal/extproc/server.go b/internal/extproc/server.go index 810d52cd..91874fc4 100644 --- a/internal/extproc/server.go +++ b/internal/extproc/server.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/status" + "github.com/envoyproxy/ai-gateway/extprocapi" "github.com/envoyproxy/ai-gateway/filterconfig" "github.com/envoyproxy/ai-gateway/internal/extproc/backendauth" "github.com/envoyproxy/ai-gateway/internal/extproc/router" @@ -37,7 +38,7 @@ func (s *Server[P]) LoadConfig(config *filterconfig.Config) error { if err != nil { return fmt.Errorf("cannot create request body parser: %w", err) } - rt, err := router.NewRouter(config) + rt, err := router.NewRouter(config, extprocapi.NewCustomRouter) if err != nil { return fmt.Errorf("cannot create router: %w", err) } diff --git a/tests/extproc/custom_extproc_test.go b/tests/extproc/custom_extproc_test.go new file mode 100644 index 00000000..dc187069 --- /dev/null +++ b/tests/extproc/custom_extproc_test.go @@ -0,0 +1,67 @@ +//go:build test_extproc + +package extproc + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "runtime" + "testing" + "time" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/filterconfig" +) + +// TestExtProcCustomRouter tests examples/extproc_custom_router. +func TestExtProcCustomRouter(t *testing.T) { + requireBinaries(t) + requireRunEnvoy(t, "/dev/null", "dummy") + requireTestUpstream(t) + configPath := t.TempDir() + "/extproc-config.yaml" + requireWriteExtProcConfig(t, configPath, &filterconfig.Config{ + InputSchema: openAISchema, + // This can be any header key, but it must match the envoy.yaml routing configuration. + SelectedBackendHeaderKey: "x-selected-backend-name", + ModelNameHeaderKey: "x-model-name", + Rules: []filterconfig.RouteRule{ + { + Backends: []filterconfig.Backend{{Name: "testupstream", OutputSchema: openAISchema}}, + Headers: []filterconfig.HeaderMatch{{Name: "x-model-name", Value: "something-cool"}}, + }, + }, + }) + stdout := &bytes.Buffer{} + requireExtProc(t, stdout, fmt.Sprintf("../../out/extproc_custom_router-%s-%s", + runtime.GOOS, runtime.GOARCH), configPath) + + require.Eventually(t, func() bool { + client := openai.NewClient(option.WithBaseURL(listenerAddress+"/v1/"), + option.WithHeader( + "x-expected-path", base64.StdEncoding.EncodeToString([]byte("/v1/chat/completions"))), + option.WithHeader("x-response-body", + base64.StdEncoding.EncodeToString([]byte(`{"choices":[{"message":{"content":"This is a test."}}]}`)), + )) + chatCompletion, err := client.Chat.Completions.New(context.Background(), openai.ChatCompletionNewParams{ + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Say this is a test"), + }), + Model: openai.F("something-cool"), + }) + if err != nil { + t.Logf("error: %v", err) + return false + } + for _, choice := range chatCompletion.Choices { + t.Logf("choice: %s", choice.Message.Content) + } + return true + }, 10*time.Second, 1*time.Second) + + require.Contains(t, stdout.String(), "model name: something-cool") // This must be logged by the custom router. +} diff --git a/tests/extproc/envoy.yaml b/tests/extproc/envoy.yaml index b6f76b15..83ed68e7 100644 --- a/tests/extproc/envoy.yaml +++ b/tests/extproc/envoy.yaml @@ -46,6 +46,14 @@ static_resources: - header: key: 'Authorization' value: 'Bearer TEST_OPENAI_API_KEY' + - match: + prefix: "/" + headers: + - name: x-selected-backend-name + string_match: + exact: testupstream + route: + cluster: testupstream http_filters: - name: envoy.filters.http.ext_proc typed_config: @@ -71,6 +79,19 @@ static_resources: suppressEnvoyHeaders: true clusters: + - name: testupstream + connect_timeout: 0.25s + type: STATIC + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: testupstream + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 127.0.0.1 + port_value: 8080 - name: extproc_cluster connect_timeout: 0.25s type: STATIC diff --git a/tests/extproc/extproc_test.go b/tests/extproc/extproc_test.go index 90caa038..b6a0efa3 100644 --- a/tests/extproc/extproc_test.go +++ b/tests/extproc/extproc_test.go @@ -9,6 +9,7 @@ import ( _ "embed" "encoding/json" "fmt" + "io" "os" "os/exec" "runtime" @@ -25,6 +26,8 @@ import ( "github.com/envoyproxy/ai-gateway/filterconfig" ) +const listenerAddress = "http://localhost:1062" + //go:embed envoy.yaml var envoyYamlBase string @@ -44,7 +47,8 @@ var ( func TestE2E(t *testing.T) { requireBinaries(t) accessLogPath := t.TempDir() + "/access.log" - requireRunEnvoy(t, accessLogPath) + openAIAPIKey := getEnvVarOrSkip(t, "TEST_OPENAI_API_KEY") + requireRunEnvoy(t, accessLogPath, openAIAPIKey) configPath := t.TempDir() + "/extproc-config.yaml" requireWriteExtProcConfig(t, configPath, &filterconfig.Config{ TokenUsageMetadata: &filterconfig.TokenUsageMetadata{ @@ -68,10 +72,10 @@ func TestE2E(t *testing.T) { }, }, }) - requireExtProc(t, configPath) + requireExtProcWithAWSCredentials(t, configPath) t.Run("health-checking", func(t *testing.T) { - client := openai.NewClient(option.WithBaseURL("http://localhost:1062/v1/")) + client := openai.NewClient(option.WithBaseURL(listenerAddress + "/v1/")) for _, tc := range []struct { testCaseName, modelName string @@ -139,28 +143,45 @@ func TestE2E(t *testing.T) { // TODO: add more tests like updating the config, signal handling, etc. } -// requireExtProc starts the external processor with the provided configPath. +// requireExtProcWithAWSCredentials starts the external processor with the provided executable and configPath +// with additional environment variables for AWS credentials. +// // The config must be in YAML format specified in [filterconfig.Config] type. -func requireExtProc(t *testing.T, configPath string) { +func requireExtProcWithAWSCredentials(t *testing.T, configPath string) { awsAccessKeyID := getEnvVarOrSkip(t, "TEST_AWS_ACCESS_KEY_ID") awsSecretAccessKey := getEnvVarOrSkip(t, "TEST_AWS_SECRET_ACCESS_KEY") - - cmd := exec.Command(extProcBinaryPath()) // #nosec G204 - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Args = append(cmd.Args, "-configPath", configPath) - cmd.Env = append(os.Environ(), + requireExtProc(t, os.Stdout, extProcExecutablePath(), configPath, fmt.Sprintf("AWS_ACCESS_KEY_ID=%s", awsAccessKeyID), fmt.Sprintf("AWS_SECRET_ACCESS_KEY=%s", awsSecretAccessKey), ) +} + +// requireExtProc starts the external processor with the provided executable and configPath +// with additional environment variables. +// +// The config must be in YAML format specified in [filterconfig.Config] type. +func requireExtProc(t *testing.T, stdout io.Writer, executable, configPath string, envs ...string) { + cmd := exec.Command(executable) + cmd.Stdout = stdout + cmd.Stderr = os.Stderr + cmd.Args = append(cmd.Args, "-configPath", configPath) + cmd.Env = append(os.Environ(), envs...) require.NoError(t, cmd.Start()) t.Cleanup(func() { _ = cmd.Process.Signal(os.Interrupt) }) } -// requireRunEnvoy starts the Envoy proxy with the provided configuration. -func requireRunEnvoy(t *testing.T, accessLogPath string) { - openAIAPIKey := getEnvVarOrSkip(t, "TEST_OPENAI_API_KEY") +func requireTestUpstream(t *testing.T) { + // Starts the Envoy proxy. + envoyCmd := exec.Command(testUpstreamExecutablePath()) // #nosec G204 + envoyCmd.Stdout = os.Stdout + envoyCmd.Stderr = os.Stderr + envoyCmd.Env = []string{"TESTUPSTREAM_ID=extproc_test"} + require.NoError(t, envoyCmd.Start()) + t.Cleanup(func() { _ = envoyCmd.Process.Signal(os.Interrupt) }) +} +// requireRunEnvoy starts the Envoy proxy with the provided configuration. +func requireRunEnvoy(t *testing.T, accessLogPath string, openAIAPIKey string) { tmpDir := t.TempDir() envoyYaml := strings.Replace(envoyYamlBase, "TEST_OPENAI_API_KEY", openAIAPIKey, 1) envoyYaml = strings.Replace(envoyYaml, "ACCESS_LOG_PATH", accessLogPath, 1) @@ -189,9 +210,15 @@ func requireBinaries(t *testing.T) { } // Check if the Extproc binary is present in the root of the repository - _, err = os.Stat(extProcBinaryPath()) + _, err = os.Stat(extProcExecutablePath()) + if err != nil { + t.Fatalf("%s binary not found in the root of the repository", extProcExecutablePath()) + } + + // Check if the TestUpstream binary is present in the root of the repository + _, err = os.Stat(testUpstreamExecutablePath()) if err != nil { - t.Fatalf("%s binary not found in the root of the repository", extProcBinaryPath()) + t.Fatalf("%s binary not found in the root of the repository", testUpstreamExecutablePath()) } } @@ -211,6 +238,10 @@ func requireWriteExtProcConfig(t *testing.T, configPath string, config *filterco require.NoError(t, os.WriteFile(configPath, configBytes, 0o600)) } -func extProcBinaryPath() string { +func extProcExecutablePath() string { return fmt.Sprintf("../../out/extproc-%s-%s", runtime.GOOS, runtime.GOARCH) } + +func testUpstreamExecutablePath() string { + return fmt.Sprintf("../../out/testupstream-%s-%s", runtime.GOOS, runtime.GOARCH) +} diff --git a/tests/testupstream/main.go b/tests/testupstream/main.go index 4ef0bd04..d5545315 100644 --- a/tests/testupstream/main.go +++ b/tests/testupstream/main.go @@ -180,23 +180,28 @@ func handler(w http.ResponseWriter, r *http.Request) { return } - expectedBody, err := base64.StdEncoding.DecodeString(r.Header.Get(expectedRequestBodyHeaderKey)) - if err != nil { - fmt.Println("failed to decode the expected request body") - http.Error(w, "failed to decode the expected request body", http.StatusBadRequest) - return - } - actual, err := io.ReadAll(r.Body) + requestBody, err := io.ReadAll(r.Body) if err != nil { fmt.Println("failed to read the request body") http.Error(w, "failed to read the request body", http.StatusInternalServerError) return } - if string(expectedBody) != string(actual) { - fmt.Println("unexpected request body: got", string(actual), "expected", string(expectedBody)) - http.Error(w, "unexpected request body: got "+string(actual)+", expected "+string(expectedBody), http.StatusBadRequest) - return + if r.Header.Get(expectedRequestBodyHeaderKey) != "" { + expectedBody, err := base64.StdEncoding.DecodeString(r.Header.Get(expectedRequestBodyHeaderKey)) + if err != nil { + fmt.Println("failed to decode the expected request body") + http.Error(w, "failed to decode the expected request body", http.StatusBadRequest) + return + } + + if string(expectedBody) != string(requestBody) { + fmt.Println("unexpected request body: got", string(requestBody), "expected", string(expectedBody)) + http.Error(w, "unexpected request body: got "+string(requestBody)+", expected "+string(expectedBody), http.StatusBadRequest) + return + } + } else { + fmt.Println("no expected request body") } responseBody, err := base64.StdEncoding.DecodeString(r.Header.Get(responseBodyHeaderKey))