core/homeassistant/components/assist_satellite/websocket_api.py

206 lines
6.8 KiB
Python

"""Assist satellite Websocket API."""
import asyncio
from dataclasses import asdict, replace
from typing import Any
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.util import uuid as uuid_util
from .connection_test import CONNECTION_TEST_URL_BASE
from .const import (
CONNECTION_TEST_DATA,
DATA_COMPONENT,
DOMAIN,
AssistSatelliteEntityFeature,
)
from .entity import AssistSatelliteEntity
CONNECTION_TEST_TIMEOUT = 30
@callback
def async_register_websocket_api(hass: HomeAssistant) -> None:
"""Register the websocket API."""
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
websocket_api.async_register_command(hass, websocket_get_configuration)
websocket_api.async_register_command(hass, websocket_set_wake_words)
websocket_api.async_register_command(hass, websocket_test_connection)
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/intercept_wake_word",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
@websocket_api.require_admin
@websocket_api.async_response
async def websocket_intercept_wake_word(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Intercept the next wake word from a satellite."""
satellite = hass.data[DATA_COMPONENT].get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
async def intercept_wake_word() -> None:
"""Push an intercepted wake word to websocket."""
try:
wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_message(
websocket_api.event_message(
msg["id"],
{"wake_word_phrase": wake_word_phrase},
)
)
except HomeAssistantError as err:
connection.send_error(msg["id"], "home_assistant_error", str(err))
task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word")
connection.subscriptions[msg["id"]] = task.cancel
connection.send_message(websocket_api.result_message(msg["id"]))
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/get_configuration",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
def websocket_get_configuration(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Get the current satellite configuration."""
satellite = hass.data[DATA_COMPONENT].get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
config_dict = asdict(satellite.async_get_configuration())
config_dict["pipeline_entity_id"] = satellite.pipeline_entity_id
config_dict["vad_entity_id"] = satellite.vad_sensitivity_entity_id
connection.send_result(msg["id"], config_dict)
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/set_wake_words",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
vol.Required("wake_word_ids"): [str],
}
)
@websocket_api.require_admin
@websocket_api.async_response
async def websocket_set_wake_words(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Set the active wake words for the satellite."""
satellite = hass.data[DATA_COMPONENT].get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
config = satellite.async_get_configuration()
# Don't set too many active wake words
actual_ids = msg["wake_word_ids"]
if len(actual_ids) > config.max_active_wake_words:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_SUPPORTED,
f"Maximum number of active wake words is {config.max_active_wake_words}",
)
return
# Verify all ids are available
available_ids = {ww.id for ww in config.available_wake_words}
for ww_id in actual_ids:
if ww_id not in available_ids:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_SUPPORTED,
f"Wake word id is not supported: {ww_id}",
)
return
await satellite.async_set_configuration(
replace(config, active_wake_words=actual_ids)
)
connection.send_result(msg["id"])
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/test_connection",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
@websocket_api.async_response
async def websocket_test_connection(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Test the connection between the device and Home Assistant.
Send an announcement to the device with a special media id.
"""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
if not (satellite.supported_features or 0) & AssistSatelliteEntityFeature.ANNOUNCE:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_SUPPORTED,
"Entity does not support announce",
)
return
# Announce and wait for event
connection_test_data = hass.data[CONNECTION_TEST_DATA]
connection_id = uuid_util.random_uuid_hex()
connection_test_event = asyncio.Event()
connection_test_data[connection_id] = connection_test_event
hass.async_create_background_task(
satellite.async_internal_announce(
media_id=f"{CONNECTION_TEST_URL_BASE}/{connection_id}"
),
f"assist_satellite_connection_test_{msg['entity_id']}",
)
try:
async with asyncio.timeout(CONNECTION_TEST_TIMEOUT):
await connection_test_event.wait()
connection.send_result(msg["id"], {"status": "success"})
except TimeoutError:
connection.send_result(msg["id"], {"status": "timeout"})
finally:
connection_test_data.pop(connection_id, None)