Skip to content

Commit

Permalink
fix: fix body rewrite error & Fix non chat type model mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
pepesi committed Jan 23, 2025
1 parent 849a53f commit 7025c0e
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 46 deletions.
3 changes: 0 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,6 @@ func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
if strings.HasSuffix(path, "/v1/audio/transcriptions") {
return provider.ApiNameAudioTranscription
}
if strings.HasSuffix(path, "/v1/audio/speech") {
return provider.ApiNameAudioSpeech
}
Expand Down
6 changes: 6 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
}

func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return c.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
request := &chatCompletionRequest{}
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
Expand All @@ -148,6 +151,9 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
}

func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
claudeResponse := &claudeTextGenResponse{}
if err := json.Unmarshal(body, claudeResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
Expand Down
3 changes: 3 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
}

func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
request := &chatCompletionRequest{}
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
Expand Down
10 changes: 7 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/dify.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
"time"
)

const (
Expand Down Expand Up @@ -83,6 +84,9 @@ func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
}

func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return d.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
request := &chatCompletionRequest{}
err := d.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用
func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
if m.useOpenAICompatibleAPI() {
return body, nil
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
request := &chatCompletionRequest{}
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
Expand Down
45 changes: 30 additions & 15 deletions plugins/wasm-go/extensions/ai-proxy/provider/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ const (
)

type chatCompletionRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Seed int `json:"seed,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *streamOptions `json:"stream_options,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Tools []tool `json:"tools,omitempty"`
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
Stop []string `json:"stop,omitempty"`
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Seed int `json:"seed,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *streamOptions `json:"stream_options,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Tools []tool `json:"tools,omitempty"`
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
Stop []string `json:"stop,omitempty"`
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
}

Expand Down Expand Up @@ -230,6 +230,21 @@ func (e *streamEvent) setValue(key, value string) {
}
}

// https://platform.openai.com/docs/guides/images
type imageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
}

// https://platform.openai.com/docs/guides/speech-to-text
type audioSpeechRequest struct {
Model string `json:"model"`
Input string `json:"input"`
Voice string `json:"voice"`
}

type embeddingsRequest struct {
Input interface{} `json:"input"`
Model string `json:"model"`
Expand Down
4 changes: 4 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
// 非chat类型的请求,不做处理
if apiName != ApiNameChatCompletion {
return types.ActionContinue, nil
}

request := &chatCompletionRequest{}
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
Expand Down
43 changes: 21 additions & 22 deletions plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package provider

import (
"encoding/json"
"errors"
"math/rand"
"net/http"
Expand All @@ -12,6 +11,7 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)

type ApiName string
Expand All @@ -22,11 +22,10 @@ const (
// ApiName 格式 {vendor}/{version}/{apitype}
// 表示遵循 厂商/版本/接口类型 的格式
// 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
ApiNameAudioTranscription ApiName = "openai/v1/audiotranscription"
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"

PathOpenAIChatCompletions = "/v1/chat/completions"
PathOpenAIEmbeddings = "/v1/embeddings"
Expand Down Expand Up @@ -379,23 +378,19 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.outputVariable = json.Get("outputVariable").String()

c.capabilities = make(map[string]string)
for capability, pathJson := range json.Get("abilities").Map() {
for capability, pathJson := range json.Get("capabilities").Map() {
// 过滤掉不受支持的能力
switch capability {
case string(ApiNameChatCompletion),
string(ApiNameEmbeddings),
string(ApiNameImageGeneration),
string(ApiNameAudioSpeech),
string(ApiNameAudioTranscription):
string(ApiNameAudioSpeech):
c.capabilities[capability] = pathJson.String()
}
}
}

func (c *ProviderConfig) Validate() error {
if c.timeout < 0 {
return errors.New("invalid timeout in config")
}
if c.protocol != protocolOpenAI && c.protocol != protocolOriginal {
return errors.New("invalid protocol in config")
}
Expand Down Expand Up @@ -528,7 +523,7 @@ func getMappedModel(model string, modelMapping map[string]string, log wrapper.Lo
}

func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
if modelMapping == nil || len(modelMapping) == 0 {
if len(modelMapping) == 0 {
return ""
}

Expand Down Expand Up @@ -618,17 +613,21 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt
}
}

// defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
var request interface{}
if apiName == ApiNameChatCompletion {
request = &chatCompletionRequest{}
} else {
request = &embeddingsRequest{}
}
if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
switch apiName {
case ApiNameChatCompletion:
stream := gjson.GetBytes(body, "stream").Bool()
if stream {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
ctx.SetContext(ctxKeyIsStreaming, true)
} else {
ctx.SetContext(ctxKeyIsStreaming, false)
}
}
return json.Marshal(request)
model := gjson.GetBytes(body, "model").String()
ctx.SetContext(ctxKeyOriginalRequestModel, model)
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping, log))
}

func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
Expand Down
7 changes: 5 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,13 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
}

func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
if apiName == ApiNameChatCompletion {
switch apiName {
case ApiNameChatCompletion:
return m.onChatCompletionRequestBody(ctx, body, headers, log)
} else {
case ApiNameEmbeddings:
return m.onEmbeddingsRequestBody(ctx, body, log)
default:
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
}

Expand Down

0 comments on commit 7025c0e

Please sign in to comment.