Skip to content

Commit

Permalink
fix: set max tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
missuo committed Apr 19, 2024
1 parent bf14762 commit 6a68a2c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
26 changes: 24 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* @Author: Vincent Yang
* @Date: 2024-04-16 22:58:22
* @LastEditors: Vincent Yang
* @LastEditTime: 2024-04-18 20:13:27
* @LastEditors: Vincent Young
* @LastEditTime: 2024-04-19 03:45:05
* @FilePath: /cohere2openai/main.go
* @Telegram: https://t.me/missuo
* @GitHub: https://github.com/missuo
Expand Down Expand Up @@ -45,6 +45,7 @@ func cohereRequest(c *gin.Context, openAIReq OpenAIRequest) {
ChatHistory: []ChatMessage{},
Message: "",
Stream: openAIReq.Stream,
MaxTokens: openAIReq.MaxTokens,
}

for _, msg := range openAIReq.Messages {
Expand All @@ -67,6 +68,7 @@ func cohereRequest(c *gin.Context, openAIReq OpenAIRequest) {
}

reqBody, _ := json.Marshal(cohereReq)
fmt.Println(string(reqBody))
req, err := http.NewRequest("POST", "https://api.cohere.ai/v1/chat", bytes.NewBuffer(reqBody))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
Expand Down Expand Up @@ -180,6 +182,7 @@ func cohereNonStreamRequest(c *gin.Context, openAIReq OpenAIRequest) {
ChatHistory: []ChatMessage{},
Message: "",
Stream: openAIReq.Stream,
MaxTokens: openAIReq.MaxTokens,
}

for _, msg := range openAIReq.Messages {
Expand Down Expand Up @@ -261,6 +264,25 @@ func handler(c *gin.Context) {
if !isInSlice(openAIReq.Model, allowModels) {
openAIReq.Model = "command-r-plus"
}

// Set max tokens based on model
switch openAIReq.Model {
case "command-light":
openAIReq.MaxTokens = 4000
case "command":
openAIReq.MaxTokens = 4000
case "command-light-nightly":
openAIReq.MaxTokens = 4000
case "command-nightly":
openAIReq.MaxTokens = 4000
case "command-r":
openAIReq.MaxTokens = 4000
case "command-r-plus":
openAIReq.MaxTokens = 4000
default:
openAIReq.MaxTokens = 4096
}

if openAIReq.Stream {
cohereRequest(c, openAIReq)
} else {
Expand Down
8 changes: 5 additions & 3 deletions types.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* @Author: Vincent Yang
* @Date: 2024-04-16 22:58:27
* @LastEditors: Vincent Yang
* @LastEditTime: 2024-04-18 04:34:55
* @LastEditors: Vincent Young
* @LastEditTime: 2024-04-19 03:34:12
* @FilePath: /cohere2openai/types.go
* @Telegram: https://t.me/missuo
* @GitHub: https://github.com/missuo
Expand All @@ -18,14 +18,16 @@ type OpenAIRequest struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"messages"`
Stream bool `json:"stream"`
Stream bool `json:"stream"`
MaxTokens int64 `json:"max_tokens"`
}

type CohereRequest struct {
Model string `json:"model"`
ChatHistory []ChatMessage `json:"chat_history"`
Message string `json:"message"`
Stream bool `json:"stream"`
MaxTokens int64 `json:"max_tokens"`
}

type ChatMessage struct {
Expand Down

0 comments on commit 6a68a2c

Please sign in to comment.