Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: AI Proxy Wasm Plugin Integration with GitHub Models #1304 #1362

Merged
merged 12 commits into from
Oct 6, 2024
107 changes: 106 additions & 1 deletion plugins/wasm-go/extensions/ai-proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。

360智脑所对应的 `type` 为 `ai360`。它并无特有的配置字段。

#### GitHub模型

GitHub模型所对应的 `type` 为 `github`。它并无特有的配置字段。

#### Mistral

Mistral 所对应的 `type` 为 `mistral`。它并无特有的配置字段。
Expand Down Expand Up @@ -1018,6 +1022,107 @@ provider:
}
```

### 使用 OpenAI 协议代理 GitHub 模型服务

**配置信息**

```yaml
provider:
type: github
apiTokens:
- "YOUR_GITHUB_ACCESS_TOKEN"
modelMapping:
"gpt-4o": "gpt-4o"
"gpt-4": "Phi-3.5-MoE-instruct"
"gpt-3.5": "cohere-command-r-08-2024"
"text-embedding-3-large": "text-embedding-3-large"
```

**请求示例**

```json
{
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is the capital of France?"
}
],
"stream": true,
"temperature": 1.0,
"top_p": 1.0,
"max_tokens": 1000,
"model": "gpt-4o"
}
```

**响应示例**
```json
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The capital of France is Paris.",
"role": "assistant"
}
}
],
"created": 1728131051,
"id": "chatcmpl-AEy7PU2JImdsD1W6Jw8GigZSEnM2u",
"model": "gpt-4o-2024-08-06",
"object": "chat.completion",
"system_fingerprint": "fp_67802d9a6d",
"usage": {
"completion_tokens": 7,
"prompt_tokens": 24,
"total_tokens": 31
}
}
```

**文本向量请求示例**

```json
{
"input": ["first phrase", "second phrase", "third phrase"],
"model": "text-embedding-3-large"
}
```

响应示例:

```json
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
-0.0012583479,
0.0020349282,
...
0.012051377,
-0.0053306012,
0.0060688322
]
}
],
"model": "text-embedding-3-large",
"usage": {
"prompt_tokens": 6,
"total_tokens": 6
}
}
```

### 使用 OpenAI 协议代理360智脑服务

**配置信息**
Expand All @@ -1026,7 +1131,7 @@ provider:
provider:
type: ai360
apiTokens:
- "YOUR_MINIMAX_API_TOKEN"
- "YOUR_360_API_TOKEN"
modelMapping:
"gpt-4o": "360gpt-turbo-responsibility-8k"
"gpt-4": "360gpt2-pro"
Expand Down
112 changes: 112 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/github.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package provider

import (
"encoding/json"
"errors"
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)

// githubProvider is the provider for GitHub OpenAI service.
const (
githubDomain = "models.inference.ai.azure.com"
githubCompletionPath = "/chat/completions"
githubEmbeddingPath = "/embeddings"
)

type githubProviderInitializer struct {
}

type githubProvider struct {
config ProviderConfig
contextCache *contextCache
}

func (m *githubProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}

func (m *githubProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &githubProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}

func (m *githubProvider) GetProviderType() string {
return providerTypeGithub
}

func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(githubDomain)
if apiName == ApiNameChatCompletion {
_ = util.OverwriteRequestPath(githubCompletionPath)
}
if apiName == ApiNameEmbeddings {
_ = util.OverwriteRequestPath(githubEmbeddingPath)
}
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}

func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, log)
}
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsRequestBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
}

func (m *githubProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}

func (m *githubProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in embeddings request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
5 changes: 4 additions & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import (
"math/rand"
"strings"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)

type ApiName string
Expand All @@ -20,6 +21,7 @@ const (
providerTypeMoonshot = "moonshot"
providerTypeAzure = "azure"
providerTypeAi360 = "ai360"
providerTypeGithub = "github"
providerTypeQwen = "qwen"
providerTypeOpenAI = "openai"
providerTypeGroq = "groq"
Expand Down Expand Up @@ -78,6 +80,7 @@ var (
providerTypeMoonshot: &moonshotProviderInitializer{},
providerTypeAzure: &azureProviderInitializer{},
providerTypeAi360: &ai360ProviderInitializer{},
providerTypeGithub: &githubProviderInitializer{},
providerTypeQwen: &qwenProviderInitializer{},
providerTypeOpenAI: &openaiProviderInitializer{},
providerTypeGroq: &groqProviderInitializer{},
Expand Down
Loading