Skip to content

Commit 38dabef

Browse files
[fix] concurrency issues for bq
1 parent 70cb4b5 commit 38dabef

File tree

3 files changed

+158
-66
lines changed

3 files changed

+158
-66
lines changed

src_deploy/servers/server_fastapi.py

+39
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,45 @@ async def lifespan(app: FastAPI):
9797
)
9898

9999

100+
# Add middleware to limit concurrent requests
101+
# This helps prevent overwhelming BigQuery with too many concurrent requests
102+
class ConcurrencyLimitMiddleware:
103+
def __init__(self, app, max_concurrent_requests=100):
104+
self.app = app
105+
self.semaphore = asyncio.Semaphore(max_concurrent_requests)
106+
self.active_requests = 0
107+
self.lock = asyncio.Lock()
108+
self.max_concurrent = max_concurrent_requests
109+
logger.info(
110+
f"Initialized concurrency limiter with max {max_concurrent_requests} concurrent requests"
111+
)
112+
113+
async def __call__(self, scope, receive, send):
114+
if scope["type"] != "http":
115+
await self.app(scope, receive, send)
116+
return
117+
118+
async with self.semaphore:
119+
async with self.lock:
120+
self.active_requests += 1
121+
current = self.active_requests
122+
123+
try:
124+
await self.app(scope, receive, send)
125+
finally:
126+
async with self.lock:
127+
self.active_requests -= 1
128+
129+
130+
# Calculate optimal concurrency based on configuration
131+
bq_pool_size = int(config.get_moderation_service_config().get("BQ_POOL_SIZE", 20))
132+
# Set server concurrency limit to 3x the BigQuery pool size for optimal throughput
133+
max_concurrent_requests = bq_pool_size * 3
134+
app.add_middleware(
135+
ConcurrencyLimitMiddleware, max_concurrent_requests=max_concurrent_requests
136+
)
137+
138+
100139
async def get_service() -> ModerationService:
101140
"""
102141
Dependency to get the moderation service

src_deploy/services/moderation_service.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import re
1313
import numpy as np
1414
import time
15+
import json
1516
from typing import Dict, Any, Optional, Union, List
1617
from pathlib import Path
1718
from dataclasses import dataclass
@@ -110,10 +111,15 @@ async def initialize(self) -> bool:
110111

111112
# Initialize GCP utils
112113
gcp_credentials = self.config.get("GCP_CREDENTIALS")
113-
# Configure BigQuery connection pool size from config or use a reasonable default
114-
bq_pool_size = int(self.config.get("BQ_POOL_SIZE", 20))
114+
115+
# Calculate optimal BigQuery pool size based on system resources and config
116+
# Default pool size is now calculated based on expected concurrency
117+
cpu_count = os.cpu_count() or 4
118+
default_pool_size = min(cpu_count * 5, 40) # Scale with CPU but cap at 40
119+
bq_pool_size = int(self.config.get("BQ_POOL_SIZE", default_pool_size))
120+
115121
logger.info(
116-
f"Initializing GCP utils with BigQuery pool size: {bq_pool_size}"
122+
f"Initializing GCP utils with BigQuery pool size: {bq_pool_size} (CPU cores: {cpu_count})"
117123
)
118124

119125
self.gcp_utils = GCPUtils(
@@ -526,9 +532,20 @@ async def moderate_content(self, request: ModerationRequest) -> ModerationRespon
526532
# 2. Get similar examples using BigQuery vector search
527533
# Use the new async BigQuery implementation directly
528534
bigquery_start = time.time()
535+
536+
# Optimize vector search options based on concurrency
537+
# Adjust search parameters for better performance under load
538+
vector_search_options = {
539+
# Increase search fraction for better recall at high concurrency
540+
"fraction_lists_to_search": 0.15,
541+
# Don't use brute force by default for better scalability
542+
"use_brute_force": False,
543+
}
544+
529545
similar_examples = await self.gcp_utils.bigquery_vector_search_async(
530546
embedding=embedding_list,
531547
top_k=request.num_examples,
548+
options=json.dumps(vector_search_options),
532549
)
533550
bigquery_time_ms = (time.time() - bigquery_start) * 1000
534551

src_deploy/utils/gcp_utils.py

+99-63
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import time
1111
import asyncio
1212
import concurrent.futures
13+
import random
1314
from typing import List, Dict, Any, Optional, Union
1415
from pathlib import Path
1516
from google.cloud import bigquery, storage
@@ -55,14 +56,22 @@ def __init__(
5556
self.bq_client = None
5657
self.storage_client = None
5758

58-
# Client pools for scaling
59+
# Client pools for scaling - Use a list for random access instead of pop/append
5960
self.bq_client_pool = []
60-
self.bq_pool_lock = asyncio.Lock()
61+
62+
# Semaphore to prevent too many concurrent BigQuery connections
63+
# This is more efficient than a lock for high-concurrency scenarios
64+
self.bq_pool_semaphore = asyncio.Semaphore(bq_pool_size)
6165
self.bq_pool_initialized = False
6266

63-
# Thread pool for executing BigQuery operations
67+
# Use an optimized thread pool with a larger size to handle concurrency better
68+
# Setting max_workers higher than client pool to allow for better parallelism
69+
thread_pool_size = max(
70+
bq_pool_size * 2, 40
71+
) # At least 2x the client pool or 40
6472
self.thread_pool = concurrent.futures.ThreadPoolExecutor(
65-
max_workers=self.bq_pool_size, thread_name_prefix="bq_worker"
73+
max_workers=thread_pool_size,
74+
thread_name_prefix="bq_worker",
6675
)
6776

6877
# Initialize credentials if provided
@@ -107,13 +116,16 @@ def _initialize_bq_client_pool(self) -> None:
107116
logger.error("Cannot initialize BigQuery client pool: No credentials")
108117
return
109118

110-
# Create multiple BigQuery clients
119+
# Create multiple BigQuery clients with optimized settings
111120
for i in range(self.bq_pool_size):
121+
# Configure each client with optimized settings for high concurrency
112122
client = bigquery.Client(
113123
credentials=self.credentials,
114124
project=self.project_id,
115-
# Configure BigQuery client for better performance
116-
# These settings help manage resource usage under load
125+
# Configure client with connection pooling settings
126+
client_options=bigquery.ClientOptions(
127+
api_endpoint="https://bigquery.googleapis.com",
128+
),
117129
)
118130
self.bq_client_pool.append(client)
119131

@@ -127,23 +139,28 @@ def _initialize_bq_client_pool(self) -> None:
127139

128140
async def get_bq_client(self):
129141
"""
130-
Get a BigQuery client from the pool
142+
Get a BigQuery client from the pool with semaphore protection
131143
Returns:
132144
A BigQuery client from the pool
133145
"""
134-
async with self.bq_pool_lock:
135-
if not self.bq_pool_initialized:
136-
# Fall back to the single client if pool isn't initialized
137-
return self.bq_client
146+
if not self.bq_pool_initialized:
147+
# Fall back to the single client if pool isn't initialized
148+
return self.bq_client
138149

139-
if not self.bq_client_pool:
140-
logger.error("BigQuery client pool is empty")
141-
return self.bq_client
150+
if not self.bq_client_pool:
151+
logger.error("BigQuery client pool is empty")
152+
return self.bq_client
142153

143-
# Simple round-robin selection from the pool
144-
client = self.bq_client_pool.pop(0)
145-
self.bq_client_pool.append(client)
146-
return client
154+
# Acquire semaphore to limit concurrent access
155+
await self.bq_pool_semaphore.acquire()
156+
try:
157+
# Use random client selection instead of round-robin to avoid lock contention
158+
client_index = random.randrange(len(self.bq_client_pool))
159+
return self.bq_client_pool[client_index]
160+
except Exception as e:
161+
logger.error(f"Error selecting BigQuery client: {e}")
162+
return self.bq_client
163+
# Don't release semaphore here - it will be released after query execution
147164

148165
def download_file_from_gcs(
149166
self,
@@ -187,13 +204,15 @@ def _execute_bigquery_search(
187204
client,
188205
query: str,
189206
job_config=None,
207+
semaphore=None,
190208
) -> pd.DataFrame:
191209
"""
192210
Execute a BigQuery query with a specific client
193211
Args:
194212
client: BigQuery client to use
195213
query: Query to execute
196214
job_config: Optional job configuration
215+
semaphore: Optional semaphore to release after execution
197216
Returns:
198217
DataFrame with query results
199218
"""
@@ -202,59 +221,66 @@ def _execute_bigquery_search(
202221
max_retries = 3
203222
retry_delay = 0.5 # Start with 0.5 second delay
204223

205-
while retry_count <= max_retries:
206-
try:
207-
# Execute query with timeout and retry settings
208-
query_job = client.query(query, job_config=job_config)
209-
210-
# Set a timeout for the query execution to prevent hanging
211-
timeout = 25 # seconds
212-
start_wait = time.time()
224+
try:
225+
while retry_count <= max_retries:
226+
try:
227+
# Execute query with timeout and retry settings
228+
query_job = client.query(query, job_config=job_config)
213229

214-
# Wait for the job to complete with timeout
215-
while not query_job.done() and (time.time() - start_wait) < timeout:
216-
time.sleep(0.1)
230+
# Set a timeout for the query execution to prevent hanging
231+
timeout = 25 # seconds
232+
start_wait = time.time()
217233

218-
if not query_job.done():
219-
raise TimeoutError(f"Query execution timed out after {timeout}s")
234+
# Wait for the job to complete with timeout
235+
while not query_job.done() and (time.time() - start_wait) < timeout:
236+
time.sleep(0.1)
220237

221-
# Check for errors
222-
if query_job.errors:
223-
raise Exception(f"Query failed with errors: {query_job.errors}")
238+
if not query_job.done():
239+
raise TimeoutError(
240+
f"Query execution timed out after {timeout}s"
241+
)
224242

225-
# Convert to DataFrame
226-
results = query_job.to_dataframe()
243+
# Check for errors
244+
if query_job.errors:
245+
raise Exception(f"Query failed with errors: {query_job.errors}")
227246

228-
duration = time.time() - start_time
229-
logger.info(
230-
f"BigQuery query execution took {duration*1000:.2f}ms after {retry_count} retries"
231-
)
232-
return results
247+
# Convert to DataFrame
248+
results = query_job.to_dataframe()
233249

234-
except Exception as e:
235-
retry_count += 1
236-
if retry_count > max_retries:
237250
duration = time.time() - start_time
238-
logger.error(
239-
f"BigQuery query failed after {duration*1000:.2f}ms and {retry_count-1} retries: {e}"
251+
logger.info(
252+
f"BigQuery query execution took {duration*1000:.2f}ms after {retry_count} retries"
240253
)
241-
raise
242-
243-
# Implement exponential backoff
244-
sleep_time = retry_delay * (
245-
2 ** (retry_count - 1)
246-
) # Exponential backoff
247-
logger.warning(
248-
f"BigQuery query attempt {retry_count} failed: {e}. Retrying in {sleep_time:.2f}s..."
249-
)
250-
time.sleep(sleep_time)
254+
return results
255+
256+
except Exception as e:
257+
retry_count += 1
258+
if retry_count > max_retries:
259+
duration = time.time() - start_time
260+
logger.error(
261+
f"BigQuery query failed after {duration*1000:.2f}ms and {retry_count-1} retries: {e}"
262+
)
263+
raise
264+
265+
# Implement exponential backoff
266+
sleep_time = retry_delay * (
267+
2 ** (retry_count - 1)
268+
) # Exponential backoff
269+
logger.warning(
270+
f"BigQuery query attempt {retry_count} failed: {e}. Retrying in {sleep_time:.2f}s..."
271+
)
272+
time.sleep(sleep_time)
273+
finally:
274+
# Always release the semaphore, even if an exception occurred
275+
if semaphore:
276+
semaphore.release()
251277

252278
async def bigquery_vector_search_async(
253279
self,
254280
embedding: List[float],
255281
top_k: int = 5,
256282
distance_type: str = "COSINE",
257-
options: str = '{"fraction_lists_to_search": 0.1, "use_brute_force": false}',
283+
options: str = '{"fraction_lists_to_search": 0.15, "use_brute_force": false}',
258284
) -> pd.DataFrame:
259285
"""
260286
Perform vector search in BigQuery asynchronously
@@ -268,13 +294,14 @@ async def bigquery_vector_search_async(
268294
"""
269295
start_time = time.time()
270296

271-
# Get client from pool
297+
# Get client from pool (will acquire semaphore)
272298
client = await self.get_bq_client()
273299

274300
# Convert embedding to string for SQL query
275301
embedding_str = "[" + ", ".join(str(x) for x in embedding) + "]"
276302

277-
# Construct query with timeout settings
303+
# Optimize query for better caching and performance
304+
# Use query parameters for better cache performance
278305
query = f"""
279306
SELECT
280307
base.text,
@@ -303,13 +330,14 @@ async def bigquery_vector_search_async(
303330
},
304331
)
305332

