mautrix-python/mautrix/crypto/store/asyncpg/store.py

711 lines
27 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from collections import defaultdict
from datetime import timedelta
from asyncpg import UniqueViolationError
from mautrix.client.state_store import SyncStore
from mautrix.client.state_store.asyncpg import PgStateStore
from mautrix.errors import GroupSessionWithheldError
from mautrix.types import (
CrossSigner,
CrossSigningUsage,
DeviceID,
DeviceIdentity,
EventID,
IdentityKey,
RoomID,
RoomKeyWithheldCode,
SessionID,
SigningKey,
SyncToken,
TOFUSigningKey,
TrustState,
UserID,
)
from mautrix.util.async_db import Database, Scheme
from mautrix.util.logging import TraceLogger
from ... import InboundGroupSession, OlmAccount, OutboundGroupSession, RatchetSafety, Session
from ..abstract import CryptoStore, StateStore
from .upgrade import upgrade_table
try:
from sqlite3 import IntegrityError, sqlite_version_info as sqlite_version
from aiosqlite import Cursor
except ImportError:
Cursor = None
sqlite_version = (0, 0, 0)
class IntegrityError(Exception):
pass
class PgCryptoStateStore(PgStateStore, StateStore):
"""
This class ensures that the PgStateStore in the client module implements the StateStore
methods needed by the crypto module.
"""
class PgCryptoStore(CryptoStore, SyncStore):
upgrade_table = upgrade_table
db: Database
account_id: str
pickle_key: str
log: TraceLogger
_sync_token: SyncToken | None
_device_id: DeviceID | None
_account: OlmAccount | None
_olm_cache: dict[IdentityKey, dict[SessionID, Session]]
def __init__(self, account_id: str, pickle_key: str, db: Database) -> None:
self.db = db
self.account_id = account_id
self.pickle_key = pickle_key
self.log = db.log
self._sync_token = None
self._device_id = DeviceID("")
self._account = None
self._olm_cache = defaultdict(lambda: {})
async def delete(self) -> None:
tables = ("crypto_account", "crypto_olm_session", "crypto_megolm_outbound_session")
async with self.db.acquire() as conn, conn.transaction():
for table in tables:
await conn.execute(f"DELETE FROM {table} WHERE account_id=$1", self.account_id)
async def get_device_id(self) -> DeviceID | None:
q = "SELECT device_id FROM crypto_account WHERE account_id=$1"
device_id = await self.db.fetchval(q, self.account_id)
self._device_id = device_id or self._device_id
return self._device_id
async def put_device_id(self, device_id: DeviceID) -> None:
q = "UPDATE crypto_account SET device_id=$1 WHERE account_id=$2"
await self.db.fetchval(q, device_id, self.account_id)
self._device_id = device_id
async def put_next_batch(self, next_batch: SyncToken) -> None:
self._sync_token = next_batch
q = "UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2"
await self.db.execute(q, self._sync_token, self.account_id)
async def get_next_batch(self) -> SyncToken:
if self._sync_token is None:
q = "SELECT sync_token FROM crypto_account WHERE account_id=$1"
self._sync_token = await self.db.fetchval(q, self.account_id)
return self._sync_token
async def put_account(self, account: OlmAccount) -> None:
self._account = account
pickle = account.pickle(self.pickle_key)
q = """
INSERT INTO crypto_account (account_id, device_id, shared, sync_token, account)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (account_id) DO UPDATE
SET shared=excluded.shared, sync_token=excluded.sync_token, account=excluded.account
"""
await self.db.execute(
q,
self.account_id,
self._device_id or "",
account.shared,
self._sync_token or "",
pickle,
)
async def get_account(self) -> OlmAccount:
if self._account is None:
q = "SELECT shared, account, device_id FROM crypto_account WHERE account_id=$1"
row = await self.db.fetchrow(q, self.account_id)
if row is not None:
self._account = OlmAccount.from_pickle(
row["account"], passphrase=self.pickle_key, shared=row["shared"]
)
return self._account
async def has_session(self, key: IdentityKey) -> bool:
if len(self._olm_cache[key]) > 0:
return True
q = "SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2"
val = await self.db.fetchval(q, key, self.account_id)
return val is not None
async def get_sessions(self, key: IdentityKey) -> list[Session]:
q = """
SELECT session_id, session, created_at, last_encrypted, last_decrypted
FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2
ORDER BY last_decrypted DESC
"""
rows = await self.db.fetch(q, key, self.account_id)
sessions = []
for row in rows:
try:
sess = self._olm_cache[key][row["session_id"]]
except KeyError:
sess = Session.from_pickle(
row["session"],
passphrase=self.pickle_key,
creation_time=row["created_at"],
last_encrypted=row["last_encrypted"],
last_decrypted=row["last_decrypted"],
)
self._olm_cache[key][SessionID(sess.id)] = sess
sessions.append(sess)
return sessions
async def get_latest_session(self, key: IdentityKey) -> Session | None:
q = """
SELECT session_id, session, created_at, last_encrypted, last_decrypted
FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2
ORDER BY last_decrypted DESC LIMIT 1
"""
row = await self.db.fetchrow(q, key, self.account_id)
if row is None:
return None
try:
return self._olm_cache[key][row["session_id"]]
except KeyError:
sess = Session.from_pickle(
row["session"],
passphrase=self.pickle_key,
creation_time=row["created_at"],
last_encrypted=row["last_encrypted"],
last_decrypted=row["last_decrypted"],
)
self._olm_cache[key][SessionID(sess.id)] = sess
return sess
async def add_session(self, key: IdentityKey, session: Session) -> None:
if session.id in self._olm_cache[key]:
self.log.warning(f"Cache already contains Olm session with ID {session.id}")
self._olm_cache[key][SessionID(session.id)] = session
pickle = session.pickle(self.pickle_key)
q = """
INSERT INTO crypto_olm_session (
session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7)
"""
await self.db.execute(
q,
session.id,
key,
pickle,
session.creation_time,
session.last_encrypted,
session.last_decrypted,
self.account_id,
)
async def update_session(self, key: IdentityKey, session: Session) -> None:
try:
assert self._olm_cache[key][SessionID(session.id)] == session
except (KeyError, AssertionError) as e:
self.log.warning(
f"Cached olm session with ID {session.id} "
f"isn't equal to the one being saved to the database ({e})"
)
pickle = session.pickle(self.pickle_key)
q = """
UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3
WHERE session_id=$4 AND account_id=$5
"""
await self.db.execute(
q, pickle, session.last_encrypted, session.last_decrypted, session.id, self.account_id
)
async def put_group_session(
self,
room_id: RoomID,
sender_key: IdentityKey,
session_id: SessionID,
session: InboundGroupSession,
) -> None:
pickle = session.pickle(self.pickle_key)
forwarding_chains = ",".join(session.forwarding_chain)
q = """
INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT (session_id, account_id) DO UPDATE
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key,
signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session,
forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety,
received_at=excluded.received_at, max_age=excluded.max_age,
max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled
"""
try:
await self.db.execute(
q,
session_id,
sender_key,
session.signing_key,
room_id,
pickle,
forwarding_chains,
session.ratchet_safety.json(),
session.received_at,
int(session.max_age.total_seconds() * 1000) if session.max_age else None,
session.max_messages,
session.is_scheduled,
self.account_id,
)
except (IntegrityError, UniqueViolationError):
self.log.exception(f"Failed to insert megolm session {session_id}")
async def get_group_session(
self, room_id: RoomID, session_id: SessionID
) -> InboundGroupSession | None:
q = """
SELECT
sender_key, signing_key, session, forwarding_chains, withheld_code,
ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND session_id=$2 AND account_id=$3
"""
row = await self.db.fetchrow(q, room_id, session_id, self.account_id)
if row is None:
return None
if row["withheld_code"] is not None:
raise GroupSessionWithheldError(session_id, row["withheld_code"])
forwarding_chain = row["forwarding_chains"].split(",") if row["forwarding_chains"] else []
return InboundGroupSession.from_pickle(
row["session"],
passphrase=self.pickle_key,
signing_key=row["signing_key"],
sender_key=row["sender_key"],
room_id=room_id,
forwarding_chain=forwarding_chain,
ratchet_safety=RatchetSafety.parse_json(row["ratchet_safety"] or "{}"),
received_at=row["received_at"],
max_age=timedelta(milliseconds=row["max_age"]) if row["max_age"] else None,
max_messages=row["max_messages"],
is_scheduled=row["is_scheduled"],
)
async def redact_group_session(
self, room_id: RoomID, session_id: SessionID, reason: str
) -> None:
q = """
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
"""
await self.db.execute(
q,
RoomKeyWithheldCode.BEEPER_REDACTED.value,
f"Session redacted: {reason}",
session_id,
self.account_id,
)
async def redact_group_sessions(
self, room_id: RoomID, sender_key: IdentityKey, reason: str
) -> list[SessionID]:
if not room_id and not sender_key:
raise ValueError("Either room_id or sender_key must be provided")
q = """
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
RETURNING session_id
"""
rows = await self.db.fetch(
q,
RoomKeyWithheldCode.BEEPER_REDACTED.value,
f"Session redacted: {reason}",
room_id,
sender_key,
self.account_id,
)
return [row["session_id"] for row in rows]
async def redact_expired_group_sessions(self) -> list[SessionID]:
if self.db.scheme == Scheme.SQLITE:
q = """
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
AND received_at IS NOT NULL and max_age IS NOT NULL
AND unixepoch(received_at) + (2 * max_age / 1000) < unixepoch(date('now'))
RETURNING session_id
"""
elif self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = """
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
AND received_at IS NOT NULL and max_age IS NOT NULL
AND received_at + 2 * (max_age * interval '1 millisecond') < now()
RETURNING session_id
"""
else:
raise RuntimeError(f"Unsupported dialect {self.db.scheme}")
rows = await self.db.fetch(
q,
RoomKeyWithheldCode.BEEPER_REDACTED.value,
f"Session redacted: expired",
self.account_id,
)
return [row["session_id"] for row in rows]
async def redact_outdated_group_sessions(self) -> list[SessionID]:
q = """
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL
RETURNING session_id
"""
rows = await self.db.fetch(
q,
RoomKeyWithheldCode.BEEPER_REDACTED.value,
f"Session redacted: outdated",
self.account_id,
)
return [row["session_id"] for row in rows]
async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
q = """
SELECT COUNT(session) FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND session_id=$2 AND account_id=$3 AND session IS NOT NULL
"""
count = await self.db.fetchval(q, room_id, session_id, self.account_id)
return count > 0
async def add_outbound_group_session(self, session: OutboundGroupSession) -> None:
pickle = session.pickle(self.pickle_key)
max_age = int(session.max_age.total_seconds() * 1000)
q = """
INSERT INTO crypto_megolm_outbound_session (
room_id, session_id, session, shared, max_messages, message_count,
max_age, created_at, last_used, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
ON CONFLICT (account_id, room_id) DO UPDATE
SET session_id=excluded.session_id, session=excluded.session, shared=excluded.shared,
max_messages=excluded.max_messages, message_count=excluded.message_count,
max_age=excluded.max_age, created_at=excluded.created_at, last_used=excluded.last_used
"""
await self.db.execute(
q,
session.room_id,
session.id,
pickle,
session.shared,
session.max_messages,
session.message_count,
max_age,
session.creation_time,
session.use_time,
self.account_id,
)
async def update_outbound_group_session(self, session: OutboundGroupSession) -> None:
pickle = session.pickle(self.pickle_key)
q = """
UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3
WHERE room_id=$4 AND session_id=$5 AND account_id=$6
"""
await self.db.execute(
q,
pickle,
session.message_count,
session.use_time,
session.room_id,
session.id,
self.account_id,
)
async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None:
q = """
SELECT room_id, session_id, session, shared, max_messages, message_count, max_age,
created_at, last_used
FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2
"""
row = await self.db.fetchrow(q, room_id, self.account_id)
if row is None:
return None
return OutboundGroupSession.from_pickle(
row["session"],
passphrase=self.pickle_key,
room_id=row["room_id"],
shared=row["shared"],
max_messages=row["max_messages"],
message_count=row["message_count"],
max_age=timedelta(milliseconds=row["max_age"]),
use_time=row["last_used"],
creation_time=row["created_at"],
)
async def remove_outbound_group_session(self, room_id: RoomID) -> None:
q = "DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2"
await self.db.execute(q, room_id, self.account_id)
async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None:
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = """
DELETE FROM crypto_megolm_outbound_session WHERE account_id=$1 AND room_id=ANY($2)
"""
await self.db.execute(q, self.account_id, rooms)
else:
params = ",".join(["?"] * len(rooms))
q = f"""
DELETE FROM crypto_megolm_outbound_session WHERE account_id=? AND room_id IN ({params})
"""
await self.db.execute(q, self.account_id, *rooms)
_validate_message_index_query = """
INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp)
VALUES ($1, $2, $3, $4, $5)
-- have to update something so that RETURNING * always returns the row
ON CONFLICT (sender_key, session_id, "index") DO UPDATE SET sender_key=excluded.sender_key
RETURNING *
"""
async def validate_message_index(
self,
sender_key: IdentityKey,
session_id: SessionID,
event_id: EventID,
index: int,
timestamp: int,
) -> bool:
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH) or (
# RETURNING was added in SQLite 3.35.0 https://www.sqlite.org/lang_returning.html
self.db.scheme == Scheme.SQLITE
and sqlite_version >= (3, 35)
):
row = await self.db.fetchrow(
self._validate_message_index_query,
sender_key,
session_id,
index,
event_id,
timestamp,
)
return row["event_id"] == event_id and row["timestamp"] == timestamp
else:
row = await self.db.fetchrow(
"SELECT event_id, timestamp FROM crypto_message_index "
'WHERE sender_key=$1 AND session_id=$2 AND "index"=$3',
sender_key,
session_id,
index,
)
if row is not None:
return row["event_id"] == event_id and row["timestamp"] == timestamp
await self.db.execute(
"INSERT INTO crypto_message_index(sender_key, session_id, "
' "index", event_id, timestamp) '
"VALUES ($1, $2, $3, $4, $5)",
sender_key,
session_id,
index,
event_id,
timestamp,
)
return True
async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None:
q = "SELECT user_id FROM crypto_tracked_user WHERE user_id=$1"
tracked_user_id = await self.db.fetchval(q, user_id)
if tracked_user_id is None:
return None
q = """
SELECT device_id, identity_key, signing_key, trust, deleted, name
FROM crypto_device WHERE user_id=$1
"""
rows = await self.db.fetch(q, user_id)
result = {}
for row in rows:
result[row["device_id"]] = DeviceIdentity(
user_id=user_id,
device_id=row["device_id"],
identity_key=row["identity_key"],
signing_key=row["signing_key"],
trust=TrustState(row["trust"]),
deleted=row["deleted"],
name=row["name"],
)
return result
async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None:
q = """
SELECT identity_key, signing_key, trust, deleted, name FROM crypto_device
WHERE user_id=$1 AND device_id=$2
"""
row = await self.db.fetchrow(q, user_id, device_id)
if row is None:
return None
return DeviceIdentity(
user_id=user_id,
device_id=device_id,
name=row["name"],
identity_key=row["identity_key"],
signing_key=row["signing_key"],
trust=TrustState(row["trust"]),
deleted=row["deleted"],
)
async def find_device_by_key(
self, user_id: UserID, identity_key: IdentityKey
) -> DeviceIdentity | None:
q = """
SELECT device_id, signing_key, trust, deleted, name FROM crypto_device
WHERE user_id=$1 AND identity_key=$2
"""
row = await self.db.fetchrow(
q,
user_id,
identity_key,
)
if row is None:
return None
return DeviceIdentity(
user_id=user_id,
device_id=row["device_id"],
name=row["name"],
identity_key=identity_key,
signing_key=row["signing_key"],
trust=TrustState(row["trust"]),
deleted=row["deleted"],
)
async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None:
data = [
(
user_id,
device_id,
identity.identity_key,
identity.signing_key,
identity.trust,
identity.deleted,
identity.name,
)
for device_id, identity in devices.items()
]
columns = [
"user_id",
"device_id",
"identity_key",
"signing_key",
"trust",
"deleted",
"name",
]
async with self.db.acquire() as conn, conn.transaction():
q = """
INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING
"""
await conn.execute(q, user_id)
await conn.execute("DELETE FROM crypto_device WHERE user_id=$1", user_id)
if self.db.scheme == Scheme.POSTGRES:
await conn.copy_records_to_table("crypto_device", records=data, columns=columns)
else:
q = """
INSERT INTO crypto_device (
user_id, device_id, identity_key, signing_key, trust, deleted, name
) VALUES ($1, $2, $3, $4, $5, $6, $7)
"""
await conn.executemany(q, data)
async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]:
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)"
rows = await self.db.fetch(q, users)
else:
params = ",".join(["?"] * len(users))
q = f"SELECT user_id FROM crypto_tracked_user WHERE user_id IN ({params})"
rows = await self.db.fetch(q, *users)
return [row["user_id"] for row in rows]
async def put_cross_signing_key(
self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey
) -> None:
q = """
INSERT INTO crypto_cross_signing_keys (user_id, usage, key, first_seen_key)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, usage) DO UPDATE SET key=excluded.key
"""
try:
await self.db.execute(q, user_id, usage.value, key, key)
except Exception:
self.log.exception(f"Failed to store cross-signing key {user_id}/{key}/{usage}")
async def get_cross_signing_keys(
self, user_id: UserID
) -> dict[CrossSigningUsage, TOFUSigningKey]:
q = "SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1"
return {
CrossSigningUsage(row["usage"]): TOFUSigningKey(
key=SigningKey(row["key"]),
first=SigningKey(row["first_seen_key"]),
)
for row in await self.db.fetch(q, user_id)
}
async def put_signature(
self, target: CrossSigner, signer: CrossSigner, signature: str
) -> None:
q = """
INSERT INTO crypto_cross_signing_signatures (
signed_user_id, signed_key, signer_user_id, signer_key, signature
) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (signed_user_id, signed_key, signer_user_id, signer_key)
DO UPDATE SET signature=excluded.signature
"""
signed_user_id, signed_key = target
signer_user_id, signer_key = signer
try:
await self.db.execute(
q, signed_user_id, signed_key, signer_user_id, signer_key, signature
)
except Exception:
self.log.exception(
f"Failed to store signature from {signer_user_id}/{signer_key} "
f"for {signed_user_id}/{signed_key}"
)
async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool:
q = """
SELECT EXISTS(
SELECT 1 FROM crypto_cross_signing_signatures
WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3 AND signer_key=$4
)
"""
signed_user_id, signed_key = target
signer_user_id, signer_key = signer
return await self.db.fetchval(q, signed_user_id, signed_key, signer_user_id, signer_key)
async def drop_signatures_by_key(self, signer: CrossSigner) -> int:
signer_user_id, signer_key = signer
q = "DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2"
try:
res = await self.db.execute(q, signer_user_id, signer_key)
except Exception:
self.log.exception(
f"Failed to drop old signatures made by replaced key {signer_user_id}/{signer_key}"
)
return -1
if Cursor is not None and isinstance(res, Cursor):
return res.rowcount
elif (
isinstance(res, str)
and res.startswith("DELETE ")
and (intPart := res[len("DELETE ") :]).isdecimal()
):
return int(intPart)
return -1