316 lines
14 KiB
Python
316 lines
14 KiB
Python
# Copyright (c) 2023 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
from __future__ import annotations
|
|
|
|
from typing import Optional
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
|
|
from mautrix import client as cli
|
|
from mautrix.errors import GroupSessionWithheldError
|
|
from mautrix.types import (
|
|
ASToDeviceEvent,
|
|
DecryptedOlmEvent,
|
|
DeviceID,
|
|
DeviceLists,
|
|
DeviceOTKCount,
|
|
EncryptionAlgorithm,
|
|
EncryptionKeyAlgorithm,
|
|
EventType,
|
|
Member,
|
|
Membership,
|
|
StateEvent,
|
|
ToDeviceEvent,
|
|
TrustState,
|
|
UserID,
|
|
)
|
|
from mautrix.util import background_task
|
|
from mautrix.util.logging import TraceLogger
|
|
|
|
from .account import OlmAccount
|
|
from .decrypt_megolm import MegolmDecryptionMachine
|
|
from .encrypt_megolm import MegolmEncryptionMachine
|
|
from .key_request import KeyRequestingMachine
|
|
from .key_share import KeySharingMachine
|
|
from .store import CryptoStore, StateStore
|
|
from .unwedge import OlmUnwedgingMachine
|
|
|
|
|
|
class OlmMachine(
|
|
MegolmEncryptionMachine,
|
|
MegolmDecryptionMachine,
|
|
OlmUnwedgingMachine,
|
|
KeySharingMachine,
|
|
KeyRequestingMachine,
|
|
):
|
|
"""
|
|
OlmMachine is the main class for handling things related to Matrix end-to-end encryption with
|
|
Olm and Megolm. Users primarily need :meth:`encrypt_megolm_event`, :meth:`share_group_session`,
|
|
and :meth:`decrypt_megolm_event`. Tracking device lists, establishing Olm sessions and handling
|
|
Megolm group sessions is handled internally.
|
|
"""
|
|
|
|
client: cli.Client
|
|
log: TraceLogger
|
|
crypto_store: CryptoStore
|
|
state_store: StateStore
|
|
|
|
account: Optional[OlmAccount]
|
|
|
|
def __init__(
|
|
self,
|
|
client: cli.Client,
|
|
crypto_store: CryptoStore,
|
|
state_store: StateStore,
|
|
log: Optional[TraceLogger] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.client = client
|
|
self.log = log or logging.getLogger("mau.crypto")
|
|
self.crypto_store = crypto_store
|
|
self.state_store = state_store
|
|
self.account = None
|
|
|
|
self.send_keys_min_trust = TrustState.UNVERIFIED
|
|
self.share_keys_min_trust = TrustState.CROSS_SIGNED_TOFU
|
|
self.allow_key_share = self.default_allow_key_share
|
|
|
|
self.delete_outbound_keys_on_ack = False
|
|
self.dont_store_outbound_keys = False
|
|
self.delete_previous_keys_on_receive = False
|
|
self.ratchet_keys_on_decrypt = False
|
|
self.delete_fully_used_keys_on_decrypt = False
|
|
self.delete_keys_on_device_delete = False
|
|
self.disable_device_change_key_rotation = False
|
|
|
|
self._fetch_keys_lock = asyncio.Lock()
|
|
self._megolm_decrypt_lock = asyncio.Lock()
|
|
self._share_keys_lock = asyncio.Lock()
|
|
self._last_key_share = time.monotonic() - 60
|
|
self._key_request_waiters = {}
|
|
self._inbound_session_waiters = {}
|
|
self._prev_unwedge = {}
|
|
self._cs_fetch_attempted = set()
|
|
|
|
self.client.add_event_handler(
|
|
cli.InternalEventType.DEVICE_OTK_COUNT, self.handle_otk_count, wait_sync=True
|
|
)
|
|
self.client.add_event_handler(cli.InternalEventType.DEVICE_LISTS, self.handle_device_lists)
|
|
self.client.add_event_handler(EventType.TO_DEVICE_ENCRYPTED, self.handle_to_device_event)
|
|
self.client.add_event_handler(EventType.ROOM_KEY_REQUEST, self.handle_room_key_request)
|
|
self.client.add_event_handler(EventType.BEEPER_ROOM_KEY_ACK, self.handle_beep_room_key_ack)
|
|
# self.client.add_event_handler(EventType.ROOM_KEY_WITHHELD, self.handle_room_key_withheld)
|
|
# self.client.add_event_handler(EventType.ORG_MATRIX_ROOM_KEY_WITHHELD,
|
|
# self.handle_room_key_withheld)
|
|
self.client.add_event_handler(EventType.ROOM_MEMBER, self.handle_member_event)
|
|
|
|
async def load(self) -> None:
|
|
"""Load the Olm account into memory, or create one if the store doesn't have one stored."""
|
|
self.account = await self.crypto_store.get_account()
|
|
if self.account is None:
|
|
self.account = OlmAccount()
|
|
await self.crypto_store.put_account(self.account)
|
|
|
|
async def handle_as_otk_counts(
|
|
self, otk_counts: dict[UserID, dict[DeviceID, DeviceOTKCount]]
|
|
) -> None:
|
|
for user_id, devices in otk_counts.items():
|
|
for device_id, count in devices.items():
|
|
if user_id == self.client.mxid and device_id == self.client.device_id:
|
|
await self.handle_otk_count(count)
|
|
else:
|
|
self.log.warning(f"Got OTK count for unknown device {user_id}/{device_id}")
|
|
|
|
async def handle_as_device_lists(self, device_lists: DeviceLists) -> None:
|
|
background_task.create(self.handle_device_lists(device_lists))
|
|
|
|
async def handle_as_to_device_event(self, evt: ASToDeviceEvent) -> None:
|
|
if evt.to_user_id != self.client.mxid or evt.to_device_id != self.client.device_id:
|
|
self.log.warning(
|
|
f"Got to-device event for unknown device {evt.to_user_id}/{evt.to_device_id}"
|
|
)
|
|
return
|
|
if evt.type == EventType.TO_DEVICE_ENCRYPTED:
|
|
await self.handle_to_device_event(evt)
|
|
elif evt.type == EventType.ROOM_KEY_REQUEST:
|
|
await self.handle_room_key_request(evt)
|
|
elif evt.type == EventType.BEEPER_ROOM_KEY_ACK:
|
|
await self.handle_beep_room_key_ack(evt)
|
|
else:
|
|
self.log.debug(f"Got unknown to-device event {evt.type} from {evt.sender}")
|
|
|
|
async def handle_otk_count(self, otk_count: DeviceOTKCount) -> None:
|
|
"""
|
|
Handle the ``device_one_time_keys_count`` data in a sync response.
|
|
|
|
This is automatically registered as an event handler and therefore called if the client you
|
|
passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you
|
|
do syncing in some manual way.
|
|
"""
|
|
if otk_count.signed_curve25519 < self.account.max_one_time_keys // 2:
|
|
self.log.debug(
|
|
f"Sync response said we have {otk_count.signed_curve25519} signed"
|
|
" curve25519 keys left, sharing new ones..."
|
|
)
|
|
await self.share_keys(otk_count.signed_curve25519)
|
|
|
|
async def handle_device_lists(self, device_lists: DeviceLists) -> None:
|
|
"""
|
|
Handle the ``device_lists`` data in a sync response.
|
|
|
|
This is automatically registered as an event handler and therefore called if the client you
|
|
passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you
|
|
do syncing in some manual way.
|
|
"""
|
|
if len(device_lists.changed) > 0:
|
|
async with self._fetch_keys_lock:
|
|
await self._fetch_keys(device_lists.changed, include_untracked=False)
|
|
|
|
async def handle_member_event(self, evt: StateEvent) -> None:
|
|
"""
|
|
Handle a new member event.
|
|
|
|
This is automatically registered as an event handler and therefore called if the client you
|
|
passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you
|
|
receive events in some manual way (e.g. through appservice transactions)
|
|
"""
|
|
if not await self.state_store.is_encrypted(evt.room_id):
|
|
return
|
|
prev = evt.prev_content.membership
|
|
cur = evt.content.membership
|
|
ignored_changes = {
|
|
Membership.INVITE: Membership.JOIN,
|
|
Membership.BAN: Membership.LEAVE,
|
|
Membership.LEAVE: Membership.BAN,
|
|
}
|
|
if prev == cur or ignored_changes.get(prev) == cur:
|
|
return
|
|
src = getattr(evt, "source", None)
|
|
prev_cache = evt.unsigned.get("mautrix_prev_membership")
|
|
if isinstance(prev_cache, Member) and prev_cache.membership == cur:
|
|
self.log.debug(
|
|
f"Got duplicate membership state event in {evt.room_id} changing {evt.state_key} "
|
|
f"from {prev} to {cur}, cached state was {prev_cache} (event ID: {evt.event_id}, "
|
|
f"sync source: {src})"
|
|
)
|
|
return
|
|
self.log.debug(
|
|
f"Got membership state event in {evt.room_id} changing {evt.state_key} from "
|
|
f"{prev} to {cur} (event ID: {evt.event_id}, sync source: {src}, "
|
|
f"cached: {prev_cache.membership if prev_cache else None}), invalidating group session"
|
|
)
|
|
await self.crypto_store.remove_outbound_group_session(evt.room_id)
|
|
|
|
async def handle_to_device_event(self, evt: ToDeviceEvent) -> None:
|
|
"""
|
|
Handle an encrypted to-device event.
|
|
|
|
This is automatically registered as an event handler and therefore called if the client you
|
|
passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you
|
|
do syncing in some manual way.
|
|
"""
|
|
self.log.trace(
|
|
f"Handling encrypted to-device event from {evt.content.sender_key} ({evt.sender})"
|
|
)
|
|
decrypted_evt = await self._decrypt_olm_event(evt)
|
|
if decrypted_evt.type == EventType.ROOM_KEY:
|
|
await self._receive_room_key(decrypted_evt)
|
|
elif decrypted_evt.type == EventType.FORWARDED_ROOM_KEY:
|
|
await self._receive_forwarded_room_key(decrypted_evt)
|
|
|
|
async def _receive_room_key(self, evt: DecryptedOlmEvent) -> None:
|
|
# TODO nio had a comment saying "handle this better"
|
|
# for the case where evt.Keys.Ed25519 is none?
|
|
if evt.content.algorithm != EncryptionAlgorithm.MEGOLM_V1 or not evt.keys.ed25519:
|
|
return
|
|
if not evt.content.beeper_max_messages or not evt.content.beeper_max_age_ms:
|
|
await self._fill_encryption_info(evt.content)
|
|
if self.delete_previous_keys_on_receive and not evt.content.beeper_is_scheduled:
|
|
removed_ids = await self.crypto_store.redact_group_sessions(
|
|
evt.content.room_id, evt.sender_key, reason="received new key from device"
|
|
)
|
|
self.log.info(f"Redacted previous megolm sessions: {removed_ids}")
|
|
await self._create_group_session(
|
|
evt.sender_key,
|
|
evt.keys.ed25519,
|
|
evt.content.room_id,
|
|
evt.content.session_id,
|
|
evt.content.session_key,
|
|
max_age=evt.content.beeper_max_age_ms,
|
|
max_messages=evt.content.beeper_max_messages,
|
|
is_scheduled=evt.content.beeper_is_scheduled,
|
|
)
|
|
|
|
async def handle_beep_room_key_ack(self, evt: ToDeviceEvent) -> None:
|
|
try:
|
|
sess = await self.crypto_store.get_group_session(
|
|
evt.content.room_id, evt.content.session_id
|
|
)
|
|
except GroupSessionWithheldError:
|
|
self.log.debug(
|
|
f"Ignoring room key ack for session {evt.content.session_id}"
|
|
" that was already redacted"
|
|
)
|
|
return
|
|
if not sess:
|
|
self.log.debug(f"Ignoring room key ack for unknown session {evt.content.session_id}")
|
|
return
|
|
if (
|
|
sess.sender_key == self.account.identity_key
|
|
and self.delete_outbound_keys_on_ack
|
|
and evt.content.first_message_index == 0
|
|
):
|
|
self.log.debug("Redacting inbound copy of outbound group session after ack")
|
|
await self.crypto_store.redact_group_session(
|
|
evt.content.room_id, evt.content.session_id, reason="outbound session acked"
|
|
)
|
|
else:
|
|
self.log.debug(f"Received room key ack for {sess.id}")
|
|
|
|
async def share_keys(self, current_otk_count: int | None = None) -> None:
|
|
"""
|
|
Share any keys that need to be shared. This is automatically called from
|
|
:meth:`handle_otk_count`, so you should not need to call this yourself.
|
|
|
|
Args:
|
|
current_otk_count: The current number of signed curve25519 keys present on the server.
|
|
If omitted, the count will be fetched from the server.
|
|
"""
|
|
async with self._share_keys_lock:
|
|
await self._share_keys(current_otk_count)
|
|
|
|
async def _share_keys(self, current_otk_count: int | None) -> None:
|
|
if current_otk_count is None or (
|
|
# If the last key share was recent and the new count is very low, re-check the count
|
|
# from the server to avoid any race conditions.
|
|
self._last_key_share + 60 > time.monotonic()
|
|
and current_otk_count < 10
|
|
):
|
|
self.log.debug("Checking OTK count on server")
|
|
current_otk_count = (await self.client.upload_keys()).get(
|
|
EncryptionKeyAlgorithm.SIGNED_CURVE25519, 0
|
|
)
|
|
device_keys = (
|
|
self.account.get_device_keys(self.client.mxid, self.client.device_id)
|
|
if not self.account.shared
|
|
else None
|
|
)
|
|
one_time_keys = self.account.get_one_time_keys(
|
|
self.client.mxid, self.client.device_id, current_otk_count
|
|
)
|
|
if not device_keys and not one_time_keys:
|
|
self.log.warning("No one-time keys nor device keys got when trying to share keys")
|
|
return
|
|
if device_keys:
|
|
self.log.debug("Going to upload initial account keys")
|
|
self.log.debug(f"Uploading {len(one_time_keys)} one-time keys")
|
|
resp = await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys)
|
|
self.account.shared = True
|
|
self._last_key_share = time.monotonic()
|
|
await self.crypto_store.put_account(self.account)
|
|
self.log.debug(f"Shared keys and saved account, new keys: {resp}")
|