Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sets the host automatically to Backend #155

Merged
merged 12 commits into from
Jan 22, 2025
57 changes: 50 additions & 7 deletions internal/controller/sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"context"
"fmt"

egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1"
"github.com/go-logr/logr"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/utils/ptr"
Expand All @@ -18,7 +20,10 @@ import (
"github.com/envoyproxy/ai-gateway/filterconfig"
)

const selectedBackendHeaderKey = "x-ai-eg-selected-backend"
const (
selectedBackendHeaderKey = "x-ai-eg-selected-backend"
hostRewriteHTTPFilterName = "ai-eg-host-rewrite"
)

// mountedExtProcSecretPath specifies the secret file mounted on the external proc. The idea is to update the mounted
//
Expand Down Expand Up @@ -110,16 +115,39 @@ func (c *configSink) handleEvent(event ConfigSinkEvent) {
}

func (c *configSink) syncAIGatewayRoute(aiGatewayRoute *aigv1a1.AIGatewayRoute) {
// Check if the HTTPRouteFilter exists in the namespace.
var httpRouteFilter egv1a1.HTTPRouteFilter
err := c.client.Get(context.Background(),
client.ObjectKey{Name: hostRewriteHTTPFilterName, Namespace: aiGatewayRoute.Namespace}, &httpRouteFilter)
if apierrors.IsNotFound(err) {
httpRouteFilter = egv1a1.HTTPRouteFilter{
ObjectMeta: metav1.ObjectMeta{
Name: hostRewriteHTTPFilterName,
Namespace: aiGatewayRoute.Namespace,
},
Spec: egv1a1.HTTPRouteFilterSpec{
URLRewrite: &egv1a1.HTTPURLRewriteFilter{
Hostname: &egv1a1.HTTPHostnameModifier{
Type: egv1a1.BackendHTTPHostnameModifier,
},
},
},
}
if err := c.client.Create(context.Background(), &httpRouteFilter); err != nil {
c.logger.Error(err, "failed to create HTTPRouteFilter", "namespace", aiGatewayRoute.Namespace, "name", hostRewriteHTTPFilterName)
return
}
} else if err != nil {
c.logger.Error(err, "failed to get HTTPRouteFilter", "namespace", aiGatewayRoute.Namespace, "name", hostRewriteHTTPFilterName, "error", err)
return
}

// Check if the HTTPRoute exists.
c.logger.Info("syncing AIGatewayRoute", "namespace", aiGatewayRoute.Namespace, "name", aiGatewayRoute.Name)
var httpRoute gwapiv1.HTTPRoute
err := c.client.Get(context.Background(), client.ObjectKey{Name: aiGatewayRoute.Name, Namespace: aiGatewayRoute.Namespace}, &httpRoute)
err = c.client.Get(context.Background(), client.ObjectKey{Name: aiGatewayRoute.Name, Namespace: aiGatewayRoute.Namespace}, &httpRoute)
existingRoute := err == nil
if client.IgnoreNotFound(err) != nil {
c.logger.Error(err, "failed to get HTTPRoute", "namespace", aiGatewayRoute.Namespace, "name", aiGatewayRoute.Name)
return
}
if !existingRoute {
if apierrors.IsNotFound(err) {
// This means that this AIGatewayRoute is a new one.
httpRoute = gwapiv1.HTTPRoute{
ObjectMeta: metav1.ObjectMeta{
Expand All @@ -129,6 +157,9 @@ func (c *configSink) syncAIGatewayRoute(aiGatewayRoute *aigv1a1.AIGatewayRoute)
},
Spec: gwapiv1.HTTPRouteSpec{},
}
} else if err != nil {
c.logger.Error(err, "failed to get HTTPRoute", "namespace", aiGatewayRoute.Namespace, "name", aiGatewayRoute.Name, "error", err)
return
}

// Update the HTTPRoute with the new AIGatewayRoute.
Expand Down Expand Up @@ -300,6 +331,16 @@ func (c *configSink) newHTTPRoute(dst *gwapiv1.HTTPRoute, aiGatewayRoute *aigv1a
}
}

rewriteFilters := []gwapiv1.HTTPRouteFilter{
{
Type: gwapiv1.HTTPRouteFilterExtensionRef,
ExtensionRef: &gwapiv1.LocalObjectReference{
Group: "gateway.envoyproxy.io",
Kind: "HTTPRouteFilter",
Name: hostRewriteHTTPFilterName,
},
},
}
rules := make([]gwapiv1.HTTPRouteRule, len(backends))
for i, b := range backends {
key := fmt.Sprintf("%s.%s", b.Name, b.Namespace)
Expand All @@ -310,6 +351,7 @@ func (c *configSink) newHTTPRoute(dst *gwapiv1.HTTPRoute, aiGatewayRoute *aigv1a
Matches: []gwapiv1.HTTPRouteMatch{
{Headers: []gwapiv1.HTTPHeaderMatch{{Name: selectedBackendHeaderKey, Value: key}}},
},
Filters: rewriteFilters,
}
rules[i] = rule
}
Expand All @@ -322,6 +364,7 @@ func (c *configSink) newHTTPRoute(dst *gwapiv1.HTTPRoute, aiGatewayRoute *aigv1a
BackendRefs: []gwapiv1.HTTPBackendRef{
{BackendRef: gwapiv1.BackendRef{BackendObjectReference: backends[0].Spec.BackendRef}},
},
Filters: rewriteFilters,
})

