143 lines
4.6 KiB
Python
143 lines
4.6 KiB
Python
# Copyright (c) 2023 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Awaitable, Callable, TypedDict
|
|
import asyncio
|
|
import functools
|
|
import json
|
|
|
|
import olm
|
|
|
|
from mautrix.errors import MForbidden, MNotFound
|
|
from mautrix.types import (
|
|
DeviceID,
|
|
EncryptionKeyAlgorithm,
|
|
EventType,
|
|
IdentityKey,
|
|
KeyID,
|
|
RequestedKeyInfo,
|
|
RoomEncryptionStateEventContent,
|
|
RoomID,
|
|
RoomKeyEventContent,
|
|
SessionID,
|
|
SigningKey,
|
|
TrustState,
|
|
UserID,
|
|
)
|
|
from mautrix.util.logging import TraceLogger
|
|
|
|
from .. import client as cli, crypto
|
|
|
|
|
|
class SignedObject(TypedDict):
|
|
signatures: dict[UserID, dict[str, str]]
|
|
unsigned: Any
|
|
|
|
|
|
class BaseOlmMachine:
|
|
client: cli.Client
|
|
log: TraceLogger
|
|
crypto_store: crypto.CryptoStore
|
|
state_store: crypto.StateStore
|
|
|
|
account: account.OlmAccount
|
|
|
|
send_keys_min_trust: TrustState
|
|
share_keys_min_trust: TrustState
|
|
allow_key_share: Callable[[crypto.DeviceIdentity, RequestedKeyInfo], Awaitable[bool]]
|
|
|
|
delete_outbound_keys_on_ack: bool
|
|
dont_store_outbound_keys: bool
|
|
delete_previous_keys_on_receive: bool
|
|
ratchet_keys_on_decrypt: bool
|
|
delete_fully_used_keys_on_decrypt: bool
|
|
delete_keys_on_device_delete: bool
|
|
disable_device_change_key_rotation: bool
|
|
|
|
# Futures that wait for responses to a key request
|
|
_key_request_waiters: dict[SessionID, asyncio.Future]
|
|
# Futures that wait for a session to be received (either normally or through a key request)
|
|
_inbound_session_waiters: dict[SessionID, asyncio.Future]
|
|
|
|
_prev_unwedge: dict[IdentityKey, float]
|
|
_fetch_keys_lock: asyncio.Lock
|
|
_megolm_decrypt_lock: asyncio.Lock
|
|
_share_keys_lock: asyncio.Lock
|
|
_last_key_share: float
|
|
_cs_fetch_attempted: set[UserID]
|
|
|
|
async def wait_for_session(
|
|
self, room_id: RoomID, session_id: SessionID, timeout: float = 3
|
|
) -> bool:
|
|
try:
|
|
fut = self._inbound_session_waiters[session_id]
|
|
except KeyError:
|
|
fut = asyncio.get_running_loop().create_future()
|
|
self._inbound_session_waiters[session_id] = fut
|
|
try:
|
|
return await asyncio.wait_for(asyncio.shield(fut), timeout)
|
|
except asyncio.TimeoutError:
|
|
return await self.crypto_store.has_group_session(room_id, session_id)
|
|
|
|
def _mark_session_received(self, session_id: SessionID) -> None:
|
|
try:
|
|
self._inbound_session_waiters.pop(session_id).set_result(True)
|
|
except KeyError:
|
|
return
|
|
|
|
async def _fill_encryption_info(self, evt: RoomKeyEventContent) -> None:
|
|
encryption_info = await self.state_store.get_encryption_info(evt.room_id)
|
|
if not encryption_info:
|
|
self.log.warning(
|
|
f"Encryption info for {evt.room_id} not found in state store, fetching from server"
|
|
)
|
|
try:
|
|
encryption_info = await self.client.get_state_event(
|
|
evt.room_id, EventType.ROOM_ENCRYPTION
|
|
)
|
|
except (MNotFound, MForbidden) as e:
|
|
self.log.warning(
|
|
f"Failed to get encryption info for {evt.room_id} from server: {e},"
|
|
" using defaults"
|
|
)
|
|
encryption_info = RoomEncryptionStateEventContent()
|
|
if not encryption_info:
|
|
self.log.warning(
|
|
f"Didn't find encryption info for {evt.room_id} on server either,"
|
|
" using defaults"
|
|
)
|
|
encryption_info = RoomEncryptionStateEventContent()
|
|
|
|
if not evt.beeper_max_age_ms:
|
|
evt.beeper_max_age_ms = encryption_info.rotation_period_ms
|
|
if not evt.beeper_max_messages:
|
|
evt.beeper_max_messages = encryption_info.rotation_period_msgs
|
|
|
|
|
|
canonical_json = functools.partial(
|
|
json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True
|
|
)
|
|
|
|
|
|
def verify_signature_json(
|
|
data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey
|
|
) -> bool:
|
|
data_copy = {**data}
|
|
data_copy.pop("unsigned", None)
|
|
signatures = data_copy.pop("signatures")
|
|
key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name))
|
|
try:
|
|
signature = signatures[user_id][key_id]
|
|
except KeyError:
|
|
return False
|
|
signed_data = canonical_json(data_copy)
|
|
try:
|
|
olm.ed25519_verify(key, signed_data, signature)
|
|
return True
|
|
except olm.OlmVerifyError:
|
|
return False
|