Migrate database metadata to declarative base

This commit is contained in:
Raoul Snyman 2023-03-13 21:06:39 -07:00
parent 1fc1bd4124
commit 3a72f07520
3 changed files with 56 additions and 47 deletions

View File

@ -26,16 +26,27 @@ import json
import logging import logging
import os import os
from copy import copy from copy import copy
from pathlib import Path
from types import ModuleType
from typing import Optional, Tuple, Union
from urllib.parse import quote_plus as urlquote from urllib.parse import quote_plus as urlquote
from alembic.migration import MigrationContext from alembic.migration import MigrationContext
from alembic.operations import Operations from alembic.operations import Operations
from sqlalchemy import Column, ForeignKey, Integer, MetaData, Table, Unicode, UnicodeText, create_engine, types from sqlalchemy import Column, ForeignKey, MetaData, create_engine
from sqlalchemy.engine.url import URL, make_url from sqlalchemy.engine.url import URL, make_url
from sqlalchemy.exc import DBAPIError, InvalidRequestError, OperationalError, ProgrammingError, SQLAlchemyError from sqlalchemy.exc import DBAPIError, InvalidRequestError, OperationalError, ProgrammingError, SQLAlchemyError
from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import Session, backref, relationship, scoped_session, sessionmaker
from sqlalchemy.orm import backref, mapper, relationship, scoped_session, sessionmaker
from sqlalchemy.pool import NullPool from sqlalchemy.pool import NullPool
from sqlalchemy.types import Integer, TypeDecorator, Unicode, UnicodeText
# Maintain backwards compatibility with older versions of SQLAlchemy while supporting SQLAlchemy 1.4+
try:
from sqlalchemy.orm import declarative_base, declared_attr
from sqlalchemy.orm.decl_api import DeclarativeMeta
except ImportError:
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.ext.declarative.api import DeclarativeMeta
from openlp.core.common import delete_file from openlp.core.common import delete_file
from openlp.core.common.applocation import AppLocation from openlp.core.common.applocation import AppLocation
@ -70,7 +81,7 @@ def _get_scalar_result(engine, sql):
return conn.scalar(sql) return conn.scalar(sql)
def _sqlite_file_exists(database): def _sqlite_file_exists(database: str) -> bool:
if not os.path.isfile(database) or os.path.getsize(database) < 100: if not os.path.isfile(database) or os.path.getsize(database) < 100:
return False return False
@ -139,7 +150,8 @@ def database_exists(url):
return False return False
def init_db(url, auto_flush=True, auto_commit=False, base=None): def init_db(url: str, auto_flush: bool = True, auto_commit: bool = False, base: Optional[DeclarativeMeta] = None) \
-> Tuple[Session, MetaData]:
""" """
Initialise and return the session and metadata for a database Initialise and return the session and metadata for a database
@ -158,7 +170,7 @@ def init_db(url, auto_flush=True, auto_commit=False, base=None):
return session, metadata return session, metadata
def get_db_path(plugin_name, db_file_name=None): def get_db_path(plugin_name: str, db_file_name: Union[Path, str, None] = None) -> str:
""" """
Create a path to a database from the plugin name and database name Create a path to a database from the plugin name and database name
@ -176,7 +188,7 @@ def get_db_path(plugin_name, db_file_name=None):
return 'sqlite:///{path}/{name}'.format(path=AppLocation.get_section_data_path(plugin_name), name=db_file_name) return 'sqlite:///{path}/{name}'.format(path=AppLocation.get_section_data_path(plugin_name), name=db_file_name)
def handle_db_error(plugin_name, db_file_path): def handle_db_error(plugin_name: str, db_file_path: Path):
""" """
Log and report to the user that a database cannot be loaded Log and report to the user that a database cannot be loaded
@ -191,7 +203,7 @@ def handle_db_error(plugin_name, db_file_path):
'OpenLP cannot load your database.\n\nDatabase: {db}').format(db=db_path)) 'OpenLP cannot load your database.\n\nDatabase: {db}').format(db=db_path))
def init_url(plugin_name, db_file_name=None): def init_url(plugin_name: str, db_file_name: Union[Path, str, None] = None) -> str:
""" """
Construct the connection string for a database. Construct the connection string for a database.
@ -214,7 +226,7 @@ def init_url(plugin_name, db_file_name=None):
return db_url return db_url
def get_upgrade_op(session): def get_upgrade_op(session: Session) -> Operations:
""" """
Create a migration context and an operations object for performing upgrades. Create a migration context and an operations object for performing upgrades.
@ -282,12 +294,12 @@ class BaseModel(object):
return instance return instance
class PathType(types.TypeDecorator): class PathType(TypeDecorator):
""" """
Create a PathType for storing Path objects with SQLAlchemy. Behind the scenes we convert the Path object to a JSON Create a PathType for storing Path objects with SQLAlchemy. Behind the scenes we convert the Path object to a JSON
representation and store it as a Unicode type representation and store it as a Unicode type
""" """
impl = types.Unicode impl = Unicode
cache_ok = True cache_ok = True
def coerce_compared_value(self, op, value): def coerce_compared_value(self, op, value):
@ -329,34 +341,32 @@ class PathType(types.TypeDecorator):
return json.loads(value, cls=OpenLPJSONDecoder, base_path=data_path) return json.loads(value, cls=OpenLPJSONDecoder, base_path=data_path)
def upgrade_db(url, upgrade): def upgrade_db(url: str, upgrade: ModuleType) -> Tuple[int, int]:
""" """
Upgrade a database. Upgrade a database.
:param url: The url of the database to upgrade. :param url: The url of the database to upgrade.
:param upgrade: The python module that contains the upgrade instructions. :param upgrade: The python module that contains the upgrade instructions.
""" """
log.debug('Checking upgrades for DB {db}'.format(db=url))
if not database_exists(url): if not database_exists(url):
log.warning("Database {db} doesn't exist - skipping upgrade checks".format(db=url)) log.warning("Database {db} doesn't exist - skipping upgrade checks".format(db=url))
return 0, 0 return 0, 0
log.debug('Checking upgrades for DB {db}'.format(db=url)) Base = declarative_base(MetaData)
session, metadata = init_db(url) class Metadata(Base):
class Metadata(BaseModel):
""" """
Provides a class for the metadata table. Provides a class for the metadata table.
""" """
pass __tablename__ = 'metadata'
key = Column(Unicode(64), primary_key=True)
value = Column(UnicodeText(), default=None)
session, metadata = init_db(url, base=Base)
metadata.create_all(checkfirst=True)
metadata_table = Table(
'metadata', metadata,
Column('key', types.Unicode(64), primary_key=True),
Column('value', types.UnicodeText(), default=None)
)
metadata_table.create(checkfirst=True)
mapper(Metadata, metadata_table)
version_meta = session.query(Metadata).get('version') version_meta = session.query(Metadata).get('version')
if version_meta: if version_meta:
version = int(version_meta.value) version = int(version_meta.value)
@ -364,7 +374,7 @@ def upgrade_db(url, upgrade):
# Due to issues with other checks, if the version is not set in the DB then default to 0 # Due to issues with other checks, if the version is not set in the DB then default to 0
# and let the upgrade function handle the checks # and let the upgrade function handle the checks
version = 0 version = 0
version_meta = Metadata.populate(key='version', value=version) version_meta = Metadata(key='version', value=version)
session.add(version_meta) session.add(version_meta)
session.commit() session.commit()
if version > upgrade.__version__: if version > upgrade.__version__:
@ -387,7 +397,7 @@ def upgrade_db(url, upgrade):
'"upgrade_{version:d}", upgrade process has been halted.'.format(version=version)) '"upgrade_{version:d}", upgrade process has been halted.'.format(version=version))
break break
except (SQLAlchemyError, DBAPIError): except (SQLAlchemyError, DBAPIError):
version_meta = Metadata.populate(key='version', value=int(upgrade.__version__)) version_meta = Metadata(key='version', value=int(upgrade.__version__))
session.commit() session.commit()
upgrade_version = upgrade.__version__ upgrade_version = upgrade.__version__
version = int(version_meta.value) version = int(version_meta.value)
@ -395,7 +405,7 @@ def upgrade_db(url, upgrade):
return version, upgrade_version return version, upgrade_version
def delete_database(plugin_name, db_file_name=None): def delete_database(plugin_name: str, db_file_name: Optional[str] = None):
""" """
Remove a database file from the system. Remove a database file from the system.
@ -429,11 +439,8 @@ class Manager(object):
self.is_dirty = False self.is_dirty = False
self.session = None self.session = None
self.db_url = None self.db_url = None
if db_file_path: log.debug('Manager: Creating new DB url')
log.debug('Manager: Creating new DB url') self.db_url = init_url(plugin_name, db_file_path)
self.db_url = init_url(plugin_name, str(db_file_path)) # TOdO :PATHLIB
else:
self.db_url = init_url(plugin_name)
if not session: if not session:
try: try:
self.session = init_schema(self.db_url) self.session = init_schema(self.db_url)

