10
10
from contextlib import contextmanager
11
11
from datetime import datetime
12
12
from typing import Any
13
- from typing import ContextManager
13
+ from typing import AsyncContextManager
14
14
15
15
import asyncpg # type: ignore
16
16
import boto3
46
46
from onyx .utils .logger import setup_logger
47
47
from shared_configs .configs import MULTI_TENANT
48
48
from shared_configs .configs import POSTGRES_DEFAULT_SCHEMA
49
+ from shared_configs .configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
49
50
from shared_configs .configs import TENANT_ID_PREFIX
50
51
from shared_configs .contextvars import CURRENT_TENANT_ID_CONTEXTVAR
51
52
from shared_configs .contextvars import get_current_tenant_id
@@ -189,45 +190,6 @@ class SqlEngine:
189
190
_lock : threading .Lock = threading .Lock ()
190
191
_app_name : str = POSTGRES_UNKNOWN_APP_NAME
191
192
192
- # NOTE(rkuo) - this appears to be unused, clean it up?
193
- # @classmethod
194
- # def _init_engine(cls, **engine_kwargs: Any) -> Engine:
195
- # connection_string = build_connection_string(
196
- # db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
197
- # )
198
-
199
- # # Start with base kwargs that are valid for all pool types
200
- # final_engine_kwargs: dict[str, Any] = {}
201
-
202
- # if POSTGRES_USE_NULL_POOL:
203
- # # if null pool is specified, then we need to make sure that
204
- # # we remove any passed in kwargs related to pool size that would
205
- # # cause the initialization to fail
206
- # final_engine_kwargs.update(engine_kwargs)
207
-
208
- # final_engine_kwargs["poolclass"] = pool.NullPool
209
- # if "pool_size" in final_engine_kwargs:
210
- # del final_engine_kwargs["pool_size"]
211
- # if "max_overflow" in final_engine_kwargs:
212
- # del final_engine_kwargs["max_overflow"]
213
- # else:
214
- # final_engine_kwargs["pool_size"] = 20
215
- # final_engine_kwargs["max_overflow"] = 5
216
- # final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
217
- # final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
218
-
219
- # # any passed in kwargs override the defaults
220
- # final_engine_kwargs.update(engine_kwargs)
221
-
222
- # logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
223
- # # echo=True here for inspecting all emitted db queries
224
- # engine = create_engine(connection_string, **final_engine_kwargs)
225
-
226
- # if USE_IAM_AUTH:
227
- # event.listen(engine, "do_connect", provide_iam_token)
228
-
229
- # return engine
230
-
231
193
@classmethod
232
194
def init_engine (
233
195
cls ,
@@ -410,18 +372,6 @@ def provide_iam_token_async(
410
372
return _ASYNC_ENGINE
411
373
412
374
413
- # Listen for events on the synchronous Session class
414
- @event .listens_for (Session , "after_begin" )
415
- def _set_search_path (
416
- session : Session , transaction : Any , connection : Any , * args : Any , ** kwargs : Any
417
- ) -> None :
418
- """Every time a new transaction is started,
419
- set the search_path from the session's info."""
420
- tenant_id = session .info .get ("tenant_id" )
421
- if tenant_id :
422
- connection .exec_driver_sql (f'SET search_path = "{ tenant_id } "' )
423
-
424
-
425
375
engine = get_sqlalchemy_async_engine ()
426
376
AsyncSessionLocal = sessionmaker ( # type: ignore
427
377
bind = engine ,
@@ -430,33 +380,6 @@ def _set_search_path(
430
380
)
431
381
432
382
433
- @asynccontextmanager
434
- async def get_async_session_with_tenant (
435
- tenant_id : str | None = None ,
436
- ) -> AsyncGenerator [AsyncSession , None ]:
437
- if tenant_id is None :
438
- tenant_id = get_current_tenant_id ()
439
-
440
- if not is_valid_schema_name (tenant_id ):
441
- logger .error (f"Invalid tenant ID: { tenant_id } " )
442
- raise ValueError ("Invalid tenant ID" )
443
-
444
- async with AsyncSessionLocal () as session :
445
- session .sync_session .info ["tenant_id" ] = tenant_id
446
-
447
- if POSTGRES_IDLE_SESSIONS_TIMEOUT :
448
- await session .execute (
449
- text (
450
- f"SET idle_in_transaction_session_timeout = { POSTGRES_IDLE_SESSIONS_TIMEOUT } "
451
- )
452
- )
453
-
454
- try :
455
- yield session
456
- finally :
457
- pass
458
-
459
-
460
383
@contextmanager
461
384
def get_session_with_current_tenant () -> Generator [Session , None , None ]:
462
385
tenant_id = get_current_tenant_id ()
@@ -474,17 +397,24 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
474
397
CURRENT_TENANT_ID_CONTEXTVAR .reset (token )
475
398
476
399
400
+ def _set_search_path_on_checkout__listener (
401
+ dbapi_conn : Any , connection_record : Any , connection_proxy : Any
402
+ ) -> None :
403
+ """Listener to make sure we ALWAYS set the search path on checkout."""
404
+ tenant_id = get_current_tenant_id ()
405
+ if tenant_id and is_valid_schema_name (tenant_id ):
406
+ with dbapi_conn .cursor () as cursor :
407
+ cursor .execute (f'SET search_path TO "{ tenant_id } "' )
408
+
409
+
477
410
@contextmanager
478
411
def get_session_with_tenant (* , tenant_id : str ) -> Generator [Session , None , None ]:
479
412
"""
480
413
Generate a database session for a specific tenant.
481
414
"""
482
- if tenant_id is None :
483
- tenant_id = POSTGRES_DEFAULT_SCHEMA
484
-
485
415
engine = get_sqlalchemy_engine ()
486
416
487
- event .listen (engine , "checkout" , set_search_path_on_checkout )
417
+ event .listen (engine , "checkout" , _set_search_path_on_checkout__listener )
488
418
489
419
if not is_valid_schema_name (tenant_id ):
490
420
raise HTTPException (status_code = 400 , detail = "Invalid tenant ID" )
@@ -519,57 +449,84 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
519
449
cursor .close ()
520
450
521
451
522
- def set_search_path_on_checkout (
523
- dbapi_conn : Any , connection_record : Any , connection_proxy : Any
524
- ) -> None :
452
+ def get_session () -> Generator [Session , None , None ]:
453
+ """For use w/ Depends for FastAPI endpoints.
454
+
455
+ Has some additional validation, and likely should be merged
456
+ with get_session_context_manager in the future."""
525
457
tenant_id = get_current_tenant_id ()
526
- if tenant_id and is_valid_schema_name (tenant_id ):
527
- with dbapi_conn .cursor () as cursor :
528
- cursor .execute (f'SET search_path TO "{ tenant_id } "' )
458
+ if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT :
459
+ raise BasicAuthenticationError (detail = "User must authenticate" )
460
+
461
+ if not is_valid_schema_name (tenant_id ):
462
+ raise HTTPException (status_code = 400 , detail = "Invalid tenant ID" )
529
463
464
+ with get_session_context_manager () as db_session :
465
+ yield db_session
530
466
531
- def get_session_generator_with_tenant () -> Generator [Session , None , None ]:
467
+
468
+ @contextlib .contextmanager
469
+ def get_session_context_manager () -> Generator [Session , None , None ]:
470
+ """Context manager for database sessions."""
532
471
tenant_id = get_current_tenant_id ()
533
472
with get_session_with_tenant (tenant_id = tenant_id ) as session :
534
473
yield session
535
474
536
475
537
- def get_session () -> Generator [Session , None , None ]:
538
- tenant_id = get_current_tenant_id ()
539
- if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT :
540
- raise BasicAuthenticationError (detail = "User must authenticate" )
476
+ def _set_search_path_on_transaction__listener (
477
+ session : Session , transaction : Any , connection : Any , * args : Any , ** kwargs : Any
478
+ ) -> None :
479
+ """Every time a new transaction is started,
480
+ set the search_path from the session's info."""
481
+ tenant_id = session .info .get ("tenant_id" )
482
+ if tenant_id :
483
+ connection .exec_driver_sql (f'SET search_path = "{ tenant_id } "' )
541
484
542
- engine = get_sqlalchemy_engine ()
543
485
544
- with Session (engine , expire_on_commit = False ) as session :
545
- if MULTI_TENANT :
546
- if not is_valid_schema_name (tenant_id ):
547
- raise HTTPException (status_code = 400 , detail = "Invalid tenant ID" )
548
- session .execute (text (f'SET search_path = "{ tenant_id } "' ))
549
- yield session
486
+ async def get_async_session (
487
+ tenant_id : str | None = None ,
488
+ ) -> AsyncGenerator [AsyncSession , None ]:
489
+ """For use w/ Depends for *async* FastAPI endpoints.
550
490
491
+ For standard `async with ... as ...` use, use get_async_session_context_manager.
492
+ """
493
+
494
+ if tenant_id is None :
495
+ tenant_id = get_current_tenant_id ()
551
496
552
- async def get_async_session () -> AsyncGenerator [AsyncSession , None ]:
553
- tenant_id = get_current_tenant_id ()
554
497
engine = get_sqlalchemy_async_engine ()
498
+
555
499
async with AsyncSession (engine , expire_on_commit = False ) as async_session :
556
- if MULTI_TENANT :
557
- if not is_valid_schema_name (tenant_id ):
558
- raise HTTPException (status_code = 400 , detail = "Invalid tenant ID" )
559
- await async_session .execute (text (f'SET search_path = "{ tenant_id } "' ))
560
- yield async_session
500
+ # set the search path on sync session as well to be extra safe
501
+ event .listen (
502
+ async_session .sync_session ,
503
+ "after_begin" ,
504
+ _set_search_path_on_transaction__listener ,
505
+ )
561
506
507
+ if POSTGRES_IDLE_SESSIONS_TIMEOUT :
508
+ await async_session .execute (
509
+ text (
510
+ f"SET idle_in_transaction_session_timeout = { POSTGRES_IDLE_SESSIONS_TIMEOUT } "
511
+ )
512
+ )
562
513
563
- def get_session_context_manager () -> ContextManager [Session ]:
564
- """Context manager for database sessions."""
565
- return contextlib .contextmanager (get_session_generator_with_tenant )()
514
+ if not is_valid_schema_name (tenant_id ):
515
+ raise HTTPException (status_code = 400 , detail = "Invalid tenant ID" )
566
516
517
+ # don't need to set the search path for self-hosted + default schema
518
+ # this is also true for sync sessions, but just not adding it there for
519
+ # now to simplify / not change too much
520
+ if MULTI_TENANT or tenant_id != POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE :
521
+ await async_session .execute (text (f'SET search_path = "{ tenant_id } "' ))
522
+
523
+ yield async_session
567
524
568
- def get_session_factory () -> sessionmaker [ Session ]:
569
- global SessionFactory
570
- if SessionFactory is None :
571
- SessionFactory = sessionmaker ( bind = get_sqlalchemy_engine ())
572
- return SessionFactory
525
+
526
+ def get_async_session_context_manager (
527
+ tenant_id : str | None = None ,
528
+ ) -> AsyncContextManager [ AsyncSession ]:
529
+ return asynccontextmanager ( get_async_session )( tenant_id )
573
530
574
531
575
532
async def warm_up_connections (
0 commit comments