mautrix-python/mautrix/client/state_store/tests/store_test.py

172 lines
6.2 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import AsyncContextManager, AsyncIterator, Callable
from contextlib import asynccontextmanager
import json
import os
import pathlib
import random
import string
import time
import asyncpg
import pytest
from mautrix.types import EncryptionAlgorithm, Member, Membership, RoomID, StateEvent, UserID
from mautrix.util.async_db import Database
from .. import MemoryStateStore, StateStore
from ..asyncpg import PgStateStore
@asynccontextmanager
async def async_postgres_store() -> AsyncIterator[PgStateStore]:
try:
pg_url = os.environ["MEOW_TEST_PG_URL"]
except KeyError:
pytest.skip("Skipped Postgres tests (MEOW_TEST_PG_URL not specified)")
return
conn: asyncpg.Connection = await asyncpg.connect(pg_url)
schema_name = "".join(random.choices(string.ascii_lowercase, k=8))
schema_name = f"test_schema_{schema_name}_{int(time.time())}"
await conn.execute(f"CREATE SCHEMA {schema_name}")
db = Database.create(
pg_url,
upgrade_table=PgStateStore.upgrade_table,
db_args={"min_size": 1, "max_size": 3, "server_settings": {"search_path": schema_name}},
)
store = PgStateStore(db)
await db.start()
yield store
await db.stop()
await conn.execute(f"DROP SCHEMA {schema_name} CASCADE")
await conn.close()
@asynccontextmanager
async def async_sqlite_store() -> AsyncIterator[PgStateStore]:
db = Database.create(
"sqlite::memory:", upgrade_table=PgStateStore.upgrade_table, db_args={"min_size": 1}
)
store = PgStateStore(db)
await db.start()
yield store
await db.stop()
@asynccontextmanager
async def memory_store() -> AsyncIterator[MemoryStateStore]:
yield MemoryStateStore()
@pytest.fixture(params=[async_postgres_store, async_sqlite_store, memory_store])
async def store(request) -> AsyncIterator[StateStore]:
param: Callable[[], AsyncContextManager[StateStore]] = request.param
async with param() as state_store:
yield state_store
def read_state_file(request, file) -> dict[RoomID, list[StateEvent]]:
path = pathlib.Path(request.node.fspath).with_name(file)
with path.open() as fp:
content = json.load(fp)
return {
room_id: [StateEvent.deserialize({**evt, "room_id": room_id}) for evt in events]
for room_id, events in content.items()
}
async def store_room_state(request, store: StateStore) -> None:
room_state_changes = read_state_file(request, "new_state.json")
for events in room_state_changes.values():
for evt in events:
await store.update_state(evt)
async def get_all_members(request, store: StateStore) -> None:
room_state = read_state_file(request, "members.json")
for room_id, member_events in room_state.items():
await store.set_members(room_id, {evt.state_key: evt.content for evt in member_events})
async def get_joined_members(request, store: StateStore) -> None:
path = pathlib.Path(request.node.fspath).with_name("joined_members.json")
with path.open() as fp:
content = json.load(fp)
for room_id, members in content.items():
parsed_members = {
user_id: Member(
membership=Membership.JOIN,
displayname=member.get("display_name", ""),
avatar_url=member.get("avatar_url", ""),
)
for user_id, member in members.items()
}
await store.set_members(room_id, parsed_members, only_membership=Membership.JOIN)
async def test_basic(store: StateStore) -> None:
room_id = RoomID("!foo:example.com")
user_id = UserID("@tulir:example.com")
assert not await store.is_encrypted(room_id)
assert not await store.is_joined(room_id, user_id)
await store.joined(room_id, user_id)
assert await store.is_joined(room_id, user_id)
assert not await store.has_encryption_info_cached(RoomID("!unknown-room:example.com"))
assert await store.is_encrypted(RoomID("!unknown-room:example.com")) is None
async def test_basic_updated(request, store: StateStore) -> None:
await store_room_state(request, store)
test_group = RoomID("!telegram-group:example.com")
assert await store.is_encrypted(test_group)
assert (await store.get_encryption_info(test_group)).algorithm == EncryptionAlgorithm.MEGOLM_V1
assert not await store.is_encrypted(RoomID("!unencrypted-room:example.com"))
async def test_updates(request, store: StateStore) -> None:
await store_room_state(request, store)
room_id = RoomID("!telegram-group:example.com")
initial_members = {"@tulir:example.com", "@telegram_84359547:example.com"}
joined_members = initial_members | {
"@telegrambot:example.com",
"@telegram_5647382910:example.com",
"@telegram_374880943:example.com",
"@telegram_987654321:example.com",
"@telegram_123456789:example.com",
}
left_members = {"@telegram_476034259:example.com", "@whatsappbot:example.com"}
full_members = joined_members | left_members
any_membership = (
Membership.JOIN,
Membership.INVITE,
Membership.LEAVE,
Membership.BAN,
Membership.KNOCK,
)
leave_memberships = (Membership.BAN, Membership.LEAVE)
assert set(await store.get_members(room_id)) == initial_members
await get_all_members(request, store)
assert set(await store.get_members(room_id)) == joined_members
assert set(await store.get_members(room_id, memberships=any_membership)) == full_members
await get_joined_members(request, store)
assert set(await store.get_members(room_id)) == joined_members
assert set(await store.get_members(room_id, memberships=any_membership)) == full_members
assert set(await store.get_members(room_id, memberships=leave_memberships)) == left_members
assert set(
await store.get_members_filtered(
room_id,
memberships=leave_memberships,
not_id="",
not_prefix="@telegram_",
not_suffix=":example.com",
)
) == {"@whatsappbot:example.com"}