mirror of https://github.com/home-assistant/core
431 lines
14 KiB
Python
431 lines
14 KiB
Python
"""Assist satellite entity."""
|
|
|
|
from abc import abstractmethod
|
|
import asyncio
|
|
from collections.abc import AsyncIterable
|
|
import contextlib
|
|
from dataclasses import dataclass
|
|
from enum import StrEnum
|
|
import logging
|
|
import time
|
|
from typing import Any, Final, Literal, final
|
|
|
|
from homeassistant.components import media_source, stt, tts
|
|
from homeassistant.components.assist_pipeline import (
|
|
OPTION_PREFERRED,
|
|
AudioSettings,
|
|
PipelineEvent,
|
|
PipelineEventType,
|
|
PipelineStage,
|
|
async_get_pipeline,
|
|
async_get_pipelines,
|
|
async_pipeline_from_audio_stream,
|
|
vad,
|
|
)
|
|
from homeassistant.components.media_player import async_process_play_media_url
|
|
from homeassistant.components.tts import (
|
|
generate_media_source_id as tts_generate_media_source_id,
|
|
)
|
|
from homeassistant.core import Context, callback
|
|
from homeassistant.helpers import entity
|
|
from homeassistant.helpers.entity import EntityDescription
|
|
|
|
from .const import AssistSatelliteEntityFeature
|
|
from .errors import AssistSatelliteError, SatelliteBusyError
|
|
|
|
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
class AssistSatelliteState(StrEnum):
|
|
"""Valid states of an Assist satellite entity."""
|
|
|
|
IDLE = "idle"
|
|
"""Device is waiting for user input, such as a wake word or a button press."""
|
|
|
|
LISTENING = "listening"
|
|
"""Device is streaming audio with the voice command to Home Assistant."""
|
|
|
|
PROCESSING = "processing"
|
|
"""Home Assistant is processing the voice command."""
|
|
|
|
RESPONDING = "responding"
|
|
"""Device is speaking the response."""
|
|
|
|
|
|
class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True):
|
|
"""A class that describes Assist satellite entities."""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AssistSatelliteWakeWord:
|
|
"""Available wake word model."""
|
|
|
|
id: str
|
|
"""Unique id for wake word model."""
|
|
|
|
wake_word: str
|
|
"""Wake word phrase."""
|
|
|
|
trained_languages: list[str]
|
|
"""List of languages that the wake word was trained on."""
|
|
|
|
|
|
@dataclass
|
|
class AssistSatelliteConfiguration:
|
|
"""Satellite configuration."""
|
|
|
|
available_wake_words: list[AssistSatelliteWakeWord]
|
|
"""List of available available wake word models."""
|
|
|
|
active_wake_words: list[str]
|
|
"""List of active wake word ids."""
|
|
|
|
max_active_wake_words: int
|
|
"""Maximum number of simultaneous wake words allowed (0 for no limit)."""
|
|
|
|
|
|
@dataclass
|
|
class AssistSatelliteAnnouncement:
|
|
"""Announcement to be made."""
|
|
|
|
message: str
|
|
"""Message to be spoken."""
|
|
|
|
media_id: str
|
|
"""Media ID to be played."""
|
|
|
|
media_id_source: Literal["url", "media_id", "tts"]
|
|
|
|
|
|
class AssistSatelliteEntity(entity.Entity):
|
|
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
|
|
|
entity_description: AssistSatelliteEntityDescription
|
|
_attr_should_poll = False
|
|
_attr_supported_features = AssistSatelliteEntityFeature(0)
|
|
_attr_pipeline_entity_id: str | None = None
|
|
_attr_vad_sensitivity_entity_id: str | None = None
|
|
|
|
_conversation_id: str | None = None
|
|
_conversation_id_time: float | None = None
|
|
|
|
_run_has_tts: bool = False
|
|
_is_announcing = False
|
|
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
|
_attr_tts_options: dict[str, Any] | None = None
|
|
_pipeline_task: asyncio.Task | None = None
|
|
|
|
__assist_satellite_state = AssistSatelliteState.IDLE
|
|
|
|
@final
|
|
@property
|
|
def state(self) -> str | None:
|
|
"""Return state of the entity."""
|
|
return self.__assist_satellite_state
|
|
|
|
@property
|
|
def pipeline_entity_id(self) -> str | None:
|
|
"""Entity ID of the pipeline to use for the next conversation."""
|
|
return self._attr_pipeline_entity_id
|
|
|
|
@property
|
|
def vad_sensitivity_entity_id(self) -> str | None:
|
|
"""Entity ID of the VAD sensitivity to use for the next conversation."""
|
|
return self._attr_vad_sensitivity_entity_id
|
|
|
|
@property
|
|
def tts_options(self) -> dict[str, Any] | None:
|
|
"""Options passed for text-to-speech."""
|
|
return self._attr_tts_options
|
|
|
|
@callback
|
|
@abstractmethod
|
|
def async_get_configuration(self) -> AssistSatelliteConfiguration:
|
|
"""Get the current satellite configuration."""
|
|
|
|
@abstractmethod
|
|
async def async_set_configuration(
|
|
self, config: AssistSatelliteConfiguration
|
|
) -> None:
|
|
"""Set the current satellite configuration."""
|
|
|
|
async def async_intercept_wake_word(self) -> str | None:
|
|
"""Intercept the next wake word from the satellite.
|
|
|
|
Returns the detected wake word phrase or None.
|
|
"""
|
|
if self._wake_word_intercept_future is not None:
|
|
raise SatelliteBusyError("Wake word interception already in progress")
|
|
|
|
# Will cause next wake word to be intercepted in
|
|
# async_accept_pipeline_from_satellite
|
|
self._wake_word_intercept_future = asyncio.Future()
|
|
|
|
_LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id)
|
|
|
|
try:
|
|
return await self._wake_word_intercept_future
|
|
finally:
|
|
self._wake_word_intercept_future = None
|
|
|
|
async def async_internal_announce(
|
|
self,
|
|
message: str | None = None,
|
|
media_id: str | None = None,
|
|
) -> None:
|
|
"""Play and show an announcement on the satellite.
|
|
|
|
If media_id is not provided, message is synthesized to
|
|
audio with the selected pipeline.
|
|
|
|
If media_id is provided, it is played directly. It is possible
|
|
to omit the message and the satellite will not show any text.
|
|
|
|
Calls async_announce with message and media id.
|
|
"""
|
|
await self._cancel_running_pipeline()
|
|
|
|
media_id_source: Literal["url", "media_id", "tts"] | None = None
|
|
|
|
if message is None:
|
|
message = ""
|
|
|
|
if not media_id:
|
|
media_id_source = "tts"
|
|
# Synthesize audio and get URL
|
|
pipeline_id = self._resolve_pipeline()
|
|
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
|
|
|
tts_options: dict[str, Any] = {}
|
|
if pipeline.tts_voice is not None:
|
|
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
|
|
|
if self.tts_options is not None:
|
|
tts_options.update(self.tts_options)
|
|
|
|
media_id = tts_generate_media_source_id(
|
|
self.hass,
|
|
message,
|
|
engine=pipeline.tts_engine,
|
|
language=pipeline.tts_language,
|
|
options=tts_options,
|
|
)
|
|
|
|
if media_source.is_media_source_id(media_id):
|
|
if not media_id_source:
|
|
media_id_source = "media_id"
|
|
media = await media_source.async_resolve_media(
|
|
self.hass,
|
|
media_id,
|
|
None,
|
|
)
|
|
media_id = media.url
|
|
|
|
if not media_id_source:
|
|
media_id_source = "url"
|
|
|
|
# Resolve to full URL
|
|
media_id = async_process_play_media_url(self.hass, media_id)
|
|
|
|
if self._is_announcing:
|
|
raise SatelliteBusyError
|
|
|
|
self._is_announcing = True
|
|
self._set_state(AssistSatelliteState.RESPONDING)
|
|
|
|
try:
|
|
# Block until announcement is finished
|
|
await self.async_announce(
|
|
AssistSatelliteAnnouncement(message, media_id, media_id_source)
|
|
)
|
|
finally:
|
|
self._is_announcing = False
|
|
self._set_state(AssistSatelliteState.IDLE)
|
|
|
|
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
|
|
"""Announce media on the satellite.
|
|
|
|
Should block until the announcement is done playing.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
async def async_accept_pipeline_from_satellite(
|
|
self,
|
|
audio_stream: AsyncIterable[bytes],
|
|
start_stage: PipelineStage = PipelineStage.STT,
|
|
end_stage: PipelineStage = PipelineStage.TTS,
|
|
wake_word_phrase: str | None = None,
|
|
) -> None:
|
|
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
|
|
await self._cancel_running_pipeline()
|
|
|
|
if self._wake_word_intercept_future and start_stage in (
|
|
PipelineStage.WAKE_WORD,
|
|
PipelineStage.STT,
|
|
):
|
|
if start_stage == PipelineStage.WAKE_WORD:
|
|
self._wake_word_intercept_future.set_exception(
|
|
AssistSatelliteError(
|
|
"Only on-device wake words currently supported"
|
|
)
|
|
)
|
|
return
|
|
|
|
# Intercepting wake word and immediately end pipeline
|
|
_LOGGER.debug(
|
|
"Intercepted wake word: %s (entity_id=%s)",
|
|
wake_word_phrase,
|
|
self.entity_id,
|
|
)
|
|
|
|
if wake_word_phrase is None:
|
|
self._wake_word_intercept_future.set_exception(
|
|
AssistSatelliteError("No wake word phrase provided")
|
|
)
|
|
else:
|
|
self._wake_word_intercept_future.set_result(wake_word_phrase)
|
|
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
|
|
return
|
|
|
|
device_id = self.registry_entry.device_id if self.registry_entry else None
|
|
|
|
# Refresh context if necessary
|
|
if (
|
|
(self._context is None)
|
|
or (self._context_set is None)
|
|
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
|
|
):
|
|
self.async_set_context(Context())
|
|
|
|
assert self._context is not None
|
|
|
|
# Reset conversation id if necessary
|
|
if self._conversation_id_time and (
|
|
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
|
):
|
|
self._conversation_id = None
|
|
self._conversation_id_time = None
|
|
|
|
# Set entity state based on pipeline events
|
|
self._run_has_tts = False
|
|
|
|
assert self.platform.config_entry is not None
|
|
self._pipeline_task = self.platform.config_entry.async_create_background_task(
|
|
self.hass,
|
|
async_pipeline_from_audio_stream(
|
|
self.hass,
|
|
context=self._context,
|
|
event_callback=self._internal_on_pipeline_event,
|
|
stt_metadata=stt.SpeechMetadata(
|
|
language="", # set in async_pipeline_from_audio_stream
|
|
format=stt.AudioFormats.WAV,
|
|
codec=stt.AudioCodecs.PCM,
|
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
|
),
|
|
stt_stream=audio_stream,
|
|
pipeline_id=self._resolve_pipeline(),
|
|
conversation_id=self._conversation_id,
|
|
device_id=device_id,
|
|
tts_audio_output=self.tts_options,
|
|
wake_word_phrase=wake_word_phrase,
|
|
audio_settings=AudioSettings(
|
|
silence_seconds=self._resolve_vad_sensitivity()
|
|
),
|
|
start_stage=start_stage,
|
|
end_stage=end_stage,
|
|
),
|
|
f"{self.entity_id}_pipeline",
|
|
)
|
|
|
|
try:
|
|
await self._pipeline_task
|
|
finally:
|
|
self._pipeline_task = None
|
|
|
|
async def _cancel_running_pipeline(self) -> None:
|
|
"""Cancel the current pipeline if it's running."""
|
|
if self._pipeline_task is not None:
|
|
self._pipeline_task.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await self._pipeline_task
|
|
|
|
self._pipeline_task = None
|
|
|
|
@abstractmethod
|
|
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
|
"""Handle pipeline events."""
|
|
|
|
@callback
|
|
def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
|
|
"""Set state based on pipeline stage."""
|
|
if event.type is PipelineEventType.WAKE_WORD_START:
|
|
self._set_state(AssistSatelliteState.IDLE)
|
|
elif event.type is PipelineEventType.STT_START:
|
|
self._set_state(AssistSatelliteState.LISTENING)
|
|
elif event.type is PipelineEventType.INTENT_START:
|
|
self._set_state(AssistSatelliteState.PROCESSING)
|
|
elif event.type is PipelineEventType.INTENT_END:
|
|
assert event.data is not None
|
|
# Update timeout
|
|
self._conversation_id_time = time.monotonic()
|
|
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
|
elif event.type is PipelineEventType.TTS_START:
|
|
# Wait until tts_response_finished is called to return to waiting state
|
|
self._run_has_tts = True
|
|
self._set_state(AssistSatelliteState.RESPONDING)
|
|
elif event.type is PipelineEventType.RUN_END:
|
|
if not self._run_has_tts:
|
|
self._set_state(AssistSatelliteState.IDLE)
|
|
|
|
self.on_pipeline_event(event)
|
|
|
|
@callback
|
|
def _set_state(self, state: AssistSatelliteState) -> None:
|
|
"""Set the entity's state."""
|
|
self.__assist_satellite_state = state
|
|
self.async_write_ha_state()
|
|
|
|
@callback
|
|
def tts_response_finished(self) -> None:
|
|
"""Tell entity that the text-to-speech response has finished playing."""
|
|
self._set_state(AssistSatelliteState.IDLE)
|
|
|
|
@callback
|
|
def _resolve_pipeline(self) -> str | None:
|
|
"""Resolve pipeline from select entity to id.
|
|
|
|
Return None to make async_get_pipeline look up the preferred pipeline.
|
|
"""
|
|
if not (pipeline_entity_id := self.pipeline_entity_id):
|
|
return None
|
|
|
|
if (pipeline_entity_state := self.hass.states.get(pipeline_entity_id)) is None:
|
|
raise RuntimeError("Pipeline entity not found")
|
|
|
|
if pipeline_entity_state.state != OPTION_PREFERRED:
|
|
# Resolve pipeline by name
|
|
for pipeline in async_get_pipelines(self.hass):
|
|
if pipeline.name == pipeline_entity_state.state:
|
|
return pipeline.id
|
|
|
|
return None
|
|
|
|
@callback
|
|
def _resolve_vad_sensitivity(self) -> float:
|
|
"""Resolve VAD sensitivity from select entity to enum."""
|
|
vad_sensitivity = vad.VadSensitivity.DEFAULT
|
|
|
|
if vad_sensitivity_entity_id := self.vad_sensitivity_entity_id:
|
|
if (
|
|
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
|
|
) is None:
|
|
raise RuntimeError("VAD sensitivity entity not found")
|
|
|
|
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
|
|
|
|
return vad.VadSensitivity.to_seconds(vad_sensitivity)
|