Migrate database metadata to declarative base

This commit is contained in:
Raoul Snyman 2023-03-13 21:06:39 -07:00
parent 1fc1bd4124
commit 3a72f07520
3 changed files with 56 additions and 47 deletions

View File

@ -26,16 +26,27 @@ import json
import logging
import os
from copy import copy
from pathlib import Path
from types import ModuleType
from typing import Optional, Tuple, Union
from urllib.parse import quote_plus as urlquote
from alembic.migration import MigrationContext
from alembic.operations import Operations
from sqlalchemy import Column, ForeignKey, Integer, MetaData, Table, Unicode, UnicodeText, create_engine, types
from sqlalchemy import Column, ForeignKey, MetaData, create_engine
from sqlalchemy.engine.url import URL, make_url
from sqlalchemy.exc import DBAPIError, InvalidRequestError, OperationalError, ProgrammingError, SQLAlchemyError
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import backref, mapper, relationship, scoped_session, sessionmaker
from sqlalchemy.orm import Session, backref, relationship, scoped_session, sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy.types import Integer, TypeDecorator, Unicode, UnicodeText
# Maintain backwards compatibility with older versions of SQLAlchemy while supporting SQLAlchemy 1.4+
try:
from sqlalchemy.orm import declarative_base, declared_attr
from sqlalchemy.orm.decl_api import DeclarativeMeta
except ImportError:
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.ext.declarative.api import DeclarativeMeta
from openlp.core.common import delete_file
from openlp.core.common.applocation import AppLocation
@ -70,7 +81,7 @@ def _get_scalar_result(engine, sql):
return conn.scalar(sql)
def _sqlite_file_exists(database):
def _sqlite_file_exists(database: str) -> bool:
if not os.path.isfile(database) or os.path.getsize(database) < 100:
return False
@ -139,7 +150,8 @@ def database_exists(url):
return False
def init_db(url, auto_flush=True, auto_commit=False, base=None):
def init_db(url: str, auto_flush: bool = True, auto_commit: bool = False, base: Optional[DeclarativeMeta] = None) \
-> Tuple[Session, MetaData]:
"""
Initialise and return the session and metadata for a database
@ -158,7 +170,7 @@ def init_db(url, auto_flush=True, auto_commit=False, base=None):
return session, metadata
def get_db_path(plugin_name, db_file_name=None):
def get_db_path(plugin_name: str, db_file_name: Union[Path, str, None] = None) -> str:
"""
Create a path to a database from the plugin name and database name
@ -176,7 +188,7 @@ def get_db_path(plugin_name, db_file_name=None):
return 'sqlite:///{path}/{name}'.format(path=AppLocation.get_section_data_path(plugin_name), name=db_file_name)
def handle_db_error(plugin_name, db_file_path):
def handle_db_error(plugin_name: str, db_file_path: Path):
"""
Log and report to the user that a database cannot be loaded
@ -191,7 +203,7 @@ def handle_db_error(plugin_name, db_file_path):
'OpenLP cannot load your database.\n\nDatabase: {db}').format(db=db_path))
def init_url(plugin_name, db_file_name=None):
def init_url(plugin_name: str, db_file_name: Union[Path, str, None] = None) -> str:
"""
Construct the connection string for a database.
@ -214,7 +226,7 @@ def init_url(plugin_name, db_file_name=None):
return db_url
def get_upgrade_op(session):
def get_upgrade_op(session: Session) -> Operations:
"""
Create a migration context and an operations object for performing upgrades.
@ -282,12 +294,12 @@ class BaseModel(object):
return instance
class PathType(types.TypeDecorator):
class PathType(TypeDecorator):
"""
Create a PathType for storing Path objects with SQLAlchemy. Behind the scenes we convert the Path object to a JSON
representation and store it as a Unicode type
"""
impl = types.Unicode
impl = Unicode
cache_ok = True
def coerce_compared_value(self, op, value):
@ -329,34 +341,32 @@ class PathType(types.TypeDecorator):
return json.loads(value, cls=OpenLPJSONDecoder, base_path=data_path)
def upgrade_db(url, upgrade):
def upgrade_db(url: str, upgrade: ModuleType) -> Tuple[int, int]:
"""
Upgrade a database.
:param url: The url of the database to upgrade.
:param upgrade: The python module that contains the upgrade instructions.
"""
log.debug('Checking upgrades for DB {db}'.format(db=url))
if not database_exists(url):
log.warning("Database {db} doesn't exist - skipping upgrade checks".format(db=url))
return 0, 0
log.debug('Checking upgrades for DB {db}'.format(db=url))
Base = declarative_base(MetaData)
session, metadata = init_db(url)
class Metadata(BaseModel):
class Metadata(Base):
"""
Provides a class for the metadata table.
"""
pass
__tablename__ = 'metadata'
key = Column(Unicode(64), primary_key=True)
value = Column(UnicodeText(), default=None)
session, metadata = init_db(url, base=Base)
metadata.create_all(checkfirst=True)
metadata_table = Table(
'metadata', metadata,
Column('key', types.Unicode(64), primary_key=True),
Column('value', types.UnicodeText(), default=None)
)
metadata_table.create(checkfirst=True)
mapper(Metadata, metadata_table)
version_meta = session.query(Metadata).get('version')
if version_meta:
version = int(version_meta.value)
@ -364,7 +374,7 @@ def upgrade_db(url, upgrade):
# Due to issues with other checks, if the version is not set in the DB then default to 0
# and let the upgrade function handle the checks
version = 0
version_meta = Metadata.populate(key='version', value=version)
version_meta = Metadata(key='version', value=version)
session.add(version_meta)
session.commit()
if version > upgrade.__version__:
@ -387,7 +397,7 @@ def upgrade_db(url, upgrade):
'"upgrade_{version:d}", upgrade process has been halted.'.format(version=version))
break
except (SQLAlchemyError, DBAPIError):
version_meta = Metadata.populate(key='version', value=int(upgrade.__version__))
version_meta = Metadata(key='version', value=int(upgrade.__version__))
session.commit()
upgrade_version = upgrade.__version__
version = int(version_meta.value)
@ -395,7 +405,7 @@ def upgrade_db(url, upgrade):
return version, upgrade_version
def delete_database(plugin_name, db_file_name=None):
def delete_database(plugin_name: str, db_file_name: Optional[str] = None):
"""
Remove a database file from the system.
@ -429,11 +439,8 @@ class Manager(object):
self.is_dirty = False
self.session = None
self.db_url = None
if db_file_path:
log.debug('Manager: Creating new DB url')
self.db_url = init_url(plugin_name, str(db_file_path)) # TOdO :PATHLIB
else:
self.db_url = init_url(plugin_name)
log.debug('Manager: Creating new DB url')
self.db_url = init_url(plugin_name, db_file_path)
if not session:
try:
self.session = init_schema(self.db_url)

