mautrix-python/mautrix/util/async_db/aiosqlite.py

202 lines
6.8 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, AsyncContextManager
from contextlib import asynccontextmanager
import asyncio
import logging
import os
import re
import sqlite3
from yarl import URL
import aiosqlite
from .connection import LoggingConnection
from .database import Database
from .scheme import Scheme
from .upgrade import UpgradeTable
POSITIONAL_PARAM_PATTERN = re.compile(r"\$(\d+)")
class TxnConnection(aiosqlite.Connection):
def __init__(self, path: str, **kwargs) -> None:
def connector() -> sqlite3.Connection:
return sqlite3.connect(
path, detect_types=sqlite3.PARSE_DECLTYPES, isolation_level=None, **kwargs
)
super().__init__(connector, iter_chunk_size=64)
@asynccontextmanager
async def transaction(self) -> None:
await self.execute("BEGIN TRANSACTION")
try:
yield
except Exception:
await self.rollback()
raise
else:
await self.commit()
def __execute(self, query: str, *args: Any):
query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query)
return super().execute(query, args)
async def execute(
self, query: str, *args: Any, timeout: float | None = None
) -> aiosqlite.Cursor:
return await self.__execute(query, *args)
async def executemany(
self, query: str, *args: Any, timeout: float | None = None
) -> aiosqlite.Cursor:
query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query)
return await super().executemany(query, *args)
async def fetch(
self, query: str, *args: Any, timeout: float | None = None
) -> list[sqlite3.Row]:
async with self.__execute(query, *args) as cursor:
return list(await cursor.fetchall())
async def fetchrow(
self, query: str, *args: Any, timeout: float | None = None
) -> sqlite3.Row | None:
async with self.__execute(query, *args) as cursor:
return await cursor.fetchone()
async def fetchval(
self, query: str, *args: Any, column: int = 0, timeout: float | None = None
) -> Any:
row = await self.fetchrow(query, *args)
if row is None:
return None
return row[column]
class SQLiteDatabase(Database):
scheme = Scheme.SQLITE
_parent: SQLiteDatabase | None
_pool: asyncio.Queue[TxnConnection]
_stopped: bool
_conns: int
_init_commands: list[str]
def __init__(
self,
url: URL,
upgrade_table: UpgradeTable,
db_args: dict[str, Any] | None = None,
log: logging.Logger | None = None,
owner_name: str | None = None,
ignore_foreign_tables: bool = True,
) -> None:
super().__init__(
url,
db_args=db_args,
upgrade_table=upgrade_table,
log=log,
owner_name=owner_name,
ignore_foreign_tables=ignore_foreign_tables,
)
self._parent = None
self._path = url.path
self._pool = asyncio.Queue(self._db_args.pop("min_size", 1))
self._db_args.pop("max_size", None)
self._stopped = False
self._conns = 0
self._init_commands = self._add_missing_pragmas(self._db_args.pop("init_commands", []))
@staticmethod
def _add_missing_pragmas(init_commands: list[str]) -> list[str]:
has_foreign_keys = False
has_journal_mode = False
has_synchronous = False
has_busy_timeout = False
for cmd in init_commands:
if "PRAGMA" not in cmd:
continue
if "foreign_keys" in cmd:
has_foreign_keys = True
elif "journal_mode" in cmd:
has_journal_mode = True
elif "synchronous" in cmd:
has_synchronous = True
elif "busy_timeout" in cmd:
has_busy_timeout = True
if not has_foreign_keys:
init_commands.append("PRAGMA foreign_keys = ON")
if not has_journal_mode:
init_commands.append("PRAGMA journal_mode = WAL")
if not has_synchronous and "PRAGMA journal_mode = WAL" in init_commands:
init_commands.append("PRAGMA synchronous = NORMAL")
if not has_busy_timeout:
init_commands.append("PRAGMA busy_timeout = 5000")
return init_commands
def override_pool(self, db: Database) -> None:
assert isinstance(db, SQLiteDatabase)
self._parent = db
async def start(self) -> None:
if self._parent:
await super().start()
return
if self._conns:
raise RuntimeError("database pool has already been started")
elif self._stopped:
raise RuntimeError("database pool can't be restarted")
self.log.debug(f"Connecting to {self.url}")
self.log.debug(f"Database connection init commands: {self._init_commands}")
if os.path.exists(self._path):
if not os.access(self._path, os.W_OK):
self.log.warning("Database file doesn't seem writable")
elif not os.access(os.path.dirname(os.path.abspath(self._path)), os.W_OK):
self.log.warning("Database file doesn't exist and directory doesn't seem writable")
for _ in range(self._pool.maxsize):
conn = await TxnConnection(self._path, **self._db_args)
if self._init_commands:
cur = await conn.cursor()
for command in self._init_commands:
self.log.trace("Executing init command: %s", command)
await cur.execute(command)
await conn.commit()
conn.row_factory = sqlite3.Row
self._pool.put_nowait(conn)
self._conns += 1
await super().start()
async def stop(self) -> None:
if self._parent:
return
self._stopped = True
while self._conns > 0:
conn = await self._pool.get()
self._conns -= 1
await conn.close()
def acquire(self) -> AsyncContextManager[LoggingConnection]:
if self._parent:
return self._parent.acquire()
return self._acquire()
@asynccontextmanager
async def _acquire(self) -> LoggingConnection:
if self._stopped:
raise RuntimeError("database pool has been stopped")
conn = await self._pool.get()
try:
yield LoggingConnection(self.scheme, conn, self.log)
finally:
self._pool.put_nowait(conn)
Database.schemes["sqlite"] = SQLiteDatabase
Database.schemes["sqlite3"] = SQLiteDatabase