core/homeassistant/components/group/entity.py

486 lines
16 KiB
Python

"""Provide entity classes for group entities."""
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Callable, Collection, Mapping
import logging
from typing import Any
from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_OFF, STATE_ON
from homeassistant.core import (
CALLBACK_TYPE,
Event,
EventStateChangedData,
HomeAssistant,
State,
callback,
split_entity_id,
)
from homeassistant.helpers import start
from homeassistant.helpers.entity import Entity, async_generate_entity_id
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_state_change_event
from .const import ATTR_AUTO, ATTR_ORDER, DATA_COMPONENT, DOMAIN, GROUP_ORDER, REG_KEY
from .registry import GroupIntegrationRegistry, SingleStateType
ENTITY_ID_FORMAT = DOMAIN + ".{}"
_PACKAGE_LOGGER = logging.getLogger(__package__)
_LOGGER = logging.getLogger(__name__)
class GroupEntity(Entity):
"""Representation of a Group of entities."""
_unrecorded_attributes = frozenset({ATTR_ENTITY_ID})
_attr_should_poll = False
_entity_ids: list[str]
@callback
def async_start_preview(
self,
preview_callback: Callable[[str, Mapping[str, Any]], None],
) -> CALLBACK_TYPE:
"""Render a preview."""
for entity_id in self._entity_ids:
if (state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(entity_id, state)
@callback
def async_state_changed_listener(
event: Event[EventStateChangedData] | None,
) -> None:
"""Handle child updates."""
self.async_update_group_state()
if event:
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
calculated_state = self._async_calculate_state()
preview_callback(calculated_state.state, calculated_state.attributes)
async_state_changed_listener(None)
return async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
async def async_added_to_hass(self) -> None:
"""Register listeners."""
for entity_id in self._entity_ids:
if (state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(entity_id, state)
@callback
def async_state_changed_listener(
event: Event[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
self.async_on_remove(start.async_at_start(self.hass, self._update_at_start))
@callback
def _update_at_start(self, _: HomeAssistant) -> None:
"""Update the group state at start."""
self.async_update_group_state()
self.async_write_ha_state()
@callback
def async_defer_or_update_ha_state(self) -> None:
"""Only update once at start."""
if not self.hass.is_running:
return
self.async_update_group_state()
self.async_write_ha_state()
@abstractmethod
@callback
def async_update_group_state(self) -> None:
"""Abstract method to update the entity."""
@callback
def async_update_supported_features(
self,
entity_id: str,
new_state: State | None,
) -> None:
"""Update dictionaries with supported features."""
class Group(Entity):
"""Track a group of entity ids."""
_unrecorded_attributes = frozenset({ATTR_ENTITY_ID, ATTR_ORDER, ATTR_AUTO})
_attr_should_poll = False
tracking: tuple[str, ...]
trackable: tuple[str, ...]
single_state_type_key: SingleStateType | None
_registry: GroupIntegrationRegistry
def __init__(
self,
hass: HomeAssistant,
name: str,
*,
created_by_service: bool,
entity_ids: Collection[str] | None,
icon: str | None,
mode: bool | None,
order: int | None,
) -> None:
"""Initialize a group.
This Object has factory function for creation.
"""
self.hass = hass
self._attr_name = name
self._state: str | None = None
self._attr_icon = icon
self._entity_ids = entity_ids
self._on_off: dict[str, bool] = {}
self._assumed: dict[str, bool] = {}
self._on_states: set[str] = set()
self.created_by_service = created_by_service
self.mode = any
if mode:
self.mode = all
self._order = order
self._assumed_state = False
self._async_unsub_state_changed: CALLBACK_TYPE | None = None
@staticmethod
@callback
def async_create_group_entity(
hass: HomeAssistant,
name: str,
*,
created_by_service: bool,
entity_ids: Collection[str] | None,
icon: str | None,
mode: bool | None,
object_id: str | None,
order: int | None,
) -> Group:
"""Create a group entity."""
if order is None:
hass.data.setdefault(GROUP_ORDER, 0)
order = hass.data[GROUP_ORDER]
# Keep track of the group order without iterating
# every state in the state machine every time
# we setup a new group
hass.data[GROUP_ORDER] += 1
group = Group(
hass,
name,
created_by_service=created_by_service,
entity_ids=entity_ids,
icon=icon,
mode=mode,
order=order,
)
group.entity_id = async_generate_entity_id(
ENTITY_ID_FORMAT, object_id or name, hass=hass
)
return group
@staticmethod
async def async_create_group(
hass: HomeAssistant,
name: str,
*,
created_by_service: bool,
entity_ids: Collection[str] | None,
icon: str | None,
mode: bool | None,
object_id: str | None,
order: int | None,
) -> Group:
"""Initialize a group.
This method must be run in the event loop.
"""
group = Group.async_create_group_entity(
hass,
name,
created_by_service=created_by_service,
entity_ids=entity_ids,
icon=icon,
mode=mode,
object_id=object_id,
order=order,
)
# If called before the platform async_setup is called (test cases)
await async_get_component(hass).async_add_entities([group])
return group
def set_name(self, value: str) -> None:
"""Set Group name."""
self._attr_name = value
@property
def state(self) -> str | None:
"""Return the state of the group."""
return self._state
def set_icon(self, value: str | None) -> None:
"""Set Icon for group."""
self._attr_icon = value
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return the state attributes for the group."""
data = {ATTR_ENTITY_ID: self.tracking, ATTR_ORDER: self._order}
if self.created_by_service:
data[ATTR_AUTO] = True
return data
@property
def assumed_state(self) -> bool:
"""Test if any member has an assumed state."""
return self._assumed_state
@callback
def async_update_tracked_entity_ids(
self, entity_ids: Collection[str] | None
) -> None:
"""Update the member entity IDs.
This method must be run in the event loop.
"""
self._async_stop()
self._set_tracked(entity_ids)
self._reset_tracked_state()
self._async_start()
def _set_tracked(self, entity_ids: Collection[str] | None) -> None:
"""Tuple of entities to be tracked."""
# tracking are the entities we want to track
# trackable are the entities we actually watch
if not entity_ids:
self.tracking = ()
self.trackable = ()
self.single_state_type_key = None
return
registry = self._registry
excluded_domains = registry.exclude_domains
tracking: list[str] = []
trackable: list[str] = []
single_state_type_set: set[SingleStateType] = set()
for ent_id in entity_ids:
ent_id_lower = ent_id.lower()
domain = split_entity_id(ent_id_lower)[0]
tracking.append(ent_id_lower)
if domain not in excluded_domains:
trackable.append(ent_id_lower)
if domain in registry.state_group_mapping:
single_state_type_set.add(registry.state_group_mapping[domain])
elif domain == DOMAIN:
# If a group contains another group we check if that group
# has a specific single state type
if ent_id in registry.state_group_mapping:
single_state_type_set.add(registry.state_group_mapping[ent_id])
else:
single_state_type_set.add(SingleStateType(STATE_ON, STATE_OFF))
if len(single_state_type_set) == 1:
self.single_state_type_key = next(iter(single_state_type_set))
# To support groups with nested groups we store the state type
# per group entity_id if there is a single state type
registry.state_group_mapping[self.entity_id] = self.single_state_type_key
else:
self.single_state_type_key = None
self.trackable = tuple(trackable)
self.tracking = tuple(tracking)
@callback
def _async_deregister(self) -> None:
"""Deregister group entity from the registry."""
registry = self._registry
if self.entity_id in registry.state_group_mapping:
registry.state_group_mapping.pop(self.entity_id)
@callback
def _async_start(self, _: HomeAssistant | None = None) -> None:
"""Start tracking members and write state."""
self._reset_tracked_state()
self._async_start_tracking()
self.async_write_ha_state()
@callback
def _async_start_tracking(self) -> None:
"""Start tracking members.
This method must be run in the event loop.
"""
if self.trackable and self._async_unsub_state_changed is None:
self._async_unsub_state_changed = async_track_state_change_event(
self.hass, self.trackable, self._async_state_changed_listener
)
self._async_update_group_state()
@callback
def _async_stop(self) -> None:
"""Unregister the group from Home Assistant.
This method must be run in the event loop.
"""
if self._async_unsub_state_changed:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
@callback
def async_update_group_state(self) -> None:
"""Query all members and determine current group state."""
self._state = None
self._async_update_group_state()
async def async_added_to_hass(self) -> None:
"""Handle addition to Home Assistant."""
self._registry = self.hass.data[REG_KEY]
self._set_tracked(self._entity_ids)
self.async_on_remove(start.async_at_start(self.hass, self._async_start))
self.async_on_remove(self._async_deregister)
async def async_will_remove_from_hass(self) -> None:
"""Handle removal from Home Assistant."""
self._async_stop()
async def _async_state_changed_listener(
self, event: Event[EventStateChangedData]
) -> None:
"""Respond to a member state changing.
This method must be run in the event loop.
"""
# removed
if self._async_unsub_state_changed is None:
return
self.async_set_context(event.context)
if (new_state := event.data["new_state"]) is None:
# The state was removed from the state machine
self._reset_tracked_state()
self._async_update_group_state(new_state)
self.async_write_ha_state()
def _reset_tracked_state(self) -> None:
"""Reset tracked state."""
self._on_off = {}
self._assumed = {}
self._on_states = set()
for entity_id in self.trackable:
if (state := self.hass.states.get(entity_id)) is not None:
self._see_state(state)
def _see_state(self, new_state: State) -> None:
"""Keep track of the state."""
entity_id = new_state.entity_id
domain = new_state.domain
state = new_state.state
registry = self._registry
self._assumed[entity_id] = bool(new_state.attributes.get(ATTR_ASSUMED_STATE))
if domain not in registry.on_states_by_domain:
# Handle the group of a group case
if state in registry.on_off_mapping:
self._on_states.add(state)
elif state in registry.off_on_mapping:
self._on_states.add(registry.off_on_mapping[state])
self._on_off[entity_id] = state in registry.on_off_mapping
else:
entity_on_state = registry.on_states_by_domain[domain]
if domain in registry.on_states_by_domain:
self._on_states.update(entity_on_state)
self._on_off[entity_id] = state in entity_on_state
@callback
def _async_update_group_state(self, tr_state: State | None = None) -> None:
"""Update group state.
Optionally you can provide the only state changed since last update
allowing this method to take shortcuts.
This method must be run in the event loop.
"""
# To store current states of group entities. Might not be needed.
if tr_state:
self._see_state(tr_state)
if not self._on_off:
return
if (
tr_state is None
or self._assumed_state
and not tr_state.attributes.get(ATTR_ASSUMED_STATE)
):
self._assumed_state = self.mode(self._assumed.values())
elif tr_state.attributes.get(ATTR_ASSUMED_STATE):
self._assumed_state = True
num_on_states = len(self._on_states)
# If all the entity domains we are tracking
# have the same on state we use this state
# and its hass.data[REG_KEY].on_off_mapping to off
if num_on_states == 1:
on_state = next(iter(self._on_states))
# If we do not have an on state for any domains
# we use None (which will be STATE_UNKNOWN)
elif num_on_states == 0:
self._state = None
return
if self.single_state_type_key:
on_state = self.single_state_type_key.on_state
# If the entity domains have more than one
# on state, we use STATE_ON/STATE_OFF
else:
on_state = STATE_ON
group_is_on = self.mode(self._on_off.values())
if group_is_on:
self._state = on_state
elif self.single_state_type_key:
self._state = self.single_state_type_key.off_state
else:
self._state = STATE_OFF
def async_get_component(hass: HomeAssistant) -> EntityComponent[Group]:
"""Get the group entity component."""
if (component := hass.data.get(DATA_COMPONENT)) is None:
component = hass.data[DATA_COMPONENT] = EntityComponent[Group](
_PACKAGE_LOGGER, DOMAIN, hass
)
return component