dst.Spec.Rules = rules
Expand Down
17 changes: 14 additions & 3 deletions internal/controller/sink_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1"
"github.com/go-logr/logr"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -81,10 +82,8 @@ func TestConfigSink_syncAIGatewayRoute(t *testing.T) {
}, metav1.CreateOptions{})
require.NoError(t, err)

// Then sync.
// Then sync, which should update the HTTPRoute.
s.syncAIGatewayRoute(route)
// Referencing backends should be updated.
// Also HTTPRoute should be updated.
var updatedHTTPRoute gwapiv1.HTTPRoute
err = fakeClient.Get(context.Background(), client.ObjectKey{Name: "route1", Namespace: "ns1"}, &updatedHTTPRoute)
require.NoError(t, err)
Expand All @@ -98,6 +97,12 @@ func TestConfigSink_syncAIGatewayRoute(t *testing.T) {
require.Equal(t, "some-backend1", string(updatedHTTPRoute.Spec.Rules[2].BackendRefs[0].BackendRef.Name))
require.Equal(t, "/", *updatedHTTPRoute.Spec.Rules[2].Matches[0].Path.Value)
})

// Check the namespace has the default host rewrite filter.
var f egv1a1.HTTPRouteFilter
err := s.client.Get(context.Background(), client.ObjectKey{Name: hostRewriteHTTPFilterName, Namespace: "ns1"}, &f)
require.NoError(t, err)
require.Equal(t, hostRewriteHTTPFilterName, f.Name)
}

func TestConfigSink_syncAIServiceBackend(t *testing.T) {
Expand Down Expand Up @@ -221,6 +226,12 @@ func Test_newHTTPRoute(t *testing.T) {
require.Equal(t, expRules[i].Matches, r.Matches)
require.Equal(t, expRules[i].BackendRefs, r.BackendRefs)
}

// Each rule should have a host rewrite filter by default.
require.Len(t, r.Filters, 1)
require.Equal(t, gwapiv1.HTTPRouteFilterExtensionRef, r.Filters[0].Type)
require.NotNil(t, r.Filters[0].ExtensionRef)
require.Equal(t, hostRewriteHTTPFilterName, string(r.Filters[0].ExtensionRef.Name))
})
}
}
Expand Down
22 changes: 22 additions & 0 deletions tests/controller/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,32 @@ func TestStartControllers(t *testing.T) {
require.Len(t, httpRoute.Spec.Rules[1].Matches[0].Headers, 1)
require.Equal(t, "x-ai-eg-selected-backend", string(httpRoute.Spec.Rules[1].Matches[0].Headers[0].Name))
require.Equal(t, "backend2.default", httpRoute.Spec.Rules[1].Matches[0].Headers[0].Value)

// Check all rule has the host rewrite filter.
for _, rule := range httpRoute.Spec.Rules {
require.Len(t, rule.Filters, 1)
require.NotNil(t, rule.Filters[0].ExtensionRef)
require.Equal(t, "ai-eg-host-rewrite", string(rule.Filters[0].ExtensionRef.Name))
}
return true
}, 30*time.Second, 200*time.Millisecond)
})
}

// Check if the host rewrite filter exists in the default namespace.
t.Run("verify host rewrite filter", func(t *testing.T) {
require.Eventually(t, func() bool {
var filter egv1a1.HTTPRouteFilter
err := c.Get(ctx, client.ObjectKey{Name: "ai-eg-host-rewrite", Namespace: "default"}, &filter)
if err != nil {
t.Logf("failed to get filter: %v", err)
return false
}
require.Equal(t, "default", filter.Namespace)
require.Equal(t, "ai-eg-host-rewrite", filter.Name)
return true
}, 30*time.Second, 200*time.Millisecond)
})
}

