mautrix-python/mautrix/api.py

512 lines
18 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import ClassVar, Literal, Mapping
from enum import Enum
from json.decoder import JSONDecodeError
from urllib.parse import quote as urllib_quote, urljoin as urllib_join
import asyncio
import inspect
import json
import logging
import platform
import time
from aiohttp import ClientResponse, ClientSession, __version__ as aiohttp_version
from aiohttp.client_exceptions import ClientError, ContentTypeError
from yarl import URL
from mautrix import __optional_imports__, __version__ as mautrix_version
from mautrix.errors import MatrixConnectionError, MatrixRequestError, make_request_error
from mautrix.util.async_body import AsyncBody, async_iter_bytes
from mautrix.util.logging import TraceLogger
from mautrix.util.opt_prometheus import Counter
if __optional_imports__:
# Safe to import, but it's not actually needed, so don't force-import the whole types module.
from mautrix.types import JSON, DeviceID, UserID
API_CALLS = Counter(
name="bridge_matrix_api_calls",
documentation="The number of Matrix client API calls made",
labelnames=("method",),
)
API_CALLS_FAILED = Counter(
name="bridge_matrix_api_calls_failed",
documentation="The number of Matrix client API calls which failed",
labelnames=("method",),
)
class APIPath(Enum):
"""
The known Matrix API path prefixes.
These don't start with a slash so they can be used nicely with yarl.
"""
CLIENT = "_matrix/client"
MEDIA = "_matrix/media"
SYNAPSE_ADMIN = "_synapse/admin"
def __repr__(self):
return self.value
def __str__(self):
return self.value
class Method(Enum):
"""A HTTP method."""
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
def __repr__(self):
return self.value
def __str__(self):
return self.value
class PathBuilder:
"""
A utility class to build API paths.
Examples:
>>> from mautrix.api import Path
>>> room_id = "!foo:example.com"
>>> event_id = "$bar:example.com"
>>> str(Path.v3.rooms[room_id].event[event_id])
"_matrix/client/v3/rooms/%21foo%3Aexample.com/event/%24bar%3Aexample.com"
"""
def __init__(self, path: str | APIPath = "") -> None:
self.path: str = str(path)
def __str__(self) -> str:
return self.path
def __repr__(self):
return self.path
def __getattr__(self, append: str) -> PathBuilder:
if append is None:
return self
return PathBuilder(f"{self.path}/{append}")
def raw(self, append: str) -> PathBuilder:
"""
Directly append a string to the path.
Args:
append: The string to append.
"""
if append is None:
return self
return PathBuilder(self.path + append)
def __eq__(self, other: PathBuilder | str) -> bool:
return other.path == self.path if isinstance(other, PathBuilder) else other == self.path
@staticmethod
def _quote(string: str) -> str:
return urllib_quote(string, safe="")
def __getitem__(self, append: str | int) -> PathBuilder:
if append is None:
return self
return PathBuilder(f"{self.path}/{self._quote(str(append))}")
def replace(self, find: str, replace: str) -> PathBuilder:
return PathBuilder(self.path.replace(find, replace))
ClientPath = PathBuilder(APIPath.CLIENT)
ClientPath.__doc__ = """
A path builder with the standard client prefix ( ``/_matrix/client``, :attr:`APIPath.CLIENT`).
"""
Path = PathBuilder(APIPath.CLIENT)
Path.__doc__ = """A shorter alias for :attr:`ClientPath`"""
MediaPath = PathBuilder(APIPath.MEDIA)
MediaPath.__doc__ = """
A path builder with the standard media prefix (``/_matrix/media``, :attr:`APIPath.MEDIA`)
Examples:
>>> from mautrix.api import MediaPath
>>> str(MediaPath.v3.config)
"_matrix/media/v3/config"
"""
SynapseAdminPath = PathBuilder(APIPath.SYNAPSE_ADMIN)
SynapseAdminPath.__doc__ = """
A path builder for synapse-specific admin API paths
(``/_synapse/admin``, :attr:`APIPath.SYNAPSE_ADMIN`)
Examples:
>>> from mautrix.api import SynapseAdminPath
>>> user_id = "@user:example.com"
>>> str(SynapseAdminPath.v1.users[user_id]/login)
"_synapse/admin/v1/users/%40user%3Aexample.com/login"
"""
_req_id = 0
def _next_global_req_id() -> int:
global _req_id
_req_id += 1
return _req_id
class HTTPAPI:
"""HTTPAPI is a simple asyncio Matrix API request sender."""
default_ua: ClassVar[str] = (
f"mautrix-python/{mautrix_version} aiohttp/{aiohttp_version} "
f"Python/{platform.python_version()}"
)
"""
The default value for the ``User-Agent`` header.
You should prepend your program name and version here before creating any HTTPAPI instances
in order to have proper user agents for all requests.
"""
global_default_retry_count: ClassVar[int] = 0
"""The default retry count to use if an instance-specific value is not passed."""
base_url: URL
"""The base URL of the homeserver's client-server API to use."""
token: str
"""The access token to use in requests."""
log: TraceLogger
"""The :class:`logging.Logger` instance to log requests with."""
session: ClientSession
"""The aiohttp ClientSession instance to make requests with."""
txn_id: int | None
"""A counter used for generating transaction IDs."""
default_retry_count: int
"""The default retry count to use if a custom value is not passed to :meth:`request`"""
as_user_id: UserID | None
"""An optional user ID to set as the user_id query parameter for appservice requests."""
as_device_id: DeviceID | None
"""
An optional device ID to set as the user_id query parameter for appservice requests (MSC3202).
"""
def __init__(
self,
base_url: URL | str,
token: str = "",
*,
client_session: ClientSession = None,
default_retry_count: int = None,
txn_id: int = 0,
log: TraceLogger | None = None,
loop: asyncio.AbstractEventLoop | None = None,
as_user_id: UserID | None = None,
as_device_id: UserID | None = None,
) -> None:
"""
Args:
base_url: The base URL of the homeserver's client-server API to use.
token: The access token to use.
client_session: The aiohttp client session to use.
txn_id: The outgoing transaction ID to start with.
log: The :class:`logging.Logger` instance to log requests with.
default_retry_count: Default number of retries to do when encountering network errors.
as_user_id: An optional user ID to set as the user_id query parameter for
appservice requests.
as_device_id: An optional device ID to set as the user_id query parameter for
appservice requests (MSC3202).
"""
self.base_url = URL(base_url)
self.token = token
self.log = log or logging.getLogger("mau.http")
self.session = client_session or ClientSession(
loop=loop, headers={"User-Agent": self.default_ua}
)
self.as_user_id = as_user_id
self.as_device_id = as_device_id
if txn_id is not None:
self.txn_id = txn_id
if default_retry_count is not None:
self.default_retry_count = default_retry_count
else:
self.default_retry_count = self.global_default_retry_count
async def _send(
self,
method: Method,
url: URL,
content: bytes | bytearray | str | AsyncBody,
query_params: dict[str, str],
headers: dict[str, str],
) -> tuple[JSON, ClientResponse]:
request = self.session.request(
str(method), url, data=content, params=query_params, headers=headers
)
async with request as response:
if response.status < 200 or response.status >= 300:
errcode = unstable_errcode = message = None
try:
response_data = await response.json()
errcode = response_data["errcode"]
message = response_data["error"]
unstable_errcode = response_data.get("org.matrix.msc3848.unstable.errcode")
except (JSONDecodeError, ContentTypeError, KeyError):
pass
raise make_request_error(
http_status=response.status,
text=await response.text(),
errcode=errcode,
message=message,
unstable_errcode=unstable_errcode,
)
return await response.json(), response
def _log_request(
self,
method: Method,
url: URL,
content: str | bytes | bytearray | AsyncBody | None,
orig_content,
query_params: dict[str, str],
headers: dict[str, str],
req_id: int,
sensitive: bool,
) -> None:
if not self.log:
return
if isinstance(content, (bytes, bytearray)):
log_content = f"<{len(content)} bytes>"
elif inspect.isasyncgen(content):
size = headers.get("Content-Length", None)
log_content = f"<{size} async bytes>" if size else f"<stream with unknown length>"
elif sensitive:
log_content = f"<{len(content)} sensitive bytes>"
else:
log_content = content
as_user = query_params.get("user_id", None)
level = 5 if url.path.endswith("/v3/sync") else 10
self.log.log(
level,
f"req #{req_id}: {method} {url} {log_content}".strip(" "),
extra={
"matrix_http_request": {
"req_id": req_id,
"method": str(method),
"url": str(url),
"content": (
orig_content
if isinstance(orig_content, (dict, list)) and not sensitive
else log_content
),
"user": as_user,
}
},
)
def _log_request_done(
self, path: PathBuilder | str, req_id: int, duration: float, status: int
) -> None:
level = 5 if path == Path.v3.sync else 10
duration_str = f"{duration * 1000:.1f}ms" if duration < 1 else f"{duration:.3f}s"
path_without_prefix = f"/{path}".replace("/_matrix/client", "")
self.log.log(
level,
f"req #{req_id} ({path_without_prefix}) completed in {duration_str} "
f"with status {status}",
)
def _full_path(self, path: PathBuilder | str) -> str:
path = str(path)
if path and path[0] == "/":
path = path[1:]
base_path = self.base_url.raw_path
if base_path[-1] != "/":
base_path += "/"
return urllib_join(base_path, path)
def log_download_request(self, url: URL, query_params: dict[str, str]) -> int:
req_id = _next_global_req_id()
self._log_request(Method.GET, url, None, None, query_params, {}, req_id, False)
return req_id
def log_download_request_done(
self, url: URL, req_id: int, duration: float, status: int
) -> None:
self._log_request_done(url.path.removeprefix("/_matrix/media/"), req_id, duration, status)
async def request(
self,
method: Method,
path: PathBuilder | str,
content: dict | list | bytes | bytearray | str | AsyncBody | None = None,
headers: dict[str, str] | None = None,
query_params: Mapping[str, str] | None = None,
retry_count: int | None = None,
metrics_method: str = "",
min_iter_size: int = 25 * 1024 * 1024,
sensitive: bool = False,
) -> JSON:
"""
Make a raw Matrix API request.
Args:
method: The HTTP method to use.
path: The full API endpoint to call (including the _matrix/... prefix)
content: The content to post as a dict/list (will be serialized as JSON)
or bytes/str (will be sent as-is).
headers: A dict of HTTP headers to send. If the headers don't contain ``Content-Type``,
it'll be set to ``application/json``. The ``Authorization`` header is always
overridden if :attr:`token` is set.
query_params: A dict of query parameters to send.
retry_count: Number of times to retry if the homeserver isn't reachable.
Defaults to :attr:`default_retry_count`.
metrics_method: Name of the method to include in Prometheus timing metrics.
min_iter_size: If the request body is larger than this value, it will be passed to
aiohttp as an async iterable to stop it from copying the whole thing
in memory.
sensitive: If True, the request content will not be logged.
Returns:
The parsed response JSON.
"""
headers = headers or {}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
query_params = query_params or {}
if isinstance(query_params, dict):
query_params = {k: v for k, v in query_params.items() if v is not None}
if self.as_user_id:
query_params["user_id"] = self.as_user_id
if self.as_device_id:
query_params["org.matrix.msc3202.device_id"] = self.as_device_id
query_params["device_id"] = self.as_device_id
if method != Method.GET:
content = content or {}
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
orig_content = content
is_json = headers.get("Content-Type", None) == "application/json"
if is_json and isinstance(content, (dict, list)):
content = json.dumps(content)
else:
orig_content = content = None
full_url = self.base_url.with_path(self._full_path(path), encoded=True)
req_id = _next_global_req_id()
if retry_count is None:
retry_count = self.default_retry_count
if inspect.isasyncgen(content):
# Can't retry with non-static body
retry_count = 0
do_fake_iter = content and hasattr(content, "__len__") and len(content) > min_iter_size
if do_fake_iter:
headers["Content-Length"] = str(len(content))
backoff = 4
log_url = full_url.with_query(query_params)
while True:
self._log_request(
method, log_url, content, orig_content, query_params, headers, req_id, sensitive
)
API_CALLS.labels(method=metrics_method).inc()
req_content = async_iter_bytes(content) if do_fake_iter else content
start = time.monotonic()
try:
resp_data, resp = await self._send(
method, full_url, req_content, query_params, headers or {}
)
self._log_request_done(path, req_id, time.monotonic() - start, resp.status)
return resp_data
except MatrixRequestError as e:
API_CALLS_FAILED.labels(method=metrics_method).inc()
if retry_count > 0 and e.http_status in (502, 503, 504):
self.log.warning(
f"Request #{req_id} failed with HTTP {e.http_status}, "
f"retrying in {backoff} seconds"
)
else:
self._log_request_done(path, req_id, time.monotonic() - start, e.http_status)
raise
except ClientError as e:
API_CALLS_FAILED.labels(method=metrics_method).inc()
if retry_count > 0:
self.log.warning(
f"Request #{req_id} failed with {e}, retrying in {backoff} seconds"
)
else:
raise MatrixConnectionError(str(e)) from e
except Exception:
API_CALLS_FAILED.labels(method=metrics_method).inc()
raise
await asyncio.sleep(backoff)
backoff *= 2
retry_count -= 1
def get_txn_id(self) -> str:
"""Get a new unique transaction ID."""
self.txn_id += 1
return f"mautrix-python_{time.time_ns()}_{self.txn_id}"
def get_download_url(
self,
mxc_uri: str,
download_type: Literal["download", "thumbnail"] = "download",
file_name: str | None = None,
) -> URL:
"""
Get the full HTTP URL to download a ``mxc://`` URI.
Args:
mxc_uri: The MXC URI whose full URL to get.
download_type: The type of download ("download" or "thumbnail").
file_name: Optionally, a file name to include in the download URL.
Returns:
The full HTTP URL.
Raises:
ValueError: If `mxc_uri` doesn't begin with ``mxc://``.
Examples:
>>> api = HTTPAPI(base_url="https://matrix-client.matrix.org", ...)
>>> api.get_download_url("mxc://matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6")
"https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6"
>>> api.get_download_url("mxc://matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6", file_name="hello.png")
"https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6/hello.png"
"""
server_name, media_id = self.parse_mxc_uri(mxc_uri)
url = self.base_url / str(APIPath.MEDIA) / "v3" / download_type / server_name / media_id
if file_name:
url /= file_name
return url
@staticmethod
def parse_mxc_uri(mxc_uri: str) -> tuple[str, str]:
"""
Parse a ``mxc://`` URI.
Args:
mxc_uri: The MXC URI to parse.
Returns:
A tuple containing the server and media ID of the MXC URI.
Raises:
ValueError: If `mxc_uri` doesn't begin with ``mxc://``.
"""
if mxc_uri.startswith("mxc://"):
server_name, media_id = mxc_uri[6:].split("/")
return server_name, media_id
else:
raise ValueError("MXC URI did not begin with `mxc://`")