mautrix-python/mautrix/crypto/decrypt_olm.py

135 lines
5.5 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Optional
import asyncio
import olm
from mautrix.errors import DecryptionError, MatchingSessionDecryptionError
from mautrix.types import (
DecryptedOlmEvent,
EncryptedOlmEventContent,
EncryptionAlgorithm,
IdentityKey,
OlmCiphertext,
OlmMsgType,
ToDeviceEvent,
UserID,
)
from mautrix.util import background_task
from .base import BaseOlmMachine
from .sessions import Session
class OlmDecryptionMachine(BaseOlmMachine):
async def _decrypt_olm_event(self, evt: ToDeviceEvent) -> DecryptedOlmEvent:
if not isinstance(evt.content, EncryptedOlmEventContent):
raise DecryptionError("unsupported event content class")
elif evt.content.algorithm != EncryptionAlgorithm.OLM_V1:
raise DecryptionError("unsupported event encryption algorithm")
try:
own_content = evt.content.ciphertext[self.account.identity_key]
except KeyError:
raise DecryptionError("olm event doesn't contain ciphertext for this device")
self.log.debug(
f"Decrypting to-device olm event from {evt.sender}/{evt.content.sender_key}"
)
plaintext = await self._decrypt_olm_ciphertext(
evt.sender, evt.content.sender_key, own_content
)
try:
decrypted_evt: DecryptedOlmEvent = DecryptedOlmEvent.parse_json(plaintext)
except Exception:
self.log.trace("Failed to parse olm event plaintext: %s", plaintext)
raise
if decrypted_evt.sender != evt.sender:
raise DecryptionError("mismatched sender in olm payload")
elif decrypted_evt.recipient != self.client.mxid:
raise DecryptionError("mismatched recipient in olm payload")
elif decrypted_evt.recipient_keys.ed25519 != self.account.signing_key:
raise DecryptionError("mismatched recipient key in olm payload")
decrypted_evt.sender_key = evt.content.sender_key
decrypted_evt.source = evt
self.log.debug(
f"Successfully decrypted olm event from {evt.sender}/{decrypted_evt.sender_device} "
f"(sender key: {decrypted_evt.sender_key} into a {decrypted_evt.type}"
)
return decrypted_evt
async def _decrypt_olm_ciphertext(
self, sender: UserID, sender_key: IdentityKey, message: OlmCiphertext
) -> str:
if message.type not in (OlmMsgType.PREKEY, OlmMsgType.MESSAGE):
raise DecryptionError("unsupported olm message type")
try:
plaintext = await self._try_decrypt_olm_ciphertext(sender_key, message)
except MatchingSessionDecryptionError:
self.log.warning(
f"Found matching session yet decryption failed for sender {sender}"
f" with key {sender_key}"
)
background_task.create(self._unwedge_session(sender, sender_key))
raise
if not plaintext:
if message.type != OlmMsgType.PREKEY:
background_task.create(self._unwedge_session(sender, sender_key))
raise DecryptionError("Decryption failed for normal message")
self.log.trace(f"Trying to create inbound session for {sender}/{sender_key}")
try:
session = await self._create_inbound_session(sender_key, message.body)
except olm.OlmSessionError as e:
background_task.create(self._unwedge_session(sender, sender_key))
raise DecryptionError("Failed to create new session from prekey message") from e
self.log.debug(
f"Created inbound session {session.id} for {sender} (sender key: {sender_key})"
)
try:
plaintext = session.decrypt(message)
except olm.OlmSessionError as e:
raise DecryptionError(
"Failed to decrypt olm event with session created from prekey message"
) from e
await self.crypto_store.update_session(sender_key, session)
return plaintext
async def _try_decrypt_olm_ciphertext(
self, sender_key: IdentityKey, message: OlmCiphertext
) -> Optional[str]:
sessions = await self.crypto_store.get_sessions(sender_key)
for session in sessions:
if message.type == OlmMsgType.PREKEY and not session.matches(message.body):
continue
try:
plaintext = session.decrypt(message)
except olm.OlmSessionError as e:
if message.type == OlmMsgType.PREKEY:
raise MatchingSessionDecryptionError(
"decryption failed with matching session"
) from e
else:
await self.crypto_store.update_session(sender_key, session)
return plaintext
return None
async def _create_inbound_session(self, sender_key: IdentityKey, ciphertext: str) -> Session:
session = self.account.new_inbound_session(sender_key, ciphertext)
await self.crypto_store.put_account(self.account)
await self.crypto_store.add_session(sender_key, session)
return session
async def _unwedge_session(self, sender: UserID, sender_key: IdentityKey) -> None:
raise NotImplementedError()