Skip to content

Commit

Permalink
feature: allow ai-proxy to forward standard AI capabilities that are … (
Browse files Browse the repository at this point in the history
  • Loading branch information
pepesi authored Feb 12, 2025
1 parent 477e44b commit a84a382
Show file tree
Hide file tree
Showing 32 changed files with 517 additions and 158 deletions.
3 changes: 2 additions & 1 deletion plugins/wasm-go/extensions/ai-proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ description: AI 代理插件配置参考
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |

| `capabilities` | map of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key表示的是采用的厂商协议能力,values表示的真实的厂商该能力的api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank |
| `passthrough` | bool | 非必填 | - | 只要是不支持的API能力都直接转发, 此配置是capabilities配置的放大版本,允许任意api透传,就像没有ai-proxy插件一样 |
`context`的配置字段说明如下:

| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
Expand Down
42 changes: 31 additions & 11 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf

rawPath := ctx.Path()
path, _ := url.Parse(rawPath)
apiName := getOpenAiApiName(path.Path)
apiName := getApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig()
if providerConfig.IsOriginal() {
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
Expand All @@ -103,20 +103,25 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log)

hasRequestBody := wrapper.HasRequestBody()
err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
if err != nil {
if providerConfig.PassthroughUnsupportedAPI() {
log.Warnf("[onHttpRequestHeader] passthrough unsupported API: %v", err)
ctx.DontReadRequestBody()
return types.ActionContinue
}
ctx.DontReadRequestBody()
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
return types.ActionContinue
}

util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
hasRequestBody := wrapper.HasRequestBody()
if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
}
ctx.DontReadRequestBody()
return types.ActionContinue
}

Expand Down Expand Up @@ -151,6 +156,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
if err == nil {
return action
}
if pluginConfig.GetProviderConfig().PassthroughUnsupportedAPI() {
log.Warnf("[onHttpRequestBody] passthrough unsupported API: %v", err)
return types.ActionContinue
}
util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
}
return types.ActionContinue
Expand Down Expand Up @@ -267,12 +276,23 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
}
}

func getOpenAiApiName(path string) provider.ApiName {
func getApiName(path string) provider.ApiName {
// openai style
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
if strings.HasSuffix(path, "/v1/audio/speech") {
return provider.ApiNameAudioSpeech
}
if strings.HasSuffix(path, "/v1/images/generations") {
return provider.ApiNameImageGeneration
}
// cohere style
if strings.HasSuffix(path, "/v1/rerank") {
return provider.ApiNameCohereV1Rerank
}
return ""
}
13 changes: 11 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ type ai360Provider struct {
contextCache *contextCache
}

func (m *ai360ProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}

func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
Expand All @@ -30,6 +37,7 @@ func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error
}

func (m *ai360ProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &ai360Provider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -41,7 +49,7 @@ func (m *ai360Provider) GetProviderType() string {
}

func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
Expand All @@ -50,13 +58,14 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
}

func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}

func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, ai360Domain)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
}
13 changes: 11 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ import (
type azureProviderInitializer struct {
}

func (m *azureProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}

func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.azureServiceUrl == "" {
return errors.New("missing azureServiceUrl in provider config")
Expand All @@ -35,6 +43,7 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
} else {
serviceUrl = u
}
config.setDefaultCapabilities(m.DefaultCapabilities())
return &azureProvider{
config: config,
serviceUrl: serviceUrl,
Expand All @@ -54,15 +63,15 @@ func (m *azureProvider) GetProviderType() string {
}

func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
Expand Down
17 changes: 12 additions & 5 deletions plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import (
// baichuanProvider is the provider for baichuan Ai service.

const (
baichuanDomain = "api.baichuan-ai.com"
baichuanChatCompletionPath = "/v1/chat/completions"
baichuanDomain = "api.baichuan-ai.com"
)

type baichuanProviderInitializer struct {
Expand All @@ -26,7 +25,15 @@ func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) err
return nil
}

func (m *baichuanProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}

func (m *baichuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &baichuanProvider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -43,22 +50,22 @@ func (m *baichuanProvider) GetProviderType() string {
}

func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}

func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, baichuanDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
Expand Down
15 changes: 12 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
const (
baiduDomain = "qianfan.baidubce.com"
baiduChatCompletionPath = "/v2/chat/completions"
baiduEmbeddings = "/v2/embeddings"
)

type baiduProviderInitializer struct{}
Expand All @@ -25,7 +26,15 @@ func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil
}

func (g *baiduProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): baiduChatCompletionPath,
string(ApiNameEmbeddings): baiduEmbeddings,
}
}

func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(g.DefaultCapabilities())
return &baiduProvider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -42,22 +51,22 @@ func (g *baiduProvider) GetProviderType() string {
}

func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
if !g.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
return nil
}

func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}

func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, baiduChatCompletionPath)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
util.OverwriteRequestHostHeader(headers, baiduDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
Expand Down
25 changes: 22 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,16 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil
}

func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): claudeChatCompletionPath,
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}

func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(c.DefaultCapabilities())
return &claudeProvider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -102,15 +111,15 @@ func (c *claudeProvider) GetProviderType() string {
}

func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
if !c.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return nil
}

func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
util.OverwriteRequestHostHeader(headers, claudeDomain)

headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))
Expand All @@ -123,13 +132,16 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
}

func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}

func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return c.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
request := &chatCompletionRequest{}
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
Expand All @@ -139,6 +151,9 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
}

func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
claudeResponse := &claudeTextGenResponse{}
if err := json.Unmarshal(body, claudeResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
Expand All @@ -154,6 +169,10 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
// only process the response from chat completion, skip other responses
if name != ApiNameChatCompletion {
return chunk, nil
}

responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
Expand Down
10 changes: 8 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,14 @@ func (c *cloudflareProviderInitializer) ValidateConfig(config *ProviderConfig) e
}
return nil
}
func (c *cloudflareProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): cloudflareChatCompletionPath,
}
}

func (c *cloudflareProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(c.DefaultCapabilities())
return &cloudflareProvider{
config: config,
contextCache: createContextCache(&config),
Expand All @@ -43,15 +49,15 @@ func (c *cloudflareProvider) GetProviderType() string {
}

func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
if !c.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return nil
}

func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
Expand Down
Loading

0 comments on commit a84a382

Please sign in to comment.