Skip to content

Commit cc21604

Browse files
committed
Fix custom schema
1 parent 511d0e7 commit cc21604

File tree

5 files changed

+86
-128
lines changed

5 files changed

+86
-128
lines changed

backend/ee/onyx/server/saml.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from onyx.db.auth import get_user_count
2929
from onyx.db.auth import get_user_db
3030
from onyx.db.engine import get_async_session
31+
from onyx.db.engine import get_async_session_context_manager
3132
from onyx.db.engine import get_session
3233
from onyx.db.models import User
3334
from onyx.utils.logger import setup_logger
@@ -49,13 +50,10 @@ async def upsert_saml_user(email: str) -> User:
4950
Identity Provider, but we need a valid password to satisfy system requirements.
5051
"""
5152
logger.debug(f"Attempting to upsert SAML user with email: {email}")
52-
get_async_session_context = contextlib.asynccontextmanager(
53-
get_async_session
54-
) # type:ignore
5553
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
5654
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
5755

58-
async with get_async_session_context() as session:
56+
async with get_async_session_context_manager() as session:
5957
async with get_user_db_context(session) as user_db:
6058
async with get_user_manager_context(user_db) as user_manager:
6159
try:

backend/onyx/auth/users.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
from onyx.db.auth import get_user_db
9393
from onyx.db.auth import SQLAlchemyUserAdminDB
9494
from onyx.db.engine import get_async_session
95-
from onyx.db.engine import get_async_session_with_tenant
95+
from onyx.db.engine import get_async_session_context_manager
9696
from onyx.db.engine import get_session_with_tenant
9797
from onyx.db.models import AccessToken
9898
from onyx.db.models import OAuthAccount
@@ -253,7 +253,7 @@ async def get_by_email(self, user_email: str) -> User:
253253
tenant_id = fetch_ee_implementation_or_noop(
254254
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
255255
)(user_email)
256-
async with get_async_session_with_tenant(tenant_id) as db_session:
256+
async with get_async_session_context_manager(tenant_id) as db_session:
257257
if MULTI_TENANT:
258258
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
259259
db_session, User, OAuthAccount
@@ -296,7 +296,7 @@ async def create(
296296
)
297297
user: User
298298

299-
async with get_async_session_with_tenant(tenant_id) as db_session:
299+
async with get_async_session_context_manager(tenant_id) as db_session:
300300
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
301301
verify_email_is_invited(user_create.email)
302302
verify_email_domain(user_create.email)
@@ -402,7 +402,7 @@ async def oauth_callback(
402402

403403
# Proceed with the tenant context
404404
token = None
405-
async with get_async_session_with_tenant(tenant_id) as db_session:
405+
async with get_async_session_context_manager(tenant_id) as db_session:
406406
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
407407

408408
verify_email_in_whitelist(account_email, tenant_id)
@@ -642,7 +642,7 @@ async def authenticate(
642642
return None
643643

644644
# Create a tenant-specific session
645-
async with get_async_session_with_tenant(tenant_id) as tenant_session:
645+
async with get_async_session_context_manager(tenant_id) as tenant_session:
646646
tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase(
647647
tenant_session, User
648648
)

backend/onyx/db/auth.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from onyx.auth.schemas import UserRole
1818
from onyx.db.api_key import get_api_key_email_pattern
1919
from onyx.db.engine import get_async_session
20-
from onyx.db.engine import get_async_session_with_tenant
20+
from onyx.db.engine import get_async_session_context_manager
2121
from onyx.db.models import AccessToken
2222
from onyx.db.models import OAuthAccount
2323
from onyx.db.models import User
@@ -55,7 +55,7 @@ def get_total_users_count(db_session: Session) -> int:
5555

5656

5757
async def get_user_count(only_admin_users: bool = False) -> int:
58-
async with get_async_session_with_tenant() as session:
58+
async with get_async_session_context_manager() as session:
5959
count_stmt = func.count(User.id) # type: ignore
6060
stmt = select(count_stmt)
6161
if only_admin_users:

backend/onyx/db/engine.py

+73-116
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from contextlib import contextmanager
1111
from datetime import datetime
1212
from typing import Any
13-
from typing import ContextManager
13+
from typing import AsyncContextManager
1414

1515
import asyncpg # type: ignore
1616
import boto3
@@ -46,6 +46,7 @@
4646
from onyx.utils.logger import setup_logger
4747
from shared_configs.configs import MULTI_TENANT
4848
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
49+
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
4950
from shared_configs.configs import TENANT_ID_PREFIX
5051
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
5152
from shared_configs.contextvars import get_current_tenant_id
@@ -189,45 +190,6 @@ class SqlEngine:
189190
_lock: threading.Lock = threading.Lock()
190191
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
191192

192-
# NOTE(rkuo) - this appears to be unused, clean it up?
193-
# @classmethod
194-
# def _init_engine(cls, **engine_kwargs: Any) -> Engine:
195-
# connection_string = build_connection_string(
196-
# db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
197-
# )
198-
199-
# # Start with base kwargs that are valid for all pool types
200-
# final_engine_kwargs: dict[str, Any] = {}
201-
202-
# if POSTGRES_USE_NULL_POOL:
203-
# # if null pool is specified, then we need to make sure that
204-
# # we remove any passed in kwargs related to pool size that would
205-
# # cause the initialization to fail
206-
# final_engine_kwargs.update(engine_kwargs)
207-
208-
# final_engine_kwargs["poolclass"] = pool.NullPool
209-
# if "pool_size" in final_engine_kwargs:
210-
# del final_engine_kwargs["pool_size"]
211-
# if "max_overflow" in final_engine_kwargs:
212-
# del final_engine_kwargs["max_overflow"]
213-
# else:
214-
# final_engine_kwargs["pool_size"] = 20
215-
# final_engine_kwargs["max_overflow"] = 5
216-
# final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
217-
# final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
218-
219-
# # any passed in kwargs override the defaults
220-
# final_engine_kwargs.update(engine_kwargs)
221-
222-
# logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
223-
# # echo=True here for inspecting all emitted db queries
224-
# engine = create_engine(connection_string, **final_engine_kwargs)
225-
226-
# if USE_IAM_AUTH:
227-
# event.listen(engine, "do_connect", provide_iam_token)
228-
229-
# return engine
230-
231193
@classmethod
232194
def init_engine(
233195
cls,
@@ -410,18 +372,6 @@ def provide_iam_token_async(
410372
return _ASYNC_ENGINE
411373

412374

413-
# Listen for events on the synchronous Session class
414-
@event.listens_for(Session, "after_begin")
415-
def _set_search_path(
416-
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
417-
) -> None:
418-
"""Every time a new transaction is started,
419-
set the search_path from the session's info."""
420-
tenant_id = session.info.get("tenant_id")
421-
if tenant_id:
422-
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
423-
424-
425375
engine = get_sqlalchemy_async_engine()
426376
AsyncSessionLocal = sessionmaker( # type: ignore
427377
bind=engine,
@@ -430,33 +380,6 @@ def _set_search_path(
430380
)
431381

432382

433-
@asynccontextmanager
434-
async def get_async_session_with_tenant(
435-
tenant_id: str | None = None,
436-
) -> AsyncGenerator[AsyncSession, None]:
437-
if tenant_id is None:
438-
tenant_id = get_current_tenant_id()
439-
440-
if not is_valid_schema_name(tenant_id):
441-
logger.error(f"Invalid tenant ID: {tenant_id}")
442-
raise ValueError("Invalid tenant ID")
443-
444-
async with AsyncSessionLocal() as session:
445-
session.sync_session.info["tenant_id"] = tenant_id
446-
447-
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
448-
await session.execute(
449-
text(
450-
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
451-
)
452-
)
453-
454-
try:
455-
yield session
456-
finally:
457-
pass
458-
459-
460383
@contextmanager
461384
def get_session_with_current_tenant() -> Generator[Session, None, None]:
462385
tenant_id = get_current_tenant_id()
@@ -474,17 +397,24 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
474397
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
475398

476399

400+
def _set_search_path_on_checkout__listener(
401+
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
402+
) -> None:
403+
"""Listener to make sure we ALWAYS set the search path on checkout."""
404+
tenant_id = get_current_tenant_id()
405+
if tenant_id and is_valid_schema_name(tenant_id):
406+
with dbapi_conn.cursor() as cursor:
407+
cursor.execute(f'SET search_path TO "{tenant_id}"')
408+
409+
477410
@contextmanager
478411
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
479412
"""
480413
Generate a database session for a specific tenant.
481414
"""
482-
if tenant_id is None:
483-
tenant_id = POSTGRES_DEFAULT_SCHEMA
484-
485415
engine = get_sqlalchemy_engine()
486416

487-
event.listen(engine, "checkout", set_search_path_on_checkout)
417+
event.listen(engine, "checkout", _set_search_path_on_checkout__listener)
488418

489419
if not is_valid_schema_name(tenant_id):
490420
raise HTTPException(status_code=400, detail="Invalid tenant ID")
@@ -519,57 +449,84 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
519449
cursor.close()
520450

521451

522-
def set_search_path_on_checkout(
523-
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
524-
) -> None:
452+
def get_session() -> Generator[Session, None, None]:
453+
"""For use w/ Depends for FastAPI endpoints.
454+
455+
Has some additional validation, and likely should be merged
456+
with get_session_context_manager in the future."""
525457
tenant_id = get_current_tenant_id()
526-
if tenant_id and is_valid_schema_name(tenant_id):
527-
with dbapi_conn.cursor() as cursor:
528-
cursor.execute(f'SET search_path TO "{tenant_id}"')
458+
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
459+
raise BasicAuthenticationError(detail="User must authenticate")
460+
461+
if not is_valid_schema_name(tenant_id):
462+
raise HTTPException(status_code=400, detail="Invalid tenant ID")
529463

464+
with get_session_context_manager() as db_session:
465+
yield db_session
530466

531-
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
467+
468+
@contextlib.contextmanager
469+
def get_session_context_manager() -> Generator[Session, None, None]:
470+
"""Context manager for database sessions."""
532471
tenant_id = get_current_tenant_id()
533472
with get_session_with_tenant(tenant_id=tenant_id) as session:
534473
yield session
535474

536475

537-
def get_session() -> Generator[Session, None, None]:
538-
tenant_id = get_current_tenant_id()
539-
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
540-
raise BasicAuthenticationError(detail="User must authenticate")
476+
def _set_search_path_on_transaction__listener(
477+
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
478+
) -> None:
479+
"""Every time a new transaction is started,
480+
set the search_path from the session's info."""
481+
tenant_id = session.info.get("tenant_id")
482+
if tenant_id:
483+
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
541484

542-
engine = get_sqlalchemy_engine()
543485

544-
with Session(engine, expire_on_commit=False) as session:
545-
if MULTI_TENANT:
546-
if not is_valid_schema_name(tenant_id):
547-
raise HTTPException(status_code=400, detail="Invalid tenant ID")
548-
session.execute(text(f'SET search_path = "{tenant_id}"'))
549-
yield session
486+
async def get_async_session(
487+
tenant_id: str | None = None,
488+
) -> AsyncGenerator[AsyncSession, None]:
489+
"""For use w/ Depends for *async* FastAPI endpoints.
550490
491+
For standard `async with ... as ...` use, use get_async_session_context_manager.
492+
"""
493+
494+
if tenant_id is None:
495+
tenant_id = get_current_tenant_id()
551496

552-
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
553-
tenant_id = get_current_tenant_id()
554497
engine = get_sqlalchemy_async_engine()
498+
555499
async with AsyncSession(engine, expire_on_commit=False) as async_session:
556-
if MULTI_TENANT:
557-
if not is_valid_schema_name(tenant_id):
558-
raise HTTPException(status_code=400, detail="Invalid tenant ID")
559-
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
560-
yield async_session
500+
# set the search path on sync session as well to be extra safe
501+
event.listen(
502+
async_session.sync_session,
503+
"after_begin",
504+
_set_search_path_on_transaction__listener,
505+
)
561506

507+
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
508+
await async_session.execute(
509+
text(
510+
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
511+
)
512+
)
562513

563-
def get_session_context_manager() -> ContextManager[Session]:
564-
"""Context manager for database sessions."""
565-
return contextlib.contextmanager(get_session_generator_with_tenant)()
514+
if not is_valid_schema_name(tenant_id):
515+
raise HTTPException(status_code=400, detail="Invalid tenant ID")
566516

517+
# don't need to set the search path for self-hosted + default schema
518+
# this is also true for sync sessions, but just not adding it there for
519+
# now to simplify / not change too much
520+
if MULTI_TENANT or tenant_id != POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
521+
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
522+
523+
yield async_session
567524

568-
def get_session_factory() -> sessionmaker[Session]:
569-
global SessionFactory
570-
if SessionFactory is None:
571-
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
572-
return SessionFactory
525+
526+
def get_async_session_context_manager(
527+
tenant_id: str | None = None,
528+
) -> AsyncContextManager[AsyncSession]:
529+
return asynccontextmanager(get_async_session)(tenant_id)
573530

574531

575532
async def warm_up_connections(

backend/shared_configs/configs.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ def validate_cors_origin(origin: str) -> None:
142142
# Multi-tenancy configuration
143143
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
144144

145-
POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public"
145+
POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE = "public"
146+
POSTGRES_DEFAULT_SCHEMA = (
147+
os.environ.get("POSTGRES_DEFAULT_SCHEMA") or POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
148+
)
146149
DEFAULT_REDIS_PREFIX = os.environ.get("DEFAULT_REDIS_PREFIX") or "default"
147150

148151

0 commit comments

Comments
 (0)