Skip to content

Commit

Permalink
fix: correcting some errors in 'rewriting logic'
Browse files Browse the repository at this point in the history
  • Loading branch information
pepesi committed Jan 24, 2025
1 parent 21df731 commit eb23ddf
Show file tree
Hide file tree
Showing 27 changed files with 111 additions and 70 deletions.
34 changes: 22 additions & 12 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf

rawPath := ctx.Path()
path, _ := url.Parse(rawPath)
apiName := getOpenAiApiName(path.Path)
apiName := getApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig()
if providerConfig.IsOriginal() {
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
Expand All @@ -103,20 +103,25 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log)

hasRequestBody := wrapper.HasRequestBody()
err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
if err != nil {
if providerConfig.PassthroughUnsupportedAPI() {
log.Warnf("[onHttpRequestHeader] passthrough unsupported API: %v", err)
ctx.DontReadRequestBody()
return types.ActionContinue
}
ctx.DontReadRequestBody()
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
return types.ActionContinue
}

util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
hasRequestBody := wrapper.HasRequestBody()
if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
}
ctx.DontReadRequestBody()
return types.ActionContinue
}

Expand Down Expand Up @@ -151,6 +156,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
if err == nil {
return action
}
if pluginConfig.GetProviderConfig().PassthroughUnsupportedAPI() {
log.Warnf("[onHttpRequestBody] passthrough unsupported API: %v", err)
return types.ActionContinue
}
util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
}
return types.ActionContinue
Expand Down Expand Up @@ -267,7 +276,8 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
}
}

func getOpenAiApiName(path string) provider.ApiName {
func getApiName(path string) provider.ApiName {
// openai style
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
Expand All @@ -280,7 +290,7 @@ func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/images/generations") {
return provider.ApiNameImageGeneration
}
// rerank
// cohere style
if strings.HasSuffix(path, "/v1/rerank") {
return provider.ApiNameCohereV1Rerank
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (m *ai360Provider) GetProviderType() string {

func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
Expand All @@ -59,7 +59,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam

func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ func (m *azureProvider) GetProviderType() string {

func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ func (m *baichuanProvider) GetProviderType() string {

func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ func (g *baiduProvider) GetProviderType() string {

func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !g.config.isSupportedAPI(apiName) {
return g.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
return nil
}

func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, g.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
Expand Down
8 changes: 6 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (c *claudeProvider) GetProviderType() string {

func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !c.config.isSupportedAPI(apiName) {
return c.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return nil
Expand All @@ -133,7 +133,7 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam

func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, c.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}
Expand Down Expand Up @@ -169,6 +169,10 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
// only process the response from chat completion, skip other responses
if name != ApiNameChatCompletion {
return chunk, nil
}

responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ func (c *cloudflareProvider) GetProviderType() string {

func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !c.config.isSupportedAPI(apiName) {
return c.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return nil
}

func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, c.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ func (m *cohereProvider) GetProviderType() string {

func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
7 changes: 5 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (d *deeplProvider) GetProviderType() string {

func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !d.config.isSupportedAPI(apiName) {
return d.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
d.config.handleRequestHeaders(d, ctx, apiName, log)
return nil
Expand All @@ -97,7 +97,7 @@ func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName

func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !d.config.isSupportedAPI(apiName) {
return types.ActionContinue, d.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
}
Expand All @@ -119,6 +119,9 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api
}

func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
deeplResponse := &deeplResponse{}
if err := json.Unmarshal(body, deeplResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ func (m *deepseekProvider) GetProviderType() string {

func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
10 changes: 8 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/dify.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (d *difyProvider) GetProviderType() string {

func (d *difyProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return d.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
d.config.handleRequestHeaders(d, ctx, apiName, log)
return nil
Expand All @@ -78,7 +78,7 @@ func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName

func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, d.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
}
Expand All @@ -99,6 +99,9 @@ func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiN
}

func (d *difyProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
difyResponse := &DifyChatResponse{}
if err := json.Unmarshal(body, difyResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal dify response: %v", err)
Expand Down Expand Up @@ -150,6 +153,9 @@ func (d *difyProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
if name != ApiNameChatCompletion {
return chunk, nil
}
// sample event response:
// data: {"event": "agent_thought", "id": "8dcf3648-fbad-407a-85dd-73a6f43aeb9f", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "position": 1, "thought": "", "observation": "", "tool": "", "tool_input": "", "created_at": 1705639511, "message_files": [], "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"}

Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/doubao.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ func (m *doubaoProvider) GetProviderType() string {

func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
7 changes: 5 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (g *geminiProvider) GetProviderType() string {

func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !g.config.isSupportedAPI(apiName) {
return g.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
Expand All @@ -72,7 +72,7 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam

func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, g.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
Expand Down Expand Up @@ -115,6 +115,9 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
if name != ApiNameChatCompletion {
return chunk, nil
}
// sample end event response:
// data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini,一个大型多模态模型,由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}}
responseBuilder := &strings.Builder{}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (m *githubProvider) GetProviderType() string {

func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
Expand All @@ -62,7 +62,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa

func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/groq.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ func (g *groqProvider) GetProviderType() string {

func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !g.config.isSupportedAPI(apiName) {
return g.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
return nil
}

func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, g.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
Expand Down
Loading

0 comments on commit eb23ddf

Please sign in to comment.