From b3073d226bdcf122eed6538a362e179bd6f141dd Mon Sep 17 00:00:00 2001 From: maleicacid Date: Sun, 23 Feb 2025 21:48:20 +0800 Subject: [PATCH] classify errors in chatbot --- openlrc/chatbot.py | 71 +++++++++++++++++++++++++---------------- openlrc/translate.py | 4 +-- pyproject.toml | 6 +++- tests/test_translate.py | 19 ++++------- 4 files changed, 57 insertions(+), 43 deletions(-) diff --git a/openlrc/chatbot.py b/openlrc/chatbot.py index 34686c0..26804c3 100644 --- a/openlrc/chatbot.py +++ b/openlrc/chatbot.py @@ -7,7 +7,8 @@ import re import time from copy import deepcopy -from typing import List, Union, Dict, Callable, Optional +from dataclasses import dataclass, field +from typing import List, Set, Union, Dict, Callable, Optional import anthropic import google.generativeai as genai @@ -75,8 +76,38 @@ def route_chatbot(model: str) -> (type, str): return model2chatbot[model], model + +@dataclass +class ErrorGroup: + errs: Set[Exception] + sleep_time: int | Callable[[], int] # TODO: define the whole retry model rather than only sleep time + + def get_sleep_time(self): + if callable(self.sleep_time): + return self.sleep_time() + else: + return self.sleep_time + +@dataclass +class ClassifiedErrors: + long_wait: ErrorGroup = field(default_factory=lambda: ErrorGroup(set(), lambda: random.randint(30, 60))) + short_wait: ErrorGroup = field(default_factory=lambda: ErrorGroup(set(), 3)) + default: ErrorGroup = field(default_factory=lambda: ErrorGroup(set(), 1)) + + def get_sleep_time(self, err): + if err in self.long_wait.errs: + return self.long_wait.get_sleep_time() + elif err in self.short_wait.errs: + return self.short_wait.get_sleep_time() + else: + return self.default.get_sleep_time() + +default_chatbot_classified_errors = ClassifiedErrors() +default_chatbot_classified_errors.long_wait.errs.add(openai.RateLimitError) +default_chatbot_classified_errors.short_wait.errs.add(openai.APITimeoutError) + class ChatBot: - def __init__(self, model_name, temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.8, beta=False): + def __init__(self, model_name, temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.8, beta=False, classified_errors=default_chatbot_classified_errors): try: self.model_info = Models.get_model(model_name, beta) self.model_name = model_name @@ -88,9 +119,9 @@ def __init__(self, model_name, temperature=1, top_p=1, retry=8, max_async=16, fe self.retry = retry self.max_async = max_async self.fee_limit = fee_limit + self.classified_errors = classified_errors self.api_fees = [] - def estimate_fee(self, messages: List[Dict]): """ Estimate the total fee for the given messages. @@ -171,11 +202,10 @@ def message(self, messages_list: Union[List[Dict], List[List[Dict]]], stop_seque def __str__(self): return f'ChatBot ({self.model_name})' - @_register_chatbot class GPTBot(ChatBot): def __init__(self, model_name='gpt-4o-mini', temperature=1, top_p=1, retry=8, max_async=16, json_mode=False, - fee_limit=0.05, proxy=None, base_url_config=None, api_key=None): + fee_limit=0.05, proxy=None, base_url_config=None, api_key=None, classified_errors=default_chatbot_classified_errors): # clamp temperature to 0-2 temperature = max(0, min(2, temperature)) @@ -184,7 +214,7 @@ def __init__(self, model_name='gpt-4o-mini', temperature=1, top_p=1, retry=8, ma if base_url_config and base_url_config['openai'] == 'https://api.deepseek.com/beta': is_beta = True - super().__init__(model_name, temperature, top_p, retry, max_async, fee_limit, is_beta) + super().__init__(model_name, temperature, top_p, retry, max_async, fee_limit, is_beta, classified_errors) self.async_client = AsyncGPTClient( api_key=api_key or os.environ['OPENAI_API_KEY'], @@ -236,7 +266,7 @@ async def _create_achat(self, messages: List[Dict], stop_sequences: Optional[Lis break except (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError, openai.APIError) as e: - sleep_time = self._get_sleep_time(e) + sleep_time = self.classified_errors.get_sleep_time(e) logger.warning(f'{type(e).__name__}: {e}. Wait {sleep_time}s before retry. Retry num: {i + 1}.') time.sleep(sleep_time) @@ -245,25 +275,19 @@ async def _create_achat(self, messages: List[Dict], stop_sequences: Optional[Lis return response - @staticmethod - def _get_sleep_time(error): - if isinstance(error, openai.RateLimitError): - return random.randint(30, 60) - elif isinstance(error, openai.APITimeoutError): - return 3 - else: - return 15 - +default_claudebot_classified_errors = ClassifiedErrors() +default_claudebot_classified_errors.long_wait.errs.add(anthropic.RateLimitError) +default_claudebot_classified_errors.short_wait.errs.add(anthropic.APITimeoutError) @_register_chatbot class ClaudeBot(ChatBot): def __init__(self, model_name='claude-3-5-sonnet-20241022', temperature=1, top_p=1, retry=8, max_async=16, - fee_limit=0.8, proxy=None, base_url_config=None, api_key=None): + fee_limit=0.8, proxy=None, base_url_config=None, api_key=None, classified_errors=default_claudebot_classified_errors): # clamp temperature to 0-1 temperature = max(0, min(1, temperature)) - super().__init__(model_name, temperature, top_p, retry, max_async, fee_limit) + super().__init__(model_name, temperature, top_p, retry, max_async, fee_limit, classified_errors=classified_errors) self.async_client = AsyncAnthropic( api_key=api_key or os.environ['ANTHROPIC_API_KEY'], @@ -321,7 +345,7 @@ async def _create_achat(self, messages: List[Dict], stop_sequences: Optional[Lis except ( anthropic.RateLimitError, anthropic.APITimeoutError, anthropic.APIConnectionError, anthropic.APIError) as e: - sleep_time = self._get_sleep_time(e) + sleep_time = self.classified_errors.get_sleep_time(e) logger.warning(f'{type(e).__name__}: {e}. Wait {sleep_time}s before retry. Retry num: {i + 1}.') time.sleep(sleep_time) @@ -330,15 +354,6 @@ async def _create_achat(self, messages: List[Dict], stop_sequences: Optional[Lis return response - def _get_sleep_time(self, error): - if isinstance(error, anthropic.RateLimitError): - return random.randint(30, 60) - elif isinstance(error, anthropic.APITimeoutError): - return 3 - else: - return 15 - - @_register_chatbot class GeminiBot(ChatBot): def __init__(self, model_name='gemini-2.0-flash-exp', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.8, diff --git a/openlrc/translate.py b/openlrc/translate.py index 3d0c659..1a0374b 100644 --- a/openlrc/translate.py +++ b/openlrc/translate.py @@ -299,7 +299,7 @@ def _save_intermediate_results(self, compare_path: Path, compare_list: List[dict with open(compare_path, 'w', encoding='utf-8') as f: json.dump(compare_results, f, indent=4, ensure_ascii=False) - def atomic_translate(self, chatbot_model: Union[str, ModelConfig], texts: List[str], src_lang: str, + def atomic_translate(self, texts: List[str], src_lang: str, target_lang: str) -> List[str]: """ Perform atomic translation for each text individually. @@ -319,7 +319,7 @@ def atomic_translate(self, chatbot_model: Union[str, ModelConfig], texts: List[s Raises: AssertionError: If the number of translated texts doesn't match the input. """ - chatbot = ChunkedTranslatorAgent(src_lang, target_lang, TranslateInfo(), chatbot_model, self.fee_limit, + chatbot = ChunkedTranslatorAgent(src_lang, target_lang, TranslateInfo(), self.chatbot_model, self.fee_limit, self.proxy, self.base_url_config).chatbot diff --git a/pyproject.toml b/pyproject.toml index b72cb1a..78155d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,10 @@ matplotlib = "^3.9.2" typing-extensions = "^4.12.2" onnxruntime = "^1.20.0" +[tool.poetry.group.dev.dependencies] +torch = "^2.6.0" +torchvision = "^0.21.0" +torchaudio = "^2.6.0" #torch = ">=2.0.0, !=2.0.1" #torchaudio = "^2.0.0" #torchvision = "^0.17.1" @@ -89,4 +93,4 @@ priority = "primary" "Bug Tracker" = "https://github.com/zh-plus/Open-Lyrics/issues" [tool.poetry.scripts] -openlrc = "openlrc.cli:main" +openlrc = "openlrc.cli:main" \ No newline at end of file diff --git a/tests/test_translate.py b/tests/test_translate.py index 6c00b98..15b3fc7 100644 --- a/tests/test_translate.py +++ b/tests/test_translate.py @@ -11,7 +11,7 @@ from openlrc.utils import get_similarity test_models = ['gpt-4o-mini', 'claude-3-5-haiku-20241022'] - +test_translators = [LLMTranslator(m) for m in test_models] class TestLLMTranslator(unittest.TestCase): @@ -20,25 +20,22 @@ def tearDown(self) -> None: compare_path.unlink(missing_ok=True) def test_single_chunk_translation(self): - for chatbot_model in test_models: + for translator in test_translators: text = 'Hello, how are you?' - translator = LLMTranslator(chatbot_model) translation = translator.translate(text, 'en', 'es')[0] self.assertGreater(get_similarity(translation, 'Hola, ¿cómo estás?'), 0.5) def test_multiple_chunk_translation(self): - for chatbot_model in test_models: + for translator in test_translators: texts = ['Hello, how are you?', 'I am fine, thank you.'] - translator = LLMTranslator(chatbot_model) translations = translator.translate(texts, 'en', 'es') self.assertGreater(get_similarity(translations[0], 'Hola, ¿cómo estás?'), 0.5) self.assertGreater(get_similarity(translations[1], 'Estoy bien, gracias.'), 0.5) def test_different_language_translation(self): - for chatbot_model in test_models: + for translator in test_translators: text = 'Hello, how are you?' - translator = LLMTranslator(chatbot_model) try: translation = translator.translate(text, 'en', 'ja')[0] self.assertTrue( @@ -49,17 +46,15 @@ def test_different_language_translation(self): pass def test_empty_text_list_translation(self): - for chatbot_model in test_models: + for translator in test_translators: texts = [] - translator = LLMTranslator(chatbot_model) translations = translator.translate(texts, 'en', 'es') self.assertEqual(translations, []) def test_atomic_translate(self): - for chatbot_model in test_models: + for translator in test_translators: texts = ['Hello, how are you?', 'I am fine, thank you.'] - translator = LLMTranslator(chatbot_model) - translations = translator.atomic_translate(chatbot_model, texts, 'en', 'zh') + translations = translator.atomic_translate( texts, 'en', 'zh') self.assertGreater(get_similarity(translations[0], '你好,你好吗?'), 0.5) self.assertGreater(get_similarity(translations[1], '我很好,谢谢。'), 0.5)