356 lines
13 KiB
Python
356 lines
13 KiB
Python
# Copyright (c) 2023 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
# Partly based on github.com/Cadair/python-appservice-framework (MIT license)
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Awaitable, Callable
|
|
from json import JSONDecodeError
|
|
import json
|
|
import logging
|
|
|
|
from aiohttp import web
|
|
|
|
from mautrix.types import (
|
|
JSON,
|
|
ASToDeviceEvent,
|
|
DeviceID,
|
|
DeviceLists,
|
|
DeviceOTKCount,
|
|
EphemeralEvent,
|
|
Event,
|
|
EventType,
|
|
RoomAlias,
|
|
SerializerError,
|
|
UserID,
|
|
)
|
|
from mautrix.util import background_task
|
|
|
|
HandlerFunc = Callable[[Event], Awaitable]
|
|
|
|
|
|
class AppServiceServerMixin:
|
|
log: logging.Logger
|
|
|
|
hs_token: str
|
|
ephemeral_events: bool
|
|
encryption_events: bool
|
|
synchronous_handlers: bool
|
|
|
|
query_user: Callable[[UserID], JSON]
|
|
query_alias: Callable[[RoomAlias], JSON]
|
|
|
|
transactions: set[str]
|
|
event_handlers: list[HandlerFunc]
|
|
to_device_handler: HandlerFunc | None
|
|
otk_handler: Callable[[dict[UserID, dict[DeviceID, DeviceOTKCount]]], Awaitable] | None
|
|
device_list_handler: Callable[[DeviceLists], Awaitable] | None
|
|
|
|
def __init__(
|
|
self,
|
|
ephemeral_events: bool = False,
|
|
encryption_events: bool = False,
|
|
log: logging.Logger | None = None,
|
|
hs_token: str | None = None,
|
|
) -> None:
|
|
if log is not None:
|
|
self.log = log
|
|
if hs_token is not None:
|
|
self.hs_token = hs_token
|
|
self.transactions = set()
|
|
self.event_handlers = []
|
|
self.to_device_handler = None
|
|
self.otk_handler = None
|
|
self.device_list_handler = None
|
|
self.ephemeral_events = ephemeral_events
|
|
self.encryption_events = encryption_events
|
|
self.synchronous_handlers = False
|
|
|
|
async def default_query_handler(_):
|
|
return None
|
|
|
|
self.query_user = default_query_handler
|
|
self.query_alias = default_query_handler
|
|
|
|
def register_routes(self, app: web.Application) -> None:
|
|
app.router.add_route(
|
|
"PUT", "/transactions/{transaction_id}", self._http_handle_transaction
|
|
)
|
|
app.router.add_route("GET", "/rooms/{alias}", self._http_query_alias)
|
|
app.router.add_route("GET", "/users/{user_id}", self._http_query_user)
|
|
app.router.add_route(
|
|
"PUT", "/_matrix/app/v1/transactions/{transaction_id}", self._http_handle_transaction
|
|
)
|
|
app.router.add_route("GET", "/_matrix/app/v1/rooms/{alias}", self._http_query_alias)
|
|
app.router.add_route("GET", "/_matrix/app/v1/users/{user_id}", self._http_query_user)
|
|
app.router.add_route("POST", "/_matrix/app/v1/ping", self._http_ping)
|
|
|
|
def _check_token(self, request: web.Request) -> bool:
|
|
try:
|
|
token = request.rel_url.query["access_token"]
|
|
except KeyError:
|
|
try:
|
|
token = request.headers["Authorization"].removeprefix("Bearer ")
|
|
except KeyError:
|
|
self.log.debug("No access_token nor Authorization header in request")
|
|
return False
|
|
|
|
if token != self.hs_token:
|
|
self.log.debug(f"Incorrect hs_token in request")
|
|
return False
|
|
|
|
return True
|
|
|
|
async def _http_query_user(self, request: web.Request) -> web.Response:
|
|
if not self._check_token(request):
|
|
return web.json_response({"error": "Invalid auth token"}, status=401)
|
|
|
|
try:
|
|
user_id = request.match_info["user_id"]
|
|
except KeyError:
|
|
return web.json_response({"error": "Missing user_id parameter"}, status=400)
|
|
|
|
try:
|
|
response = await self.query_user(user_id)
|
|
except Exception:
|
|
self.log.exception("Exception in user query handler")
|
|
return web.json_response({"error": "Internal appservice error"}, status=500)
|
|
|
|
if not response:
|
|
return web.json_response({}, status=404)
|
|
return web.json_response(response)
|
|
|
|
async def _http_query_alias(self, request: web.Request) -> web.Response:
|
|
if not self._check_token(request):
|
|
return web.json_response({"error": "Invalid auth token"}, status=401)
|
|
|
|
try:
|
|
alias = request.match_info["alias"]
|
|
except KeyError:
|
|
return web.json_response({"error": "Missing alias parameter"}, status=400)
|
|
|
|
try:
|
|
response = await self.query_alias(alias)
|
|
except Exception:
|
|
self.log.exception("Exception in alias query handler")
|
|
return web.json_response({"error": "Internal appservice error"}, status=500)
|
|
|
|
if not response:
|
|
return web.json_response({}, status=404)
|
|
return web.json_response(response)
|
|
|
|
async def _http_ping(self, request: web.Request) -> web.Response:
|
|
if not self._check_token(request):
|
|
raise web.HTTPUnauthorized(
|
|
content_type="application/json",
|
|
text=json.dumps({"error": "Invalid auth token", "errcode": "M_UNKNOWN_TOKEN"}),
|
|
)
|
|
try:
|
|
body = await request.json()
|
|
except JSONDecodeError:
|
|
raise web.HTTPBadRequest(
|
|
content_type="application/json",
|
|
text=json.dumps({"error": "Body is not JSON", "errcode": "M_NOT_JSON"}),
|
|
)
|
|
txn_id = body.get("transaction_id")
|
|
self.log.info(f"Received ping from homeserver with transaction ID {txn_id}")
|
|
return web.json_response({})
|
|
|
|
@staticmethod
|
|
def _get_with_fallback(
|
|
json: dict[str, Any], field: str, unstable_prefix: str, default: Any = None
|
|
) -> Any:
|
|
try:
|
|
return json.pop(field)
|
|
except KeyError:
|
|
try:
|
|
return json.pop(f"{unstable_prefix}.{field}")
|
|
except KeyError:
|
|
return default
|
|
|
|
async def _read_transaction_header(self, request: web.Request) -> tuple[str, dict[str, Any]]:
|
|
if not self._check_token(request):
|
|
raise web.HTTPUnauthorized(
|
|
content_type="application/json",
|
|
text=json.dumps({"error": "Invalid auth token", "errcode": "M_UNKNOWN_TOKEN"}),
|
|
)
|
|
|
|
transaction_id = request.match_info["transaction_id"]
|
|
if transaction_id in self.transactions:
|
|
raise web.HTTPOk(content_type="application/json", text="{}")
|
|
|
|
try:
|
|
return transaction_id, await request.json()
|
|
except JSONDecodeError:
|
|
raise web.HTTPBadRequest(
|
|
content_type="application/json",
|
|
text=json.dumps({"error": "Body is not JSON", "errcode": "M_NOT_JSON"}),
|
|
)
|
|
|
|
async def _http_handle_transaction(self, request: web.Request) -> web.Response:
|
|
transaction_id, data = await self._read_transaction_header(request)
|
|
|
|
txn_content_log = []
|
|
try:
|
|
events = data.pop("events")
|
|
if events:
|
|
txn_content_log.append(f"{len(events)} PDUs")
|
|
except KeyError:
|
|
raise web.HTTPBadRequest(
|
|
content_type="application/json",
|
|
text=json.dumps(
|
|
{"error": "Missing events object in body", "errcode": "M_BAD_JSON"}
|
|
),
|
|
)
|
|
|
|
if self.ephemeral_events:
|
|
ephemeral = self._get_with_fallback(data, "ephemeral", "de.sorunome.msc2409")
|
|
if ephemeral:
|
|
txn_content_log.append(f"{len(ephemeral)} EDUs")
|
|
else:
|
|
ephemeral = None
|
|
if self.encryption_events:
|
|
to_device = self._get_with_fallback(data, "to_device", "de.sorunome.msc2409")
|
|
device_lists = DeviceLists.deserialize(
|
|
self._get_with_fallback(data, "device_lists", "org.matrix.msc3202")
|
|
)
|
|
otk_counts = {
|
|
user_id: {
|
|
device_id: DeviceOTKCount.deserialize(count)
|
|
for device_id, count in devices.items()
|
|
}
|
|
for user_id, devices in self._get_with_fallback(
|
|
data, "device_one_time_keys_count", "org.matrix.msc3202", default={}
|
|
).items()
|
|
}
|
|
if to_device:
|
|
txn_content_log.append(f"{len(to_device)} to-device events")
|
|
if device_lists.changed:
|
|
txn_content_log.append(f"{len(device_lists.changed)} device list changes")
|
|
if otk_counts:
|
|
txn_content_log.append(
|
|
f"{sum(len(vals) for vals in otk_counts.values())} OTK counts"
|
|
)
|
|
else:
|
|
otk_counts = {}
|
|
device_lists = None
|
|
to_device = None
|
|
|
|
if len(txn_content_log) > 2:
|
|
txn_content_log = [", ".join(txn_content_log[:-1]), txn_content_log[-1]]
|
|
if not txn_content_log:
|
|
txn_description = "nothing?"
|
|
else:
|
|
txn_description = " and ".join(txn_content_log)
|
|
self.log.debug(f"Handling transaction {transaction_id} with {txn_description}")
|
|
|
|
try:
|
|
output = await self.handle_transaction(
|
|
transaction_id,
|
|
events=events,
|
|
extra_data=data,
|
|
ephemeral=ephemeral,
|
|
to_device=to_device,
|
|
device_lists=device_lists,
|
|
otk_counts=otk_counts,
|
|
)
|
|
except Exception:
|
|
self.log.exception("Exception in transaction handler")
|
|
output = None
|
|
finally:
|
|
self.log.debug(f"Finished handling transaction {transaction_id}")
|
|
|
|
self.transactions.add(transaction_id)
|
|
|
|
return web.json_response(output or {})
|
|
|
|
@staticmethod
|
|
def _fix_prev_content(raw_event: JSON) -> None:
|
|
try:
|
|
if raw_event["unsigned"] is None:
|
|
del raw_event["unsigned"]
|
|
except KeyError:
|
|
pass
|
|
try:
|
|
raw_event["unsigned"]["prev_content"]
|
|
except KeyError:
|
|
try:
|
|
raw_event.setdefault("unsigned", {})["prev_content"] = raw_event["prev_content"]
|
|
except KeyError:
|
|
pass
|
|
|
|
async def handle_transaction(
|
|
self,
|
|
txn_id: str,
|
|
*,
|
|
events: list[JSON],
|
|
extra_data: JSON,
|
|
ephemeral: list[JSON] | None = None,
|
|
to_device: list[JSON] | None = None,
|
|
otk_counts: dict[UserID, dict[DeviceID, DeviceOTKCount]] | None = None,
|
|
device_lists: DeviceLists | None = None,
|
|
) -> JSON:
|
|
for raw_td in to_device or []:
|
|
try:
|
|
td = ASToDeviceEvent.deserialize(raw_td)
|
|
except SerializerError:
|
|
self.log.exception("Failed to deserialize to-device event %s", raw_td)
|
|
else:
|
|
try:
|
|
await self.to_device_handler(td)
|
|
except Exception:
|
|
self.log.exception("Exception in Matrix to-device event handler")
|
|
if device_lists and self.device_list_handler:
|
|
try:
|
|
await self.device_list_handler(device_lists)
|
|
except Exception:
|
|
self.log.exception("Exception in Matrix device list change handler")
|
|
if otk_counts and self.otk_handler:
|
|
try:
|
|
await self.otk_handler(otk_counts)
|
|
except Exception:
|
|
self.log.exception("Exception in Matrix OTK count handler")
|
|
for raw_edu in ephemeral or []:
|
|
try:
|
|
edu = EphemeralEvent.deserialize(raw_edu)
|
|
except SerializerError:
|
|
self.log.exception("Failed to deserialize ephemeral event %s", raw_edu)
|
|
else:
|
|
await self.handle_matrix_event(edu, ephemeral=True)
|
|
for raw_event in events:
|
|
try:
|
|
self._fix_prev_content(raw_event)
|
|
event = Event.deserialize(raw_event)
|
|
except SerializerError:
|
|
self.log.exception("Failed to deserialize event %s", raw_event)
|
|
else:
|
|
await self.handle_matrix_event(event)
|
|
return {}
|
|
|
|
async def handle_matrix_event(self, event: Event, ephemeral: bool = False) -> None:
|
|
if ephemeral:
|
|
event.type = event.type.with_class(EventType.Class.EPHEMERAL)
|
|
elif getattr(event, "state_key", None) is not None:
|
|
event.type = event.type.with_class(EventType.Class.STATE)
|
|
else:
|
|
event.type = event.type.with_class(EventType.Class.MESSAGE)
|
|
|
|
async def try_handle(handler_func: HandlerFunc):
|
|
try:
|
|
await handler_func(event)
|
|
except Exception:
|
|
self.log.exception("Exception in Matrix event handler")
|
|
|
|
if self.synchronous_handlers:
|
|
for handler in self.event_handlers:
|
|
await handler(event)
|
|
else:
|
|
for handler in self.event_handlers:
|
|
background_task.create(try_handle(handler))
|
|
|
|
def matrix_event_handler(self, func: HandlerFunc) -> HandlerFunc:
|
|
self.event_handlers.append(func)
|
|
return func
|