172 lines
6.2 KiB
Python
172 lines
6.2 KiB
Python
# Copyright (c) 2022 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
from __future__ import annotations
|
|
|
|
from typing import AsyncContextManager, AsyncIterator, Callable
|
|
from contextlib import asynccontextmanager
|
|
import json
|
|
import os
|
|
import pathlib
|
|
import random
|
|
import string
|
|
import time
|
|
|
|
import asyncpg
|
|
import pytest
|
|
|
|
from mautrix.types import EncryptionAlgorithm, Member, Membership, RoomID, StateEvent, UserID
|
|
from mautrix.util.async_db import Database
|
|
|
|
from .. import MemoryStateStore, StateStore
|
|
from ..asyncpg import PgStateStore
|
|
|
|
|
|
@asynccontextmanager
|
|
async def async_postgres_store() -> AsyncIterator[PgStateStore]:
|
|
try:
|
|
pg_url = os.environ["MEOW_TEST_PG_URL"]
|
|
except KeyError:
|
|
pytest.skip("Skipped Postgres tests (MEOW_TEST_PG_URL not specified)")
|
|
return
|
|
conn: asyncpg.Connection = await asyncpg.connect(pg_url)
|
|
schema_name = "".join(random.choices(string.ascii_lowercase, k=8))
|
|
schema_name = f"test_schema_{schema_name}_{int(time.time())}"
|
|
await conn.execute(f"CREATE SCHEMA {schema_name}")
|
|
db = Database.create(
|
|
pg_url,
|
|
upgrade_table=PgStateStore.upgrade_table,
|
|
db_args={"min_size": 1, "max_size": 3, "server_settings": {"search_path": schema_name}},
|
|
)
|
|
store = PgStateStore(db)
|
|
await db.start()
|
|
yield store
|
|
await db.stop()
|
|
await conn.execute(f"DROP SCHEMA {schema_name} CASCADE")
|
|
await conn.close()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def async_sqlite_store() -> AsyncIterator[PgStateStore]:
|
|
db = Database.create(
|
|
"sqlite::memory:", upgrade_table=PgStateStore.upgrade_table, db_args={"min_size": 1}
|
|
)
|
|
store = PgStateStore(db)
|
|
await db.start()
|
|
yield store
|
|
await db.stop()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def memory_store() -> AsyncIterator[MemoryStateStore]:
|
|
yield MemoryStateStore()
|
|
|
|
|
|
@pytest.fixture(params=[async_postgres_store, async_sqlite_store, memory_store])
|
|
async def store(request) -> AsyncIterator[StateStore]:
|
|
param: Callable[[], AsyncContextManager[StateStore]] = request.param
|
|
async with param() as state_store:
|
|
yield state_store
|
|
|
|
|
|
def read_state_file(request, file) -> dict[RoomID, list[StateEvent]]:
|
|
path = pathlib.Path(request.node.fspath).with_name(file)
|
|
with path.open() as fp:
|
|
content = json.load(fp)
|
|
return {
|
|
room_id: [StateEvent.deserialize({**evt, "room_id": room_id}) for evt in events]
|
|
for room_id, events in content.items()
|
|
}
|
|
|
|
|
|
async def store_room_state(request, store: StateStore) -> None:
|
|
room_state_changes = read_state_file(request, "new_state.json")
|
|
for events in room_state_changes.values():
|
|
for evt in events:
|
|
await store.update_state(evt)
|
|
|
|
|
|
async def get_all_members(request, store: StateStore) -> None:
|
|
room_state = read_state_file(request, "members.json")
|
|
for room_id, member_events in room_state.items():
|
|
await store.set_members(room_id, {evt.state_key: evt.content for evt in member_events})
|
|
|
|
|
|
async def get_joined_members(request, store: StateStore) -> None:
|
|
path = pathlib.Path(request.node.fspath).with_name("joined_members.json")
|
|
with path.open() as fp:
|
|
content = json.load(fp)
|
|
for room_id, members in content.items():
|
|
parsed_members = {
|
|
user_id: Member(
|
|
membership=Membership.JOIN,
|
|
displayname=member.get("display_name", ""),
|
|
avatar_url=member.get("avatar_url", ""),
|
|
)
|
|
for user_id, member in members.items()
|
|
}
|
|
await store.set_members(room_id, parsed_members, only_membership=Membership.JOIN)
|
|
|
|
|
|
async def test_basic(store: StateStore) -> None:
|
|
room_id = RoomID("!foo:example.com")
|
|
user_id = UserID("@tulir:example.com")
|
|
|
|
assert not await store.is_encrypted(room_id)
|
|
assert not await store.is_joined(room_id, user_id)
|
|
await store.joined(room_id, user_id)
|
|
assert await store.is_joined(room_id, user_id)
|
|
|
|
assert not await store.has_encryption_info_cached(RoomID("!unknown-room:example.com"))
|
|
assert await store.is_encrypted(RoomID("!unknown-room:example.com")) is None
|
|
|
|
|
|
async def test_basic_updated(request, store: StateStore) -> None:
|
|
await store_room_state(request, store)
|
|
test_group = RoomID("!telegram-group:example.com")
|
|
assert await store.is_encrypted(test_group)
|
|
assert (await store.get_encryption_info(test_group)).algorithm == EncryptionAlgorithm.MEGOLM_V1
|
|
assert not await store.is_encrypted(RoomID("!unencrypted-room:example.com"))
|
|
|
|
|
|
async def test_updates(request, store: StateStore) -> None:
|
|
await store_room_state(request, store)
|
|
room_id = RoomID("!telegram-group:example.com")
|
|
initial_members = {"@tulir:example.com", "@telegram_84359547:example.com"}
|
|
joined_members = initial_members | {
|
|
"@telegrambot:example.com",
|
|
"@telegram_5647382910:example.com",
|
|
"@telegram_374880943:example.com",
|
|
"@telegram_987654321:example.com",
|
|
"@telegram_123456789:example.com",
|
|
}
|
|
left_members = {"@telegram_476034259:example.com", "@whatsappbot:example.com"}
|
|
full_members = joined_members | left_members
|
|
any_membership = (
|
|
Membership.JOIN,
|
|
Membership.INVITE,
|
|
Membership.LEAVE,
|
|
Membership.BAN,
|
|
Membership.KNOCK,
|
|
)
|
|
leave_memberships = (Membership.BAN, Membership.LEAVE)
|
|
assert set(await store.get_members(room_id)) == initial_members
|
|
await get_all_members(request, store)
|
|
assert set(await store.get_members(room_id)) == joined_members
|
|
assert set(await store.get_members(room_id, memberships=any_membership)) == full_members
|
|
await get_joined_members(request, store)
|
|
assert set(await store.get_members(room_id)) == joined_members
|
|
assert set(await store.get_members(room_id, memberships=any_membership)) == full_members
|
|
assert set(await store.get_members(room_id, memberships=leave_memberships)) == left_members
|
|
assert set(
|
|
await store.get_members_filtered(
|
|
room_id,
|
|
memberships=leave_memberships,
|
|
not_id="",
|
|
not_prefix="@telegram_",
|
|
not_suffix=":example.com",
|
|
)
|
|
) == {"@whatsappbot:example.com"}
|