Skip to content

Commit

Permalink
feature: support hunyuan openai compatibale api
Browse files Browse the repository at this point in the history
  • Loading branch information
pepesi committed Jan 22, 2025
1 parent 3ad26e3 commit ed5b04b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
40 changes: 31 additions & 9 deletions plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ const (

hunyuanAuthKeyLen = 32
hunyuanAuthIdLen = 36

// docs: https://cloud.tencent.com/document/product/1729/111007
hunyuanOpenAiDomain = "api.hunyuan.cloud.tencent.com"
hunyuanOpenAiRequestPath = "/v1/chat/completions"
)

type hunyuanProviderInitializer struct {
Expand Down Expand Up @@ -86,6 +90,10 @@ type hunyuanChatMessage struct {
}

func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) error {
// 允许 hunyuanauthid 和 hunyuanauthkey 为空, 当他们都为空的时候,认为是使用openai的 兼容接口
if len(config.hunyuanAuthId) == 0 && len(config.hunyuanAuthKey) == 0 {
return nil
}
// 校验hunyuan id 和 key的合法性
if len(config.hunyuanAuthId) != hunyuanAuthIdLen || len(config.hunyuanAuthKey) != hunyuanAuthKeyLen {
return errors.New("hunyuanAuthId / hunyuanAuthKey is illegal in config file")
Expand All @@ -95,7 +103,7 @@ func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) erro

func (m *hunyuanProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): hunyuanRequestPath,
string(ApiNameChatCompletion): hunyuanOpenAiRequestPath,
}
}

Expand All @@ -121,6 +129,10 @@ func (m *hunyuanProvider) GetProviderType() string {
return providerTypeHunyuan
}

func (m *hunyuanProvider) useOpenAICompatibleAPI() bool {
return len(m.config.hunyuanAuthId) == 0 && len(m.config.hunyuanAuthKey) == 0
}

func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName
Expand All @@ -131,20 +143,27 @@ func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
}

func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, hunyuanDomain)
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)

// 添加 hunyuan 需要的自定义字段
headers.Set(actionKey, hunyuanChatCompletionTCAction)
headers.Set(versionKey, versionValue)
if m.useOpenAICompatibleAPI() {
util.OverwriteRequestHostHeader(headers, hunyuanOpenAiDomain)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
} else {
util.OverwriteRequestHostHeader(headers, hunyuanDomain)
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
// 添加 hunyuan 需要的自定义字段
headers.Set(actionKey, hunyuanChatCompletionTCAction)
headers.Set(versionKey, versionValue)
}
}

// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
if m.useOpenAICompatibleAPI() {
return types.ActionContinue, nil
}

// 为header添加时间戳字段 (因为需要根据body进行签名时依赖时间戳,故于body处理部分创建时间戳)
var timestamp int64 = time.Now().Unix()
Expand Down Expand Up @@ -272,6 +291,9 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName

// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用
func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
if m.useOpenAICompatibleAPI() {
return body, nil
}
request := &chatCompletionRequest{}
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil {
Expand All @@ -297,7 +319,7 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
}

func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if m.config.protocol == protocolOriginal {
if m.config.IsOriginal() || m.useOpenAICompatibleAPI() {
return chunk, nil
}

Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ func (c *ProviderConfig) handleRequestBody(
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
) (types.Action, error) {
// use original protocol
if c.protocol == protocolOriginal {
if c.IsOriginal() {
return types.ActionContinue, nil
}

Expand Down

0 comments on commit ed5b04b

Please sign in to comment.