mautrix-python/mautrix/bridge/bridge.py

355 lines
13 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any
from abc import ABC, abstractmethod
from enum import Enum
import asyncio
import sys
from aiohttp import web
from mautrix import __version__ as __mautrix_version__
from mautrix.api import HTTPAPI
from mautrix.appservice import AppService, ASStateStore
from mautrix.client.state_store.asyncpg import PgStateStore as PgClientStateStore
from mautrix.errors import MExclusive, MUnknownToken
from mautrix.types import RoomID, UserID
from mautrix.util.async_db import Database, DatabaseException, UpgradeTable
from mautrix.util.bridge_state import BridgeState, BridgeStateEvent, GlobalBridgeState
from mautrix.util.program import Program
from .. import bridge as br
from .state_store.asyncpg import PgBridgeStateStore
try:
import uvloop
except ImportError:
uvloop = None
class HomeserverSoftware(Enum):
STANDARD = "standard"
ASMUX = "asmux"
HUNGRY = "hungry"
@property
def is_hungry(self) -> bool:
return self == self.HUNGRY
@property
def is_asmux(self) -> bool:
return self == self.ASMUX
class Bridge(Program, ABC):
db: Database
az: AppService
state_store_class: type[ASStateStore] = PgBridgeStateStore
state_store: ASStateStore
upgrade_table: UpgradeTable
config_class: type[br.BaseBridgeConfig]
config: br.BaseBridgeConfig
matrix_class: type[br.BaseMatrixHandler]
matrix: br.BaseMatrixHandler
repo_url: str
markdown_version: str
manhole: br.commands.manhole.ManholeState | None
homeserver_software: HomeserverSoftware
beeper_network_name: str | None = None
beeper_service_name: str | None = None
def __init__(
self,
module: str = None,
name: str = None,
description: str = None,
command: str = None,
version: str = None,
config_class: type[br.BaseBridgeConfig] = None,
matrix_class: type[br.BaseMatrixHandler] = None,
state_store_class: type[ASStateStore] = None,
) -> None:
super().__init__(module, name, description, command, version, config_class)
if matrix_class:
self.matrix_class = matrix_class
if state_store_class:
self.state_store_class = state_store_class
self.manhole = None
def prepare_arg_parser(self) -> None:
super().prepare_arg_parser()
self.parser.add_argument(
"-g",
"--generate-registration",
action="store_true",
help="generate registration and quit",
)
self.parser.add_argument(
"-r",
"--registration",
type=str,
default="registration.yaml",
metavar="<path>",
help=(
"the path to save the generated registration to "
"(not needed for running the bridge)"
),
)
self.parser.add_argument(
"--ignore-unsupported-database",
action="store_true",
help="Run even if the database schema is too new",
)
self.parser.add_argument(
"--ignore-foreign-tables",
action="store_true",
help="Run even if the database contains tables from other programs (like Synapse)",
)
def preinit(self) -> None:
super().preinit()
if self.args.generate_registration:
self.generate_registration()
sys.exit(0)
def prepare(self) -> None:
if self.config.env:
self.log.debug(
"Loaded config overrides from environment: %s", list(self.config.env.keys())
)
super().prepare()
try:
self.homeserver_software = HomeserverSoftware(self.config["homeserver.software"])
except Exception:
self.log.fatal("Invalid value for homeserver.software in config")
sys.exit(11)
self.prepare_db()
self.prepare_appservice()
self.prepare_bridge()
def prepare_config(self) -> None:
self.config = self.config_class(
self.args.config,
self.args.registration,
self.base_config_path,
env_prefix=self.module.upper(),
)
if self.args.generate_registration:
self.config._check_tokens = False
self.load_and_update_config()
def generate_registration(self) -> None:
self.config.generate_registration()
self.config.save()
print(f"Registration generated and saved to {self.config.registration_path}")
def make_state_store(self) -> None:
if self.state_store_class is None:
raise RuntimeError("state_store_class is not set")
elif issubclass(self.state_store_class, PgBridgeStateStore):
self.state_store = self.state_store_class(
self.db, self.get_puppet, self.get_double_puppet
)
else:
self.state_store = self.state_store_class()
def prepare_appservice(self) -> None:
self.make_state_store()
mb = 1024**2
default_http_retry_count = self.config.get("homeserver.http_retry_count", None)
if self.name not in HTTPAPI.default_ua:
HTTPAPI.default_ua = f"{self.name}/{self.version} {HTTPAPI.default_ua}"
self.az = AppService(
server=self.config["homeserver.address"],
domain=self.config["homeserver.domain"],
verify_ssl=self.config["homeserver.verify_ssl"],
connection_limit=self.config["homeserver.connection_limit"],
id=self.config["appservice.id"],
as_token=self.config["appservice.as_token"],
hs_token=self.config["appservice.hs_token"],
tls_cert=self.config.get("appservice.tls_cert", None),
tls_key=self.config.get("appservice.tls_key", None),
bot_localpart=self.config["appservice.bot_username"],
ephemeral_events=self.config["appservice.ephemeral_events"],
encryption_events=self.config["bridge.encryption.appservice"],
default_ua=HTTPAPI.default_ua,
default_http_retry_count=default_http_retry_count,
log="mau.as",
loop=self.loop,
state_store=self.state_store,
bridge_name=self.name,
aiohttp_params={"client_max_size": self.config["appservice.max_body_size"] * mb},
)
self.az.app.router.add_post("/_matrix/app/com.beeper.bridge_state", self.get_bridge_state)
def prepare_db(self) -> None:
if not hasattr(self, "upgrade_table") or not self.upgrade_table:
raise RuntimeError("upgrade_table is not set")
self.db = Database.create(
self.config["appservice.database"],
upgrade_table=self.upgrade_table,
db_args=self.config["appservice.database_opts"],
owner_name=self.name,
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
def prepare_bridge(self) -> None:
self.matrix = self.matrix_class(bridge=self)
def _log_db_error(self, e: Exception) -> None:
self.log.critical("Failed to initialize database", exc_info=e)
if isinstance(e, DatabaseException) and e.explanation:
self.log.info(e.explanation)
sys.exit(25)
async def start_db(self) -> None:
if hasattr(self, "db") and isinstance(self.db, Database):
self.log.debug("Starting database...")
ignore_unsupported = self.args.ignore_unsupported_database
self.db.upgrade_table.allow_unsupported = ignore_unsupported
try:
await self.db.start()
if isinstance(self.state_store, PgClientStateStore):
self.state_store.upgrade_table.allow_unsupported = ignore_unsupported
await self.state_store.upgrade_table.upgrade(self.db)
if self.matrix.e2ee:
self.matrix.e2ee.crypto_db.allow_unsupported = ignore_unsupported
self.matrix.e2ee.crypto_db.override_pool(self.db)
except Exception as e:
self._log_db_error(e)
async def stop_db(self) -> None:
if hasattr(self, "db") and isinstance(self.db, Database):
await self.db.stop()
async def start(self) -> None:
await self.start_db()
self.log.debug("Starting appservice...")
await self.az.start(self.config["appservice.hostname"], self.config["appservice.port"])
try:
await self.matrix.wait_for_connection()
except MUnknownToken:
self.log.critical(
"The as_token was not accepted. Is the registration file installed "
"in your homeserver correctly?"
)
sys.exit(16)
except MExclusive:
self.log.critical(
"The as_token was accepted, but the /register request was not. "
"Are the homeserver domain and username template in the config "
"correct, and do they match the values in the registration?"
)
sys.exit(16)
except Exception:
self.log.critical("Failed to check connection to homeserver", exc_info=True)
sys.exit(16)
await self.matrix.init_encryption()
self.add_startup_actions(self.matrix.init_as_bot())
await super().start()
self.az.ready = True
status_endpoint = self.config["homeserver.status_endpoint"]
if status_endpoint and await self.count_logged_in_users() == 0:
state = BridgeState(state_event=BridgeStateEvent.UNCONFIGURED).fill()
while not await state.send(status_endpoint, self.az.as_token, self.log):
await asyncio.sleep(5)
async def system_exit(self) -> None:
if hasattr(self, "db") and isinstance(self.db, Database):
self.log.debug("Stopping database due to SystemExit")
await self.db.stop()
self.log.debug("Database stopped")
elif getattr(self, "db", None):
self.log.trace("Database not started at SystemExit")
async def stop(self) -> None:
if self.manhole:
self.manhole.close()
self.manhole = None
await self.az.stop()
await super().stop()
if self.matrix.e2ee:
await self.matrix.e2ee.stop()
await self.stop_db()
async def get_bridge_state(self, req: web.Request) -> web.Response:
if not self.az._check_token(req):
return web.json_response({"error": "Invalid auth token"}, status=401)
try:
user = await self.get_user(UserID(req.url.query["user_id"]), create=False)
except KeyError:
user = None
if user is None:
return web.json_response({"error": "User not found"}, status=404)
try:
states = await user.get_bridge_states()
except NotImplementedError:
return web.json_response({"error": "Bridge status not implemented"}, status=501)
for state in states:
await user.fill_bridge_state(state)
global_state = BridgeState(state_event=BridgeStateEvent.RUNNING).fill()
evt = GlobalBridgeState(
remote_states={state.remote_id: state for state in states}, bridge_state=global_state
)
return web.json_response(evt.serialize())
@abstractmethod
async def get_user(self, user_id: UserID, create: bool = True) -> br.BaseUser | None:
pass
@abstractmethod
async def get_portal(self, room_id: RoomID) -> br.BasePortal | None:
pass
@abstractmethod
async def get_puppet(self, user_id: UserID, create: bool = False) -> br.BasePuppet | None:
pass
@abstractmethod
async def get_double_puppet(self, user_id: UserID) -> br.BasePuppet | None:
pass
@abstractmethod
def is_bridge_ghost(self, user_id: UserID) -> bool:
pass
@abstractmethod
async def count_logged_in_users(self) -> int:
return 0
async def manhole_global_namespace(self, user_id: UserID) -> dict[str, Any]:
own_user = await self.get_user(user_id, create=False)
try:
own_puppet = await own_user.get_puppet()
except NotImplementedError:
own_puppet = None
return {
"bridge": self,
"manhole": self.manhole,
"own_user": own_user,
"own_puppet": own_puppet,
}
@property
def manhole_banner_python_version(self) -> str:
return f"Python {sys.version} on {sys.platform}"
@property
def manhole_banner_program_version(self) -> str:
return f"{self.name} {self.version} with mautrix-python {__mautrix_version__}"
def manhole_banner(self, user_id: UserID) -> str:
return (
f"{self.manhole_banner_python_version}\n"
f"{self.manhole_banner_program_version}\n\n"
f"Manhole opened by {user_id}\n"
)