Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extproc: refactors router code #105

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions internal/extproc/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ type mockRouter struct {
}

// Calculate implements [router.Router.Calculate].
func (m mockRouter) Calculate(headers map[string]string) (string, filterconfig.VersionedAPISchema, error) {
func (m mockRouter) Calculate(headers map[string]string) (*filterconfig.Backend, error) {
require.Equal(m.t, m.expHeaders, headers)
return m.retBackendName, m.retVersionedAPISchema, m.retErr
b := &filterconfig.Backend{Name: m.retBackendName, OutputSchema: m.retVersionedAPISchema}
return b, m.retErr
}

// mockRequestBodyParser implements [router.RequestBodyParser] for testing.
Expand Down
10 changes: 5 additions & 5 deletions internal/extproc/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ func (p *Processor) ProcessRequestBody(_ context.Context, rawBody *extprocv3.Htt
}

p.requestHeaders[p.config.ModelNameHeaderKey] = model
backendName, outputSchema, err := p.config.router.Calculate(p.requestHeaders)
b, err := p.config.router.Calculate(p.requestHeaders)
if err != nil {
return nil, fmt.Errorf("failed to calculate route: %w", err)
}

factory, ok := p.config.factories[outputSchema]
factory, ok := p.config.factories[b.OutputSchema]
if !ok {
return nil, fmt.Errorf("failed to find factory for output schema %q", outputSchema)
return nil, fmt.Errorf("failed to find factory for output schema %q", b.OutputSchema)
}

t, err := factory(path)
Expand All @@ -101,10 +101,10 @@ func (p *Processor) ProcessRequestBody(_ context.Context, rawBody *extprocv3.Htt
headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{
Header: &corev3.HeaderValue{Key: p.config.ModelNameHeaderKey, RawValue: []byte(model)},
}, &corev3.HeaderValueOption{
Header: &corev3.HeaderValue{Key: p.config.selectedBackendHeaderKey, RawValue: []byte(backendName)},
Header: &corev3.HeaderValue{Key: p.config.selectedBackendHeaderKey, RawValue: []byte(b.Name)},
})

if authHandler, ok := p.config.backendAuthHandlers[backendName]; ok {
if authHandler, ok := p.config.backendAuthHandlers[b.Name]; ok {
if err := authHandler.Do(p.requestHeaders, headerMutation, bodyMutation); err != nil {
return nil, fmt.Errorf("failed to do auth request: %w", err)
}
Expand Down
20 changes: 10 additions & 10 deletions internal/extproc/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
type Router interface {
// Calculate determines the backend to route to based on the headers.
// Returns the backend name and the output schema.
Calculate(headers map[string]string) (backendName string, outputSchema filterconfig.VersionedAPISchema, err error)
Calculate(headers map[string]string) (backend *filterconfig.Backend, err error)
}

// router implements [Router].
Expand All @@ -28,7 +28,7 @@ func NewRouter(config *filterconfig.Config) (Router, error) {
}

// Calculate implements [Router.Calculate].
func (r *router) Calculate(headers map[string]string) (backendName string, outputSchema filterconfig.VersionedAPISchema, err error) {
func (r *router) Calculate(headers map[string]string) (backend *filterconfig.Backend, err error) {
var rule *filterconfig.RouteRule
for i := range r.rules {
_rule := &r.rules[i]
Expand All @@ -42,28 +42,28 @@ func (r *router) Calculate(headers map[string]string) (backendName string, outpu
}
}
if rule == nil {
return "", filterconfig.VersionedAPISchema{}, errors.New("no matching rule found")
return nil, errors.New("no matching rule found")
}
backendName, outputSchema = r.selectBackendFromRule(rule)
return
return r.selectBackendFromRule(rule), nil
}

func (r *router) selectBackendFromRule(rule *filterconfig.RouteRule) (backendName string, outputSchema filterconfig.VersionedAPISchema) {
func (r *router) selectBackendFromRule(rule *filterconfig.RouteRule) (backend *filterconfig.Backend) {
// Each backend has a weight, so we randomly select depending on the weight.
// This is a pretty naive implementation and can be buggy, so fix it later.
totalWeight := 0
for _, b := range rule.Backends {
totalWeight += b.Weight
}
if totalWeight == 0 {
return rule.Backends[0].Name, rule.Backends[0].OutputSchema
return &rule.Backends[0]
}
selected := r.rng.Intn(totalWeight)
for _, b := range rule.Backends {
for i := range rule.Backends {
b := &rule.Backends[i]
if selected < b.Weight {
return b.Name, b.OutputSchema
return b
}
selected -= b.Weight
}
return rule.Backends[0].Name, rule.Backends[0].OutputSchema
return &rule.Backends[0]
}
23 changes: 11 additions & 12 deletions internal/extproc/router/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,24 @@ func TestRouter_Calculate(t *testing.T) {
require.True(t, ok)

t.Run("no matching rule", func(t *testing.T) {
backendName, outputSchema, err := r.Calculate(map[string]string{"x-model-name": "something-quirky"})
b, err := r.Calculate(map[string]string{"x-model-name": "something-quirky"})
require.Error(t, err)
require.Empty(t, backendName)
require.Empty(t, outputSchema)
require.Nil(t, b)
})
t.Run("matching rule - single backend choice", func(t *testing.T) {
backendName, outputSchema, err := r.Calculate(map[string]string{"x-model-name": "gpt4.4444"})
b, err := r.Calculate(map[string]string{"x-model-name": "gpt4.4444"})
require.NoError(t, err)
require.Equal(t, "openai", backendName)
require.Equal(t, outSchema, outputSchema)
require.Equal(t, "openai", b.Name)
require.Equal(t, outSchema, b.OutputSchema)
})
t.Run("matching rule - multiple backend choices", func(t *testing.T) {
chosenNames := make(map[string]int)
for i := 0; i < 1000; i++ {
backendName, outputSchema, err := r.Calculate(map[string]string{"x-model-name": "llama3.3333"})
b, err := r.Calculate(map[string]string{"x-model-name": "llama3.3333"})
require.NoError(t, err)
chosenNames[backendName]++
require.Contains(t, []string{"foo", "bar"}, backendName)
require.Equal(t, outSchema, outputSchema)
chosenNames[b.Name]++
require.Contains(t, []string{"foo", "bar"}, b.Name)
require.Equal(t, outSchema, b.OutputSchema)
}
require.Greater(t, chosenNames["bar"], chosenNames["foo"])
require.Greater(t, chosenNames["bar"], 700)
Expand All @@ -79,8 +78,8 @@ func TestRouter_selectBackendFromRule(t *testing.T) {

chosenNames := make(map[string]int)
for i := 0; i < 1000; i++ {
backendName, _ := r.selectBackendFromRule(rule)
chosenNames[backendName]++
b := r.selectBackendFromRule(rule)
chosenNames[b.Name]++
}

require.Greater(t, chosenNames["bar"], chosenNames["foo"])
Expand Down
Loading