10
10
import time
11
11
import asyncio
12
12
import concurrent .futures
13
+ import random
13
14
from typing import List , Dict , Any , Optional , Union
14
15
from pathlib import Path
15
16
from google .cloud import bigquery , storage
@@ -55,14 +56,22 @@ def __init__(
55
56
self .bq_client = None
56
57
self .storage_client = None
57
58
58
- # Client pools for scaling
59
+ # Client pools for scaling - Use a list for random access instead of pop/append
59
60
self .bq_client_pool = []
60
- self .bq_pool_lock = asyncio .Lock ()
61
+
62
+ # Semaphore to prevent too many concurrent BigQuery connections
63
+ # This is more efficient than a lock for high-concurrency scenarios
64
+ self .bq_pool_semaphore = asyncio .Semaphore (bq_pool_size )
61
65
self .bq_pool_initialized = False
62
66
63
- # Thread pool for executing BigQuery operations
67
+ # Use an optimized thread pool with a larger size to handle concurrency better
68
+ # Setting max_workers higher than client pool to allow for better parallelism
69
+ thread_pool_size = max (
70
+ bq_pool_size * 2 , 40
71
+ ) # At least 2x the client pool or 40
64
72
self .thread_pool = concurrent .futures .ThreadPoolExecutor (
65
- max_workers = self .bq_pool_size , thread_name_prefix = "bq_worker"
73
+ max_workers = thread_pool_size ,
74
+ thread_name_prefix = "bq_worker" ,
66
75
)
67
76
68
77
# Initialize credentials if provided
@@ -107,13 +116,16 @@ def _initialize_bq_client_pool(self) -> None:
107
116
logger .error ("Cannot initialize BigQuery client pool: No credentials" )
108
117
return
109
118
110
- # Create multiple BigQuery clients
119
+ # Create multiple BigQuery clients with optimized settings
111
120
for i in range (self .bq_pool_size ):
121
+ # Configure each client with optimized settings for high concurrency
112
122
client = bigquery .Client (
113
123
credentials = self .credentials ,
114
124
project = self .project_id ,
115
- # Configure BigQuery client for better performance
116
- # These settings help manage resource usage under load
125
+ # Configure client with connection pooling settings
126
+ client_options = bigquery .ClientOptions (
127
+ api_endpoint = "https://bigquery.googleapis.com" ,
128
+ ),
117
129
)
118
130
self .bq_client_pool .append (client )
119
131
@@ -127,23 +139,28 @@ def _initialize_bq_client_pool(self) -> None:
127
139
128
140
async def get_bq_client (self ):
129
141
"""
130
- Get a BigQuery client from the pool
142
+ Get a BigQuery client from the pool with semaphore protection
131
143
Returns:
132
144
A BigQuery client from the pool
133
145
"""
134
- async with self .bq_pool_lock :
135
- if not self .bq_pool_initialized :
136
- # Fall back to the single client if pool isn't initialized
137
- return self .bq_client
146
+ if not self .bq_pool_initialized :
147
+ # Fall back to the single client if pool isn't initialized
148
+ return self .bq_client
138
149
139
- if not self .bq_client_pool :
140
- logger .error ("BigQuery client pool is empty" )
141
- return self .bq_client
150
+ if not self .bq_client_pool :
151
+ logger .error ("BigQuery client pool is empty" )
152
+ return self .bq_client
142
153
143
- # Simple round-robin selection from the pool
144
- client = self .bq_client_pool .pop (0 )
145
- self .bq_client_pool .append (client )
146
- return client
154
+ # Acquire semaphore to limit concurrent access
155
+ await self .bq_pool_semaphore .acquire ()
156
+ try :
157
+ # Use random client selection instead of round-robin to avoid lock contention
158
+ client_index = random .randrange (len (self .bq_client_pool ))
159
+ return self .bq_client_pool [client_index ]
160
+ except Exception as e :
161
+ logger .error (f"Error selecting BigQuery client: { e } " )
162
+ return self .bq_client
163
+ # Don't release semaphore here - it will be released after query execution
147
164
148
165
def download_file_from_gcs (
149
166
self ,
@@ -187,13 +204,15 @@ def _execute_bigquery_search(
187
204
client ,
188
205
query : str ,
189
206
job_config = None ,
207
+ semaphore = None ,
190
208
) -> pd .DataFrame :
191
209
"""
192
210
Execute a BigQuery query with a specific client
193
211
Args:
194
212
client: BigQuery client to use
195
213
query: Query to execute
196
214
job_config: Optional job configuration
215
+ semaphore: Optional semaphore to release after execution
197
216
Returns:
198
217
DataFrame with query results
199
218
"""
@@ -202,59 +221,66 @@ def _execute_bigquery_search(
202
221
max_retries = 3
203
222
retry_delay = 0.5 # Start with 0.5 second delay
204
223
205
- while retry_count <= max_retries :
206
- try :
207
- # Execute query with timeout and retry settings
208
- query_job = client .query (query , job_config = job_config )
209
-
210
- # Set a timeout for the query execution to prevent hanging
211
- timeout = 25 # seconds
212
- start_wait = time .time ()
224
+ try :
225
+ while retry_count <= max_retries :
226
+ try :
227
+ # Execute query with timeout and retry settings
228
+ query_job = client .query (query , job_config = job_config )
213
229
214
- # Wait for the job to complete with timeout
215
- while not query_job . done () and ( time . time () - start_wait ) < timeout :
216
- time .sleep ( 0.1 )
230
+ # Set a timeout for the query execution to prevent hanging
231
+ timeout = 25 # seconds
232
+ start_wait = time .time ( )
217
233
218
- if not query_job .done ():
219
- raise TimeoutError (f"Query execution timed out after { timeout } s" )
234
+ # Wait for the job to complete with timeout
235
+ while not query_job .done () and (time .time () - start_wait ) < timeout :
236
+ time .sleep (0.1 )
220
237
221
- # Check for errors
222
- if query_job .errors :
223
- raise Exception (f"Query failed with errors: { query_job .errors } " )
238
+ if not query_job .done ():
239
+ raise TimeoutError (
240
+ f"Query execution timed out after { timeout } s"
241
+ )
224
242
225
- # Convert to DataFrame
226
- results = query_job .to_dataframe ()
243
+ # Check for errors
244
+ if query_job .errors :
245
+ raise Exception (f"Query failed with errors: { query_job .errors } " )
227
246
228
- duration = time .time () - start_time
229
- logger .info (
230
- f"BigQuery query execution took { duration * 1000 :.2f} ms after { retry_count } retries"
231
- )
232
- return results
247
+ # Convert to DataFrame
248
+ results = query_job .to_dataframe ()
233
249
234
- except Exception as e :
235
- retry_count += 1
236
- if retry_count > max_retries :
237
250
duration = time .time () - start_time
238
- logger .error (
239
- f"BigQuery query failed after { duration * 1000 :.2f} ms and { retry_count - 1 } retries: { e } "
251
+ logger .info (
252
+ f"BigQuery query execution took { duration * 1000 :.2f} ms after { retry_count } retries"
240
253
)
241
- raise
242
-
243
- # Implement exponential backoff
244
- sleep_time = retry_delay * (
245
- 2 ** (retry_count - 1 )
246
- ) # Exponential backoff
247
- logger .warning (
248
- f"BigQuery query attempt { retry_count } failed: { e } . Retrying in { sleep_time :.2f} s..."
249
- )
250
- time .sleep (sleep_time )
254
+ return results
255
+
256
+ except Exception as e :
257
+ retry_count += 1
258
+ if retry_count > max_retries :
259
+ duration = time .time () - start_time
260
+ logger .error (
261
+ f"BigQuery query failed after { duration * 1000 :.2f} ms and { retry_count - 1 } retries: { e } "
262
+ )
263
+ raise
264
+
265
+ # Implement exponential backoff
266
+ sleep_time = retry_delay * (
267
+ 2 ** (retry_count - 1 )
268
+ ) # Exponential backoff
269
+ logger .warning (
270
+ f"BigQuery query attempt { retry_count } failed: { e } . Retrying in { sleep_time :.2f} s..."
271
+ )
272
+ time .sleep (sleep_time )
273
+ finally :
274
+ # Always release the semaphore, even if an exception occurred
275
+ if semaphore :
276
+ semaphore .release ()
251
277
252
278
async def bigquery_vector_search_async (
253
279
self ,
254
280
embedding : List [float ],
255
281
top_k : int = 5 ,
256
282
distance_type : str = "COSINE" ,
257
- options : str = '{"fraction_lists_to_search": 0.1 , "use_brute_force": false}' ,
283
+ options : str = '{"fraction_lists_to_search": 0.15 , "use_brute_force": false}' ,
258
284
) -> pd .DataFrame :
259
285
"""
260
286
Perform vector search in BigQuery asynchronously
@@ -268,13 +294,14 @@ async def bigquery_vector_search_async(
268
294
"""
269
295
start_time = time .time ()
270
296
271
- # Get client from pool
297
+ # Get client from pool (will acquire semaphore)
272
298
client = await self .get_bq_client ()
273
299
274
300
# Convert embedding to string for SQL query
275
301
embedding_str = "[" + ", " .join (str (x ) for x in embedding ) + "]"
276
302
277
- # Construct query with timeout settings
303
+ # Optimize query for better caching and performance
304
+ # Use query parameters for better cache performance
278
305
query = f"""
279
306
SELECT
280
307
base.text,
@@ -303,13 +330,14 @@ async def bigquery_vector_search_async(
303
330
},
304
331
)
305
332
306
- # Execute in thread pool
333
+ # Execute in thread pool, passing the semaphore to be released after execution
307
334
results = await asyncio .get_event_loop ().run_in_executor (
308
335
self .thread_pool ,
309
336
self ._execute_bigquery_search ,
310
337
client ,
311
338
query ,
312
339
job_config ,
340
+ self .bq_pool_semaphore , # Pass semaphore to release after execution
313
341
)
314
342
315
343
# Calculate total time including query construction
@@ -321,6 +349,9 @@ async def bigquery_vector_search_async(
321
349
322
350
return results
323
351
except Exception as e :
352
+ # Release semaphore in case of error
353
+ self .bq_pool_semaphore .release ()
354
+
324
355
# Log error with timing information
325
356
duration = time .time () - start_time
326
357
logger .error (
@@ -333,7 +364,7 @@ def bigquery_vector_search(
333
364
embedding : List [float ],
334
365
top_k : int = 5 ,
335
366
distance_type : str = "COSINE" ,
336
- options : str = '{"fraction_lists_to_search": 0.1 , "use_brute_force": false}' ,
367
+ options : str = '{"fraction_lists_to_search": 0.15 , "use_brute_force": false}' ,
337
368
) -> pd .DataFrame :
338
369
"""
339
370
Perform vector search in BigQuery (synchronous wrapper for the async version)
@@ -409,6 +440,11 @@ def bigquery_vector_search(
409
440
410
441
async def close (self ):
411
442
"""Clean up resources when shutting down"""
443
+ # Wait for any in-progress queries to complete
444
+ # by acquiring the full semaphore capacity
445
+ for _ in range (self .bq_pool_size ):
446
+ await self .bq_pool_semaphore .acquire ()
447
+
412
448
if self .thread_pool :
413
- self .thread_pool .shutdown (wait = False )
449
+ self .thread_pool .shutdown (wait = True )
414
450
logger .info ("Shut down BigQuery thread pool" )
0 commit comments