mirror of https://github.com/home-assistant/core
486 lines
16 KiB
Python
486 lines
16 KiB
Python
"""Provide entity classes for group entities."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from abc import abstractmethod
|
|
from collections.abc import Callable, Collection, Mapping
|
|
import logging
|
|
from typing import Any
|
|
|
|
from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_OFF, STATE_ON
|
|
from homeassistant.core import (
|
|
CALLBACK_TYPE,
|
|
Event,
|
|
EventStateChangedData,
|
|
HomeAssistant,
|
|
State,
|
|
callback,
|
|
split_entity_id,
|
|
)
|
|
from homeassistant.helpers import start
|
|
from homeassistant.helpers.entity import Entity, async_generate_entity_id
|
|
from homeassistant.helpers.entity_component import EntityComponent
|
|
from homeassistant.helpers.event import async_track_state_change_event
|
|
|
|
from .const import ATTR_AUTO, ATTR_ORDER, DATA_COMPONENT, DOMAIN, GROUP_ORDER, REG_KEY
|
|
from .registry import GroupIntegrationRegistry, SingleStateType
|
|
|
|
ENTITY_ID_FORMAT = DOMAIN + ".{}"
|
|
|
|
_PACKAGE_LOGGER = logging.getLogger(__package__)
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
class GroupEntity(Entity):
|
|
"""Representation of a Group of entities."""
|
|
|
|
_unrecorded_attributes = frozenset({ATTR_ENTITY_ID})
|
|
|
|
_attr_should_poll = False
|
|
_entity_ids: list[str]
|
|
|
|
@callback
|
|
def async_start_preview(
|
|
self,
|
|
preview_callback: Callable[[str, Mapping[str, Any]], None],
|
|
) -> CALLBACK_TYPE:
|
|
"""Render a preview."""
|
|
|
|
for entity_id in self._entity_ids:
|
|
if (state := self.hass.states.get(entity_id)) is None:
|
|
continue
|
|
self.async_update_supported_features(entity_id, state)
|
|
|
|
@callback
|
|
def async_state_changed_listener(
|
|
event: Event[EventStateChangedData] | None,
|
|
) -> None:
|
|
"""Handle child updates."""
|
|
self.async_update_group_state()
|
|
if event:
|
|
self.async_update_supported_features(
|
|
event.data["entity_id"], event.data["new_state"]
|
|
)
|
|
calculated_state = self._async_calculate_state()
|
|
preview_callback(calculated_state.state, calculated_state.attributes)
|
|
|
|
async_state_changed_listener(None)
|
|
return async_track_state_change_event(
|
|
self.hass, self._entity_ids, async_state_changed_listener
|
|
)
|
|
|
|
async def async_added_to_hass(self) -> None:
|
|
"""Register listeners."""
|
|
for entity_id in self._entity_ids:
|
|
if (state := self.hass.states.get(entity_id)) is None:
|
|
continue
|
|
self.async_update_supported_features(entity_id, state)
|
|
|
|
@callback
|
|
def async_state_changed_listener(
|
|
event: Event[EventStateChangedData],
|
|
) -> None:
|
|
"""Handle child updates."""
|
|
self.async_set_context(event.context)
|
|
self.async_update_supported_features(
|
|
event.data["entity_id"], event.data["new_state"]
|
|
)
|
|
self.async_defer_or_update_ha_state()
|
|
|
|
self.async_on_remove(
|
|
async_track_state_change_event(
|
|
self.hass, self._entity_ids, async_state_changed_listener
|
|
)
|
|
)
|
|
self.async_on_remove(start.async_at_start(self.hass, self._update_at_start))
|
|
|
|
@callback
|
|
def _update_at_start(self, _: HomeAssistant) -> None:
|
|
"""Update the group state at start."""
|
|
self.async_update_group_state()
|
|
self.async_write_ha_state()
|
|
|
|
@callback
|
|
def async_defer_or_update_ha_state(self) -> None:
|
|
"""Only update once at start."""
|
|
if not self.hass.is_running:
|
|
return
|
|
|
|
self.async_update_group_state()
|
|
self.async_write_ha_state()
|
|
|
|
@abstractmethod
|
|
@callback
|
|
def async_update_group_state(self) -> None:
|
|
"""Abstract method to update the entity."""
|
|
|
|
@callback
|
|
def async_update_supported_features(
|
|
self,
|
|
entity_id: str,
|
|
new_state: State | None,
|
|
) -> None:
|
|
"""Update dictionaries with supported features."""
|
|
|
|
|
|
class Group(Entity):
|
|
"""Track a group of entity ids."""
|
|
|
|
_unrecorded_attributes = frozenset({ATTR_ENTITY_ID, ATTR_ORDER, ATTR_AUTO})
|
|
|
|
_attr_should_poll = False
|
|
tracking: tuple[str, ...]
|
|
trackable: tuple[str, ...]
|
|
single_state_type_key: SingleStateType | None
|
|
_registry: GroupIntegrationRegistry
|
|
|
|
def __init__(
|
|
self,
|
|
hass: HomeAssistant,
|
|
name: str,
|
|
*,
|
|
created_by_service: bool,
|
|
entity_ids: Collection[str] | None,
|
|
icon: str | None,
|
|
mode: bool | None,
|
|
order: int | None,
|
|
) -> None:
|
|
"""Initialize a group.
|
|
|
|
This Object has factory function for creation.
|
|
"""
|
|
self.hass = hass
|
|
self._attr_name = name
|
|
self._state: str | None = None
|
|
self._attr_icon = icon
|
|
self._entity_ids = entity_ids
|
|
self._on_off: dict[str, bool] = {}
|
|
self._assumed: dict[str, bool] = {}
|
|
self._on_states: set[str] = set()
|
|
self.created_by_service = created_by_service
|
|
self.mode = any
|
|
if mode:
|
|
self.mode = all
|
|
self._order = order
|
|
self._assumed_state = False
|
|
self._async_unsub_state_changed: CALLBACK_TYPE | None = None
|
|
|
|
@staticmethod
|
|
@callback
|
|
def async_create_group_entity(
|
|
hass: HomeAssistant,
|
|
name: str,
|
|
*,
|
|
created_by_service: bool,
|
|
entity_ids: Collection[str] | None,
|
|
icon: str | None,
|
|
mode: bool | None,
|
|
object_id: str | None,
|
|
order: int | None,
|
|
) -> Group:
|
|
"""Create a group entity."""
|
|
if order is None:
|
|
hass.data.setdefault(GROUP_ORDER, 0)
|
|
order = hass.data[GROUP_ORDER]
|
|
# Keep track of the group order without iterating
|
|
# every state in the state machine every time
|
|
# we setup a new group
|
|
hass.data[GROUP_ORDER] += 1
|
|
|
|
group = Group(
|
|
hass,
|
|
name,
|
|
created_by_service=created_by_service,
|
|
entity_ids=entity_ids,
|
|
icon=icon,
|
|
mode=mode,
|
|
order=order,
|
|
)
|
|
|
|
group.entity_id = async_generate_entity_id(
|
|
ENTITY_ID_FORMAT, object_id or name, hass=hass
|
|
)
|
|
|
|
return group
|
|
|
|
@staticmethod
|
|
async def async_create_group(
|
|
hass: HomeAssistant,
|
|
name: str,
|
|
*,
|
|
created_by_service: bool,
|
|
entity_ids: Collection[str] | None,
|
|
icon: str | None,
|
|
mode: bool | None,
|
|
object_id: str | None,
|
|
order: int | None,
|
|
) -> Group:
|
|
"""Initialize a group.
|
|
|
|
This method must be run in the event loop.
|
|
"""
|
|
group = Group.async_create_group_entity(
|
|
hass,
|
|
name,
|
|
created_by_service=created_by_service,
|
|
entity_ids=entity_ids,
|
|
icon=icon,
|
|
mode=mode,
|
|
object_id=object_id,
|
|
order=order,
|
|
)
|
|
|
|
# If called before the platform async_setup is called (test cases)
|
|
await async_get_component(hass).async_add_entities([group])
|
|
return group
|
|
|
|
def set_name(self, value: str) -> None:
|
|
"""Set Group name."""
|
|
self._attr_name = value
|
|
|
|
@property
|
|
def state(self) -> str | None:
|
|
"""Return the state of the group."""
|
|
return self._state
|
|
|
|
def set_icon(self, value: str | None) -> None:
|
|
"""Set Icon for group."""
|
|
self._attr_icon = value
|
|
|
|
@property
|
|
def extra_state_attributes(self) -> dict[str, Any]:
|
|
"""Return the state attributes for the group."""
|
|
data = {ATTR_ENTITY_ID: self.tracking, ATTR_ORDER: self._order}
|
|
if self.created_by_service:
|
|
data[ATTR_AUTO] = True
|
|
|
|
return data
|
|
|
|
@property
|
|
def assumed_state(self) -> bool:
|
|
"""Test if any member has an assumed state."""
|
|
return self._assumed_state
|
|
|
|
@callback
|
|
def async_update_tracked_entity_ids(
|
|
self, entity_ids: Collection[str] | None
|
|
) -> None:
|
|
"""Update the member entity IDs.
|
|
|
|
This method must be run in the event loop.
|
|
"""
|
|
self._async_stop()
|
|
self._set_tracked(entity_ids)
|
|
self._reset_tracked_state()
|
|
self._async_start()
|
|
|
|
def _set_tracked(self, entity_ids: Collection[str] | None) -> None:
|
|
"""Tuple of entities to be tracked."""
|
|
# tracking are the entities we want to track
|
|
# trackable are the entities we actually watch
|
|
|
|
if not entity_ids:
|
|
self.tracking = ()
|
|
self.trackable = ()
|
|
self.single_state_type_key = None
|
|
return
|
|
|
|
registry = self._registry
|
|
excluded_domains = registry.exclude_domains
|
|
|
|
tracking: list[str] = []
|
|
trackable: list[str] = []
|
|
single_state_type_set: set[SingleStateType] = set()
|
|
for ent_id in entity_ids:
|
|
ent_id_lower = ent_id.lower()
|
|
domain = split_entity_id(ent_id_lower)[0]
|
|
tracking.append(ent_id_lower)
|
|
if domain not in excluded_domains:
|
|
trackable.append(ent_id_lower)
|
|
if domain in registry.state_group_mapping:
|
|
single_state_type_set.add(registry.state_group_mapping[domain])
|
|
elif domain == DOMAIN:
|
|
# If a group contains another group we check if that group
|
|
# has a specific single state type
|
|
if ent_id in registry.state_group_mapping:
|
|
single_state_type_set.add(registry.state_group_mapping[ent_id])
|
|
else:
|
|
single_state_type_set.add(SingleStateType(STATE_ON, STATE_OFF))
|
|
|
|
if len(single_state_type_set) == 1:
|
|
self.single_state_type_key = next(iter(single_state_type_set))
|
|
# To support groups with nested groups we store the state type
|
|
# per group entity_id if there is a single state type
|
|
registry.state_group_mapping[self.entity_id] = self.single_state_type_key
|
|
else:
|
|
self.single_state_type_key = None
|
|
|
|
self.trackable = tuple(trackable)
|
|
self.tracking = tuple(tracking)
|
|
|
|
@callback
|
|
def _async_deregister(self) -> None:
|
|
"""Deregister group entity from the registry."""
|
|
registry = self._registry
|
|
if self.entity_id in registry.state_group_mapping:
|
|
registry.state_group_mapping.pop(self.entity_id)
|
|
|
|
@callback
|
|
def _async_start(self, _: HomeAssistant | None = None) -> None:
|
|
"""Start tracking members and write state."""
|
|
self._reset_tracked_state()
|
|
self._async_start_tracking()
|
|
self.async_write_ha_state()
|
|
|
|
@callback
|
|
def _async_start_tracking(self) -> None:
|
|
"""Start tracking members.
|
|
|
|
This method must be run in the event loop.
|
|
"""
|
|
if self.trackable and self._async_unsub_state_changed is None:
|
|
self._async_unsub_state_changed = async_track_state_change_event(
|
|
self.hass, self.trackable, self._async_state_changed_listener
|
|
)
|
|
|
|
self._async_update_group_state()
|
|
|
|
@callback
|
|
def _async_stop(self) -> None:
|
|
"""Unregister the group from Home Assistant.
|
|
|
|
This method must be run in the event loop.
|
|
"""
|
|
if self._async_unsub_state_changed:
|
|
self._async_unsub_state_changed()
|
|
self._async_unsub_state_changed = None
|
|
|
|
@callback
|
|
def async_update_group_state(self) -> None:
|
|
"""Query all members and determine current group state."""
|
|
self._state = None
|
|
self._async_update_group_state()
|
|
|
|
async def async_added_to_hass(self) -> None:
|
|
"""Handle addition to Home Assistant."""
|
|
self._registry = self.hass.data[REG_KEY]
|
|
self._set_tracked(self._entity_ids)
|
|
self.async_on_remove(start.async_at_start(self.hass, self._async_start))
|
|
self.async_on_remove(self._async_deregister)
|
|
|
|
async def async_will_remove_from_hass(self) -> None:
|
|
"""Handle removal from Home Assistant."""
|
|
self._async_stop()
|
|
|
|
async def _async_state_changed_listener(
|
|
self, event: Event[EventStateChangedData]
|
|
) -> None:
|
|
"""Respond to a member state changing.
|
|
|
|
This method must be run in the event loop.
|
|
"""
|
|
# removed
|
|
if self._async_unsub_state_changed is None:
|
|
return
|
|
|
|
self.async_set_context(event.context)
|
|
|
|
if (new_state := event.data["new_state"]) is None:
|
|
# The state was removed from the state machine
|
|
self._reset_tracked_state()
|
|
|
|
self._async_update_group_state(new_state)
|
|
self.async_write_ha_state()
|
|
|
|
def _reset_tracked_state(self) -> None:
|
|
"""Reset tracked state."""
|
|
self._on_off = {}
|
|
self._assumed = {}
|
|
self._on_states = set()
|
|
|
|
for entity_id in self.trackable:
|
|
if (state := self.hass.states.get(entity_id)) is not None:
|
|
self._see_state(state)
|
|
|
|
def _see_state(self, new_state: State) -> None:
|
|
"""Keep track of the state."""
|
|
entity_id = new_state.entity_id
|
|
domain = new_state.domain
|
|
state = new_state.state
|
|
registry = self._registry
|
|
self._assumed[entity_id] = bool(new_state.attributes.get(ATTR_ASSUMED_STATE))
|
|
|
|
if domain not in registry.on_states_by_domain:
|
|
# Handle the group of a group case
|
|
if state in registry.on_off_mapping:
|
|
self._on_states.add(state)
|
|
elif state in registry.off_on_mapping:
|
|
self._on_states.add(registry.off_on_mapping[state])
|
|
self._on_off[entity_id] = state in registry.on_off_mapping
|
|
else:
|
|
entity_on_state = registry.on_states_by_domain[domain]
|
|
if domain in registry.on_states_by_domain:
|
|
self._on_states.update(entity_on_state)
|
|
self._on_off[entity_id] = state in entity_on_state
|
|
|
|
@callback
|
|
def _async_update_group_state(self, tr_state: State | None = None) -> None:
|
|
"""Update group state.
|
|
|
|
Optionally you can provide the only state changed since last update
|
|
allowing this method to take shortcuts.
|
|
|
|
This method must be run in the event loop.
|
|
"""
|
|
# To store current states of group entities. Might not be needed.
|
|
if tr_state:
|
|
self._see_state(tr_state)
|
|
|
|
if not self._on_off:
|
|
return
|
|
|
|
if (
|
|
tr_state is None
|
|
or self._assumed_state
|
|
and not tr_state.attributes.get(ATTR_ASSUMED_STATE)
|
|
):
|
|
self._assumed_state = self.mode(self._assumed.values())
|
|
|
|
elif tr_state.attributes.get(ATTR_ASSUMED_STATE):
|
|
self._assumed_state = True
|
|
|
|
num_on_states = len(self._on_states)
|
|
# If all the entity domains we are tracking
|
|
# have the same on state we use this state
|
|
# and its hass.data[REG_KEY].on_off_mapping to off
|
|
if num_on_states == 1:
|
|
on_state = next(iter(self._on_states))
|
|
# If we do not have an on state for any domains
|
|
# we use None (which will be STATE_UNKNOWN)
|
|
elif num_on_states == 0:
|
|
self._state = None
|
|
return
|
|
if self.single_state_type_key:
|
|
on_state = self.single_state_type_key.on_state
|
|
# If the entity domains have more than one
|
|
# on state, we use STATE_ON/STATE_OFF
|
|
else:
|
|
on_state = STATE_ON
|
|
group_is_on = self.mode(self._on_off.values())
|
|
if group_is_on:
|
|
self._state = on_state
|
|
elif self.single_state_type_key:
|
|
self._state = self.single_state_type_key.off_state
|
|
else:
|
|
self._state = STATE_OFF
|
|
|
|
|
|
def async_get_component(hass: HomeAssistant) -> EntityComponent[Group]:
|
|
"""Get the group entity component."""
|
|
if (component := hass.data.get(DATA_COMPONENT)) is None:
|
|
component = hass.data[DATA_COMPONENT] = EntityComponent[Group](
|
|
_PACKAGE_LOGGER, DOMAIN, hass
|
|
)
|
|
return component
|