Source code for sqlalchemy_helpers.flask_ext
# SPDX-FileCopyrightText: 2023 Contributors to the Fedora Project
#
# SPDX-License-Identifier: LGPL-3.0-or-later
"""
Flask integration of database management.
"""
import os
from typing import Any, Callable, cast, TypeVar
import click
from flask import abort, current_app, Flask, has_app_context
from flask.cli import AppGroup
from sqlalchemy.orm import DeclarativeBase, Session
from sqlalchemy.sql.expression import Select
from werkzeug.utils import find_modules, import_string
from .manager import DatabaseManager, SyncResult
def _get_manager(
engine_args: dict[str, Any] | None = None, app: Flask | None = None
) -> DatabaseManager:
"""Get the database manager using the Flask app's configuration."""
app = app or current_app
uri = app.config["SQLALCHEMY_DATABASE_URI"]
alembic_location = app.config["DB_ALEMBIC_LOCATION"]
base_model = app.extensions[DatabaseExtension._app_base_model_name]
manager = DatabaseManager(uri, alembic_location, engine_args=engine_args, base_model=base_model)
return manager
def _syncdb() -> None:
"""Run :meth:`DatabaseManager.sync` on the command-line."""
manager = _get_manager()
result = manager.sync()
if result == SyncResult.CREATED:
click.echo("Database created.")
elif result == SyncResult.UPGRADED:
click.echo("Database upgraded.")
elif result == SyncResult.ALREADY_UP_TO_DATE:
click.echo("Database already up-to-date.")
else:
click.echo(f"Unexpected sync result: {result}", err=True)
# Ref: https://flask.palletsprojects.com/en/2.0.x/extensiondev/
[docs]
class DatabaseExtension:
"""A Flask extension to configure the database manager according the the app's configuration.
It cleans up database connections at the end of the requests, and creates the CLI endpoint to
sync the database schema.
"""
_app_manager_name = "_sqlah_database_manager"
_app_base_model_name = "_sqlah_base_model"
def __init__(self, app: Flask | None = None, base_model: DeclarativeBase | None = None):
self.app = app
self._base_model = base_model
if app is not None:
self.init_app(app, base_model=self._base_model)
[docs]
def init_app(self, app: Flask, base_model: DeclarativeBase | None = None) -> None:
"""Initialize the extention on the provided Flask app
Args:
app (flask.Flask): the Flask application.
"""
base_model = base_model or self._base_model
# Set config defaults
app.config.setdefault("SQLALCHEMY_DATABASE_URI", "sqlite:///:memory:")
app.config.setdefault("DB_ALEMBIC_LOCATION", os.path.join(app.root_path, "migrations"))
main_module = app.import_name
if main_module.endswith(".app"):
main_module = main_module[:-4]
app.config.setdefault("DB_MODELS_LOCATION", f"{main_module}.models")
# Connect hook
app.before_request(self.before_request)
# Disconnect hook
app.teardown_appcontext(self.teardown)
# Store the base_model
app.extensions[self._app_base_model_name] = base_model
# CLI
db_cli = AppGroup("db", help="Database operations.")
db_cli.command("sync", help="Create or migrate the database.")(_syncdb)
app.cli.add_command(db_cli)
# Import all modules here that might define models so that
# they will be registered properly on the metadata.
models_location = app.config["DB_MODELS_LOCATION"]
try:
for module in find_modules(models_location, include_packages=True, recursive=True):
import_string(module)
except ValueError:
# It's just a module, importing it is enough
import_string(models_location)
[docs]
def teardown(self, exception: BaseException | None) -> None:
"""Close the database connection at the end of each requests."""
if self._app_manager_name in current_app.extensions:
current_app.extensions[self._app_manager_name].Session.remove()
[docs]
def before_request(self) -> None:
"""Prepare the database manager at the start of each request.
This is necessary to allow access to the ``Model.get_*`` methods.
"""
# Just create the manager
self.manager # noqa: B018
@property
def session(self) -> Session:
"""sqlalchemy.session.Session: the database Session instance to use."""
return self.manager.Session()
@property
def manager(self) -> DatabaseManager:
"""DatabaseManager: the instance of the database manager."""
if self._app_manager_name not in current_app.extensions:
current_app.extensions[self._app_manager_name] = _get_manager()
return cast(DatabaseManager, current_app.extensions[self._app_manager_name])
# View helpers
M = TypeVar("M")
[docs]
def get_or_404(Model: type[M], pk: Any, description: str | None = None) -> M:
"""Like ``query.get`` but aborts with 404 if not found.
Args:
Model: a model class.
pk: the primary key of the desired record.
description: a message for the 404 error if not found.
"""
rv: M | None = Model.get_by_pk(pk) # type: ignore
if rv is None:
abort(404, description=description)
return rv
[docs]
def first_or_404(
query: Select[tuple[M]], description: str | None = None, session: Session | None = None
) -> M:
"""Like ``session.scalars(query).first()`` but aborts with 404 if not found.
Args:
query: a query to retrieve.
description: a message for the 404 error if no records are found.
session: the database session, or ``None`` to use the default session.
"""
if not session:
manager: DatabaseManager = current_app.extensions[DatabaseExtension._app_manager_name]
session = manager.Session()
rv = session.scalars(query).first()
if rv is None:
abort(404, description=description)
return rv
# Useful in alembic's env.py
[docs]
def get_url_from_app(app_factory: Callable[..., Flask]) -> str:
"""Get the DB URI from the app configuration
Create the application if it hasn't been created yet. This is useful in Alembic's ``env.py``.
Args: app_factory: the Flask application factory, to be called if this function is
called outside of and application context.
"""
if not has_app_context():
app = app_factory()
return cast(str, app.config["SQLALCHEMY_DATABASE_URI"])
else:
return cast(str, current_app.config["SQLALCHEMY_DATABASE_URI"])