From 506dcafdab4c6e985121110cbeea14c04228f79d Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Fri, 24 Jan 2025 10:26:32 +0800 Subject: [PATCH] fix: correcting some errors in 'rewriting logic' --- plugins/wasm-go/extensions/ai-proxy/main.go | 33 ++++++++++++------- .../extensions/ai-proxy/provider/ai360.go | 4 +-- .../extensions/ai-proxy/provider/azure.go | 4 +-- .../extensions/ai-proxy/provider/baichuan.go | 4 +-- .../extensions/ai-proxy/provider/baidu.go | 4 +-- .../extensions/ai-proxy/provider/claude.go | 8 +++-- .../ai-proxy/provider/cloudflare.go | 4 +-- .../extensions/ai-proxy/provider/cohere.go | 4 +-- .../extensions/ai-proxy/provider/deepl.go | 7 ++-- .../extensions/ai-proxy/provider/deepseek.go | 4 +-- .../extensions/ai-proxy/provider/dify.go | 10 ++++-- .../extensions/ai-proxy/provider/doubao.go | 4 +-- .../extensions/ai-proxy/provider/gemini.go | 7 ++-- .../extensions/ai-proxy/provider/github.go | 4 +-- .../extensions/ai-proxy/provider/groq.go | 4 +-- .../extensions/ai-proxy/provider/hunyuan.go | 9 +++-- .../extensions/ai-proxy/provider/minimax.go | 10 ++++-- .../extensions/ai-proxy/provider/mistral.go | 4 +-- .../extensions/ai-proxy/provider/moonshot.go | 7 ++-- .../extensions/ai-proxy/provider/ollama.go | 4 +-- .../extensions/ai-proxy/provider/provider.go | 9 ++--- .../extensions/ai-proxy/provider/qwen.go | 6 ++-- .../extensions/ai-proxy/provider/spark.go | 10 ++++-- .../extensions/ai-proxy/provider/stepfun.go | 4 +-- .../ai-proxy/provider/together_ai.go | 4 +-- .../extensions/ai-proxy/provider/yi.go | 4 +-- .../extensions/ai-proxy/provider/zhipuai.go | 4 +-- 27 files changed, 110 insertions(+), 70 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 6244a0e60e..c35f8d73c4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -78,7 +78,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf rawPath := ctx.Path() path, _ := url.Parse(rawPath) - apiName := getOpenAiApiName(path.Path) + apiName := getApiName(path.Path) providerConfig := pluginConfig.GetProviderConfig() if providerConfig.IsOriginal() { if handler, ok := activeProvider.(provider.ApiNameHandler); ok { @@ -103,20 +103,24 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf // Set the apiToken for the current request. providerConfig.SetApiTokenInUse(ctx, log) - hasRequestBody := wrapper.HasRequestBody() err := handler.OnRequestHeaders(ctx, apiName, log) - if err == nil { - if hasRequestBody { - proxywasm.RemoveHttpRequestHeader("Content-Length") - ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) - // Delay the header processing to allow changing in OnRequestBody - return types.HeaderStopIteration + if err != nil { + if providerConfig.PassthroughUnsupportedAPI() { + log.Warnf("[onHttpRequestHeader] passthrough unsupported API: %v", err) + return types.ActionContinue } - ctx.DontReadRequestBody() + util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) return types.ActionContinue } - util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) + hasRequestBody := wrapper.HasRequestBody() + if hasRequestBody { + proxywasm.RemoveHttpRequestHeader("Content-Length") + ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) + // Delay the header processing to allow changing in OnRequestBody + return types.HeaderStopIteration + } + ctx.DontReadRequestBody() return types.ActionContinue } @@ -151,6 +155,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig if err == nil { return action } + if pluginConfig.GetProviderConfig().PassthroughUnsupportedAPI() { + log.Warnf("[onHttpRequestBody] passthrough unsupported API: %v", err) + return types.ActionContinue + } util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err)) } return types.ActionContinue @@ -267,7 +275,8 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) { } } -func getOpenAiApiName(path string) provider.ApiName { +func getApiName(path string) provider.ApiName { + // openai style if strings.HasSuffix(path, "/v1/chat/completions") { return provider.ApiNameChatCompletion } @@ -280,7 +289,7 @@ func getOpenAiApiName(path string) provider.ApiName { if strings.HasSuffix(path, "/v1/images/generations") { return provider.ApiNameImageGeneration } - // rerank + // cohere style if strings.HasSuffix(path, "/v1/rerank") { return provider.ApiNameCohereV1Rerank } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index 52751a0600..57b092cd82 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -50,7 +50,7 @@ func (m *ai360Provider) GetProviderType() string { func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -59,7 +59,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index b663a59dff..5fcc378d47 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -64,7 +64,7 @@ func (m *azureProvider) GetProviderType() string { func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -72,7 +72,7 @@ func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index d0b8e90586..d04c5c7d85 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -51,7 +51,7 @@ func (m *baichuanProvider) GetProviderType() string { func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -59,7 +59,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index f23f185ed3..27bf5aaecc 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -52,7 +52,7 @@ func (g *baiduProvider) GetProviderType() string { func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !g.config.isSupportedAPI(apiName) { - return g.config.handleUnsupportedAPI() + return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) return nil @@ -60,7 +60,7 @@ func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !g.config.isSupportedAPI(apiName) { - return types.ActionContinue, g.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 2e1f95f072..8d75f5cae0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -112,7 +112,7 @@ func (c *claudeProvider) GetProviderType() string { func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !c.config.isSupportedAPI(apiName) { - return c.config.handleUnsupportedAPI() + return errUnsupportedApiName } c.config.handleRequestHeaders(c, ctx, apiName, log) return nil @@ -133,7 +133,7 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !c.config.isSupportedAPI(apiName) { - return types.ActionContinue, c.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) } @@ -169,6 +169,10 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A if isLastChunk || len(chunk) == 0 { return nil, nil } + // only process the response from chat completion, skip other responses + if name != ApiNameChatCompletion { + return chunk, nil + } responseBuilder := &strings.Builder{} lines := strings.Split(string(chunk), "\n") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index 59d9ce5492..22b9cd4286 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -50,7 +50,7 @@ func (c *cloudflareProvider) GetProviderType() string { func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !c.config.isSupportedAPI(apiName) { - return c.config.handleUnsupportedAPI() + return errUnsupportedApiName } c.config.handleRequestHeaders(c, ctx, apiName, log) return nil @@ -58,7 +58,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !c.config.isSupportedAPI(apiName) { - return types.ActionContinue, c.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index 55ab910084..a21964e497 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -67,7 +67,7 @@ func (m *cohereProvider) GetProviderType() string { func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -75,7 +75,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 02da91f977..812bd32557 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -84,7 +84,7 @@ func (d *deeplProvider) GetProviderType() string { func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !d.config.isSupportedAPI(apiName) { - return d.config.handleUnsupportedAPI() + return errUnsupportedApiName } d.config.handleRequestHeaders(d, ctx, apiName, log) return nil @@ -97,7 +97,7 @@ func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !d.config.isSupportedAPI(apiName) { - return types.ActionContinue, d.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log) } @@ -119,6 +119,9 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api } func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } deeplResponse := &deeplResponse{} if err := json.Unmarshal(body, deeplResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index eb56b3ab0e..c8eca82a5c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -53,7 +53,7 @@ func (m *deepseekProvider) GetProviderType() string { func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -61,7 +61,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/dify.go b/plugins/wasm-go/extensions/ai-proxy/provider/dify.go index c6427312ad..030d8ba2b4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/dify.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/dify.go @@ -52,7 +52,7 @@ func (d *difyProvider) GetProviderType() string { func (d *difyProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return d.config.handleUnsupportedAPI() + return errUnsupportedApiName } d.config.handleRequestHeaders(d, ctx, apiName, log) return nil @@ -78,7 +78,7 @@ func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { - return types.ActionContinue, d.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log) } @@ -99,6 +99,9 @@ func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiN } func (d *difyProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } difyResponse := &DifyChatResponse{} if err := json.Unmarshal(body, difyResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal dify response: %v", err) @@ -150,6 +153,9 @@ func (d *difyProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api if isLastChunk || len(chunk) == 0 { return nil, nil } + if name != ApiNameChatCompletion { + return chunk, nil + } // sample event response: // data: {"event": "agent_thought", "id": "8dcf3648-fbad-407a-85dd-73a6f43aeb9f", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "position": 1, "thought": "", "observation": "", "tool": "", "tool_input": "", "created_at": 1705639511, "message_files": [], "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index 1f421b63a4..a896078e12 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -51,7 +51,7 @@ func (m *doubaoProvider) GetProviderType() string { func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -59,7 +59,7 @@ func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index d67a2a584d..1f8d877ea1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -58,7 +58,7 @@ func (g *geminiProvider) GetProviderType() string { func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !g.config.isSupportedAPI(apiName) { - return g.config.handleUnsupportedAPI() + return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -72,7 +72,7 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !g.config.isSupportedAPI(apiName) { - return types.ActionContinue, g.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) } @@ -115,6 +115,9 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A if isLastChunk || len(chunk) == 0 { return nil, nil } + if name != ApiNameChatCompletion { + return chunk, nil + } // sample end event response: // data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini,一个大型多模态模型,由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}} responseBuilder := &strings.Builder{} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index 71793b473a..e8a05cc1c9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -53,7 +53,7 @@ func (m *githubProvider) GetProviderType() string { func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -62,7 +62,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index b6004d3ce4..c415a707b8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -50,7 +50,7 @@ func (g *groqProvider) GetProviderType() string { func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !g.config.isSupportedAPI(apiName) { - return g.config.handleUnsupportedAPI() + return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) return nil @@ -58,7 +58,7 @@ func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !g.config.isSupportedAPI(apiName) { - return types.ActionContinue, g.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index c9849f333c..fc3fddca2f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -137,7 +137,7 @@ func (m *hunyuanProvider) useOpenAICompatibleAPI() bool { func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -161,7 +161,7 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa // hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法 func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } if m.useOpenAICompatibleAPI() { return types.ActionContinue, nil @@ -321,7 +321,7 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a } func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - if m.config.IsOriginal() || m.useOpenAICompatibleAPI() { + if m.config.IsOriginal() || m.useOpenAICompatibleAPI() || name != ApiNameChatCompletion { return chunk, nil } @@ -440,6 +440,9 @@ func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName if m.config.IsOriginal() || m.useOpenAICompatibleAPI() { return body, nil } + if apiName != ApiNameChatCompletion { + return body, nil + } log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body)) hunyuanResponse := &hunyuanTextGenResponseNonStreaming{} if err := json.Unmarshal(body, hunyuanResponse); err != nil { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index 19c011393a..0f5e0d3695 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -75,7 +75,7 @@ func (m *minimaxProvider) GetProviderType() string { func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody @@ -90,7 +90,7 @@ func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } if minimaxApiTypePro == m.config.minimaxApiType { // Use chat completion Pro API. @@ -167,6 +167,9 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name if isLastChunk || len(chunk) == 0 { return nil, nil } + if name != ApiNameChatCompletion { + return chunk, nil + } // Sample event response: // data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false} @@ -200,6 +203,9 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name // TransformResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API. func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } minimaxResp := &minimaxChatCompletionProResp{} if err := json.Unmarshal(body, minimaxResp); err != nil { return nil, fmt.Errorf("unable to unmarshal minimax response: %v", err) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index 7cc992fb65..3f361a27ac 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -49,7 +49,7 @@ func (m *mistralProvider) GetProviderType() string { func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -57,7 +57,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 646f970098..776e250836 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -65,7 +65,7 @@ func (m *moonshotProvider) GetProviderType() string { func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -82,7 +82,7 @@ func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN // moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法 func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } // 非chat类型的请求,不做处理 if apiName != ApiNameChatCompletion { @@ -165,6 +165,9 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba } func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { + if name != ApiNameChatCompletion { + return chunk, nil + } receivedBody := chunk if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { receivedBody = append(bufferedStreamingBody, chunk...) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index e15fd89b88..57ad424ffa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -55,7 +55,7 @@ func (m *ollamaProvider) GetProviderType() string { func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -63,7 +63,7 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index eb984cc28a..ae2763c0aa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -267,7 +267,7 @@ type ProviderConfig struct { capabilities map[string]string // @Title zh-CN 是否开启透传 // @Description zh-CN 如果是插件不支持的API,是否透传请求, 默认为false - passthrsough bool + passthrough bool } func (c *ProviderConfig) GetId() string { @@ -453,11 +453,8 @@ func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) { return ReplaceByCustomSettings(body, c.customSettings) } -func (c *ProviderConfig) handleUnsupportedAPI() error { - if c.passthrsough { - return nil - } - return errUnsupportedApiName +func (c *ProviderConfig) PassthroughUnsupportedAPI() bool { + return c.passthrough } func CreateProvider(pc ProviderConfig) (Provider, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 0d4807fcfd..e5650f355f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -105,7 +105,7 @@ func (m *qwenProvider) GetProviderType() string { func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) @@ -150,7 +150,7 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b } if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } @@ -290,7 +290,7 @@ func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName Ap if m.config.isSupportedAPI(apiName) { return body, nil } - return nil, m.config.handleUnsupportedAPI() + return nil, errUnsupportedApiName } func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index 5ad514eb19..bac72e0239 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -75,7 +75,7 @@ func (p *sparkProvider) GetProviderType() string { func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !p.config.isSupportedAPI(apiName) { - return p.config.handleUnsupportedAPI() + return errUnsupportedApiName } p.config.handleRequestHeaders(p, ctx, apiName, log) return nil @@ -83,12 +83,15 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !p.config.isSupportedAPI(apiName) { - return types.ActionContinue, p.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log) } func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } sparkResponse := &sparkResponse{} if err := json.Unmarshal(body, sparkResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal spark response: %v", err) @@ -104,6 +107,9 @@ func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Ap if isLastChunk || len(chunk) == 0 { return nil, nil } + if name != ApiNameChatCompletion { + return chunk, nil + } responseBuilder := &strings.Builder{} lines := strings.Split(string(chunk), "\n") for _, data := range lines { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index 71e208710a..71315621a4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -50,7 +50,7 @@ func (m *stepfunProvider) GetProviderType() string { func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -58,7 +58,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go b/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go index 33a4203df9..dfbeb401ca 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go @@ -49,7 +49,7 @@ func (m *togetherAIProvider) GetProviderType() string { func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -57,7 +57,7 @@ func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index 80b0fbd743..3c3db4d5de 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -49,7 +49,7 @@ func (m *yiProvider) GetProviderType() string { func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -57,7 +57,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index 12936667f7..e95e99fc8b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -51,7 +51,7 @@ func (m *zhipuAiProvider) GetProviderType() string { func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if !m.config.isSupportedAPI(apiName) { - return m.config.handleUnsupportedAPI() + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) return nil @@ -59,7 +59,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, m.config.handleUnsupportedAPI() + return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) }