Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: allow ai-proxy to forward standard AI capabilities that are natively supported #1704

Merged
merged 13 commits into from
Feb 12, 2025
Merged
3 changes: 2 additions & 1 deletion plugins/wasm-go/extensions/ai-proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ description: AI 代理插件配置参考
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |

| `capabilities` | map of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key表示的是采用的厂商协议能力,values表示的真实的厂商该能力的api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank |
| `passthrough` | bool | 非必填 | - | 只要是不支持的API能力都直接转发, 此配置是capabilities配置的放大版本,允许任意api透传,就像没有ai-proxy插件一样 |
`context`的配置字段说明如下:

| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
Expand Down
10 changes: 10 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,5 +274,15 @@ func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
if strings.HasSuffix(path, "/v1/audio/speech") {
return provider.ApiNameAudioSpeech
}
if strings.HasSuffix(path, "/v1/images/generations") {
return provider.ApiNameImageGeneration
}
// rerank
if strings.HasSuffix(path, "/v1/rerank") {
return provider.ApiNameCohereV1Rerank
johnlanni marked this conversation as resolved.
Show resolved Hide resolved
}
return ""
}
17 changes: 13 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ type ai360Provider struct {
contextCache *contextCache
}

func (m *ai360ProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}

func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
Expand All @@ -30,6 +37,7 @@ func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error
}

func (m *ai360ProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &ai360Provider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -41,22 +49,23 @@ func (m *ai360Provider) GetProviderType() string {
}

func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return errUnsupportedApiName
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return nil
}

func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}

func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, ai360Domain)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
}
17 changes: 13 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ import (
type azureProviderInitializer struct {
}

func (m *azureProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}

func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.azureServiceUrl == "" {
return errors.New("missing azureServiceUrl in provider config")
Expand All @@ -35,6 +43,7 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
} else {
serviceUrl = u
}
config.setDefaultCapabilities(m.DefaultCapabilities())
return &azureProvider{
config: config,
serviceUrl: serviceUrl,
Expand All @@ -54,16 +63,16 @@ func (m *azureProvider) GetProviderType() string {
}

func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return errUnsupportedApiName
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand Down
21 changes: 14 additions & 7 deletions plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import (
// baichuanProvider is the provider for baichuan Ai service.

const (
baichuanDomain = "api.baichuan-ai.com"
baichuanChatCompletionPath = "/v1/chat/completions"
baichuanDomain = "api.baichuan-ai.com"
)

type baichuanProviderInitializer struct {
Expand All @@ -26,7 +25,15 @@ func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) err
return nil
}

func (m *baichuanProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}

func (m *baichuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &baichuanProvider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -43,22 +50,22 @@ func (m *baichuanProvider) GetProviderType() string {
}

func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return errUnsupportedApiName
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}

func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, baichuanDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
Expand Down
19 changes: 14 additions & 5 deletions plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
const (
baiduDomain = "qianfan.baidubce.com"
baiduChatCompletionPath = "/v2/chat/completions"
baiduEmbeddings = "/v2/embeddings"
)

type baiduProviderInitializer struct{}
Expand All @@ -25,7 +26,15 @@ func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil
}

func (g *baiduProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): baiduChatCompletionPath,
string(ApiNameEmbeddings): baiduEmbeddings,
}
}

func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(g.DefaultCapabilities())
return &baiduProvider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -42,22 +51,22 @@ func (g *baiduProvider) GetProviderType() string {
}

func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return errUnsupportedApiName
if !g.config.isSupportedAPI(apiName) {
return g.config.handleUnsupportedAPI()
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
return nil
}

func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, g.config.handleUnsupportedAPI()
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}

func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, baiduChatCompletionPath)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
util.OverwriteRequestHostHeader(headers, baiduDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
Expand Down
25 changes: 20 additions & 5 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,16 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil
}

func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): claudeChatCompletionPath,
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}

func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(c.DefaultCapabilities())
return &claudeProvider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -102,15 +111,15 @@ func (c *claudeProvider) GetProviderType() string {
}

func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return errUnsupportedApiName
if !c.config.isSupportedAPI(apiName) {
return c.config.handleUnsupportedAPI()
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return nil
}

func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
util.OverwriteRequestHostHeader(headers, claudeDomain)

headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))
Expand All @@ -123,13 +132,16 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
}

func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, c.config.handleUnsupportedAPI()
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}

func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return c.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
request := &chatCompletionRequest{}
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
Expand All @@ -139,6 +151,9 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
}

func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
claudeResponse := &claudeTextGenResponse{}
if err := json.Unmarshal(body, claudeResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
Expand Down
14 changes: 10 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,14 @@ func (c *cloudflareProviderInitializer) ValidateConfig(config *ProviderConfig) e
}
return nil
}
func (c *cloudflareProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): cloudflareChatCompletionPath,
}
}

func (c *cloudflareProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(c.DefaultCapabilities())
return &cloudflareProvider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -43,16 +49,16 @@ func (c *cloudflareProvider) GetProviderType() string {
}

func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return errUnsupportedApiName
if !c.config.isSupportedAPI(apiName) {
return c.config.handleUnsupportedAPI()
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return nil
}

func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, c.config.handleUnsupportedAPI()
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}
Expand Down
25 changes: 19 additions & 6 deletions plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ import (
)

const (
cohereDomain = "api.cohere.com"
cohereDomain = "api.cohere.com"
// TODO: support more capabilities, upgrade to v2, docs: https://docs.cohere.com/v2/reference/chat
cohereChatCompletionPath = "/v1/chat"
cohereRerankPath = "/v1/rerank"
)

type cohereProviderInitializer struct{}
Expand All @@ -25,7 +27,15 @@ func (m *cohereProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil
}

func (m *cohereProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): cohereChatCompletionPath,
string(ApiNameCohereV1Rerank): cohereRerankPath,
}
}

func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &cohereProvider{
config: config,
contextCache: createContextCache(&config),
Expand Down Expand Up @@ -56,16 +66,16 @@ func (m *cohereProvider) GetProviderType() string {
}

func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return errUnsupportedApiName
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
Expand All @@ -90,13 +100,16 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe
}

func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, cohereDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
request := &chatCompletionRequest{}
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
Expand Down
Loading
Loading