From a84a382f1d2630bfacedb457718e96beaf51e60b Mon Sep 17 00:00:00 2001 From: pepesi Date: Wed, 12 Feb 2025 15:23:44 +0800 Subject: [PATCH] =?UTF-8?q?feature:=20allow=20ai-proxy=20to=20forward=20st?= =?UTF-8?q?andard=20AI=20capabilities=20that=20are=20=E2=80=A6=20(#1704)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plugins/wasm-go/extensions/ai-proxy/README.md | 3 +- plugins/wasm-go/extensions/ai-proxy/main.go | 42 +++++++--- .../extensions/ai-proxy/provider/ai360.go | 13 ++- .../extensions/ai-proxy/provider/azure.go | 13 ++- .../extensions/ai-proxy/provider/baichuan.go | 17 ++-- .../extensions/ai-proxy/provider/baidu.go | 15 +++- .../extensions/ai-proxy/provider/claude.go | 25 +++++- .../ai-proxy/provider/cloudflare.go | 10 ++- .../extensions/ai-proxy/provider/cohere.go | 21 ++++- .../extensions/ai-proxy/provider/coze.go | 5 ++ .../extensions/ai-proxy/provider/deepl.go | 14 +++- .../extensions/ai-proxy/provider/deepseek.go | 17 +++- .../extensions/ai-proxy/provider/dify.go | 16 +++- .../extensions/ai-proxy/provider/doubao.go | 18 ++++- .../extensions/ai-proxy/provider/gemini.go | 12 ++- .../extensions/ai-proxy/provider/github.go | 19 +++-- .../extensions/ai-proxy/provider/groq.go | 13 ++- .../extensions/ai-proxy/provider/hunyuan.go | 56 ++++++++++--- .../extensions/ai-proxy/provider/minimax.go | 32 ++++---- .../extensions/ai-proxy/provider/mistral.go | 13 ++- .../extensions/ai-proxy/provider/model.go | 45 +++++++---- .../extensions/ai-proxy/provider/moonshot.go | 20 ++++- .../extensions/ai-proxy/provider/ollama.go | 18 +++-- .../extensions/ai-proxy/provider/openai.go | 16 ++-- .../extensions/ai-proxy/provider/provider.go | 81 +++++++++++++++---- .../extensions/ai-proxy/provider/qwen.go | 28 +++++-- .../extensions/ai-proxy/provider/spark.go | 19 ++++- .../extensions/ai-proxy/provider/stepfun.go | 14 +++- .../ai-proxy/provider/together_ai.go | 18 +++-- .../extensions/ai-proxy/provider/yi.go | 13 ++- .../extensions/ai-proxy/provider/zhipuai.go | 18 ++++- .../wasm-go/extensions/ai-proxy/util/http.go | 11 +++ 32 files changed, 517 insertions(+), 158 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index f0af574922..f1470e8ae8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -42,7 +42,8 @@ description: AI 代理插件配置参考 | `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | | `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 | | `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 | - +| `capabilities` | map of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key表示的是采用的厂商协议能力,values表示的真实的厂商该能力的api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank | +| `passthrough` | bool | 非必填 | - | 只要是不支持的API能力都直接转发, 此配置是capabilities配置的放大版本,允许任意api透传,就像没有ai-proxy插件一样 | `context`的配置字段说明如下: | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 220243f989..1dcea5878f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -78,7 +78,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf rawPath := ctx.Path() path, _ := url.Parse(rawPath) - apiName := getOpenAiApiName(path.Path) + apiName := getApiName(path.Path) providerConfig := pluginConfig.GetProviderConfig() if providerConfig.IsOriginal() { if handler, ok := activeProvider.(provider.ApiNameHandler); ok { @@ -103,20 +103,25 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf // Set the apiToken for the current request. providerConfig.SetApiTokenInUse(ctx, log) - hasRequestBody := wrapper.HasRequestBody() err := handler.OnRequestHeaders(ctx, apiName, log) - if err == nil { - if hasRequestBody { - proxywasm.RemoveHttpRequestHeader("Content-Length") - ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) - // Delay the header processing to allow changing in OnRequestBody - return types.HeaderStopIteration + if err != nil { + if providerConfig.PassthroughUnsupportedAPI() { + log.Warnf("[onHttpRequestHeader] passthrough unsupported API: %v", err) + ctx.DontReadRequestBody() + return types.ActionContinue } - ctx.DontReadRequestBody() + util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) return types.ActionContinue } - util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) + hasRequestBody := wrapper.HasRequestBody() + if hasRequestBody { + proxywasm.RemoveHttpRequestHeader("Content-Length") + ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) + // Delay the header processing to allow changing in OnRequestBody + return types.HeaderStopIteration + } + ctx.DontReadRequestBody() return types.ActionContinue } @@ -151,6 +156,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig if err == nil { return action } + if pluginConfig.GetProviderConfig().PassthroughUnsupportedAPI() { + log.Warnf("[onHttpRequestBody] passthrough unsupported API: %v", err) + return types.ActionContinue + } util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err)) } return types.ActionContinue @@ -267,12 +276,23 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) { } } -func getOpenAiApiName(path string) provider.ApiName { +func getApiName(path string) provider.ApiName { + // openai style if strings.HasSuffix(path, "/v1/chat/completions") { return provider.ApiNameChatCompletion } if strings.HasSuffix(path, "/v1/embeddings") { return provider.ApiNameEmbeddings } + if strings.HasSuffix(path, "/v1/audio/speech") { + return provider.ApiNameAudioSpeech + } + if strings.HasSuffix(path, "/v1/images/generations") { + return provider.ApiNameImageGeneration + } + // cohere style + if strings.HasSuffix(path, "/v1/rerank") { + return provider.ApiNameCohereV1Rerank + } return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index 5a8f4d70cd..57b092cd82 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -22,6 +22,13 @@ type ai360Provider struct { contextCache *contextCache } +func (m *ai360ProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + } +} + func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error { if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") @@ -30,6 +37,7 @@ func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (m *ai360ProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &ai360Provider{ config: config, contextCache: createContextCache(&config), @@ -41,7 +49,7 @@ func (m *ai360Provider) GetProviderType() string { } func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -50,7 +58,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) @@ -58,5 +66,6 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteRequestHostHeader(headers, ai360Domain) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index 4c107edf17..5fcc378d47 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -15,6 +15,14 @@ import ( type azureProviderInitializer struct { } +func (m *azureProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + } +} + func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error { if config.azureServiceUrl == "" { return errors.New("missing azureServiceUrl in provider config") @@ -35,6 +43,7 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid } else { serviceUrl = u } + config.setDefaultCapabilities(m.DefaultCapabilities()) return &azureProvider{ config: config, serviceUrl: serviceUrl, @@ -54,7 +63,7 @@ func (m *azureProvider) GetProviderType() string { } func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -62,7 +71,7 @@ func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index 3b272e4f37..d04c5c7d85 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -12,8 +12,7 @@ import ( // baichuanProvider is the provider for baichuan Ai service. const ( - baichuanDomain = "api.baichuan-ai.com" - baichuanChatCompletionPath = "/v1/chat/completions" + baichuanDomain = "api.baichuan-ai.com" ) type baichuanProviderInitializer struct { @@ -26,7 +25,15 @@ func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) err return nil } +func (m *baichuanProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + } +} + func (m *baichuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &baichuanProvider{ config: config, contextCache: createContextCache(&config), @@ -43,7 +50,7 @@ func (m *baichuanProvider) GetProviderType() string { } func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -51,14 +58,14 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, baichuanDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index f541d31fec..27bf5aaecc 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -14,6 +14,7 @@ import ( const ( baiduDomain = "qianfan.baidubce.com" baiduChatCompletionPath = "/v2/chat/completions" + baiduEmbeddings = "/v2/embeddings" ) type baiduProviderInitializer struct{} @@ -25,7 +26,15 @@ func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (g *baiduProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): baiduChatCompletionPath, + string(ApiNameEmbeddings): baiduEmbeddings, + } +} + func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(g.DefaultCapabilities()) return &baiduProvider{ config: config, contextCache: createContextCache(&config), @@ -42,7 +51,7 @@ func (g *baiduProvider) GetProviderType() string { } func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !g.config.isSupportedAPI(apiName) { return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) @@ -50,14 +59,14 @@ func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !g.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) } func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, baiduChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities) util.OverwriteRequestHostHeader(headers, baiduDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 9be84cc44e..8d75f5cae0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -85,7 +85,16 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): claudeChatCompletionPath, + // docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + } +} + func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(c.DefaultCapabilities()) return &claudeProvider{ config: config, contextCache: createContextCache(&config), @@ -102,7 +111,7 @@ func (c *claudeProvider) GetProviderType() string { } func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !c.config.isSupportedAPI(apiName) { return errUnsupportedApiName } c.config.handleRequestHeaders(c, ctx, apiName, log) @@ -110,7 +119,7 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities) util.OverwriteRequestHostHeader(headers, claudeDomain) headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx)) @@ -123,13 +132,16 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam } func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !c.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) } func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return c.config.defaultTransformRequestBody(ctx, apiName, body, log) + } request := &chatCompletionRequest{} if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { return nil, err @@ -139,6 +151,9 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A } func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } claudeResponse := &claudeTextGenResponse{} if err := json.Unmarshal(body, claudeResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal claude response: %v", err) @@ -154,6 +169,10 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A if isLastChunk || len(chunk) == 0 { return nil, nil } + // only process the response from chat completion, skip other responses + if name != ApiNameChatCompletion { + return chunk, nil + } responseBuilder := &strings.Builder{} lines := strings.Split(string(chunk), "\n") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index e191b89f37..22b9cd4286 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -25,8 +25,14 @@ func (c *cloudflareProviderInitializer) ValidateConfig(config *ProviderConfig) e } return nil } +func (c *cloudflareProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): cloudflareChatCompletionPath, + } +} func (c *cloudflareProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(c.DefaultCapabilities()) return &cloudflareProvider{ config: config, contextCache: createContextCache(&config), @@ -43,7 +49,7 @@ func (c *cloudflareProvider) GetProviderType() string { } func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !c.config.isSupportedAPI(apiName) { return errUnsupportedApiName } c.config.handleRequestHeaders(c, ctx, apiName, log) @@ -51,7 +57,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A } func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !c.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index 7bee5efd58..a21964e497 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -12,8 +12,10 @@ import ( ) const ( - cohereDomain = "api.cohere.com" + cohereDomain = "api.cohere.com" + // TODO: support more capabilities, upgrade to v2, docs: https://docs.cohere.com/v2/reference/chat cohereChatCompletionPath = "/v1/chat" + cohereRerankPath = "/v1/rerank" ) type cohereProviderInitializer struct{} @@ -25,7 +27,15 @@ func (m *cohereProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (m *cohereProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): cohereChatCompletionPath, + string(ApiNameCohereV1Rerank): cohereRerankPath, + } +} + func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &cohereProvider{ config: config, contextCache: createContextCache(&config), @@ -56,7 +66,7 @@ func (m *cohereProvider) GetProviderType() string { } func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -64,7 +74,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) @@ -90,13 +100,16 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe } func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, cohereDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") } func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return m.config.defaultTransformRequestBody(ctx, apiName, body, log) + } request := &chatCompletionRequest{} if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { return nil, err diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go index 4e30ec27c3..8626b226af 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go @@ -21,7 +21,12 @@ func (m *cozeProviderInitializer) ValidateConfig(config *ProviderConfig) error { return nil } +func (m *cozeProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{} +} + func (m *cozeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &cozeProvider{ config: config, contextCache: createContextCache(&config), diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 192904ecc8..812bd32557 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -64,7 +64,14 @@ func (d *deeplProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (d *deeplProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): deeplChatCompletionPath, + } +} + func (d *deeplProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(d.DefaultCapabilities()) return &deeplProvider{ config: config, contextCache: createContextCache(&config), @@ -76,7 +83,7 @@ func (d *deeplProvider) GetProviderType() string { } func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !d.config.isSupportedAPI(apiName) { return errUnsupportedApiName } d.config.handleRequestHeaders(d, ctx, apiName, log) @@ -89,7 +96,7 @@ func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName } func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !d.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log) @@ -112,6 +119,9 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api } func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } deeplResponse := &deeplResponse{} if err := json.Unmarshal(body, deeplResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index b6a842f1c6..c8eca82a5c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -12,7 +12,9 @@ import ( // deepseekProvider is the provider for deepseek Ai service. const ( - deepseekDomain = "api.deepseek.com" + deepseekDomain = "api.deepseek.com" + // TODO: docs: https://api-docs.deepseek.com/api/create-chat-completion + // accourding to the docs, the path should be /chat/completions, need to be verified deepseekChatCompletionPath = "/v1/chat/completions" ) @@ -26,7 +28,14 @@ func (m *deepseekProviderInitializer) ValidateConfig(config *ProviderConfig) err return nil } +func (m *deepseekProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): deepseekChatCompletionPath, + } +} + func (m *deepseekProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &deepseekProvider{ config: config, contextCache: createContextCache(&config), @@ -43,7 +52,7 @@ func (m *deepseekProvider) GetProviderType() string { } func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -51,14 +60,14 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, deepseekChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, deepseekDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/dify.go b/plugins/wasm-go/extensions/ai-proxy/provider/dify.go index 5c21cc3b63..b93461a0c6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/dify.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/dify.go @@ -4,13 +4,14 @@ import ( "encoding/json" "errors" "fmt" + "net/http" + "strings" + "time" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "strings" - "time" ) const ( @@ -83,6 +84,9 @@ func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b } func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return d.config.defaultTransformRequestBody(ctx, apiName, body, log) + } request := &chatCompletionRequest{} err := d.config.parseRequestAndMapModel(ctx, request, body, log) if err != nil { @@ -95,6 +99,9 @@ func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiN } func (d *difyProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } difyResponse := &DifyChatResponse{} if err := json.Unmarshal(body, difyResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal dify response: %v", err) @@ -146,6 +153,9 @@ func (d *difyProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api if isLastChunk || len(chunk) == 0 { return nil, nil } + if name != ApiNameChatCompletion { + return chunk, nil + } // sample event response: // data: {"event": "agent_thought", "id": "8dcf3648-fbad-407a-85dd-73a6f43aeb9f", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "position": 1, "thought": "", "observation": "", "tool": "", "tool_input": "", "created_at": 1705639511, "message_files": [], "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index ed8b3b18bf..a896078e12 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -13,6 +13,7 @@ import ( const ( doubaoDomain = "ark.cn-beijing.volces.com" doubaoChatCompletionPath = "/api/v3/chat/completions" + doubaoEmbeddingsPath = "/api/v3/embeddings" ) type doubaoProviderInitializer struct{} @@ -24,7 +25,15 @@ func (m *doubaoProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (m *doubaoProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): doubaoChatCompletionPath, + string(ApiNameEmbeddings): doubaoEmbeddingsPath, + } +} + func (m *doubaoProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &doubaoProvider{ config: config, contextCache: createContextCache(&config), @@ -41,7 +50,7 @@ func (m *doubaoProvider) GetProviderType() string { } func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -49,14 +58,14 @@ func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, doubaoChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, doubaoDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") @@ -66,5 +75,8 @@ func (m *doubaoProvider) GetApiName(path string) ApiName { if strings.Contains(path, doubaoChatCompletionPath) { return ApiNameChatCompletion } + if strings.Contains(path, doubaoEmbeddingsPath) { + return ApiNameEmbeddings + } return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 1017b0f028..1f8d877ea1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -35,7 +35,12 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{} +} + func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(g.DefaultCapabilities()) return &geminiProvider{ config: config, contextCache: createContextCache(&config), @@ -52,7 +57,7 @@ func (g *geminiProvider) GetProviderType() string { } func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !g.config.isSupportedAPI(apiName) { return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) @@ -66,7 +71,7 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam } func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !g.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) @@ -110,6 +115,9 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A if isLastChunk || len(chunk) == 0 { return nil, nil } + if name != ApiNameChatCompletion { + return chunk, nil + } // sample end event response: // data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini,一个大型多模态模型,由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}} responseBuilder := &strings.Builder{} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index fb5649a673..e8a05cc1c9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -32,7 +32,15 @@ func (m *githubProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (m *githubProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): githubCompletionPath, + string(ApiNameEmbeddings): githubEmbeddingPath, + } +} + func (m *githubProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &githubProvider{ config: config, contextCache: createContextCache(&config), @@ -44,7 +52,7 @@ func (m *githubProvider) GetProviderType() string { } func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -53,7 +61,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) @@ -61,12 +69,7 @@ func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteRequestHostHeader(headers, githubDomain) - if apiName == ApiNameChatCompletion { - util.OverwriteRequestPathHeader(headers, githubCompletionPath) - } - if apiName == ApiNameEmbeddings { - util.OverwriteRequestPathHeader(headers, githubEmbeddingPath) - } + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index 04e29500df..c415a707b8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -25,7 +25,14 @@ func (g *groqProviderInitializer) ValidateConfig(config *ProviderConfig) error { return nil } +func (g *groqProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): groqChatCompletionPath, + } +} + func (g *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(g.DefaultCapabilities()) return &groqProvider{ config: config, contextCache: createContextCache(&config), @@ -42,7 +49,7 @@ func (g *groqProvider) GetProviderType() string { } func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !g.config.isSupportedAPI(apiName) { return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) @@ -50,14 +57,14 @@ func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName } func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !g.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) } func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, groqChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities) util.OverwriteRequestHostHeader(headers, groqDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index 583c838c2c..fc3fddca2f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -39,6 +39,11 @@ const ( hunyuanAuthKeyLen = 32 hunyuanAuthIdLen = 36 + + // docs: https://cloud.tencent.com/document/product/1729/111007 + hunyuanOpenAiDomain = "api.hunyuan.cloud.tencent.com" + hunyuanOpenAiRequestPath = "/v1/chat/completions" + hunyuanOpenAiEmbeddings = "/v1/embeddings" ) type hunyuanProviderInitializer struct { @@ -86,6 +91,10 @@ type hunyuanChatMessage struct { } func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) error { + // 允许 hunyuanauthid 和 hunyuanauthkey 为空, 当他们都为空的时候,认为是使用openai的 兼容接口 + if len(config.hunyuanAuthId) == 0 && len(config.hunyuanAuthKey) == 0 { + return nil + } // 校验hunyuan id 和 key的合法性 if len(config.hunyuanAuthId) != hunyuanAuthIdLen || len(config.hunyuanAuthKey) != hunyuanAuthKeyLen { return errors.New("hunyuanAuthId / hunyuanAuthKey is illegal in config file") @@ -93,7 +102,15 @@ func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) erro return nil } +func (m *hunyuanProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): hunyuanOpenAiRequestPath, + string(ApiNameEmbeddings): hunyuanOpenAiEmbeddings, + } +} + func (m *hunyuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &hunyuanProvider{ config: config, client: wrapper.NewClusterClient(wrapper.RouteCluster{ @@ -114,8 +131,12 @@ func (m *hunyuanProvider) GetProviderType() string { return providerTypeHunyuan } +func (m *hunyuanProvider) useOpenAICompatibleAPI() bool { + return len(m.config.hunyuanAuthId) == 0 && len(m.config.hunyuanAuthKey) == 0 +} + func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -124,19 +145,27 @@ func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestHostHeader(headers, hunyuanDomain) - util.OverwriteRequestPathHeader(headers, hunyuanRequestPath) - - // 添加 hunyuan 需要的自定义字段 - headers.Set(actionKey, hunyuanChatCompletionTCAction) - headers.Set(versionKey, versionValue) + if m.useOpenAICompatibleAPI() { + util.OverwriteRequestHostHeader(headers, hunyuanOpenAiDomain) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + } else { + util.OverwriteRequestHostHeader(headers, hunyuanDomain) + util.OverwriteRequestPathHeader(headers, hunyuanRequestPath) + // 添加 hunyuan 需要的自定义字段 + headers.Set(actionKey, hunyuanChatCompletionTCAction) + headers.Set(versionKey, versionValue) + } } // hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法 func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } + if m.useOpenAICompatibleAPI() { + return types.ActionContinue, nil + } // 为header添加时间戳字段 (因为需要根据body进行签名时依赖时间戳,故于body处理部分创建时间戳) var timestamp int64 = time.Now().Unix() @@ -264,6 +293,9 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName // hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用 func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + if m.useOpenAICompatibleAPI() { + return m.config.defaultTransformRequestBody(ctx, apiName, body, log) + } request := &chatCompletionRequest{} err := m.config.parseRequestAndMapModel(ctx, request, body, log) if err != nil { @@ -289,7 +321,7 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a } func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - if m.config.protocol == protocolOriginal { + if m.config.IsOriginal() || m.useOpenAICompatibleAPI() || name != ApiNameChatCompletion { return chunk, nil } @@ -405,6 +437,12 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex } func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if m.config.IsOriginal() || m.useOpenAICompatibleAPI() { + return body, nil + } + if apiName != ApiNameChatCompletion { + return body, nil + } log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body)) hunyuanResponse := &hunyuanTextGenResponseNonStreaming{} if err := json.Unmarshal(body, hunyuanResponse); err != nil { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index 02812c0b46..0f5e0d3695 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -41,7 +41,7 @@ type minimaxProviderInitializer struct { func (m *minimaxProviderInitializer) ValidateConfig(config *ProviderConfig) error { // If using the chat completion Pro API, a group ID must be set. if minimaxApiTypePro == config.minimaxApiType && config.minimaxGroupId == "" { - return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when minimaxApiType is %s", minimaxApiTypePro)) + return fmt.Errorf("missing minimaxGroupId in provider config when minimaxApiType is %s", minimaxApiTypePro) } if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") @@ -49,7 +49,15 @@ func (m *minimaxProviderInitializer) ValidateConfig(config *ProviderConfig) erro return nil } +func (m *minimaxProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // minimax 替换path的时候,要根据modelmapping替换,这儿的配置无实质作用,只是为了保持和其他provider的一致性 + string(ApiNameChatCompletion): minimaxChatCompletionV2Path, + } +} + func (m *minimaxProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &minimaxProvider{ config: config, contextCache: createContextCache(&config), @@ -66,7 +74,7 @@ func (m *minimaxProvider) GetProviderType() string { } func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -81,7 +89,7 @@ func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa } func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } if minimaxApiTypePro == m.config.minimaxApiType { @@ -159,6 +167,9 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name if isLastChunk || len(chunk) == 0 { return nil, nil } + if name != ApiNameChatCompletion { + return chunk, nil + } // Sample event response: // data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false} @@ -192,6 +203,9 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name // TransformResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API. func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } minimaxResp := &minimaxChatCompletionProResp{} if err := json.Unmarshal(body, minimaxResp); err != nil { return nil, fmt.Errorf("unable to unmarshal minimax response: %v", err) @@ -268,18 +282,6 @@ type minimaxUsage struct { CompletionTokens int64 `json:"completion_tokens"` } -func (m *minimaxProvider) parseModel(body []byte) (string, error) { - var tempMap map[string]interface{} - if err := json.Unmarshal(body, &tempMap); err != nil { - return "", err - } - model, ok := tempMap["model"].(string) - if !ok { - return "", errors.New("missing model in chat completion request") - } - return model, nil -} - func (m *minimaxProvider) setBotSettings(request *minimaxChatCompletionProRequest, botSettingContent string) { if len(request.BotSettings) == 0 { request.BotSettings = []minimaxBotSetting{ diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index eb5b319a20..3f361a27ac 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -22,7 +22,16 @@ func (m *mistralProviderInitializer) ValidateConfig(config *ProviderConfig) erro return nil } +func (m *mistralProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // The chat interface of mistral is the same as that of OpenAI. docs: https://docs.mistral.ai/api/ + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + } +} + func (m *mistralProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &mistralProvider{ config: config, contextCache: createContextCache(&config), @@ -39,7 +48,7 @@ func (m *mistralProvider) GetProviderType() string { } func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -47,7 +56,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 51a65555c7..61bef7467b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -19,21 +19,21 @@ const ( ) type chatCompletionRequest struct { - Model string `json:"model"` - Messages []chatMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - Seed int `json:"seed,omitempty"` - Stream bool `json:"stream,omitempty"` - StreamOptions *streamOptions `json:"stream_options,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Tools []tool `json:"tools,omitempty"` - ToolChoice *toolChoice `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` - Stop []string `json:"stop,omitempty"` + Model string `json:"model"` + Messages []chatMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + Seed int `json:"seed,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *streamOptions `json:"stream_options,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Tools []tool `json:"tools,omitempty"` + ToolChoice *toolChoice `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` + Stop []string `json:"stop,omitempty"` ResponseFormat map[string]interface{} `json:"response_format,omitempty"` } @@ -230,6 +230,21 @@ func (e *streamEvent) setValue(key, value string) { } } +// https://platform.openai.com/docs/guides/images +type imageGenerationRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` +} + +// https://platform.openai.com/docs/guides/speech-to-text +type audioSpeechRequest struct { + Model string `json:"model"` + Input string `json:"input"` + Voice string `json:"voice"` +} + type embeddingsRequest struct { Input interface{} `json:"input"` Model string `json:"model"` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 6477b8bda4..776e250836 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -34,7 +34,14 @@ func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) err return nil } +func (m *moonshotProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): moonshotChatCompletionPath, + } +} + func (m *moonshotProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &moonshotProvider{ config: config, client: wrapper.NewClusterClient(wrapper.RouteCluster{ @@ -57,7 +64,7 @@ func (m *moonshotProvider) GetProviderType() string { } func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -65,7 +72,7 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, moonshotChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, moonshotDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") @@ -74,9 +81,13 @@ func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN // moonshot 有自己获取 context 的配置(moonshotFileId),因此无法复用 handleRequestBody 方法 // moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法 func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } + // 非chat类型的请求,不做处理 + if apiName != ApiNameChatCompletion { + return types.ActionContinue, nil + } request := &chatCompletionRequest{} if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { @@ -154,6 +165,9 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba } func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { + if name != ApiNameChatCompletion { + return chunk, nil + } receivedBody := chunk if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { receivedBody = append(bufferedStreamingBody, chunk...) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index 76b74f65fb..57ad424ffa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -12,10 +12,6 @@ import ( // ollamaProvider is the provider for Ollama service. -const ( - ollamaChatCompletionPath = "/v1/chat/completions" -) - type ollamaProviderInitializer struct { } @@ -29,9 +25,17 @@ func (m *ollamaProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (m *ollamaProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // ollama的chat接口path和OpenAI的chat接口一样 + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + } +} + func (m *ollamaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { serverPortStr := fmt.Sprintf("%d", config.ollamaServerPort) serviceDomain := config.ollamaServerHost + ":" + serverPortStr + config.setDefaultCapabilities(m.DefaultCapabilities()) return &ollamaProvider{ config: config, serviceDomain: serviceDomain, @@ -50,7 +54,7 @@ func (m *ollamaProvider) GetProviderType() string { } func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -58,14 +62,14 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, ollamaChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, m.serviceDomain) headers.Del("Content-Length") } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 2315b09564..60767b0d02 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -26,6 +26,13 @@ func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath, + string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath, + } +} + func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { if config.openaiCustomUrl == "" { return &openaiProvider{ @@ -38,6 +45,7 @@ func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provi if len(pairs) != 2 { return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl) } + config.setDefaultCapabilities(m.DefaultCapabilities()) return &openaiProvider{ config: config, customDomain: pairs[0], @@ -64,13 +72,7 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { if m.customPath == "" { - switch apiName { - case ApiNameChatCompletion: - util.OverwriteRequestPathHeader(headers, defaultOpenaiChatCompletionPath) - case ApiNameEmbeddings: - ctx.DontReadRequestBody() - util.OverwriteRequestPathHeader(headers, defaultOpenaiEmbeddingsPath) - } + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) } else { util.OverwriteRequestPathHeader(headers, m.customPath) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index c59787e6da..ae2763c0aa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -1,7 +1,6 @@ package provider import ( - "encoding/json" "errors" "math/rand" "net/http" @@ -12,14 +11,27 @@ import ( "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) type ApiName string type Pointcut string const ( - ApiNameChatCompletion ApiName = "chatCompletion" - ApiNameEmbeddings ApiName = "embeddings" + + // ApiName 格式 {vendor}/{version}/{apitype} + // 表示遵循 厂商/版本/接口类型 的格式 + // 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank + ApiNameChatCompletion ApiName = "openai/v1/chatcompletions" + ApiNameEmbeddings ApiName = "openai/v1/embeddings" + ApiNameImageGeneration ApiName = "openai/v1/imagegeneration" + ApiNameAudioSpeech ApiName = "openai/v1/audiospeech" + + PathOpenAIChatCompletions = "/v1/chat/completions" + PathOpenAIEmbeddings = "/v1/embeddings" + + // TODO: 以下是一些非标准的API名称,需要进一步确认是否支持 + ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank" providerTypeMoonshot = "moonshot" providerTypeAzure = "azure" @@ -250,6 +262,12 @@ type ProviderConfig struct { inputVariable string `required:"false" yaml:"inputVariable" json:"inputVariable"` // @Title zh-CN dify中应用类型为workflow时需要设置输出变量,当botType为workflow时一起使用 outputVariable string `required:"false" yaml:"outputVariable" json:"outputVariable"` + // @Title zh-CN 额外支持的ai能力 + // @Description zh-CN 开放的ai能力和urlpath映射,例如: {"openai/v1/chatcompletions": "/v1/chat/completions"} + capabilities map[string]string + // @Title zh-CN 是否开启透传 + // @Description zh-CN 如果是插件不支持的API,是否透传请求, 默认为false + passthrough bool } func (c *ProviderConfig) GetId() string { @@ -361,12 +379,22 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.botType = json.Get("botType").String() c.inputVariable = json.Get("inputVariable").String() c.outputVariable = json.Get("outputVariable").String() + + c.capabilities = make(map[string]string) + for capability, pathJson := range json.Get("capabilities").Map() { + // 过滤掉不受支持的能力 + switch capability { + case string(ApiNameChatCompletion), + string(ApiNameEmbeddings), + string(ApiNameImageGeneration), + string(ApiNameAudioSpeech), + string(ApiNameCohereV1Rerank): + c.capabilities[capability] = pathJson.String() + } + } } func (c *ProviderConfig) Validate() error { - if c.timeout < 0 { - return errors.New("invalid timeout in config") - } if c.protocol != protocolOpenAI && c.protocol != protocolOriginal { return errors.New("invalid protocol in config") } @@ -425,6 +453,10 @@ func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) { return ReplaceByCustomSettings(body, c.customSettings) } +func (c *ProviderConfig) PassthroughUnsupportedAPI() bool { + return c.passthrough +} + func CreateProvider(pc ProviderConfig) (Provider, error) { initializer, has := providerInitializers[pc.typ] if !has { @@ -499,7 +531,7 @@ func getMappedModel(model string, modelMapping map[string]string, log wrapper.Lo } func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string { - if modelMapping == nil || len(modelMapping) == 0 { + if len(modelMapping) == 0 { return "" } @@ -527,11 +559,22 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper. return "" } +func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool { + _, exist := c.capabilities[string(apiName)] + return exist +} + +func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string) { + for capability, path := range capabilities { + c.capabilities[capability] = path + } +} + func (c *ProviderConfig) handleRequestBody( provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log, ) (types.Action, error) { // use original protocol - if c.protocol == protocolOriginal { + if c.IsOriginal() { return types.ActionContinue, nil } @@ -578,17 +621,21 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt } } +// defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能 func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { - var request interface{} - if apiName == ApiNameChatCompletion { - request = &chatCompletionRequest{} - } else { - request = &embeddingsRequest{} - } - if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil { - return nil, err + switch apiName { + case ApiNameChatCompletion: + stream := gjson.GetBytes(body, "stream").Bool() + if stream { + _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") + ctx.SetContext(ctxKeyIsStreaming, true) + } else { + ctx.SetContext(ctxKeyIsStreaming, false) + } } - return json.Marshal(request) + model := gjson.GetBytes(body, "model").String() + ctx.SetContext(ctxKeyOriginalRequestModel, model) + return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping, log)) } func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 26e77813cd..e5650f355f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -52,7 +52,15 @@ func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error { return nil } +func (m *qwenProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): qwenChatCompletionPath, + string(ApiNameEmbeddings): qwenTextEmbeddingPath, + } +} + func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &qwenProvider{ config: config, contextCache: createContextCache(&config), @@ -75,18 +83,19 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName if m.config.IsOriginal() { } else if m.config.qwenEnableCompatible { util.OverwriteRequestPathHeader(headers, qwenCompatiblePath) - } else if apiName == ApiNameChatCompletion { - util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath) - } else if apiName == ApiNameEmbeddings { - util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath) + } else if apiName == ApiNameChatCompletion || apiName == ApiNameEmbeddings { + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) } } func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { - if apiName == ApiNameChatCompletion { + switch apiName { + case ApiNameChatCompletion: return m.onChatCompletionRequestBody(ctx, body, headers, log) - } else { + case ApiNameEmbeddings: return m.onEmbeddingsRequestBody(ctx, body, log) + default: + return m.config.defaultTransformRequestBody(ctx, apiName, body, log) } } @@ -95,7 +104,7 @@ func (m *qwenProvider) GetProviderType() string { } func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } @@ -140,7 +149,7 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b return types.ActionContinue, nil } - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) @@ -278,6 +287,9 @@ func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName Ap if apiName == ApiNameEmbeddings { return m.onEmbeddingsResponseBody(ctx, body, log) } + if m.config.isSupportedAPI(apiName) { + return body, nil + } return nil, errUnsupportedApiName } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index 72d17a2473..bac72e0239 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -55,7 +55,14 @@ func (i *sparkProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (i *sparkProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): sparkChatCompletionPath, + } +} + func (i *sparkProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(i.DefaultCapabilities()) return &sparkProvider{ config: config, contextCache: createContextCache(&config), @@ -67,7 +74,7 @@ func (p *sparkProvider) GetProviderType() string { } func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !p.config.isSupportedAPI(apiName) { return errUnsupportedApiName } p.config.handleRequestHeaders(p, ctx, apiName, log) @@ -75,13 +82,16 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !p.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log) } func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } sparkResponse := &sparkResponse{} if err := json.Unmarshal(body, sparkResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal spark response: %v", err) @@ -97,6 +107,9 @@ func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Ap if isLastChunk || len(chunk) == 0 { return nil, nil } + if name != ApiNameChatCompletion { + return chunk, nil + } responseBuilder := &strings.Builder{} lines := strings.Split(string(chunk), "\n") for _, data := range lines { @@ -168,7 +181,7 @@ func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, respons } func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), p.config.capabilities) util.OverwriteRequestHostHeader(headers, sparkHost) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx)) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index 4cd5de7dd5..71315621a4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -24,7 +24,15 @@ func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) erro return nil } +func (m *stepfunProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // stepfun的chat接口path和OpenAI的chat接口一样 + string(ApiNameChatCompletion): stepfunChatCompletionPath, + } +} + func (m *stepfunProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &stepfunProvider{ config: config, contextCache: createContextCache(&config), @@ -41,7 +49,7 @@ func (m *stepfunProvider) GetProviderType() string { } func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -49,14 +57,14 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, stepfunChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, stepfunDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go b/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go index 8e24280308..dfbeb401ca 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go @@ -2,11 +2,12 @@ package provider import ( "errors" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "strings" ) const ( @@ -23,7 +24,14 @@ func (m *togetherAIProviderInitializer) ValidateConfig(config *ProviderConfig) e return nil } +func (m *togetherAIProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): togetherAICompletionPath, + } +} + func (m *togetherAIProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &togetherAIProvider{ config: config, contextCache: createContextCache(&config), @@ -40,7 +48,7 @@ func (m *togetherAIProvider) GetProviderType() string { } func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -48,14 +56,14 @@ func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A } func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, togetherAICompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, togetherAIDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index a4d09c8774..3c3db4d5de 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -24,7 +24,14 @@ func (m *yiProviderInitializer) ValidateConfig(config *ProviderConfig) error { return nil } +func (m *yiProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): yiChatCompletionPath, + } +} + func (m *yiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &yiProvider{ config: config, contextCache: createContextCache(&config), @@ -41,7 +48,7 @@ func (m *yiProvider) GetProviderType() string { } func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -49,14 +56,14 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, } func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, yiChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, yiDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index e8a873c789..e95e99fc8b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -13,6 +13,7 @@ import ( const ( zhipuAiDomain = "open.bigmodel.cn" zhipuAiChatCompletionPath = "/api/paas/v4/chat/completions" + zhipuAiEmbeddingsPath = "/api/paas/v4/embeddings" ) type zhipuAiProviderInitializer struct{} @@ -24,7 +25,15 @@ func (m *zhipuAiProviderInitializer) ValidateConfig(config *ProviderConfig) erro return nil } +func (m *zhipuAiProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): zhipuAiChatCompletionPath, + string(ApiNameEmbeddings): zhipuAiEmbeddingsPath, + } +} + func (m *zhipuAiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &zhipuAiProvider{ config: config, contextCache: createContextCache(&config), @@ -41,7 +50,7 @@ func (m *zhipuAiProvider) GetProviderType() string { } func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -49,14 +58,14 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, zhipuAiChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, zhipuAiDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") @@ -66,5 +75,8 @@ func (m *zhipuAiProvider) GetApiName(path string) ApiName { if strings.Contains(path, zhipuAiChatCompletionPath) { return ApiNameChatCompletion } + if strings.Contains(path, zhipuAiEmbeddingsPath) { + return ApiNameEmbeddings + } return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index 4f36871b75..4e5d5067f4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -57,6 +57,17 @@ func OverwriteRequestPathHeader(headers http.Header, path string) { headers.Set(":path", path) } +func OverwriteRequestPathHeaderByCapability(headers http.Header, apiName string, mapping map[string]string) { + mappedPath, exist := mapping[apiName] + if !exist { + return + } + if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil { + headers.Set("X-ENVOY-ORIGINAL-PATH", originPath) + } + headers.Set(":path", mappedPath) +} + func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) { if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" { if originAuth := headers.Get("Authorization"); originAuth != "" {