Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: allow ai-proxy to forward standard AI capabilities that are natively supported #1704

Merged
merged 13 commits into from
Feb 12, 2025
Merged
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (m *ai360Provider) GetProviderType() string {

func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
Expand All @@ -59,7 +59,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam

func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ func (m *azureProvider) GetProviderType() string {

func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ func (m *baichuanProvider) GetProviderType() string {

func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ func (g *baiduProvider) GetProviderType() string {

func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !g.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return g.config.handleUnsupportedAPI()
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
return nil
}

func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, g.config.handleUnsupportedAPI()
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (c *claudeProvider) GetProviderType() string {

func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !c.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return c.config.handleUnsupportedAPI()
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return nil
Expand All @@ -133,7 +133,7 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam

func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, c.config.handleUnsupportedAPI()
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ func (c *cloudflareProvider) GetProviderType() string {

func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !c.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return c.config.handleUnsupportedAPI()
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return nil
}

func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, c.config.handleUnsupportedAPI()
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ func (m *cohereProvider) GetProviderType() string {

func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (d *deeplProvider) GetProviderType() string {

func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !d.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return d.config.handleUnsupportedAPI()
}
d.config.handleRequestHeaders(d, ctx, apiName, log)
return nil
Expand All @@ -97,7 +97,7 @@ func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName

func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !d.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, d.config.handleUnsupportedAPI()
}
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ func (m *deepseekProvider) GetProviderType() string {

func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/dify.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (d *difyProvider) GetProviderType() string {

func (d *difyProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return errUnsupportedApiName
return d.config.handleUnsupportedAPI()
}
d.config.handleRequestHeaders(d, ctx, apiName, log)
return nil
Expand All @@ -78,7 +78,7 @@ func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName

func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, d.config.handleUnsupportedAPI()
}
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/doubao.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ func (m *doubaoProvider) GetProviderType() string {

func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (g *geminiProvider) GetProviderType() string {

func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !g.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return g.config.handleUnsupportedAPI()
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
Expand All @@ -72,7 +72,7 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam

func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, g.config.handleUnsupportedAPI()
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (m *githubProvider) GetProviderType() string {

func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
Expand All @@ -62,7 +62,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa

func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/groq.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ func (g *groqProvider) GetProviderType() string {

func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !g.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return g.config.handleUnsupportedAPI()
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
return nil
}

func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, g.config.handleUnsupportedAPI()
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (m *hunyuanProvider) useOpenAICompatibleAPI() bool {

func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
Expand All @@ -161,7 +161,7 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
if m.useOpenAICompatibleAPI() {
return types.ActionContinue, nil
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/minimax.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (m *minimaxProvider) GetProviderType() string {

func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
Expand All @@ -90,7 +90,7 @@ func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa

func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
if minimaxApiTypePro == m.config.minimaxApiType {
// Use chat completion Pro API.
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/mistral.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ func (m *mistralProvider) GetProviderType() string {

func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (m *moonshotProvider) GetProviderType() string {

func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
Expand All @@ -82,7 +82,7 @@ func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN
// moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法
func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
// 非chat类型的请求,不做处理
if apiName != ApiNameChatCompletion {
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ func (m *ollamaProvider) GetProviderType() string {

func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
10 changes: 10 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ type ProviderConfig struct {
// @Title zh-CN 额外支持的ai能力
// @Description zh-CN 开放的ai能力和urlpath映射,例如: {"openai/v1/chatcompletions": "/v1/chat/completions"}
capabilities map[string]string
// @Title zh-CN 是否开启透传
// @Description zh-CN 如果是插件不支持的API,是否透传请求, 默认为false
passthrsough bool
}

func (c *ProviderConfig) GetId() string {
Expand Down Expand Up @@ -450,6 +453,13 @@ func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
return ReplaceByCustomSettings(body, c.customSettings)
}

func (c *ProviderConfig) handleUnsupportedAPI() error {
if c.passthrsough {
return nil
johnlanni marked this conversation as resolved.
Show resolved Hide resolved
}
return errUnsupportedApiName
}

func CreateProvider(pc ProviderConfig) (Provider, error) {
initializer, has := providerInitializers[pc.typ]
if !has {
Expand Down
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (m *qwenProvider) GetProviderType() string {

func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
return m.config.handleUnsupportedAPI()
}

m.config.handleRequestHeaders(m, ctx, apiName, log)
Expand Down Expand Up @@ -150,7 +150,7 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
}

if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down Expand Up @@ -290,7 +290,7 @@ func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName Ap
if m.config.isSupportedAPI(apiName) {
return body, nil
}
return nil, errUnsupportedApiName
return nil, m.config.handleUnsupportedAPI()
}

func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
Expand Down
Loading
Loading