Skip to content

Commit 1f6defb

Browse files
author
Val Kharitonov
committed
add mistralai
1 parent 9ce23b0 commit 1f6defb

File tree

5 files changed

+57
-1
lines changed

5 files changed

+57
-1
lines changed

gptcli/assistant.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from gptcli.completion import CompletionProvider, ModelOverrides, Message
88
from gptcli.google import GoogleCompletionProvider
99
from gptcli.llama import LLaMACompletionProvider
10+
from gptcli.mistral import MistralCompletionProvider
1011
from gptcli.openai import OpenAICompletionProvider
1112
from gptcli.anthropic import AnthropicCompletionProvider
1213

@@ -64,6 +65,8 @@ def get_completion_provider(model: str) -> CompletionProvider:
6465
return LLaMACompletionProvider()
6566
elif model.startswith("chat-bison"):
6667
return GoogleCompletionProvider()
68+
elif model.startswith("mistral"):
69+
return MistralCompletionProvider()
6770
else:
6871
raise ValueError(f"Unknown model: {model}")
6972

gptcli/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class GptCliConfig:
2020
show_price: bool = True
2121
api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
2222
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
23+
mistral_api_key: Optional[str] = os.environ.get("MISTRAL_API_KEY")
2324
anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY")
2425
google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY")
2526
log_file: Optional[str] = None

gptcli/gpt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import datetime
1616
import google.generativeai as genai
1717
import gptcli.anthropic
18+
import gptcli.mistral
1819
from gptcli.assistant import (
1920
Assistant,
2021
DEFAULT_ASSISTANTS,
@@ -178,6 +179,9 @@ def main():
178179
)
179180
sys.exit(1)
180181

182+
if config.mistral_api_key:
183+
gptcli.mistral.api_key = config.mistral_api_key
184+
181185
if config.anthropic_api_key:
182186
gptcli.anthropic.api_key = config.anthropic_api_key
183187

gptcli/mistral.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Iterator, List
2+
import os
3+
from gptcli.completion import CompletionProvider, Message
4+
from mistralai.client import MistralClient, ChatMessage
5+
6+
api_key = os.environ.get("MISTRAL_API_KEY")
7+
8+
9+
class MistralCompletionProvider(CompletionProvider):
10+
def __init__(self):
11+
self.client = MistralClient(api_key=api_key)
12+
13+
def complete(
14+
self, messages: List[Message], args: dict, stream: bool = False
15+
) -> Iterator[str]:
16+
kwargs = {}
17+
if "temperature" in args:
18+
kwargs["temperature"] = args["temperature"]
19+
if "top_p" in args:
20+
kwargs["top_p"] = args["top_p"]
21+
22+
if stream:
23+
response_iter = self.client.chat_stream(
24+
model=args["model"],
25+
messages=[
26+
ChatMessage(role=msg["role"], content=msg["content"])
27+
for msg in messages
28+
],
29+
**kwargs,
30+
)
31+
32+
for response in response_iter:
33+
next_choice = response.choices[0]
34+
if next_choice.finish_reason is None and next_choice.delta.content:
35+
yield next_choice.delta.content
36+
else:
37+
response = self.client.chat(
38+
model=args["model"],
39+
messages=[
40+
ChatMessage(role=msg["role"], content=msg["content"])
41+
for msg in messages
42+
],
43+
**kwargs,
44+
)
45+
next_choice = response.choices[0]
46+
if next_choice.message.content:
47+
yield next_choice.message.content

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies = [
2020
"anthropic==0.7.7",
2121
"attrs==23.1.0",
2222
"black==23.1.0",
23+
"mistralai==0.0.8",
2324
"google-generativeai==0.1.0",
2425
"openai==1.3.8",
2526
"prompt-toolkit==3.0.41",
@@ -28,7 +29,7 @@ dependencies = [
2829
"rich==13.7.0",
2930
"tiktoken==0.5.2",
3031
"tokenizers==0.15.0",
31-
"typing_extensions==4.5.0",
32+
"typing_extensions==4.9.0",
3233
]
3334

3435
[project.optional-dependencies]

0 commit comments

Comments
 (0)