401 lines
13 KiB
Python
401 lines
13 KiB
Python
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
|
# Copyright (C) 2019 Tulir Asokan
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
from __future__ import annotations
|
|
|
|
from typing import AsyncGenerator, Awaitable, Union, cast
|
|
from collections import defaultdict
|
|
import asyncio
|
|
import hashlib
|
|
import logging
|
|
import math
|
|
import time
|
|
|
|
from aiohttp import ClientResponse
|
|
from telethon import helpers, utils
|
|
from telethon.crypto import AuthKey
|
|
from telethon.network import MTProtoSender
|
|
from telethon.tl.alltlobjects import LAYER
|
|
from telethon.tl.functions import InvokeWithLayerRequest
|
|
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
|
|
from telethon.tl.functions.upload import (
|
|
GetFileRequest,
|
|
SaveBigFilePartRequest,
|
|
SaveFilePartRequest,
|
|
)
|
|
from telethon.tl.types import (
|
|
Document,
|
|
InputDocumentFileLocation,
|
|
InputFile,
|
|
InputFileBig,
|
|
InputFileLocation,
|
|
InputPeerPhotoFileLocation,
|
|
InputPhotoFileLocation,
|
|
TypeInputFile,
|
|
)
|
|
|
|
from mautrix.appservice import IntentAPI
|
|
from mautrix.types import ContentURI, EncryptedFile
|
|
from mautrix.util.logging import TraceLogger
|
|
|
|
from ..db import TelegramFile as DBTelegramFile
|
|
from ..tgclient import MautrixTelegramClient
|
|
|
|
try:
|
|
from mautrix.crypto.attachments import async_encrypt_attachment
|
|
except ImportError:
|
|
async_encrypt_attachment = None
|
|
|
|
log: TraceLogger = cast(TraceLogger, logging.getLogger("mau.util"))
|
|
|
|
TypeLocation = Union[
|
|
Document,
|
|
InputDocumentFileLocation,
|
|
InputPeerPhotoFileLocation,
|
|
InputFileLocation,
|
|
InputPhotoFileLocation,
|
|
]
|
|
|
|
|
|
class DownloadSender:
|
|
sender: MTProtoSender
|
|
request: GetFileRequest
|
|
remaining: int
|
|
stride: int
|
|
|
|
def __init__(
|
|
self,
|
|
sender: MTProtoSender,
|
|
file: TypeLocation,
|
|
offset: int,
|
|
limit: int,
|
|
stride: int,
|
|
count: int,
|
|
) -> None:
|
|
self.sender = sender
|
|
self.request = GetFileRequest(file, offset=offset, limit=limit)
|
|
self.stride = stride
|
|
self.remaining = count
|
|
|
|
async def next(self) -> bytes | None:
|
|
if not self.remaining:
|
|
return None
|
|
result = await self.sender.send(self.request)
|
|
self.remaining -= 1
|
|
self.request.offset += self.stride
|
|
return result.bytes
|
|
|
|
def disconnect(self) -> Awaitable[None]:
|
|
return self.sender.disconnect()
|
|
|
|
|
|
class UploadSender:
|
|
sender: MTProtoSender
|
|
request: SaveFilePartRequest < SaveBigFilePartRequest
|
|
part_count: int
|
|
stride: int
|
|
previous: asyncio.Task | None
|
|
loop: asyncio.AbstractEventLoop
|
|
|
|
def __init__(
|
|
self,
|
|
sender: MTProtoSender,
|
|
file_id: int,
|
|
part_count: int,
|
|
big: bool,
|
|
index: int,
|
|
stride: int,
|
|
loop: asyncio.AbstractEventLoop,
|
|
) -> None:
|
|
self.sender = sender
|
|
self.part_count = part_count
|
|
if big:
|
|
self.request = SaveBigFilePartRequest(file_id, index, part_count, b"")
|
|
else:
|
|
self.request = SaveFilePartRequest(file_id, index, b"")
|
|
self.stride = stride
|
|
self.previous = None
|
|
self.loop = loop
|
|
|
|
async def next(self, data: bytes) -> None:
|
|
if self.previous:
|
|
await self.previous
|
|
self.previous = asyncio.create_task(self._next(data))
|
|
|
|
async def _next(self, data: bytes) -> None:
|
|
self.request.bytes = data
|
|
log.trace(
|
|
f"Sending file part {self.request.file_part}/{self.part_count}"
|
|
f" with {len(data)} bytes"
|
|
)
|
|
await self.sender.send(self.request)
|
|
self.request.file_part += self.stride
|
|
|
|
async def disconnect(self) -> None:
|
|
if self.previous:
|
|
await self.previous
|
|
return await self.sender.disconnect()
|
|
|
|
|
|
class ParallelTransferrer:
|
|
client: MautrixTelegramClient
|
|
loop: asyncio.AbstractEventLoop
|
|
dc_id: int
|
|
senders: list[DownloadSender | UploadSender] | None
|
|
auth_key: AuthKey
|
|
upload_ticker: int
|
|
|
|
def __init__(self, client: MautrixTelegramClient, dc_id: int | None = None) -> None:
|
|
self.client = client
|
|
self.loop = self.client.loop
|
|
self.dc_id = dc_id or self.client.session.dc_id
|
|
self.auth_key = (
|
|
None if dc_id and self.client.session.dc_id != dc_id else self.client.session.auth_key
|
|
)
|
|
self.senders = None
|
|
self.upload_ticker = 0
|
|
|
|
async def _cleanup(self) -> None:
|
|
await asyncio.gather(*(sender.disconnect() for sender in self.senders))
|
|
self.senders = None
|
|
|
|
@staticmethod
|
|
def _get_connection_count(
|
|
file_size: int, max_count: int = 20, full_size: int = 100 * 1024 * 1024
|
|
) -> int:
|
|
if file_size > full_size:
|
|
return max_count
|
|
return math.ceil((file_size / full_size) * max_count)
|
|
|
|
async def _init_download(
|
|
self, connections: int, file: TypeLocation, part_count: int, part_size: int
|
|
) -> None:
|
|
minimum, remainder = divmod(part_count, connections)
|
|
|
|
def get_part_count() -> int:
|
|
nonlocal remainder
|
|
if remainder > 0:
|
|
remainder -= 1
|
|
return minimum + 1
|
|
return minimum
|
|
|
|
# The first cross-DC sender will export+import the authorization, so we always create it
|
|
# before creating any other senders.
|
|
self.senders = [
|
|
await self._create_download_sender(
|
|
file, 0, part_size, connections * part_size, get_part_count()
|
|
),
|
|
*await asyncio.gather(
|
|
*(
|
|
self._create_download_sender(
|
|
file, i, part_size, connections * part_size, get_part_count()
|
|
)
|
|
for i in range(1, connections)
|
|
)
|
|
),
|
|
]
|
|
|
|
async def _create_download_sender(
|
|
self, file: TypeLocation, index: int, part_size: int, stride: int, part_count: int
|
|
) -> DownloadSender:
|
|
return DownloadSender(
|
|
await self._create_sender(), file, index * part_size, part_size, stride, part_count
|
|
)
|
|
|
|
async def _init_upload(
|
|
self, connections: int, file_id: int, part_count: int, big: bool
|
|
) -> None:
|
|
self.senders = [
|
|
await self._create_upload_sender(file_id, part_count, big, 0, connections),
|
|
*await asyncio.gather(
|
|
*(
|
|
self._create_upload_sender(file_id, part_count, big, i, connections)
|
|
for i in range(1, connections)
|
|
)
|
|
),
|
|
]
|
|
|
|
async def _create_upload_sender(
|
|
self, file_id: int, part_count: int, big: bool, index: int, stride: int
|
|
) -> UploadSender:
|
|
return UploadSender(
|
|
await self._create_sender(), file_id, part_count, big, index, stride, loop=self.loop
|
|
)
|
|
|
|
async def _create_sender(self) -> MTProtoSender:
|
|
dc = await self.client._get_dc(self.dc_id)
|
|
sender = MTProtoSender(self.auth_key, loggers=self.client._log)
|
|
await sender.connect(
|
|
self.client._connection(
|
|
dc.ip_address, dc.port, dc.id, loggers=self.client._log, proxy=self.client._proxy
|
|
)
|
|
)
|
|
if not self.auth_key:
|
|
log.debug(f"Exporting auth to DC {self.dc_id}")
|
|
auth = await self.client(ExportAuthorizationRequest(self.dc_id))
|
|
self.client._init_request.query = ImportAuthorizationRequest(
|
|
id=auth.id, bytes=auth.bytes
|
|
)
|
|
req = InvokeWithLayerRequest(LAYER, self.client._init_request)
|
|
await sender.send(req)
|
|
self.auth_key = sender.auth_key
|
|
return sender
|
|
|
|
async def init_upload(
|
|
self,
|
|
file_id: int,
|
|
file_size: int,
|
|
part_size_kb: float | None = None,
|
|
connection_count: int | None = None,
|
|
) -> tuple[int, int, bool]:
|
|
connection_count = connection_count or self._get_connection_count(file_size)
|
|
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
|
|
part_count = (file_size + part_size - 1) // part_size
|
|
is_large = file_size > 10 * 1024 * 1024
|
|
await self._init_upload(connection_count, file_id, part_count, is_large)
|
|
return part_size, part_count, is_large
|
|
|
|
async def upload(self, part: bytes) -> None:
|
|
await self.senders[self.upload_ticker].next(part)
|
|
self.upload_ticker = (self.upload_ticker + 1) % len(self.senders)
|
|
|
|
async def finish_upload(self) -> None:
|
|
await self._cleanup()
|
|
|
|
async def download(
|
|
self,
|
|
file: TypeLocation,
|
|
file_size: int,
|
|
part_size_kb: float | None = None,
|
|
connection_count: int | None = None,
|
|
) -> AsyncGenerator[bytes, None]:
|
|
connection_count = connection_count or self._get_connection_count(file_size)
|
|
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
|
|
part_count = math.ceil(file_size / part_size)
|
|
log.debug(
|
|
f"Starting parallel download: {connection_count} {part_size} {part_count} {file!s}"
|
|
)
|
|
await self._init_download(connection_count, file, part_count, part_size)
|
|
|
|
part = 0
|
|
while part < part_count:
|
|
tasks = []
|
|
for sender in self.senders:
|
|
tasks.append(asyncio.create_task(sender.next()))
|
|
for task in tasks:
|
|
data = await task
|
|
if not data:
|
|
break
|
|
yield data
|
|
part += 1
|
|
log.trace(f"Part {part} downloaded")
|
|
|
|
log.debug("Parallel download finished, cleaning up connections")
|
|
await self._cleanup()
|
|
|
|
|
|
parallel_transfer_locks: defaultdict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
|
|
|
|
|
|
async def parallel_transfer_to_matrix(
|
|
client: MautrixTelegramClient,
|
|
intent: IntentAPI,
|
|
loc_id: str,
|
|
location: TypeLocation,
|
|
filename: str,
|
|
encrypt: bool,
|
|
parallel_id: int,
|
|
) -> DBTelegramFile:
|
|
size = location.size
|
|
mime_type = location.mime_type
|
|
dc_id, location = utils.get_input_location(location)
|
|
# We lock the transfers because telegram has connection count limits
|
|
async with parallel_transfer_locks[parallel_id]:
|
|
downloader = ParallelTransferrer(client, dc_id)
|
|
data = downloader.download(location, size)
|
|
decryption_info = None
|
|
up_mime_type = mime_type
|
|
if encrypt and async_encrypt_attachment:
|
|
|
|
async def encrypted(stream):
|
|
nonlocal decryption_info
|
|
async for chunk in async_encrypt_attachment(stream):
|
|
if isinstance(chunk, EncryptedFile):
|
|
decryption_info = chunk
|
|
else:
|
|
yield chunk
|
|
|
|
data = encrypted(data)
|
|
up_mime_type = "application/octet-stream"
|
|
content_uri = await intent.upload_media(
|
|
data, mime_type=up_mime_type, filename=filename, size=size if not encrypt else None
|
|
)
|
|
if decryption_info:
|
|
decryption_info.url = content_uri
|
|
return DBTelegramFile(
|
|
id=loc_id,
|
|
mxc=content_uri,
|
|
mime_type=mime_type,
|
|
was_converted=False,
|
|
timestamp=int(time.time()),
|
|
size=size,
|
|
width=None,
|
|
height=None,
|
|
decryption_info=decryption_info,
|
|
)
|
|
|
|
|
|
async def _internal_transfer_to_telegram(
|
|
client: MautrixTelegramClient, response: ClientResponse
|
|
) -> tuple[TypeInputFile, int]:
|
|
file_id = helpers.generate_random_long()
|
|
file_size = response.content_length
|
|
|
|
hash_md5 = hashlib.md5()
|
|
uploader = ParallelTransferrer(client)
|
|
part_size, part_count, is_large = await uploader.init_upload(file_id, file_size)
|
|
buffer = bytearray()
|
|
async for data in response.content:
|
|
if not is_large:
|
|
hash_md5.update(data)
|
|
if len(buffer) == 0 and len(data) == part_size:
|
|
await uploader.upload(data)
|
|
continue
|
|
new_len = len(buffer) + len(data)
|
|
if new_len >= part_size:
|
|
cutoff = part_size - len(buffer)
|
|
buffer.extend(data[:cutoff])
|
|
await uploader.upload(bytes(buffer))
|
|
buffer.clear()
|
|
buffer.extend(data[cutoff:])
|
|
else:
|
|
buffer.extend(data)
|
|
if len(buffer) > 0:
|
|
await uploader.upload(bytes(buffer))
|
|
await uploader.finish_upload()
|
|
if is_large:
|
|
return InputFileBig(file_id, part_count, "upload"), file_size
|
|
else:
|
|
return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size
|
|
|
|
|
|
async def parallel_transfer_to_telegram(
|
|
client: MautrixTelegramClient, intent: IntentAPI, uri: ContentURI, parallel_id: int
|
|
) -> tuple[TypeInputFile, int]:
|
|
url = intent.api.get_download_url(uri)
|
|
async with parallel_transfer_locks[parallel_id]:
|
|
async with intent.api.session.get(url) as response:
|
|
return await _internal_transfer_to_telegram(client, response)
|