Skip to content

Commit da09049

Browse files
committed
Enhance chatbot model handling for unrecorded models.
1 parent 3e57fbc commit da09049

9 files changed

+260
-58
lines changed

openlrc/agents.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _initialize_chatbot(self, chatbot_model: Union[str, ModelConfig], fee_limit:
4343
if isinstance(chatbot_model, str):
4444
chatbot_cls: Union[Type[ClaudeBot], Type[GPTBot], Type[GeminiBot]]
4545
chatbot_cls, model_name = route_chatbot(chatbot_model)
46-
return chatbot_cls(model_name=model_name, fee_limit=fee_limit, proxy=proxy, retry=2,
46+
return chatbot_cls(model_name=model_name, fee_limit=fee_limit, proxy=proxy, retry=4,
4747
temperature=self.TEMPERATURE, base_url_config=base_url_config)
4848
elif isinstance(chatbot_model, ModelConfig):
4949
chatbot_cls = provider2chatbot[chatbot_model.provider]
@@ -58,7 +58,7 @@ def _initialize_chatbot(self, chatbot_model: Union[str, ModelConfig], fee_limit:
5858
base_url_config = None
5959
logger.warning(f'Unsupported base_url configuration for provider: {chatbot_model.provider}')
6060

61-
return chatbot_cls(model_name=chatbot_model.name, fee_limit=fee_limit, proxy=proxy, retry=2,
61+
return chatbot_cls(model_name=chatbot_model.name, fee_limit=fee_limit, proxy=proxy, retry=4,
6262
temperature=self.TEMPERATURE, base_url_config=base_url_config,
6363
api_key=chatbot_model.api_key)
6464

@@ -190,7 +190,8 @@ def translate_chunk(self, chunk_id: int, chunk: List[Tuple[int, str]],
190190
guideline = context.guideline if use_glossary else context.non_glossary_guideline
191191
messages_list = [
192192
{'role': 'system', 'content': self.prompter.system()},
193-
{'role': 'user', 'content': self.prompter.user(chunk_id, user_input, context.summary, guideline)},
193+
{'role': 'user',
194+
'content': self.prompter.user(chunk_id, user_input, context.previous_summaries, guideline)},
194195
]
195196
resp = self.chatbot.message(messages_list, output_checker=self.prompter.check_format)[0]
196197
translations, summary, scene = self._parse_responses(resp)

openlrc/chatbot.py

+70-40
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# All rights reserved.
33

44
import asyncio
5+
import json
56
import os
67
import random
78
import re
@@ -16,9 +17,9 @@
1617
from anthropic import AsyncAnthropic
1718
from anthropic._types import NOT_GIVEN
1819
from anthropic.types import Message
19-
from google.generativeai import GenerationConfig
20-
from google.generativeai.types import AsyncGenerateContentResponse, GenerateContentResponse, HarmCategory, \
21-
HarmBlockThreshold
20+
from google import genai
21+
from google.genai import types
22+
from google.genai.types import HarmCategory, HarmBlockThreshold
2223
from openai import AsyncClient as AsyncGPTClient
2324
from openai.types.chat import ChatCompletion
2425

@@ -57,10 +58,7 @@ def route_chatbot(model: str) -> (type, str):
5758
chatbot_type, chatbot_model = re.match(r'(.+):(.+)', model).groups()
5859
chatbot_type, chatbot_model = chatbot_type.strip().lower(), chatbot_model.strip()
5960

60-
try:
61-
Models.get_model(chatbot_model)
62-
except ValueError:
63-
raise ValueError(f'Invalid model {chatbot_model}.')
61+
Models.get_model(chatbot_model)
6462

6563
if chatbot_type == 'openai':
6664
return GPTBot, chatbot_model
@@ -235,7 +233,8 @@ async def _create_achat(self, messages: List[Dict], stop_sequences: Optional[Lis
235233
continue
236234

237235
break
238-
except (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError, openai.APIError) as e:
236+
except (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError, openai.APIError,
237+
json.decoder.JSONDecodeError) as e:
239238
sleep_time = self._get_sleep_time(e)
240239
logger.warning(f'{type(e).__name__}: {e}. Wait {sleep_time}s before retry. Retry num: {i + 1}.')
241240
time.sleep(sleep_time)
@@ -251,6 +250,8 @@ def _get_sleep_time(error):
251250
return random.randint(30, 60)
252251
elif isinstance(error, openai.APITimeoutError):
253252
return 3
253+
elif isinstance(error, json.decoder.JSONDecodeError):
254+
return 1
254255
else:
255256
return 15
256257

@@ -349,26 +350,45 @@ def __init__(self, model_name='gemini-2.0-flash-exp', temperature=1, top_p=1, re
349350

350351
self.model_name = model_name
351352

352-
genai.configure(api_key=api_key or os.environ['GOOGLE_API_KEY'])
353-
self.config = GenerationConfig(temperature=self.temperature, top_p=self.top_p)
353+
# genai.configure(api_key=api_key or os.environ['GOOGLE_API_KEY'])
354+
self.client = genai.Client(
355+
api_key=api_key or os.environ['GOOGLE_API_KEY']
356+
)
357+
354358
# Should not block any translation-related content.
355-
self.safety_settings = {
356-
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
357-
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
358-
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
359-
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE
360-
}
359+
# self.safety_settings = {
360+
# HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
361+
# HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
362+
# HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
363+
# HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE
364+
# }
365+
self.safety_settings = [
366+
types.SafetySetting(
367+
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE
368+
),
369+
types.SafetySetting(
370+
category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE
371+
),
372+
types.SafetySetting(
373+
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE
374+
),
375+
types.SafetySetting(
376+
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE
377+
)
378+
]
379+
self.config = types.GenerateContentConfig(temperature=self.temperature, top_p=self.top_p,
380+
safety_settings=self.safety_settings)
361381

362382
if proxy:
363383
logger.warning('Google Gemini SDK does not support proxy, try using the system-level proxy if needed.')
364384

365385
if base_url_config:
366386
logger.warning('Google Gemini SDK does not support changing base_url.')
367387

368-
def update_fee(self, response: Union[GenerateContentResponse, AsyncGenerateContentResponse]):
388+
def update_fee(self, response: types.GenerateContentResponse):
369389
model_info = self.model_info
370390
prompt_tokens = response.usage_metadata.prompt_token_count
371-
completion_tokens = response.usage_metadata.candidates_token_count
391+
completion_tokens = response.usage_metadata.candidates_token_count or 0
372392

373393
self.api_fees[-1] += (prompt_tokens * model_info.input_price +
374394
completion_tokens * model_info.output_price) / 1000000
@@ -401,31 +421,41 @@ async def _create_achat(self, messages: List[Dict], stop_sequences: Optional[Lis
401421
history_messages[i]['parts'] = [{'text': content}]
402422

403423
self.config.stop_sequences = stop_sequences
404-
generative_model = genai.GenerativeModel(model_name=self.model_name, generation_config=self.config,
405-
safety_settings=self.safety_settings, system_instruction=system_msg)
406-
client = genai.ChatSession(generative_model, history=history_messages)
424+
# generative_model = genai.GenerativeModel(model_name=self.model_name, generation_config=self.config,
425+
# safety_settings=self.safety_settings, system_instruction=system_msg)
426+
# client = genai.ChatSession(generative_model, history=history_messages)
427+
self.config.system_instruction = system_msg
407428

408429
response = None
409430
for i in range(self.retry):
410-
try:
411-
# send_message_async is buggy, so we use send_message instead as a workaround
412-
response = client.send_message(user_msg, safety_settings=self.safety_settings)
413-
self.update_fee(response)
414-
if not output_checker(user_msg, response.text):
415-
logger.warning(f'Invalid response format. Retry num: {i + 1}.')
416-
continue
417-
418-
if not response._done:
419-
logger.warning(f'Failed to get a complete response. Retry num: {i + 1}.')
420-
continue
421-
422-
break
423-
except (genai.types.BrokenResponseError, genai.types.IncompleteIterationError,
424-
genai.types.StopCandidateException) as e:
425-
logger.warning(f'{type(e).__name__}: {e}. Retry num: {i + 1}.')
426-
except genai.types.generation_types.BlockedPromptException as e:
427-
logger.warning(f'Prompt blocked: {e}.\n Retry in 30s.')
428-
time.sleep(30)
431+
# try:
432+
# send_message_async is buggy, so we use send_message instead as a workaround
433+
# response = client.send_message(user_msg, safety_settings=self.safety_settings)
434+
response = await self.client.aio.models.generate_content(
435+
model=self.model_name,
436+
contents=user_msg,
437+
config=self.config,
438+
)
439+
self.update_fee(response)
440+
if not response.text:
441+
logger.warning(f'Get None response. Wait 15s. Retry num: {i + 1}.')
442+
time.sleep(15)
443+
continue
444+
445+
if not output_checker(user_msg, response.text):
446+
logger.warning(f'Invalid response format. Retry num: {i + 1}.')
447+
continue
448+
449+
if not response:
450+
logger.warning(f'Failed to get a complete response. Retry num: {i + 1}.')
451+
continue
452+
453+
break
454+
# except Exception as e:
455+
# logger.warning(f'{type(e).__name__}: {e}. Retry num: {i + 1}.')
456+
# time.sleep(3)
457+
# except genai.types.generation_types.BlockedPromptException as e:
458+
# logger.warning(f'Prompt blocked: {e}.\n Retry in 30s.')
429459

430460
if not response:
431461
raise ChatBotException('Failed to create a chat.')

openlrc/context.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
# Copyright (C) 2024. Hao Zheng
1+
# Copyright (C) 2025. Hao Zheng
22
# All rights reserved.
33
import re
4-
from typing import Optional, Union
4+
from typing import Optional, Union, List
55

66
from pydantic import BaseModel
77

88
from openlrc import ModelConfig
99

1010

1111
class TranslationContext(BaseModel):
12+
previous_summaries: Optional[List[str]] = None
1213
summary: Optional[str] = ''
1314
scene: Optional[str] = ''
1415
model: Optional[Union[str, ModelConfig]] = None

0 commit comments

Comments
 (0)