Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: allow ai-proxy to forward standard AI capabilities that are natively supported #1704

Merged
merged 13 commits into from
Feb 12, 2025
Merged
3 changes: 2 additions & 1 deletion plugins/wasm-go/extensions/ai-proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ description: AI 代理插件配置参考
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |

| `capabilities` | map of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key表示的是采用的厂商协议能力,values表示的真实的厂商该能力的api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank |
| `passthrough` | bool | 非必填 | - | 只要是不支持的API能力都直接转发, 此配置是capabilities配置的放大版本,允许任意api透传,就像没有ai-proxy插件一样 |
`context`的配置字段说明如下:

| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
Expand Down
42 changes: 31 additions & 11 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf

rawPath := ctx.Path()
path, _ := url.Parse(rawPath)
apiName := getOpenAiApiName(path.Path)
apiName := getApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig()
if providerConfig.IsOriginal() {
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
Expand All @@ -103,20 +103,25 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log)

hasRequestBody := wrapper.HasRequestBody()
err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
if err != nil {
johnlanni marked this conversation as resolved.
Show resolved Hide resolved
if providerConfig.PassthroughUnsupportedAPI() {
log.Warnf("[onHttpRequestHeader] passthrough unsupported API: %v", err)
ctx.DontReadRequestBody()
return types.ActionContinue
}
ctx.DontReadRequestBody()
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
return types.ActionContinue
}

util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
hasRequestBody := wrapper.HasRequestBody()
if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
}
ctx.DontReadRequestBody()
return types.ActionContinue
}

Expand Down Expand Up @@ -151,6 +156,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
if err == nil {
return action
}
if pluginConfig.GetProviderConfig().PassthroughUnsupportedAPI() {
log.Warnf("[onHttpRequestBody] passthrough unsupported API: %v", err)
return types.ActionContinue
}
util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
}
return types.ActionContinue
Expand Down Expand Up @@ -267,12 +276,23 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
}
}

func getOpenAiApiName(path string) provider.ApiName {
func getApiName(path string) provider.ApiName {
// openai style
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
if strings.HasSuffix(path, "/v1/audio/speech") {
return provider.ApiNameAudioSpeech
}
if strings.HasSuffix(path, "/v1/images/generations") {
return provider.ApiNameImageGeneration
}
// cohere style
if strings.HasSuffix(path, "/v1/rerank") {
return provider.ApiNameCohereV1Rerank
johnlanni marked this conversation as resolved.
Show resolved Hide resolved
}
return ""
}
13 changes: 11 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ type ai360Provider struct {
contextCache *contextCache
}

func (m *ai360ProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}

func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
Expand All @@ -30,6 +37,7 @@ func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error
}

func (m *ai360ProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &ai360Provider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -41,7 +49,7 @@ func (m *ai360Provider) GetProviderType() string {
}

func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
Expand All @@ -50,13 +58,14 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
}

func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}

func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, ai360Domain)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
}
13 changes: 11 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ import (
type azureProviderInitializer struct {
}

func (m *azureProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}

func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.azureServiceUrl == "" {
return errors.New("missing azureServiceUrl in provider config")
Expand All @@ -35,6 +43,7 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
} else {
serviceUrl = u
}
config.setDefaultCapabilities(m.DefaultCapabilities())
return &azureProvider{
config: config,
serviceUrl: serviceUrl,
Expand All @@ -54,15 +63,15 @@ func (m *azureProvider) GetProviderType() string {
}

func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
Expand Down
17 changes: 12 additions & 5 deletions plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import (
// baichuanProvider is the provider for baichuan Ai service.

const (
baichuanDomain = "api.baichuan-ai.com"
baichuanChatCompletionPath = "/v1/chat/completions"
baichuanDomain = "api.baichuan-ai.com"
)

type baichuanProviderInitializer struct {
Expand All @@ -26,7 +25,15 @@ func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) err
return nil
}

func (m *baichuanProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}

func (m *baichuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &baichuanProvider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -43,22 +50,22 @@ func (m *baichuanProvider) GetProviderType() string {
}

func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}

func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, baichuanDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
Expand Down
15 changes: 12 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
const (
baiduDomain = "qianfan.baidubce.com"
baiduChatCompletionPath = "/v2/chat/completions"
baiduEmbeddings = "/v2/embeddings"
)

type baiduProviderInitializer struct{}
Expand All @@ -25,7 +26,15 @@ func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil
}

func (g *baiduProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): baiduChatCompletionPath,
string(ApiNameEmbeddings): baiduEmbeddings,
}
}

func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(g.DefaultCapabilities())
return &baiduProvider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -42,22 +51,22 @@ func (g *baiduProvider) GetProviderType() string {
}

func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
if !g.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
return nil
}

func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}

func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, baiduChatCompletionPath)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
util.OverwriteRequestHostHeader(headers, baiduDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
Expand Down
25 changes: 22 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,16 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil
}

func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): claudeChatCompletionPath,
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}

func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(c.DefaultCapabilities())
return &claudeProvider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -102,15 +111,15 @@ func (c *claudeProvider) GetProviderType() string {
}

func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
if !c.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return nil
}

func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
util.OverwriteRequestHostHeader(headers, claudeDomain)

headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))
Expand All @@ -123,13 +132,16 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
}

func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}

func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return c.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
request := &chatCompletionRequest{}
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
Expand All @@ -139,6 +151,9 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
}

func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
claudeResponse := &claudeTextGenResponse{}
if err := json.Unmarshal(body, claudeResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
Expand All @@ -154,6 +169,10 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
// only process the response from chat completion, skip other responses
if name != ApiNameChatCompletion {
return chunk, nil
}

responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
Expand Down
10 changes: 8 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,14 @@ func (c *cloudflareProviderInitializer) ValidateConfig(config *ProviderConfig) e
}
return nil
}
func (c *cloudflareProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): cloudflareChatCompletionPath,
}
}

func (c *cloudflareProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(c.DefaultCapabilities())
return &cloudflareProvider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -43,15 +49,15 @@ func (c *cloudflareProvider) GetProviderType() string {
}

func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
if !c.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return nil
}

func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
Expand Down
Loading
Loading