130 lines
3.8 KiB
Python
130 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Awaitable, Callable, TypeVar
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
import urllib.request
|
|
|
|
from aiohttp import ClientConnectionError
|
|
from yarl import URL
|
|
|
|
from mautrix.util.logging import TraceLogger
|
|
|
|
try:
|
|
from aiohttp_socks import ProxyConnectionError, ProxyError, ProxyTimeoutError
|
|
except ImportError:
|
|
|
|
class ProxyError(Exception):
|
|
pass
|
|
|
|
ProxyConnectionError = ProxyTimeoutError = ProxyError
|
|
|
|
RETRYABLE_PROXY_EXCEPTIONS = (
|
|
ProxyError,
|
|
ProxyTimeoutError,
|
|
ProxyConnectionError,
|
|
ClientConnectionError,
|
|
ConnectionError,
|
|
asyncio.TimeoutError,
|
|
)
|
|
|
|
|
|
class ProxyHandler:
|
|
current_proxy_url: str | None = None
|
|
log = logging.getLogger("mau.proxy")
|
|
|
|
def __init__(self, api_url: str | None) -> None:
|
|
self.api_url = api_url
|
|
|
|
def get_proxy_url_from_api(self, reason: str | None = None) -> str | None:
|
|
assert self.api_url is not None
|
|
|
|
api_url = str(URL(self.api_url).update_query({"reason": reason} if reason else {}))
|
|
|
|
# NOTE: using urllib.request to intentionally block the whole bridge until the proxy change applied
|
|
request = urllib.request.Request(api_url, method="GET")
|
|
self.log.debug("Requesting proxy from: %s", api_url)
|
|
|
|
try:
|
|
with urllib.request.urlopen(request) as f:
|
|
response = json.loads(f.read().decode())
|
|
except Exception:
|
|
self.log.exception("Failed to retrieve proxy from API")
|
|
return self.current_proxy_url
|
|
else:
|
|
return response["proxy_url"]
|
|
|
|
def update_proxy_url(self, reason: str | None = None) -> bool:
|
|
old_proxy = self.current_proxy_url
|
|
new_proxy = None
|
|
|
|
if self.api_url is not None:
|
|
new_proxy = self.get_proxy_url_from_api(reason)
|
|
else:
|
|
new_proxy = urllib.request.getproxies().get("http")
|
|
|
|
if old_proxy != new_proxy:
|
|
self.log.debug("Set new proxy URL: %s", new_proxy)
|
|
self.current_proxy_url = new_proxy
|
|
return True
|
|
|
|
self.log.debug("Got same proxy URL: %s", new_proxy)
|
|
return False
|
|
|
|
def get_proxy_url(self) -> str | None:
|
|
if not self.current_proxy_url:
|
|
self.update_proxy_url()
|
|
|
|
return self.current_proxy_url
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
async def proxy_with_retry(
|
|
name: str,
|
|
func: Callable[[], Awaitable[T]],
|
|
logger: TraceLogger,
|
|
proxy_handler: ProxyHandler,
|
|
on_proxy_change: Callable[[], Awaitable[None]],
|
|
max_retries: int = 10,
|
|
min_wait_seconds: int = 0,
|
|
max_wait_seconds: int = 60,
|
|
multiply_wait_seconds: int = 10,
|
|
retryable_exceptions: tuple[Exception] = RETRYABLE_PROXY_EXCEPTIONS,
|
|
reset_after_seconds: int | None = None,
|
|
) -> T:
|
|
errors = 0
|
|
last_error = 0
|
|
|
|
while True:
|
|
try:
|
|
return await func()
|
|
except retryable_exceptions as e:
|
|
errors += 1
|
|
if errors > max_retries:
|
|
raise
|
|
wait = errors * multiply_wait_seconds
|
|
wait = max(wait, min_wait_seconds)
|
|
wait = min(wait, max_wait_seconds)
|
|
logger.warning(
|
|
"%s while trying to %s, retrying in %d seconds",
|
|
e.__class__.__name__,
|
|
name,
|
|
wait,
|
|
)
|
|
if errors > 1 and proxy_handler.update_proxy_url(
|
|
f"{e.__class__.__name__} while trying to {name}"
|
|
):
|
|
await on_proxy_change()
|
|
|
|
# If sufficient time has passed since the previous error, reset the
|
|
# error count. Useful for long running tasks with rare failures.
|
|
if reset_after_seconds is not None:
|
|
now = time.time()
|
|
if last_error and now - last_error > reset_after_seconds:
|
|
errors = 0
|
|
last_error = now
|