133 lines
4.8 KiB
Python
133 lines
4.8 KiB
Python
# Copyright (c) 2022 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
from __future__ import annotations
|
|
|
|
from typing import AsyncContextManager, AsyncIterator, Callable
|
|
from contextlib import asynccontextmanager
|
|
import os
|
|
import random
|
|
import string
|
|
import time
|
|
|
|
import asyncpg
|
|
import pytest
|
|
|
|
from mautrix.client.state_store import SyncStore
|
|
from mautrix.crypto import InboundGroupSession, OlmAccount, OutboundGroupSession
|
|
from mautrix.types import DeviceID, EventID, RoomID, SessionID, SyncToken
|
|
from mautrix.util.async_db import Database
|
|
|
|
from .. import CryptoStore, MemoryCryptoStore, PgCryptoStore
|
|
|
|
|
|
@asynccontextmanager
|
|
async def async_postgres_store() -> AsyncIterator[PgCryptoStore]:
|
|
try:
|
|
pg_url = os.environ["MEOW_TEST_PG_URL"]
|
|
except KeyError:
|
|
pytest.skip("Skipped Postgres tests (MEOW_TEST_PG_URL not specified)")
|
|
return
|
|
conn: asyncpg.Connection = await asyncpg.connect(pg_url)
|
|
schema_name = "".join(random.choices(string.ascii_lowercase, k=8))
|
|
schema_name = f"test_schema_{schema_name}_{int(time.time())}"
|
|
await conn.execute(f"CREATE SCHEMA {schema_name}")
|
|
db = Database.create(
|
|
pg_url,
|
|
upgrade_table=PgCryptoStore.upgrade_table,
|
|
db_args={"min_size": 1, "max_size": 3, "server_settings": {"search_path": schema_name}},
|
|
)
|
|
store = PgCryptoStore("", "test", db)
|
|
await db.start()
|
|
yield store
|
|
await db.stop()
|
|
await conn.execute(f"DROP SCHEMA {schema_name} CASCADE")
|
|
await conn.close()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def async_sqlite_store() -> AsyncIterator[PgCryptoStore]:
|
|
db = Database.create(
|
|
"sqlite::memory:", upgrade_table=PgCryptoStore.upgrade_table, db_args={"min_size": 1}
|
|
)
|
|
store = PgCryptoStore("", "test", db)
|
|
await db.start()
|
|
yield store
|
|
await db.stop()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def memory_store() -> AsyncIterator[MemoryCryptoStore]:
|
|
yield MemoryCryptoStore("", "test")
|
|
|
|
|
|
@pytest.fixture(params=[async_postgres_store, async_sqlite_store, memory_store])
|
|
async def crypto_store(request) -> AsyncIterator[CryptoStore]:
|
|
param: Callable[[], AsyncContextManager[CryptoStore]] = request.param
|
|
async with param() as state_store:
|
|
yield state_store
|
|
|
|
|
|
async def test_basic(crypto_store: CryptoStore) -> None:
|
|
acc = OlmAccount()
|
|
keys = acc.identity_keys
|
|
await crypto_store.put_account(acc)
|
|
await crypto_store.put_device_id(DeviceID("TEST"))
|
|
if isinstance(crypto_store, SyncStore):
|
|
await crypto_store.put_next_batch(SyncToken("TEST"))
|
|
|
|
assert await crypto_store.get_device_id() == "TEST"
|
|
assert (await crypto_store.get_account()).identity_keys == keys
|
|
if isinstance(crypto_store, SyncStore):
|
|
assert await crypto_store.get_next_batch() == "TEST"
|
|
|
|
|
|
def _make_group_sess(
|
|
acc: OlmAccount, room_id: RoomID
|
|
) -> tuple[InboundGroupSession, OutboundGroupSession]:
|
|
outbound = OutboundGroupSession(room_id)
|
|
inbound = InboundGroupSession(
|
|
session_key=outbound.session_key,
|
|
signing_key=acc.signing_key,
|
|
sender_key=acc.identity_key,
|
|
room_id=room_id,
|
|
)
|
|
return inbound, outbound
|
|
|
|
|
|
async def test_validate_message_index(crypto_store: CryptoStore) -> None:
|
|
acc = OlmAccount()
|
|
|
|
inbound, outbound = _make_group_sess(acc, RoomID("!foo:bar.com"))
|
|
outbound.shared = True
|
|
orig_plaintext = "hello world"
|
|
ciphertext = outbound.encrypt(orig_plaintext)
|
|
ts = int(time.time() * 1000)
|
|
plaintext, index = inbound.decrypt(ciphertext)
|
|
assert plaintext == orig_plaintext
|
|
|
|
assert await crypto_store.validate_message_index(
|
|
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts
|
|
), "Initial validation returns True"
|
|
assert await crypto_store.validate_message_index(
|
|
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts
|
|
), "Validating the same details again returns True"
|
|
assert not await crypto_store.validate_message_index(
|
|
acc.identity_key, SessionID(inbound.id), EventID("$bar"), index, ts
|
|
), "Different event ID causes validation to fail"
|
|
assert not await crypto_store.validate_message_index(
|
|
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + 1
|
|
), "Different timestamp causes validation to fail"
|
|
assert not await crypto_store.validate_message_index(
|
|
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + 1
|
|
), "Validating incorrect details twice fails"
|
|
assert await crypto_store.validate_message_index(
|
|
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts
|
|
), "Validating the same details after fails still returns True"
|
|
|
|
|
|
# TODO tests for device identity storage, group session storage
|
|
# and cross-signing key/signature storage
|