diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 4f08ab82c0..6244a0e60e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -280,5 +280,9 @@ func getOpenAiApiName(path string) provider.ApiName { if strings.HasSuffix(path, "/v1/images/generations") { return provider.ApiNameImageGeneration } + // rerank + if strings.HasSuffix(path, "/v1/rerank") { + return provider.ApiNameCohereV1Rerank + } return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index bddfccdf09..5fcc378d47 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -17,8 +17,7 @@ type azureProviderInitializer struct { func (m *azureProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - // azure 此配置无实质作用,只是为了保持和其他provider的一致性 - // TODO: azure的模式和openai是一致的,只是需要处理前缀,可以在TransformRequestHeaders中处理,以支持通用能力 + // TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities string(ApiNameChatCompletion): PathOpenAIChatCompletions, string(ApiNameEmbeddings): PathOpenAIEmbeddings, } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index 7e0c362148..d04c5c7d85 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -27,7 +27,6 @@ func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) err func (m *baichuanProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - // 百川AI的chat和embeddings接口和OpenAI的chat和embeddings接口一样 string(ApiNameChatCompletion): PathOpenAIChatCompletions, string(ApiNameEmbeddings): PathOpenAIEmbeddings, } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index 931d1b677c..a21964e497 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -13,7 +13,7 @@ import ( const ( cohereDomain = "api.cohere.com" - // TODO: 现在cohere有v2, 也有embeddings, 考虑更多支持: https://docs.cohere.com/v2/reference/rerank + // TODO: support more capabilities, upgrade to v2, docs: https://docs.cohere.com/v2/reference/chat cohereChatCompletionPath = "/v1/chat" cohereRerankPath = "/v1/rerank" ) @@ -100,7 +100,7 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe } func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath) + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) util.OverwriteRequestHostHeader(headers, cohereDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go index 874f7a48ae..8626b226af 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go @@ -22,9 +22,7 @@ func (m *cozeProviderInitializer) ValidateConfig(config *ProviderConfig) error { } func (m *cozeProviderInitializer) DefaultCapabilities() map[string]string { - return map[string]string{ - // 此配置暂时无实质作用,只是为了保持和其他provider的一致性 - } + return map[string]string{} } func (m *cozeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index 81ba9e8562..c8eca82a5c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -13,7 +13,8 @@ import ( const ( deepseekDomain = "api.deepseek.com" - // TODO: 根据文档 docs: https://api-docs.deepseek.com/api/create-chat-completion, path应该是 /chat/completions, 待验证 + // TODO: docs: https://api-docs.deepseek.com/api/create-chat-completion + // accourding to the docs, the path should be /chat/completions, need to be verified deepseekChatCompletionPath = "/v1/chat/completions" ) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 7e456115c0..757f3937d7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -36,11 +36,7 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error } func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string { - return map[string]string{ - // path在gemini中没有实际意义,只是为了保持和其他provider的一致性 - string(ApiNameChatCompletion): "_", - string(ApiNameEmbeddings): "_", - } + return map[string]string{} } func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index 9f7a71f06f..3f361a27ac 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -24,8 +24,7 @@ func (m *mistralProviderInitializer) ValidateConfig(config *ProviderConfig) erro func (m *mistralProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - // mistral的chat接口和OpenAI的chat接口一样 - // docs: https://docs.mistral.ai/api/ + // The chat interface of mistral is the same as that of OpenAI. docs: https://docs.mistral.ai/api/ string(ApiNameChatCompletion): PathOpenAIChatCompletions, string(ApiNameEmbeddings): PathOpenAIEmbeddings, } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 6c810f6dbe..d2d9efbbc1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -384,7 +384,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { case string(ApiNameChatCompletion), string(ApiNameEmbeddings), string(ApiNameImageGeneration), - string(ApiNameAudioSpeech): + string(ApiNameAudioSpeech), + string(ApiNameCohereV1Rerank): c.capabilities[capability] = pathJson.String() } }