355 lines
13 KiB
Python
355 lines
13 KiB
Python
# Copyright (c) 2022 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
import asyncio
|
|
import sys
|
|
|
|
from aiohttp import web
|
|
|
|
from mautrix import __version__ as __mautrix_version__
|
|
from mautrix.api import HTTPAPI
|
|
from mautrix.appservice import AppService, ASStateStore
|
|
from mautrix.client.state_store.asyncpg import PgStateStore as PgClientStateStore
|
|
from mautrix.errors import MExclusive, MUnknownToken
|
|
from mautrix.types import RoomID, UserID
|
|
from mautrix.util.async_db import Database, DatabaseException, UpgradeTable
|
|
from mautrix.util.bridge_state import BridgeState, BridgeStateEvent, GlobalBridgeState
|
|
from mautrix.util.program import Program
|
|
|
|
from .. import bridge as br
|
|
from .state_store.asyncpg import PgBridgeStateStore
|
|
|
|
try:
|
|
import uvloop
|
|
except ImportError:
|
|
uvloop = None
|
|
|
|
|
|
class HomeserverSoftware(Enum):
|
|
STANDARD = "standard"
|
|
ASMUX = "asmux"
|
|
HUNGRY = "hungry"
|
|
|
|
@property
|
|
def is_hungry(self) -> bool:
|
|
return self == self.HUNGRY
|
|
|
|
@property
|
|
def is_asmux(self) -> bool:
|
|
return self == self.ASMUX
|
|
|
|
|
|
class Bridge(Program, ABC):
|
|
db: Database
|
|
az: AppService
|
|
state_store_class: type[ASStateStore] = PgBridgeStateStore
|
|
state_store: ASStateStore
|
|
upgrade_table: UpgradeTable
|
|
config_class: type[br.BaseBridgeConfig]
|
|
config: br.BaseBridgeConfig
|
|
matrix_class: type[br.BaseMatrixHandler]
|
|
matrix: br.BaseMatrixHandler
|
|
repo_url: str
|
|
markdown_version: str
|
|
manhole: br.commands.manhole.ManholeState | None
|
|
homeserver_software: HomeserverSoftware
|
|
beeper_network_name: str | None = None
|
|
beeper_service_name: str | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
module: str = None,
|
|
name: str = None,
|
|
description: str = None,
|
|
command: str = None,
|
|
version: str = None,
|
|
config_class: type[br.BaseBridgeConfig] = None,
|
|
matrix_class: type[br.BaseMatrixHandler] = None,
|
|
state_store_class: type[ASStateStore] = None,
|
|
) -> None:
|
|
super().__init__(module, name, description, command, version, config_class)
|
|
if matrix_class:
|
|
self.matrix_class = matrix_class
|
|
if state_store_class:
|
|
self.state_store_class = state_store_class
|
|
self.manhole = None
|
|
|
|
def prepare_arg_parser(self) -> None:
|
|
super().prepare_arg_parser()
|
|
self.parser.add_argument(
|
|
"-g",
|
|
"--generate-registration",
|
|
action="store_true",
|
|
help="generate registration and quit",
|
|
)
|
|
self.parser.add_argument(
|
|
"-r",
|
|
"--registration",
|
|
type=str,
|
|
default="registration.yaml",
|
|
metavar="<path>",
|
|
help=(
|
|
"the path to save the generated registration to "
|
|
"(not needed for running the bridge)"
|
|
),
|
|
)
|
|
self.parser.add_argument(
|
|
"--ignore-unsupported-database",
|
|
action="store_true",
|
|
help="Run even if the database schema is too new",
|
|
)
|
|
self.parser.add_argument(
|
|
"--ignore-foreign-tables",
|
|
action="store_true",
|
|
help="Run even if the database contains tables from other programs (like Synapse)",
|
|
)
|
|
|
|
def preinit(self) -> None:
|
|
super().preinit()
|
|
if self.args.generate_registration:
|
|
self.generate_registration()
|
|
sys.exit(0)
|
|
|
|
def prepare(self) -> None:
|
|
if self.config.env:
|
|
self.log.debug(
|
|
"Loaded config overrides from environment: %s", list(self.config.env.keys())
|
|
)
|
|
super().prepare()
|
|
try:
|
|
self.homeserver_software = HomeserverSoftware(self.config["homeserver.software"])
|
|
except Exception:
|
|
self.log.fatal("Invalid value for homeserver.software in config")
|
|
sys.exit(11)
|
|
self.prepare_db()
|
|
self.prepare_appservice()
|
|
self.prepare_bridge()
|
|
|
|
def prepare_config(self) -> None:
|
|
self.config = self.config_class(
|
|
self.args.config,
|
|
self.args.registration,
|
|
self.base_config_path,
|
|
env_prefix=self.module.upper(),
|
|
)
|
|
if self.args.generate_registration:
|
|
self.config._check_tokens = False
|
|
self.load_and_update_config()
|
|
|
|
def generate_registration(self) -> None:
|
|
self.config.generate_registration()
|
|
self.config.save()
|
|
print(f"Registration generated and saved to {self.config.registration_path}")
|
|
|
|
def make_state_store(self) -> None:
|
|
if self.state_store_class is None:
|
|
raise RuntimeError("state_store_class is not set")
|
|
elif issubclass(self.state_store_class, PgBridgeStateStore):
|
|
self.state_store = self.state_store_class(
|
|
self.db, self.get_puppet, self.get_double_puppet
|
|
)
|
|
else:
|
|
self.state_store = self.state_store_class()
|
|
|
|
def prepare_appservice(self) -> None:
|
|
self.make_state_store()
|
|
mb = 1024**2
|
|
default_http_retry_count = self.config.get("homeserver.http_retry_count", None)
|
|
if self.name not in HTTPAPI.default_ua:
|
|
HTTPAPI.default_ua = f"{self.name}/{self.version} {HTTPAPI.default_ua}"
|
|
self.az = AppService(
|
|
server=self.config["homeserver.address"],
|
|
domain=self.config["homeserver.domain"],
|
|
verify_ssl=self.config["homeserver.verify_ssl"],
|
|
connection_limit=self.config["homeserver.connection_limit"],
|
|
id=self.config["appservice.id"],
|
|
as_token=self.config["appservice.as_token"],
|
|
hs_token=self.config["appservice.hs_token"],
|
|
tls_cert=self.config.get("appservice.tls_cert", None),
|
|
tls_key=self.config.get("appservice.tls_key", None),
|
|
bot_localpart=self.config["appservice.bot_username"],
|
|
ephemeral_events=self.config["appservice.ephemeral_events"],
|
|
encryption_events=self.config["bridge.encryption.appservice"],
|
|
default_ua=HTTPAPI.default_ua,
|
|
default_http_retry_count=default_http_retry_count,
|
|
log="mau.as",
|
|
loop=self.loop,
|
|
state_store=self.state_store,
|
|
bridge_name=self.name,
|
|
aiohttp_params={"client_max_size": self.config["appservice.max_body_size"] * mb},
|
|
)
|
|
self.az.app.router.add_post("/_matrix/app/com.beeper.bridge_state", self.get_bridge_state)
|
|
|
|
def prepare_db(self) -> None:
|
|
if not hasattr(self, "upgrade_table") or not self.upgrade_table:
|
|
raise RuntimeError("upgrade_table is not set")
|
|
self.db = Database.create(
|
|
self.config["appservice.database"],
|
|
upgrade_table=self.upgrade_table,
|
|
db_args=self.config["appservice.database_opts"],
|
|
owner_name=self.name,
|
|
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
|
)
|
|
|
|
def prepare_bridge(self) -> None:
|
|
self.matrix = self.matrix_class(bridge=self)
|
|
|
|
def _log_db_error(self, e: Exception) -> None:
|
|
self.log.critical("Failed to initialize database", exc_info=e)
|
|
if isinstance(e, DatabaseException) and e.explanation:
|
|
self.log.info(e.explanation)
|
|
sys.exit(25)
|
|
|
|
async def start_db(self) -> None:
|
|
if hasattr(self, "db") and isinstance(self.db, Database):
|
|
self.log.debug("Starting database...")
|
|
ignore_unsupported = self.args.ignore_unsupported_database
|
|
self.db.upgrade_table.allow_unsupported = ignore_unsupported
|
|
try:
|
|
await self.db.start()
|
|
if isinstance(self.state_store, PgClientStateStore):
|
|
self.state_store.upgrade_table.allow_unsupported = ignore_unsupported
|
|
await self.state_store.upgrade_table.upgrade(self.db)
|
|
if self.matrix.e2ee:
|
|
self.matrix.e2ee.crypto_db.allow_unsupported = ignore_unsupported
|
|
self.matrix.e2ee.crypto_db.override_pool(self.db)
|
|
except Exception as e:
|
|
self._log_db_error(e)
|
|
|
|
async def stop_db(self) -> None:
|
|
if hasattr(self, "db") and isinstance(self.db, Database):
|
|
await self.db.stop()
|
|
|
|
async def start(self) -> None:
|
|
await self.start_db()
|
|
|
|
self.log.debug("Starting appservice...")
|
|
await self.az.start(self.config["appservice.hostname"], self.config["appservice.port"])
|
|
try:
|
|
await self.matrix.wait_for_connection()
|
|
except MUnknownToken:
|
|
self.log.critical(
|
|
"The as_token was not accepted. Is the registration file installed "
|
|
"in your homeserver correctly?"
|
|
)
|
|
sys.exit(16)
|
|
except MExclusive:
|
|
self.log.critical(
|
|
"The as_token was accepted, but the /register request was not. "
|
|
"Are the homeserver domain and username template in the config "
|
|
"correct, and do they match the values in the registration?"
|
|
)
|
|
sys.exit(16)
|
|
except Exception:
|
|
self.log.critical("Failed to check connection to homeserver", exc_info=True)
|
|
sys.exit(16)
|
|
|
|
await self.matrix.init_encryption()
|
|
self.add_startup_actions(self.matrix.init_as_bot())
|
|
await super().start()
|
|
self.az.ready = True
|
|
|
|
status_endpoint = self.config["homeserver.status_endpoint"]
|
|
if status_endpoint and await self.count_logged_in_users() == 0:
|
|
state = BridgeState(state_event=BridgeStateEvent.UNCONFIGURED).fill()
|
|
while not await state.send(status_endpoint, self.az.as_token, self.log):
|
|
await asyncio.sleep(5)
|
|
|
|
async def system_exit(self) -> None:
|
|
if hasattr(self, "db") and isinstance(self.db, Database):
|
|
self.log.debug("Stopping database due to SystemExit")
|
|
await self.db.stop()
|
|
self.log.debug("Database stopped")
|
|
elif getattr(self, "db", None):
|
|
self.log.trace("Database not started at SystemExit")
|
|
|
|
async def stop(self) -> None:
|
|
if self.manhole:
|
|
self.manhole.close()
|
|
self.manhole = None
|
|
await self.az.stop()
|
|
await super().stop()
|
|
if self.matrix.e2ee:
|
|
await self.matrix.e2ee.stop()
|
|
await self.stop_db()
|
|
|
|
async def get_bridge_state(self, req: web.Request) -> web.Response:
|
|
if not self.az._check_token(req):
|
|
return web.json_response({"error": "Invalid auth token"}, status=401)
|
|
try:
|
|
user = await self.get_user(UserID(req.url.query["user_id"]), create=False)
|
|
except KeyError:
|
|
user = None
|
|
if user is None:
|
|
return web.json_response({"error": "User not found"}, status=404)
|
|
try:
|
|
states = await user.get_bridge_states()
|
|
except NotImplementedError:
|
|
return web.json_response({"error": "Bridge status not implemented"}, status=501)
|
|
for state in states:
|
|
await user.fill_bridge_state(state)
|
|
global_state = BridgeState(state_event=BridgeStateEvent.RUNNING).fill()
|
|
evt = GlobalBridgeState(
|
|
remote_states={state.remote_id: state for state in states}, bridge_state=global_state
|
|
)
|
|
return web.json_response(evt.serialize())
|
|
|
|
@abstractmethod
|
|
async def get_user(self, user_id: UserID, create: bool = True) -> br.BaseUser | None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_portal(self, room_id: RoomID) -> br.BasePortal | None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_puppet(self, user_id: UserID, create: bool = False) -> br.BasePuppet | None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_double_puppet(self, user_id: UserID) -> br.BasePuppet | None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def is_bridge_ghost(self, user_id: UserID) -> bool:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def count_logged_in_users(self) -> int:
|
|
return 0
|
|
|
|
async def manhole_global_namespace(self, user_id: UserID) -> dict[str, Any]:
|
|
own_user = await self.get_user(user_id, create=False)
|
|
try:
|
|
own_puppet = await own_user.get_puppet()
|
|
except NotImplementedError:
|
|
own_puppet = None
|
|
return {
|
|
"bridge": self,
|
|
"manhole": self.manhole,
|
|
"own_user": own_user,
|
|
"own_puppet": own_puppet,
|
|
}
|
|
|
|
@property
|
|
def manhole_banner_python_version(self) -> str:
|
|
return f"Python {sys.version} on {sys.platform}"
|
|
|
|
@property
|
|
def manhole_banner_program_version(self) -> str:
|
|
return f"{self.name} {self.version} with mautrix-python {__mautrix_version__}"
|
|
|
|
def manhole_banner(self, user_id: UserID) -> str:
|
|
return (
|
|
f"{self.manhole_banner_python_version}\n"
|
|
f"{self.manhole_banner_program_version}\n\n"
|
|
f"Manhole opened by {user_id}\n"
|
|
)
|