mirror of https://github.com/home-assistant/core
123 lines
3.4 KiB
Python
123 lines
3.4 KiB
Python
"""Utilities to help with aiohttp."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from http import HTTPStatus
|
|
import io
|
|
from typing import Any
|
|
from urllib.parse import parse_qsl
|
|
|
|
from aiohttp import payload, web
|
|
from aiohttp.typedefs import JSONDecoder
|
|
from multidict import CIMultiDict, MultiDict
|
|
|
|
from .json import json_loads
|
|
|
|
|
|
class MockStreamReader:
|
|
"""Small mock to imitate stream reader."""
|
|
|
|
def __init__(self, content: bytes) -> None:
|
|
"""Initialize mock stream reader."""
|
|
self._content = io.BytesIO(content)
|
|
|
|
async def read(self, byte_count: int = -1) -> bytes:
|
|
"""Read bytes."""
|
|
if byte_count == -1:
|
|
return self._content.read()
|
|
return self._content.read(byte_count)
|
|
|
|
|
|
class MockPayloadWriter:
|
|
"""Small mock to imitate payload writer."""
|
|
|
|
def enable_chunking(self) -> None:
|
|
"""Enable chunking."""
|
|
|
|
async def write_headers(self, *args: Any, **kwargs: Any) -> None:
|
|
"""Write headers."""
|
|
|
|
|
|
_MOCK_PAYLOAD_WRITER = MockPayloadWriter()
|
|
|
|
|
|
class MockRequest:
|
|
"""Mock an aiohttp request."""
|
|
|
|
mock_source: str | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
content: bytes,
|
|
mock_source: str,
|
|
method: str = "GET",
|
|
status: int = HTTPStatus.OK,
|
|
headers: dict[str, str] | None = None,
|
|
query_string: str | None = None,
|
|
url: str = "",
|
|
) -> None:
|
|
"""Initialize a request."""
|
|
self.method = method
|
|
self.url = url
|
|
self.status = status
|
|
self.headers: CIMultiDict[str] = CIMultiDict(headers or {})
|
|
self.query_string = query_string or ""
|
|
self.keep_alive = False
|
|
self.version = (1, 1)
|
|
self._content = content
|
|
self.mock_source = mock_source
|
|
self._payload_writer = _MOCK_PAYLOAD_WRITER
|
|
|
|
async def _prepare_hook(self, response: Any) -> None:
|
|
"""Prepare hook."""
|
|
|
|
@property
|
|
def query(self) -> MultiDict[str]:
|
|
"""Return a dictionary with the query variables."""
|
|
return MultiDict(parse_qsl(self.query_string, keep_blank_values=True))
|
|
|
|
@property
|
|
def _text(self) -> str:
|
|
"""Return the body as text."""
|
|
return self._content.decode("utf-8")
|
|
|
|
@property
|
|
def content(self) -> MockStreamReader:
|
|
"""Return the body as text."""
|
|
return MockStreamReader(self._content)
|
|
|
|
@property
|
|
def body_exists(self) -> bool:
|
|
"""Return True if request has HTTP BODY, False otherwise."""
|
|
return bool(self._text)
|
|
|
|
async def json(self, loads: JSONDecoder = json_loads) -> Any:
|
|
"""Return the body as JSON."""
|
|
return loads(self._text)
|
|
|
|
async def post(self) -> MultiDict[str]:
|
|
"""Return POST parameters."""
|
|
return MultiDict(parse_qsl(self._text, keep_blank_values=True))
|
|
|
|
async def text(self) -> str:
|
|
"""Return the body as text."""
|
|
return self._text
|
|
|
|
|
|
def serialize_response(response: web.Response) -> dict[str, Any]:
|
|
"""Serialize an aiohttp response to a dictionary."""
|
|
if (body := response.body) is None:
|
|
body_decoded = None
|
|
elif isinstance(body, payload.StringPayload):
|
|
body_decoded = body._value.decode(body.encoding or "utf-8") # noqa: SLF001
|
|
elif isinstance(body, bytes):
|
|
body_decoded = body.decode(response.charset or "utf-8")
|
|
else:
|
|
raise TypeError("Unknown payload encoding")
|
|
|
|
return {
|
|
"status": response.status,
|
|
"body": body_decoded,
|
|
"headers": dict(response.headers),
|
|
}
|