mautrix-python/mautrix/client/state_store/asyncpg/store.py

258 lines
10 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, NamedTuple
import json
from mautrix.types import (
Member,
Membership,
MemberStateEventContent,
PowerLevelStateEventContent,
RoomEncryptionStateEventContent,
RoomID,
Serializable,
UserID,
)
from mautrix.util.async_db import Database, Scheme
from ..abstract import StateStore
from .upgrade import upgrade_table
class RoomState(NamedTuple):
is_encrypted: bool
has_full_member_list: bool
encryption: RoomEncryptionStateEventContent
power_levels: PowerLevelStateEventContent
class PgStateStore(StateStore):
upgrade_table = upgrade_table
db: Database
def __init__(self, db: Database) -> None:
self.db = db
async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None:
res = await self.db.fetchrow(
"SELECT membership, displayname, avatar_url "
"FROM mx_user_profile WHERE room_id=$1 AND user_id=$2",
room_id,
user_id,
)
if res is None:
return None
return Member(
membership=Membership.deserialize(res["membership"]),
displayname=res["displayname"],
avatar_url=res["avatar_url"],
)
async def set_member(
self, room_id: RoomID, user_id: UserID, member: Member | MemberStateEventContent
) -> None:
q = (
"INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) "
"VALUES ($1, $2, $3, $4, $5)"
"ON CONFLICT (room_id, user_id) DO UPDATE SET membership=$3, displayname=$4,"
" avatar_url=$5"
)
await self.db.execute(
q, room_id, user_id, member.membership.value, member.displayname, member.avatar_url
)
async def set_membership(
self, room_id: RoomID, user_id: UserID, membership: Membership
) -> None:
q = (
"INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3) "
"ON CONFLICT (room_id, user_id) DO UPDATE SET membership=$3"
)
await self.db.execute(q, room_id, user_id, membership.value)
async def get_members(
self,
room_id: RoomID,
memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
) -> list[UserID]:
membership_values = [membership.value for membership in memberships]
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = "SELECT user_id FROM mx_user_profile WHERE room_id=$1 AND membership=ANY($2)"
res = await self.db.fetch(q, room_id, membership_values)
else:
membership_placeholders = ("?," * len(memberships)).rstrip(",")
q = (
"SELECT user_id FROM mx_user_profile "
f"WHERE room_id=? AND membership IN ({membership_placeholders})"
)
res = await self.db.fetch(q, room_id, *membership_values)
return [profile["user_id"] for profile in res]
async def get_member_profiles(
self,
room_id: RoomID,
memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
) -> dict[UserID, Member]:
membership_values = [membership.value for membership in memberships]
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = (
"SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile "
"WHERE room_id=$1 AND membership=ANY($2)"
)
res = await self.db.fetch(q, room_id, membership_values)
else:
membership_placeholders = ("?," * len(memberships)).rstrip(",")
q = (
"SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile "
f"WHERE room_id=? AND membership IN ({membership_placeholders})"
)
res = await self.db.fetch(q, room_id, *membership_values)
return {profile["user_id"]: Member.deserialize(profile) for profile in res}
async def get_members_filtered(
self,
room_id: RoomID,
not_prefix: str,
not_suffix: str,
not_id: str,
memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
) -> list[UserID]:
not_like = f"{not_prefix}%{not_suffix}"
membership_values = [membership.value for membership in memberships]
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = (
"SELECT user_id FROM mx_user_profile "
"WHERE room_id=$1 AND membership=ANY($2)"
"AND user_id != $3 AND user_id NOT LIKE $4"
)
res = await self.db.fetch(q, room_id, membership_values, not_id, not_like)
else:
membership_placeholders = ("?," * len(memberships)).rstrip(",")
q = (
"SELECT user_id FROM mx_user_profile "
f"WHERE room_id=? AND membership IN ({membership_placeholders})"
"AND user_id != ? AND user_id NOT LIKE ?"
)
res = await self.db.fetch(q, room_id, *membership_values, not_id, not_like)
return [profile["user_id"] for profile in res]
async def set_members(
self,
room_id: RoomID,
members: dict[UserID, Member | MemberStateEventContent],
only_membership: Membership | None = None,
) -> None:
columns = ["room_id", "user_id", "membership", "displayname", "avatar_url"]
records = [
(room_id, user_id, str(member.membership), member.displayname, member.avatar_url)
for user_id, member in members.items()
]
async with self.db.acquire() as conn, conn.transaction():
del_q = "DELETE FROM mx_user_profile WHERE room_id=$1"
if only_membership is None:
await conn.execute(del_q, room_id)
elif self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
del_q = f"{del_q} AND (membership=$2 OR user_id = ANY($3))"
await conn.execute(del_q, room_id, only_membership.value, list(members.keys()))
else:
member_placeholders = ("?," * len(members)).rstrip(",")
del_q = f"{del_q} AND (membership=? OR user_id IN ({member_placeholders}))"
await conn.execute(del_q, room_id, only_membership.value, *members.keys())
if self.db.scheme == Scheme.POSTGRES:
await conn.copy_records_to_table(
"mx_user_profile", records=records, columns=columns
)
else:
q = (
"INSERT INTO mx_user_profile (room_id, user_id, membership, "
"displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)"
)
await conn.executemany(q, records)
if not only_membership or only_membership == Membership.JOIN:
await conn.execute(
"UPDATE mx_room_state SET has_full_member_list=true WHERE room_id=$1",
room_id,
)
async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]:
q = (
"SELECT mx_user_profile.room_id FROM mx_user_profile "
"LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id "
"WHERE user_id=$1 AND mx_room_state.is_encrypted=true"
)
rows = await self.db.fetch(q, user_id)
return [row["room_id"] for row in rows]
async def has_full_member_list(self, room_id: RoomID) -> bool:
return bool(
await self.db.fetchval(
"SELECT has_full_member_list FROM mx_room_state WHERE room_id=$1", room_id
)
)
async def has_power_levels_cached(self, room_id: RoomID) -> bool:
return bool(
await self.db.fetchval(
"SELECT power_levels IS NOT NULL FROM mx_room_state WHERE room_id=$1", room_id
)
)
async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent | None:
power_levels_json = await self.db.fetchval(
"SELECT power_levels FROM mx_room_state WHERE room_id=$1", room_id
)
if power_levels_json is None:
return None
return PowerLevelStateEventContent.parse_json(power_levels_json)
async def set_power_levels(
self, room_id: RoomID, content: PowerLevelStateEventContent | dict[str, Any]
) -> None:
await self.db.execute(
"INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2) "
"ON CONFLICT (room_id) DO UPDATE SET power_levels=$2",
room_id,
json.dumps(content.serialize() if isinstance(content, Serializable) else content),
)
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
return bool(
await self.db.fetchval(
"SELECT encryption IS NULL FROM mx_room_state WHERE room_id=$1", room_id
)
)
async def is_encrypted(self, room_id: RoomID) -> bool | None:
return await self.db.fetchval(
"SELECT is_encrypted FROM mx_room_state WHERE room_id=$1", room_id
)
async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None:
row = await self.db.fetchrow(
"SELECT is_encrypted, encryption FROM mx_room_state WHERE room_id=$1", room_id
)
if row is None or not row["is_encrypted"]:
return None
return RoomEncryptionStateEventContent.parse_json(row["encryption"])
async def set_encryption_info(
self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
) -> None:
q = (
"INSERT INTO mx_room_state (room_id, is_encrypted, encryption) VALUES ($1, true, $2) "
"ON CONFLICT (room_id) DO UPDATE SET is_encrypted=true, encryption=$2"
)
await self.db.execute(
q,
room_id,
json.dumps(content.serialize() if isinstance(content, Serializable) else content),
)