mautrix-python/mautrix/util/async_body.py

96 lines
3.2 KiB
Python

# Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import AsyncGenerator, Union
import logging
import aiohttp
AsyncBody = AsyncGenerator[Union[bytes, bytearray, memoryview], None]
async def async_iter_bytes(data: bytearray | bytes, chunk_size: int = 1024**2) -> AsyncBody:
"""
Return memory views into a byte array in chunks. This is used to prevent aiohttp from copying
the entire request body.
Args:
data: The underlying data to iterate through.
chunk_size: How big each returned chunk should be.
Returns:
An async generator that yields the given data in chunks.
"""
with memoryview(data) as mv:
for i in range(0, len(data), chunk_size):
yield mv[i : i + chunk_size]
class FileTooLargeError(Exception):
def __init__(self, max_size: int) -> None:
super().__init__(f"File size larger than maximum ({max_size / 1024 / 1024} MiB)")
_default_dl_log = logging.getLogger("mau.util.download")
async def read_response_chunks(
resp: aiohttp.ClientResponse, max_size: int, log: logging.Logger = _default_dl_log
) -> bytearray:
"""
Read the body from an aiohttp response in chunks into a mutable bytearray.
Args:
resp: The aiohttp response object to read the body from.
max_size: The maximum size to read. FileTooLargeError will be raised if the Content-Length
is higher than this, or if the body exceeds this size during reading.
log: A logger for logging download status.
Returns:
The body data as a byte array.
Raises:
FileTooLargeError: if the body is larger than the provided max_size.
"""
content_length = int(resp.headers.get("Content-Length", "0"))
if 0 < max_size < content_length:
raise FileTooLargeError(max_size)
size_str = "unknown length" if content_length == 0 else f"{content_length} bytes"
log.info(f"Reading file download response with {size_str} (max: {max_size})")
data = bytearray(content_length)
mv = memoryview(data) if content_length > 0 else None
read_size = 0
max_size += 1
while True:
block = await resp.content.readany()
if not block:
break
max_size -= len(block)
if max_size <= 0:
raise FileTooLargeError(max_size)
if len(data) >= read_size + len(block):
mv[read_size : read_size + len(block)] = block
elif len(data) > read_size:
log.warning("File being downloaded is bigger than expected")
mv[read_size:] = block[: len(data) - read_size]
mv.release()
mv = None
data.extend(block[len(data) - read_size :])
else:
if mv is not None:
mv.release()
mv = None
data.extend(block)
read_size += len(block)
if mv is not None:
mv.release()
log.info(f"Successfully read {read_size} bytes of file download response")
return data
__all__ = ["AsyncBody", "FileTooLargeError", "async_iter_bytes", "async_read_bytes"]