mautrix-python/mautrix/crypto/key_request.py

159 lines
5.8 KiB
Python

# Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Dict, List, Optional, Union
import asyncio
import uuid
from mautrix.types import (
DecryptedOlmEvent,
DeviceID,
EncryptionAlgorithm,
EventType,
ForwardedRoomKeyEventContent,
IdentityKey,
KeyRequestAction,
RequestedKeyInfo,
RoomID,
RoomKeyRequestEventContent,
SessionID,
UserID,
)
from .base import BaseOlmMachine
from .sessions import InboundGroupSession
class KeyRequestingMachine(BaseOlmMachine):
async def request_room_key(
self,
room_id: RoomID,
sender_key: IdentityKey,
session_id: SessionID,
from_devices: Dict[UserID, List[DeviceID]],
timeout: Optional[Union[int, float]] = None,
) -> bool:
"""
Request keys for a Megolm group session from other devices.
Once the keys are received, or if this task is cancelled (via the ``timeout`` parameter),
a cancel request event is sent to the remaining devices. If the ``timeout`` is set to zero
or less, this will return immediately, and the extra key requests will not be cancelled.
Args:
room_id: The room where the session is used.
sender_key: The key of the user who created the session.
session_id: The ID of the session.
from_devices: A dict from user ID to list of device IDs whom to ask for the keys.
timeout: The maximum number of seconds to wait for the keys. If the timeout is
``None``, the wait time is not limited, but the task can still be cancelled.
If it's zero or less, this returns immediately and will never cancel requests.
Returns:
``True`` if the keys were received and are now in the crypto store,
``False`` otherwise (including if the method didn't wait at all).
"""
request_id = str(uuid.uuid1())
request = RoomKeyRequestEventContent(
action=KeyRequestAction.REQUEST,
body=RequestedKeyInfo(
algorithm=EncryptionAlgorithm.MEGOLM_V1,
room_id=room_id,
sender_key=sender_key,
session_id=session_id,
),
request_id=request_id,
requesting_device_id=self.client.device_id,
)
wait = timeout is None or timeout > 0
fut: Optional[asyncio.Future] = None
if wait:
fut = asyncio.get_running_loop().create_future()
self._key_request_waiters[session_id] = fut
await self.client.send_to_device(
EventType.ROOM_KEY_REQUEST,
{
user_id: {device_id: request for device_id in devices}
for user_id, devices in from_devices.items()
},
)
if not wait:
# Timeout is set and <=0, don't wait for keys
return False
assert fut is not None
got_keys = False
try:
user_id, device_id = await asyncio.wait_for(fut, timeout=timeout)
got_keys = True
try:
del from_devices[user_id][device_id]
if len(from_devices[user_id]) == 0:
del from_devices[user_id]
except KeyError:
pass
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
del self._key_request_waiters[session_id]
if len(from_devices) > 0:
cancel = RoomKeyRequestEventContent(
action=KeyRequestAction.CANCEL,
request_id=str(request_id),
requesting_device_id=self.client.device_id,
)
await self.client.send_to_device(
EventType.ROOM_KEY_REQUEST,
{
user_id: {device_id: cancel for device_id in devices}
for user_id, devices in from_devices.items()
},
)
return got_keys
async def _receive_forwarded_room_key(self, evt: DecryptedOlmEvent) -> None:
key: ForwardedRoomKeyEventContent = evt.content
if await self.crypto_store.has_group_session(key.room_id, key.session_id):
self.log.debug(
f"Ignoring received session {key.session_id} from {evt.sender}/"
f"{evt.sender_device}, as crypto store says we have it already"
)
return
if not key.beeper_max_messages or not key.beeper_max_age_ms:
await self._fill_encryption_info(key)
key.forwarding_key_chain.append(evt.sender_key)
sess = InboundGroupSession.import_session(
key.session_key,
key.signing_key,
key.sender_key,
key.room_id,
key.forwarding_key_chain,
max_age=key.beeper_max_age_ms,
max_messages=key.beeper_max_messages,
is_scheduled=key.beeper_is_scheduled,
)
if key.session_id != sess.id:
self.log.warning(
f"Mismatched session ID while importing forwarded key from "
f"{evt.sender}/{evt.sender_device}: '{key.session_id}' != '{sess.id}'"
)
return
await self.crypto_store.put_group_session(
key.room_id, key.sender_key, key.session_id, sess
)
self._mark_session_received(key.session_id)
self.log.debug(
f"Imported {key.session_id} for {key.room_id} "
f"from {evt.sender}/{evt.sender_device}"
)
try:
task = self._key_request_waiters[key.session_id]
except KeyError:
pass
else:
task.set_result((evt.sender, evt.sender_device))