mautrix-python/mautrix/util/async_db/upgrade.py

184 lines
6.7 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Awaitable, Callable, Optional, cast
import functools
import inspect
import logging
from mautrix.util.logging import TraceLogger
from .. import async_db
from .connection import LoggingConnection
from .errors import UnsupportedDatabaseVersion
from .scheme import Scheme
Upgrade = Callable[[LoggingConnection, Scheme], Awaitable[Optional[int]]]
UpgradeWithoutScheme = Callable[[LoggingConnection], Awaitable[Optional[int]]]
async def noop_upgrade(_: LoggingConnection, _2: Scheme) -> None:
pass
def _wrap_upgrade(fn: UpgradeWithoutScheme | Upgrade) -> Upgrade:
params = inspect.signature(fn).parameters
if len(params) == 1:
_wrapped: UpgradeWithoutScheme = cast(UpgradeWithoutScheme, fn)
@functools.wraps(_wrapped)
async def _wrapper(conn: LoggingConnection, _: Scheme) -> Optional[int]:
return await _wrapped(conn)
return _wrapper
else:
return fn
class UpgradeTable:
upgrades: list[Upgrade]
allow_unsupported: bool
database_name: str
version_table_name: str
log: TraceLogger
def __init__(
self,
allow_unsupported: bool = False,
version_table_name: str = "version",
database_name: str = "database",
log: logging.Logger | TraceLogger | None = None,
) -> None:
self.upgrades = []
self.allow_unsupported = allow_unsupported
self.version_table_name = version_table_name
self.database_name = database_name
self.log = log or logging.getLogger("mau.db.upgrade")
def register(
self,
_outer_fn: Upgrade | UpgradeWithoutScheme | None = None,
*,
index: int = -1,
description: str = "",
transaction: bool = True,
upgrades_to: int | Upgrade | None = None,
) -> Upgrade | Callable[[Upgrade | UpgradeWithoutScheme], Upgrade]:
if isinstance(index, str):
description = index
index = -1
def actually_register(fn: Upgrade | UpgradeWithoutScheme) -> Upgrade:
fn = _wrap_upgrade(fn)
fn.__mau_db_upgrade_description__ = description
fn.__mau_db_upgrade_transaction__ = transaction
fn.__mau_db_upgrade_destination__ = (
upgrades_to
if not upgrades_to or isinstance(upgrades_to, int)
else _wrap_upgrade(upgrades_to)
)
if index == -1 or index == len(self.upgrades):
self.upgrades.append(fn)
else:
if len(self.upgrades) <= index:
self.upgrades += [noop_upgrade] * (index - len(self.upgrades) + 1)
self.upgrades[index] = fn
return fn
return actually_register(_outer_fn) if _outer_fn else actually_register
async def _save_version(self, conn: LoggingConnection, version: int) -> None:
self.log.trace(f"Saving current version (v{version}) to database")
await conn.execute(f"DELETE FROM {self.version_table_name}")
await conn.execute(f"INSERT INTO {self.version_table_name} (version) VALUES ($1)", version)
async def upgrade(self, db: async_db.Database) -> None:
await db.execute(
f"""CREATE TABLE IF NOT EXISTS {self.version_table_name} (
version INTEGER PRIMARY KEY
)"""
)
row = await db.fetchrow(f"SELECT version FROM {self.version_table_name} LIMIT 1")
version = row["version"] if row else 0
if len(self.upgrades) < version:
unsupported_version_error = UnsupportedDatabaseVersion(
self.database_name, version, len(self.upgrades)
)
if not self.allow_unsupported:
raise unsupported_version_error
else:
self.log.warning(str(unsupported_version_error))
return
elif len(self.upgrades) == version:
self.log.debug(f"Database at v{version}, not upgrading")
return
async with db.acquire() as conn:
while version < len(self.upgrades):
old_version = version
upgrade = self.upgrades[version]
new_version = (
getattr(upgrade, "__mau_db_upgrade_destination__", None) or version + 1
)
if callable(new_version):
new_version = await new_version(conn, db.scheme)
desc = getattr(upgrade, "__mau_db_upgrade_description__", None)
suffix = f": {desc}" if desc else ""
self.log.debug(
f"Upgrading {self.database_name} from v{old_version} to v{new_version}{suffix}"
)
if getattr(upgrade, "__mau_db_upgrade_transaction__", True):
async with conn.transaction():
version = await upgrade(conn, db.scheme) or new_version
await self._save_version(conn, version)
else:
version = await upgrade(conn, db.scheme) or new_version
await self._save_version(conn, version)
if version != new_version:
self.log.warning(
f"Upgrading {self.database_name} actually went from v{old_version} "
f"to v{version}"
)
upgrade_tables: dict[str, UpgradeTable] = {}
def register_upgrade_table_parent_module(name: str) -> None:
upgrade_tables[name] = UpgradeTable()
def _find_upgrade_table(fn: Upgrade) -> UpgradeTable:
try:
module = fn.__module__
except AttributeError as e:
raise ValueError(
"Registering upgrades without an UpgradeTable requires the function "
"to have the __module__ attribute."
) from e
parts = module.split(".")
used_parts = []
last_error = None
for part in parts:
used_parts.append(part)
try:
return upgrade_tables[".".join(used_parts)]
except KeyError as e:
last_error = e
raise KeyError(
"Registering upgrades without an UpgradeTable requires you to register a parent "
"module with register_upgrade_table_parent_module first."
) from last_error
def register_upgrade(index: int = -1, description: str = "") -> Callable[[Upgrade], Upgrade]:
def actually_register(fn: Upgrade) -> Upgrade:
return _find_upgrade_table(fn).register(fn, index=index, description=description)
return actually_register