Skip to content

Commit bc86b0c

Browse files
committed
add thinking_budget and fix run_async
1 parent b9c43c9 commit bc86b0c

File tree

6 files changed

+78
-39
lines changed

6 files changed

+78
-39
lines changed

matrix/app_server/app_api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from matrix.client.endpoint_cache import EndpointCache
3939
from matrix.common.cluster_info import ClusterInfo, get_head_http_host
4040
from matrix.utils.json import convert_to_json_compatible
41-
from matrix.utils.os import lock_file
41+
from matrix.utils.os import lock_file, run_async
4242
from matrix.utils.ray import (
4343
ACTOR_NAME_SPACE,
4444
Action,
@@ -335,7 +335,7 @@ async def dummy_updater():
335335
ttl=endpoint_ttl_sec,
336336
serve_app=serve_app,
337337
)
338-
workers = asyncio.run(endpoint_cache())
338+
workers = run_async(endpoint_cache())
339339
metadata["endpoints"] = {
340340
"head": head,
341341
"workers": workers,

matrix/app_server/deploy_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
"aws_account",
6262
"aws_region",
6363
"endpoint_name",
64+
"anthropic_version",
65+
"thinking_budget",
6466
]
6567

6668
vllm_app_template = """
@@ -116,6 +118,7 @@
116118
args:
117119
model: {{ app.model_name }}
118120
api_key: {{ app.api_key }}
121+
thinking_budget: {{ app.thinking_budget }}
119122
deployments:
120123
- name: GeminiDeployment
121124
max_ongoing_requests: {{ app.max_ongoing_requests }}
@@ -394,6 +397,7 @@ def get_yaml_for_deployment(
394397
default_params = {
395398
"name": "gemini",
396399
"max_ongoing_requests": 10,
400+
"thinking_budget": 1024,
397401
}
398402
app.update({k: v for k, v in default_params.items() if k not in app})
399403
assert "api_key" in app, "add api_key to gemini app"

matrix/app_server/llm/gemini_proxy.py

+41-7
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8+
import re
89
from argparse import ArgumentParser
910
from typing import Any, Dict, List, Optional
1011

12+
import packaging
1113
from fastapi import FastAPI, HTTPException
1214
from google import genai
1315
from ray import serve
@@ -19,6 +21,11 @@
1921
app = FastAPI()
2022

2123

24+
def _extract_version(name: str) -> packaging.version.Version | None:
25+
match = re.search(r"gemini-(\d+\.\d+)", name)
26+
return packaging.version.parse(match.group(1)) if match else None
27+
28+
2229
@serve.deployment(
2330
autoscaling_config={
2431
"min_replicas": 1,
@@ -33,9 +40,15 @@ def __init__(
3340
self,
3441
api_key: str,
3542
model_name: str,
43+
thinking_budget: int,
3644
):
3745
self.model_name = model_name
3846
self.client = genai.Client(api_key=api_key)
47+
self.thinking_budget = thinking_budget
48+
version = _extract_version(model_name)
49+
self.reasoning = version is not None and version >= packaging.version.parse(
50+
"2.5"
51+
)
3952

4053
def _transform_messages(
4154
self, messages: List[Dict[str, str]]
@@ -98,7 +111,7 @@ async def create_chat_completion(
98111
completion_request.get("messages", [])
99112
)
100113

101-
request_params = {
114+
request_params: Dict[str, Any] = {
102115
"contents": messages_transformed,
103116
"config": {
104117
"temperature": completion_request.get("temperature", 0.6),
@@ -110,22 +123,36 @@ async def create_chat_completion(
110123
"system_instruction": system_instruction_content,
111124
},
112125
}
126+
if self.reasoning:
127+
request_params["config"]["thinking_config"] = {
128+
"thinking_budget": self.thinking_budget
129+
}
113130
try:
114131
response = await self.client.aio.models.generate_content(
115132
model=self.model_name, **request_params
116133
)
117134
except genai.errors.APIError as e:
118135
raise HTTPException(status_code=e.code, detail=str(e))
119136

120-
completion_response: Dict[str, Any] = {
121-
"id": response.response_id,
122-
"usage": {
137+
if response.usage_metadata:
138+
usage = {
123139
"prompt_tokens": response.usage_metadata.prompt_token_count,
124140
"total_tokens": response.usage_metadata.total_token_count,
125141
"completion_tokens": response.usage_metadata.candidates_token_count,
126-
},
142+
}
143+
else:
144+
usage = {
145+
"prompt_tokens": 0,
146+
"total_tokens": 0,
147+
"completion_tokens": 0,
148+
}
149+
150+
completion_response: Dict[str, Any] = {
151+
"id": response.response_id,
152+
"usage": usage,
127153
}
128154
if response.candidates is None:
155+
error: Dict[str, Any] = {}
129156
if response.prompt_feedback and response.prompt_feedback.block_reason:
130157
error = { # Adding an error field for clarity, not standard OpenAI format
131158
"message": f"Content blocked due to: {response.prompt_feedback.block_reason.name}",
@@ -143,17 +170,22 @@ async def create_chat_completion(
143170
choices = [
144171
{
145172
"index": i,
146-
"finish_reason": [candidate.finish_reason.value],
173+
"finish_reason": (
174+
candidate.finish_reason.value
175+
if candidate.finish_reason is not None
176+
else ""
177+
),
147178
"message": {
148179
"content": "".join(
149180
part.text
150-
for part in candidate.content.parts
181+
for part in (candidate.content.parts or [])
151182
if part.text is not None
152183
),
153184
"role": "assistant",
154185
},
155186
}
156187
for i, candidate in enumerate(response.candidates)
188+
if candidate.content
157189
]
158190
completion_response["choices"] = choices
159191

@@ -168,6 +200,7 @@ def build_app(cli_args: Dict[str, str]) -> serve.Application:
168200
argparse = ArgumentParser()
169201
argparse.add_argument("--api_key", type=str, required=True)
170202
argparse.add_argument("--model_name", type=str, required=True)
203+
argparse.add_argument("--thinking_budget", type=int, default=1024)
171204

172205
arg_strings = []
173206
for key, value in cli_args.items():
@@ -187,4 +220,5 @@ def build_app(cli_args: Dict[str, str]) -> serve.Application:
187220
).bind(
188221
args.api_key,
189222
args.model_name,
223+
args.thinking_budget,
190224
)

matrix/client/query_llm.py

+2-29
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from matrix.app_server.llm import openai_pb2, openai_pb2_grpc
2828
from matrix.client.client_utils import get_an_endpoint_url, save_to_jsonl
2929
from matrix.client.endpoint_cache import EndpointCache
30+
from matrix.utils.os import run_async
3031

3132
CHAR_PER_TOKEN = 3.61
3233
logging.basicConfig(
@@ -451,35 +452,7 @@ async def _process_requests():
451452
*[make_request(url, model, request, **kwargs) for request in requests]
452453
)
453454

454-
# Get the event loop
455-
try:
456-
loop = asyncio.get_event_loop()
457-
except RuntimeError:
458-
# No event loop in this thread, create a new one
459-
loop = asyncio.new_event_loop()
460-
asyncio.set_event_loop(loop)
461-
462-
# Check if we're already in an async context
463-
if loop.is_running():
464-
# We're in an async context and can't use run_until_complete
465-
# Create a new thread to run our async code
466-
import concurrent.futures
467-
import threading
468-
469-
def run_in_new_loop():
470-
# Create a new event loop for this thread
471-
new_loop = asyncio.new_event_loop()
472-
try:
473-
return new_loop.run_until_complete(_process_requests())
474-
finally:
475-
new_loop.close()
476-
477-
# Run in an executor to avoid blocking the current event loop
478-
with concurrent.futures.ThreadPoolExecutor() as pool:
479-
return pool.submit(run_in_new_loop).result()
480-
else:
481-
# We're in a sync context, use the current loop
482-
return loop.run_until_complete(_process_requests())
455+
return run_async(_process_requests())
483456

484457

485458
async def main(

matrix/utils/os.py

+28
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import asyncio
8+
import concurrent
79
import os
810
import select
911
import signal
@@ -233,3 +235,29 @@ def lock_file(filepath, mode, timeout=10, poll_interval=0.1):
233235
f"Could not acquire lock for {filepath} within {timeout} seconds."
234236
)
235237
time.sleep(poll_interval)
238+
239+
240+
def run_async(coro: tp.Awaitable[tp.Any]) -> tp.Any:
241+
"""
242+
Run an async coroutine from a synchronous context.
243+
Handles cases where an event loop is already running (e.g., Jupyter, FastAPI).
244+
"""
245+
try:
246+
loop = asyncio.get_event_loop()
247+
except RuntimeError:
248+
loop = asyncio.new_event_loop()
249+
asyncio.set_event_loop(loop)
250+
251+
if loop.is_running():
252+
253+
def run_in_new_loop():
254+
new_loop = asyncio.new_event_loop()
255+
try:
256+
return new_loop.run_until_complete(coro)
257+
finally:
258+
new_loop.close()
259+
260+
with concurrent.futures.ThreadPoolExecutor() as pool:
261+
return pool.submit(run_in_new_loop).result()
262+
else:
263+
return loop.run_until_complete(coro)

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ dependencies = [
2323
"pyyaml",
2424
"portalocker",
2525
"boto3",
26-
"google-genai==1.9.0",
26+
"google-genai>=1.13.0",
2727
"datasketch",
2828
]
2929
# zip_safe = false

0 commit comments

Comments
 (0)