159 lines
5.8 KiB
Python
159 lines
5.8 KiB
Python
# Copyright (c) 2023 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
from typing import Dict, List, Optional, Union
|
|
import asyncio
|
|
import uuid
|
|
|
|
from mautrix.types import (
|
|
DecryptedOlmEvent,
|
|
DeviceID,
|
|
EncryptionAlgorithm,
|
|
EventType,
|
|
ForwardedRoomKeyEventContent,
|
|
IdentityKey,
|
|
KeyRequestAction,
|
|
RequestedKeyInfo,
|
|
RoomID,
|
|
RoomKeyRequestEventContent,
|
|
SessionID,
|
|
UserID,
|
|
)
|
|
|
|
from .base import BaseOlmMachine
|
|
from .sessions import InboundGroupSession
|
|
|
|
|
|
class KeyRequestingMachine(BaseOlmMachine):
|
|
async def request_room_key(
|
|
self,
|
|
room_id: RoomID,
|
|
sender_key: IdentityKey,
|
|
session_id: SessionID,
|
|
from_devices: Dict[UserID, List[DeviceID]],
|
|
timeout: Optional[Union[int, float]] = None,
|
|
) -> bool:
|
|
"""
|
|
Request keys for a Megolm group session from other devices.
|
|
|
|
Once the keys are received, or if this task is cancelled (via the ``timeout`` parameter),
|
|
a cancel request event is sent to the remaining devices. If the ``timeout`` is set to zero
|
|
or less, this will return immediately, and the extra key requests will not be cancelled.
|
|
|
|
Args:
|
|
room_id: The room where the session is used.
|
|
sender_key: The key of the user who created the session.
|
|
session_id: The ID of the session.
|
|
from_devices: A dict from user ID to list of device IDs whom to ask for the keys.
|
|
timeout: The maximum number of seconds to wait for the keys. If the timeout is
|
|
``None``, the wait time is not limited, but the task can still be cancelled.
|
|
If it's zero or less, this returns immediately and will never cancel requests.
|
|
|
|
Returns:
|
|
``True`` if the keys were received and are now in the crypto store,
|
|
``False`` otherwise (including if the method didn't wait at all).
|
|
"""
|
|
request_id = str(uuid.uuid1())
|
|
request = RoomKeyRequestEventContent(
|
|
action=KeyRequestAction.REQUEST,
|
|
body=RequestedKeyInfo(
|
|
algorithm=EncryptionAlgorithm.MEGOLM_V1,
|
|
room_id=room_id,
|
|
sender_key=sender_key,
|
|
session_id=session_id,
|
|
),
|
|
request_id=request_id,
|
|
requesting_device_id=self.client.device_id,
|
|
)
|
|
|
|
wait = timeout is None or timeout > 0
|
|
fut: Optional[asyncio.Future] = None
|
|
if wait:
|
|
fut = asyncio.get_running_loop().create_future()
|
|
self._key_request_waiters[session_id] = fut
|
|
await self.client.send_to_device(
|
|
EventType.ROOM_KEY_REQUEST,
|
|
{
|
|
user_id: {device_id: request for device_id in devices}
|
|
for user_id, devices in from_devices.items()
|
|
},
|
|
)
|
|
if not wait:
|
|
# Timeout is set and <=0, don't wait for keys
|
|
return False
|
|
assert fut is not None
|
|
got_keys = False
|
|
try:
|
|
user_id, device_id = await asyncio.wait_for(fut, timeout=timeout)
|
|
got_keys = True
|
|
try:
|
|
del from_devices[user_id][device_id]
|
|
if len(from_devices[user_id]) == 0:
|
|
del from_devices[user_id]
|
|
except KeyError:
|
|
pass
|
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
pass
|
|
del self._key_request_waiters[session_id]
|
|
if len(from_devices) > 0:
|
|
cancel = RoomKeyRequestEventContent(
|
|
action=KeyRequestAction.CANCEL,
|
|
request_id=str(request_id),
|
|
requesting_device_id=self.client.device_id,
|
|
)
|
|
await self.client.send_to_device(
|
|
EventType.ROOM_KEY_REQUEST,
|
|
{
|
|
user_id: {device_id: cancel for device_id in devices}
|
|
for user_id, devices in from_devices.items()
|
|
},
|
|
)
|
|
return got_keys
|
|
|
|
async def _receive_forwarded_room_key(self, evt: DecryptedOlmEvent) -> None:
|
|
key: ForwardedRoomKeyEventContent = evt.content
|
|
if await self.crypto_store.has_group_session(key.room_id, key.session_id):
|
|
self.log.debug(
|
|
f"Ignoring received session {key.session_id} from {evt.sender}/"
|
|
f"{evt.sender_device}, as crypto store says we have it already"
|
|
)
|
|
return
|
|
if not key.beeper_max_messages or not key.beeper_max_age_ms:
|
|
await self._fill_encryption_info(key)
|
|
key.forwarding_key_chain.append(evt.sender_key)
|
|
sess = InboundGroupSession.import_session(
|
|
key.session_key,
|
|
key.signing_key,
|
|
key.sender_key,
|
|
key.room_id,
|
|
key.forwarding_key_chain,
|
|
max_age=key.beeper_max_age_ms,
|
|
max_messages=key.beeper_max_messages,
|
|
is_scheduled=key.beeper_is_scheduled,
|
|
)
|
|
if key.session_id != sess.id:
|
|
self.log.warning(
|
|
f"Mismatched session ID while importing forwarded key from "
|
|
f"{evt.sender}/{evt.sender_device}: '{key.session_id}' != '{sess.id}'"
|
|
)
|
|
return
|
|
|
|
await self.crypto_store.put_group_session(
|
|
key.room_id, key.sender_key, key.session_id, sess
|
|
)
|
|
self._mark_session_received(key.session_id)
|
|
|
|
self.log.debug(
|
|
f"Imported {key.session_id} for {key.room_id} "
|
|
f"from {evt.sender}/{evt.sender_device}"
|
|
)
|
|
|
|
try:
|
|
task = self._key_request_waiters[key.session_id]
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
task.set_result((evt.sender, evt.sender_device))
|