View File

@ -53,11 +53,13 @@ def plugin_manager_env(registry, state):
Registry().register('settings', MagicMock())
def test_bootstrap_initialise(settings, state):
@patch('openlp.core.lib.pluginmanager.Plugin.__subclasses__')
def test_bootstrap_initialise(mocked_subclasses, settings, state):
"""
Test the PluginManager.bootstrap_initialise() method
"""
# GIVEN: A plugin manager with some mocked out methods
mocked_subclasses.return_value = [MagicMock()]
State().add_service('mediacontroller', 0)
State().update_pre_conditions('mediacontroller', True)
manager = PluginManager()

View File

@ -85,6 +85,20 @@ def add_records(projector_db, test):
return added
@pytest.fixture()
def projector(temp_folder, settings):
"""
Set up anything necessary for all tests
"""
tmpdb_url = 'sqlite:///{db}'.format(db=os.path.join(temp_folder, TEST_DB))
with patch('openlp.core.projectors.db.init_url') as mocked_init_url:
mocked_init_url.return_value = tmpdb_url
proj = ProjectorDB()
yield proj
proj.session.close()
del proj
def test_upgrade_old_projector_db(temp_folder):
"""
Test that we can upgrade a version 1 db to the current schema
@ -102,20 +116,6 @@ def test_upgrade_old_projector_db(temp_folder):
assert updated_to_version == latest_version, 'The projector DB should have been upgrade to the latest version'
@pytest.fixture()
def projector(temp_folder, settings):
"""
Set up anything necessary for all tests
"""
tmpdb_url = 'sqlite:///{db}'.format(db=os.path.join(temp_folder, TEST_DB))
with patch('openlp.core.projectors.db.init_url') as mocked_init_url:
mocked_init_url.return_value = tmpdb_url
proj = ProjectorDB()
yield proj
proj.session.close()
del proj
def test_find_record_by_ip(projector):
"""
Test find record by IP