diff --git a/filterconfig/filterconfig.go b/filterconfig/filterconfig.go index 8956fe1e8..d2ce45be9 100644 --- a/filterconfig/filterconfig.go +++ b/filterconfig/filterconfig.go @@ -71,6 +71,13 @@ type Config struct { // LLMRequestCost configures the cost of each LLM-related request. Optional. If this is provided, the filter will populate // the "calculated" cost in the filter metadata at the end of the response body processing. LLMRequestCosts []LLMRequestCost `json:"llmRequestCosts,omitempty"` + // MonitorContinuousUsageStats flag controls if external process monitors every response-body chunk for usage stats + // when true, it will monitor for token metadata usage in every response-body chunk received during request in streaming mode + // compatible with vllm's 'continuous_usage_stats' flag + // when false, it will stop monitoring after detecting token metadata usage after finding it for the first time. + // compatible with OpenAI's streaming response (https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-usage) + // Only affects request in streaming mode + MonitorContinuousUsageStats bool `yaml:"monitorContinuousUsageStats,omitempty"` // InputSchema specifies the API schema of the input format of requests to the filter. Schema VersionedAPISchema `json:"schema"` // ModelNameHeaderKey is the header key to be populated with the model name by the filter. diff --git a/filterconfig/filterconfig_test.go b/filterconfig/filterconfig_test.go index b55acb580..82250acd7 100644 --- a/filterconfig/filterconfig_test.go +++ b/filterconfig/filterconfig_test.go @@ -74,6 +74,7 @@ rules: require.Equal(t, "OpenAI", string(cfg.Schema.Name)) require.Equal(t, "x-ai-eg-selected-backend", cfg.SelectedBackendHeaderKey) require.Equal(t, "x-ai-eg-model", cfg.ModelNameHeaderKey) + require.Len(t, cfg.Rules, 2) require.Equal(t, "llama3.3333", cfg.Rules[0].Headers[0].Value) require.Equal(t, "gpt4.4444", cfg.Rules[1].Headers[0].Value) diff --git a/internal/extproc/processor.go b/internal/extproc/processor.go index dcc7b3b21..ce5b4abed 100644 --- a/internal/extproc/processor.go +++ b/internal/extproc/processor.go @@ -246,6 +246,7 @@ func (p *Processor) maybeBuildDynamicMetadata() (*structpb.Struct, error) { if len(metadata) == 0 { return nil, nil } + return &structpb.Struct{ Fields: map[string]*structpb.Value{ p.config.metadataNamespace: { diff --git a/internal/extproc/server.go b/internal/extproc/server.go index 1e14b6b16..341ac3e80 100644 --- a/internal/extproc/server.go +++ b/internal/extproc/server.go @@ -50,7 +50,7 @@ func (s *Server[P]) LoadConfig(config *filterconfig.Config) error { for _, r := range config.Rules { for _, b := range r.Backends { if _, ok := factories[b.Schema]; !ok { - factories[b.Schema], err = translator.NewFactory(config.Schema, b.Schema) + factories[b.Schema], err = translator.NewFactory(config.Schema, b.Schema, config.MonitorContinuousUsageStats) if err != nil { return fmt.Errorf("cannot create translator factory: %w", err) } diff --git a/internal/extproc/translator/openai_awsbedrock.go b/internal/extproc/translator/openai_awsbedrock.go index 67a34b7ab..6294180c8 100644 --- a/internal/extproc/translator/openai_awsbedrock.go +++ b/internal/extproc/translator/openai_awsbedrock.go @@ -555,7 +555,6 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(respHeaders if err := json.NewDecoder(body).Decode(&bedrockResp); err != nil { return nil, nil, tokenUsage, fmt.Errorf("failed to unmarshal body: %w", err) } - openAIResp := openai.ChatCompletionResponse{ Object: "chat.completion", Choices: make([]openai.ChatCompletionResponseChoice, 0, len(bedrockResp.Output.Message.Content)), diff --git a/internal/extproc/translator/openai_openai.go b/internal/extproc/translator/openai_openai.go index ce16d55ba..f58286ae7 100644 --- a/internal/extproc/translator/openai_openai.go +++ b/internal/extproc/translator/openai_openai.go @@ -14,21 +14,24 @@ import ( "github.com/envoyproxy/ai-gateway/internal/extproc/router" ) -// newOpenAIToOpenAITranslator implements [TranslatorFactory] for OpenAI to OpenAI translation. -func newOpenAIToOpenAITranslator(path string) (Translator, error) { - if path == "/v1/chat/completions" { - return &openAIToOpenAITranslatorV1ChatCompletion{}, nil - } else { - return nil, fmt.Errorf("unsupported path: %s", path) +// newOpenAIToOpenAITranslatorFactory implements [TranslatorFactory] for OpenAI to OpenAI translation. +func newOpenAIToOpenAITranslatorFactory(monitorContinuousUsageStats bool) Factory { + return func(path string) (Translator, error) { + if path == "/v1/chat/completions" { + return &openAIToOpenAITranslatorV1ChatCompletion{monitorContinuousUsageStats: monitorContinuousUsageStats}, nil + } else { + return nil, fmt.Errorf("unsupported path: %s", path) + } } } // openAIToOpenAITranslatorV1ChatCompletion implements [Translator] for /v1/chat/completions. type openAIToOpenAITranslatorV1ChatCompletion struct { defaultTranslator - stream bool - buffered []byte - bufferingDone bool + stream bool + buffered []byte + bufferingDone bool + monitorContinuousUsageStats bool } // RequestBody implements [RequestBody]. @@ -96,8 +99,12 @@ func (o *openAIToOpenAITranslatorV1ChatCompletion) ResponseBody(respHeaders map[ } } if o.stream { - if !o.bufferingDone { - buf, err := io.ReadAll(body) + if !o.bufferingDone || o.monitorContinuousUsageStats { + // OpenAI's api suggests that usage info will only be sent in the last chunk (https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-usage) + // whereas vllm model server supports including usage-info in each returned chunk. + // To incorporate both approaches, we check for usage-info in each chunk + var buf []byte + buf, err = io.ReadAll(body) if err != nil { return nil, nil, tokenUsage, fmt.Errorf("failed to read body: %w", err) } diff --git a/internal/extproc/translator/openai_openai_test.go b/internal/extproc/translator/openai_openai_test.go index b114edb33..f87bdef94 100644 --- a/internal/extproc/translator/openai_openai_test.go +++ b/internal/extproc/translator/openai_openai_test.go @@ -20,11 +20,11 @@ import ( func TestNewOpenAIToOpenAITranslator(t *testing.T) { t.Run("unsupported path", func(t *testing.T) { - _, err := newOpenAIToOpenAITranslator("/v1/foo/bar") + _, err := newOpenAIToOpenAITranslatorFactory(false)("/v1/foo/bar") require.Error(t, err) }) t.Run("/v1/chat/completions", func(t *testing.T) { - translator, err := newOpenAIToOpenAITranslator("/v1/chat/completions") + translator, err := newOpenAIToOpenAITranslatorFactory(false)("/v1/chat/completions") require.NoError(t, err) require.NotNil(t, translator) }) @@ -184,6 +184,12 @@ data: [DONE] if tokenUsage.OutputTokens > 0 { require.Equal(t, uint32(12), tokenUsage.OutputTokens) } + if tokenUsage.InputTokens > 0 { + require.Equal(t, uint32(13), tokenUsage.InputTokens) + } + if tokenUsage.TotalTokens > 0 { + require.Equal(t, uint32(25), tokenUsage.TotalTokens) + } } }) t.Run("non-streaming", func(t *testing.T) { @@ -194,13 +200,17 @@ data: [DONE] }) t.Run("valid body", func(t *testing.T) { var resp openai.ChatCompletionResponse - resp.Usage.TotalTokens = 42 + resp.Usage = openai.ChatCompletionResponseUsage{ + PromptTokens: 11, + CompletionTokens: 22, + TotalTokens: 33, + } body, err := json.Marshal(resp) require.NoError(t, err) o := &openAIToOpenAITranslatorV1ChatCompletion{} _, _, usedToken, err := o.ResponseBody(nil, bytes.NewBuffer(body), false) require.NoError(t, err) - require.Equal(t, LLMTokenUsage{TotalTokens: 42}, usedToken) + require.Equal(t, LLMTokenUsage{InputTokens: 11, OutputTokens: 22, TotalTokens: 33}, usedToken) }) }) } @@ -208,18 +218,18 @@ data: [DONE] func TestExtractUsageFromBufferEvent(t *testing.T) { t.Run("valid usage data", func(t *testing.T) { o := &openAIToOpenAITranslatorV1ChatCompletion{} - o.buffered = []byte("data: {\"usage\": {\"total_tokens\": 42}}\n") + o.buffered = []byte("data: {\"usage\": {\"completion_tokens\":22,\"prompt_tokens\":11,\"total_tokens\": 33}}\n") usedToken := o.extractUsageFromBufferEvent() - require.Equal(t, LLMTokenUsage{TotalTokens: 42}, usedToken) + require.Equal(t, LLMTokenUsage{TotalTokens: 33, InputTokens: 11, OutputTokens: 22}, usedToken) require.True(t, o.bufferingDone) require.Nil(t, o.buffered) }) t.Run("valid usage data after invalid", func(t *testing.T) { o := &openAIToOpenAITranslatorV1ChatCompletion{} - o.buffered = []byte("data: invalid\ndata: {\"usage\": {\"total_tokens\": 42}}\n") + o.buffered = []byte("data: invalid\ndata: {\"usage\": {\"completion_tokens\":22,\"prompt_tokens\":11,\"total_tokens\": 33}}\n") usedToken := o.extractUsageFromBufferEvent() - require.Equal(t, LLMTokenUsage{TotalTokens: 42}, usedToken) + require.Equal(t, LLMTokenUsage{TotalTokens: 33, InputTokens: 11, OutputTokens: 22}, usedToken) require.True(t, o.bufferingDone) require.Nil(t, o.buffered) }) @@ -232,9 +242,9 @@ func TestExtractUsageFromBufferEvent(t *testing.T) { require.False(t, o.bufferingDone) require.NotNil(t, o.buffered) - o.buffered = append(o.buffered, []byte("{\"usage\": {\"total_tokens\": 42}}\n")...) + o.buffered = append(o.buffered, []byte("{\"usage\": {\"completion_tokens\":22,\"prompt_tokens\":11,\"total_tokens\": 33}}\n")...) usedToken = o.extractUsageFromBufferEvent() - require.Equal(t, LLMTokenUsage{TotalTokens: 42}, usedToken) + require.Equal(t, LLMTokenUsage{TotalTokens: 33, InputTokens: 11, OutputTokens: 22}, usedToken) require.True(t, o.bufferingDone) require.Nil(t, o.buffered) }) diff --git a/internal/extproc/translator/translator.go b/internal/extproc/translator/translator.go index ed2eb603a..5ec29e53d 100644 --- a/internal/extproc/translator/translator.go +++ b/internal/extproc/translator/translator.go @@ -35,12 +35,12 @@ func isGoodStatusCode(code int) bool { type Factory func(path string) (Translator, error) // NewFactory returns a callback function that creates a translator for the given API schema combination. -func NewFactory(in, out filterconfig.VersionedAPISchema) (Factory, error) { +func NewFactory(in, out filterconfig.VersionedAPISchema, monitorContinuousUsageStats bool) (Factory, error) { if in.Name == filterconfig.APISchemaOpenAI { // TODO: currently, we ignore the LLMAPISchema."Version" field. switch out.Name { case filterconfig.APISchemaOpenAI: - return newOpenAIToOpenAITranslator, nil + return newOpenAIToOpenAITranslatorFactory(monitorContinuousUsageStats), nil case filterconfig.APISchemaAWSBedrock: return newOpenAIToAWSBedrockTranslator, nil } diff --git a/internal/extproc/translator/translator_test.go b/internal/extproc/translator/translator_test.go index 64c420236..b89a183b0 100644 --- a/internal/extproc/translator/translator_test.go +++ b/internal/extproc/translator/translator_test.go @@ -13,6 +13,7 @@ func TestNewFactory(t *testing.T) { _, err := NewFactory( filterconfig.VersionedAPISchema{Name: "Foo", Version: "v100"}, filterconfig.VersionedAPISchema{Name: "Bar", Version: "v123"}, + false, ) require.ErrorContains(t, err, "unsupported API schema combination: client={Foo v100}, backend={Bar v123}") }) @@ -20,6 +21,7 @@ func TestNewFactory(t *testing.T) { f, err := NewFactory( filterconfig.VersionedAPISchema{Name: filterconfig.APISchemaOpenAI}, filterconfig.VersionedAPISchema{Name: filterconfig.APISchemaOpenAI}, + false, ) require.NoError(t, err) require.NotNil(t, f) @@ -34,6 +36,7 @@ func TestNewFactory(t *testing.T) { f, err := NewFactory( filterconfig.VersionedAPISchema{Name: filterconfig.APISchemaOpenAI}, filterconfig.VersionedAPISchema{Name: filterconfig.APISchemaAWSBedrock}, + false, ) require.NoError(t, err) require.NotNil(t, f)