222 lines
7.8 KiB
Python
222 lines
7.8 KiB
Python
# Copyright (c) 2022 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
from __future__ import annotations
|
|
|
|
from mautrix.client.state_store import SyncStore
|
|
from mautrix.types import (
|
|
CrossSigner,
|
|
CrossSigningUsage,
|
|
DeviceID,
|
|
DeviceIdentity,
|
|
EventID,
|
|
IdentityKey,
|
|
RoomID,
|
|
SessionID,
|
|
SigningKey,
|
|
SyncToken,
|
|
TOFUSigningKey,
|
|
UserID,
|
|
)
|
|
|
|
from ..account import OlmAccount
|
|
from ..sessions import InboundGroupSession, OutboundGroupSession, Session
|
|
from .abstract import CryptoStore
|
|
|
|
|
|
class MemoryCryptoStore(CryptoStore, SyncStore):
|
|
_device_id: DeviceID | None
|
|
_sync_token: SyncToken | None
|
|
_account: OlmAccount | None
|
|
_message_indices: dict[tuple[IdentityKey, SessionID, int], tuple[EventID, int]]
|
|
_devices: dict[UserID, dict[DeviceID, DeviceIdentity]]
|
|
_olm_sessions: dict[IdentityKey, list[Session]]
|
|
_inbound_sessions: dict[tuple[RoomID, SessionID], InboundGroupSession]
|
|
_outbound_sessions: dict[RoomID, OutboundGroupSession]
|
|
_signatures: dict[CrossSigner, dict[CrossSigner, str]]
|
|
_cross_signing_keys: dict[UserID, dict[CrossSigningUsage, TOFUSigningKey]]
|
|
|
|
def __init__(self, account_id: str, pickle_key: str) -> None:
|
|
self.account_id = account_id
|
|
self.pickle_key = pickle_key
|
|
|
|
self._sync_token = None
|
|
self._device_id = None
|
|
self._account = None
|
|
self._message_indices = {}
|
|
self._devices = {}
|
|
self._olm_sessions = {}
|
|
self._inbound_sessions = {}
|
|
self._outbound_sessions = {}
|
|
self._signatures = {}
|
|
self._cross_signing_keys = {}
|
|
|
|
async def get_device_id(self) -> DeviceID | None:
|
|
return self._device_id
|
|
|
|
async def put_device_id(self, device_id: DeviceID) -> None:
|
|
self._device_id = device_id
|
|
|
|
async def put_next_batch(self, next_batch: SyncToken) -> None:
|
|
self._sync_token = next_batch
|
|
|
|
async def get_next_batch(self) -> SyncToken:
|
|
return self._sync_token
|
|
|
|
async def delete(self) -> None:
|
|
self._account = None
|
|
self._device_id = None
|
|
self._olm_sessions = {}
|
|
self._outbound_sessions = {}
|
|
|
|
async def put_account(self, account: OlmAccount) -> None:
|
|
self._account = account
|
|
|
|
async def get_account(self) -> OlmAccount:
|
|
return self._account
|
|
|
|
async def has_session(self, key: IdentityKey) -> bool:
|
|
return key in self._olm_sessions
|
|
|
|
async def get_sessions(self, key: IdentityKey) -> list[Session]:
|
|
return self._olm_sessions.get(key, [])
|
|
|
|
async def get_latest_session(self, key: IdentityKey) -> Session | None:
|
|
try:
|
|
return self._olm_sessions[key][-1]
|
|
except (KeyError, IndexError):
|
|
return None
|
|
|
|
async def add_session(self, key: IdentityKey, session: Session) -> None:
|
|
self._olm_sessions.setdefault(key, []).append(session)
|
|
|
|
async def update_session(self, key: IdentityKey, session: Session) -> None:
|
|
# This is a no-op as the session object is the same one previously added.
|
|
pass
|
|
|
|
async def put_group_session(
|
|
self,
|
|
room_id: RoomID,
|
|
sender_key: IdentityKey,
|
|
session_id: SessionID,
|
|
session: InboundGroupSession,
|
|
) -> None:
|
|
self._inbound_sessions[(room_id, session_id)] = session
|
|
|
|
async def get_group_session(
|
|
self, room_id: RoomID, session_id: SessionID
|
|
) -> InboundGroupSession:
|
|
return self._inbound_sessions.get((room_id, session_id))
|
|
|
|
async def redact_group_session(
|
|
self, room_id: RoomID, session_id: SessionID, reason: str
|
|
) -> None:
|
|
self._inbound_sessions.pop((room_id, session_id), None)
|
|
|
|
async def redact_group_sessions(
|
|
self, room_id: RoomID, sender_key: IdentityKey, reason: str
|
|
) -> list[SessionID]:
|
|
if not room_id and not sender_key:
|
|
raise ValueError("Either room_id or sender_key must be provided")
|
|
deleted = []
|
|
keys = list(self._inbound_sessions.keys())
|
|
for key in keys:
|
|
item = self._inbound_sessions[key]
|
|
if (not room_id or item.room_id == room_id) and (
|
|
not sender_key or item.sender_key == sender_key
|
|
):
|
|
deleted.append(SessionID(item.id))
|
|
del self._inbound_sessions[key]
|
|
return deleted
|
|
|
|
async def redact_expired_group_sessions(self) -> list[SessionID]:
|
|
raise NotImplementedError()
|
|
|
|
async def redact_outdated_group_sessions(self) -> list[SessionID]:
|
|
raise NotImplementedError()
|
|
|
|
async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
|
|
return (room_id, session_id) in self._inbound_sessions
|
|
|
|
async def add_outbound_group_session(self, session: OutboundGroupSession) -> None:
|
|
self._outbound_sessions[session.room_id] = session
|
|
|
|
async def update_outbound_group_session(self, session: OutboundGroupSession) -> None:
|
|
# This is a no-op as the session object is the same one previously added.
|
|
pass
|
|
|
|
async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None:
|
|
return self._outbound_sessions.get(room_id)
|
|
|
|
async def remove_outbound_group_session(self, room_id: RoomID) -> None:
|
|
self._outbound_sessions.pop(room_id, None)
|
|
|
|
async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None:
|
|
for room_id in rooms:
|
|
self._outbound_sessions.pop(room_id, None)
|
|
|
|
async def validate_message_index(
|
|
self,
|
|
sender_key: IdentityKey,
|
|
session_id: SessionID,
|
|
event_id: EventID,
|
|
index: int,
|
|
timestamp: int,
|
|
) -> bool:
|
|
try:
|
|
return self._message_indices[(sender_key, session_id, index)] == (event_id, timestamp)
|
|
except KeyError:
|
|
self._message_indices[(sender_key, session_id, index)] = (event_id, timestamp)
|
|
return True
|
|
|
|
async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None:
|
|
return self._devices.get(user_id)
|
|
|
|
async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None:
|
|
return self._devices.get(user_id, {}).get(device_id)
|
|
|
|
async def find_device_by_key(
|
|
self, user_id: UserID, identity_key: IdentityKey
|
|
) -> DeviceIdentity | None:
|
|
for device in self._devices.get(user_id, {}).values():
|
|
if device.identity_key == identity_key:
|
|
return device
|
|
return None
|
|
|
|
async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None:
|
|
self._devices[user_id] = devices
|
|
|
|
async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]:
|
|
return [user_id for user_id in users if user_id in self._devices]
|
|
|
|
async def put_cross_signing_key(
|
|
self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey
|
|
) -> None:
|
|
try:
|
|
current = self._cross_signing_keys[user_id][usage]
|
|
except KeyError:
|
|
self._cross_signing_keys.setdefault(user_id, {})[usage] = TOFUSigningKey(
|
|
key=key, first=key
|
|
)
|
|
else:
|
|
current.key = key
|
|
|
|
async def get_cross_signing_keys(
|
|
self, user_id: UserID
|
|
) -> dict[CrossSigningUsage, TOFUSigningKey]:
|
|
return self._cross_signing_keys.get(user_id, {})
|
|
|
|
async def put_signature(
|
|
self, target: CrossSigner, signer: CrossSigner, signature: str
|
|
) -> None:
|
|
self._signatures.setdefault(signer, {})[target] = signature
|
|
|
|
async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool:
|
|
return target in self._signatures.get(signer, {})
|
|
|
|
async def drop_signatures_by_key(self, signer: CrossSigner) -> int:
|
|
deleted = self._signatures.pop(signer, None)
|
|
return len(deleted)
|