Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract Input/Output token usage from request. #111

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ site/yarn-debug.log*
site/yarn-error.log*
site/static/.DS_Store
site/temp
/.idea/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary change - already in here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also you can have a global gitignore in your os... that's where people usually place the gitignore for something that's not generated by the project

19 changes: 16 additions & 3 deletions filterconfig/filterconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ modelNameHeaderKey: x-envoy-ai-gateway-model
// modelNameHeaderKey: x-envoy-ai-gateway-model
// tokenUsageMetadata:
// namespace: ai_gateway_llm_ns
// key: token_usage_key
// inputTokensKey: input_tokens_usage
// outputTokensKey: output_tokens_usage
// totalTokensKey: total_tokens_usage
// rules:
// - backends:
// - name: kserve
Expand Down Expand Up @@ -66,6 +68,13 @@ modelNameHeaderKey: x-envoy-ai-gateway-model
// From Envoy configuration perspective, configuring the header matching based on `x-envoy-ai-gateway-selected-backend` is enough to route the request to the selected backend.
// That is because the matching decision is made by the filter and the selected backend is populated in the header `x-envoy-ai-gateway-selected-backend`.
type Config struct {
// MonitorContinuousUsageStats flag controls if external process monitors every response-body chunk for usage stats
// when true, it will monitor for token metadata usage in every response-body chunk received during request in streaming mode
// compatible with vllm's 'continuous_usage_stats' flag
// when false, it will stop monitoring after detecting token metadata usage after finding it for the first time.
// compatible with OpenAI's streaming response (https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-usage)
// Only affects request in streaming mode
MonitorContinuousUsageStats bool `yaml:"monitorContinuousUsageStats,omitempty"`
Comment on lines +71 to +77
Copy link
Member

@mathetake mathetake Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you remove the change related to this? I think this is another issue and metadata is not cumulative so basically it's overriding previous ones if it's emitted in the middle.

// TokenUsageMetadata is the namespace and key to be used in the filter metadata to store the usage token, optional.
// If this is provided, the filter will populate the usage token in the filter metadata at the end of the
// response body processing.
Expand All @@ -90,8 +99,12 @@ type Config struct {
type TokenUsageMetadata struct {
// Namespace is the namespace of the metadata.
Namespace string `yaml:"namespace"`
// Key is the key of the metadata.
Key string `yaml:"key"`
// InputTokensKey is the key of the metadata containing input-token count parsed from upstream-response.
InputTokensKey string `yaml:"inputTokenKey"`
// OutputTokensKey is the key of the metadata containing output-token count parsed from upstream-response.
OutputTokensKey string `yaml:"outputTokenKey"`
// TotalTokensKey is the key of the metadata containing total-token count parsed from upstream-response.
TotalTokensKey string `yaml:"totalTokenKey"`
}

// VersionedAPISchema corresponds to LLMAPISchema in api/v1alpha1/api.go.
Expand Down
8 changes: 6 additions & 2 deletions filterconfig/filterconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ selectedBackendHeaderKey: x-envoy-ai-gateway-selected-backend
modelNameHeaderKey: x-envoy-ai-gateway-model
tokenUsageMetadata:
namespace: ai_gateway_llm_ns
key: token_usage_key
inputTokensKey: input_tokens_usage
outputTokensKey: output_tokens_usage
totalTokensKey: total_tokens_usage
rules:
- backends:
- name: kserve
Expand All @@ -61,7 +63,9 @@ rules:
cfg, err := filterconfig.UnmarshalConfigYaml(configPath)
require.NoError(t, err)
require.Equal(t, "ai_gateway_llm_ns", cfg.TokenUsageMetadata.Namespace)
require.Equal(t, "token_usage_key", cfg.TokenUsageMetadata.Key)
require.Equal(t, "input_tokens_usage", cfg.TokenUsageMetadata.InputTokensKey)
require.Equal(t, "output_tokens_usage", cfg.TokenUsageMetadata.OutputTokensKey)
require.Equal(t, "total_tokens_usage", cfg.TokenUsageMetadata.TotalTokensKey)
require.Equal(t, "OpenAI", string(cfg.InputSchema.Schema))
require.Equal(t, "x-envoy-ai-gateway-selected-backend", cfg.SelectedBackendHeaderKey)
require.Equal(t, "x-envoy-ai-gateway-model", cfg.ModelNameHeaderKey)
Expand Down
6 changes: 3 additions & 3 deletions internal/extproc/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type mockTranslator struct {
retHeaderMutation *extprocv3.HeaderMutation
retBodyMutation *extprocv3.BodyMutation
retOverride *extprocv3http.ProcessingMode
retUsedToken uint32
retTokenUsage *translator.TokenUsage
retErr error
}

Expand All @@ -86,13 +86,13 @@ func (m mockTranslator) ResponseHeaders(headers map[string]string) (headerMutati
}

// ResponseBody implements [translator.Translator.ResponseBody].
func (m mockTranslator) ResponseBody(body io.Reader, _ bool) (headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, usedToken uint32, err error) {
func (m mockTranslator) ResponseBody(body io.Reader, _ bool) (headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tknUsage *translator.TokenUsage, err error) {
if m.expResponseBody != nil {
buf, err := io.ReadAll(body)
require.NoError(m.t, err)
require.Equal(m.t, m.expResponseBody.Body, buf)
}
return m.retHeaderMutation, m.retBodyMutation, m.retUsedToken, m.retErr
return m.retHeaderMutation, m.retBodyMutation, m.retTokenUsage, m.retErr
}

// mockRouter implements [router.Router] for testing.
Expand Down
9 changes: 6 additions & 3 deletions internal/extproc/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,23 @@ func (p *Processor) ProcessResponseBody(_ context.Context, body *extprocv3.HttpB
},
},
}
if p.config.tokenUsageMetadata != nil {
if p.config.tokenUsageMetadata != nil && usedToken != nil {
resp.DynamicMetadata = buildTokenUsageDynamicMetadata(p.config.tokenUsageMetadata, usedToken)
}
return resp, nil
}

func buildTokenUsageDynamicMetadata(md *filterconfig.TokenUsageMetadata, usage uint32) *structpb.Struct {
func buildTokenUsageDynamicMetadata(md *filterconfig.TokenUsageMetadata, tknUsage *translator.TokenUsage) *structpb.Struct {
fmt.Println("**** token usage **** ", tknUsage)
return &structpb.Struct{
Fields: map[string]*structpb.Value{
md.Namespace: {
Kind: &structpb.Value_StructValue{
StructValue: &structpb.Struct{
Fields: map[string]*structpb.Value{
md.Key: {Kind: &structpb.Value_NumberValue{NumberValue: float64(usage)}},
md.InputTokensKey: {Kind: &structpb.Value_NumberValue{NumberValue: float64(tknUsage.InputTokens)}},
md.OutputTokensKey: {Kind: &structpb.Value_NumberValue{NumberValue: float64(tknUsage.OutputTokens)}},
md.TotalTokensKey: {Kind: &structpb.Value_NumberValue{NumberValue: float64(tknUsage.TotalTokens)}},
},
},
},
Expand Down
18 changes: 15 additions & 3 deletions internal/extproc/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,19 @@ func TestProcessor_ProcessResponseBody(t *testing.T) {
inBody := &extprocv3.HttpBody{Body: []byte("some-body")}
expBodyMut := &extprocv3.BodyMutation{}
expHeadMut := &extprocv3.HeaderMutation{}
mt := &mockTranslator{t: t, expResponseBody: inBody, retBodyMutation: expBodyMut, retHeaderMutation: expHeadMut, retUsedToken: 123}
mt := &mockTranslator{
t: t, expResponseBody: inBody, retBodyMutation: expBodyMut, retHeaderMutation: expHeadMut,
retTokenUsage: &translator.TokenUsage{
InputTokens: 11,
OutputTokens: 22,
TotalTokens: 33,
},
}
p := &Processor{translator: mt, config: &processorConfig{tokenUsageMetadata: &filterconfig.TokenUsageMetadata{
Namespace: "ai_gateway_llm_ns", Key: "token_usage",
Namespace: "ai_gateway_llm_ns",
InputTokensKey: "input_tokens_usage",
OutputTokensKey: "output_tokens_usage",
TotalTokensKey: "total_tokens_usage",
}}}
res, err := p.ProcessResponseBody(context.Background(), inBody)
require.NoError(t, err)
Expand All @@ -71,7 +81,9 @@ func TestProcessor_ProcessResponseBody(t *testing.T) {

md := res.DynamicMetadata
require.NotNil(t, md)
require.Equal(t, float64(123), md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["token_usage"].GetNumberValue())
require.Equal(t, float64(11), md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["input_tokens_usage"].GetNumberValue())
require.Equal(t, float64(22), md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["output_tokens_usage"].GetNumberValue())
require.Equal(t, float64(33), md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["total_tokens_usage"].GetNumberValue())
})
}

Expand Down
2 changes: 1 addition & 1 deletion internal/extproc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *Server[P]) LoadConfig(config *filterconfig.Config) error {
for _, r := range config.Rules {
for _, b := range r.Backends {
if _, ok := factories[b.OutputSchema]; !ok {
factories[b.OutputSchema], err = translator.NewFactory(config.InputSchema, b.OutputSchema)
factories[b.OutputSchema], err = translator.NewFactory(config.InputSchema, b.OutputSchema, config.MonitorContinuousUsageStats)
if err != nil {
return fmt.Errorf("cannot create translator factory: %w", err)
}
Expand Down
6 changes: 4 additions & 2 deletions internal/extproc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestServer_LoadConfig(t *testing.T) {
})
t.Run("ok", func(t *testing.T) {
config := &filterconfig.Config{
TokenUsageMetadata: &filterconfig.TokenUsageMetadata{Namespace: "ns", Key: "key"},
TokenUsageMetadata: &filterconfig.TokenUsageMetadata{Namespace: "ns", InputTokensKey: "input_tokens", OutputTokensKey: "output_tokens", TotalTokensKey: "total_tokens"},
InputSchema: filterconfig.VersionedAPISchema{Schema: filterconfig.APISchemaOpenAI},
SelectedBackendHeaderKey: "x-envoy-ai-gateway-selected-backend",
ModelNameHeaderKey: "x-model-name",
Expand Down Expand Up @@ -73,7 +73,9 @@ func TestServer_LoadConfig(t *testing.T) {
require.NotNil(t, s.config)
require.NotNil(t, s.config.tokenUsageMetadata)
require.Equal(t, "ns", s.config.tokenUsageMetadata.Namespace)
require.Equal(t, "key", s.config.tokenUsageMetadata.Key)
require.Equal(t, "input_tokens", s.config.tokenUsageMetadata.InputTokensKey)
require.Equal(t, "output_tokens", s.config.tokenUsageMetadata.OutputTokensKey)
require.Equal(t, "total_tokens", s.config.tokenUsageMetadata.TotalTokensKey)
require.NotNil(t, s.config.router)
require.NotNil(t, s.config.bodyParser)
require.Equal(t, "x-envoy-ai-gateway-selected-backend", s.config.selectedBackendHeaderKey)
Expand Down
30 changes: 21 additions & 9 deletions internal/extproc/translator/openai_awsbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,21 +288,25 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseHeaders(headers m

// ResponseBody implements [Translator.ResponseBody].
func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(body io.Reader, endOfStream bool) (
headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, usedToken uint32, err error,
headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tknUsage *TokenUsage, err error,
) {
mut := &extprocv3.BodyMutation_Body{}
if o.stream {
buf, err := io.ReadAll(body)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to read body: %w", err)
return nil, nil, nil, fmt.Errorf("failed to read body: %w", err)
}
o.bufferedBody = append(o.bufferedBody, buf...)
o.extractAmazonEventStreamEvents()

for i := range o.events {
event := &o.events[i]
if usage := event.Usage; usage != nil {
usedToken = uint32(usage.TotalTokens) //nolint:gosec
tknUsage = &TokenUsage{
InputTokens: uint32(usage.InputTokens), //nolint:gosec
OutputTokens: uint32(usage.OutputTokens), //nolint:gosec
TotalTokens: uint32(usage.TotalTokens), //nolint:gosec
}
}

oaiEvent, ok := o.convertEvent(event)
Expand All @@ -321,15 +325,19 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(body io.Read
if endOfStream {
mut.Body = append(mut.Body, []byte("data: [DONE]\n")...)
}
return headerMutation, &extprocv3.BodyMutation{Mutation: mut}, usedToken, nil
return headerMutation, &extprocv3.BodyMutation{Mutation: mut}, tknUsage, nil
}

var bedrockResp awsbedrock.ConverseOutput
if err := json.NewDecoder(body).Decode(&bedrockResp); err != nil {
return nil, nil, 0, fmt.Errorf("failed to unmarshal body: %w", err)
return nil, nil, nil, fmt.Errorf("failed to unmarshal body: %w", err)
}

usedToken = uint32(bedrockResp.Usage.TotalTokens) //nolint:gosec
tknUsage = &TokenUsage{
InputTokens: uint32(bedrockResp.Usage.InputTokens), //nolint:gosec
OutputTokens: uint32(bedrockResp.Usage.OutputTokens), //nolint:gosec
TotalTokens: uint32(bedrockResp.Usage.TotalTokens), //nolint:gosec
}

openAIResp := openai.ChatCompletionResponse{
Object: "chat.completion",
Expand All @@ -341,7 +349,11 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(body io.Read
PromptTokens: bedrockResp.Usage.InputTokens,
CompletionTokens: bedrockResp.Usage.OutputTokens,
}
usedToken = uint32(bedrockResp.Usage.TotalTokens) //nolint:gosec
tknUsage = &TokenUsage{
InputTokens: uint32(bedrockResp.Usage.InputTokens), //nolint:gosec
OutputTokens: uint32(bedrockResp.Usage.OutputTokens), //nolint:gosec
TotalTokens: uint32(bedrockResp.Usage.TotalTokens), //nolint:gosec
}
}
for i, output := range bedrockResp.Output.Message.Content {
choice := openai.ChatCompletionResponseChoice{
Expand All @@ -367,13 +379,13 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(body io.Read
}

if body, err := json.Marshal(openAIResp); err != nil {
return nil, nil, 0, fmt.Errorf("failed to marshal body: %w", err)
return nil, nil, nil, fmt.Errorf("failed to marshal body: %w", err)
} else {
mut.Body = body
}
headerMutation = &extprocv3.HeaderMutation{}
setContentLength(headerMutation, mut.Body)
return headerMutation, &extprocv3.BodyMutation{Mutation: mut}, usedToken, nil
return headerMutation, &extprocv3.BodyMutation{Mutation: mut}, tknUsage, nil
}

// extractAmazonEventStreamEvents extracts [awsbedrock.ConverseStreamEvent] from the buffered body.
Expand Down
12 changes: 8 additions & 4 deletions internal/extproc/translator/openai_awsbedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_Streaming_ResponseBody(t *

var results []string
for i := 0; i < len(buf); i++ {
hm, bm, usedToken, err := o.ResponseBody(bytes.NewBuffer([]byte{buf[i]}), i == len(buf)-1)
hm, bm, tknUsage, err := o.ResponseBody(bytes.NewBuffer([]byte{buf[i]}), i == len(buf)-1)
require.NoError(t, err)
require.Nil(t, hm)
require.NotNil(t, bm)
Expand All @@ -448,8 +448,10 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_Streaming_ResponseBody(t *
if len(newBody) > 0 {
results = append(results, string(newBody))
}
if usedToken > 0 {
require.Equal(t, uint32(77), usedToken)
if tknUsage != nil {
require.Equal(t, uint32(41), tknUsage.InputTokens)
require.Equal(t, uint32(36), tknUsage.OutputTokens)
require.Equal(t, uint32(77), tknUsage.TotalTokens)
}
}

Expand Down Expand Up @@ -596,7 +598,9 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T)
var openAIResp openai.ChatCompletionResponse
err = json.Unmarshal(newBody, &openAIResp)
require.NoError(t, err)
require.Equal(t, uint32(30), usedToken)
require.Equal(t, uint32(10), usedToken.InputTokens)
require.Equal(t, uint32(20), usedToken.OutputTokens)
require.Equal(t, uint32(30), usedToken.TotalTokens)
if !cmp.Equal(openAIResp, tt.output) {
t.Errorf("ConvertOpenAIToBedrock(), diff(got, expected) = %s\n", cmp.Diff(openAIResp, tt.output))
}
Expand Down
53 changes: 34 additions & 19 deletions internal/extproc/translator/openai_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,24 @@ import (
"github.com/envoyproxy/ai-gateway/internal/extproc/router"
)

// newOpenAIToOpenAITranslator implements [TranslatorFactory] for OpenAI to OpenAI translation.
func newOpenAIToOpenAITranslator(path string) (Translator, error) {
if path == "/v1/chat/completions" {
return &openAIToOpenAITranslatorV1ChatCompletion{}, nil
} else {
return nil, fmt.Errorf("unsupported path: %s", path)
// newOpenAIToOpenAITranslatorFactory implements [TranslatorFactory] for OpenAI to OpenAI translation.
func newOpenAIToOpenAITranslatorFactory(monitorContinuousUsageStats bool) Factory {
return func(path string) (Translator, error) {
if path == "/v1/chat/completions" {
return &openAIToOpenAITranslatorV1ChatCompletion{monitorContinuousUsageStats: monitorContinuousUsageStats}, nil
} else {
return nil, fmt.Errorf("unsupported path: %s", path)
}
}
}

// openAIToOpenAITranslatorV1ChatCompletion implements [Translator] for /v1/chat/completions.
type openAIToOpenAITranslatorV1ChatCompletion struct {
defaultTranslator
stream bool
buffered []byte
bufferingDone bool
stream bool
buffered []byte
bufferingDone bool
monitorContinuousUsageStats bool
}

// RequestBody implements [RequestBody].
Expand All @@ -50,36 +53,44 @@ func (o *openAIToOpenAITranslatorV1ChatCompletion) RequestBody(body router.Reque

// ResponseBody implements [Translator.ResponseBody].
func (o *openAIToOpenAITranslatorV1ChatCompletion) ResponseBody(body io.Reader, _ bool) (
headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, usedToken uint32, err error,
headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tknUsage *TokenUsage, err error,
) {
if o.stream {
if !o.bufferingDone {
buf, err := io.ReadAll(body)
if !o.bufferingDone || o.monitorContinuousUsageStats {
// OpenAI's api suggests that usage info will only be sent in the last chunk (https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-usage)
// whereas vllm model server supports including usage-info in each returned chunk.
// To incorporate both approaches, we check for usage-info in each chunk
var buf []byte
buf, err = io.ReadAll(body)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to read body: %w", err)
return nil, nil, nil, fmt.Errorf("failed to read body: %w", err)
}
o.buffered = append(o.buffered, buf...)
usedToken = o.extractUsageFromBufferEvent()
tknUsage = o.extractUsageFromBufferEvent()
}
return
}
var resp openai.ChatCompletionResponse
if err := json.NewDecoder(body).Decode(&resp); err != nil {
return nil, nil, 0, fmt.Errorf("failed to unmarshal body: %w", err)
return nil, nil, nil, fmt.Errorf("failed to unmarshal body: %w", err)
}
tknUsage = &TokenUsage{
InputTokens: uint32(resp.Usage.PromptTokens), //nolint:gosec
OutputTokens: uint32(resp.Usage.CompletionTokens), //nolint:gosec
TotalTokens: uint32(resp.Usage.TotalTokens), //nolint:gosec
}
usedToken = uint32(resp.Usage.TotalTokens) //nolint:gosec
return
}

var dataPrefix = []byte("data: ")

// extractUsageFromBufferEvent extracts the token usage from the buffered event.
// Once the usage is extracted, it returns the number of tokens used, and bufferingDone is set to true.
func (o *openAIToOpenAITranslatorV1ChatCompletion) extractUsageFromBufferEvent() (usedToken uint32) {
func (o *openAIToOpenAITranslatorV1ChatCompletion) extractUsageFromBufferEvent() (tknUsage *TokenUsage) {
for {
i := bytes.IndexByte(o.buffered, '\n')
if i == -1 {
return 0
return nil
}
line := o.buffered[:i]
o.buffered = o.buffered[i+1:]
Expand All @@ -91,7 +102,11 @@ func (o *openAIToOpenAITranslatorV1ChatCompletion) extractUsageFromBufferEvent()
continue
}
if usage := event.Usage; usage != nil {
usedToken = uint32(usage.TotalTokens) //nolint:gosec
tknUsage = &TokenUsage{
InputTokens: uint32(event.Usage.PromptTokens), //nolint:gosec
OutputTokens: uint32(event.Usage.CompletionTokens), //nolint:gosec
TotalTokens: uint32(event.Usage.TotalTokens), //nolint:gosec
}
o.bufferingDone = true
o.buffered = nil
return
Expand Down
Loading
Loading