2
2
# All rights reserved.
3
3
4
4
import asyncio
5
+ import json
5
6
import os
6
7
import random
7
8
import re
16
17
from anthropic import AsyncAnthropic
17
18
from anthropic ._types import NOT_GIVEN
18
19
from anthropic .types import Message
19
- from google . generativeai import GenerationConfig
20
- from google .generativeai . types import AsyncGenerateContentResponse , GenerateContentResponse , HarmCategory , \
21
- HarmBlockThreshold
20
+ from google import genai
21
+ from google .genai import types
22
+ from google . genai . types import HarmCategory , HarmBlockThreshold
22
23
from openai import AsyncClient as AsyncGPTClient
23
24
from openai .types .chat import ChatCompletion
24
25
@@ -57,10 +58,7 @@ def route_chatbot(model: str) -> (type, str):
57
58
chatbot_type , chatbot_model = re .match (r'(.+):(.+)' , model ).groups ()
58
59
chatbot_type , chatbot_model = chatbot_type .strip ().lower (), chatbot_model .strip ()
59
60
60
- try :
61
- Models .get_model (chatbot_model )
62
- except ValueError :
63
- raise ValueError (f'Invalid model { chatbot_model } .' )
61
+ Models .get_model (chatbot_model )
64
62
65
63
if chatbot_type == 'openai' :
66
64
return GPTBot , chatbot_model
@@ -235,7 +233,8 @@ async def _create_achat(self, messages: List[Dict], stop_sequences: Optional[Lis
235
233
continue
236
234
237
235
break
238
- except (openai .RateLimitError , openai .APITimeoutError , openai .APIConnectionError , openai .APIError ) as e :
236
+ except (openai .RateLimitError , openai .APITimeoutError , openai .APIConnectionError , openai .APIError ,
237
+ json .decoder .JSONDecodeError ) as e :
239
238
sleep_time = self ._get_sleep_time (e )
240
239
logger .warning (f'{ type (e ).__name__ } : { e } . Wait { sleep_time } s before retry. Retry num: { i + 1 } .' )
241
240
time .sleep (sleep_time )
@@ -251,6 +250,8 @@ def _get_sleep_time(error):
251
250
return random .randint (30 , 60 )
252
251
elif isinstance (error , openai .APITimeoutError ):
253
252
return 3
253
+ elif isinstance (error , json .decoder .JSONDecodeError ):
254
+ return 1
254
255
else :
255
256
return 15
256
257
@@ -349,26 +350,45 @@ def __init__(self, model_name='gemini-2.0-flash-exp', temperature=1, top_p=1, re
349
350
350
351
self .model_name = model_name
351
352
352
- genai .configure (api_key = api_key or os .environ ['GOOGLE_API_KEY' ])
353
- self .config = GenerationConfig (temperature = self .temperature , top_p = self .top_p )
353
+ # genai.configure(api_key=api_key or os.environ['GOOGLE_API_KEY'])
354
+ self .client = genai .Client (
355
+ api_key = api_key or os .environ ['GOOGLE_API_KEY' ]
356
+ )
357
+
354
358
# Should not block any translation-related content.
355
- self .safety_settings = {
356
- HarmCategory .HARM_CATEGORY_HATE_SPEECH : HarmBlockThreshold .BLOCK_NONE ,
357
- HarmCategory .HARM_CATEGORY_HARASSMENT : HarmBlockThreshold .BLOCK_NONE ,
358
- HarmCategory .HARM_CATEGORY_SEXUALLY_EXPLICIT : HarmBlockThreshold .BLOCK_NONE ,
359
- HarmCategory .HARM_CATEGORY_DANGEROUS_CONTENT : HarmBlockThreshold .BLOCK_NONE
360
- }
359
+ # self.safety_settings = {
360
+ # HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
361
+ # HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
362
+ # HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
363
+ # HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE
364
+ # }
365
+ self .safety_settings = [
366
+ types .SafetySetting (
367
+ category = HarmCategory .HARM_CATEGORY_HATE_SPEECH , threshold = HarmBlockThreshold .BLOCK_NONE
368
+ ),
369
+ types .SafetySetting (
370
+ category = HarmCategory .HARM_CATEGORY_HARASSMENT , threshold = HarmBlockThreshold .BLOCK_NONE
371
+ ),
372
+ types .SafetySetting (
373
+ category = HarmCategory .HARM_CATEGORY_SEXUALLY_EXPLICIT , threshold = HarmBlockThreshold .BLOCK_NONE
374
+ ),
375
+ types .SafetySetting (
376
+ category = HarmCategory .HARM_CATEGORY_DANGEROUS_CONTENT , threshold = HarmBlockThreshold .BLOCK_NONE
377
+ )
378
+ ]
379
+ self .config = types .GenerateContentConfig (temperature = self .temperature , top_p = self .top_p ,
380
+ safety_settings = self .safety_settings )
361
381
362
382
if proxy :
363
383
logger .warning ('Google Gemini SDK does not support proxy, try using the system-level proxy if needed.' )
364
384
365
385
if base_url_config :
366
386
logger .warning ('Google Gemini SDK does not support changing base_url.' )
367
387
368
- def update_fee (self , response : Union [ GenerateContentResponse , AsyncGenerateContentResponse ] ):
388
+ def update_fee (self , response : types . GenerateContentResponse ):
369
389
model_info = self .model_info
370
390
prompt_tokens = response .usage_metadata .prompt_token_count
371
- completion_tokens = response .usage_metadata .candidates_token_count
391
+ completion_tokens = response .usage_metadata .candidates_token_count or 0
372
392
373
393
self .api_fees [- 1 ] += (prompt_tokens * model_info .input_price +
374
394
completion_tokens * model_info .output_price ) / 1000000
@@ -401,31 +421,41 @@ async def _create_achat(self, messages: List[Dict], stop_sequences: Optional[Lis
401
421
history_messages [i ]['parts' ] = [{'text' : content }]
402
422
403
423
self .config .stop_sequences = stop_sequences
404
- generative_model = genai .GenerativeModel (model_name = self .model_name , generation_config = self .config ,
405
- safety_settings = self .safety_settings , system_instruction = system_msg )
406
- client = genai .ChatSession (generative_model , history = history_messages )
424
+ # generative_model = genai.GenerativeModel(model_name=self.model_name, generation_config=self.config,
425
+ # safety_settings=self.safety_settings, system_instruction=system_msg)
426
+ # client = genai.ChatSession(generative_model, history=history_messages)
427
+ self .config .system_instruction = system_msg
407
428
408
429
response = None
409
430
for i in range (self .retry ):
410
- try :
411
- # send_message_async is buggy, so we use send_message instead as a workaround
412
- response = client .send_message (user_msg , safety_settings = self .safety_settings )
413
- self .update_fee (response )
414
- if not output_checker (user_msg , response .text ):
415
- logger .warning (f'Invalid response format. Retry num: { i + 1 } .' )
416
- continue
417
-
418
- if not response ._done :
419
- logger .warning (f'Failed to get a complete response. Retry num: { i + 1 } .' )
420
- continue
421
-
422
- break
423
- except (genai .types .BrokenResponseError , genai .types .IncompleteIterationError ,
424
- genai .types .StopCandidateException ) as e :
425
- logger .warning (f'{ type (e ).__name__ } : { e } . Retry num: { i + 1 } .' )
426
- except genai .types .generation_types .BlockedPromptException as e :
427
- logger .warning (f'Prompt blocked: { e } .\n Retry in 30s.' )
428
- time .sleep (30 )
431
+ # try:
432
+ # send_message_async is buggy, so we use send_message instead as a workaround
433
+ # response = client.send_message(user_msg, safety_settings=self.safety_settings)
434
+ response = await self .client .aio .models .generate_content (
435
+ model = self .model_name ,
436
+ contents = user_msg ,
437
+ config = self .config ,
438
+ )
439
+ self .update_fee (response )
440
+ if not response .text :
441
+ logger .warning (f'Get None response. Wait 15s. Retry num: { i + 1 } .' )
442
+ time .sleep (15 )
443
+ continue
444
+
445
+ if not output_checker (user_msg , response .text ):
446
+ logger .warning (f'Invalid response format. Retry num: { i + 1 } .' )
447
+ continue
448
+
449
+ if not response :
450
+ logger .warning (f'Failed to get a complete response. Retry num: { i + 1 } .' )
451
+ continue
452
+
453
+ break
454
+ # except Exception as e:
455
+ # logger.warning(f'{type(e).__name__}: {e}. Retry num: {i + 1}.')
456
+ # time.sleep(3)
457
+ # except genai.types.generation_types.BlockedPromptException as e:
458
+ # logger.warning(f'Prompt blocked: {e}.\n Retry in 30s.')
429
459
430
460
if not response :
431
461
raise ChatBotException ('Failed to create a chat.' )
0 commit comments