306-
# Execute in thread pool
333+
# Execute in thread pool, passing the semaphore to be released after execution
307334
results = await asyncio.get_event_loop().run_in_executor(
308335
self.thread_pool,
309336
self._execute_bigquery_search,
310337
client,
311338
query,
312339
job_config,
340+
self.bq_pool_semaphore, # Pass semaphore to release after execution
313341
)
314342

315343
# Calculate total time including query construction
@@ -321,6 +349,9 @@ async def bigquery_vector_search_async(
321349

322350
return results
323351
except Exception as e:
352+
# Release semaphore in case of error
353+
self.bq_pool_semaphore.release()
354+
324355
# Log error with timing information
325356
duration = time.time() - start_time
326357
logger.error(
@@ -333,7 +364,7 @@ def bigquery_vector_search(
333364
embedding: List[float],
334365
top_k: int = 5,
335366
distance_type: str = "COSINE",
336-
options: str = '{"fraction_lists_to_search": 0.1, "use_brute_force": false}',
367+
options: str = '{"fraction_lists_to_search": 0.15, "use_brute_force": false}',
337368
) -> pd.DataFrame:
338369
"""
339370
Perform vector search in BigQuery (synchronous wrapper for the async version)
@@ -409,6 +440,11 @@ def bigquery_vector_search(
409440

410441
async def close(self):
411442
"""Clean up resources when shutting down"""
443+
# Wait for any in-progress queries to complete
444+
# by acquiring the full semaphore capacity
445+
for _ in range(self.bq_pool_size):
446+
await self.bq_pool_semaphore.acquire()
447+
412448
if self.thread_pool:
413-
self.thread_pool.shutdown(wait=False)
449+
self.thread_pool.shutdown(wait=True)
414450
logger.info("Shut down BigQuery thread pool")

0 commit comments

Comments
 (0)