267 lines
10 KiB
Python
267 lines
10 KiB
Python
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
|
# Copyright (C) 2021 Tulir Asokan
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, ClassVar, Iterable
|
|
import asyncio
|
|
import datetime
|
|
|
|
from telethon import utils
|
|
from telethon.crypto import AuthKey
|
|
from telethon.sessions import MemorySession
|
|
from telethon.tl.types import PeerChannel, PeerChat, PeerUser, updates
|
|
|
|
from mautrix.util.async_db import Database, Scheme
|
|
|
|
fake_db = Database.create("") if TYPE_CHECKING else None
|
|
|
|
|
|
class PgSession(MemorySession):
|
|
db: ClassVar[Database] = fake_db
|
|
|
|
session_id: str
|
|
_dc_id: int
|
|
_server_address: str | None
|
|
_port: int | None
|
|
_auth_key: AuthKey | None
|
|
_takeout_id: int | None
|
|
_process_entities_lock: asyncio.Lock
|
|
|
|
def __init__(
|
|
self,
|
|
session_id: str,
|
|
dc_id: int = 0,
|
|
server_address: str | None = None,
|
|
port: int | None = None,
|
|
auth_key: AuthKey | None = None,
|
|
takeout_id: int | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.session_id = session_id
|
|
self._dc_id = dc_id
|
|
self._server_address = server_address
|
|
self._port = port
|
|
self._auth_key = auth_key
|
|
self._takeout_id = takeout_id
|
|
self._process_entities_lock = asyncio.Lock()
|
|
|
|
def clone(self, to_instance=None) -> MemorySession:
|
|
# We don't want to store data of clones
|
|
# (which are used for temporarily connecting to different DCs)
|
|
return super().clone(MemorySession())
|
|
|
|
@property
|
|
def auth_key_bytes(self) -> bytes | None:
|
|
return self._auth_key.key if self._auth_key else None
|
|
|
|
@classmethod
|
|
async def get(cls, session_id: str) -> PgSession:
|
|
q = (
|
|
"SELECT session_id, dc_id, server_address, port, auth_key FROM telethon_sessions "
|
|
"WHERE session_id=$1"
|
|
)
|
|
row = await cls.db.fetchrow(q, session_id)
|
|
if row is None:
|
|
return cls(session_id)
|
|
data = {**row}
|
|
auth_key = AuthKey(data.pop("auth_key", None))
|
|
return cls(**data, auth_key=auth_key)
|
|
|
|
@classmethod
|
|
async def has(cls, session_id: str) -> bool:
|
|
q = "SELECT COUNT(*) FROM telethon_sessions WHERE session_id=$1"
|
|
count = await cls.db.fetchval(q, session_id)
|
|
return count > 0
|
|
|
|
async def save(self) -> None:
|
|
q = (
|
|
"INSERT INTO telethon_sessions (session_id, dc_id, server_address, port, auth_key) "
|
|
"VALUES ($1, $2, $3, $4, $5) ON CONFLICT (session_id) "
|
|
"DO UPDATE SET dc_id=$2, server_address=$3, port=$4, auth_key=$5"
|
|
)
|
|
await self.db.execute(
|
|
q, self.session_id, self.dc_id, self.server_address, self.port, self.auth_key_bytes
|
|
)
|
|
|
|
_tables: ClassVar[tuple[str, ...]] = (
|
|
"telethon_sessions",
|
|
"telethon_entities",
|
|
"telethon_sent_files",
|
|
"telethon_update_state",
|
|
)
|
|
|
|
async def delete(self) -> None:
|
|
async with self.db.acquire() as conn, conn.transaction():
|
|
for table in self._tables:
|
|
await conn.execute(f"DELETE FROM {table} WHERE session_id=$1", self.session_id)
|
|
|
|
async def close(self) -> None:
|
|
# Nothing to do here, DB connection is global
|
|
pass
|
|
|
|
async def get_update_state(self, entity_id: int) -> updates.State | None:
|
|
q = (
|
|
"SELECT pts, qts, date, seq, unread_count FROM telethon_update_state "
|
|
"WHERE session_id=$1 AND entity_id=$2"
|
|
)
|
|
row = await self.db.fetchrow(q, self.session_id, entity_id)
|
|
if row is None:
|
|
return None
|
|
date = datetime.datetime.utcfromtimestamp(row["date"])
|
|
return updates.State(row["pts"], row["qts"], date, row["seq"], row["unread_count"])
|
|
|
|
_set_update_state_q = """
|
|
INSERT INTO telethon_update_state (session_id, entity_id, pts, qts, date, seq, unread_count)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
ON CONFLICT (session_id, entity_id) DO UPDATE SET
|
|
pts=excluded.pts, qts=excluded.qts, date=excluded.date, seq=excluded.seq,
|
|
unread_count=excluded.unread_count
|
|
"""
|
|
|
|
async def set_update_state(self, entity_id: int, row: updates.State) -> None:
|
|
q = self._set_update_state_q
|
|
ts = row.date.timestamp()
|
|
await self.db.execute(
|
|
q, self.session_id, entity_id, row.pts, row.qts, ts, row.seq, row.unread_count
|
|
)
|
|
|
|
async def set_update_states(self, rows: list[tuple[int, updates.State]]) -> None:
|
|
rows = [
|
|
(
|
|
self.session_id,
|
|
entity_id,
|
|
row.pts,
|
|
row.qts,
|
|
row.date.timestamp(),
|
|
row.seq,
|
|
row.unread_count,
|
|
)
|
|
for entity_id, row in rows
|
|
]
|
|
if self.db.scheme == Scheme.POSTGRES:
|
|
q = """
|
|
INSERT INTO telethon_update_state (
|
|
session_id, entity_id, pts, qts, date, seq, unread_count
|
|
)
|
|
VALUES (
|
|
$1,
|
|
unnest($2::bigint[]), unnest($3::bigint[]), unnest($4::bigint[]),
|
|
unnest($5::bigint[]), unnest($6::bigint[]), unnest($7::integer[])
|
|
)
|
|
ON CONFLICT (session_id, entity_id) DO UPDATE SET
|
|
pts=excluded.pts, qts=excluded.qts, date=excluded.date, seq=excluded.seq,
|
|
unread_count=excluded.unread_count
|
|
"""
|
|
_, entity_ids, ptses, qtses, timestamps, seqs, unread_counts = zip(*rows)
|
|
await self.db.execute(
|
|
q, self.session_id, entity_ids, ptses, qtses, timestamps, seqs, unread_counts
|
|
)
|
|
else:
|
|
await self.db.executemany(self._set_update_state_q, rows)
|
|
|
|
async def delete_update_state(self, entity_id: int) -> None:
|
|
q = "DELETE FROM telethon_update_state WHERE session_id=$1 AND entity_id=$2"
|
|
await self.db.execute(q, self.session_id, entity_id)
|
|
|
|
async def get_update_states(self) -> Iterable[tuple[int, updates.State], ...]:
|
|
q = (
|
|
"SELECT entity_id, pts, qts, date, seq, unread_count FROM telethon_update_state "
|
|
"WHERE session_id=$1"
|
|
)
|
|
rows = await self.db.fetch(q, self.session_id)
|
|
return (
|
|
(
|
|
row["entity_id"],
|
|
updates.State(
|
|
row["pts"],
|
|
row["qts"],
|
|
datetime.datetime.utcfromtimestamp(row["date"]),
|
|
row["seq"],
|
|
row["unread_count"],
|
|
),
|
|
)
|
|
for row in rows
|
|
)
|
|
|
|
def _entity_values_to_row(
|
|
self, id: int, hash: int, username: str | None, phone: str | int | None, name: str | None
|
|
) -> tuple[str, int, int, str | None, str | None, str | None]:
|
|
return self.session_id, id, hash, username, str(phone) if phone else None, name
|
|
|
|
async def process_entities(self, tlo) -> None:
|
|
# Postgres likes to deadlock on simultaneous upserts, so just lock the whole thing here
|
|
# TODO: make sure postgres doesn't deadlock on upserts when session_id is different
|
|
async with self._process_entities_lock:
|
|
await self._locked_process_entities(tlo)
|
|
|
|
async def _locked_process_entities(self, tlo) -> None:
|
|
rows: list[tuple[str, int, int, str | None, str | None, str | None]] = (
|
|
self._entities_to_rows(tlo)
|
|
)
|
|
if not rows:
|
|
return
|
|
if self.db.scheme == Scheme.POSTGRES:
|
|
q = (
|
|
"INSERT INTO telethon_entities (session_id, id, hash, username, phone, name) "
|
|
"VALUES ($1, unnest($2::bigint[]), unnest($3::bigint[]), "
|
|
" unnest($4::text[]), unnest($5::text[]), unnest($6::text[])) "
|
|
"ON CONFLICT (session_id, id) DO UPDATE"
|
|
" SET hash=excluded.hash, username=excluded.username,"
|
|
" phone=excluded.phone, name=excluded.name"
|
|
)
|
|
_, ids, hashes, usernames, phones, names = zip(*rows)
|
|
await self.db.execute(q, self.session_id, ids, hashes, usernames, phones, names)
|
|
else:
|
|
q = (
|
|
"INSERT INTO telethon_entities (session_id, id, hash, username, phone, name) "
|
|
"VALUES ($1, $2, $3, $4, $5, $6) "
|
|
"ON CONFLICT (session_id, id) DO UPDATE "
|
|
" SET hash=$3, username=$4, phone=$5, name=$6"
|
|
)
|
|
await self.db.executemany(q, rows)
|
|
|
|
async def _select_entity(
|
|
self, constraint: str, *args: str | int | tuple[int, ...]
|
|
) -> tuple[int, int] | None:
|
|
q = f"SELECT id, hash FROM telethon_entities WHERE session_id=$1 AND {constraint}"
|
|
row = await self.db.fetchrow(q, self.session_id, *args)
|
|
if row is None:
|
|
return None
|
|
return row["id"], row["hash"]
|
|
|
|
async def get_entity_rows_by_phone(self, key: str | int) -> tuple[int, int] | None:
|
|
return await self._select_entity("phone=$2", str(key))
|
|
|
|
async def get_entity_rows_by_username(self, key: str) -> tuple[int, int] | None:
|
|
return await self._select_entity("username=$2", key)
|
|
|
|
async def get_entity_rows_by_name(self, key: str) -> tuple[int, int] | None:
|
|
return await self._select_entity("name=$2", key)
|
|
|
|
async def get_entity_rows_by_id(self, key: int, exact: bool = True) -> tuple[int, int] | None:
|
|
if exact:
|
|
return await self._select_entity("id=$2", key)
|
|
|
|
ids = (
|
|
utils.get_peer_id(PeerUser(key)),
|
|
utils.get_peer_id(PeerChat(key)),
|
|
utils.get_peer_id(PeerChannel(key)),
|
|
)
|
|
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
|
|
return await self._select_entity("id=ANY($2)", ids)
|
|
else:
|
|
return await self._select_entity(f"id IN ($2, $3, $4)", *ids)
|