core/homeassistant/components/mqtt/models.py

425 lines
14 KiB
Python

"""Models used by multiple MQTT modules."""
from __future__ import annotations
from ast import literal_eval
import asyncio
from collections import deque
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import StrEnum
import logging
from typing import TYPE_CHECKING, Any, TypedDict
from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME, Platform
from homeassistant.core import CALLBACK_TYPE, callback
from homeassistant.exceptions import ServiceValidationError, TemplateError
from homeassistant.helpers import template
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
from homeassistant.helpers.typing import (
ConfigType,
DiscoveryInfoType,
TemplateVarsType,
VolSchemaType,
)
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from paho.mqtt.client import MQTTMessage
from .client import MQTT, Subscription
from .debug_info import TimestampedPublishMessage
from .device_trigger import Trigger
from .discovery import MQTTDiscoveryPayload
from .tag import MQTTTagScanner
from .const import DOMAIN, TEMPLATE_ERRORS
class PayloadSentinel(StrEnum):
"""Sentinel for `async_render_with_possible_json_value`."""
NONE = "none"
DEFAULT = "default"
_LOGGER = logging.getLogger(__name__)
ATTR_THIS = "this"
type PublishPayloadType = str | bytes | int | float | None
def convert_outgoing_mqtt_payload(
payload: PublishPayloadType,
) -> PublishPayloadType:
"""Ensure correct raw MQTT payload is passed as bytes for publishing."""
if isinstance(payload, str) and payload.startswith(("b'", 'b"')):
try:
native_object = literal_eval(payload)
except (ValueError, TypeError, SyntaxError, MemoryError):
pass
else:
if isinstance(native_object, bytes):
return native_object
return payload
@dataclass
class PublishMessage:
"""MQTT Message for publishing."""
topic: str
payload: PublishPayloadType
qos: int
retain: bool
# eq=False so we use the id() of the object for comparison
# since client will only generate one instance of this object
# per messages/subscribed_topic.
@dataclass(slots=True, frozen=True, eq=False)
class ReceiveMessage:
"""MQTT Message received."""
topic: str
payload: ReceivePayloadType
qos: int
retain: bool
subscribed_topic: str
timestamp: float
type MessageCallbackType = Callable[[ReceiveMessage], None]
class SubscriptionDebugInfo(TypedDict):
"""Class for holding subscription debug info."""
messages: deque[ReceiveMessage]
count: int
class EntityDebugInfo(TypedDict):
"""Class for holding entity based debug info."""
subscriptions: dict[str, SubscriptionDebugInfo]
discovery_data: DiscoveryInfoType
transmitted: dict[str, dict[str, deque[TimestampedPublishMessage]]]
class TriggerDebugInfo(TypedDict):
"""Class for holding trigger based debug info."""
device_id: str
discovery_data: DiscoveryInfoType
class PendingDiscovered(TypedDict):
"""Pending discovered items."""
pending: deque[MQTTDiscoveryPayload]
unsub: CALLBACK_TYPE
class MqttOriginInfo(TypedDict, total=False):
"""Integration info of discovered entity."""
name: str
manufacturer: str
sw_version: str
hw_version: str
support_url: str
class MqttCommandTemplateException(ServiceValidationError):
"""Handle MqttCommandTemplate exceptions."""
_message: str
def __init__(
self,
*args: object,
base_exception: Exception,
command_template: str,
value: PublishPayloadType,
entity_id: str | None = None,
) -> None:
"""Initialize exception."""
super().__init__(base_exception, *args)
value_log = str(value)
self.translation_domain = DOMAIN
self.translation_key = "command_template_error"
self.translation_placeholders = {
"error": str(base_exception),
"entity_id": str(entity_id),
"command_template": command_template,
}
entity_id_log = "" if entity_id is None else f" for entity '{entity_id}'"
self._message = (
f"{type(base_exception).__name__}: {base_exception} rendering template{entity_id_log}"
f", template: '{command_template}' and payload: {value_log}"
)
def __str__(self) -> str:
"""Return exception message string."""
return self._message
class MqttCommandTemplate:
"""Class for rendering MQTT payload with command templates."""
def __init__(
self,
command_template: template.Template | None,
*,
entity: Entity | None = None,
) -> None:
"""Instantiate a command template."""
self._template_state: template.TemplateStateFromEntityId | None = None
self._command_template = command_template
self._entity = entity
@callback
def async_render(
self,
value: PublishPayloadType = None,
variables: TemplateVarsType = None,
) -> PublishPayloadType:
"""Render or convert the command template with given value or variables."""
if self._command_template is None:
return value
values: dict[str, Any] = {"value": value}
if self._entity:
values[ATTR_ENTITY_ID] = self._entity.entity_id
values[ATTR_NAME] = self._entity.name
if not self._template_state and self._command_template.hass is not None:
self._template_state = template.TemplateStateFromEntityId(
self._entity.hass, self._entity.entity_id
)
values[ATTR_THIS] = self._template_state
if variables is not None:
values.update(variables)
_LOGGER.debug(
"Rendering outgoing payload with variables %s and %s",
values,
self._command_template,
)
try:
return convert_outgoing_mqtt_payload(
self._command_template.async_render(values, parse_result=False)
)
except TemplateError as exc:
raise MqttCommandTemplateException(
base_exception=exc,
command_template=self._command_template.template,
value=value,
entity_id=self._entity.entity_id if self._entity is not None else None,
) from exc
class MqttValueTemplateException(TemplateError):
"""Handle MqttValueTemplate exceptions."""
_message: str
def __init__(
self,
*args: object,
base_exception: Exception,
value_template: str,
default: ReceivePayloadType | PayloadSentinel,
payload: ReceivePayloadType,
entity_id: str | None = None,
) -> None:
"""Initialize exception."""
super().__init__(base_exception, *args)
entity_id_log = "" if entity_id is None else f" for entity '{entity_id}'"
default_log = str(default)
default_payload_log = (
"" if default is PayloadSentinel.NONE else f", default value: {default_log}"
)
payload_log = str(payload)
self._message = (
f"{type(base_exception).__name__}: {base_exception} rendering template{entity_id_log}"
f", template: '{value_template}'{default_payload_log} and payload: {payload_log}"
)
def __str__(self) -> str:
"""Return exception message string."""
return self._message
class MqttValueTemplate:
"""Class for rendering MQTT value template with possible json values."""
def __init__(
self,
value_template: template.Template | None,
*,
entity: Entity | None = None,
config_attributes: TemplateVarsType = None,
) -> None:
"""Instantiate a value template."""
self._template_state: template.TemplateStateFromEntityId | None = None
self._value_template = value_template
self._config_attributes = config_attributes
self._entity = entity
@callback
def async_render_with_possible_json_value(
self,
payload: ReceivePayloadType,
default: ReceivePayloadType | PayloadSentinel = PayloadSentinel.NONE,
variables: TemplateVarsType = None,
) -> ReceivePayloadType:
"""Render with possible json value or pass-though a received MQTT value."""
rendered_payload: ReceivePayloadType
if self._value_template is None:
return payload
values: dict[str, Any] = {}
if variables is not None:
values.update(variables)
if self._config_attributes is not None:
values.update(self._config_attributes)
if self._entity:
values[ATTR_ENTITY_ID] = self._entity.entity_id
values[ATTR_NAME] = self._entity.name
if not self._template_state and self._value_template.hass:
self._template_state = template.TemplateStateFromEntityId(
self._value_template.hass, self._entity.entity_id
)
values[ATTR_THIS] = self._template_state
if default is PayloadSentinel.NONE:
_LOGGER.debug(
"Rendering incoming payload '%s' with variables %s and %s",
payload,
values,
self._value_template,
)
try:
rendered_payload = (
self._value_template.async_render_with_possible_json_value(
payload, variables=values
)
)
except TEMPLATE_ERRORS as exc:
raise MqttValueTemplateException(
base_exception=exc,
value_template=self._value_template.template,
default=default,
payload=payload,
entity_id=self._entity.entity_id if self._entity else None,
) from exc
return rendered_payload
_LOGGER.debug(
(
"Rendering incoming payload '%s' with variables %s with default value"
" '%s' and %s"
),
payload,
values,
default,
self._value_template,
)
try:
rendered_payload = (
self._value_template.async_render_with_possible_json_value(
payload, default, variables=values
)
)
except TEMPLATE_ERRORS as exc:
raise MqttValueTemplateException(
base_exception=exc,
value_template=self._value_template.template,
default=default,
payload=payload,
entity_id=self._entity.entity_id if self._entity else None,
) from exc
return rendered_payload
class EntityTopicState:
"""Manage entity state write requests for subscribed topics."""
def __init__(self) -> None:
"""Register topic."""
self.subscribe_calls: dict[str, Entity] = {}
@callback
def process_write_state_requests(self, msg: MQTTMessage) -> None:
"""Process the write state requests."""
while self.subscribe_calls:
entity_id, entity = self.subscribe_calls.popitem()
try:
entity.async_write_ha_state()
except Exception:
_LOGGER.exception(
"Exception raised while updating state of %s, topic: "
"'%s' with payload: %s",
entity_id,
msg.topic,
msg.payload,
)
@callback
def write_state_request(self, entity: Entity) -> None:
"""Register write state request."""
self.subscribe_calls[entity.entity_id] = entity
@dataclass
class MqttData:
"""Keep the MQTT entry data."""
client: MQTT
config: list[ConfigType]
debug_info_entities: dict[str, EntityDebugInfo] = field(default_factory=dict)
debug_info_triggers: dict[tuple[str, str], TriggerDebugInfo] = field(
default_factory=dict
)
device_triggers: dict[str, Trigger] = field(default_factory=dict)
data_config_flow_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
discovery_already_discovered: set[tuple[str, str]] = field(default_factory=set)
discovery_pending_discovered: dict[tuple[str, str], PendingDiscovered] = field(
default_factory=dict
)
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
default_factory=dict
)
discovery_unsubscribe: list[CALLBACK_TYPE] = field(default_factory=list)
integration_unsubscribe: dict[str, CALLBACK_TYPE] = field(default_factory=dict)
last_discovery: float = 0.0
platforms_loaded: set[Platform | str] = field(default_factory=set)
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
reload_handlers: dict[str, CALLBACK_TYPE] = field(default_factory=dict)
reload_schema: dict[str, VolSchemaType] = field(default_factory=dict)
state_write_requests: EntityTopicState = field(default_factory=EntityTopicState)
subscriptions_to_restore: set[Subscription] = field(default_factory=set)
tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict)
@dataclass(slots=True)
class MqttComponentConfig:
"""(component, object_id, node_id, discovery_payload)."""
component: str
object_id: str
node_id: str | None
discovery_payload: MQTTDiscoveryPayload
DATA_MQTT: HassKey[MqttData] = HassKey("mqtt")
DATA_MQTT_AVAILABLE: HassKey[asyncio.Future[bool]] = HassKey("mqtt_client_available")