View File

@ -53,11 +53,13 @@ def plugin_manager_env(registry, state):
Registry().register('settings', MagicMock()) Registry().register('settings', MagicMock())
def test_bootstrap_initialise(settings, state): @patch('openlp.core.lib.pluginmanager.Plugin.__subclasses__')
def test_bootstrap_initialise(mocked_subclasses, settings, state):
""" """
Test the PluginManager.bootstrap_initialise() method Test the PluginManager.bootstrap_initialise() method
""" """
# GIVEN: A plugin manager with some mocked out methods # GIVEN: A plugin manager with some mocked out methods
mocked_subclasses.return_value = [MagicMock()]
State().add_service('mediacontroller', 0) State().add_service('mediacontroller', 0)
State().update_pre_conditions('mediacontroller', True) State().update_pre_conditions('mediacontroller', True)
manager = PluginManager() manager = PluginManager()

View File

@ -85,6 +85,20 @@ def add_records(projector_db, test):
return added return added
@pytest.fixture()
def projector(temp_folder, settings):
"""
Set up anything necessary for all tests
"""
tmpdb_url = 'sqlite:///{db}'.format(db=os.path.join(temp_folder, TEST_DB))
with patch('openlp.core.projectors.db.init_url') as mocked_init_url:
mocked_init_url.return_value = tmpdb_url
proj = ProjectorDB()
yield proj
proj.session.close()
del proj
def test_upgrade_old_projector_db(temp_folder): def test_upgrade_old_projector_db(temp_folder):
""" """
Test that we can upgrade a version 1 db to the current schema Test that we can upgrade a version 1 db to the current schema
@ -102,20 +116,6 @@ def test_upgrade_old_projector_db(temp_folder):
assert updated_to_version == latest_version, 'The projector DB should have been upgrade to the latest version' assert updated_to_version == latest_version, 'The projector DB should have been upgrade to the latest version'
@pytest.fixture()
def projector(temp_folder, settings):
"""
Set up anything necessary for all tests
"""
tmpdb_url = 'sqlite:///{db}'.format(db=os.path.join(temp_folder, TEST_DB))
with patch('openlp.core.projectors.db.init_url') as mocked_init_url:
mocked_init_url.return_value = tmpdb_url
proj = ProjectorDB()
yield proj
proj.session.close()
del proj
def test_find_record_by_ip(projector): def test_find_record_by_ip(projector):
""" """
Test find record by IP Test find record by IP