369 lines
13 KiB
Python
369 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import binascii
|
|
from functools import partial
|
|
import logging
|
|
from struct import Struct
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
|
|
from cryptography.exceptions import InvalidTag
|
|
from noise.backends.default import DefaultNoiseBackend # type: ignore[import-untyped]
|
|
from noise.backends.default.ciphers import ( # type: ignore[import-untyped]
|
|
ChaCha20Cipher,
|
|
CryptographyCipher,
|
|
)
|
|
from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
|
from noise.state import CipherState # type: ignore[import-untyped]
|
|
|
|
from ..core import (
|
|
APIConnectionError,
|
|
BadNameAPIError,
|
|
HandshakeAPIError,
|
|
InvalidEncryptionKeyAPIError,
|
|
ProtocolAPIError,
|
|
)
|
|
from .base import _LOGGER, APIFrameHelper
|
|
|
|
if TYPE_CHECKING:
|
|
from ..connection import APIConnection
|
|
|
|
|
|
PACK_NONCE = partial(Struct("<LQ").pack, 0)
|
|
|
|
_bytes = bytes
|
|
|
|
|
|
class ChaCha20CipherReuseable(ChaCha20Cipher): # type: ignore[misc]
|
|
"""ChaCha20 cipher that can be reused."""
|
|
|
|
format_nonce = PACK_NONCE
|
|
|
|
@property
|
|
def klass(self) -> type[ChaCha20Poly1305Reusable]:
|
|
return ChaCha20Poly1305Reusable
|
|
|
|
|
|
class ESPHomeNoiseBackend(DefaultNoiseBackend): # type: ignore[misc]
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.ciphers["ChaChaPoly"] = ChaCha20CipherReuseable
|
|
|
|
|
|
ESPHOME_NOISE_BACKEND = ESPHomeNoiseBackend()
|
|
|
|
|
|
# This is effectively an enum but we don't want to use an enum
|
|
# because we have a simple dispatch in the data_received method
|
|
# that would be more complicated with an enum and we want to add
|
|
# cdefs for each different state so we have a good test for each
|
|
# state receiving data since we found that the protractor event
|
|
# loop will send use a bytearray instead of bytes was not handled
|
|
# correctly.
|
|
NOISE_STATE_HELLO = 1
|
|
NOISE_STATE_HANDSHAKE = 2
|
|
NOISE_STATE_READY = 3
|
|
NOISE_STATE_CLOSED = 4
|
|
|
|
|
|
NOISE_HELLO = b"\x01\x00\x00"
|
|
|
|
int_ = int
|
|
|
|
|
|
class EncryptCipher:
|
|
"""Wrapper around the ChaCha20Poly1305 cipher for encryption."""
|
|
|
|
__slots__ = ("_nonce", "_encrypt")
|
|
|
|
def __init__(self, cipher_state: CipherState) -> None:
|
|
"""Initialize the cipher wrapper."""
|
|
crypto_cipher: CryptographyCipher = cipher_state.cipher
|
|
cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher
|
|
self._nonce: int = cipher_state.n
|
|
self._encrypt = cipher.encrypt
|
|
|
|
def encrypt(self, data: _bytes) -> bytes:
|
|
"""Encrypt a frame."""
|
|
ciphertext = self._encrypt(PACK_NONCE(self._nonce), data, None)
|
|
self._nonce += 1
|
|
return ciphertext
|
|
|
|
|
|
class DecryptCipher:
|
|
"""Wrapper around the ChaCha20Poly1305 cipher for decryption."""
|
|
|
|
__slots__ = ("_nonce", "_decrypt")
|
|
|
|
def __init__(self, cipher_state: CipherState) -> None:
|
|
"""Initialize the cipher wrapper."""
|
|
crypto_cipher: CryptographyCipher = cipher_state.cipher
|
|
cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher
|
|
self._nonce: int = cipher_state.n
|
|
self._decrypt = cipher.decrypt
|
|
|
|
def decrypt(self, data: _bytes) -> bytes:
|
|
"""Decrypt a frame."""
|
|
plaintext = self._decrypt(PACK_NONCE(self._nonce), data, None)
|
|
self._nonce += 1
|
|
return plaintext
|
|
|
|
|
|
class APINoiseFrameHelper(APIFrameHelper):
|
|
"""Frame helper for noise encrypted connections."""
|
|
|
|
__slots__ = (
|
|
"_noise_psk",
|
|
"_expected_name",
|
|
"_state",
|
|
"_server_name",
|
|
"_proto",
|
|
"_encrypt_cipher",
|
|
"_decrypt_cipher",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
connection: APIConnection,
|
|
noise_psk: str,
|
|
expected_name: str | None,
|
|
client_info: str,
|
|
log_name: str,
|
|
) -> None:
|
|
"""Initialize the API frame helper."""
|
|
super().__init__(connection, client_info, log_name)
|
|
self._noise_psk = noise_psk
|
|
self._expected_name = expected_name
|
|
self._state = NOISE_STATE_HELLO
|
|
self._server_name: str | None = None
|
|
self._encrypt_cipher: EncryptCipher | None = None
|
|
self._decrypt_cipher: DecryptCipher | None = None
|
|
self._setup_proto()
|
|
|
|
def close(self) -> None:
|
|
"""Close the connection."""
|
|
# Make sure we set the ready event if its not already set
|
|
# so that we don't block forever on the ready event if we
|
|
# are waiting for the handshake to complete.
|
|
self._set_ready_future_exception(
|
|
APIConnectionError(f"{self._log_name}: Connection closed")
|
|
)
|
|
self._state = NOISE_STATE_CLOSED
|
|
super().close()
|
|
|
|
def _handle_error(self, exc: Exception) -> None:
|
|
"""Handle an error, and provide a good message when during hello."""
|
|
if self._state == NOISE_STATE_HELLO and isinstance(exc, ConnectionResetError):
|
|
original_exc: Exception = exc
|
|
exc = HandshakeAPIError(
|
|
f"{self._log_name}: The connection dropped immediately after encrypted hello; "
|
|
"Try enabling encryption on the device or turning off "
|
|
f"encryption on the client ({self._client_info})."
|
|
)
|
|
exc.__cause__ = original_exc
|
|
elif isinstance(exc, InvalidTag):
|
|
original_exc = exc
|
|
exc = InvalidEncryptionKeyAPIError(
|
|
f"{self._log_name}: Invalid encryption key", self._server_name
|
|
)
|
|
exc.__cause__ = original_exc
|
|
super()._handle_error(exc)
|
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
"""Handle a new connection."""
|
|
super().connection_made(transport)
|
|
self._send_hello_handshake()
|
|
|
|
def data_received(self, data: bytes | bytearray | memoryview) -> None:
|
|
self._add_to_buffer(data)
|
|
while self._buffer_len:
|
|
self._pos = 0
|
|
if (header := self._read(3)) is None:
|
|
return
|
|
preamble = header[0]
|
|
if preamble != 0x01:
|
|
self._handle_error_and_close(
|
|
ProtocolAPIError(
|
|
f"{self._log_name}: Marker byte invalid: {header[0]}"
|
|
)
|
|
)
|
|
return
|
|
msg_size_high = header[1]
|
|
msg_size_low = header[2]
|
|
if (frame := self._read((msg_size_high << 8) | msg_size_low)) is None:
|
|
# The complete frame is not yet available, wait for more data
|
|
# to arrive before continuing, since callback_packet has not
|
|
# been called yet the buffer will not be cleared and the next
|
|
# call to data_received will continue processing the packet
|
|
# at the start of the frame.
|
|
return
|
|
|
|
# asyncio already runs data_received in a try block
|
|
# which will call connection_lost if an exception is raised
|
|
if self._state == NOISE_STATE_READY:
|
|
self._handle_frame(frame)
|
|
elif self._state == NOISE_STATE_HELLO:
|
|
self._handle_hello(frame)
|
|
elif self._state == NOISE_STATE_HANDSHAKE:
|
|
self._handle_handshake(frame)
|
|
else:
|
|
self._handle_closed(frame)
|
|
|
|
self._remove_from_buffer()
|
|
|
|
def _send_hello_handshake(self) -> None:
|
|
"""Send a ClientHello to the server."""
|
|
handshake_frame = self._proto.write_message()
|
|
frame_len = len(handshake_frame) + 1
|
|
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
|
self._write_bytes(
|
|
(NOISE_HELLO, header, b"\x00", handshake_frame),
|
|
_LOGGER.isEnabledFor(logging.DEBUG),
|
|
)
|
|
|
|
def _handle_hello(self, server_hello: bytes) -> None:
|
|
"""Perform the handshake with the server."""
|
|
if not server_hello:
|
|
self._handle_error_and_close(
|
|
HandshakeAPIError(f"{self._log_name}: ServerHello is empty")
|
|
)
|
|
return
|
|
|
|
# First byte of server hello is the protocol the server chose
|
|
# for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
|
|
# exists.
|
|
chosen_proto = server_hello[0]
|
|
if chosen_proto != 0x01:
|
|
self._handle_error_and_close(
|
|
HandshakeAPIError(
|
|
f"{self._log_name}: Unknown protocol selected by client {chosen_proto}"
|
|
)
|
|
)
|
|
return
|
|
|
|
# Check name matches expected name (for noise sessions, this is done
|
|
# during hello phase before a connection is set up)
|
|
# Server name is encoded as a string followed by a zero byte after the chosen proto byte
|
|
server_name_i = server_hello.find(b"\0", 1)
|
|
if server_name_i != -1:
|
|
# server name found, this extension was added in 2022.2
|
|
server_name = server_hello[1:server_name_i].decode()
|
|
self._server_name = server_name
|
|
|
|
if self._expected_name is not None and self._expected_name != server_name:
|
|
self._handle_error_and_close(
|
|
BadNameAPIError(
|
|
f"{self._log_name}: Server sent a different name '{server_name}'",
|
|
server_name,
|
|
)
|
|
)
|
|
return
|
|
|
|
self._state = NOISE_STATE_HANDSHAKE
|
|
|
|
def _decode_noise_psk(self) -> bytes:
|
|
"""Decode the given noise psk from base64 format to raw bytes."""
|
|
psk = self._noise_psk
|
|
server_name = self._server_name
|
|
try:
|
|
psk_bytes = binascii.a2b_base64(psk)
|
|
except ValueError:
|
|
raise InvalidEncryptionKeyAPIError(
|
|
f"{self._log_name}: Malformed PSK `{psk}`, expected "
|
|
"base64-encoded value",
|
|
server_name,
|
|
)
|
|
if len(psk_bytes) != 32:
|
|
raise InvalidEncryptionKeyAPIError(
|
|
f"{self._log_name}:Malformed PSK `{psk}`, expected"
|
|
f" 32-bytes of base64 data",
|
|
server_name,
|
|
)
|
|
return psk_bytes
|
|
|
|
def _setup_proto(self) -> None:
|
|
"""Set up the noise protocol."""
|
|
proto = NoiseConnection.from_name(
|
|
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
|
|
)
|
|
proto.set_as_initiator()
|
|
proto.set_psks(self._decode_noise_psk())
|
|
proto.set_prologue(b"NoiseAPIInit\x00\x00")
|
|
proto.start_handshake()
|
|
self._proto = proto
|
|
|
|
def _error_on_incorrect_preamble(self, msg: bytes) -> None:
|
|
"""Handle an incorrect preamble."""
|
|
explanation = msg[1:].decode()
|
|
if explanation != "Handshake MAC failure":
|
|
exc = HandshakeAPIError(
|
|
f"{self._log_name}: Handshake failure: {explanation}"
|
|
)
|
|
else:
|
|
exc = InvalidEncryptionKeyAPIError(
|
|
f"{self._log_name}: Invalid encryption key", self._server_name
|
|
)
|
|
self._handle_error_and_close(exc)
|
|
|
|
def _handle_handshake(self, msg: bytes) -> None:
|
|
if msg[0] != 0:
|
|
self._error_on_incorrect_preamble(msg)
|
|
return
|
|
self._proto.read_message(msg[1:])
|
|
self._state = NOISE_STATE_READY
|
|
noise_protocol = self._proto.noise_protocol
|
|
self._decrypt_cipher = DecryptCipher(noise_protocol.cipher_state_decrypt) # pylint: disable=no-member
|
|
self._encrypt_cipher = EncryptCipher(noise_protocol.cipher_state_encrypt) # pylint: disable=no-member
|
|
self.ready_future.set_result(None)
|
|
|
|
def write_packets(
|
|
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
|
) -> None:
|
|
"""Write a packets to the socket.
|
|
|
|
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
|
"""
|
|
if TYPE_CHECKING:
|
|
assert self._encrypt_cipher is not None, "Handshake should be complete"
|
|
|
|
out: list[bytes] = []
|
|
for packet in packets:
|
|
type_: int = packet[0]
|
|
data: bytes = packet[1]
|
|
data_len = len(data)
|
|
data_header = bytes(
|
|
(
|
|
(type_ >> 8) & 0xFF,
|
|
type_ & 0xFF,
|
|
(data_len >> 8) & 0xFF,
|
|
data_len & 0xFF,
|
|
)
|
|
)
|
|
frame = self._encrypt_cipher.encrypt(data_header + data)
|
|
frame_len = len(frame)
|
|
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
|
out.append(header)
|
|
out.append(frame)
|
|
|
|
self._write_bytes(out, debug_enabled)
|
|
|
|
def _handle_frame(self, frame: bytes) -> None:
|
|
"""Handle an incoming frame."""
|
|
if TYPE_CHECKING:
|
|
assert self._decrypt_cipher is not None, "Handshake should be complete"
|
|
msg = self._decrypt_cipher.decrypt(frame)
|
|
# Message layout is
|
|
# 2 bytes: message type
|
|
# 2 bytes: message length
|
|
# N bytes: message data
|
|
type_high = msg[0]
|
|
type_low = msg[1]
|
|
msg_type = (type_high << 8) | type_low
|
|
payload = msg[4:]
|
|
self._connection.process_packet(msg_type, payload)
|
|
|
|
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
|
|
"""Handle a closed frame."""
|
|
self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))
|