347 lines
15 KiB
Python
347 lines
15 KiB
Python
# Copyright (c) 2022 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
from __future__ import annotations
|
|
|
|
from mautrix.errors import DeviceValidationError
|
|
from mautrix.types import (
|
|
CrossSigner,
|
|
CrossSigningKeys,
|
|
CrossSigningUsage,
|
|
DeviceID,
|
|
DeviceIdentity,
|
|
DeviceKeys,
|
|
EncryptionKeyAlgorithm,
|
|
IdentityKey,
|
|
KeyID,
|
|
QueryKeysResponse,
|
|
SigningKey,
|
|
SyncToken,
|
|
TrustState,
|
|
UserID,
|
|
)
|
|
|
|
from .base import BaseOlmMachine, verify_signature_json
|
|
|
|
|
|
class DeviceListMachine(BaseOlmMachine):
|
|
async def _fetch_keys(
|
|
self, users: list[UserID], since: SyncToken = "", include_untracked: bool = False
|
|
) -> dict[UserID, dict[DeviceID, DeviceIdentity]]:
|
|
if not include_untracked:
|
|
users = await self.crypto_store.filter_tracked_users(users)
|
|
if len(users) == 0:
|
|
return {}
|
|
users = set(users)
|
|
|
|
self.log.trace(f"Querying keys for {users}")
|
|
resp = await self.client.query_keys(users, token=since)
|
|
missing_users = users.copy()
|
|
|
|
for server, err in resp.failures.items():
|
|
self.log.warning(f"Query keys failure for {server}: {err}")
|
|
|
|
data = {}
|
|
for user_id, devices in resp.device_keys.items():
|
|
missing_users.remove(user_id)
|
|
|
|
new_devices = {}
|
|
existing_devices = (await self.crypto_store.get_devices(user_id)) or {}
|
|
|
|
self.log.trace(
|
|
f"Updating devices for {user_id}, got {len(devices)}, "
|
|
f"have {len(existing_devices)} in store"
|
|
)
|
|
changed = False
|
|
ssks = resp.self_signing_keys.get(user_id)
|
|
ssk = ssks.first_ed25519_key if ssks else None
|
|
for device_id, device_keys in devices.items():
|
|
try:
|
|
existing = existing_devices[device_id]
|
|
except KeyError:
|
|
existing = None
|
|
changed = True
|
|
self.log.trace(f"Validating device {device_keys} of {user_id}")
|
|
try:
|
|
new_device = await self._validate_device(
|
|
user_id, device_id, device_keys, existing
|
|
)
|
|
except DeviceValidationError as e:
|
|
self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}")
|
|
else:
|
|
if new_device:
|
|
new_devices[device_id] = new_device
|
|
await self._store_device_self_signatures(device_keys, ssk)
|
|
self.log.debug(
|
|
f"Storing new device list for {user_id} containing {len(new_devices)} devices"
|
|
)
|
|
await self.crypto_store.put_devices(user_id, new_devices)
|
|
data[user_id] = new_devices
|
|
|
|
if changed or len(new_devices) != len(existing_devices):
|
|
if self.delete_keys_on_device_delete:
|
|
for device_id in existing_devices.keys() - new_devices.keys():
|
|
device = existing_devices[device_id]
|
|
removed_ids = await self.crypto_store.redact_group_sessions(
|
|
room_id=None, sender_key=device.identity_key, reason="device removed"
|
|
)
|
|
self.log.info(
|
|
"Redacted megolm sessions sent by removed device "
|
|
f"{device.user_id}/{device.device_id}: {removed_ids}"
|
|
)
|
|
await self.on_devices_changed(user_id)
|
|
|
|
for user_id in missing_users:
|
|
self.log.warning(f"Didn't get any devices for user {user_id}")
|
|
|
|
for user_id in users:
|
|
await self._store_cross_signing_keys(resp, user_id)
|
|
|
|
return data
|
|
|
|
async def _store_device_self_signatures(
|
|
self, device_keys: DeviceKeys, self_signing_key: SigningKey | None
|
|
) -> None:
|
|
device_desc = f"Device {device_keys.user_id}/{device_keys.device_id}"
|
|
try:
|
|
self_signatures = device_keys.signatures[device_keys.user_id].copy()
|
|
except KeyError:
|
|
self.log.warning(f"{device_desc} doesn't have any signatures from the user")
|
|
return
|
|
if len(device_keys.signatures) > 1:
|
|
self.log.debug(
|
|
f"{device_desc} has signatures from other users (%s)",
|
|
set(device_keys.signatures.keys()) - {device_keys.user_id},
|
|
)
|
|
|
|
device_self_sig = self_signatures.pop(
|
|
KeyID(EncryptionKeyAlgorithm.ED25519, device_keys.device_id)
|
|
)
|
|
target = CrossSigner(device_keys.user_id, device_keys.ed25519)
|
|
# This one is already validated by _validate_device
|
|
await self.crypto_store.put_signature(target, target, device_self_sig)
|
|
|
|
try:
|
|
cs_self_sig = self_signatures.pop(
|
|
KeyID(EncryptionKeyAlgorithm.ED25519, self_signing_key)
|
|
)
|
|
except KeyError:
|
|
self.log.warning(f"{device_desc} isn't cross-signed")
|
|
else:
|
|
is_valid_self_sig = verify_signature_json(
|
|
device_keys.serialize(), device_keys.user_id, self_signing_key, self_signing_key
|
|
)
|
|
if is_valid_self_sig:
|
|
signer = CrossSigner(device_keys.user_id, self_signing_key)
|
|
await self.crypto_store.put_signature(target, signer, cs_self_sig)
|
|
else:
|
|
self.log.warning(f"{device_desc} doesn't have a valid cross-signing signature")
|
|
|
|
if len(self_signatures) > 0:
|
|
self.log.debug(
|
|
f"{device_desc} has signatures from unexpected keys (%s)",
|
|
set(self_signatures.keys()),
|
|
)
|
|
|
|
async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: UserID) -> None:
|
|
new_keys: dict[CrossSigningUsage, CrossSigningKeys] = {}
|
|
try:
|
|
master = new_keys[CrossSigningUsage.MASTER] = resp.master_keys[user_id]
|
|
except KeyError:
|
|
self.log.debug(f"Didn't get a cross-signing master key for {user_id}")
|
|
return
|
|
try:
|
|
new_keys[CrossSigningUsage.SELF] = resp.self_signing_keys[user_id]
|
|
except KeyError:
|
|
self.log.debug(f"Didn't get a cross-signing self-signing key for {user_id}")
|
|
return
|
|
try:
|
|
new_keys[CrossSigningUsage.USER] = resp.user_signing_keys[user_id]
|
|
except KeyError:
|
|
pass
|
|
current_keys = await self.crypto_store.get_cross_signing_keys(user_id)
|
|
for usage, key in current_keys.items():
|
|
if usage in new_keys and key.key != new_keys[usage].first_ed25519_key:
|
|
num = await self.crypto_store.drop_signatures_by_key(CrossSigner(user_id, key.key))
|
|
if num >= 0:
|
|
self.log.debug(
|
|
f"Dropped {num} signatures made by key {user_id}/{key.key} ({usage})"
|
|
" as it has been replaced"
|
|
)
|
|
for usage, key in new_keys.items():
|
|
actual_key = key.first_ed25519_key
|
|
self.log.debug(f"Storing cross-signing key for {user_id}: {actual_key} (type {usage})")
|
|
await self.crypto_store.put_cross_signing_key(user_id, usage, actual_key)
|
|
|
|
if usage != CrossSigningUsage.MASTER and (
|
|
KeyID(EncryptionKeyAlgorithm.ED25519, master.first_ed25519_key)
|
|
not in key.signatures[user_id]
|
|
):
|
|
self.log.warning(
|
|
f"Cross-signing key {user_id}/{actual_key}/{usage}"
|
|
" doesn't seem to have a signature from the master key"
|
|
)
|
|
|
|
for signer_user_id, signatures in key.signatures.items():
|
|
for key_id, signature in signatures.items():
|
|
signing_key = SigningKey(key_id.key_id)
|
|
if signer_user_id == user_id:
|
|
try:
|
|
device = resp.device_keys[signer_user_id][DeviceID(key_id.key_id)]
|
|
signing_key = device.ed25519
|
|
except KeyError:
|
|
pass
|
|
if len(signing_key) != 43:
|
|
self.log.debug(
|
|
f"Cross-signing key {user_id}/{actual_key} has a signature from "
|
|
f"an unknown key {key_id}"
|
|
)
|
|
continue
|
|
signing_key_log = signing_key
|
|
if signing_key != key_id.key_id:
|
|
signing_key_log = f"{signing_key} ({key_id})"
|
|
self.log.debug(
|
|
f"Verifying cross-signing key {user_id}/{actual_key} "
|
|
f"with key {signer_user_id}/{signing_key_log}"
|
|
)
|
|
is_valid_sig = verify_signature_json(
|
|
key.serialize(), signer_user_id, key_id.key_id, signing_key
|
|
)
|
|
if is_valid_sig:
|
|
self.log.debug(f"Signature from {signing_key_log} for {key_id} verified")
|
|
await self.crypto_store.put_signature(
|
|
target=CrossSigner(user_id, actual_key),
|
|
signer=CrossSigner(signer_user_id, signing_key),
|
|
signature=signature,
|
|
)
|
|
else:
|
|
self.log.warning(f"Invalid signature from {signing_key_log} for {key_id}")
|
|
|
|
async def get_or_fetch_device(
|
|
self, user_id: UserID, device_id: DeviceID
|
|
) -> DeviceIdentity | None:
|
|
device = await self.crypto_store.get_device(user_id, device_id)
|
|
if device is not None:
|
|
return device
|
|
devices = await self._fetch_keys([user_id], include_untracked=True)
|
|
try:
|
|
return devices[user_id][device_id]
|
|
except KeyError:
|
|
return None
|
|
|
|
async def get_or_fetch_device_by_key(
|
|
self, user_id: UserID, identity_key: IdentityKey
|
|
) -> DeviceIdentity | None:
|
|
device = await self.crypto_store.find_device_by_key(user_id, identity_key)
|
|
if device is not None:
|
|
return device
|
|
devices = await self._fetch_keys([user_id], include_untracked=True)
|
|
for device in devices.get(user_id, {}).values():
|
|
if device.identity_key == identity_key:
|
|
return device
|
|
return None
|
|
|
|
async def on_devices_changed(self, user_id: UserID) -> None:
|
|
if self.disable_device_change_key_rotation:
|
|
return
|
|
shared_rooms = await self.state_store.find_shared_rooms(user_id)
|
|
self.log.debug(
|
|
f"Devices of {user_id} changed, invalidating group session in {shared_rooms}"
|
|
)
|
|
await self.crypto_store.remove_outbound_group_sessions(shared_rooms)
|
|
|
|
@staticmethod
|
|
async def _validate_device(
|
|
user_id: UserID,
|
|
device_id: DeviceID,
|
|
device_keys: DeviceKeys,
|
|
existing: DeviceIdentity | None = None,
|
|
) -> DeviceIdentity:
|
|
if user_id != device_keys.user_id:
|
|
raise DeviceValidationError(
|
|
f"mismatching user ID (expected {user_id}, got {device_keys.user_id})"
|
|
)
|
|
elif device_id != device_keys.device_id:
|
|
raise DeviceValidationError(
|
|
f"mismatching device ID (expected {device_id}, got {device_keys.device_id})"
|
|
)
|
|
|
|
signing_key = device_keys.ed25519
|
|
if not signing_key:
|
|
raise DeviceValidationError("didn't find ed25519 signing key")
|
|
identity_key = device_keys.curve25519
|
|
if not identity_key:
|
|
raise DeviceValidationError("didn't find curve25519 identity key")
|
|
|
|
if existing and existing.signing_key != signing_key:
|
|
raise DeviceValidationError(
|
|
f"received update for device with different signing key "
|
|
f"(expected {existing.signing_key}, got {signing_key})"
|
|
)
|
|
|
|
if not verify_signature_json(device_keys.serialize(), user_id, device_id, signing_key):
|
|
raise DeviceValidationError("invalid signature on device keys")
|
|
|
|
name = device_keys.unsigned.device_display_name or device_id
|
|
|
|
return DeviceIdentity(
|
|
user_id=user_id,
|
|
device_id=device_id,
|
|
identity_key=identity_key,
|
|
signing_key=signing_key,
|
|
trust=TrustState.UNVERIFIED,
|
|
name=name,
|
|
deleted=False,
|
|
)
|
|
|
|
async def resolve_trust(self, device: DeviceIdentity) -> TrustState:
|
|
try:
|
|
return await self._try_resolve_trust(device)
|
|
except Exception:
|
|
self.log.exception(f"Failed to resolve trust of {device.user_id}/{device.device_id}")
|
|
return TrustState.UNVERIFIED
|
|
|
|
async def _try_resolve_trust(self, device: DeviceIdentity) -> TrustState:
|
|
if device.trust in (TrustState.VERIFIED, TrustState.BLACKLISTED):
|
|
return device.trust
|
|
their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id)
|
|
if len(their_keys) == 0 and device.user_id not in self._cs_fetch_attempted:
|
|
self.log.debug(f"Didn't find any cross-signing keys for {device.user_id}, fetching...")
|
|
async with self._fetch_keys_lock:
|
|
if device.user_id not in self._cs_fetch_attempted:
|
|
self._cs_fetch_attempted.add(device.user_id)
|
|
await self._fetch_keys([device.user_id])
|
|
their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id)
|
|
try:
|
|
msk = their_keys[CrossSigningUsage.MASTER]
|
|
ssk = their_keys[CrossSigningUsage.SELF]
|
|
except KeyError as e:
|
|
self.log.error(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}")
|
|
return TrustState.UNVERIFIED
|
|
ssk_signed = await self.crypto_store.is_key_signed_by(
|
|
target=CrossSigner(device.user_id, ssk.key),
|
|
signer=CrossSigner(device.user_id, msk.key),
|
|
)
|
|
if not ssk_signed:
|
|
self.log.warning(
|
|
f"Self-signing key of {device.user_id} is not signed by their master key"
|
|
)
|
|
return TrustState.UNVERIFIED
|
|
device_signed = await self.crypto_store.is_key_signed_by(
|
|
target=CrossSigner(device.user_id, device.signing_key),
|
|
signer=CrossSigner(device.user_id, ssk.key),
|
|
)
|
|
if device_signed:
|
|
if await self.is_user_trusted(device.user_id):
|
|
return TrustState.CROSS_SIGNED_TRUSTED
|
|
elif msk.key == msk.first:
|
|
return TrustState.CROSS_SIGNED_TOFU
|
|
return TrustState.CROSS_SIGNED_UNTRUSTED
|
|
return TrustState.UNVERIFIED
|
|
|
|
async def is_user_trusted(self, user_id: UserID) -> bool:
|
|
# TODO implement once own cross-signing key stuff is ready
|
|
return False
|