From 3aae782aec699b1c6d0c773cae49634e1f8c4d21 Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Wed, 22 Jan 2025 11:27:11 +0800 Subject: [PATCH 01/10] feature: allow ai-proxy to forward standard AI capabilities that are natively supported --- plugins/wasm-go/extensions/ai-proxy/README.md | 2 +- plugins/wasm-go/extensions/ai-proxy/main.go | 9 ++++ .../extensions/ai-proxy/provider/ai360.go | 3 +- .../extensions/ai-proxy/provider/azure.go | 5 ++- .../extensions/ai-proxy/provider/baichuan.go | 5 ++- .../extensions/ai-proxy/provider/baidu.go | 5 ++- .../extensions/ai-proxy/provider/claude.go | 5 ++- .../ai-proxy/provider/cloudflare.go | 5 ++- .../extensions/ai-proxy/provider/cohere.go | 5 ++- .../extensions/ai-proxy/provider/coze.go | 1 + .../extensions/ai-proxy/provider/deepl.go | 5 ++- .../extensions/ai-proxy/provider/deepseek.go | 5 ++- .../extensions/ai-proxy/provider/doubao.go | 5 ++- .../extensions/ai-proxy/provider/gemini.go | 5 ++- .../extensions/ai-proxy/provider/github.go | 5 ++- .../extensions/ai-proxy/provider/groq.go | 5 ++- .../extensions/ai-proxy/provider/hunyuan.go | 5 ++- .../extensions/ai-proxy/provider/minimax.go | 5 ++- .../extensions/ai-proxy/provider/mistral.go | 5 ++- .../extensions/ai-proxy/provider/moonshot.go | 5 ++- .../extensions/ai-proxy/provider/ollama.go | 5 ++- .../extensions/ai-proxy/provider/openai.go | 1 + .../extensions/ai-proxy/provider/provider.go | 45 ++++++++++++++++++- .../extensions/ai-proxy/provider/qwen.go | 8 +++- .../extensions/ai-proxy/provider/spark.go | 5 ++- .../extensions/ai-proxy/provider/stepfun.go | 5 ++- .../ai-proxy/provider/together_ai.go | 10 +++-- .../extensions/ai-proxy/provider/yi.go | 5 ++- .../extensions/ai-proxy/provider/zhipuai.go | 5 ++- 29 files changed, 132 insertions(+), 52 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 8f2ae49a5d..4062a6f220 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -42,7 +42,7 @@ description: AI 代理插件配置参考 | `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | | `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 | | `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 | - +| `capabilities` | array of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, 当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, openai/v1/audiotranscription | `context`的配置字段说明如下: | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 220243f989..a41a73cd8d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -274,5 +274,14 @@ func getOpenAiApiName(path string) provider.ApiName { if strings.HasSuffix(path, "/v1/embeddings") { return provider.ApiNameEmbeddings } + if strings.HasSuffix(path, "/v1/audio/transcriptions") { + return provider.ApiNameAudioTranscription + } + if strings.HasSuffix(path, "/v1/audio/speech") { + return provider.ApiNameAudioSpeech + } + if strings.HasSuffix(path, "/v1/images/generations") { + return provider.ApiNameImageGeneration + } return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index 5a8f4d70cd..44d2176591 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -30,6 +30,7 @@ func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (m *ai360ProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion, ApiNameEmbeddings) return &ai360Provider{ config: config, contextCache: createContextCache(&config), @@ -41,7 +42,7 @@ func (m *ai360Provider) GetProviderType() string { } func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index 4c107edf17..1e92d5182a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -35,6 +35,7 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid } else { serviceUrl = u } + config.setDefaultCapabilities(ApiNameChatCompletion) return &azureProvider{ config: config, serviceUrl: serviceUrl, @@ -54,7 +55,7 @@ func (m *azureProvider) GetProviderType() string { } func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -62,7 +63,7 @@ func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index 3b272e4f37..2367270245 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -27,6 +27,7 @@ func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) err } func (m *baichuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &baichuanProvider{ config: config, contextCache: createContextCache(&config), @@ -43,7 +44,7 @@ func (m *baichuanProvider) GetProviderType() string { } func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -51,7 +52,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index f541d31fec..fd5f422608 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -26,6 +26,7 @@ func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &baiduProvider{ config: config, contextCache: createContextCache(&config), @@ -42,7 +43,7 @@ func (g *baiduProvider) GetProviderType() string { } func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !g.config.isSupportedAPI(apiName) { return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) @@ -50,7 +51,7 @@ func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !g.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 9be84cc44e..604be7574d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -86,6 +86,7 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &claudeProvider{ config: config, contextCache: createContextCache(&config), @@ -102,7 +103,7 @@ func (c *claudeProvider) GetProviderType() string { } func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !c.config.isSupportedAPI(apiName) { return errUnsupportedApiName } c.config.handleRequestHeaders(c, ctx, apiName, log) @@ -123,7 +124,7 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam } func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !c.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index e191b89f37..0c5c96a350 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -27,6 +27,7 @@ func (c *cloudflareProviderInitializer) ValidateConfig(config *ProviderConfig) e } func (c *cloudflareProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &cloudflareProvider{ config: config, contextCache: createContextCache(&config), @@ -43,7 +44,7 @@ func (c *cloudflareProvider) GetProviderType() string { } func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !c.config.isSupportedAPI(apiName) { return errUnsupportedApiName } c.config.handleRequestHeaders(c, ctx, apiName, log) @@ -51,7 +52,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A } func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !c.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index 7bee5efd58..f904ccdd22 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -26,6 +26,7 @@ func (m *cohereProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &cohereProvider{ config: config, contextCache: createContextCache(&config), @@ -56,7 +57,7 @@ func (m *cohereProvider) GetProviderType() string { } func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -64,7 +65,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go index 4e30ec27c3..ca419552cf 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go @@ -22,6 +22,7 @@ func (m *cozeProviderInitializer) ValidateConfig(config *ProviderConfig) error { } func (m *cozeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &cozeProvider{ config: config, contextCache: createContextCache(&config), diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 192904ecc8..12a0fdb893 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -65,6 +65,7 @@ func (d *deeplProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (d *deeplProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &deeplProvider{ config: config, contextCache: createContextCache(&config), @@ -76,7 +77,7 @@ func (d *deeplProvider) GetProviderType() string { } func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !d.config.isSupportedAPI(apiName) { return errUnsupportedApiName } d.config.handleRequestHeaders(d, ctx, apiName, log) @@ -89,7 +90,7 @@ func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName } func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !d.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index b6a842f1c6..dd146282a8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -27,6 +27,7 @@ func (m *deepseekProviderInitializer) ValidateConfig(config *ProviderConfig) err } func (m *deepseekProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &deepseekProvider{ config: config, contextCache: createContextCache(&config), @@ -43,7 +44,7 @@ func (m *deepseekProvider) GetProviderType() string { } func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -51,7 +52,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index ed8b3b18bf..291b9a27f2 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -25,6 +25,7 @@ func (m *doubaoProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (m *doubaoProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &doubaoProvider{ config: config, contextCache: createContextCache(&config), @@ -41,7 +42,7 @@ func (m *doubaoProvider) GetProviderType() string { } func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -49,7 +50,7 @@ func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 1017b0f028..4eecfd56a6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -36,6 +36,7 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion, ApiNameEmbeddings) return &geminiProvider{ config: config, contextCache: createContextCache(&config), @@ -52,7 +53,7 @@ func (g *geminiProvider) GetProviderType() string { } func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !g.config.isSupportedAPI(apiName) { return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) @@ -66,7 +67,7 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam } func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !g.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index fb5649a673..5879f63e59 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -33,6 +33,7 @@ func (m *githubProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (m *githubProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion, ApiNameEmbeddings) return &githubProvider{ config: config, contextCache: createContextCache(&config), @@ -44,7 +45,7 @@ func (m *githubProvider) GetProviderType() string { } func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -53,7 +54,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index 04e29500df..052448e97a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -26,6 +26,7 @@ func (g *groqProviderInitializer) ValidateConfig(config *ProviderConfig) error { } func (g *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &groqProvider{ config: config, contextCache: createContextCache(&config), @@ -42,7 +43,7 @@ func (g *groqProvider) GetProviderType() string { } func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !g.config.isSupportedAPI(apiName) { return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) @@ -50,7 +51,7 @@ func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName } func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !g.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index 583c838c2c..dd38fbb0c8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -94,6 +94,7 @@ func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) erro } func (m *hunyuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &hunyuanProvider{ config: config, client: wrapper.NewClusterClient(wrapper.RouteCluster{ @@ -115,7 +116,7 @@ func (m *hunyuanProvider) GetProviderType() string { } func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -134,7 +135,7 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa // hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法 func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index 02812c0b46..adc98ce53e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -50,6 +50,7 @@ func (m *minimaxProviderInitializer) ValidateConfig(config *ProviderConfig) erro } func (m *minimaxProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &minimaxProvider{ config: config, contextCache: createContextCache(&config), @@ -66,7 +67,7 @@ func (m *minimaxProvider) GetProviderType() string { } func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -81,7 +82,7 @@ func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa } func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } if minimaxApiTypePro == m.config.minimaxApiType { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index eb5b319a20..a3142fb4ed 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -23,6 +23,7 @@ func (m *mistralProviderInitializer) ValidateConfig(config *ProviderConfig) erro } func (m *mistralProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &mistralProvider{ config: config, contextCache: createContextCache(&config), @@ -39,7 +40,7 @@ func (m *mistralProvider) GetProviderType() string { } func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -47,7 +48,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 6477b8bda4..910f9cd924 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -35,6 +35,7 @@ func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) err } func (m *moonshotProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &moonshotProvider{ config: config, client: wrapper.NewClusterClient(wrapper.RouteCluster{ @@ -57,7 +58,7 @@ func (m *moonshotProvider) GetProviderType() string { } func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -74,7 +75,7 @@ func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN // moonshot 有自己获取 context 的配置(moonshotFileId),因此无法复用 handleRequestBody 方法 // moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法 func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index 76b74f65fb..62729e28ed 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -32,6 +32,7 @@ func (m *ollamaProviderInitializer) ValidateConfig(config *ProviderConfig) error func (m *ollamaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { serverPortStr := fmt.Sprintf("%d", config.ollamaServerPort) serviceDomain := config.ollamaServerHost + ":" + serverPortStr + config.setDefaultCapabilities(ApiNameChatCompletion) return &ollamaProvider{ config: config, serviceDomain: serviceDomain, @@ -50,7 +51,7 @@ func (m *ollamaProvider) GetProviderType() string { } func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -58,7 +59,7 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 2315b09564..6bf29cabb0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -38,6 +38,7 @@ func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provi if len(pairs) != 2 { return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl) } + config.setDefaultCapabilities(ApiNameChatCompletion, ApiNameEmbeddings) return &openaiProvider{ config: config, customDomain: pairs[0], diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index c59787e6da..24f694f8cb 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -18,8 +18,18 @@ type ApiName string type Pointcut string const ( - ApiNameChatCompletion ApiName = "chatCompletion" - ApiNameEmbeddings ApiName = "embeddings" + + // ApiName 格式 {vendor}/{version}/{apitype} + // 表示遵循 厂商/版本/接口类型 的格式 + // 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank + ApiNameChatCompletion ApiName = "openai/v1/chatcompletions" + ApiNameEmbeddings ApiName = "openai/v1/embeddings" + ApiNameImageGeneration ApiName = "openai/v1/imagegeneration" + ApiNameAudioSpeech ApiName = "openai/v1/audiospeech" + ApiNameAudioTranscription ApiName = "openai/v1/audiotranscription" + + // TODO: 以下是一些非标准的API名称,需要进一步确认 + // ApiNameCohereRerank ApiName = "cohere/v1/rerank" providerTypeMoonshot = "moonshot" providerTypeAzure = "azure" @@ -250,6 +260,9 @@ type ProviderConfig struct { inputVariable string `required:"false" yaml:"inputVariable" json:"inputVariable"` // @Title zh-CN dify中应用类型为workflow时需要设置输出变量,当botType为workflow时一起使用 outputVariable string `required:"false" yaml:"outputVariable" json:"outputVariable"` + // @Title zh-CN 支持的能力 + // @Description zh-CN 开放的ai能力,例如: "openai/v1/chatcompletions" "openai/v1/embeddings" "openai/v1/imagegeneration" "openai/v1/audiospeech" "openai/v1/audiotranscription" + capabilities []string } func (c *ProviderConfig) GetId() string { @@ -361,6 +374,19 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.botType = json.Get("botType").String() c.inputVariable = json.Get("inputVariable").String() c.outputVariable = json.Get("outputVariable").String() + + c.capabilities = make([]string, 0) + for _, ability := range json.Get("abilities").Array() { + // 过滤掉不受支持的能力 + switch ability.String() { + case string(ApiNameChatCompletion), + string(ApiNameEmbeddings), + string(ApiNameImageGeneration), + string(ApiNameAudioSpeech), + string(ApiNameAudioTranscription): + c.capabilities = append(c.capabilities, ability.String()) + } + } } func (c *ProviderConfig) Validate() error { @@ -527,6 +553,21 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper. return "" } +func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool { + for _, ability := range c.capabilities { + if ability == string(apiName) { + return true + } + } + return false +} + +func (c *ProviderConfig) setDefaultCapabilities(capabilities ...ApiName) { + for _, ability := range capabilities { + c.capabilities = append(c.capabilities, string(ability)) + } +} + func (c *ProviderConfig) handleRequestBody( provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log, ) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 26e77813cd..722e848d3d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -53,6 +53,7 @@ func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error { } func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion, ApiNameEmbeddings) return &qwenProvider{ config: config, contextCache: createContextCache(&config), @@ -95,7 +96,7 @@ func (m *qwenProvider) GetProviderType() string { } func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } @@ -140,7 +141,7 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b return types.ActionContinue, nil } - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) @@ -278,6 +279,9 @@ func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName Ap if apiName == ApiNameEmbeddings { return m.onEmbeddingsResponseBody(ctx, body, log) } + if m.config.isSupportedAPI(apiName) { + return body, nil + } return nil, errUnsupportedApiName } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index 72d17a2473..2148d18dbe 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -56,6 +56,7 @@ func (i *sparkProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (i *sparkProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &sparkProvider{ config: config, contextCache: createContextCache(&config), @@ -67,7 +68,7 @@ func (p *sparkProvider) GetProviderType() string { } func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !p.config.isSupportedAPI(apiName) { return errUnsupportedApiName } p.config.handleRequestHeaders(p, ctx, apiName, log) @@ -75,7 +76,7 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !p.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index 4cd5de7dd5..b473bfebff 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -25,6 +25,7 @@ func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) erro } func (m *stepfunProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &stepfunProvider{ config: config, contextCache: createContextCache(&config), @@ -41,7 +42,7 @@ func (m *stepfunProvider) GetProviderType() string { } func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -49,7 +50,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go b/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go index 8e24280308..e6337c6a7a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go @@ -2,11 +2,12 @@ package provider import ( "errors" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "strings" ) const ( @@ -24,6 +25,7 @@ func (m *togetherAIProviderInitializer) ValidateConfig(config *ProviderConfig) e } func (m *togetherAIProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &togetherAIProvider{ config: config, contextCache: createContextCache(&config), @@ -40,7 +42,7 @@ func (m *togetherAIProvider) GetProviderType() string { } func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -48,7 +50,7 @@ func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A } func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index a4d09c8774..cd2cb7edb8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -25,6 +25,7 @@ func (m *yiProviderInitializer) ValidateConfig(config *ProviderConfig) error { } func (m *yiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &yiProvider{ config: config, contextCache: createContextCache(&config), @@ -41,7 +42,7 @@ func (m *yiProvider) GetProviderType() string { } func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -49,7 +50,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, } func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index e8a873c789..9f395977f1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -25,6 +25,7 @@ func (m *zhipuAiProviderInitializer) ValidateConfig(config *ProviderConfig) erro } func (m *zhipuAiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(ApiNameChatCompletion) return &zhipuAiProvider{ config: config, contextCache: createContextCache(&config), @@ -41,7 +42,7 @@ func (m *zhipuAiProvider) GetProviderType() string { } func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -49,7 +50,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) From aa1693b25b2f0685a6729eef42002daaad795e11 Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Wed, 22 Jan 2025 17:52:16 +0800 Subject: [PATCH 02/10] feature: support custom model capability-path mapping --- .../extensions/ai-proxy/provider/ai360.go | 12 +++++-- .../extensions/ai-proxy/provider/azure.go | 11 +++++- .../extensions/ai-proxy/provider/baichuan.go | 15 +++++--- .../extensions/ai-proxy/provider/baidu.go | 12 +++++-- .../extensions/ai-proxy/provider/claude.go | 12 +++++-- .../ai-proxy/provider/cloudflare.go | 7 +++- .../extensions/ai-proxy/provider/cohere.go | 13 +++++-- .../extensions/ai-proxy/provider/coze.go | 8 ++++- .../extensions/ai-proxy/provider/deepl.go | 8 ++++- .../extensions/ai-proxy/provider/deepseek.go | 13 +++++-- .../extensions/ai-proxy/provider/doubao.go | 15 ++++++-- .../extensions/ai-proxy/provider/gemini.go | 10 +++++- .../extensions/ai-proxy/provider/github.go | 16 +++++---- .../extensions/ai-proxy/provider/groq.go | 10 ++++-- .../extensions/ai-proxy/provider/hunyuan.go | 9 ++++- .../extensions/ai-proxy/provider/minimax.go | 23 +++++------- .../extensions/ai-proxy/provider/mistral.go | 11 +++++- .../extensions/ai-proxy/provider/moonshot.go | 10 ++++-- .../extensions/ai-proxy/provider/ollama.go | 15 ++++---- .../extensions/ai-proxy/provider/openai.go | 17 ++++----- .../extensions/ai-proxy/provider/provider.go | 35 +++++++++---------- .../extensions/ai-proxy/provider/qwen.go | 15 +++++--- .../extensions/ai-proxy/provider/spark.go | 10 ++++-- .../extensions/ai-proxy/provider/stepfun.go | 11 ++++-- .../ai-proxy/provider/together_ai.go | 10 ++++-- .../extensions/ai-proxy/provider/yi.go | 10 ++++-- .../extensions/ai-proxy/provider/zhipuai.go | 15 ++++++-- .../wasm-go/extensions/ai-proxy/util/http.go | 11 ++++++ 28 files changed, 268 insertions(+), 96 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index 44d2176591..57b092cd82 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -22,6 +22,13 @@ type ai360Provider struct { contextCache *contextCache } +func (m *ai360ProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + } +} + func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error { if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") @@ -30,7 +37,7 @@ func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (m *ai360ProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion, ApiNameEmbeddings) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &ai360Provider{ config: config, contextCache: createContextCache(&config), @@ -51,7 +58,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) @@ -59,5 +66,6 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteRequestHostHeader(headers, ai360Domain) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index 1e92d5182a..bddfccdf09 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -15,6 +15,15 @@ import ( type azureProviderInitializer struct { } +func (m *azureProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // azure 此配置无实质作用,只是为了保持和其他provider的一致性 + // TODO: azure的模式和openai是一致的,只是需要处理前缀,可以在TransformRequestHeaders中处理,以支持通用能力 + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + } +} + func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error { if config.azureServiceUrl == "" { return errors.New("missing azureServiceUrl in provider config") @@ -35,7 +44,7 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid } else { serviceUrl = u } - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &azureProvider{ config: config, serviceUrl: serviceUrl, diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index 2367270245..7e0c362148 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -12,8 +12,7 @@ import ( // baichuanProvider is the provider for baichuan Ai service. const ( - baichuanDomain = "api.baichuan-ai.com" - baichuanChatCompletionPath = "/v1/chat/completions" + baichuanDomain = "api.baichuan-ai.com" ) type baichuanProviderInitializer struct { @@ -26,8 +25,16 @@ func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) err return nil } +func (m *baichuanProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // 百川AI的chat和embeddings接口和OpenAI的chat和embeddings接口一样 + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + } +} + func (m *baichuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &baichuanProvider{ config: config, contextCache: createContextCache(&config), @@ -59,7 +66,7 @@ func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam } func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, baichuanDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index fd5f422608..27bf5aaecc 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -14,6 +14,7 @@ import ( const ( baiduDomain = "qianfan.baidubce.com" baiduChatCompletionPath = "/v2/chat/completions" + baiduEmbeddings = "/v2/embeddings" ) type baiduProviderInitializer struct{} @@ -25,8 +26,15 @@ func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (g *baiduProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): baiduChatCompletionPath, + string(ApiNameEmbeddings): baiduEmbeddings, + } +} + func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(g.DefaultCapabilities()) return &baiduProvider{ config: config, contextCache: createContextCache(&config), @@ -58,7 +66,7 @@ func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, } func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, baiduChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities) util.OverwriteRequestHostHeader(headers, baiduDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 604be7574d..470ea4de87 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -85,8 +85,16 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): claudeChatCompletionPath, + // docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + } +} + func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(c.DefaultCapabilities()) return &claudeProvider{ config: config, contextCache: createContextCache(&config), @@ -111,7 +119,7 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities) util.OverwriteRequestHostHeader(headers, claudeDomain) headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx)) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index 0c5c96a350..22b9cd4286 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -25,9 +25,14 @@ func (c *cloudflareProviderInitializer) ValidateConfig(config *ProviderConfig) e } return nil } +func (c *cloudflareProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): cloudflareChatCompletionPath, + } +} func (c *cloudflareProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(c.DefaultCapabilities()) return &cloudflareProvider{ config: config, contextCache: createContextCache(&config), diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index f904ccdd22..0964051454 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -12,8 +12,10 @@ import ( ) const ( - cohereDomain = "api.cohere.com" + cohereDomain = "api.cohere.com" + // TODO: 现在cohere有v2, 也有embeddings, 考虑更多支持: https://docs.cohere.com/v2/reference/rerank cohereChatCompletionPath = "/v1/chat" + cohereRerankPath = "/v1/rerank" ) type cohereProviderInitializer struct{} @@ -25,8 +27,15 @@ func (m *cohereProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (m *cohereProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): cohereChatCompletionPath, + string(ApiNameCohereV1Rerank): cohereRerankPath, + } +} + func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &cohereProvider{ config: config, contextCache: createContextCache(&config), diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go index ca419552cf..874f7a48ae 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go @@ -21,8 +21,14 @@ func (m *cozeProviderInitializer) ValidateConfig(config *ProviderConfig) error { return nil } +func (m *cozeProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // 此配置暂时无实质作用,只是为了保持和其他provider的一致性 + } +} + func (m *cozeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &cozeProvider{ config: config, contextCache: createContextCache(&config), diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 12a0fdb893..46b0dbab69 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -64,8 +64,14 @@ func (d *deeplProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (d *deeplProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): deeplChatCompletionPath, + } +} + func (d *deeplProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(d.DefaultCapabilities()) return &deeplProvider{ config: config, contextCache: createContextCache(&config), diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index dd146282a8..81ba9e8562 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -12,7 +12,8 @@ import ( // deepseekProvider is the provider for deepseek Ai service. const ( - deepseekDomain = "api.deepseek.com" + deepseekDomain = "api.deepseek.com" + // TODO: 根据文档 docs: https://api-docs.deepseek.com/api/create-chat-completion, path应该是 /chat/completions, 待验证 deepseekChatCompletionPath = "/v1/chat/completions" ) @@ -26,8 +27,14 @@ func (m *deepseekProviderInitializer) ValidateConfig(config *ProviderConfig) err return nil } +func (m *deepseekProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): deepseekChatCompletionPath, + } +} + func (m *deepseekProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &deepseekProvider{ config: config, contextCache: createContextCache(&config), @@ -59,7 +66,7 @@ func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam } func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, deepseekChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, deepseekDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index 291b9a27f2..a896078e12 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -13,6 +13,7 @@ import ( const ( doubaoDomain = "ark.cn-beijing.volces.com" doubaoChatCompletionPath = "/api/v3/chat/completions" + doubaoEmbeddingsPath = "/api/v3/embeddings" ) type doubaoProviderInitializer struct{} @@ -24,8 +25,15 @@ func (m *doubaoProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (m *doubaoProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): doubaoChatCompletionPath, + string(ApiNameEmbeddings): doubaoEmbeddingsPath, + } +} + func (m *doubaoProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &doubaoProvider{ config: config, contextCache: createContextCache(&config), @@ -57,7 +65,7 @@ func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, } func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, doubaoChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, doubaoDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") @@ -67,5 +75,8 @@ func (m *doubaoProvider) GetApiName(path string) ApiName { if strings.Contains(path, doubaoChatCompletionPath) { return ApiNameChatCompletion } + if strings.Contains(path, doubaoEmbeddingsPath) { + return ApiNameEmbeddings + } return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 4eecfd56a6..7e456115c0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -35,8 +35,16 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // path在gemini中没有实际意义,只是为了保持和其他provider的一致性 + string(ApiNameChatCompletion): "_", + string(ApiNameEmbeddings): "_", + } +} + func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion, ApiNameEmbeddings) + config.setDefaultCapabilities(g.DefaultCapabilities()) return &geminiProvider{ config: config, contextCache: createContextCache(&config), diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index 5879f63e59..e8a05cc1c9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -32,8 +32,15 @@ func (m *githubProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (m *githubProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): githubCompletionPath, + string(ApiNameEmbeddings): githubEmbeddingPath, + } +} + func (m *githubProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion, ApiNameEmbeddings) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &githubProvider{ config: config, contextCache: createContextCache(&config), @@ -62,12 +69,7 @@ func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteRequestHostHeader(headers, githubDomain) - if apiName == ApiNameChatCompletion { - util.OverwriteRequestPathHeader(headers, githubCompletionPath) - } - if apiName == ApiNameEmbeddings { - util.OverwriteRequestPathHeader(headers, githubEmbeddingPath) - } + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index 052448e97a..c415a707b8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -25,8 +25,14 @@ func (g *groqProviderInitializer) ValidateConfig(config *ProviderConfig) error { return nil } +func (g *groqProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): groqChatCompletionPath, + } +} + func (g *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(g.DefaultCapabilities()) return &groqProvider{ config: config, contextCache: createContextCache(&config), @@ -58,7 +64,7 @@ func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b } func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, groqChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities) util.OverwriteRequestHostHeader(headers, groqDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index dd38fbb0c8..4cb6e3a3f8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -93,8 +93,14 @@ func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) erro return nil } +func (m *hunyuanProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): hunyuanRequestPath, + } +} + func (m *hunyuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &hunyuanProvider{ config: config, client: wrapper.NewClusterClient(wrapper.RouteCluster{ @@ -127,6 +133,7 @@ func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteRequestHostHeader(headers, hunyuanDomain) util.OverwriteRequestPathHeader(headers, hunyuanRequestPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) // 添加 hunyuan 需要的自定义字段 headers.Set(actionKey, hunyuanChatCompletionTCAction) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index adc98ce53e..c9f7aa1030 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -41,7 +41,7 @@ type minimaxProviderInitializer struct { func (m *minimaxProviderInitializer) ValidateConfig(config *ProviderConfig) error { // If using the chat completion Pro API, a group ID must be set. if minimaxApiTypePro == config.minimaxApiType && config.minimaxGroupId == "" { - return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when minimaxApiType is %s", minimaxApiTypePro)) + return fmt.Errorf("missing minimaxGroupId in provider config when minimaxApiType is %s", minimaxApiTypePro) } if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") @@ -49,8 +49,15 @@ func (m *minimaxProviderInitializer) ValidateConfig(config *ProviderConfig) erro return nil } +func (m *minimaxProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // minimax 替换path的时候,要根据modelmapping替换,这儿的配置无实质作用,只是为了保持和其他provider的一致性 + string(ApiNameChatCompletion): minimaxChatCompletionV2Path, + } +} + func (m *minimaxProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &minimaxProvider{ config: config, contextCache: createContextCache(&config), @@ -269,18 +276,6 @@ type minimaxUsage struct { CompletionTokens int64 `json:"completion_tokens"` } -func (m *minimaxProvider) parseModel(body []byte) (string, error) { - var tempMap map[string]interface{} - if err := json.Unmarshal(body, &tempMap); err != nil { - return "", err - } - model, ok := tempMap["model"].(string) - if !ok { - return "", errors.New("missing model in chat completion request") - } - return model, nil -} - func (m *minimaxProvider) setBotSettings(request *minimaxChatCompletionProRequest, botSettingContent string) { if len(request.BotSettings) == 0 { request.BotSettings = []minimaxBotSetting{ diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index a3142fb4ed..9f7a71f06f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -22,8 +22,17 @@ func (m *mistralProviderInitializer) ValidateConfig(config *ProviderConfig) erro return nil } +func (m *mistralProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // mistral的chat接口和OpenAI的chat接口一样 + // docs: https://docs.mistral.ai/api/ + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + } +} + func (m *mistralProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &mistralProvider{ config: config, contextCache: createContextCache(&config), diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 910f9cd924..b49800d621 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -34,8 +34,14 @@ func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) err return nil } +func (m *moonshotProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): moonshotChatCompletionPath, + } +} + func (m *moonshotProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &moonshotProvider{ config: config, client: wrapper.NewClusterClient(wrapper.RouteCluster{ @@ -66,7 +72,7 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api } func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, moonshotChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, moonshotDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index 62729e28ed..57ad424ffa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -12,10 +12,6 @@ import ( // ollamaProvider is the provider for Ollama service. -const ( - ollamaChatCompletionPath = "/v1/chat/completions" -) - type ollamaProviderInitializer struct { } @@ -29,10 +25,17 @@ func (m *ollamaProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (m *ollamaProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // ollama的chat接口path和OpenAI的chat接口一样 + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + } +} + func (m *ollamaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { serverPortStr := fmt.Sprintf("%d", config.ollamaServerPort) serviceDomain := config.ollamaServerHost + ":" + serverPortStr - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &ollamaProvider{ config: config, serviceDomain: serviceDomain, @@ -66,7 +69,7 @@ func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, } func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, ollamaChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, m.serviceDomain) headers.Del("Content-Length") } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 6bf29cabb0..60767b0d02 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -26,6 +26,13 @@ func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath, + string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath, + } +} + func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { if config.openaiCustomUrl == "" { return &openaiProvider{ @@ -38,7 +45,7 @@ func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provi if len(pairs) != 2 { return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl) } - config.setDefaultCapabilities(ApiNameChatCompletion, ApiNameEmbeddings) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &openaiProvider{ config: config, customDomain: pairs[0], @@ -65,13 +72,7 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { if m.customPath == "" { - switch apiName { - case ApiNameChatCompletion: - util.OverwriteRequestPathHeader(headers, defaultOpenaiChatCompletionPath) - case ApiNameEmbeddings: - ctx.DontReadRequestBody() - util.OverwriteRequestPathHeader(headers, defaultOpenaiEmbeddingsPath) - } + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) } else { util.OverwriteRequestPathHeader(headers, m.customPath) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 24f694f8cb..874b625ee9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -28,8 +28,11 @@ const ( ApiNameAudioSpeech ApiName = "openai/v1/audiospeech" ApiNameAudioTranscription ApiName = "openai/v1/audiotranscription" - // TODO: 以下是一些非标准的API名称,需要进一步确认 - // ApiNameCohereRerank ApiName = "cohere/v1/rerank" + PathOpenAIChatCompletions = "/v1/chat/completions" + PathOpenAIEmbeddings = "/v1/embeddings" + + // TODO: 以下是一些非标准的API名称,需要进一步确认是否支持 + ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank" providerTypeMoonshot = "moonshot" providerTypeAzure = "azure" @@ -260,9 +263,9 @@ type ProviderConfig struct { inputVariable string `required:"false" yaml:"inputVariable" json:"inputVariable"` // @Title zh-CN dify中应用类型为workflow时需要设置输出变量,当botType为workflow时一起使用 outputVariable string `required:"false" yaml:"outputVariable" json:"outputVariable"` - // @Title zh-CN 支持的能力 - // @Description zh-CN 开放的ai能力,例如: "openai/v1/chatcompletions" "openai/v1/embeddings" "openai/v1/imagegeneration" "openai/v1/audiospeech" "openai/v1/audiotranscription" - capabilities []string + // @Title zh-CN 额外支持的ai能力 + // @Description zh-CN 开放的ai能力和urlpath映射,例如: {"openai/v1/chatcompletions": "/v1/chat/completions"} + capabilities map[string]string } func (c *ProviderConfig) GetId() string { @@ -375,16 +378,16 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.inputVariable = json.Get("inputVariable").String() c.outputVariable = json.Get("outputVariable").String() - c.capabilities = make([]string, 0) - for _, ability := range json.Get("abilities").Array() { + c.capabilities = make(map[string]string) + for capability, pathJson := range json.Get("abilities").Map() { // 过滤掉不受支持的能力 - switch ability.String() { + switch capability { case string(ApiNameChatCompletion), string(ApiNameEmbeddings), string(ApiNameImageGeneration), string(ApiNameAudioSpeech), string(ApiNameAudioTranscription): - c.capabilities = append(c.capabilities, ability.String()) + c.capabilities[capability] = pathJson.String() } } } @@ -554,17 +557,13 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper. } func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool { - for _, ability := range c.capabilities { - if ability == string(apiName) { - return true - } - } - return false + _, exist := c.capabilities[string(apiName)] + return exist } -func (c *ProviderConfig) setDefaultCapabilities(capabilities ...ApiName) { - for _, ability := range capabilities { - c.capabilities = append(c.capabilities, string(ability)) +func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string) { + for capability, path := range capabilities { + c.capabilities[capability] = path } } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 722e848d3d..d3b7dc7953 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -52,8 +52,15 @@ func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error { return nil } +func (m *qwenProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): qwenChatCompletionPath, + string(ApiNameEmbeddings): qwenTextEmbeddingPath, + } +} + func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion, ApiNameEmbeddings) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &qwenProvider{ config: config, contextCache: createContextCache(&config), @@ -76,10 +83,8 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName if m.config.IsOriginal() { } else if m.config.qwenEnableCompatible { util.OverwriteRequestPathHeader(headers, qwenCompatiblePath) - } else if apiName == ApiNameChatCompletion { - util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath) - } else if apiName == ApiNameEmbeddings { - util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath) + } else if apiName == ApiNameChatCompletion || apiName == ApiNameEmbeddings { + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) } } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index 2148d18dbe..0fd10a8f3b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -55,8 +55,14 @@ func (i *sparkProviderInitializer) ValidateConfig(config *ProviderConfig) error return nil } +func (i *sparkProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): sparkChatCompletionPath, + } +} + func (i *sparkProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(i.DefaultCapabilities()) return &sparkProvider{ config: config, contextCache: createContextCache(&config), @@ -169,7 +175,7 @@ func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, respons } func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), p.config.capabilities) util.OverwriteRequestHostHeader(headers, sparkHost) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx)) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index b473bfebff..71315621a4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -24,8 +24,15 @@ func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) erro return nil } +func (m *stepfunProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + // stepfun的chat接口path和OpenAI的chat接口一样 + string(ApiNameChatCompletion): stepfunChatCompletionPath, + } +} + func (m *stepfunProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &stepfunProvider{ config: config, contextCache: createContextCache(&config), @@ -57,7 +64,7 @@ func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName } func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, stepfunChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, stepfunDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go b/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go index e6337c6a7a..dfbeb401ca 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go @@ -24,8 +24,14 @@ func (m *togetherAIProviderInitializer) ValidateConfig(config *ProviderConfig) e return nil } +func (m *togetherAIProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): togetherAICompletionPath, + } +} + func (m *togetherAIProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &togetherAIProvider{ config: config, contextCache: createContextCache(&config), @@ -57,7 +63,7 @@ func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiN } func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, togetherAICompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, togetherAIDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index cd2cb7edb8..3c3db4d5de 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -24,8 +24,14 @@ func (m *yiProviderInitializer) ValidateConfig(config *ProviderConfig) error { return nil } +func (m *yiProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): yiChatCompletionPath, + } +} + func (m *yiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &yiProvider{ config: config, contextCache: createContextCache(&config), @@ -57,7 +63,7 @@ func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, bod } func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, yiChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, yiDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index 9f395977f1..e95e99fc8b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -13,6 +13,7 @@ import ( const ( zhipuAiDomain = "open.bigmodel.cn" zhipuAiChatCompletionPath = "/api/paas/v4/chat/completions" + zhipuAiEmbeddingsPath = "/api/paas/v4/embeddings" ) type zhipuAiProviderInitializer struct{} @@ -24,8 +25,15 @@ func (m *zhipuAiProviderInitializer) ValidateConfig(config *ProviderConfig) erro return nil } +func (m *zhipuAiProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): zhipuAiChatCompletionPath, + string(ApiNameEmbeddings): zhipuAiEmbeddingsPath, + } +} + func (m *zhipuAiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - config.setDefaultCapabilities(ApiNameChatCompletion) + config.setDefaultCapabilities(m.DefaultCapabilities()) return &zhipuAiProvider{ config: config, contextCache: createContextCache(&config), @@ -57,7 +65,7 @@ func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName } func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, zhipuAiChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, zhipuAiDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") @@ -67,5 +75,8 @@ func (m *zhipuAiProvider) GetApiName(path string) ApiName { if strings.Contains(path, zhipuAiChatCompletionPath) { return ApiNameChatCompletion } + if strings.Contains(path, zhipuAiEmbeddingsPath) { + return ApiNameEmbeddings + } return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index 4f36871b75..4e5d5067f4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -57,6 +57,17 @@ func OverwriteRequestPathHeader(headers http.Header, path string) { headers.Set(":path", path) } +func OverwriteRequestPathHeaderByCapability(headers http.Header, apiName string, mapping map[string]string) { + mappedPath, exist := mapping[apiName] + if !exist { + return + } + if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil { + headers.Set("X-ENVOY-ORIGINAL-PATH", originPath) + } + headers.Set(":path", mappedPath) +} + func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) { if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" { if originAuth := headers.Get("Authorization"); originAuth != "" { From 9130e67bb341a46f3165df62c501f5e93e80ddba Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Wed, 22 Jan 2025 20:51:28 +0800 Subject: [PATCH 03/10] feature: support hunyuan openai compatibale api --- .../extensions/ai-proxy/provider/hunyuan.go | 40 ++++++++++++++----- .../extensions/ai-proxy/provider/provider.go | 2 +- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index 4cb6e3a3f8..c71227b1d2 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -39,6 +39,10 @@ const ( hunyuanAuthKeyLen = 32 hunyuanAuthIdLen = 36 + + // docs: https://cloud.tencent.com/document/product/1729/111007 + hunyuanOpenAiDomain = "api.hunyuan.cloud.tencent.com" + hunyuanOpenAiRequestPath = "/v1/chat/completions" ) type hunyuanProviderInitializer struct { @@ -86,6 +90,10 @@ type hunyuanChatMessage struct { } func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) error { + // 允许 hunyuanauthid 和 hunyuanauthkey 为空, 当他们都为空的时候,认为是使用openai的 兼容接口 + if len(config.hunyuanAuthId) == 0 && len(config.hunyuanAuthKey) == 0 { + return nil + } // 校验hunyuan id 和 key的合法性 if len(config.hunyuanAuthId) != hunyuanAuthIdLen || len(config.hunyuanAuthKey) != hunyuanAuthKeyLen { return errors.New("hunyuanAuthId / hunyuanAuthKey is illegal in config file") @@ -95,7 +103,7 @@ func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) erro func (m *hunyuanProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - string(ApiNameChatCompletion): hunyuanRequestPath, + string(ApiNameChatCompletion): hunyuanOpenAiRequestPath, } } @@ -121,6 +129,10 @@ func (m *hunyuanProvider) GetProviderType() string { return providerTypeHunyuan } +func (m *hunyuanProvider) useOpenAICompatibleAPI() bool { + return len(m.config.hunyuanAuthId) == 0 && len(m.config.hunyuanAuthKey) == 0 +} + func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { return errUnsupportedApiName @@ -131,13 +143,17 @@ func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestHostHeader(headers, hunyuanDomain) - util.OverwriteRequestPathHeader(headers, hunyuanRequestPath) - util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) - - // 添加 hunyuan 需要的自定义字段 - headers.Set(actionKey, hunyuanChatCompletionTCAction) - headers.Set(versionKey, versionValue) + if m.useOpenAICompatibleAPI() { + util.OverwriteRequestHostHeader(headers, hunyuanOpenAiDomain) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + } else { + util.OverwriteRequestHostHeader(headers, hunyuanDomain) + util.OverwriteRequestPathHeader(headers, hunyuanRequestPath) + // 添加 hunyuan 需要的自定义字段 + headers.Set(actionKey, hunyuanChatCompletionTCAction) + headers.Set(versionKey, versionValue) + } } // hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法 @@ -145,6 +161,9 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } + if m.useOpenAICompatibleAPI() { + return types.ActionContinue, nil + } // 为header添加时间戳字段 (因为需要根据body进行签名时依赖时间戳,故于body处理部分创建时间戳) var timestamp int64 = time.Now().Unix() @@ -272,6 +291,9 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName // hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用 func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + if m.useOpenAICompatibleAPI() { + return body, nil + } request := &chatCompletionRequest{} err := m.config.parseRequestAndMapModel(ctx, request, body, log) if err != nil { @@ -297,7 +319,7 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a } func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - if m.config.protocol == protocolOriginal { + if m.config.IsOriginal() || m.useOpenAICompatibleAPI() { return chunk, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 874b625ee9..4ff62411bc 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -571,7 +571,7 @@ func (c *ProviderConfig) handleRequestBody( provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log, ) (types.Action, error) { // use original protocol - if c.protocol == protocolOriginal { + if c.IsOriginal() { return types.ActionContinue, nil } From d031d522b67c156556238b104424718d5f297e48 Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Wed, 22 Jan 2025 20:56:23 +0800 Subject: [PATCH 04/10] feature: support tencent hunyuan embeddings --- plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index c71227b1d2..a61dff054d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -43,6 +43,7 @@ const ( // docs: https://cloud.tencent.com/document/product/1729/111007 hunyuanOpenAiDomain = "api.hunyuan.cloud.tencent.com" hunyuanOpenAiRequestPath = "/v1/chat/completions" + hunyuanOpenAiEmbeddings = "/v1/embeddings" ) type hunyuanProviderInitializer struct { @@ -104,6 +105,7 @@ func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) erro func (m *hunyuanProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ string(ApiNameChatCompletion): hunyuanOpenAiRequestPath, + string(ApiNameEmbeddings): hunyuanOpenAiEmbeddings, } } @@ -435,6 +437,9 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex } func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if m.config.IsOriginal() || m.useOpenAICompatibleAPI() { + return body, nil + } log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body)) hunyuanResponse := &hunyuanTextGenResponseNonStreaming{} if err := json.Unmarshal(body, hunyuanResponse); err != nil { From c1b958652b5f012da7e05e5078a94d76a6f849de Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Thu, 23 Jan 2025 11:49:55 +0800 Subject: [PATCH 05/10] fix: fix body rewrite error & Fix non chat type model mapping --- plugins/wasm-go/extensions/ai-proxy/main.go | 3 -- .../extensions/ai-proxy/provider/claude.go | 6 +++ .../extensions/ai-proxy/provider/cohere.go | 3 ++ .../extensions/ai-proxy/provider/dify.go | 10 +++-- .../extensions/ai-proxy/provider/hunyuan.go | 2 +- .../extensions/ai-proxy/provider/model.go | 45 ++++++++++++------- .../extensions/ai-proxy/provider/moonshot.go | 4 ++ .../extensions/ai-proxy/provider/provider.go | 43 +++++++++--------- .../extensions/ai-proxy/provider/qwen.go | 7 ++- 9 files changed, 77 insertions(+), 46 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index a41a73cd8d..4f08ab82c0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -274,9 +274,6 @@ func getOpenAiApiName(path string) provider.ApiName { if strings.HasSuffix(path, "/v1/embeddings") { return provider.ApiNameEmbeddings } - if strings.HasSuffix(path, "/v1/audio/transcriptions") { - return provider.ApiNameAudioTranscription - } if strings.HasSuffix(path, "/v1/audio/speech") { return provider.ApiNameAudioSpeech } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 470ea4de87..a1412114f5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -139,6 +139,9 @@ func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, } func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return c.config.defaultTransformRequestBody(ctx, apiName, body, log) + } request := &chatCompletionRequest{} if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { return nil, err @@ -148,6 +151,9 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A } func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } claudeResponse := &claudeTextGenResponse{} if err := json.Unmarshal(body, claudeResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal claude response: %v", err) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index 0964051454..931d1b677c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -107,6 +107,9 @@ func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam } func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return m.config.defaultTransformRequestBody(ctx, apiName, body, log) + } request := &chatCompletionRequest{} if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { return nil, err diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/dify.go b/plugins/wasm-go/extensions/ai-proxy/provider/dify.go index 28b6dec794..2395e97f66 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/dify.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/dify.go @@ -4,13 +4,14 @@ import ( "encoding/json" "errors" "fmt" + "net/http" + "strings" + "time" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "strings" - "time" ) const ( @@ -83,6 +84,9 @@ func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b } func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return d.config.defaultTransformRequestBody(ctx, apiName, body, log) + } request := &chatCompletionRequest{} err := d.config.parseRequestAndMapModel(ctx, request, body, log) if err != nil { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index a61dff054d..7aa0a4ae66 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -294,7 +294,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName // hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用 func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { if m.useOpenAICompatibleAPI() { - return body, nil + return m.config.defaultTransformRequestBody(ctx, apiName, body, log) } request := &chatCompletionRequest{} err := m.config.parseRequestAndMapModel(ctx, request, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 51a65555c7..61bef7467b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -19,21 +19,21 @@ const ( ) type chatCompletionRequest struct { - Model string `json:"model"` - Messages []chatMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - Seed int `json:"seed,omitempty"` - Stream bool `json:"stream,omitempty"` - StreamOptions *streamOptions `json:"stream_options,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Tools []tool `json:"tools,omitempty"` - ToolChoice *toolChoice `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` - Stop []string `json:"stop,omitempty"` + Model string `json:"model"` + Messages []chatMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + Seed int `json:"seed,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *streamOptions `json:"stream_options,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Tools []tool `json:"tools,omitempty"` + ToolChoice *toolChoice `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` + Stop []string `json:"stop,omitempty"` ResponseFormat map[string]interface{} `json:"response_format,omitempty"` } @@ -230,6 +230,21 @@ func (e *streamEvent) setValue(key, value string) { } } +// https://platform.openai.com/docs/guides/images +type imageGenerationRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` +} + +// https://platform.openai.com/docs/guides/speech-to-text +type audioSpeechRequest struct { + Model string `json:"model"` + Input string `json:"input"` + Voice string `json:"voice"` +} + type embeddingsRequest struct { Input interface{} `json:"input"` Model string `json:"model"` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index b49800d621..ae7315c459 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -84,6 +84,10 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } + // 非chat类型的请求,不做处理 + if apiName != ApiNameChatCompletion { + return types.ActionContinue, nil + } request := &chatCompletionRequest{} if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 4ff62411bc..6c810f6dbe 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -1,7 +1,6 @@ package provider import ( - "encoding/json" "errors" "math/rand" "net/http" @@ -12,6 +11,7 @@ import ( "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) type ApiName string @@ -22,11 +22,10 @@ const ( // ApiName 格式 {vendor}/{version}/{apitype} // 表示遵循 厂商/版本/接口类型 的格式 // 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank - ApiNameChatCompletion ApiName = "openai/v1/chatcompletions" - ApiNameEmbeddings ApiName = "openai/v1/embeddings" - ApiNameImageGeneration ApiName = "openai/v1/imagegeneration" - ApiNameAudioSpeech ApiName = "openai/v1/audiospeech" - ApiNameAudioTranscription ApiName = "openai/v1/audiotranscription" + ApiNameChatCompletion ApiName = "openai/v1/chatcompletions" + ApiNameEmbeddings ApiName = "openai/v1/embeddings" + ApiNameImageGeneration ApiName = "openai/v1/imagegeneration" + ApiNameAudioSpeech ApiName = "openai/v1/audiospeech" PathOpenAIChatCompletions = "/v1/chat/completions" PathOpenAIEmbeddings = "/v1/embeddings" @@ -379,23 +378,19 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.outputVariable = json.Get("outputVariable").String() c.capabilities = make(map[string]string) - for capability, pathJson := range json.Get("abilities").Map() { + for capability, pathJson := range json.Get("capabilities").Map() { // 过滤掉不受支持的能力 switch capability { case string(ApiNameChatCompletion), string(ApiNameEmbeddings), string(ApiNameImageGeneration), - string(ApiNameAudioSpeech), - string(ApiNameAudioTranscription): + string(ApiNameAudioSpeech): c.capabilities[capability] = pathJson.String() } } } func (c *ProviderConfig) Validate() error { - if c.timeout < 0 { - return errors.New("invalid timeout in config") - } if c.protocol != protocolOpenAI && c.protocol != protocolOriginal { return errors.New("invalid protocol in config") } @@ -528,7 +523,7 @@ func getMappedModel(model string, modelMapping map[string]string, log wrapper.Lo } func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string { - if modelMapping == nil || len(modelMapping) == 0 { + if len(modelMapping) == 0 { return "" } @@ -618,17 +613,21 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt } } +// defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能 func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { - var request interface{} - if apiName == ApiNameChatCompletion { - request = &chatCompletionRequest{} - } else { - request = &embeddingsRequest{} - } - if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil { - return nil, err + switch apiName { + case ApiNameChatCompletion: + stream := gjson.GetBytes(body, "stream").Bool() + if stream { + _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") + ctx.SetContext(ctxKeyIsStreaming, true) + } else { + ctx.SetContext(ctxKeyIsStreaming, false) + } } - return json.Marshal(request) + model := gjson.GetBytes(body, "model").String() + ctx.SetContext(ctxKeyOriginalRequestModel, model) + return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping, log)) } func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index d3b7dc7953..e5650f355f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -89,10 +89,13 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName } func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { - if apiName == ApiNameChatCompletion { + switch apiName { + case ApiNameChatCompletion: return m.onChatCompletionRequestBody(ctx, body, headers, log) - } else { + case ApiNameEmbeddings: return m.onEmbeddingsRequestBody(ctx, body, log) + default: + return m.config.defaultTransformRequestBody(ctx, apiName, body, log) } } From 78ced6e57e33d6a8549cb2900c5f9c1d317d9162 Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Thu, 23 Jan 2025 14:08:46 +0800 Subject: [PATCH 06/10] fix cohere rerank --- plugins/wasm-go/extensions/ai-proxy/main.go | 4 ++++ plugins/wasm-go/extensions/ai-proxy/provider/azure.go | 3 +-- plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go | 1 - plugins/wasm-go/extensions/ai-proxy/provider/cohere.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/coze.go | 4 +--- plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go | 3 ++- plugins/wasm-go/extensions/ai-proxy/provider/gemini.go | 6 +----- plugins/wasm-go/extensions/ai-proxy/provider/mistral.go | 3 +-- plugins/wasm-go/extensions/ai-proxy/provider/provider.go | 3 ++- 9 files changed, 14 insertions(+), 17 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 4f08ab82c0..6244a0e60e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -280,5 +280,9 @@ func getOpenAiApiName(path string) provider.ApiName { if strings.HasSuffix(path, "/v1/images/generations") { return provider.ApiNameImageGeneration } + // rerank + if strings.HasSuffix(path, "/v1/rerank") { + return provider.ApiNameCohereV1Rerank + } return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index bddfccdf09..5fcc378d47 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -17,8 +17,7 @@ type azureProviderInitializer struct { func (m *azureProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - // azure 此配置无实质作用,只是为了保持和其他provider的一致性 - // TODO: azure的模式和openai是一致的,只是需要处理前缀,可以在TransformRequestHeaders中处理,以支持通用能力 + // TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities string(ApiNameChatCompletion): PathOpenAIChatCompletions, string(ApiNameEmbeddings): PathOpenAIEmbeddings, } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index 7e0c362148..d04c5c7d85 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -27,7 +27,6 @@ func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) err func (m *baichuanProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - // 百川AI的chat和embeddings接口和OpenAI的chat和embeddings接口一样 string(ApiNameChatCompletion): PathOpenAIChatCompletions, string(ApiNameEmbeddings): PathOpenAIEmbeddings, } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index 931d1b677c..a21964e497 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -13,7 +13,7 @@ import ( const ( cohereDomain = "api.cohere.com" - // TODO: 现在cohere有v2, 也有embeddings, 考虑更多支持: https://docs.cohere.com/v2/reference/rerank + // TODO: support more capabilities, upgrade to v2, docs: https://docs.cohere.com/v2/reference/chat cohereChatCompletionPath = "/v1/chat" cohereRerankPath = "/v1/rerank" ) @@ -100,7 +100,7 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe } func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, cohereDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go index 874f7a48ae..8626b226af 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go @@ -22,9 +22,7 @@ func (m *cozeProviderInitializer) ValidateConfig(config *ProviderConfig) error { } func (m *cozeProviderInitializer) DefaultCapabilities() map[string]string { - return map[string]string{ - // 此配置暂时无实质作用,只是为了保持和其他provider的一致性 - } + return map[string]string{} } func (m *cozeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index 81ba9e8562..c8eca82a5c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -13,7 +13,8 @@ import ( const ( deepseekDomain = "api.deepseek.com" - // TODO: 根据文档 docs: https://api-docs.deepseek.com/api/create-chat-completion, path应该是 /chat/completions, 待验证 + // TODO: docs: https://api-docs.deepseek.com/api/create-chat-completion + // accourding to the docs, the path should be /chat/completions, need to be verified deepseekChatCompletionPath = "/v1/chat/completions" ) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 7e456115c0..757f3937d7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -36,11 +36,7 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string { - return map[string]string{ - // path在gemini中没有实际意义,只是为了保持和其他provider的一致性 - string(ApiNameChatCompletion): "_", - string(ApiNameEmbeddings): "_", - } + return map[string]string{} } func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index 9f7a71f06f..3f361a27ac 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -24,8 +24,7 @@ func (m *mistralProviderInitializer) ValidateConfig(config *ProviderConfig) erro func (m *mistralProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - // mistral的chat接口和OpenAI的chat接口一样 - // docs: https://docs.mistral.ai/api/ + // The chat interface of mistral is the same as that of OpenAI. docs: https://docs.mistral.ai/api/ string(ApiNameChatCompletion): PathOpenAIChatCompletions, string(ApiNameEmbeddings): PathOpenAIEmbeddings, } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 6c810f6dbe..d2d9efbbc1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -384,7 +384,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { case string(ApiNameChatCompletion), string(ApiNameEmbeddings), string(ApiNameImageGeneration), - string(ApiNameAudioSpeech): + string(ApiNameAudioSpeech), + string(ApiNameCohereV1Rerank): c.capabilities[capability] = pathJson.String() } } From d64444388d9eb4ad639d4db44ac76f2d1f6c4485 Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Thu, 23 Jan 2025 16:07:12 +0800 Subject: [PATCH 07/10] chore: fix readme example error --- plugins/wasm-go/extensions/ai-proxy/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 4062a6f220..b90a32c057 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -42,7 +42,7 @@ description: AI 代理插件配置参考 | `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | | `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 | | `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 | -| `capabilities` | array of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, 当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, openai/v1/audiotranscription | +| `capabilities` | map of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key表示的是采用的厂商协议能力,values表示的真实的厂商该能力的api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank | `context`的配置字段说明如下: | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | From 290868b22bc3844a31d2eb3f7cad90c070c440f0 Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Thu, 23 Jan 2025 20:53:14 +0800 Subject: [PATCH 08/10] feature: allow passthrough unsupported api --- plugins/wasm-go/extensions/ai-proxy/provider/ai360.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/azure.go | 4 ++-- .../wasm-go/extensions/ai-proxy/provider/baichuan.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/baidu.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/claude.go | 4 ++-- .../wasm-go/extensions/ai-proxy/provider/cloudflare.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/cohere.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/deepl.go | 4 ++-- .../wasm-go/extensions/ai-proxy/provider/deepseek.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/dify.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/doubao.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/gemini.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/github.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/groq.go | 4 ++-- .../wasm-go/extensions/ai-proxy/provider/hunyuan.go | 4 ++-- .../wasm-go/extensions/ai-proxy/provider/minimax.go | 4 ++-- .../wasm-go/extensions/ai-proxy/provider/mistral.go | 4 ++-- .../wasm-go/extensions/ai-proxy/provider/moonshot.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/ollama.go | 4 ++-- .../wasm-go/extensions/ai-proxy/provider/provider.go | 10 ++++++++++ plugins/wasm-go/extensions/ai-proxy/provider/qwen.go | 6 +++--- plugins/wasm-go/extensions/ai-proxy/provider/spark.go | 4 ++-- .../wasm-go/extensions/ai-proxy/provider/stepfun.go | 4 ++-- .../extensions/ai-proxy/provider/together_ai.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/yi.go | 4 ++-- .../wasm-go/extensions/ai-proxy/provider/zhipuai.go | 4 ++-- 26 files changed, 61 insertions(+), 51 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index 57b092cd82..52751a0600 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -50,7 +50,7 @@ func (m *ai360Provider) GetProviderType() string { func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -59,7 +59,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index 5fcc378d47..b663a59dff 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -64,7 +64,7 @@ func (m *azureProvider) GetProviderType() string { func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -72,7 +72,7 @@ func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index d04c5c7d85..d0b8e90586 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -51,7 +51,7 @@ func (m *baichuanProvider) GetProviderType() string { func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -59,7 +59,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index 27bf5aaecc..f23f185ed3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -52,7 +52,7 @@ func (g *baiduProvider) GetProviderType() string { func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !g.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return g.config.handleUnsupportedAPI() } g.config.handleRequestHeaders(g, ctx, apiName, log) return nil @@ -60,7 +60,7 @@ func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !g.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, g.config.handleUnsupportedAPI() } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index a1412114f5..2e1f95f072 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -112,7 +112,7 @@ func (c *claudeProvider) GetProviderType() string { func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !c.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return c.config.handleUnsupportedAPI() } c.config.handleRequestHeaders(c, ctx, apiName, log) return nil @@ -133,7 +133,7 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !c.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, c.config.handleUnsupportedAPI() } return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index 22b9cd4286..59d9ce5492 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -50,7 +50,7 @@ func (c *cloudflareProvider) GetProviderType() string { func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !c.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return c.config.handleUnsupportedAPI() } c.config.handleRequestHeaders(c, ctx, apiName, log) return nil @@ -58,7 +58,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !c.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, c.config.handleUnsupportedAPI() } return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index a21964e497..55ab910084 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -67,7 +67,7 @@ func (m *cohereProvider) GetProviderType() string { func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -75,7 +75,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 46b0dbab69..02da91f977 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -84,7 +84,7 @@ func (d *deeplProvider) GetProviderType() string { func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !d.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return d.config.handleUnsupportedAPI() } d.config.handleRequestHeaders(d, ctx, apiName, log) return nil @@ -97,7 +97,7 @@ func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !d.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, d.config.handleUnsupportedAPI() } return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index c8eca82a5c..eb56b3ab0e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -53,7 +53,7 @@ func (m *deepseekProvider) GetProviderType() string { func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -61,7 +61,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/dify.go b/plugins/wasm-go/extensions/ai-proxy/provider/dify.go index 2395e97f66..c6427312ad 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/dify.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/dify.go @@ -52,7 +52,7 @@ func (d *difyProvider) GetProviderType() string { func (d *difyProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return errUnsupportedApiName + return d.config.handleUnsupportedAPI() } d.config.handleRequestHeaders(d, ctx, apiName, log) return nil @@ -78,7 +78,7 @@ func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, d.config.handleUnsupportedAPI() } return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index a896078e12..1f421b63a4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -51,7 +51,7 @@ func (m *doubaoProvider) GetProviderType() string { func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -59,7 +59,7 @@ func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 757f3937d7..d67a2a584d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -58,7 +58,7 @@ func (g *geminiProvider) GetProviderType() string { func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !g.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return g.config.handleUnsupportedAPI() } g.config.handleRequestHeaders(g, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -72,7 +72,7 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !g.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, g.config.handleUnsupportedAPI() } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index e8a05cc1c9..71793b473a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -53,7 +53,7 @@ func (m *githubProvider) GetProviderType() string { func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -62,7 +62,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index c415a707b8..b6004d3ce4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -50,7 +50,7 @@ func (g *groqProvider) GetProviderType() string { func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !g.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return g.config.handleUnsupportedAPI() } g.config.handleRequestHeaders(g, ctx, apiName, log) return nil @@ -58,7 +58,7 @@ func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !g.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, g.config.handleUnsupportedAPI() } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index 7aa0a4ae66..c9849f333c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -137,7 +137,7 @@ func (m *hunyuanProvider) useOpenAICompatibleAPI() bool { func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -161,7 +161,7 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa // hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法 func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } if m.useOpenAICompatibleAPI() { return types.ActionContinue, nil diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index c9f7aa1030..19c011393a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -75,7 +75,7 @@ func (m *minimaxProvider) GetProviderType() string { func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -90,7 +90,7 @@ func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } if minimaxApiTypePro == m.config.minimaxApiType { // Use chat completion Pro API. diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index 3f361a27ac..7cc992fb65 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -49,7 +49,7 @@ func (m *mistralProvider) GetProviderType() string { func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -57,7 +57,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index ae7315c459..646f970098 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -65,7 +65,7 @@ func (m *moonshotProvider) GetProviderType() string { func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -82,7 +82,7 @@ func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN // moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法 func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } // 非chat类型的请求,不做处理 if apiName != ApiNameChatCompletion { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index 57ad424ffa..e15fd89b88 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -55,7 +55,7 @@ func (m *ollamaProvider) GetProviderType() string { func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -63,7 +63,7 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index d2d9efbbc1..eb984cc28a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -265,6 +265,9 @@ type ProviderConfig struct { // @Title zh-CN 额外支持的ai能力 // @Description zh-CN 开放的ai能力和urlpath映射,例如: {"openai/v1/chatcompletions": "/v1/chat/completions"} capabilities map[string]string + // @Title zh-CN 是否开启透传 + // @Description zh-CN 如果是插件不支持的API,是否透传请求, 默认为false + passthrsough bool } func (c *ProviderConfig) GetId() string { @@ -450,6 +453,13 @@ func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) { return ReplaceByCustomSettings(body, c.customSettings) } +func (c *ProviderConfig) handleUnsupportedAPI() error { + if c.passthrsough { + return nil + } + return errUnsupportedApiName +} + func CreateProvider(pc ProviderConfig) (Provider, error) { initializer, has := providerInitializers[pc.typ] if !has { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index e5650f355f..0d4807fcfd 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -105,7 +105,7 @@ func (m *qwenProvider) GetProviderType() string { func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -150,7 +150,7 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b } if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } @@ -290,7 +290,7 @@ func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName Ap if m.config.isSupportedAPI(apiName) { return body, nil } - return nil, errUnsupportedApiName + return nil, m.config.handleUnsupportedAPI() } func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index 0fd10a8f3b..5ad514eb19 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -75,7 +75,7 @@ func (p *sparkProvider) GetProviderType() string { func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !p.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return p.config.handleUnsupportedAPI() } p.config.handleRequestHeaders(p, ctx, apiName, log) return nil @@ -83,7 +83,7 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !p.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, p.config.handleUnsupportedAPI() } return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index 71315621a4..71e208710a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -50,7 +50,7 @@ func (m *stepfunProvider) GetProviderType() string { func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -58,7 +58,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go b/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go index dfbeb401ca..33a4203df9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go @@ -49,7 +49,7 @@ func (m *togetherAIProvider) GetProviderType() string { func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -57,7 +57,7 @@ func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index 3c3db4d5de..80b0fbd743 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -49,7 +49,7 @@ func (m *yiProvider) GetProviderType() string { func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -57,7 +57,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index e95e99fc8b..12936667f7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -51,7 +51,7 @@ func (m *zhipuAiProvider) GetProviderType() string { func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return errUnsupportedApiName + return m.config.handleUnsupportedAPI() } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -59,7 +59,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, m.config.handleUnsupportedAPI() } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } From 21df731baef29863aaf33b06bc1f90060ce05d72 Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Thu, 23 Jan 2025 21:20:07 +0800 Subject: [PATCH 09/10] chore: add docs for passthrought for ai-proxy --- plugins/wasm-go/extensions/ai-proxy/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index b90a32c057..aba67f6b88 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -43,6 +43,7 @@ description: AI 代理插件配置参考 | `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 | | `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 | | `capabilities` | map of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key表示的是采用的厂商协议能力,values表示的真实的厂商该能力的api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank | +| `passthrough` | bool | 非必填 | - | 只要是不支持的API能力都直接转发, 此配置是capabilities配置的放大版本,允许任意api透传,就像没有ai-proxy插件一样 | `context`的配置字段说明如下: | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | From eb23ddff955d73f3ab372c8e3fc9e57bd31b789f Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Fri, 24 Jan 2025 10:26:32 +0800 Subject: [PATCH 10/10] fix: correcting some errors in 'rewriting logic' --- plugins/wasm-go/extensions/ai-proxy/main.go | 34 ++++++++++++------- .../extensions/ai-proxy/provider/ai360.go | 4 +-- .../extensions/ai-proxy/provider/azure.go | 4 +-- .../extensions/ai-proxy/provider/baichuan.go | 4 +-- .../extensions/ai-proxy/provider/baidu.go | 4 +-- .../extensions/ai-proxy/provider/claude.go | 8 +++-- .../ai-proxy/provider/cloudflare.go | 4 +-- .../extensions/ai-proxy/provider/cohere.go | 4 +-- .../extensions/ai-proxy/provider/deepl.go | 7 ++-- .../extensions/ai-proxy/provider/deepseek.go | 4 +-- .../extensions/ai-proxy/provider/dify.go | 10 ++++-- .../extensions/ai-proxy/provider/doubao.go | 4 +-- .../extensions/ai-proxy/provider/gemini.go | 7 ++-- .../extensions/ai-proxy/provider/github.go | 4 +-- .../extensions/ai-proxy/provider/groq.go | 4 +-- .../extensions/ai-proxy/provider/hunyuan.go | 9 +++-- .../extensions/ai-proxy/provider/minimax.go | 10 ++++-- .../extensions/ai-proxy/provider/mistral.go | 4 +-- .../extensions/ai-proxy/provider/moonshot.go | 7 ++-- .../extensions/ai-proxy/provider/ollama.go | 4 +-- .../extensions/ai-proxy/provider/provider.go | 9 ++--- .../extensions/ai-proxy/provider/qwen.go | 6 ++-- .../extensions/ai-proxy/provider/spark.go | 10 ++++-- .../extensions/ai-proxy/provider/stepfun.go | 4 +-- .../ai-proxy/provider/together_ai.go | 4 +-- .../extensions/ai-proxy/provider/yi.go | 4 +-- .../extensions/ai-proxy/provider/zhipuai.go | 4 +-- 27 files changed, 111 insertions(+), 70 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 6244a0e60e..1dcea5878f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -78,7 +78,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf rawPath := ctx.Path() path, _ := url.Parse(rawPath) - apiName := getOpenAiApiName(path.Path) + apiName := getApiName(path.Path) providerConfig := pluginConfig.GetProviderConfig() if providerConfig.IsOriginal() { if handler, ok := activeProvider.(provider.ApiNameHandler); ok { @@ -103,20 +103,25 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf // Set the apiToken for the current request. providerConfig.SetApiTokenInUse(ctx, log) - hasRequestBody := wrapper.HasRequestBody() err := handler.OnRequestHeaders(ctx, apiName, log) - if err == nil { - if hasRequestBody { - proxywasm.RemoveHttpRequestHeader("Content-Length") - ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) - // Delay the header processing to allow changing in OnRequestBody - return types.HeaderStopIteration + if err != nil { + if providerConfig.PassthroughUnsupportedAPI() { + log.Warnf("[onHttpRequestHeader] passthrough unsupported API: %v", err) + ctx.DontReadRequestBody() + return types.ActionContinue } - ctx.DontReadRequestBody() + util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) return types.ActionContinue } - util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) + hasRequestBody := wrapper.HasRequestBody() + if hasRequestBody { + proxywasm.RemoveHttpRequestHeader("Content-Length") + ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) + // Delay the header processing to allow changing in OnRequestBody + return types.HeaderStopIteration + } + ctx.DontReadRequestBody() return types.ActionContinue } @@ -151,6 +156,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig if err == nil { return action } + if pluginConfig.GetProviderConfig().PassthroughUnsupportedAPI() { + log.Warnf("[onHttpRequestBody] passthrough unsupported API: %v", err) + return types.ActionContinue + } util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err)) } return types.ActionContinue @@ -267,7 +276,8 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) { } } -func getOpenAiApiName(path string) provider.ApiName { +func getApiName(path string) provider.ApiName { + // openai style if strings.HasSuffix(path, "/v1/chat/completions") { return provider.ApiNameChatCompletion } @@ -280,7 +290,7 @@ func getOpenAiApiName(path string) provider.ApiName { if strings.HasSuffix(path, "/v1/images/generations") { return provider.ApiNameImageGeneration } - // rerank + // cohere style if strings.HasSuffix(path, "/v1/rerank") { return provider.ApiNameCohereV1Rerank } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index 52751a0600..57b092cd82 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -50,7 +50,7 @@ func (m *ai360Provider) GetProviderType() string { func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -59,7 +59,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index b663a59dff..5fcc378d47 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -64,7 +64,7 @@ func (m *azureProvider) GetProviderType() string { func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -72,7 +72,7 @@ func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index d0b8e90586..d04c5c7d85 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -51,7 +51,7 @@ func (m *baichuanProvider) GetProviderType() string { func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -59,7 +59,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index f23f185ed3..27bf5aaecc 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -52,7 +52,7 @@ func (g *baiduProvider) GetProviderType() string { func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !g.config.isSupportedAPI(apiName) { - return g.config.handleUnsupportedAPI() + return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) return nil @@ -60,7 +60,7 @@ func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !g.config.isSupportedAPI(apiName) { - return types.ActionContinue, g.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 2e1f95f072..8d75f5cae0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -112,7 +112,7 @@ func (c *claudeProvider) GetProviderType() string { func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !c.config.isSupportedAPI(apiName) { - return c.config.handleUnsupportedAPI() + return errUnsupportedApiName } c.config.handleRequestHeaders(c, ctx, apiName, log) return nil @@ -133,7 +133,7 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !c.config.isSupportedAPI(apiName) { - return types.ActionContinue, c.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) } @@ -169,6 +169,10 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A if isLastChunk || len(chunk) == 0 { return nil, nil } + // only process the response from chat completion, skip other responses + if name != ApiNameChatCompletion { + return chunk, nil + } responseBuilder := &strings.Builder{} lines := strings.Split(string(chunk), "\n") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index 59d9ce5492..22b9cd4286 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -50,7 +50,7 @@ func (c *cloudflareProvider) GetProviderType() string { func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !c.config.isSupportedAPI(apiName) { - return c.config.handleUnsupportedAPI() + return errUnsupportedApiName } c.config.handleRequestHeaders(c, ctx, apiName, log) return nil @@ -58,7 +58,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !c.config.isSupportedAPI(apiName) { - return types.ActionContinue, c.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index 55ab910084..a21964e497 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -67,7 +67,7 @@ func (m *cohereProvider) GetProviderType() string { func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -75,7 +75,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 02da91f977..812bd32557 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -84,7 +84,7 @@ func (d *deeplProvider) GetProviderType() string { func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !d.config.isSupportedAPI(apiName) { - return d.config.handleUnsupportedAPI() + return errUnsupportedApiName } d.config.handleRequestHeaders(d, ctx, apiName, log) return nil @@ -97,7 +97,7 @@ func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !d.config.isSupportedAPI(apiName) { - return types.ActionContinue, d.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log) } @@ -119,6 +119,9 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api } func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } deeplResponse := &deeplResponse{} if err := json.Unmarshal(body, deeplResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index eb56b3ab0e..c8eca82a5c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -53,7 +53,7 @@ func (m *deepseekProvider) GetProviderType() string { func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -61,7 +61,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/dify.go b/plugins/wasm-go/extensions/ai-proxy/provider/dify.go index c6427312ad..030d8ba2b4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/dify.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/dify.go @@ -52,7 +52,7 @@ func (d *difyProvider) GetProviderType() string { func (d *difyProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return d.config.handleUnsupportedAPI() + return errUnsupportedApiName } d.config.handleRequestHeaders(d, ctx, apiName, log) return nil @@ -78,7 +78,7 @@ func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { - return types.ActionContinue, d.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log) } @@ -99,6 +99,9 @@ func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiN } func (d *difyProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } difyResponse := &DifyChatResponse{} if err := json.Unmarshal(body, difyResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal dify response: %v", err) @@ -150,6 +153,9 @@ func (d *difyProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api if isLastChunk || len(chunk) == 0 { return nil, nil } + if name != ApiNameChatCompletion { + return chunk, nil + } // sample event response: // data: {"event": "agent_thought", "id": "8dcf3648-fbad-407a-85dd-73a6f43aeb9f", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "position": 1, "thought": "", "observation": "", "tool": "", "tool_input": "", "created_at": 1705639511, "message_files": [], "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index 1f421b63a4..a896078e12 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -51,7 +51,7 @@ func (m *doubaoProvider) GetProviderType() string { func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -59,7 +59,7 @@ func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index d67a2a584d..1f8d877ea1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -58,7 +58,7 @@ func (g *geminiProvider) GetProviderType() string { func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !g.config.isSupportedAPI(apiName) { - return g.config.handleUnsupportedAPI() + return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -72,7 +72,7 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !g.config.isSupportedAPI(apiName) { - return types.ActionContinue, g.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) } @@ -115,6 +115,9 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A if isLastChunk || len(chunk) == 0 { return nil, nil } + if name != ApiNameChatCompletion { + return chunk, nil + } // sample end event response: // data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini,一个大型多模态模型,由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}} responseBuilder := &strings.Builder{} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index 71793b473a..e8a05cc1c9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -53,7 +53,7 @@ func (m *githubProvider) GetProviderType() string { func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -62,7 +62,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index b6004d3ce4..c415a707b8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -50,7 +50,7 @@ func (g *groqProvider) GetProviderType() string { func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !g.config.isSupportedAPI(apiName) { - return g.config.handleUnsupportedAPI() + return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) return nil @@ -58,7 +58,7 @@ func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !g.config.isSupportedAPI(apiName) { - return types.ActionContinue, g.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index c9849f333c..fc3fddca2f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -137,7 +137,7 @@ func (m *hunyuanProvider) useOpenAICompatibleAPI() bool { func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -161,7 +161,7 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa // hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法 func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } if m.useOpenAICompatibleAPI() { return types.ActionContinue, nil @@ -321,7 +321,7 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a } func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - if m.config.IsOriginal() || m.useOpenAICompatibleAPI() { + if m.config.IsOriginal() || m.useOpenAICompatibleAPI() || name != ApiNameChatCompletion { return chunk, nil } @@ -440,6 +440,9 @@ func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName if m.config.IsOriginal() || m.useOpenAICompatibleAPI() { return body, nil } + if apiName != ApiNameChatCompletion { + return body, nil + } log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body)) hunyuanResponse := &hunyuanTextGenResponseNonStreaming{} if err := json.Unmarshal(body, hunyuanResponse); err != nil { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index 19c011393a..0f5e0d3695 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -75,7 +75,7 @@ func (m *minimaxProvider) GetProviderType() string { func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -90,7 +90,7 @@ func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } if minimaxApiTypePro == m.config.minimaxApiType { // Use chat completion Pro API. @@ -167,6 +167,9 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name if isLastChunk || len(chunk) == 0 { return nil, nil } + if name != ApiNameChatCompletion { + return chunk, nil + } // Sample event response: // data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false} @@ -200,6 +203,9 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name // TransformResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API. func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } minimaxResp := &minimaxChatCompletionProResp{} if err := json.Unmarshal(body, minimaxResp); err != nil { return nil, fmt.Errorf("unable to unmarshal minimax response: %v", err) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index 7cc992fb65..3f361a27ac 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -49,7 +49,7 @@ func (m *mistralProvider) GetProviderType() string { func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -57,7 +57,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 646f970098..776e250836 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -65,7 +65,7 @@ func (m *moonshotProvider) GetProviderType() string { func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -82,7 +82,7 @@ func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN // moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法 func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } // 非chat类型的请求,不做处理 if apiName != ApiNameChatCompletion { @@ -165,6 +165,9 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba } func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { + if name != ApiNameChatCompletion { + return chunk, nil + } receivedBody := chunk if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { receivedBody = append(bufferedStreamingBody, chunk...) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index e15fd89b88..57ad424ffa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -55,7 +55,7 @@ func (m *ollamaProvider) GetProviderType() string { func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -63,7 +63,7 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index eb984cc28a..ae2763c0aa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -267,7 +267,7 @@ type ProviderConfig struct { capabilities map[string]string // @Title zh-CN 是否开启透传 // @Description zh-CN 如果是插件不支持的API,是否透传请求, 默认为false - passthrsough bool + passthrough bool } func (c *ProviderConfig) GetId() string { @@ -453,11 +453,8 @@ func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) { return ReplaceByCustomSettings(body, c.customSettings) } -func (c *ProviderConfig) handleUnsupportedAPI() error { - if c.passthrsough { - return nil - } - return errUnsupportedApiName +func (c *ProviderConfig) PassthroughUnsupportedAPI() bool { + return c.passthrough } func CreateProvider(pc ProviderConfig) (Provider, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 0d4807fcfd..e5650f355f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -105,7 +105,7 @@ func (m *qwenProvider) GetProviderType() string { func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -150,7 +150,7 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b } if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } @@ -290,7 +290,7 @@ func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName Ap if m.config.isSupportedAPI(apiName) { return body, nil } - return nil, m.config.handleUnsupportedAPI() + return nil, errUnsupportedApiName } func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index 5ad514eb19..bac72e0239 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -75,7 +75,7 @@ func (p *sparkProvider) GetProviderType() string { func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !p.config.isSupportedAPI(apiName) { - return p.config.handleUnsupportedAPI() + return errUnsupportedApiName } p.config.handleRequestHeaders(p, ctx, apiName, log) return nil @@ -83,12 +83,15 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !p.config.isSupportedAPI(apiName) { - return types.ActionContinue, p.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log) } func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } sparkResponse := &sparkResponse{} if err := json.Unmarshal(body, sparkResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal spark response: %v", err) @@ -104,6 +107,9 @@ func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Ap if isLastChunk || len(chunk) == 0 { return nil, nil } + if name != ApiNameChatCompletion { + return chunk, nil + } responseBuilder := &strings.Builder{} lines := strings.Split(string(chunk), "\n") for _, data := range lines { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index 71e208710a..71315621a4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -50,7 +50,7 @@ func (m *stepfunProvider) GetProviderType() string { func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -58,7 +58,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go b/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go index 33a4203df9..dfbeb401ca 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go @@ -49,7 +49,7 @@ func (m *togetherAIProvider) GetProviderType() string { func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -57,7 +57,7 @@ func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index 80b0fbd743..3c3db4d5de 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -49,7 +49,7 @@ func (m *yiProvider) GetProviderType() string { func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -57,7 +57,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index 12936667f7..e95e99fc8b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -51,7 +51,7 @@ func (m *zhipuAiProvider) GetProviderType() string { func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -59,7 +59,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) }