5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import logging
8
+ import re
8
9
from argparse import ArgumentParser
9
10
from typing import Any , Dict , List , Optional
10
11
12
+ import packaging
11
13
from fastapi import FastAPI , HTTPException
12
14
from google import genai
13
15
from ray import serve
19
21
app = FastAPI ()
20
22
21
23
24
+ def _extract_version (name : str ) -> packaging .version .Version | None :
25
+ match = re .search (r"gemini-(\d+\.\d+)" , name )
26
+ return packaging .version .parse (match .group (1 )) if match else None
27
+
28
+
22
29
@serve .deployment (
23
30
autoscaling_config = {
24
31
"min_replicas" : 1 ,
@@ -33,9 +40,15 @@ def __init__(
33
40
self ,
34
41
api_key : str ,
35
42
model_name : str ,
43
+ thinking_budget : int ,
36
44
):
37
45
self .model_name = model_name
38
46
self .client = genai .Client (api_key = api_key )
47
+ self .thinking_budget = thinking_budget
48
+ version = _extract_version (model_name )
49
+ self .reasoning = version is not None and version >= packaging .version .parse (
50
+ "2.5"
51
+ )
39
52
40
53
def _transform_messages (
41
54
self , messages : List [Dict [str , str ]]
@@ -98,7 +111,7 @@ async def create_chat_completion(
98
111
completion_request .get ("messages" , [])
99
112
)
100
113
101
- request_params = {
114
+ request_params : Dict [ str , Any ] = {
102
115
"contents" : messages_transformed ,
103
116
"config" : {
104
117
"temperature" : completion_request .get ("temperature" , 0.6 ),
@@ -110,22 +123,36 @@ async def create_chat_completion(
110
123
"system_instruction" : system_instruction_content ,
111
124
},
112
125
}
126
+ if self .reasoning :
127
+ request_params ["config" ]["thinking_config" ] = {
128
+ "thinking_budget" : self .thinking_budget
129
+ }
113
130
try :
114
131
response = await self .client .aio .models .generate_content (
115
132
model = self .model_name , ** request_params
116
133
)
117
134
except genai .errors .APIError as e :
118
135
raise HTTPException (status_code = e .code , detail = str (e ))
119
136
120
- completion_response : Dict [str , Any ] = {
121
- "id" : response .response_id ,
122
- "usage" : {
137
+ if response .usage_metadata :
138
+ usage = {
123
139
"prompt_tokens" : response .usage_metadata .prompt_token_count ,
124
140
"total_tokens" : response .usage_metadata .total_token_count ,
125
141
"completion_tokens" : response .usage_metadata .candidates_token_count ,
126
- },
142
+ }
143
+ else :
144
+ usage = {
145
+ "prompt_tokens" : 0 ,
146
+ "total_tokens" : 0 ,
147
+ "completion_tokens" : 0 ,
148
+ }
149
+
150
+ completion_response : Dict [str , Any ] = {
151
+ "id" : response .response_id ,
152
+ "usage" : usage ,
127
153
}
128
154
if response .candidates is None :
155
+ error : Dict [str , Any ] = {}
129
156
if response .prompt_feedback and response .prompt_feedback .block_reason :
130
157
error = { # Adding an error field for clarity, not standard OpenAI format
131
158
"message" : f"Content blocked due to: { response .prompt_feedback .block_reason .name } " ,
@@ -143,17 +170,22 @@ async def create_chat_completion(
143
170
choices = [
144
171
{
145
172
"index" : i ,
146
- "finish_reason" : [candidate .finish_reason .value ],
173
+ "finish_reason" : (
174
+ candidate .finish_reason .value
175
+ if candidate .finish_reason is not None
176
+ else ""
177
+ ),
147
178
"message" : {
148
179
"content" : "" .join (
149
180
part .text
150
- for part in candidate .content .parts
181
+ for part in ( candidate .content .parts or [])
151
182
if part .text is not None
152
183
),
153
184
"role" : "assistant" ,
154
185
},
155
186
}
156
187
for i , candidate in enumerate (response .candidates )
188
+ if candidate .content
157
189
]
158
190
completion_response ["choices" ] = choices
159
191
@@ -168,6 +200,7 @@ def build_app(cli_args: Dict[str, str]) -> serve.Application:
168
200
argparse = ArgumentParser ()
169
201
argparse .add_argument ("--api_key" , type = str , required = True )
170
202
argparse .add_argument ("--model_name" , type = str , required = True )
203
+ argparse .add_argument ("--thinking_budget" , type = int , default = 1024 )
171
204
172
205
arg_strings = []
173
206
for key , value in cli_args .items ():
@@ -187,4 +220,5 @@ def build_app(cli_args: Dict[str, str]) -> serve.Application:
187
220
).bind (
188
221
args .api_key ,
189
222
args .model_name ,
223
+ args .thinking_budget ,
190
224
)
0 commit comments