mautrix-python/mautrix/crypto/device_lists.py

347 lines
15 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from mautrix.errors import DeviceValidationError
from mautrix.types import (
CrossSigner,
CrossSigningKeys,
CrossSigningUsage,
DeviceID,
DeviceIdentity,
DeviceKeys,
EncryptionKeyAlgorithm,
IdentityKey,
KeyID,
QueryKeysResponse,
SigningKey,
SyncToken,
TrustState,
UserID,
)
from .base import BaseOlmMachine, verify_signature_json
class DeviceListMachine(BaseOlmMachine):
async def _fetch_keys(
self, users: list[UserID], since: SyncToken = "", include_untracked: bool = False
) -> dict[UserID, dict[DeviceID, DeviceIdentity]]:
if not include_untracked:
users = await self.crypto_store.filter_tracked_users(users)
if len(users) == 0:
return {}
users = set(users)
self.log.trace(f"Querying keys for {users}")
resp = await self.client.query_keys(users, token=since)
missing_users = users.copy()
for server, err in resp.failures.items():
self.log.warning(f"Query keys failure for {server}: {err}")
data = {}
for user_id, devices in resp.device_keys.items():
missing_users.remove(user_id)
new_devices = {}
existing_devices = (await self.crypto_store.get_devices(user_id)) or {}
self.log.trace(
f"Updating devices for {user_id}, got {len(devices)}, "
f"have {len(existing_devices)} in store"
)
changed = False
ssks = resp.self_signing_keys.get(user_id)
ssk = ssks.first_ed25519_key if ssks else None
for device_id, device_keys in devices.items():
try:
existing = existing_devices[device_id]
except KeyError:
existing = None
changed = True
self.log.trace(f"Validating device {device_keys} of {user_id}")
try:
new_device = await self._validate_device(
user_id, device_id, device_keys, existing
)
except DeviceValidationError as e:
self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}")
else:
if new_device:
new_devices[device_id] = new_device
await self._store_device_self_signatures(device_keys, ssk)
self.log.debug(
f"Storing new device list for {user_id} containing {len(new_devices)} devices"
)
await self.crypto_store.put_devices(user_id, new_devices)
data[user_id] = new_devices
if changed or len(new_devices) != len(existing_devices):
if self.delete_keys_on_device_delete:
for device_id in existing_devices.keys() - new_devices.keys():
device = existing_devices[device_id]
removed_ids = await self.crypto_store.redact_group_sessions(
room_id=None, sender_key=device.identity_key, reason="device removed"
)
self.log.info(
"Redacted megolm sessions sent by removed device "
f"{device.user_id}/{device.device_id}: {removed_ids}"
)
await self.on_devices_changed(user_id)
for user_id in missing_users:
self.log.warning(f"Didn't get any devices for user {user_id}")
for user_id in users:
await self._store_cross_signing_keys(resp, user_id)
return data
async def _store_device_self_signatures(
self, device_keys: DeviceKeys, self_signing_key: SigningKey | None
) -> None:
device_desc = f"Device {device_keys.user_id}/{device_keys.device_id}"
try:
self_signatures = device_keys.signatures[device_keys.user_id].copy()
except KeyError:
self.log.warning(f"{device_desc} doesn't have any signatures from the user")
return
if len(device_keys.signatures) > 1:
self.log.debug(
f"{device_desc} has signatures from other users (%s)",
set(device_keys.signatures.keys()) - {device_keys.user_id},
)
device_self_sig = self_signatures.pop(
KeyID(EncryptionKeyAlgorithm.ED25519, device_keys.device_id)
)
target = CrossSigner(device_keys.user_id, device_keys.ed25519)
# This one is already validated by _validate_device
await self.crypto_store.put_signature(target, target, device_self_sig)
try:
cs_self_sig = self_signatures.pop(
KeyID(EncryptionKeyAlgorithm.ED25519, self_signing_key)
)
except KeyError:
self.log.warning(f"{device_desc} isn't cross-signed")
else:
is_valid_self_sig = verify_signature_json(
device_keys.serialize(), device_keys.user_id, self_signing_key, self_signing_key
)
if is_valid_self_sig:
signer = CrossSigner(device_keys.user_id, self_signing_key)
await self.crypto_store.put_signature(target, signer, cs_self_sig)
else:
self.log.warning(f"{device_desc} doesn't have a valid cross-signing signature")
if len(self_signatures) > 0:
self.log.debug(
f"{device_desc} has signatures from unexpected keys (%s)",
set(self_signatures.keys()),
)
async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: UserID) -> None:
new_keys: dict[CrossSigningUsage, CrossSigningKeys] = {}
try:
master = new_keys[CrossSigningUsage.MASTER] = resp.master_keys[user_id]
except KeyError:
self.log.debug(f"Didn't get a cross-signing master key for {user_id}")
return
try:
new_keys[CrossSigningUsage.SELF] = resp.self_signing_keys[user_id]
except KeyError:
self.log.debug(f"Didn't get a cross-signing self-signing key for {user_id}")
return
try:
new_keys[CrossSigningUsage.USER] = resp.user_signing_keys[user_id]
except KeyError:
pass
current_keys = await self.crypto_store.get_cross_signing_keys(user_id)
for usage, key in current_keys.items():
if usage in new_keys and key.key != new_keys[usage].first_ed25519_key:
num = await self.crypto_store.drop_signatures_by_key(CrossSigner(user_id, key.key))
if num >= 0:
self.log.debug(
f"Dropped {num} signatures made by key {user_id}/{key.key} ({usage})"
" as it has been replaced"
)
for usage, key in new_keys.items():
actual_key = key.first_ed25519_key
self.log.debug(f"Storing cross-signing key for {user_id}: {actual_key} (type {usage})")
await self.crypto_store.put_cross_signing_key(user_id, usage, actual_key)
if usage != CrossSigningUsage.MASTER and (
KeyID(EncryptionKeyAlgorithm.ED25519, master.first_ed25519_key)
not in key.signatures[user_id]
):
self.log.warning(
f"Cross-signing key {user_id}/{actual_key}/{usage}"
" doesn't seem to have a signature from the master key"
)
for signer_user_id, signatures in key.signatures.items():
for key_id, signature in signatures.items():
signing_key = SigningKey(key_id.key_id)
if signer_user_id == user_id:
try:
device = resp.device_keys[signer_user_id][DeviceID(key_id.key_id)]
signing_key = device.ed25519
except KeyError:
pass
if len(signing_key) != 43:
self.log.debug(
f"Cross-signing key {user_id}/{actual_key} has a signature from "
f"an unknown key {key_id}"
)
continue
signing_key_log = signing_key
if signing_key != key_id.key_id:
signing_key_log = f"{signing_key} ({key_id})"
self.log.debug(
f"Verifying cross-signing key {user_id}/{actual_key} "
f"with key {signer_user_id}/{signing_key_log}"
)
is_valid_sig = verify_signature_json(
key.serialize(), signer_user_id, key_id.key_id, signing_key
)
if is_valid_sig:
self.log.debug(f"Signature from {signing_key_log} for {key_id} verified")
await self.crypto_store.put_signature(
target=CrossSigner(user_id, actual_key),
signer=CrossSigner(signer_user_id, signing_key),
signature=signature,
)
else:
self.log.warning(f"Invalid signature from {signing_key_log} for {key_id}")
async def get_or_fetch_device(
self, user_id: UserID, device_id: DeviceID
) -> DeviceIdentity | None:
device = await self.crypto_store.get_device(user_id, device_id)
if device is not None:
return device
devices = await self._fetch_keys([user_id], include_untracked=True)
try:
return devices[user_id][device_id]
except KeyError:
return None
async def get_or_fetch_device_by_key(
self, user_id: UserID, identity_key: IdentityKey
) -> DeviceIdentity | None:
device = await self.crypto_store.find_device_by_key(user_id, identity_key)
if device is not None:
return device
devices = await self._fetch_keys([user_id], include_untracked=True)
for device in devices.get(user_id, {}).values():
if device.identity_key == identity_key:
return device
return None
async def on_devices_changed(self, user_id: UserID) -> None:
if self.disable_device_change_key_rotation:
return
shared_rooms = await self.state_store.find_shared_rooms(user_id)
self.log.debug(
f"Devices of {user_id} changed, invalidating group session in {shared_rooms}"
)
await self.crypto_store.remove_outbound_group_sessions(shared_rooms)
@staticmethod
async def _validate_device(
user_id: UserID,
device_id: DeviceID,
device_keys: DeviceKeys,
existing: DeviceIdentity | None = None,
) -> DeviceIdentity:
if user_id != device_keys.user_id:
raise DeviceValidationError(
f"mismatching user ID (expected {user_id}, got {device_keys.user_id})"
)
elif device_id != device_keys.device_id:
raise DeviceValidationError(
f"mismatching device ID (expected {device_id}, got {device_keys.device_id})"
)
signing_key = device_keys.ed25519
if not signing_key:
raise DeviceValidationError("didn't find ed25519 signing key")
identity_key = device_keys.curve25519
if not identity_key:
raise DeviceValidationError("didn't find curve25519 identity key")
if existing and existing.signing_key != signing_key:
raise DeviceValidationError(
f"received update for device with different signing key "
f"(expected {existing.signing_key}, got {signing_key})"
)
if not verify_signature_json(device_keys.serialize(), user_id, device_id, signing_key):
raise DeviceValidationError("invalid signature on device keys")
name = device_keys.unsigned.device_display_name or device_id
return DeviceIdentity(
user_id=user_id,
device_id=device_id,
identity_key=identity_key,
signing_key=signing_key,
trust=TrustState.UNVERIFIED,
name=name,
deleted=False,
)
async def resolve_trust(self, device: DeviceIdentity) -> TrustState:
try:
return await self._try_resolve_trust(device)
except Exception:
self.log.exception(f"Failed to resolve trust of {device.user_id}/{device.device_id}")
return TrustState.UNVERIFIED
async def _try_resolve_trust(self, device: DeviceIdentity) -> TrustState:
if device.trust in (TrustState.VERIFIED, TrustState.BLACKLISTED):
return device.trust
their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id)
if len(their_keys) == 0 and device.user_id not in self._cs_fetch_attempted:
self.log.debug(f"Didn't find any cross-signing keys for {device.user_id}, fetching...")
async with self._fetch_keys_lock:
if device.user_id not in self._cs_fetch_attempted:
self._cs_fetch_attempted.add(device.user_id)
await self._fetch_keys([device.user_id])
their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id)
try:
msk = their_keys[CrossSigningUsage.MASTER]
ssk = their_keys[CrossSigningUsage.SELF]
except KeyError as e:
self.log.error(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}")
return TrustState.UNVERIFIED
ssk_signed = await self.crypto_store.is_key_signed_by(
target=CrossSigner(device.user_id, ssk.key),
signer=CrossSigner(device.user_id, msk.key),
)
if not ssk_signed:
self.log.warning(
f"Self-signing key of {device.user_id} is not signed by their master key"
)
return TrustState.UNVERIFIED
device_signed = await self.crypto_store.is_key_signed_by(
target=CrossSigner(device.user_id, device.signing_key),
signer=CrossSigner(device.user_id, ssk.key),
)
if device_signed:
if await self.is_user_trusted(device.user_id):
return TrustState.CROSS_SIGNED_TRUSTED
elif msk.key == msk.first:
return TrustState.CROSS_SIGNED_TOFU
return TrustState.CROSS_SIGNED_UNTRUSTED
return TrustState.UNVERIFIED
async def is_user_trusted(self, user_id: UserID) -> bool:
# TODO implement once own cross-signing key stuff is ready
return False