diff --git a/bot/bot_factory.py b/bot/bot_factory.py index 2046da71b..d3a2fa1ae 100644 --- a/bot/bot_factory.py +++ b/bot/bot_factory.py @@ -56,5 +56,8 @@ def create_bot(bot_type): from bot.zhipuai.zhipuai_bot import ZHIPUAIBot return ZHIPUAIBot() + elif bot_type == const.MISTRAL: + from bot.mistral.mistralai_bot import MistralAIBot + return MistralAIBot() raise RuntimeError diff --git a/bot/mistral/mistralai_bot.py b/bot/mistral/mistralai_bot.py new file mode 100644 index 000000000..69b1bd602 --- /dev/null +++ b/bot/mistral/mistralai_bot.py @@ -0,0 +1,112 @@ +# encoding:utf-8 + +import time + +from mistralai.client import MistralClient +from mistralai.models.chat_completion import ChatMessage + +from bot.bot import Bot +from bot.mistral.mistralai_session import MistralAISession +from bot.session_manager import SessionManager +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger +from config import conf + +user_session = dict() + + +# OpenAI对话模型API (可用) +class MistralAIBot(Bot): + def __init__(self): + super().__init__() + api_key = conf().get("mistralai_api_key") + self.client = MistralClient(api_key=api_key) + self.system_prompt = conf().get("character_desc", "") + self.sessions = SessionManager(MistralAISession, model=conf().get("model") or "mistral-large-latest") + self.model = conf().get("model") or "mistral-large-latest" # 对话模型的名称 + self.temperature = conf().get("temperature", 0.7) # 值在[0,1]之间,越大表示回复越具有不确定性 + self.top_p = conf().get("top_p", 1) + self.safe_prompt = True + logger.info("[MISTRAL_AI] Create finish.") + + def reply(self, query, context=None): + # acquire reply content + if context and context.type: + if context.type == ContextType.TEXT: + logger.info("[MISTRAL_AI] query={}".format(query)) + session_id = context["session_id"] + reply = None + if query == "#清除记忆": + self.sessions.clear_session(session_id) + reply = Reply(ReplyType.INFO, "记忆已清除") + elif query == "#清除所有": + self.sessions.clear_all_session() + reply = Reply(ReplyType.INFO, "所有人记忆已清除") + else: + session = self.sessions.session_query(query, session_id) + result = self.reply_text(session) + total_tokens, completion_tokens, reply_content = ( + result["total_tokens"], + result["completion_tokens"], + result["content"], + ) + logger.debug( + "[MISTRAL_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens) + ) + + if total_tokens == 0: + reply = Reply(ReplyType.ERROR, reply_content) + else: + self.sessions.session_reply(reply_content, session_id, total_tokens) + reply = Reply(ReplyType.TEXT, reply_content) + return reply + else: + logger.info("[MISTRAL_AI] context={}".format(context)) + + def reply_text(self, session: MistralAISession): + try: + messages = self._convert_to_mistral_messages(self._filter_messages(session.messages)) + response = self.client.chat(messages, temperature=self.temperature, model=self.model, + top_p=self.top_p, safe_prompt=self.safe_prompt) + res_content = response.choices[0].message.content + total_tokens = response.usage.total_tokens + completion_tokens = response.usage.completion_tokens + logger.info("[MISTRAL_AI] reply={}".format(res_content)) + return { + "total_tokens": total_tokens, + "completion_tokens": completion_tokens, + "content": res_content, + } + except Exception as e: + result = {"total_tokens": 0, "completion_tokens": 0, "content": "我刚刚开小差了,请稍后再试一下"} + logger.warn("[MISTRAL_AI] Exception: {}".format(e)) + return result + + def _convert_to_mistral_messages(self, messages: list): + res = [] + res.append(ChatMessage(role="system", content=self.system_prompt)) + for msg in messages: + if msg.get("role") == "user": + role = "user" + elif msg.get("role") == "assistant": + role = "model" + else: + continue + res.append( + ChatMessage(role=role, content=msg.get("content"))) + return res + + def _filter_messages(self, messages: list): + res = [] + turn = "user" + for i in range(len(messages) - 1, -1, -1): + message = messages[i] + if message.get("role") != turn: + continue + res.insert(0, message) + if turn == "user": + turn = "assistant" + elif turn == "assistant": + turn = "user" + return res \ No newline at end of file diff --git a/bot/mistral/mistralai_session.py b/bot/mistral/mistralai_session.py new file mode 100644 index 000000000..6dcb809e3 --- /dev/null +++ b/bot/mistral/mistralai_session.py @@ -0,0 +1,76 @@ +from bot.session_manager import Session +from common.log import logger + + +class MistralAISession(Session): + def __init__(self, session_id, system_prompt=None, model="text-davinci-003"): + super().__init__(session_id, system_prompt) + self.model = model + self.reset() + + def __str__(self): + # 构造对话模型的输入 + """ + e.g. Q: xxx + A: xxx + Q: xxx + """ + prompt = "" + for item in self.messages: + if item["role"] == "system": + prompt += item["content"] + "<|endoftext|>\n\n\n" + elif item["role"] == "user": + prompt += "Q: " + item["content"] + "\n" + elif item["role"] == "assistant": + prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n" + + if len(self.messages) > 0 and self.messages[-1]["role"] == "user": + prompt += "A: " + return prompt + + def discard_exceeding(self, max_tokens, cur_tokens=None): + precise = True + try: + cur_tokens = self.calc_tokens() + except Exception as e: + precise = False + if cur_tokens is None: + raise e + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) + while cur_tokens > max_tokens: + if len(self.messages) > 1: + self.messages.pop(0) + elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant": + self.messages.pop(0) + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = len(str(self)) + break + elif len(self.messages) == 1 and self.messages[0]["role"] == "user": + logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) + break + else: + logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) + break + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = len(str(self)) + return cur_tokens + + def calc_tokens(self): + return num_tokens_from_string(str(self), self.model) + + +# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb +def num_tokens_from_string(string: str, model: str) -> int: + """Returns the number of tokens in a text string.""" + import tiktoken + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + logger.warn("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + num_tokens = len(encoding.encode(string, disallowed_special=())) + return num_tokens diff --git a/bridge/bridge.py b/bridge/bridge.py index 88e6b18c7..30b0e6466 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -33,6 +33,9 @@ def __init__(self): self.btype["chat"] = const.GEMINI if model_type in [const.ZHIPU_AI]: self.btype["chat"] = const.ZHIPU_AI + if model_type in [const.MODEL_MISTRAL_LARGE, const.MODEL_MISTRAL_MEDIUM, const.MODEL_MISTRAL_SMALL, + const.MODEL_MISTRAL_OPEN_7B, const.MODEL_MISTRAL_OPEN_8X7B]: + self.btype["chat"] = const.MISTRAL if conf().get("use_linkai") and conf().get("linkai_api_key"): self.btype["chat"] = const.LINKAI diff --git a/common/const.py b/common/const.py index aeb9dcc4e..183c6d956 100644 --- a/common/const.py +++ b/common/const.py @@ -9,7 +9,7 @@ QWEN = "qwen" GEMINI = "gemini" ZHIPU_AI = "glm-4" - +MISTRAL = "mistral" # model GPT35 = "gpt-3.5-turbo" @@ -19,10 +19,15 @@ WHISPER_1 = "whisper-1" TTS_1 = "tts-1" TTS_1_HD = "tts-1-hd" +MODEL_MISTRAL_OPEN_7B = "open-mistral-7b" +MODEL_MISTRAL_OPEN_8X7B = "open-mixtral-8x7b" +MODEL_MISTRAL_SMALL = "mistral-small-latest" +MODEL_MISTRAL_MEDIUM = "mistral-medium-latest" +MODEL_MISTRAL_LARGE = "mistral-large-latest" MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", - "gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI] + "gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI, MISTRAL] # channel FEISHU = "feishu" -DINGTALK = "dingtalk" +DINGTALK = "dingtalk" diff --git a/config.py b/config.py index 154c633fb..451353b66 100644 --- a/config.py +++ b/config.py @@ -75,6 +75,8 @@ "qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串 # Google Gemini Api Key "gemini_api_key": "", + # Mistral AI API Key + "mistralai_api_key": "", # wework的通用配置 "wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开 # 语音设置 diff --git a/requirements-optional.txt b/requirements-optional.txt index 74f1780e0..ab220bf0d 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -40,3 +40,6 @@ dingtalk_stream # zhipuai zhipuai>=2.0.1 + +#mistralai +mistralai \ No newline at end of file