mautrix-python/mautrix/bridge/e2ee.py

375 lines
16 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
import asyncio
import logging
import sys
from mautrix import __optional_imports__
from mautrix.appservice import AppService
from mautrix.client import Client, InternalEventType, SyncStore
from mautrix.crypto import CryptoStore, OlmMachine, PgCryptoStore, RejectKeyShare, StateStore
from mautrix.errors import EncryptionError, MForbidden, MNotFound, SessionNotFound
from mautrix.types import (
JSON,
DeviceIdentity,
EncryptedEvent,
EncryptedMegolmEventContent,
EventFilter,
EventType,
Filter,
LoginType,
MessageEvent,
RequestedKeyInfo,
RoomEventFilter,
RoomFilter,
RoomID,
RoomKeyWithheldCode,
Serializable,
StateEvent,
StateFilter,
TrustState,
)
from mautrix.util import background_task
from mautrix.util.async_db import Database
from mautrix.util.logging import TraceLogger
from .. import bridge as br
from .crypto_state_store import PgCryptoStateStore
class EncryptionManager:
loop: asyncio.AbstractEventLoop
log: TraceLogger = logging.getLogger("mau.bridge.e2ee")
client: Client
crypto: OlmMachine
crypto_store: CryptoStore | SyncStore
crypto_db: Database | None
state_store: StateStore
min_send_trust: TrustState
key_sharing_enabled: bool
appservice_mode: bool
periodically_delete_expired_keys: bool
delete_outdated_inbound: bool
bridge: br.Bridge
az: AppService
_id_prefix: str
_id_suffix: str
_share_session_events: dict[RoomID, asyncio.Event]
_key_delete_task: asyncio.Task | None
def __init__(
self,
bridge: br.Bridge,
homeserver_address: str,
user_id_prefix: str,
user_id_suffix: str,
db_url: str,
) -> None:
self.loop = bridge.loop or asyncio.get_event_loop()
self.bridge = bridge
self.az = bridge.az
self.device_name = bridge.name
self._id_prefix = user_id_prefix
self._id_suffix = user_id_suffix
self._share_session_events = {}
pickle_key = "mautrix.bridge.e2ee"
self.crypto_db = Database.create(
url=db_url,
upgrade_table=PgCryptoStore.upgrade_table,
log=logging.getLogger("mau.crypto.db"),
)
self.crypto_store = PgCryptoStore("", pickle_key, self.crypto_db)
self.state_store = PgCryptoStateStore(self.crypto_db, bridge.get_portal)
default_http_retry_count = bridge.config.get("homeserver.http_retry_count", None)
self.client = Client(
base_url=homeserver_address,
mxid=self.az.bot_mxid,
loop=self.loop,
sync_store=self.crypto_store,
log=self.log.getChild("client"),
default_retry_count=default_http_retry_count,
state_store=self.bridge.state_store,
)
self.crypto = OlmMachine(self.client, self.crypto_store, self.state_store)
self.client.add_event_handler(InternalEventType.SYNC_STOPPED, self._exit_on_sync_fail)
self.crypto.allow_key_share = self.allow_key_share
verification_levels = bridge.config["bridge.encryption.verification_levels"]
self.min_send_trust = TrustState.parse(verification_levels["send"])
self.crypto.share_keys_min_trust = TrustState.parse(verification_levels["share"])
self.crypto.send_keys_min_trust = TrustState.parse(verification_levels["receive"])
self.key_sharing_enabled = bridge.config["bridge.encryption.allow_key_sharing"]
self.appservice_mode = bridge.config["bridge.encryption.appservice"]
if self.appservice_mode:
self.az.otk_handler = self.crypto.handle_as_otk_counts
self.az.device_list_handler = self.crypto.handle_as_device_lists
self.az.to_device_handler = self.crypto.handle_as_to_device_event
self.periodically_delete_expired_keys = False
self.delete_outdated_inbound = False
self._key_delete_task = None
del_cfg = bridge.config["bridge.encryption.delete_keys"]
if del_cfg:
self.crypto.delete_outbound_keys_on_ack = del_cfg["delete_outbound_on_ack"]
self.crypto.dont_store_outbound_keys = del_cfg["dont_store_outbound"]
self.crypto.delete_previous_keys_on_receive = del_cfg["delete_prev_on_new_session"]
self.crypto.ratchet_keys_on_decrypt = del_cfg["ratchet_on_decrypt"]
self.crypto.delete_fully_used_keys_on_decrypt = del_cfg["delete_fully_used_on_decrypt"]
self.crypto.delete_keys_on_device_delete = del_cfg["delete_on_device_delete"]
self.periodically_delete_expired_keys = del_cfg["periodically_delete_expired"]
self.delete_outdated_inbound = del_cfg["delete_outdated_inbound"]
self.crypto.disable_device_change_key_rotation = bridge.config[
"bridge.encryption.rotation.disable_device_change_key_rotation"
]
async def _exit_on_sync_fail(self, data) -> None:
if data["error"]:
self.log.critical("Exiting due to crypto sync error")
sys.exit(32)
async def allow_key_share(self, device: DeviceIdentity, request: RequestedKeyInfo) -> bool:
if not self.key_sharing_enabled:
self.log.debug(
f"Key sharing not enabled, ignoring key request from "
f"{device.user_id}/{device.device_id}"
)
return False
elif device.trust == TrustState.BLACKLISTED:
raise RejectKeyShare(
f"Rejecting key request from blacklisted device "
f"{device.user_id}/{device.device_id}",
code=RoomKeyWithheldCode.BLACKLISTED,
reason="Your device has been blacklisted by the bridge",
)
elif await self.crypto.resolve_trust(device) >= self.crypto.share_keys_min_trust:
portal = await self.bridge.get_portal(request.room_id)
if portal is None:
raise RejectKeyShare(
f"Rejecting key request for {request.session_id} from "
f"{device.user_id}/{device.device_id}: room is not a portal",
code=RoomKeyWithheldCode.UNAVAILABLE,
reason="Requested room is not a portal",
)
user = await self.bridge.get_user(device.user_id)
if not await user.is_in_portal(portal):
raise RejectKeyShare(
f"Rejecting key request for {request.session_id} from "
f"{device.user_id}/{device.device_id}: user is not in portal",
code=RoomKeyWithheldCode.UNAUTHORIZED,
reason="You're not in that portal",
)
self.log.debug(
f"Accepting key request for {request.session_id} from "
f"{device.user_id}/{device.device_id}"
)
return True
else:
raise RejectKeyShare(
f"Rejecting key request from unverified device "
f"{device.user_id}/{device.device_id}",
code=RoomKeyWithheldCode.UNVERIFIED,
reason="Your device is not trusted by the bridge",
)
def _ignore_user(self, user_id: str) -> bool:
return (
user_id.startswith(self._id_prefix)
and user_id.endswith(self._id_suffix)
and user_id != self.az.bot_mxid
)
async def handle_member_event(self, evt: StateEvent) -> None:
if self._ignore_user(evt.state_key):
# We don't want to invalidate group sessions because a ghost left or joined
return
await self.crypto.handle_member_event(evt)
async def _share_session_lock(self, room_id: RoomID) -> bool:
try:
event = self._share_session_events[room_id]
except KeyError:
self._share_session_events[room_id] = asyncio.Event()
return True
else:
await event.wait()
return False
async def encrypt(
self, room_id: RoomID, event_type: EventType, content: Serializable | JSON
) -> tuple[EventType, EncryptedMegolmEventContent]:
try:
encrypted = await self.crypto.encrypt_megolm_event(room_id, event_type, content)
except EncryptionError:
self.log.debug("Got EncryptionError, sharing group session and trying again")
if await self._share_session_lock(room_id):
try:
users = await self.az.state_store.get_members_filtered(
room_id, self._id_prefix, self._id_suffix, self.az.bot_mxid
)
await self.crypto.share_group_session(room_id, users)
finally:
self._share_session_events.pop(room_id).set()
encrypted = await self.crypto.encrypt_megolm_event(room_id, event_type, content)
return EventType.ROOM_ENCRYPTED, encrypted
async def decrypt(self, evt: EncryptedEvent, wait_session_timeout: int = 5) -> MessageEvent:
try:
decrypted = await self.crypto.decrypt_megolm_event(evt)
except SessionNotFound as e:
if not wait_session_timeout:
raise
self.log.debug(
f"Couldn't find session {e.session_id} trying to decrypt {evt.event_id},"
f" waiting {wait_session_timeout} seconds..."
)
got_keys = await self.crypto.wait_for_session(
evt.room_id, e.session_id, timeout=wait_session_timeout
)
if got_keys:
self.log.debug(
f"Got session {e.session_id} after waiting, "
f"trying to decrypt {evt.event_id} again"
)
decrypted = await self.crypto.decrypt_megolm_event(evt)
else:
raise
self.log.trace("Decrypted event %s: %s", evt.event_id, decrypted)
return decrypted
async def start(self) -> None:
flows = await self.client.get_login_flows()
if not flows.supports_type(LoginType.APPSERVICE):
self.log.critical(
"Encryption enabled in config, but homeserver does not support appservice login"
)
sys.exit(30)
self.log.debug("Logging in with bridge bot user")
if self.crypto_db:
try:
await self.crypto_db.start()
except Exception as e:
self.bridge._log_db_error(e)
await self.crypto_store.open()
device_id = await self.crypto_store.get_device_id()
if device_id:
self.log.debug(f"Found device ID in database: {device_id}")
# We set the API token to the AS token here to authenticate the appservice login
# It'll get overridden after the login
self.client.api.token = self.az.as_token
await self.client.login(
login_type=LoginType.APPSERVICE,
device_name=self.device_name,
device_id=device_id,
store_access_token=True,
update_hs_url=False,
)
await self.crypto.load()
if not device_id:
await self.crypto_store.put_device_id(self.client.device_id)
self.log.debug(f"Logged in with new device ID {self.client.device_id}")
elif self.crypto.account.shared:
await self._verify_keys_are_on_server()
if self.appservice_mode:
self.log.info("End-to-bridge encryption support is enabled (appservice mode)")
else:
_ = self.client.start(self._filter)
self.log.info("End-to-bridge encryption support is enabled (sync mode)")
if self.delete_outdated_inbound:
deleted = await self.crypto_store.redact_outdated_group_sessions()
if len(deleted) > 0:
self.log.debug(
f"Deleted {len(deleted)} inbound keys which lacked expiration metadata"
)
if self.periodically_delete_expired_keys:
self._key_delete_task = background_task.create(self._periodically_delete_keys())
background_task.create(self._resync_encryption_info())
async def _resync_encryption_info(self) -> None:
rows = await self.crypto_db.fetch(
"""SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'"""
)
room_ids = [row["room_id"] for row in rows]
if not room_ids:
return
self.log.debug(f"Resyncing encryption state event in rooms: {room_ids}")
for room_id in room_ids:
try:
evt = await self.client.get_state_event(room_id, EventType.ROOM_ENCRYPTION)
except (MNotFound, MForbidden) as e:
self.log.debug(f"Failed to get encryption state in {room_id}: {e}")
q = """
UPDATE mx_room_state SET encryption=NULL
WHERE room_id=$1 AND encryption='{"resync":true}'
"""
await self.crypto_db.execute(q, room_id)
else:
self.log.debug(f"Resynced encryption state in {room_id}: {evt}")
q = """
UPDATE crypto_megolm_inbound_session SET max_age=$1, max_messages=$2
WHERE room_id=$3 AND max_age IS NULL and max_messages IS NULL
"""
await self.crypto_db.execute(
q, evt.rotation_period_ms, evt.rotation_period_msgs, room_id
)
async def _verify_keys_are_on_server(self) -> None:
self.log.debug("Making sure keys are still on server")
try:
resp = await self.client.query_keys([self.client.mxid])
except Exception:
self.log.critical(
"Failed to query own keys to make sure device still exists", exc_info=True
)
sys.exit(33)
try:
own_keys = resp.device_keys[self.client.mxid][self.client.device_id]
if len(own_keys.keys) > 0:
return
except KeyError:
pass
self.log.critical("Existing device doesn't have keys on server, resetting crypto")
await self.crypto.crypto_store.delete()
await self.client.logout_all()
sys.exit(34)
async def stop(self) -> None:
if self._key_delete_task:
self._key_delete_task.cancel()
self._key_delete_task = None
self.client.stop()
await self.crypto_store.close()
if self.crypto_db:
await self.crypto_db.stop()
@property
def _filter(self) -> Filter:
all_events = EventType.find("*")
return Filter(
account_data=EventFilter(types=[all_events]),
presence=EventFilter(not_types=[all_events]),
room=RoomFilter(
include_leave=False,
state=StateFilter(not_types=[all_events]),
timeline=RoomEventFilter(not_types=[all_events]),
account_data=RoomEventFilter(not_types=[all_events]),
ephemeral=RoomEventFilter(not_types=[all_events]),
),
)
async def _periodically_delete_keys(self) -> None:
while True:
deleted = await self.crypto_store.redact_expired_group_sessions()
if deleted:
self.log.info(f"Deleted expired megolm sessions: {deleted}")
else:
self.log.debug("No expired megolm sessions found")
await asyncio.sleep(24 * 60 * 60)