mautrix-python/mautrix/client/store_updater.py

299 lines
11 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
import asyncio
from mautrix.errors import MForbidden, MNotFound
from mautrix.types import (
JSON,
EventID,
EventType,
Member,
Membership,
MemberStateEventContent,
RoomAlias,
RoomID,
StateEvent,
StateEventContent,
SyncToken,
UserID,
)
from .api import ClientAPI
from .state_store import StateStore
class StoreUpdatingAPI(ClientAPI):
"""
StoreUpdatingAPI is a wrapper around the medium-level ClientAPI that optionally updates
a client state store with outgoing state events (after they're successfully sent).
"""
state_store: StateStore | None
def __init__(self, *args, state_store: StateStore | None = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.state_store = state_store
async def join_room_by_id(
self,
room_id: RoomID,
third_party_signed: JSON = None,
extra_content: dict[str, JSON] | None = None,
) -> RoomID:
room_id = await super().join_room_by_id(
room_id, third_party_signed=third_party_signed, extra_content=extra_content
)
if room_id and not extra_content and self.state_store:
await self.state_store.set_membership(room_id, self.mxid, Membership.JOIN)
return room_id
async def join_room(
self,
room_id_or_alias: RoomID | RoomAlias,
servers: list[str] | None = None,
third_party_signed: JSON = None,
max_retries: int = 4,
) -> RoomID:
room_id = await super().join_room(
room_id_or_alias, servers, third_party_signed, max_retries
)
if room_id and self.state_store:
await self.state_store.set_membership(room_id, self.mxid, Membership.JOIN)
return room_id
async def leave_room(
self,
room_id: RoomID,
reason: str | None = None,
extra_content: dict[str, JSON] | None = None,
raise_not_in_room: bool = False,
) -> None:
await super().leave_room(room_id, reason, extra_content, raise_not_in_room)
if not extra_content and self.state_store:
await self.state_store.set_membership(room_id, self.mxid, Membership.LEAVE)
async def knock_room(
self,
room_id_or_alias: RoomID | RoomAlias,
reason: str | None = None,
servers: list[str] | None = None,
) -> RoomID:
room_id = await super().knock_room(room_id_or_alias, reason, servers)
if room_id and self.state_store:
await self.state_store.set_membership(room_id, self.mxid, Membership.KNOCK)
return room_id
async def invite_user(
self,
room_id: RoomID,
user_id: UserID,
reason: str | None = None,
extra_content: dict[str, JSON] | None = None,
) -> None:
await super().invite_user(room_id, user_id, reason, extra_content=extra_content)
if not extra_content and self.state_store:
await self.state_store.set_membership(room_id, user_id, Membership.INVITE)
async def kick_user(
self,
room_id: RoomID,
user_id: UserID,
reason: str = "",
extra_content: dict[str, JSON] | None = None,
) -> None:
await super().kick_user(room_id, user_id, reason=reason, extra_content=extra_content)
if not extra_content and self.state_store:
await self.state_store.set_membership(room_id, user_id, Membership.LEAVE)
async def ban_user(
self,
room_id: RoomID,
user_id: UserID,
reason: str = "",
extra_content: dict[str, JSON] | None = None,
) -> None:
await super().ban_user(room_id, user_id, reason=reason, extra_content=extra_content)
if not extra_content and self.state_store:
await self.state_store.set_membership(room_id, user_id, Membership.BAN)
async def unban_user(
self,
room_id: RoomID,
user_id: UserID,
reason: str = "",
extra_content: dict[str, JSON] | None = None,
) -> None:
await super().unban_user(room_id, user_id, reason=reason, extra_content=extra_content)
if self.state_store:
await self.state_store.set_membership(room_id, user_id, Membership.LEAVE)
async def get_state(self, room_id: RoomID) -> list[StateEvent]:
state = await super().get_state(room_id)
if self.state_store:
update_members = self.state_store.set_members(
room_id,
{evt.state_key: evt.content for evt in state if evt.type == EventType.ROOM_MEMBER},
)
await asyncio.gather(
update_members,
*[
self.state_store.update_state(evt)
for evt in state
if evt.type != EventType.ROOM_MEMBER
],
)
return state
async def create_room(self, *args, **kwargs) -> RoomID:
room_id = await super().create_room(*args, **kwargs)
if self.state_store:
invitee_membership = Membership.INVITE
if kwargs.get("beeper_auto_join_invites"):
invitee_membership = Membership.JOIN
for user_id in kwargs.get("invitees", []):
await self.state_store.set_membership(room_id, user_id, invitee_membership)
for evt in kwargs.get("initial_state", []):
await self.state_store.update_state(
StateEvent(
type=EventType.find(evt["type"], t_class=EventType.Class.STATE),
room_id=room_id,
event_id=EventID("$fake-create-id"),
sender=self.mxid,
state_key=evt.get("state_key", ""),
timestamp=0,
content=evt["content"],
)
)
return room_id
async def send_state_event(
self,
room_id: RoomID,
event_type: EventType,
content: StateEventContent | dict[str, JSON],
state_key: str = "",
**kwargs,
) -> EventID:
event_id = await super().send_state_event(
room_id, event_type, content, state_key, **kwargs
)
if self.state_store:
fake_event = StateEvent(
type=event_type,
room_id=room_id,
event_id=event_id,
sender=self.mxid,
state_key=state_key,
timestamp=0,
content=content,
)
await self.state_store.update_state(fake_event)
return event_id
async def get_state_event(
self, room_id: RoomID, event_type: EventType, state_key: str = ""
) -> StateEventContent:
event = await super().get_state_event(room_id, event_type, state_key)
if self.state_store:
fake_event = StateEvent(
type=event_type,
room_id=room_id,
event_id=EventID(""),
sender=UserID(""),
state_key=state_key,
timestamp=0,
content=event,
)
await self.state_store.update_state(fake_event)
return event
async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]:
members = await super().get_joined_members(room_id)
if self.state_store:
await self.state_store.set_members(room_id, members, only_membership=Membership.JOIN)
return members
async def get_members(
self,
room_id: RoomID,
at: SyncToken | None = None,
membership: Membership | None = None,
not_membership: Membership | None = None,
) -> list[StateEvent]:
member_events = await super().get_members(room_id, at, membership, not_membership)
if self.state_store and not_membership != Membership.JOIN:
await self.state_store.set_members(
room_id,
{evt.state_key: evt.content for evt in member_events},
only_membership=membership,
)
return member_events
async def fill_member_event(
self,
room_id: RoomID,
user_id: UserID,
content: MemberStateEventContent,
) -> MemberStateEventContent | None:
"""
Fill a membership event content that is going to be sent in :meth:`send_member_event`.
This is used to set default fields like the displayname and avatar, which are usually set
by the server in the sugar membership endpoints like /join and /invite, but are not set
automatically when sending member events manually.
This implementation in StoreUpdatingAPI will first try to call the default implementation
(which calls :attr:`fill_member_event_callback`). If that doesn't return anything, this
will try to get the profile from the current member event, and then fall back to fetching
the global profile from the server.
Args:
room_id: The room where the member event is going to be sent.
user_id: The user whose membership is changing.
content: The new member event content.
Returns:
The filled member event content.
"""
callback_content = await super().fill_member_event(room_id, user_id, content)
if callback_content is not None:
self.log.trace("Filled new member event for %s using callback", user_id)
return callback_content
if content.displayname is None and content.avatar_url is None:
existing_member = await self.state_store.get_member(room_id, user_id)
if existing_member is not None:
self.log.trace(
"Found existing member event %s to fill new member event for %s",
existing_member,
user_id,
)
content.displayname = existing_member.displayname
content.avatar_url = existing_member.avatar_url
return content
try:
profile = await self.get_profile(user_id)
except (MNotFound, MForbidden):
profile = None
if profile:
self.log.trace(
"Fetched profile %s to fill new member event of %s", profile, user_id
)
content.displayname = profile.displayname
content.avatar_url = profile.avatar_url
return content
else:
self.log.trace("Didn't find profile info to fill new member event of %s", user_id)
else:
self.log.trace(
"Member event for %s already contains displayname or avatar, not re-filling",
user_id,
)
return None