diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index a41a73cd8d..4f08ab82c0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -274,9 +274,6 @@ func getOpenAiApiName(path string) provider.ApiName { if strings.HasSuffix(path, "/v1/embeddings") { return provider.ApiNameEmbeddings } - if strings.HasSuffix(path, "/v1/audio/transcriptions") { - return provider.ApiNameAudioTranscription - } if strings.HasSuffix(path, "/v1/audio/speech") { return provider.ApiNameAudioSpeech } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 470ea4de87..a1412114f5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -139,6 +139,9 @@ func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, } func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return c.config.defaultTransformRequestBody(ctx, apiName, body, log) + } request := &chatCompletionRequest{} if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { return nil, err @@ -148,6 +151,9 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A } func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } claudeResponse := &claudeTextGenResponse{} if err := json.Unmarshal(body, claudeResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal claude response: %v", err) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index 0964051454..931d1b677c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -107,6 +107,9 @@ func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam } func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return m.config.defaultTransformRequestBody(ctx, apiName, body, log) + } request := &chatCompletionRequest{} if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { return nil, err diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/dify.go b/plugins/wasm-go/extensions/ai-proxy/provider/dify.go index 28b6dec794..2395e97f66 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/dify.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/dify.go @@ -4,13 +4,14 @@ import ( "encoding/json" "errors" "fmt" + "net/http" + "strings" + "time" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "strings" - "time" ) const ( @@ -83,6 +84,9 @@ func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b } func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return d.config.defaultTransformRequestBody(ctx, apiName, body, log) + } request := &chatCompletionRequest{} err := d.config.parseRequestAndMapModel(ctx, request, body, log) if err != nil { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index a61dff054d..7aa0a4ae66 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -294,7 +294,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName // hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用 func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { if m.useOpenAICompatibleAPI() { - return body, nil + return m.config.defaultTransformRequestBody(ctx, apiName, body, log) } request := &chatCompletionRequest{} err := m.config.parseRequestAndMapModel(ctx, request, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 51a65555c7..61bef7467b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -19,21 +19,21 @@ const ( ) type chatCompletionRequest struct { - Model string `json:"model"` - Messages []chatMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - Seed int `json:"seed,omitempty"` - Stream bool `json:"stream,omitempty"` - StreamOptions *streamOptions `json:"stream_options,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Tools []tool `json:"tools,omitempty"` - ToolChoice *toolChoice `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` - Stop []string `json:"stop,omitempty"` + Model string `json:"model"` + Messages []chatMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + Seed int `json:"seed,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *streamOptions `json:"stream_options,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Tools []tool `json:"tools,omitempty"` + ToolChoice *toolChoice `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` + Stop []string `json:"stop,omitempty"` ResponseFormat map[string]interface{} `json:"response_format,omitempty"` } @@ -230,6 +230,21 @@ func (e *streamEvent) setValue(key, value string) { } } +// https://platform.openai.com/docs/guides/images +type imageGenerationRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` +} + +// https://platform.openai.com/docs/guides/speech-to-text +type audioSpeechRequest struct { + Model string `json:"model"` + Input string `json:"input"` + Voice string `json:"voice"` +} + type embeddingsRequest struct { Input interface{} `json:"input"` Model string `json:"model"` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index b49800d621..ae7315c459 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -84,6 +84,10 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } + // 非chat类型的请求,不做处理 + if apiName != ApiNameChatCompletion { + return types.ActionContinue, nil + } request := &chatCompletionRequest{} if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 4ff62411bc..6c810f6dbe 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -1,7 +1,6 @@ package provider import ( - "encoding/json" "errors" "math/rand" "net/http" @@ -12,6 +11,7 @@ import ( "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) type ApiName string @@ -22,11 +22,10 @@ const ( // ApiName 格式 {vendor}/{version}/{apitype} // 表示遵循 厂商/版本/接口类型 的格式 // 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank - ApiNameChatCompletion ApiName = "openai/v1/chatcompletions" - ApiNameEmbeddings ApiName = "openai/v1/embeddings" - ApiNameImageGeneration ApiName = "openai/v1/imagegeneration" - ApiNameAudioSpeech ApiName = "openai/v1/audiospeech" - ApiNameAudioTranscription ApiName = "openai/v1/audiotranscription" + ApiNameChatCompletion ApiName = "openai/v1/chatcompletions" + ApiNameEmbeddings ApiName = "openai/v1/embeddings" + ApiNameImageGeneration ApiName = "openai/v1/imagegeneration" + ApiNameAudioSpeech ApiName = "openai/v1/audiospeech" PathOpenAIChatCompletions = "/v1/chat/completions" PathOpenAIEmbeddings = "/v1/embeddings" @@ -379,23 +378,19 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.outputVariable = json.Get("outputVariable").String() c.capabilities = make(map[string]string) - for capability, pathJson := range json.Get("abilities").Map() { + for capability, pathJson := range json.Get("capabilities").Map() { // 过滤掉不受支持的能力 switch capability { case string(ApiNameChatCompletion), string(ApiNameEmbeddings), string(ApiNameImageGeneration), - string(ApiNameAudioSpeech), - string(ApiNameAudioTranscription): + string(ApiNameAudioSpeech): c.capabilities[capability] = pathJson.String() } } } func (c *ProviderConfig) Validate() error { - if c.timeout < 0 { - return errors.New("invalid timeout in config") - } if c.protocol != protocolOpenAI && c.protocol != protocolOriginal { return errors.New("invalid protocol in config") } @@ -528,7 +523,7 @@ func getMappedModel(model string, modelMapping map[string]string, log wrapper.Lo } func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string { - if modelMapping == nil || len(modelMapping) == 0 { + if len(modelMapping) == 0 { return "" } @@ -618,17 +613,21 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt } } +// defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能 func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { - var request interface{} - if apiName == ApiNameChatCompletion { - request = &chatCompletionRequest{} - } else { - request = &embeddingsRequest{} - } - if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil { - return nil, err + switch apiName { + case ApiNameChatCompletion: + stream := gjson.GetBytes(body, "stream").Bool() + if stream { + _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") + ctx.SetContext(ctxKeyIsStreaming, true) + } else { + ctx.SetContext(ctxKeyIsStreaming, false) + } } - return json.Marshal(request) + model := gjson.GetBytes(body, "model").String() + ctx.SetContext(ctxKeyOriginalRequestModel, model) + return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping, log)) } func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index d3b7dc7953..e5650f355f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -89,10 +89,13 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName } func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { - if apiName == ApiNameChatCompletion { + switch apiName { + case ApiNameChatCompletion: return m.onChatCompletionRequestBody(ctx, body, headers, log) - } else { + case ApiNameEmbeddings: return m.onEmbeddingsRequestBody(ctx, body, log) + default: + return m.config.defaultTransformRequestBody(ctx, apiName, body, log) } }