Skip to content

Commit 80655f1

Browse files
author
Val Kharitonov
committed
add mistralai
1 parent a957b4b commit 80655f1

File tree

4 files changed

+55
-0
lines changed

4 files changed

+55
-0
lines changed

gptcli/assistant.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from gptcli.completion import CompletionProvider, ModelOverrides, Message
88
from gptcli.google import GoogleCompletionProvider
99
from gptcli.llama import LLaMACompletionProvider
10+
from gptcli.mistral import MistralCompletionProvider
1011
from gptcli.openai import OpenAICompletionProvider
1112
from gptcli.anthropic import AnthropicCompletionProvider
1213

@@ -64,6 +65,8 @@ def get_completion_provider(model: str) -> CompletionProvider:
6465
return LLaMACompletionProvider()
6566
elif model.startswith("chat-bison"):
6667
return GoogleCompletionProvider()
68+
elif model.startswith("mistral"):
69+
return MistralCompletionProvider()
6770
else:
6871
raise ValueError(f"Unknown model: {model}")
6972

gptcli/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class GptCliConfig:
2020
show_price: bool = True
2121
api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
2222
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
23+
mistral_api_key: Optional[str] = os.environ.get("MISTRAL_API_KEY")
2324
anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY")
2425
google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY")
2526
log_file: Optional[str] = None

gptcli/gpt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import datetime
1616
import google.generativeai as genai
1717
import gptcli.anthropic
18+
import gptcli.mistral
1819
from gptcli.assistant import (
1920
Assistant,
2021
DEFAULT_ASSISTANTS,
@@ -178,6 +179,9 @@ def main():
178179
)
179180
sys.exit(1)
180181

182+
if config.mistral_api_key:
183+
gptcli.mistral.api_key = config.mistral_api_key
184+
181185
if config.anthropic_api_key:
182186
gptcli.anthropic.api_key = config.anthropic_api_key
183187

gptcli/mistral.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Iterator, List
2+
import os
3+
from gptcli.completion import CompletionProvider, Message
4+
from mistralai.client import MistralClient
5+
from mistralai.models.chat_completion import ChatMessage
6+
7+
api_key = os.environ.get("MISTRAL_API_KEY")
8+
9+
10+
class MistralCompletionProvider(CompletionProvider):
11+
def __init__(self):
12+
self.client = MistralClient(api_key=api_key)
13+
14+
def complete(
15+
self, messages: List[Message], args: dict, stream: bool = False
16+
) -> Iterator[str]:
17+
kwargs = {}
18+
if "temperature" in args:
19+
kwargs["temperature"] = args["temperature"]
20+
if "top_p" in args:
21+
kwargs["top_p"] = args["top_p"]
22+
23+
messages = [
24+
ChatMessage(role=msg["role"], content=msg["content"])
25+
for msg in messages
26+
]
27+
28+
if stream:
29+
response_iter = self.client.chat_stream(
30+
model=args["model"],
31+
messages=messages,
32+
**kwargs,
33+
)
34+
35+
for response in response_iter:
36+
next_choice = response.choices[0]
37+
if next_choice.finish_reason is None and next_choice.delta.content:
38+
yield next_choice.delta.content
39+
else:
40+
response = self.client.chat(
41+
model=args["model"],
42+
messages=messages,
43+
**kwargs,
44+
)
45+
next_choice = response.choices[0]
46+
if next_choice.message.content:
47+
yield next_choice.message.content

0 commit comments

Comments
 (0)