54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
# Copyright (c) 2022 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
from __future__ import annotations
|
|
|
|
from typing import Awaitable, Callable
|
|
from abc import ABC
|
|
|
|
from mautrix import __optional_imports__
|
|
from mautrix.bridge.portal import BasePortal
|
|
from mautrix.crypto import StateStore
|
|
from mautrix.types import RoomEncryptionStateEventContent, RoomID, UserID
|
|
from mautrix.util.async_db import Database
|
|
|
|
GetPortalFunc = Callable[[RoomID], Awaitable[BasePortal]]
|
|
|
|
|
|
class BaseCryptoStateStore(StateStore, ABC):
|
|
get_portal: GetPortalFunc
|
|
|
|
def __init__(self, get_portal: GetPortalFunc):
|
|
self.get_portal = get_portal
|
|
|
|
async def is_encrypted(self, room_id: RoomID) -> bool:
|
|
portal = await self.get_portal(room_id)
|
|
return portal.encrypted if portal else False
|
|
|
|
|
|
class PgCryptoStateStore(BaseCryptoStateStore):
|
|
db: Database
|
|
|
|
def __init__(self, db: Database, get_portal: GetPortalFunc) -> None:
|
|
super().__init__(get_portal)
|
|
self.db = db
|
|
|
|
async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]:
|
|
rows = await self.db.fetch(
|
|
"SELECT room_id FROM mx_user_profile "
|
|
"LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id "
|
|
"WHERE user_id=$1 AND portal.encrypted=true",
|
|
user_id,
|
|
)
|
|
return [row["room_id"] for row in rows]
|
|
|
|
async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None:
|
|
val = await self.db.fetchval(
|
|
"SELECT encryption FROM mx_room_state WHERE room_id=$1", room_id
|
|
)
|
|
if not val:
|
|
return None
|
|
return RoomEncryptionStateEventContent.parse_json(val)
|