197 lines
6.6 KiB
Python
197 lines
6.6 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import abstractmethod
|
|
import asyncio
|
|
from collections.abc import Iterable
|
|
import logging
|
|
from typing import TYPE_CHECKING, Callable, cast
|
|
|
|
from ..core import SocketClosedAPIError
|
|
|
|
if TYPE_CHECKING:
|
|
from ..connection import APIConnection
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
SOCKET_ERRORS = (
|
|
ConnectionResetError,
|
|
asyncio.IncompleteReadError,
|
|
OSError,
|
|
TimeoutError,
|
|
)
|
|
|
|
|
|
_int = int
|
|
_bytes = bytes
|
|
|
|
|
|
class APIFrameHelper:
|
|
"""Helper class to handle the API frame protocol."""
|
|
|
|
__slots__ = (
|
|
"_loop",
|
|
"_connection",
|
|
"_transport",
|
|
"_writelines",
|
|
"ready_future",
|
|
"_buffer",
|
|
"_buffer_len",
|
|
"_pos",
|
|
"_client_info",
|
|
"_log_name",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
connection: APIConnection,
|
|
client_info: str,
|
|
log_name: str,
|
|
) -> None:
|
|
"""Initialize the API frame helper."""
|
|
loop = asyncio.get_event_loop()
|
|
self._loop = loop
|
|
self._connection = connection
|
|
self._transport: asyncio.Transport | None = None
|
|
self._writelines: (
|
|
None | (Callable[[Iterable[bytes | bytearray | memoryview[int]]], None])
|
|
) = None
|
|
self.ready_future = self._loop.create_future()
|
|
self._buffer: bytes | None = None
|
|
self._buffer_len = 0
|
|
self._pos = 0
|
|
self._client_info = client_info
|
|
self._log_name = log_name
|
|
|
|
def set_log_name(self, log_name: str) -> None:
|
|
"""Set the log name."""
|
|
self._log_name = log_name
|
|
|
|
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
|
|
if not self.ready_future.done():
|
|
self.ready_future.set_exception(exc)
|
|
|
|
def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None:
|
|
"""Add data to the buffer."""
|
|
# This should not be isinstance(data, bytes) because we want to
|
|
# to explicitly check for bytes and not for subclasses of bytes
|
|
if type(data) is not bytes: # pylint: disable=unidiomatic-typecheck
|
|
# Protractor sends a bytearray, so we need to convert it to bytes
|
|
# https://github.com/esphome/issues/issues/5117
|
|
bytes_data = bytes(data)
|
|
else:
|
|
bytes_data = data
|
|
if self._buffer_len == 0:
|
|
# This is the best case scenario, we don't have to copy the data
|
|
# and can just use the buffer directly. This is the most common
|
|
# case as well.
|
|
self._buffer = bytes_data
|
|
else:
|
|
if TYPE_CHECKING:
|
|
assert self._buffer is not None, "Buffer should be set"
|
|
# This is the worst case scenario, we have to copy the bytes_data
|
|
# and can't just use the buffer directly. This is also very
|
|
# uncommon since we usually read the entire frame at once.
|
|
self._buffer += bytes_data
|
|
self._buffer_len += len(bytes_data)
|
|
|
|
def _remove_from_buffer(self) -> None:
|
|
"""Remove data from the buffer."""
|
|
end_of_frame_pos = self._pos
|
|
self._buffer_len -= end_of_frame_pos
|
|
if self._buffer_len == 0:
|
|
# This is the best case scenario, we can just set the buffer to None
|
|
# and don't have to copy the data. This is the most common case as well.
|
|
self._buffer = None
|
|
return
|
|
if TYPE_CHECKING:
|
|
assert self._buffer is not None, "Buffer should be set"
|
|
# This is the worst case scenario, we have to copy the data
|
|
# and can't just use the buffer directly. This should only happen
|
|
# when we read multiple frames at once because the event loop
|
|
# is blocked and we cannot pull the data out of the buffer fast enough.
|
|
self._buffer = self._buffer[end_of_frame_pos:]
|
|
|
|
def _read(self, length: _int) -> bytes | None:
|
|
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
|
original_pos = self._pos
|
|
new_pos = original_pos + length
|
|
if self._buffer_len < new_pos:
|
|
return None
|
|
self._pos = new_pos
|
|
if TYPE_CHECKING:
|
|
assert self._buffer is not None, "Buffer should be set"
|
|
return self._buffer[original_pos:new_pos]
|
|
|
|
def _read_varuint(self) -> _int:
|
|
"""Read a varuint from the buffer or -1 if the buffer runs out of bytes."""
|
|
if TYPE_CHECKING:
|
|
assert self._buffer is not None, "Buffer should be set"
|
|
result = 0
|
|
bitpos = 0
|
|
while self._buffer_len > self._pos:
|
|
val = self._buffer[self._pos]
|
|
self._pos += 1
|
|
result |= (val & 0x7F) << bitpos
|
|
if (val & 0x80) == 0:
|
|
return result
|
|
bitpos += 7
|
|
return -1
|
|
|
|
@abstractmethod
|
|
def write_packets(
|
|
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
|
) -> None:
|
|
"""Write a packets to the socket.
|
|
|
|
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
|
"""
|
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
"""Handle a new connection."""
|
|
self._transport = cast(asyncio.Transport, transport)
|
|
self._writelines = self._transport.writelines
|
|
|
|
def _handle_error_and_close(self, exc: Exception) -> None:
|
|
self._handle_error(exc)
|
|
self.close()
|
|
|
|
def _handle_error(self, exc: Exception) -> None:
|
|
self._set_ready_future_exception(exc)
|
|
self._connection.report_fatal_error(exc)
|
|
|
|
def connection_lost(self, exc: Exception | None) -> None:
|
|
"""Handle the connection being lost."""
|
|
self._handle_error(
|
|
exc or SocketClosedAPIError(f"{self._log_name}: Connection lost")
|
|
)
|
|
|
|
def eof_received(self) -> bool | None:
|
|
"""Handle EOF received."""
|
|
self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received"))
|
|
return False
|
|
|
|
def close(self) -> None:
|
|
"""Close the connection."""
|
|
if self._transport:
|
|
self._transport.close()
|
|
self._transport = None
|
|
self._writelines = None
|
|
|
|
def pause_writing(self) -> None:
|
|
"""Stub."""
|
|
|
|
def resume_writing(self) -> None:
|
|
"""Stub."""
|
|
|
|
def _write_bytes(self, data: Iterable[_bytes], debug_enabled: bool) -> None:
|
|
"""Write bytes to the socket."""
|
|
if debug_enabled:
|
|
_LOGGER.debug(
|
|
"%s: Sending frame: [%s]", self._log_name, b"".join(data).hex()
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
assert self._writelines is not None, "Writer is not set"
|
|
|
|
self._writelines(data)
|