Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

classify errors in chatbot #65

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 43 additions & 28 deletions openlrc/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import re
import time
from copy import deepcopy
from typing import List, Union, Dict, Callable, Optional
from dataclasses import dataclass, field
from typing import List, Set, Union, Dict, Callable, Optional

import anthropic
import google.generativeai as genai
Expand Down Expand Up @@ -75,8 +76,38 @@
return model2chatbot[model], model



@dataclass
class ErrorGroup:
errs: Set[Exception]
sleep_time: int | Callable[[], int] # TODO: define the whole retry model rather than only sleep time

def get_sleep_time(self):
if callable(self.sleep_time):
return self.sleep_time()
else:
return self.sleep_time

@dataclass
class ClassifiedErrors:
long_wait: ErrorGroup = field(default_factory=lambda: ErrorGroup(set(), lambda: random.randint(30, 60)))

Check notice

Code scanning / Bandit

Standard pseudo-random generators are not suitable for security/cryptographic purposes. Note

Standard pseudo-random generators are not suitable for security/cryptographic purposes.
short_wait: ErrorGroup = field(default_factory=lambda: ErrorGroup(set(), 3))
default: ErrorGroup = field(default_factory=lambda: ErrorGroup(set(), 1))

def get_sleep_time(self, err):
if err in self.long_wait.errs:
return self.long_wait.get_sleep_time()
elif err in self.short_wait.errs:
return self.short_wait.get_sleep_time()
else:
return self.default.get_sleep_time()

default_chatbot_classified_errors = ClassifiedErrors()
default_chatbot_classified_errors.long_wait.errs.add(openai.RateLimitError)
default_chatbot_classified_errors.short_wait.errs.add(openai.APITimeoutError)

class ChatBot:
def __init__(self, model_name, temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.8, beta=False):
def __init__(self, model_name, temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.8, beta=False, classified_errors=default_chatbot_classified_errors):
try:
self.model_info = Models.get_model(model_name, beta)
self.model_name = model_name
Expand All @@ -88,9 +119,9 @@
self.retry = retry
self.max_async = max_async
self.fee_limit = fee_limit
self.classified_errors = classified_errors

self.api_fees = []

