mautrix-python/mautrix/crypto/store/tests/store_test.py

133 lines
4.8 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import AsyncContextManager, AsyncIterator, Callable
from contextlib import asynccontextmanager
import os
import random
import string
import time
import asyncpg
import pytest
from mautrix.client.state_store import SyncStore
from mautrix.crypto import InboundGroupSession, OlmAccount, OutboundGroupSession
from mautrix.types import DeviceID, EventID, RoomID, SessionID, SyncToken
from mautrix.util.async_db import Database
from .. import CryptoStore, MemoryCryptoStore, PgCryptoStore
@asynccontextmanager
async def async_postgres_store() -> AsyncIterator[PgCryptoStore]:
try:
pg_url = os.environ["MEOW_TEST_PG_URL"]
except KeyError:
pytest.skip("Skipped Postgres tests (MEOW_TEST_PG_URL not specified)")
return
conn: asyncpg.Connection = await asyncpg.connect(pg_url)
schema_name = "".join(random.choices(string.ascii_lowercase, k=8))
schema_name = f"test_schema_{schema_name}_{int(time.time())}"
await conn.execute(f"CREATE SCHEMA {schema_name}")
db = Database.create(
pg_url,
upgrade_table=PgCryptoStore.upgrade_table,
db_args={"min_size": 1, "max_size": 3, "server_settings": {"search_path": schema_name}},
)
store = PgCryptoStore("", "test", db)
await db.start()
yield store
await db.stop()
await conn.execute(f"DROP SCHEMA {schema_name} CASCADE")
await conn.close()
@asynccontextmanager
async def async_sqlite_store() -> AsyncIterator[PgCryptoStore]:
db = Database.create(
"sqlite::memory:", upgrade_table=PgCryptoStore.upgrade_table, db_args={"min_size": 1}
)
store = PgCryptoStore("", "test", db)
await db.start()
yield store
await db.stop()
@asynccontextmanager
async def memory_store() -> AsyncIterator[MemoryCryptoStore]:
yield MemoryCryptoStore("", "test")
@pytest.fixture(params=[async_postgres_store, async_sqlite_store, memory_store])
async def crypto_store(request) -> AsyncIterator[CryptoStore]:
param: Callable[[], AsyncContextManager[CryptoStore]] = request.param
async with param() as state_store:
yield state_store
async def test_basic(crypto_store: CryptoStore) -> None:
acc = OlmAccount()
keys = acc.identity_keys
await crypto_store.put_account(acc)
await crypto_store.put_device_id(DeviceID("TEST"))
if isinstance(crypto_store, SyncStore):
await crypto_store.put_next_batch(SyncToken("TEST"))
assert await crypto_store.get_device_id() == "TEST"
assert (await crypto_store.get_account()).identity_keys == keys
if isinstance(crypto_store, SyncStore):
assert await crypto_store.get_next_batch() == "TEST"
def _make_group_sess(
acc: OlmAccount, room_id: RoomID
) -> tuple[InboundGroupSession, OutboundGroupSession]:
outbound = OutboundGroupSession(room_id)
inbound = InboundGroupSession(
session_key=outbound.session_key,
signing_key=acc.signing_key,
sender_key=acc.identity_key,
room_id=room_id,
)
return inbound, outbound
async def test_validate_message_index(crypto_store: CryptoStore) -> None:
acc = OlmAccount()
inbound, outbound = _make_group_sess(acc, RoomID("!foo:bar.com"))
outbound.shared = True
orig_plaintext = "hello world"
ciphertext = outbound.encrypt(orig_plaintext)
ts = int(time.time() * 1000)
plaintext, index = inbound.decrypt(ciphertext)
assert plaintext == orig_plaintext
assert await crypto_store.validate_message_index(
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts
), "Initial validation returns True"
assert await crypto_store.validate_message_index(
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts
), "Validating the same details again returns True"
assert not await crypto_store.validate_message_index(
acc.identity_key, SessionID(inbound.id), EventID("$bar"), index, ts
), "Different event ID causes validation to fail"
assert not await crypto_store.validate_message_index(
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + 1
), "Different timestamp causes validation to fail"
assert not await crypto_store.validate_message_index(
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + 1
), "Validating incorrect details twice fails"
assert await crypto_store.validate_message_index(
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts
), "Validating the same details after fails still returns True"
# TODO tests for device identity storage, group session storage
# and cross-signing key/signature storage