func TestAIGatewayRouteController(t *testing.T) {
Expand Down
30 changes: 26 additions & 4 deletions tests/e2e/testdata/translation_testupstream.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ spec:
name: OpenAI
backendRef:
name: testupstream
kind: Service
port: 80
kind: Backend
group: gateway.envoyproxy.io
---
apiVersion: aigateway.envoyproxy.io/v1alpha1
kind: AIServiceBackend
Expand All @@ -70,5 +70,27 @@ spec:
name: AWSBedrock
backendRef:
name: testupstream-canary
kind: Service
port: 80
kind: Backend
group: gateway.envoyproxy.io
---
apiVersion: gateway.envoyproxy.io/v1alpha1
kind: Backend
metadata:
name: testupstream
namespace: default
spec:
endpoints:
- fqdn:
hostname: testupstream.default.svc.cluster.local
port: 80
---
apiVersion: gateway.envoyproxy.io/v1alpha1
kind: Backend
metadata:
name: testupstream-canary
namespace: default
spec:
endpoints:
- fqdn:
hostname: testupstream-canary.default.svc.cluster.local
port: 80
7 changes: 6 additions & 1 deletion tests/e2e/translation_testupstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func TestTranslationWithTestUpstream(t *testing.T) {
for _, tc := range []struct {
name string
modelName string
expHost string
expTestUpstreamID string
expPath string
fakeResponseBody string
Expand All @@ -43,12 +44,14 @@ func TestTranslationWithTestUpstream(t *testing.T) {
modelName: "some-cool-model",
expTestUpstreamID: "primary",
expPath: "/v1/chat/completions",
expHost: "testupstream.default.svc.cluster.local",
fakeResponseBody: `{"choices":[{"message":{"content":"This is a test."}}]}`,
},
{
name: "aws-bedrock",
modelName: "another-cool-model",
expTestUpstreamID: "canary",
expHost: "testupstream-canary.default.svc.cluster.local",
expPath: "/model/another-cool-model/converse",
fakeResponseBody: `{"output":{"message":{"content":[{"text":"response"},{"text":"from"},{"text":"assistant"}],"role":"assistant"}},"stopReason":null,"usage":{"inputTokens":10,"outputTokens":20,"totalTokens":30}}`,
},
Expand All @@ -66,7 +69,9 @@ func TestTranslationWithTestUpstream(t *testing.T) {
"x-expected-path", base64.StdEncoding.EncodeToString([]byte(tc.expPath))),
option.WithHeader("x-response-body",
base64.StdEncoding.EncodeToString([]byte(tc.fakeResponseBody)),
))
),
option.WithHeader("x-expected-host", tc.expHost),
)

chatCompletion, err := client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
Expand Down
15 changes: 9 additions & 6 deletions tests/envtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@ func NewEnvTest(t *testing.T) (c client.Client, cfg *rest.Config, k kubernetes.I
for _, file := range files {
crds = append(crds, filepath.Join(crdPath, file.Name()))
}
const (
extensionPolicyURL = "https://raw.githubusercontent.com/envoyproxy/gateway/refs/tags/v1.2.4/charts/gateway-helm/crds/generated/gateway.envoyproxy.io_envoyextensionpolicies.yaml"
httpRouteURL = "https://raw.githubusercontent.com/kubernetes-sigs/gateway-api/refs/tags/v1.2.1/config/crd/standard/gateway.networking.k8s.io_httproutes.yaml"
)
crds = append(crds, requireThirdPartyCRDDownloaded(t, "envoyextensionpolicies_crd_for_tests.yaml", extensionPolicyURL))
crds = append(crds, requireThirdPartyCRDDownloaded(t, "httproutes_crd_for_tests.yaml", httpRouteURL))

for _, url := range []string{
"https://raw.githubusercontent.com/envoyproxy/gateway/refs/tags/v1.2.4/charts/gateway-helm/crds/generated/gateway.envoyproxy.io_envoyextensionpolicies.yaml",
"https://raw.githubusercontent.com/envoyproxy/gateway/refs/tags/v1.2.5/charts/gateway-helm/crds/generated/gateway.envoyproxy.io_httproutefilters.yaml",
"https://raw.githubusercontent.com/kubernetes-sigs/gateway-api/refs/tags/v1.2.1/config/crd/standard/gateway.networking.k8s.io_httproutes.yaml",
} {
path := filepath.Base(url) + "_for_tests.yaml"
crds = append(crds, requireThirdPartyCRDDownloaded(t, path, url))
}

env := &envtest.Environment{CRDDirectoryPaths: crds}
cfg, err = env.Start()
Expand Down
12 changes: 12 additions & 0 deletions tests/testupstream/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ const (
// If the values do not match, the request will be rejected, meaning that the request
// was routed to the wrong upstream.
expectedTestUpstreamIDKey = "x-expected-testupstream-id"
// expectedHostKey is the key for the expected host in the request.
expectedHostKey = "x-expected-host"
)

// main starts a server that listens on port 1063 and responds with the expected response body and headers
Expand Down Expand Up @@ -94,6 +96,16 @@ func handler(w http.ResponseWriter, r *http.Request) {
for k, v := range r.Header {
logger.Printf("header %q: %s\n", k, v)
}
if v := r.Header.Get(expectedHostKey); v != "" {
if r.Host != v {
fmt.Printf("unexpected host: got %q, expected %q\n", r.Host, v)
http.Error(w, "unexpected host: got "+r.Host+", expected "+v, http.StatusBadRequest)
return
}
fmt.Println("host matched:", v)
} else {
fmt.Println("no expected host: got", r.Host)
}
if v := r.Header.Get(expectedHeadersKey); v != "" {
expectedHeaders, err := base64.StdEncoding.DecodeString(v)
if err != nil {
Expand Down
Loading