core/homeassistant/components/esphome/assist_satellite.py

677 lines
25 KiB
Python

"""Support for assist satellites in ESPHome."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable
from functools import partial
import io
from itertools import chain
import logging
import socket
from typing import Any, cast
import wave
from aioesphomeapi import (
MediaPlayerFormatPurpose,
MediaPlayerSupportedFormat,
VoiceAssistantAnnounceFinished,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
VoiceAssistantFeature,
VoiceAssistantTimerEventType,
)
from homeassistant.components import assist_satellite, tts
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineStage,
)
from homeassistant.components.intent import (
TimerEventType,
TimerInfo,
async_register_timer_handler,
)
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .entity import EsphomeAssistEntity
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
from .enum_mapper import EsphomeEnumMapper
from .ffmpeg_proxy import async_create_proxy_url
_LOGGER = logging.getLogger(__name__)
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
VoiceAssistantEventType, PipelineEventType
] = EsphomeEnumMapper(
{
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END,
}
)
_TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = (
EsphomeEnumMapper(
{
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED,
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED,
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED,
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED,
}
)
)
_ANNOUNCEMENT_TIMEOUT_SEC = 5 * 60 # 5 minutes
_CONFIG_TIMEOUT_SEC = 5
async def async_setup_entry(
hass: HomeAssistant,
entry: ESPHomeConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Assist satellite entity."""
entry_data = entry.runtime_data
assert entry_data.device_info is not None
if entry_data.device_info.voice_assistant_feature_flags_compat(
entry_data.api_version
):
async_add_entities(
[
EsphomeAssistSatellite(entry, entry_data),
]
)
class EsphomeAssistSatellite(
EsphomeAssistEntity, assist_satellite.AssistSatelliteEntity
):
"""Satellite running ESPHome."""
entity_description = assist_satellite.AssistSatelliteEntityDescription(
key="assist_satellite", translation_key="assist_satellite"
)
def __init__(
self,
config_entry: ConfigEntry,
entry_data: RuntimeEntryData,
) -> None:
"""Initialize satellite."""
super().__init__(entry_data)
self.config_entry = config_entry
self.entry_data = entry_data
self.cli = self.entry_data.client
self._is_running: bool = True
self._pipeline_task: asyncio.Task | None = None
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._tts_streaming_task: asyncio.Task | None = None
self._udp_server: VoiceAssistantUDPServer | None = None
# Empty config. Updated when added to HA.
self._satellite_config = assist_satellite.AssistSatelliteConfiguration(
available_wake_words=[], active_wake_words=[], max_active_wake_words=1
)
@property
def pipeline_entity_id(self) -> str | None:
"""Return the entity ID of the pipeline to use for the next conversation."""
assert self.entry_data.device_info is not None
ent_reg = er.async_get(self.hass)
return ent_reg.async_get_entity_id(
Platform.SELECT,
DOMAIN,
f"{self.entry_data.device_info.mac_address}-pipeline",
)
@property
def vad_sensitivity_entity_id(self) -> str | None:
"""Return the entity ID of the VAD sensitivity to use for the next conversation."""
assert self.entry_data.device_info is not None
ent_reg = er.async_get(self.hass)
return ent_reg.async_get_entity_id(
Platform.SELECT,
DOMAIN,
f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
)
@callback
def async_get_configuration(
self,
) -> assist_satellite.AssistSatelliteConfiguration:
"""Get the current satellite configuration."""
return self._satellite_config
async def async_set_configuration(
self, config: assist_satellite.AssistSatelliteConfiguration
) -> None:
"""Set the current satellite configuration."""
await self.cli.set_voice_assistant_configuration(
active_wake_words=config.active_wake_words
)
_LOGGER.debug("Set active wake words: %s", config.active_wake_words)
# Ensure configuration is updated
await self._update_satellite_config()
async def _update_satellite_config(self) -> None:
"""Get the latest satellite configuration from the device."""
try:
config = await self.cli.get_voice_assistant_configuration(
_CONFIG_TIMEOUT_SEC
)
except TimeoutError:
# Placeholder config will be used
return
# Update available/active wake words
self._satellite_config.available_wake_words = [
assist_satellite.AssistSatelliteWakeWord(
id=model.id,
wake_word=model.wake_word,
trained_languages=list(model.trained_languages),
)
for model in config.available_wake_words
]
self._satellite_config.active_wake_words = list(config.active_wake_words)
self._satellite_config.max_active_wake_words = config.max_active_wake_words
_LOGGER.debug("Received satellite configuration: %s", self._satellite_config)
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
await super().async_added_to_hass()
assert self.entry_data.device_info is not None
feature_flags = (
self.entry_data.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
)
if feature_flags & VoiceAssistantFeature.API_AUDIO:
# TCP audio
self.async_on_remove(
self.cli.subscribe_voice_assistant(
handle_start=self.handle_pipeline_start,
handle_stop=self.handle_pipeline_stop,
handle_audio=self.handle_audio,
handle_announcement_finished=self.handle_announcement_finished,
)
)
else:
# UDP audio
self.async_on_remove(
self.cli.subscribe_voice_assistant(
handle_start=self.handle_pipeline_start,
handle_stop=self.handle_pipeline_stop,
handle_announcement_finished=self.handle_announcement_finished,
)
)
if feature_flags & VoiceAssistantFeature.TIMERS:
# Device supports timers
assert (self.registry_entry is not None) and (
self.registry_entry.device_id is not None
)
self.async_on_remove(
async_register_timer_handler(
self.hass, self.registry_entry.device_id, self.handle_timer_event
)
)
if feature_flags & VoiceAssistantFeature.ANNOUNCE:
# Device supports announcements
self._attr_supported_features |= (
assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
)
# Block until config is retrieved.
# If the device supports announcements, it will return a config.
_LOGGER.debug("Waiting for satellite configuration")
await self._update_satellite_config()
if not (feature_flags & VoiceAssistantFeature.SPEAKER):
# Will use media player for TTS/announcements
self._update_tts_format()
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
await super().async_will_remove_from_hass()
self._is_running = False
self._stop_pipeline()
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
try:
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
except KeyError:
_LOGGER.debug("Received unknown pipeline event type: %s", event.type)
return
data_to_send: dict[str, Any] = {}
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START:
self.entry_data.async_set_assist_pipeline_state(True)
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
assert event.data is not None
data_to_send = {"text": event.data["stt_output"]["text"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
assert event.data is not None
data_to_send = {
"conversation_id": event.data["intent_output"]["conversation_id"] or "",
}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
assert event.data is not None
data_to_send = {"text": event.data["tts_input"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
assert event.data is not None
if tts_output := event.data["tts_output"]:
path = tts_output["url"]
url = async_process_play_media_url(self.hass, path)
data_to_send = {"url": url}
assert self.entry_data.device_info is not None
feature_flags = (
self.entry_data.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
)
if feature_flags & VoiceAssistantFeature.SPEAKER:
media_id = tts_output["media_id"]
self._tts_streaming_task = (
self.config_entry.async_create_background_task(
self.hass,
self._stream_tts_audio(media_id),
"esphome_voice_assistant_tts",
)
)
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
assert event.data is not None
if not event.data["wake_word_output"]:
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
data_to_send = {
"code": "no_wake_word",
"message": "No wake word detected",
}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert event.data is not None
data_to_send = {
"code": event.data["code"],
"message": event.data["message"],
}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END:
if self._tts_streaming_task is None:
# No TTS
self.entry_data.async_set_assist_pipeline_state(False)
self.cli.send_voice_assistant_event(event_type, data_to_send)
async def async_announce(
self, announcement: assist_satellite.AssistSatelliteAnnouncement
) -> None:
"""Announce media on the satellite.
Should block until the announcement is done playing.
"""
_LOGGER.debug(
"Waiting for announcement to finished (message=%s, media_id=%s)",
announcement.message,
announcement.media_id,
)
media_id = announcement.media_id
if announcement.media_id_source != "tts":
# Route non-TTS media through the proxy
format_to_use: MediaPlayerSupportedFormat | None = None
for supported_format in chain(
*self.entry_data.media_player_formats.values()
):
if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT:
format_to_use = supported_format
break
if format_to_use is not None:
assert (self.registry_entry is not None) and (
self.registry_entry.device_id is not None
)
proxy_url = async_create_proxy_url(
self.hass,
self.registry_entry.device_id,
media_id,
media_format=format_to_use.format,
rate=format_to_use.sample_rate or None,
channels=format_to_use.num_channels or None,
width=format_to_use.sample_bytes or None,
)
media_id = async_process_play_media_url(self.hass, proxy_url)
await self.cli.send_voice_assistant_announcement_await_response(
media_id, _ANNOUNCEMENT_TIMEOUT_SEC, announcement.message
)
async def handle_pipeline_start(
self,
conversation_id: str,
flags: int,
audio_settings: VoiceAssistantAudioSettings,
wake_word_phrase: str | None,
) -> int | None:
"""Handle pipeline run request."""
# Clear audio queue
while not self._audio_queue.empty():
await self._audio_queue.get()
if self._tts_streaming_task is not None:
# Cancel current TTS response
self._tts_streaming_task.cancel()
self._tts_streaming_task = None
# API or UDP output audio
port: int = 0
assert self.entry_data.device_info is not None
feature_flags = (
self.entry_data.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
)
if (feature_flags & VoiceAssistantFeature.SPEAKER) and not (
feature_flags & VoiceAssistantFeature.API_AUDIO
):
port = await self._start_udp_server()
_LOGGER.debug("Started UDP server on port %s", port)
# Device triggered pipeline (wake word, etc.)
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
start_stage = PipelineStage.WAKE_WORD
else:
start_stage = PipelineStage.STT
end_stage = PipelineStage.TTS
if feature_flags & VoiceAssistantFeature.SPEAKER:
# Stream WAV audio
self._attr_tts_options = {
tts.ATTR_PREFERRED_FORMAT: "wav",
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
else:
# ANNOUNCEMENT format from media player
self._update_tts_format()
# Run the pipeline
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
self._pipeline_task = self.config_entry.async_create_background_task(
self.hass,
self.async_accept_pipeline_from_satellite(
audio_stream=self._wrap_audio_stream(),
start_stage=start_stage,
end_stage=end_stage,
wake_word_phrase=wake_word_phrase,
),
"esphome_assist_satellite_pipeline",
)
self._pipeline_task.add_done_callback(
lambda _future: self.handle_pipeline_finished()
)
return port
async def handle_audio(self, data: bytes) -> None:
"""Handle incoming audio chunk from API."""
self._audio_queue.put_nowait(data)
async def handle_pipeline_stop(self, abort: bool) -> None:
"""Handle request for pipeline to stop."""
if abort:
self._abort_pipeline()
else:
self._stop_pipeline()
def handle_pipeline_finished(self) -> None:
"""Handle when pipeline has finished running."""
self._stop_udp_server()
_LOGGER.debug("Pipeline finished")
def handle_timer_event(
self, event_type: TimerEventType, timer_info: TimerInfo
) -> None:
"""Handle timer events."""
try:
native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type)
except KeyError:
_LOGGER.debug("Received unknown timer event type: %s", event_type)
return
self.cli.send_voice_assistant_timer_event(
native_event_type,
timer_info.id,
timer_info.name,
timer_info.created_seconds,
timer_info.seconds_left,
timer_info.is_active,
)
async def handle_announcement_finished(
self, announce_finished: VoiceAssistantAnnounceFinished
) -> None:
"""Handle announcement finished message (also sent for TTS)."""
self.tts_response_finished()
def _update_tts_format(self) -> None:
"""Update the TTS format from the first media player."""
for supported_format in chain(*self.entry_data.media_player_formats.values()):
# Find first announcement format
if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT:
self._attr_tts_options = {
tts.ATTR_PREFERRED_FORMAT: supported_format.format,
}
if supported_format.sample_rate > 0:
self._attr_tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = (
supported_format.sample_rate
)
if supported_format.sample_rate > 0:
self._attr_tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = (
supported_format.num_channels
)
if supported_format.sample_rate > 0:
self._attr_tts_options[tts.ATTR_PREFERRED_SAMPLE_BYTES] = (
supported_format.sample_bytes
)
break
async def _stream_tts_audio(
self,
media_id: str,
sample_rate: int = 16000,
sample_width: int = 2,
sample_channels: int = 1,
samples_per_chunk: int = 512,
) -> None:
"""Stream TTS audio chunks to device via API or UDP."""
self.cli.send_voice_assistant_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
)
try:
if not self._is_running:
return
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
_LOGGER.error("Only WAV audio can be streamed, got %s", extension)
return
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
if (
(wav_file.getframerate() != sample_rate)
or (wav_file.getsampwidth() != sample_width)
or (wav_file.getnchannels() != sample_channels)
):
_LOGGER.error("Can only stream 16Khz 16-bit mono WAV")
return
_LOGGER.debug("Streaming %s audio samples", wav_file.getnframes())
while self._is_running:
chunk = wav_file.readframes(samples_per_chunk)
if not chunk:
break
if self._udp_server is not None:
self._udp_server.send_audio_bytes(chunk)
else:
self.cli.send_voice_assistant_audio(chunk)
# Wait for 90% of the duration of the audio that was
# sent for it to be played. This will overrun the
# device's buffer for very long audio, so using a media
# player is preferred.
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
seconds_in_chunk = samples_in_chunk / sample_rate
await asyncio.sleep(seconds_in_chunk * 0.9)
except asyncio.CancelledError:
return # Don't trigger state change
finally:
self.cli.send_voice_assistant_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
)
# State change
self.tts_response_finished()
self.entry_data.async_set_assist_pipeline_state(False)
async def _wrap_audio_stream(self) -> AsyncIterable[bytes]:
"""Yield audio chunks from the queue until None."""
while True:
chunk = await self._audio_queue.get()
if not chunk:
break
yield chunk
def _stop_pipeline(self) -> None:
"""Request pipeline to be stopped by ending the audio stream and continue processing."""
self._audio_queue.put_nowait(None)
_LOGGER.debug("Requested pipeline stop")
def _abort_pipeline(self) -> None:
"""Request pipeline to be aborted (no further processing)."""
_LOGGER.debug("Requested pipeline abort")
self._audio_queue.put_nowait(None)
if self._pipeline_task is not None:
self._pipeline_task.cancel()
async def _start_udp_server(self) -> int:
"""Start a UDP server on a random free port."""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", 0)) # random free port
(
_transport,
protocol,
) = await asyncio.get_running_loop().create_datagram_endpoint(
partial(VoiceAssistantUDPServer, self._audio_queue), sock=sock
)
assert isinstance(protocol, VoiceAssistantUDPServer)
self._udp_server = protocol
# Return port
return cast(int, sock.getsockname()[1])
def _stop_udp_server(self) -> None:
"""Stop the UDP server if it's running."""
if self._udp_server is None:
return
try:
self._udp_server.close()
finally:
self._udp_server = None
_LOGGER.debug("Stopped UDP server")
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
"""Receive UDP packets and forward them to the audio queue."""
transport: asyncio.DatagramTransport | None = None
remote_addr: tuple[str, int] | None = None
def __init__(
self, audio_queue: asyncio.Queue[bytes | None], *args: Any, **kwargs: Any
) -> None:
"""Initialize protocol."""
super().__init__(*args, **kwargs)
self._audio_queue = audio_queue
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Store transport for later use."""
self.transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet."""
if self.remote_addr is None:
self.remote_addr = addr
self._audio_queue.put_nowait(data)
def error_received(self, exc: Exception) -> None:
"""Handle when a send or receive operation raises an OSError.
(Other than BlockingIOError or InterruptedError.)
"""
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
# Stop pipeline
self._audio_queue.put_nowait(None)
def close(self) -> None:
"""Close the receiver."""
if self.transport is not None:
self.transport.close()
self.remote_addr = None
def send_audio_bytes(self, data: bytes) -> None:
"""Send bytes to the device via UDP."""
if self.transport is None:
_LOGGER.error("No transport to send audio to")
return
if self.remote_addr is None:
_LOGGER.error("No address to send audio to")
return
self.transport.sendto(data, self.remote_addr)