Source code for fmn.database.main

# SPDX-FileCopyrightText: Contributors to the Fedora Project
#
# SPDX-License-Identifier: MIT

from sqlalchemy import MetaData, create_engine, select
from sqlalchemy.engine import URL, Engine, make_url
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import declarative_base, sessionmaker

from ..core.config import get_settings


# use custom base for common convenience methods
[docs]class CustomBase:
[docs] @classmethod async def async_get(cls, db_session: AsyncSession, **attrs) -> "Base": """Get an object from the datbase. :param db_session: The SQLAlchemy session to use :return: the object """ return (await db_session.execute(select(cls).filter_by(**attrs))).scalar_one()
[docs] @classmethod async def async_get_or_create(cls, db_session: AsyncSession, **attrs) -> "Base": """Get an object from the database or create if missing. :param db_session: The SQLAlchemy session to use :return: the object The returned object will have an (ephemeral) boolean attribute `_was_created` which allows finding out if it existed previously or not. """ try: obj = await cls.async_get(db_session, **attrs) except NoResultFound: obj = cls(**attrs) db_session.add(obj) obj._obj_created = True await db_session.flush() else: obj._obj_created = False return obj
# use custom metadata to specify naming convention naming_convention = { "ix": "%(column_0_N_label)s_index", "uq": "%(table_name)s_%(column_0_N_name)s_key", "ck": "%(table_name)s_%(constraint_name)s_check", "fk": "%(table_name)s_%(column_0_N_name)s_%(referred_table_name)s_fkey", "pk": "%(table_name)s_pkey", } metadata = MetaData(naming_convention=naming_convention) Base = declarative_base(cls=CustomBase, metadata=metadata) async_session_maker = sessionmaker(class_=AsyncSession, expire_on_commit=False, future=True) sync_session_maker = sessionmaker(future=True, expire_on_commit=False)
[docs]def init_sync_model(sync_engine: Engine = None): if not sync_engine: sync_engine = get_sync_engine() sync_session_maker.configure(bind=sync_engine)
[docs]async def init_async_model(async_engine: AsyncEngine = None): if not async_engine: async_engine = get_async_engine() async_session_maker.configure(bind=async_engine)
[docs]def get_sync_engine(): db_config = get_settings().dict()["database"]["sqlalchemy"] db_config.setdefault("isolation_level", "SERIALIZABLE") return create_engine(**db_config)
def _async_from_sync_url(url: URL | str) -> URL: """Create an async DB URL from a conventional one.""" sync_url = make_url(url) try: dialect, _ = sync_url.drivername.split("+", 1) except ValueError: dialect = sync_url.drivername match dialect: case "sqlite": driver = "aiosqlite" case "postgresql": driver = "asyncpg" case _: raise ValueError(f"Don't know asyncio driver for dialect {dialect}") return sync_url.set(drivername=f"{dialect}+{driver}")
[docs]def get_async_engine(): db_config = get_settings().dict()["database"]["sqlalchemy"] db_config.setdefault("isolation_level", "SERIALIZABLE") db_config["url"] = _async_from_sync_url(db_config["url"]) return create_async_engine(**db_config)