135 lines
5.5 KiB
Python
135 lines
5.5 KiB
Python
# Copyright (c) 2022 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
from typing import Optional
|
|
import asyncio
|
|
|
|
import olm
|
|
|
|
from mautrix.errors import DecryptionError, MatchingSessionDecryptionError
|
|
from mautrix.types import (
|
|
DecryptedOlmEvent,
|
|
EncryptedOlmEventContent,
|
|
EncryptionAlgorithm,
|
|
IdentityKey,
|
|
OlmCiphertext,
|
|
OlmMsgType,
|
|
ToDeviceEvent,
|
|
UserID,
|
|
)
|
|
from mautrix.util import background_task
|
|
|
|
from .base import BaseOlmMachine
|
|
from .sessions import Session
|
|
|
|
|
|
class OlmDecryptionMachine(BaseOlmMachine):
|
|
async def _decrypt_olm_event(self, evt: ToDeviceEvent) -> DecryptedOlmEvent:
|
|
if not isinstance(evt.content, EncryptedOlmEventContent):
|
|
raise DecryptionError("unsupported event content class")
|
|
elif evt.content.algorithm != EncryptionAlgorithm.OLM_V1:
|
|
raise DecryptionError("unsupported event encryption algorithm")
|
|
try:
|
|
own_content = evt.content.ciphertext[self.account.identity_key]
|
|
except KeyError:
|
|
raise DecryptionError("olm event doesn't contain ciphertext for this device")
|
|
|
|
self.log.debug(
|
|
f"Decrypting to-device olm event from {evt.sender}/{evt.content.sender_key}"
|
|
)
|
|
plaintext = await self._decrypt_olm_ciphertext(
|
|
evt.sender, evt.content.sender_key, own_content
|
|
)
|
|
|
|
try:
|
|
decrypted_evt: DecryptedOlmEvent = DecryptedOlmEvent.parse_json(plaintext)
|
|
except Exception:
|
|
self.log.trace("Failed to parse olm event plaintext: %s", plaintext)
|
|
raise
|
|
if decrypted_evt.sender != evt.sender:
|
|
raise DecryptionError("mismatched sender in olm payload")
|
|
elif decrypted_evt.recipient != self.client.mxid:
|
|
raise DecryptionError("mismatched recipient in olm payload")
|
|
elif decrypted_evt.recipient_keys.ed25519 != self.account.signing_key:
|
|
raise DecryptionError("mismatched recipient key in olm payload")
|
|
decrypted_evt.sender_key = evt.content.sender_key
|
|
decrypted_evt.source = evt
|
|
self.log.debug(
|
|
f"Successfully decrypted olm event from {evt.sender}/{decrypted_evt.sender_device} "
|
|
f"(sender key: {decrypted_evt.sender_key} into a {decrypted_evt.type}"
|
|
)
|
|
return decrypted_evt
|
|
|
|
async def _decrypt_olm_ciphertext(
|
|
self, sender: UserID, sender_key: IdentityKey, message: OlmCiphertext
|
|
) -> str:
|
|
if message.type not in (OlmMsgType.PREKEY, OlmMsgType.MESSAGE):
|
|
raise DecryptionError("unsupported olm message type")
|
|
|
|
try:
|
|
plaintext = await self._try_decrypt_olm_ciphertext(sender_key, message)
|
|
except MatchingSessionDecryptionError:
|
|
self.log.warning(
|
|
f"Found matching session yet decryption failed for sender {sender}"
|
|
f" with key {sender_key}"
|
|
)
|
|
background_task.create(self._unwedge_session(sender, sender_key))
|
|
raise
|
|
|
|
if not plaintext:
|
|
if message.type != OlmMsgType.PREKEY:
|
|
background_task.create(self._unwedge_session(sender, sender_key))
|
|
raise DecryptionError("Decryption failed for normal message")
|
|
|
|
self.log.trace(f"Trying to create inbound session for {sender}/{sender_key}")
|
|
try:
|
|
session = await self._create_inbound_session(sender_key, message.body)
|
|
except olm.OlmSessionError as e:
|
|
background_task.create(self._unwedge_session(sender, sender_key))
|
|
raise DecryptionError("Failed to create new session from prekey message") from e
|
|
self.log.debug(
|
|
f"Created inbound session {session.id} for {sender} (sender key: {sender_key})"
|
|
)
|
|
|
|
try:
|
|
plaintext = session.decrypt(message)
|
|
except olm.OlmSessionError as e:
|
|
raise DecryptionError(
|
|
"Failed to decrypt olm event with session created from prekey message"
|
|
) from e
|
|
|
|
await self.crypto_store.update_session(sender_key, session)
|
|
|
|
return plaintext
|
|
|
|
async def _try_decrypt_olm_ciphertext(
|
|
self, sender_key: IdentityKey, message: OlmCiphertext
|
|
) -> Optional[str]:
|
|
sessions = await self.crypto_store.get_sessions(sender_key)
|
|
for session in sessions:
|
|
if message.type == OlmMsgType.PREKEY and not session.matches(message.body):
|
|
continue
|
|
|
|
try:
|
|
plaintext = session.decrypt(message)
|
|
except olm.OlmSessionError as e:
|
|
if message.type == OlmMsgType.PREKEY:
|
|
raise MatchingSessionDecryptionError(
|
|
"decryption failed with matching session"
|
|
) from e
|
|
else:
|
|
await self.crypto_store.update_session(sender_key, session)
|
|
return plaintext
|
|
return None
|
|
|
|
async def _create_inbound_session(self, sender_key: IdentityKey, ciphertext: str) -> Session:
|
|
session = self.account.new_inbound_session(sender_key, ciphertext)
|
|
await self.crypto_store.put_account(self.account)
|
|
await self.crypto_store.add_session(sender_key, session)
|
|
return session
|
|
|
|
async def _unwedge_session(self, sender: UserID, sender_key: IdentityKey) -> None:
|
|
raise NotImplementedError()
|