From ad59d1c3e003ec2a892a3612b7255aedd81196f4 Mon Sep 17 00:00:00 2001 From: Antonio Date: Fri, 2 Jun 2023 01:37:40 +0800 Subject: [PATCH] Async support (default) --- setup.py | 4 +-- src/Bard.py | 72 +++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 58 insertions(+), 18 deletions(-) diff --git a/setup.py b/setup.py index 0779995..de735b3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="GoogleBard", - version="1.2.2", + version="1.3.0", license="MIT License", author="Antonio Cheong", author_email="acheong@student.dalat.org", @@ -12,7 +12,7 @@ package_dir={"": "src"}, url="https://github.com/acheong08/Bard", project_urls={"Bug Report": "https://github.com/acheong08/Bard/issues/new"}, - install_requires=["requests", "prompt_toolkit", "rich"], + install_requires=["prompt_toolkit", "rich", "httpx[socks]"], long_description=open("README.md", encoding="utf-8").read(), long_description_content_type="text/markdown", py_modules=["Bard"], diff --git a/src/Bard.py b/src/Bard.py index b388998..5d130fa 100644 --- a/src/Bard.py +++ b/src/Bard.py @@ -5,9 +5,9 @@ import re import string import sys -import time -import requests +import httpx +import asyncio from prompt_toolkit import prompt from prompt_toolkit import PromptSession from prompt_toolkit.auto_suggest import AutoSuggestFromHistory @@ -49,6 +49,41 @@ def __get_input( class Chatbot: + """ + Synchronous wrapper for the AsyncChatbot class. + """ + + def __init__( + self, + session_id: str, + proxy: dict = None, + timeout: int = 20, + ): + self.loop = asyncio.get_event_loop() + self.async_chatbot = self.loop.run_until_complete( + AsyncChatbot.create(session_id, proxy, timeout) + ) + + def save_conversation(self, file_path: str, conversation_name: str): + return self.loop.run_until_complete( + self.async_chatbot.save_conversation(file_path, conversation_name) + ) + + def load_conversations(self, file_path: str) -> List[Dict]: + return self.loop.run_until_complete( + self.async_chatbot.load_conversations(file_path) + ) + + def load_conversation(self, file_path: str, conversation_name: str) -> bool: + return self.loop.run_until_complete( + self.async_chatbot.load_conversation(file_path, conversation_name) + ) + + def ask(self, message: str) -> dict: + return self.loop.run_until_complete(self.async_chatbot.ask(message)) + + +class AsyncChatbot: """ A class to interact with Google Bard. Parameters @@ -57,8 +92,6 @@ class Chatbot: proxy: str timeout: int Request timeout in seconds. - session: requests.Session - Requests session object. """ __slots__ = [ @@ -79,7 +112,6 @@ def __init__( session_id: str, proxy: dict = None, timeout: int = 20, - session: requests.Session = None, ): headers = { "Host": "bard.google.com", @@ -95,13 +127,23 @@ def __init__( self.response_id = "" self.choice_id = "" self.session_id = session_id - self.session = session or requests.Session() + self.session = httpx.AsyncClient(proxies=self.proxy) self.session.headers = headers self.session.cookies.set("__Secure-1PSID", session_id) - self.SNlM0e = self.__get_snlm0e() self.timeout = timeout - def save_conversation(self, file_path: str, conversation_name: str): + @classmethod + async def create( + cls, + session_id: str, + proxy: dict = None, + timeout: int = 20, + ) -> "AsyncChatbot": + instance = cls(session_id, proxy, timeout) + instance.SNlM0e = await instance.__get_snlm0e() + return instance + + async def save_conversation(self, file_path: str, conversation_name: str): conversations = self.load_conversations(file_path) conversation_details = { { @@ -118,14 +160,14 @@ def save_conversation(self, file_path: str, conversation_name: str): with open(file_path, "w", encoding="utf-8") as f: json.dump(conversations, f, indent=4) - def load_conversations(self, file_path: str) -> List[Dict]: + async def load_conversations(self, file_path: str) -> List[Dict]: # Check if file exists if not os.path.isfile(file_path): return [] with open(file_path, encoding="utf-8") as f: return json.load(f) - def load_conversation(self, file_path: str, conversation_name: str) -> bool: + async def load_conversation(self, file_path: str, conversation_name: str) -> bool: """ Loads a conversation from history file. Returns whether the conversation was found. """ @@ -140,16 +182,15 @@ def load_conversation(self, file_path: str, conversation_name: str) -> bool: return True return False - def __get_snlm0e(self): + async def __get_snlm0e(self): # Find "SNlM0e":"" if not self.session_id or self.session_id[-1] != ".": raise Exception( "__Secure-1PSID value must end with a single dot. Enter correct __Secure-1PSID value.", ) - resp = self.session.get( + resp = await self.session.get( "https://bard.google.com/", timeout=10, - proxies=self.proxy, ) if resp.status_code != 200: raise Exception( @@ -162,7 +203,7 @@ def __get_snlm0e(self): ) return SNlM0e.group(1) - def ask(self, message: str) -> dict: + async def ask(self, message: str) -> dict: """ Send a message to Google Bard and return the response. :param message: The message to send to Google Bard. @@ -185,12 +226,11 @@ def ask(self, message: str) -> dict: "f.req": json.dumps([None, json.dumps(message_struct)]), "at": self.SNlM0e, } - resp = self.session.post( + resp = await self.session.post( "https://bard.google.com/_/BardChatUi/data/assistant.lamda.BardFrontendService/StreamGenerate", params=params, data=data, timeout=self.timeout, - proxies=self.proxy, ) chat_data = json.loads(resp.content.splitlines()[3])[0][2] if not chat_data: