mirror of https://github.com/home-assistant/core
382 lines
12 KiB
Python
382 lines
12 KiB
Python
"""Test fixtures for voice assistant."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncIterable, Generator
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
|
|
from homeassistant.components import stt, tts, wake_word
|
|
from homeassistant.components.assist_pipeline import DOMAIN, select as assist_select
|
|
from homeassistant.components.assist_pipeline.const import (
|
|
BYTES_PER_CHUNK,
|
|
SAMPLE_CHANNELS,
|
|
SAMPLE_RATE,
|
|
SAMPLE_WIDTH,
|
|
)
|
|
from homeassistant.components.assist_pipeline.pipeline import (
|
|
PipelineData,
|
|
PipelineStorageCollection,
|
|
)
|
|
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
|
from homeassistant.const import Platform
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.helpers import device_registry as dr
|
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|
from homeassistant.setup import async_setup_component
|
|
|
|
from tests.common import (
|
|
MockConfigEntry,
|
|
MockModule,
|
|
MockPlatform,
|
|
mock_config_flow,
|
|
mock_integration,
|
|
mock_platform,
|
|
)
|
|
from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity
|
|
from tests.components.tts.common import MockTTSProvider
|
|
|
|
_TRANSCRIPT = "test transcript"
|
|
|
|
BYTES_ONE_SECOND = SAMPLE_RATE * SAMPLE_WIDTH * SAMPLE_CHANNELS
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
|
|
"""Mock the TTS cache dir with empty dir."""
|
|
|
|
|
|
class MockTTSPlatform(MockPlatform):
|
|
"""A mock TTS platform."""
|
|
|
|
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
|
|
|
|
def __init__(self, *, async_get_engine, **kwargs: Any) -> None:
|
|
"""Initialize the tts platform."""
|
|
super().__init__(**kwargs)
|
|
self.async_get_engine = async_get_engine
|
|
|
|
|
|
@pytest.fixture
|
|
async def mock_tts_provider() -> MockTTSProvider:
|
|
"""Mock TTS provider."""
|
|
provider = MockTTSProvider("en")
|
|
provider._supported_languages = ["en-US"]
|
|
return provider
|
|
|
|
|
|
@pytest.fixture
|
|
async def mock_stt_provider() -> MockSTTProvider:
|
|
"""Mock STT provider."""
|
|
return MockSTTProvider(supported_languages=["en-US"], text=_TRANSCRIPT)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_stt_provider_entity() -> MockSTTProviderEntity:
|
|
"""Test provider entity fixture."""
|
|
entity = MockSTTProviderEntity(supported_languages=["en-US"], text=_TRANSCRIPT)
|
|
entity._attr_name = "Mock STT"
|
|
return entity
|
|
|
|
|
|
class MockSttPlatform(MockPlatform):
|
|
"""Provide a fake STT platform."""
|
|
|
|
def __init__(self, *, async_get_engine, **kwargs: Any) -> None:
|
|
"""Initialize the stt platform."""
|
|
super().__init__(**kwargs)
|
|
self.async_get_engine = async_get_engine
|
|
|
|
|
|
class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
|
|
"""Mock wake word entity."""
|
|
|
|
fail_process_audio = False
|
|
url_path = "wake_word.test"
|
|
_attr_name = "test"
|
|
|
|
alternate_detections = False
|
|
detected_wake_word_index = 0
|
|
|
|
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
|
|
"""Return a list of supported wake words."""
|
|
return [
|
|
wake_word.WakeWord(id="test_ww", name="Test Wake Word"),
|
|
wake_word.WakeWord(id="test_ww_2", name="Test Wake Word 2"),
|
|
]
|
|
|
|
async def _async_process_audio_stream(
|
|
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
|
) -> wake_word.DetectionResult | None:
|
|
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
|
wake_words = await self.get_supported_wake_words()
|
|
|
|
if self.alternate_detections:
|
|
detected_id = wake_words[self.detected_wake_word_index].id
|
|
detected_name = wake_words[self.detected_wake_word_index].name
|
|
self.detected_wake_word_index = (self.detected_wake_word_index + 1) % len(
|
|
wake_words
|
|
)
|
|
else:
|
|
detected_id = wake_words[0].id
|
|
detected_name = wake_words[0].name
|
|
|
|
async for chunk, timestamp in stream:
|
|
if chunk.startswith(b"wake word"):
|
|
return wake_word.DetectionResult(
|
|
wake_word_id=detected_id,
|
|
wake_word_phrase=detected_name,
|
|
timestamp=timestamp,
|
|
queued_audio=[(b"queued audio", 0)],
|
|
)
|
|
|
|
# Not detected
|
|
return None
|
|
|
|
|
|
class MockWakeWordEntity2(wake_word.WakeWordDetectionEntity):
|
|
"""Second mock wake word entity to test cooldown."""
|
|
|
|
fail_process_audio = False
|
|
url_path = "wake_word.test2"
|
|
_attr_name = "test2"
|
|
|
|
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
|
|
"""Return a list of supported wake words."""
|
|
return [wake_word.WakeWord(id="test_ww", name="Test Wake Word")]
|
|
|
|
async def _async_process_audio_stream(
|
|
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
|
) -> wake_word.DetectionResult | None:
|
|
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
|
wake_words = await self.get_supported_wake_words()
|
|
|
|
async for chunk, timestamp in stream:
|
|
if chunk.startswith(b"wake word"):
|
|
return wake_word.DetectionResult(
|
|
wake_word_id=wake_words[0].id,
|
|
wake_word_phrase=wake_words[0].name,
|
|
timestamp=timestamp,
|
|
queued_audio=[(b"queued audio", 0)],
|
|
)
|
|
|
|
# Not detected
|
|
return None
|
|
|
|
|
|
@pytest.fixture
|
|
async def mock_wake_word_provider_entity() -> MockWakeWordEntity:
|
|
"""Mock wake word provider."""
|
|
return MockWakeWordEntity()
|
|
|
|
|
|
@pytest.fixture
|
|
async def mock_wake_word_provider_entity2() -> MockWakeWordEntity2:
|
|
"""Mock wake word provider."""
|
|
return MockWakeWordEntity2()
|
|
|
|
|
|
class MockFlow(ConfigFlow):
|
|
"""Test flow."""
|
|
|
|
|
|
@pytest.fixture
|
|
def config_flow_fixture(hass: HomeAssistant) -> Generator[None]:
|
|
"""Mock config flow."""
|
|
mock_platform(hass, "test.config_flow")
|
|
|
|
with mock_config_flow("test", MockFlow):
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
async def init_supporting_components(
|
|
hass: HomeAssistant,
|
|
mock_stt_provider: MockSTTProvider,
|
|
mock_stt_provider_entity: MockSTTProviderEntity,
|
|
mock_tts_provider: MockTTSProvider,
|
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
|
mock_wake_word_provider_entity2: MockWakeWordEntity2,
|
|
config_flow_fixture,
|
|
):
|
|
"""Initialize relevant components with empty configs."""
|
|
|
|
async def async_setup_entry_init(
|
|
hass: HomeAssistant, config_entry: ConfigEntry
|
|
) -> bool:
|
|
"""Set up test config entry."""
|
|
await hass.config_entries.async_forward_entry_setups(
|
|
config_entry, [Platform.STT, Platform.WAKE_WORD]
|
|
)
|
|
return True
|
|
|
|
async def async_unload_entry_init(
|
|
hass: HomeAssistant, config_entry: ConfigEntry
|
|
) -> bool:
|
|
"""Unload up test config entry."""
|
|
await hass.config_entries.async_unload_platforms(
|
|
config_entry, [Platform.STT, Platform.WAKE_WORD]
|
|
)
|
|
return True
|
|
|
|
async def async_setup_entry_stt_platform(
|
|
hass: HomeAssistant,
|
|
config_entry: ConfigEntry,
|
|
async_add_entities: AddEntitiesCallback,
|
|
) -> None:
|
|
"""Set up test stt platform via config entry."""
|
|
async_add_entities([mock_stt_provider_entity])
|
|
|
|
async def async_setup_entry_wake_word_platform(
|
|
hass: HomeAssistant,
|
|
config_entry: ConfigEntry,
|
|
async_add_entities: AddEntitiesCallback,
|
|
) -> None:
|
|
"""Set up test wake word platform via config entry."""
|
|
async_add_entities(
|
|
[mock_wake_word_provider_entity, mock_wake_word_provider_entity2]
|
|
)
|
|
|
|
mock_integration(
|
|
hass,
|
|
MockModule(
|
|
"test",
|
|
async_setup_entry=async_setup_entry_init,
|
|
async_unload_entry=async_unload_entry_init,
|
|
),
|
|
)
|
|
mock_platform(
|
|
hass,
|
|
"test.tts",
|
|
MockTTSPlatform(
|
|
async_get_engine=AsyncMock(return_value=mock_tts_provider),
|
|
),
|
|
)
|
|
mock_platform(
|
|
hass,
|
|
"test.stt",
|
|
MockSttPlatform(
|
|
async_get_engine=AsyncMock(return_value=mock_stt_provider),
|
|
async_setup_entry=async_setup_entry_stt_platform,
|
|
),
|
|
)
|
|
mock_platform(
|
|
hass,
|
|
"test.wake_word",
|
|
MockPlatform(
|
|
async_setup_entry=async_setup_entry_wake_word_platform,
|
|
),
|
|
)
|
|
mock_platform(hass, "test.config_flow")
|
|
|
|
assert await async_setup_component(hass, "homeassistant", {})
|
|
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
|
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
|
assert await async_setup_component(hass, "media_source", {})
|
|
|
|
config_entry = MockConfigEntry(domain="test")
|
|
config_entry.add_to_hass(hass)
|
|
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
@pytest.fixture
|
|
async def init_components(hass: HomeAssistant, init_supporting_components):
|
|
"""Initialize relevant components with empty configs."""
|
|
|
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
|
|
|
|
|
@pytest.fixture
|
|
async def assist_device(
|
|
hass: HomeAssistant, device_registry: dr.DeviceRegistry, init_components
|
|
) -> dr.DeviceEntry:
|
|
"""Create an assist device."""
|
|
config_entry = MockConfigEntry(domain="test_assist_device")
|
|
config_entry.add_to_hass(hass)
|
|
|
|
device = device_registry.async_get_or_create(
|
|
name="Test Device",
|
|
config_entry_id=config_entry.entry_id,
|
|
identifiers={("test_assist_device", "test")},
|
|
)
|
|
|
|
async def async_setup_entry_init(
|
|
hass: HomeAssistant, config_entry: ConfigEntry
|
|
) -> bool:
|
|
"""Set up test config entry."""
|
|
await hass.config_entries.async_forward_entry_setups(
|
|
config_entry, [Platform.SELECT]
|
|
)
|
|
return True
|
|
|
|
async def async_unload_entry_init(
|
|
hass: HomeAssistant, config_entry: ConfigEntry
|
|
) -> bool:
|
|
"""Unload up test config entry."""
|
|
await hass.config_entries.async_unload_platforms(
|
|
config_entry, [Platform.SELECT]
|
|
)
|
|
return True
|
|
|
|
async def async_setup_entry_select_platform(
|
|
hass: HomeAssistant,
|
|
config_entry: ConfigEntry,
|
|
async_add_entities: AddEntitiesCallback,
|
|
) -> None:
|
|
"""Set up test select platform via config entry."""
|
|
entities = [
|
|
assist_select.AssistPipelineSelect(
|
|
hass, "test_assist_device", "test-prefix"
|
|
),
|
|
assist_select.VadSensitivitySelect(hass, "test-prefix"),
|
|
]
|
|
for ent in entities:
|
|
ent._attr_device_info = dr.DeviceInfo(
|
|
identifiers={("test_assist_device", "test")},
|
|
)
|
|
async_add_entities(entities)
|
|
|
|
mock_integration(
|
|
hass,
|
|
MockModule(
|
|
"test_assist_device",
|
|
async_setup_entry=async_setup_entry_init,
|
|
async_unload_entry=async_unload_entry_init,
|
|
),
|
|
)
|
|
mock_platform(
|
|
hass,
|
|
"test_assist_device.select",
|
|
MockPlatform(
|
|
async_setup_entry=async_setup_entry_select_platform,
|
|
),
|
|
)
|
|
mock_platform(hass, "test_assist_device.config_flow")
|
|
|
|
with mock_config_flow("test_assist_device", ConfigFlow):
|
|
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
return device
|
|
|
|
|
|
@pytest.fixture
|
|
def pipeline_data(hass: HomeAssistant, init_components) -> PipelineData:
|
|
"""Return pipeline data."""
|
|
return hass.data[DOMAIN]
|
|
|
|
|
|
@pytest.fixture
|
|
def pipeline_storage(pipeline_data) -> PipelineStorageCollection:
|
|
"""Return pipeline storage collection."""
|
|
return pipeline_data.pipeline_store
|
|
|
|
|
|
def make_10ms_chunk(header: bytes) -> bytes:
|
|
"""Return 10ms of zeros with the given header."""
|
|
return header + bytes(BYTES_PER_CHUNK - len(header))
|