def estimate_fee(self, messages: List[Dict]):
"""
Estimate the total fee for the given messages.
Expand Down Expand Up @@ -171,11 +202,10 @@
def __str__(self):
return f'ChatBot ({self.model_name})'


@_register_chatbot
class GPTBot(ChatBot):
def __init__(self, model_name='gpt-4o-mini', temperature=1, top_p=1, retry=8, max_async=16, json_mode=False,
fee_limit=0.05, proxy=None, base_url_config=None, api_key=None):
fee_limit=0.05, proxy=None, base_url_config=None, api_key=None, classified_errors=default_chatbot_classified_errors):

# clamp temperature to 0-2
temperature = max(0, min(2, temperature))
Expand All @@ -184,7 +214,7 @@
if base_url_config and base_url_config['openai'] == 'https://api.deepseek.com/beta':
is_beta = True

super().__init__(model_name, temperature, top_p, retry, max_async, fee_limit, is_beta)
super().__init__(model_name, temperature, top_p, retry, max_async, fee_limit, is_beta, classified_errors)

self.async_client = AsyncGPTClient(
api_key=api_key or os.environ['OPENAI_API_KEY'],
Expand Down Expand Up @@ -236,7 +266,7 @@

break
except (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError, openai.APIError) as e:
sleep_time = self._get_sleep_time(e)
sleep_time = self.classified_errors.get_sleep_time(e)
logger.warning(f'{type(e).__name__}: {e}. Wait {sleep_time}s before retry. Retry num: {i + 1}.')
time.sleep(sleep_time)

Expand All @@ -245,25 +275,19 @@

return response

@staticmethod
def _get_sleep_time(error):
if isinstance(error, openai.RateLimitError):
return random.randint(30, 60)
elif isinstance(error, openai.APITimeoutError):
return 3
else:
return 15


default_claudebot_classified_errors = ClassifiedErrors()
default_claudebot_classified_errors.long_wait.errs.add(anthropic.RateLimitError)
default_claudebot_classified_errors.short_wait.errs.add(anthropic.APITimeoutError)
@_register_chatbot
class ClaudeBot(ChatBot):
def __init__(self, model_name='claude-3-5-sonnet-20241022', temperature=1, top_p=1, retry=8, max_async=16,
fee_limit=0.8, proxy=None, base_url_config=None, api_key=None):
fee_limit=0.8, proxy=None, base_url_config=None, api_key=None, classified_errors=default_claudebot_classified_errors):

# clamp temperature to 0-1
temperature = max(0, min(1, temperature))

super().__init__(model_name, temperature, top_p, retry, max_async, fee_limit)
super().__init__(model_name, temperature, top_p, retry, max_async, fee_limit, classified_errors=classified_errors)

self.async_client = AsyncAnthropic(
api_key=api_key or os.environ['ANTHROPIC_API_KEY'],
Expand Down Expand Up @@ -321,7 +345,7 @@
except (
anthropic.RateLimitError, anthropic.APITimeoutError, anthropic.APIConnectionError,
anthropic.APIError) as e:
sleep_time = self._get_sleep_time(e)
sleep_time = self.classified_errors.get_sleep_time(e)
logger.warning(f'{type(e).__name__}: {e}. Wait {sleep_time}s before retry. Retry num: {i + 1}.')
time.sleep(sleep_time)

Expand All @@ -330,15 +354,6 @@

return response

def _get_sleep_time(self, error):
if isinstance(error, anthropic.RateLimitError):
return random.randint(30, 60)
elif isinstance(error, anthropic.APITimeoutError):
return 3
else:
return 15


@_register_chatbot
class GeminiBot(ChatBot):
def __init__(self, model_name='gemini-2.0-flash-exp', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.8,
Expand Down
4 changes: 2 additions & 2 deletions openlrc/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _save_intermediate_results(self, compare_path: Path, compare_list: List[dict
with open(compare_path, 'w', encoding='utf-8') as f:
json.dump(compare_results, f, indent=4, ensure_ascii=False)

def atomic_translate(self, chatbot_model: Union[str, ModelConfig], texts: List[str], src_lang: str,
def atomic_translate(self, texts: List[str], src_lang: str,
target_lang: str) -> List[str]:
"""
Perform atomic translation for each text individually.
Expand All @@ -319,7 +319,7 @@ def atomic_translate(self, chatbot_model: Union[str, ModelConfig], texts: List[s
Raises:
AssertionError: If the number of translated texts doesn't match the input.
"""
chatbot = ChunkedTranslatorAgent(src_lang, target_lang, TranslateInfo(), chatbot_model, self.fee_limit,
chatbot = ChunkedTranslatorAgent(src_lang, target_lang, TranslateInfo(), self.chatbot_model, self.fee_limit,
self.proxy,
self.base_url_config).chatbot

Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ matplotlib = "^3.9.2"
typing-extensions = "^4.12.2"
onnxruntime = "^1.20.0"

[tool.poetry.group.dev.dependencies]
torch = "^2.6.0"
torchvision = "^0.21.0"
torchaudio = "^2.6.0"
#torch = ">=2.0.0, !=2.0.1"
#torchaudio = "^2.0.0"
#torchvision = "^0.17.1"
Expand All @@ -89,4 +93,4 @@ priority = "primary"
"Bug Tracker" = "https://github.com/zh-plus/Open-Lyrics/issues"

[tool.poetry.scripts]
openlrc = "openlrc.cli:main"
openlrc = "openlrc.cli:main"
19 changes: 7 additions & 12 deletions tests/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from openlrc.utils import get_similarity

test_models = ['gpt-4o-mini', 'claude-3-5-haiku-20241022']

test_translators = [LLMTranslator(m) for m in test_models]

class TestLLMTranslator(unittest.TestCase):

Expand All @@ -20,25 +20,22 @@ def tearDown(self) -> None:
compare_path.unlink(missing_ok=True)

def test_single_chunk_translation(self):
for chatbot_model in test_models:
for translator in test_translators:
text = 'Hello, how are you?'
translator = LLMTranslator(chatbot_model)
translation = translator.translate(text, 'en', 'es')[0]

self.assertGreater(get_similarity(translation, 'Hola, ¿cómo estás?'), 0.5)

def test_multiple_chunk_translation(self):
for chatbot_model in test_models:
for translator in test_translators:
texts = ['Hello, how are you?', 'I am fine, thank you.']
translator = LLMTranslator(chatbot_model)
translations = translator.translate(texts, 'en', 'es')
self.assertGreater(get_similarity(translations[0], 'Hola, ¿cómo estás?'), 0.5)
self.assertGreater(get_similarity(translations[1], 'Estoy bien, gracias.'), 0.5)

def test_different_language_translation(self):
for chatbot_model in test_models:
for translator in test_translators:
text = 'Hello, how are you?'
translator = LLMTranslator(chatbot_model)
try:
translation = translator.translate(text, 'en', 'ja')[0]
self.assertTrue(
Expand All @@ -49,17 +46,15 @@ def test_different_language_translation(self):
pass

def test_empty_text_list_translation(self):
for chatbot_model in test_models:
for translator in test_translators:
texts = []
translator = LLMTranslator(chatbot_model)
translations = translator.translate(texts, 'en', 'es')
self.assertEqual(translations, [])

def test_atomic_translate(self):
for chatbot_model in test_models:
for translator in test_translators:
texts = ['Hello, how are you?', 'I am fine, thank you.']
translator = LLMTranslator(chatbot_model)
translations = translator.atomic_translate(chatbot_model, texts, 'en', 'zh')
translations = translator.atomic_translate( texts, 'en', 'zh')
self.assertGreater(get_similarity(translations[0], '你好,你好吗?'), 0.5)
self.assertGreater(get_similarity(translations[1], '我很好,谢谢。'), 0.5)

Expand Down
Loading